mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 11:38:02 +00:00
allow a TLS server name to be configured for SSH agents (#1720)
This commit is contained in:
@@ -62,6 +62,7 @@ type SSHHelperConfig struct {
|
||||
AllowedCidrList string `hcl:"allowed_cidr_list"`
|
||||
AllowedRoles string `hcl:"allowed_roles"`
|
||||
TLSSkipVerify bool `hcl:"tls_skip_verify"`
|
||||
TLSServerName string `hcl:"tls_server_name"`
|
||||
}
|
||||
|
||||
// SetTLSParameters sets the TLS parameters for this SSH agent.
|
||||
@@ -70,6 +71,7 @@ func (c *SSHHelperConfig) SetTLSParameters(clientConfig *Config, certPool *x509.
|
||||
InsecureSkipVerify: c.TLSSkipVerify,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: certPool,
|
||||
ServerName: c.TLSServerName,
|
||||
}
|
||||
|
||||
transport := cleanhttp.DefaultTransport()
|
||||
@@ -77,6 +79,16 @@ func (c *SSHHelperConfig) SetTLSParameters(clientConfig *Config, certPool *x509.
|
||||
clientConfig.HttpClient.Transport = transport
|
||||
}
|
||||
|
||||
// Returns true if any of the following conditions are true:
|
||||
// * CA cert is configured
|
||||
// * CA path is configured
|
||||
// * configured to skip certificate verification
|
||||
// * TLS server name is configured
|
||||
//
|
||||
func (c *SSHHelperConfig) shouldSetTLSParameters() bool {
|
||||
return c.CACert != "" || c.CAPath != "" || c.TLSServerName != "" || c.TLSSkipVerify
|
||||
}
|
||||
|
||||
// NewClient returns a new client for the configuration. This client will be used by the
|
||||
// vault-ssh-helper to communicate with Vault server and verify the OTP entered by user.
|
||||
// If the configuration supplies Vault SSL certificates, then the client will
|
||||
@@ -89,7 +101,7 @@ func (c *SSHHelperConfig) NewClient() (*Client, error) {
|
||||
clientConfig.Address = c.VaultAddr
|
||||
|
||||
// Check if certificates are provided via config file.
|
||||
if c.CACert != "" || c.CAPath != "" || c.TLSSkipVerify {
|
||||
if c.shouldSetTLSParameters() {
|
||||
rootConfig := &rootcerts.Config{
|
||||
CAFile: c.CACert,
|
||||
CAPath: c.CAPath,
|
||||
@@ -145,6 +157,7 @@ func ParseSSHHelperConfig(contents string) (*SSHHelperConfig, error) {
|
||||
"allowed_cidr_list",
|
||||
"allowed_roles",
|
||||
"tls_skip_verify",
|
||||
"tls_server_name",
|
||||
}
|
||||
if err := checkHCLKeys(list, valid); err != nil {
|
||||
return nil, multierror.Prefix(err, "ssh_helper:")
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func TestSSH_CreateTLSClient(t *testing.T) {
|
||||
@@ -30,6 +31,29 @@ func TestSSH_CreateTLSClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSH_CreateTLSClient_tlsServerName(t *testing.T) {
|
||||
// Ensure that the HTTP client is associated with the configured TLS server name.
|
||||
var tlsServerName = "tls.server.name"
|
||||
|
||||
config, err := ParseSSHHelperConfig(fmt.Sprintf(`
|
||||
vault_addr = "1.2.3.4"
|
||||
tls_server_name = "%s"
|
||||
`, tlsServerName))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error loading config: %s", err))
|
||||
}
|
||||
|
||||
client, err := config.NewClient()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating the client: %s", err))
|
||||
}
|
||||
|
||||
actualTLSServerName := client.config.HttpClient.Transport.(*http.Transport).TLSClientConfig.ServerName
|
||||
if actualTLSServerName != tlsServerName {
|
||||
panic(fmt.Sprintf("incorrect TLS server name. expected: %s actual: %s", tlsServerName, actualTLSServerName))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSHHelperConfig(t *testing.T) {
|
||||
config, err := ParseSSHHelperConfig(`
|
||||
vault_addr = "1.2.3.4"
|
||||
@@ -67,3 +91,20 @@ nope = "bad"
|
||||
t.Errorf("bad error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSHHelperConfig_tlsServerName(t *testing.T) {
|
||||
var tlsServerName = "tls.server.name"
|
||||
|
||||
config, err := ParseSSHHelperConfig(fmt.Sprintf(`
|
||||
vault_addr = "1.2.3.4"
|
||||
tls_server_name = "%s"
|
||||
`, tlsServerName))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if config.TLSServerName != tlsServerName {
|
||||
t.Errorf("incorrect TLS server name. expected: %s actual: %s", tlsServerName, config.TLSServerName)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user