Files
vault/builtin/logical/ssh/path_sign.go
Mike Okner 6f84f7ffd0 Adding allow_user_key_ids field to SSH role config (#2494)
Adding a boolean field that determines whether users will be allowed to
set the ID of the signed SSH key or whether it will always be the token
display name.  Preventing users from changing the ID and always using
the token name is useful for auditing who actually used a key to access
a remote host since sshd logs key IDs.
2017-03-16 08:45:11 -04:00

405 lines
12 KiB
Go

package ssh
import (
"crypto/rand"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"golang.org/x/crypto/ssh"
)
type creationBundle struct {
KeyId string
ValidPrincipals []string
PublicKey ssh.PublicKey
CertificateType uint32
TTL time.Duration
Signer ssh.Signer
Role *sshRole
CriticalOptions map[string]string
Extensions map[string]string
}
func pathSign(b *backend) *framework.Path {
return &framework.Path{
Pattern: "sign/" + framework.GenericNameRegex("role"),
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathSign,
},
Fields: map[string]*framework.FieldSchema{
"role": &framework.FieldSchema{
Type: framework.TypeString,
Description: `The desired role with configuration for this request.`,
},
"ttl": &framework.FieldSchema{
Type: framework.TypeString,
Description: `The requested Time To Live for the SSH certificate;
sets the expiration date. If not specified
the role default, backend default, or system
default TTL is used, in that order. Cannot
be later than the role max TTL.`,
},
"public_key": &framework.FieldSchema{
Type: framework.TypeString,
Description: `SSH public key that should be signed.`,
},
"valid_principals": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Valid principals, either usernames or hostnames, that the certificate should be signed for.`,
},
"cert_type": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Type of certificate to be created; either "user" or "host".`,
Default: "user",
},
"key_id": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Key id that the created certificate should have. If not specified, the display name of the token will be used.`,
},
"critical_options": &framework.FieldSchema{
Type: framework.TypeMap,
Description: `Critical options that the certificate should be signed for.`,
},
"extensions": &framework.FieldSchema{
Type: framework.TypeMap,
Description: `Extensions that the certificate should be signed for.`,
},
},
HelpSynopsis: `Request signing an SSH key using a certain role with the provided details.`,
HelpDescription: `This path allows SSH keys to be signed according to the policy of the given role.`,
}
}
func (b *backend) pathSign(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
roleName := data.Get("role").(string)
// Get the role
role, err := b.getRole(req.Storage, roleName)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", roleName)), nil
}
return b.pathSignCertificate(req, data, role)
}
func (b *backend) pathSignCertificate(req *logical.Request, data *framework.FieldData, role *sshRole) (*logical.Response, error) {
publicKey := data.Get("public_key").(string)
if publicKey == "" {
return logical.ErrorResponse("missing public_key"), nil
}
userPublicKey, err := parsePublicSSHKey(publicKey)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to parse public_key as SSH key: %s", err)), nil
}
// Note that these various functions always return "user errors" so we pass
// them as 4xx values
keyId, err := b.calculateKeyId(data, req, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
certificateType, err := b.calculateCertificateType(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
var parsedPrincipals []string
if certificateType == ssh.HostCert {
parsedPrincipals, err = b.calculateValidPrincipals(data, "", role.AllowedDomains, validateValidPrincipalForHosts(role))
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
} else {
parsedPrincipals, err = b.calculateValidPrincipals(data, role.DefaultUser, role.AllowedUsers, strutil.StrListContains)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
}
ttl, err := b.calculateTTL(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
criticalOptions, err := b.calculateCriticalOptions(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
extensions, err := b.calculateExtensions(data, role)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
privateKeyEntry, err := caKey(req.Storage, caPrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to read CA private key: %v", err)
}
if privateKeyEntry == nil || privateKeyEntry.Key == "" {
return nil, fmt.Errorf("failed to read CA private key")
}
signer, err := ssh.ParsePrivateKey([]byte(privateKeyEntry.Key))
if err != nil {
return nil, fmt.Errorf("failed to parse stored CA private key: %v", err)
}
cBundle := creationBundle{
KeyId: keyId,
PublicKey: userPublicKey,
Signer: signer,
ValidPrincipals: parsedPrincipals,
TTL: ttl,
CertificateType: certificateType,
Role: role,
CriticalOptions: criticalOptions,
Extensions: extensions,
}
certificate, err := cBundle.sign()
if err != nil {
return nil, err
}
signedSSHCertificate := ssh.MarshalAuthorizedKey(certificate)
if len(signedSSHCertificate) == 0 {
return nil, fmt.Errorf("error marshaling signed certificate")
}
response := &logical.Response{
Data: map[string]interface{}{
"serial_number": strconv.FormatUint(certificate.Serial, 16),
"signed_key": string(signedSSHCertificate),
},
}
return response, nil
}
func (b *backend) calculateValidPrincipals(data *framework.FieldData, defaultPrincipal, principalsAllowedByRole string, validatePrincipal func([]string, string) bool) ([]string, error) {
validPrincipals := ""
validPrincipalsRaw, ok := data.GetOk("valid_principals")
if ok {
validPrincipals = validPrincipalsRaw.(string)
} else {
validPrincipals = defaultPrincipal
}
parsedPrincipals := strutil.ParseDedupAndSortStrings(validPrincipals, ",")
allowedPrincipals := strutil.ParseDedupAndSortStrings(principalsAllowedByRole, ",")
switch {
case len(parsedPrincipals) == 0:
// There is nothing to process
return nil, nil
case len(allowedPrincipals) == 0:
// User has requested principals to be set, but role is not configured
// with any principals
return nil, fmt.Errorf("role is not configured to allow any principles")
default:
// Role was explicitly configured to allow any principal.
if principalsAllowedByRole == "*" {
return parsedPrincipals, nil
}
for _, principal := range parsedPrincipals {
if !validatePrincipal(allowedPrincipals, principal) {
return nil, fmt.Errorf("%v is not a valid value for valid_principals", principal)
}
}
return parsedPrincipals, nil
}
}
func validateValidPrincipalForHosts(role *sshRole) func([]string, string) bool {
return func(allowedPrincipals []string, validPrincipal string) bool {
for _, allowedPrincipal := range allowedPrincipals {
if allowedPrincipal == validPrincipal && role.AllowBareDomains {
return true
}
if role.AllowSubdomains && strings.HasSuffix(validPrincipal, "."+allowedPrincipal) {
return true
}
}
return false
}
}
func (b *backend) calculateCertificateType(data *framework.FieldData, role *sshRole) (uint32, error) {
requestedCertificateType := data.Get("cert_type").(string)
var certificateType uint32
switch requestedCertificateType {
case "user":
if !role.AllowUserCertificates {
return 0, errors.New("cert_type 'user' is not allowed by role")
}
certificateType = ssh.UserCert
case "host":
if !role.AllowHostCertificates {
return 0, errors.New("cert_type 'host' is not allowed by role")
}
certificateType = ssh.HostCert
default:
return 0, errors.New("cert_type must be either 'user' or 'host'")
}
return certificateType, nil
}
func (b *backend) calculateKeyId(data *framework.FieldData, req *logical.Request, role *sshRole) (string, error) {
keyId := data.Get("key_id").(string)
if keyId != "" && !role.AllowUserKeyIDs {
return "", fmt.Errorf("Setting key_id is not allowed by role")
}
if keyId == "" {
keyId = req.DisplayName
}
return keyId, nil
}
func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshRole) (map[string]string, error) {
unparsedCriticalOptions := data.Get("critical_options").(map[string]interface{})
if len(unparsedCriticalOptions) == 0 {
return role.DefaultCriticalOptions, nil
}
criticalOptions := convertMapToStringValue(unparsedCriticalOptions)
if role.AllowedCriticalOptions != "" {
notAllowedOptions := []string{}
allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",")
for option := range criticalOptions {
if !strutil.StrListContains(allowedCriticalOptions, option) {
notAllowedOptions = append(notAllowedOptions, option)
}
}
if len(notAllowedOptions) != 0 {
return nil, fmt.Errorf("Critical options not on allowed list: %v", notAllowedOptions)
}
}
return criticalOptions, nil
}
func (b *backend) calculateExtensions(data *framework.FieldData, role *sshRole) (map[string]string, error) {
unparsedExtensions := data.Get("extensions").(map[string]interface{})
if len(unparsedExtensions) == 0 {
return role.DefaultExtensions, nil
}
extensions := convertMapToStringValue(unparsedExtensions)
if role.AllowedExtensions != "" {
notAllowed := []string{}
allowedExtensions := strings.Split(role.AllowedExtensions, ",")
for extension := range extensions {
if !strutil.StrListContains(allowedExtensions, extension) {
notAllowed = append(notAllowed, extension)
}
}
if len(notAllowed) != 0 {
return nil, fmt.Errorf("extensions %v are not on allowed list", notAllowed)
}
}
return extensions, nil
}
func (b *backend) calculateTTL(data *framework.FieldData, role *sshRole) (time.Duration, error) {
var ttl, maxTTL time.Duration
var ttlField string
ttlFieldInt, ok := data.GetOk("ttl")
if !ok {
ttlField = role.TTL
} else {
ttlField = ttlFieldInt.(string)
}
if len(ttlField) == 0 {
ttl = b.System().DefaultLeaseTTL()
} else {
var err error
ttl, err = parseutil.ParseDurationSecond(ttlField)
if err != nil {
return 0, fmt.Errorf("invalid requested ttl: %s", err)
}
}
if len(role.MaxTTL) == 0 {
maxTTL = b.System().MaxLeaseTTL()
} else {
var err error
maxTTL, err = parseutil.ParseDurationSecond(role.MaxTTL)
if err != nil {
return 0, fmt.Errorf("invalid requested max ttl: %s", err)
}
}
if ttl > maxTTL {
// Don't error if they were using system defaults, only error if
// they specifically chose a bad TTL
if len(ttlField) == 0 {
ttl = maxTTL
} else {
return 0, fmt.Errorf("ttl is larger than maximum allowed (%d)", maxTTL/time.Second)
}
}
return ttl, nil
}
func (b *creationBundle) sign() (*ssh.Certificate, error) {
serialNumber, err := certutil.GenerateSerialNumber()
if err != nil {
return nil, err
}
now := time.Now()
certificate := &ssh.Certificate{
Serial: serialNumber.Uint64(),
Key: b.PublicKey,
KeyId: b.KeyId,
ValidPrincipals: b.ValidPrincipals,
ValidAfter: uint64(now.Add(-30 * time.Second).In(time.UTC).Unix()),
ValidBefore: uint64(now.Add(b.TTL).In(time.UTC).Unix()),
CertType: b.CertificateType,
Permissions: ssh.Permissions{
CriticalOptions: b.CriticalOptions,
Extensions: b.Extensions,
},
}
err = certificate.SignCert(rand.Reader, b.Signer)
if err != nil {
return nil, fmt.Errorf("failed to generate signed SSH key")
}
return certificate, nil
}