diff --git a/api/ssh_agent.go b/api/ssh_agent.go index 5a8192ae95..729fd99c43 100644 --- a/api/ssh_agent.go +++ b/api/ssh_agent.go @@ -62,6 +62,7 @@ type SSHHelperConfig struct { AllowedCidrList string `hcl:"allowed_cidr_list"` AllowedRoles string `hcl:"allowed_roles"` TLSSkipVerify bool `hcl:"tls_skip_verify"` + TLSServerName string `hcl:"tls_server_name"` } // SetTLSParameters sets the TLS parameters for this SSH agent. @@ -70,6 +71,7 @@ func (c *SSHHelperConfig) SetTLSParameters(clientConfig *Config, certPool *x509. InsecureSkipVerify: c.TLSSkipVerify, MinVersion: tls.VersionTLS12, RootCAs: certPool, + ServerName: c.TLSServerName, } transport := cleanhttp.DefaultTransport() @@ -77,6 +79,16 @@ func (c *SSHHelperConfig) SetTLSParameters(clientConfig *Config, certPool *x509. clientConfig.HttpClient.Transport = transport } +// Returns true if any of the following conditions are true: +// * CA cert is configured +// * CA path is configured +// * configured to skip certificate verification +// * TLS server name is configured +// +func (c *SSHHelperConfig) shouldSetTLSParameters() bool { + return c.CACert != "" || c.CAPath != "" || c.TLSServerName != "" || c.TLSSkipVerify +} + // NewClient returns a new client for the configuration. This client will be used by the // vault-ssh-helper to communicate with Vault server and verify the OTP entered by user. // If the configuration supplies Vault SSL certificates, then the client will @@ -89,7 +101,7 @@ func (c *SSHHelperConfig) NewClient() (*Client, error) { clientConfig.Address = c.VaultAddr // Check if certificates are provided via config file. - if c.CACert != "" || c.CAPath != "" || c.TLSSkipVerify { + if c.shouldSetTLSParameters() { rootConfig := &rootcerts.Config{ CAFile: c.CACert, CAPath: c.CAPath, @@ -145,6 +157,7 @@ func ParseSSHHelperConfig(contents string) (*SSHHelperConfig, error) { "allowed_cidr_list", "allowed_roles", "tls_skip_verify", + "tls_server_name", } if err := checkHCLKeys(list, valid); err != nil { return nil, multierror.Prefix(err, "ssh_helper:") diff --git a/api/ssh_agent_test.go b/api/ssh_agent_test.go index 80e4f22aa4..915fbd48e8 100644 --- a/api/ssh_agent_test.go +++ b/api/ssh_agent_test.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" "testing" + "net/http" ) func TestSSH_CreateTLSClient(t *testing.T) { @@ -30,6 +31,29 @@ func TestSSH_CreateTLSClient(t *testing.T) { } } +func TestSSH_CreateTLSClient_tlsServerName(t *testing.T) { + // Ensure that the HTTP client is associated with the configured TLS server name. + var tlsServerName = "tls.server.name" + + config, err := ParseSSHHelperConfig(fmt.Sprintf(` +vault_addr = "1.2.3.4" +tls_server_name = "%s" +`, tlsServerName)) + if err != nil { + panic(fmt.Sprintf("error loading config: %s", err)) + } + + client, err := config.NewClient() + if err != nil { + panic(fmt.Sprintf("error creating the client: %s", err)) + } + + actualTLSServerName := client.config.HttpClient.Transport.(*http.Transport).TLSClientConfig.ServerName + if actualTLSServerName != tlsServerName { + panic(fmt.Sprintf("incorrect TLS server name. expected: %s actual: %s", tlsServerName, actualTLSServerName)) + } +} + func TestParseSSHHelperConfig(t *testing.T) { config, err := ParseSSHHelperConfig(` vault_addr = "1.2.3.4" @@ -67,3 +91,20 @@ nope = "bad" t.Errorf("bad error: %s", err) } } + +func TestParseSSHHelperConfig_tlsServerName(t *testing.T) { + var tlsServerName = "tls.server.name" + + config, err := ParseSSHHelperConfig(fmt.Sprintf(` +vault_addr = "1.2.3.4" +tls_server_name = "%s" +`, tlsServerName)) + + if err != nil { + t.Fatal(err) + } + + if config.TLSServerName != tlsServerName { + t.Errorf("incorrect TLS server name. expected: %s actual: %s", tlsServerName, config.TLSServerName) + } +}