Changes from code review

Major changes are:
* Remove duplicate code
* Check the public key used to configure the backend is a valid one
This commit is contained in:
Will May
2017-02-28 22:08:10 +00:00
committed by Vishal Nayak
parent 59397250da
commit 7d9cb5bffe
6 changed files with 45 additions and 80 deletions

View File

@@ -55,7 +55,6 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
Secrets: []*framework.Secret{
secretDynamicKey(&b),
secretOTP(&b),
secretCerts(&b),
},
Init: b.Initialize,

View File

@@ -36,24 +36,32 @@ For security reasons, the private key cannot be retrieved later.`,
func (b *backend) pathCAWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Put(&logical.StorageEntry{
publicKey := data.Get("public_key").(string)
privateKey := data.Get("private_key").(string)
_, err := ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return nil, errutil.UserError{Err: fmt.Sprintf(`Unable to parse "private_key" as an SSH private key: %s`, err)}
}
_, err = parsePublicSSHKey(publicKey)
if err != nil {
return nil, errutil.UserError{Err: fmt.Sprintf(`Unable to parse "public_key" as an SSH public key: %s`, err)}
}
err = req.Storage.Put(&logical.StorageEntry{
Key: "public_key",
Value: []byte(data.Get("public_key").(string)),
Value: []byte(publicKey),
})
if err != nil {
return nil, err
}
bundle := signingBundle{
Certificate: data.Get("private_key").(string),
Certificate: privateKey,
}
_, err = ssh.ParsePrivateKey([]byte(bundle.Certificate))
if err != nil {
return nil, errutil.UserError{Err: fmt.Sprintf(`Unable to parse "private_key" as an SSH private key: %s`, err)}
}
entry, err := logical.StorageEntryJSON("config/ssh_certificate_bundle", bundle)
entry, err := logical.StorageEntryJSON("config/ca_bundle", bundle)
if err != nil {
return nil, err
}

View File

@@ -194,7 +194,7 @@ func pathRoles(b *backend) *framework.Path {
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
A comma-separated list of extensions that certificates can have when signed.
To allow any critical options, set this to an empty string.
To allow any extensions, set this to an empty string.
`,
},
"default_critical_options": &framework.FieldSchema{
@@ -232,7 +232,7 @@ func pathRoles(b *backend) *framework.Path {
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
If set, host certificates that are requested are allowed to use the base domains listed in
"allowed_users", e.g. "example.com".
"allowed_domains", e.g. "example.com".
This is a separate option as in some cases this can be considered a security threat.
`,
},
@@ -240,7 +240,7 @@ func pathRoles(b *backend) *framework.Path {
Type: framework.TypeBool,
Description: `
[Not applicable for Dynamic type] [Not applicable for OTP type] [Optional for CA type]
If set, host certificates that are requested are allowed to use subdomains of those listed in "allowed_users".
If set, host certificates that are requested are allowed to use subdomains of those listed in "allowed_domains".
`,
},
},
@@ -525,6 +525,7 @@ func (b *backend) pathRoleRead(req *logical.Request, d *framework.FieldData) (*l
"allow_host_certificates": role.AllowHostCertificates,
"allow_bare_domains": role.AllowBareDomains,
"allow_subdomains": role.AllowSubdomains,
"key_type": role.KeyType,
},
}, nil
} else {

View File

@@ -2,17 +2,16 @@ package ssh
import (
"crypto/rand"
"encoding/base64"
"fmt"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/errutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"golang.org/x/crypto/ssh"
"log"
"strconv"
"strings"
"time"
"github.com/hashicorp/vault/helper/strutil"
)
type signingBundle struct {
@@ -105,23 +104,11 @@ func (b *backend) pathSignCertificate(req *logical.Request, data *framework.Fiel
return nil, errutil.UserError{Err: "missing public_key"}
}
keyParts := strings.Split(publicKey, " ")
if len(keyParts) > 1 {
// Someone has sent the 'full' public key rather than just the base64 encoded part that the ssh library wants
publicKey = keyParts[1]
}
decodedKey, err := base64.StdEncoding.DecodeString(publicKey)
userPublicKey, err := parsePublicSSHKey(publicKey)
if err != nil {
return nil, errutil.UserError{Err: "Unable to decode \"public_key\" as SSH key"}
}
userPublicKey, err := ssh.ParsePublicKey([]byte(decodedKey))
if err != nil {
log.Printf("Failed to parse key: %s", err)
return nil, errutil.UserError{Err: "Unable to parse \"public_key\" as SSH key"}
}
keyId := data.Get("key_id").(string)
if keyId == "" {
keyId = req.DisplayName
@@ -139,7 +126,7 @@ func (b *backend) pathSignCertificate(req *logical.Request, data *framework.Fiel
return nil, err
}
} else {
parsedPrincipals, err = b.calculateValidPrincipals(data, role.DefaultUser, role.AllowedUsers, contains)
parsedPrincipals, err = b.calculateValidPrincipals(data, role.DefaultUser, role.AllowedUsers, strutil.StrListContains)
if err != nil {
return nil, err
}
@@ -160,7 +147,7 @@ func (b *backend) pathSignCertificate(req *logical.Request, data *framework.Fiel
return nil, err
}
storedBundle, err := req.Storage.Get("config/ssh_certificate_bundle")
storedBundle, err := req.Storage.Get("config/ca_bundle")
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to fetch local CA certificate/key: %v", err)}
}
@@ -192,15 +179,12 @@ func (b *backend) pathSignCertificate(req *logical.Request, data *framework.Fiel
signedSSHCertificate := string(ssh.MarshalAuthorizedKey(certificate))
response := b.Secret(SecretCertsType).Response(
map[string]interface{}{
response := &logical.Response{
Data: map[string]interface{}{
"serial_number": strconv.FormatUint(certificate.Serial, 16),
"signed_key": signedSSHCertificate,
},
map[string]interface{}{
"serial_number": strconv.FormatUint(certificate.Serial, 16),
"signed_key": signedSSHCertificate,
})
}
return response, nil
}
@@ -291,7 +275,7 @@ func (b *backend) calculateCriticalOptions(data *framework.FieldData, role *sshR
allowedCriticalOptions := strings.Split(role.AllowedCriticalOptions, ",")
for option := range criticalOptions {
if !contains(allowedCriticalOptions, option) {
if !strutil.StrListContains(allowedCriticalOptions, option) {
notAllowedOptions = append(notAllowedOptions, option)
}
}
@@ -317,7 +301,7 @@ func (b *backend) calculateExtensions(data *framework.FieldData, role *sshRole)
allowedExtensions := strings.Split(role.AllowedExtensions, ",")
for extension := range extensions {
if !contains(allowedExtensions, extension) {
if !strutil.StrListContains(allowedExtensions, extension) {
notAllowed = append(notAllowed, extension)
}
}

View File

@@ -1,33 +0,0 @@
package ssh
import (
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
// SecretCertsType is the name used to identify this type
const SecretCertsType = "secret_ssh_ca"
func secretCerts(b *backend) *framework.Secret {
return &framework.Secret{
Type: SecretCertsType,
Fields: map[string]*framework.FieldSchema{
"signed_key": &framework.FieldSchema{
Type: framework.TypeString,
Description: "The signd certificate.",
},
"serial_number": &framework.FieldSchema{
Type: framework.TypeString,
Description: `The serial number of the certificate, for handy
reference`,
},
},
Revoke: b.secretCredsRevoke,
}
}
func (b *backend) secretCredsRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// this backend doesn't support CRL, so there's nothing that can be done when a certificate is revoked
return &logical.Response{}, nil
}

View File

@@ -189,6 +189,21 @@ func createSSHComm(logger log.Logger, username, ip string, port int, hostkey str
return SSHCommNew(fmt.Sprintf("%s:%d", ip, port), config)
}
func parsePublicSSHKey(key string) (ssh.PublicKey, error) {
keyParts := strings.Split(key, " ")
if len(keyParts) > 1 {
// Someone has sent the 'full' public key rather than just the base64 encoded part that the ssh library wants
key = keyParts[1]
}
decodedKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
return ssh.ParsePublicKey([]byte(decodedKey))
}
func convertMapToStringValue(initial map[string]interface{}) map[string]string {
result := map[string]string{}
for key, value := range initial {
@@ -196,12 +211,3 @@ func convertMapToStringValue(initial map[string]interface{}) map[string]string {
}
return result
}
func contains(array []string, needed string) bool {
for _, item := range array {
if item == needed {
return true
}
}
return false
}