diff --git a/api/ssh_agent.go b/api/ssh_agent.go index 182e13b873..c5db0671c7 100644 --- a/api/ssh_agent.go +++ b/api/ssh_agent.go @@ -8,18 +8,22 @@ import ( "os" "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/hcl" + "github.com/hashicorp/hcl/hcl/ast" "github.com/mitchellh/mapstructure" ) const ( - // Default path at which SSH backend will be mounted in Vault server + // SSHHelperDefaultMountPoint is the default path at which SSH backend will be + // mounted in the Vault server. SSHHelperDefaultMountPoint = "ssh" - // Echo request message sent as OTP by the vault-ssh-helper + // VerifyEchoRequest is the echo request message sent as OTP by the helper. VerifyEchoRequest = "verify-echo-request" - // Echo response message sent as a response to OTP matching echo request + // VerifyEchoResponse is the echo response message sent as a response to OTP + // matching echo request. VerifyEchoResponse = "verify-echo-response" ) @@ -55,8 +59,7 @@ type SSHHelperConfig struct { TLSSkipVerify bool `hcl:"tls_skip_verify"` } -// TLSClient returns a HTTP client that uses TLS verification (TLS 1.2) for a given -// certificate pool. +// SetTLSParameters sets the TLS parameters for this SSH agent. func (c *SSHHelperConfig) SetTLSParameters(clientConfig *Config, certPool *x509.CertPool) { tlsConfig := &tls.Config{ InsecureSkipVerify: c.TLSSkipVerify, @@ -112,29 +115,48 @@ func (c *SSHHelperConfig) NewClient() (*Client, error) { // Vault address is a required parameter. // Mount point defaults to "ssh". func LoadSSHHelperConfig(path string) (*SSHHelperConfig, error) { - var config SSHHelperConfig contents, err := ioutil.ReadFile(path) - if !os.IsNotExist(err) { - obj, err := hcl.Parse(string(contents)) - if err != nil { - return nil, err - } + if err != nil && !os.IsNotExist(err) { + return nil, multierror.Prefix(err, "ssh_helper:") + } + return ParseSSHHelperConfig(string(contents)) +} - if err := hcl.DecodeObject(&config, obj); err != nil { - return nil, err - } - } else { - return nil, err +// ParseSSHHelperConfig parses the given contents as a string for the SSHHelper +// configuration. +func ParseSSHHelperConfig(contents string) (*SSHHelperConfig, error) { + root, err := hcl.Parse(string(contents)) + if err != nil { + return nil, fmt.Errorf("ssh_helper: error parsing config: %s", err) } - if config.VaultAddr == "" { - return nil, fmt.Errorf("config missing vault_addr") - } - if config.SSHMountPoint == "" { - config.SSHMountPoint = SSHHelperDefaultMountPoint + list, ok := root.Node.(*ast.ObjectList) + if !ok { + return nil, fmt.Errorf("ssh_helper: error parsing config: file doesn't contain a root object") } - return &config, nil + valid := []string{ + "vault_addr", + "ssh_mount_point", + "ca_cert", + "ca_path", + "allowed_cidr_list", + "tls_skip_verify", + } + if err := checkHCLKeys(list, valid); err != nil { + return nil, multierror.Prefix(err, "ssh_helper:") + } + + var c SSHHelperConfig + c.SSHMountPoint = SSHHelperDefaultMountPoint + if err := hcl.DecodeObject(&c, list); err != nil { + return nil, multierror.Prefix(err, "ssh_helper:") + } + + if c.VaultAddr == "" { + return nil, fmt.Errorf("ssh_helper: missing config 'vault_addr'") + } + return &c, nil } // SSHHelper creates an SSHHelper object which can talk to Vault server with SSH backend @@ -189,3 +211,31 @@ func (c *SSHHelper) Verify(otp string) (*SSHVerifyResponse, error) { } return &verifyResp, nil } + +func checkHCLKeys(node ast.Node, valid []string) error { + var list *ast.ObjectList + switch n := node.(type) { + case *ast.ObjectList: + list = n + case *ast.ObjectType: + list = n.List + default: + return fmt.Errorf("cannot check HCL keys of type %T", n) + } + + validMap := make(map[string]struct{}, len(valid)) + for _, v := range valid { + validMap[v] = struct{}{} + } + + var result error + for _, item := range list.Items { + key := item.Keys[0].Token.Value().(string) + if _, ok := validMap[key]; !ok { + result = multierror.Append(result, fmt.Errorf( + "invalid key '%s' on line %d", key, item.Assign.Line)) + } + } + + return result +} diff --git a/api/ssh_agent_test.go b/api/ssh_agent_test.go index 6bdb0456fd..80e4f22aa4 100644 --- a/api/ssh_agent_test.go +++ b/api/ssh_agent_test.go @@ -2,6 +2,7 @@ package api import ( "fmt" + "strings" "testing" ) @@ -28,3 +29,41 @@ func TestSSH_CreateTLSClient(t *testing.T) { panic(fmt.Sprintf("error creating client with TLS transport")) } } + +func TestParseSSHHelperConfig(t *testing.T) { + config, err := ParseSSHHelperConfig(` + vault_addr = "1.2.3.4" +`) + if err != nil { + t.Fatal(err) + } + + if config.SSHMountPoint != SSHHelperDefaultMountPoint { + t.Errorf("expected %q to be %q", config.SSHMountPoint, SSHHelperDefaultMountPoint) + } +} + +func TestParseSSHHelperConfig_missingVaultAddr(t *testing.T) { + _, err := ParseSSHHelperConfig("") + if err == nil { + t.Fatal("expected error") + } + + if !strings.Contains(err.Error(), "ssh_helper: missing config 'vault_addr'") { + t.Errorf("bad error: %s", err) + } +} + +func TestParseSSHHelperConfig_badKeys(t *testing.T) { + _, err := ParseSSHHelperConfig(` +vault_addr = "1.2.3.4" +nope = "bad" +`) + if err == nil { + t.Fatal("expected error") + } + + if !strings.Contains(err.Error(), "ssh_helper: invalid key 'nope' on line 3") { + t.Errorf("bad error: %s", err) + } +}