From 87649219ffed19b85dcd7a0ce102d2bd1a9ed6d5 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Thu, 31 Aug 2023 23:31:42 +0100 Subject: [PATCH] Add -dev-tls-san flag (#22657) * Add -dev-tls-san flag This is helpful when wanting to set up a dev server with TLS in Kubernetes and any other situations where the dev server may not be the same machine as the Vault client (e.g. in combination with some /etc/hosts entries) * Automatically add (best-effort only) -dev-listen-address host to extraSANs --- changelog/22657.txt | 3 ++ command/server.go | 25 ++++++++++- command/server/config.go | 4 +- command/server/tls_util.go | 11 ++++- command/server/tls_util_test.go | 80 +++++++++++++++++++++++++++++++++ command/server_test.go | 23 ++++++---- 6 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 changelog/22657.txt create mode 100644 command/server/tls_util_test.go diff --git a/changelog/22657.txt b/changelog/22657.txt new file mode 100644 index 0000000000..89a8ab4409 --- /dev/null +++ b/changelog/22657.txt @@ -0,0 +1,3 @@ +```release-note:improvement +command/server: add `-dev-tls-san` flag to configure subject alternative names for the certificate generated when using `-dev-tls`. +``` diff --git a/command/server.go b/command/server.go index eb8767b4d9..5758b3d13e 100644 --- a/command/server.go +++ b/command/server.go @@ -132,6 +132,7 @@ type ServerCommand struct { flagDev bool flagDevTLS bool flagDevTLSCertDir string + flagDevTLSSANs []string flagDevRootTokenID string flagDevListenAddr string flagDevNoStoreToken bool @@ -256,6 +257,18 @@ func (c *ServerCommand) Flags() *FlagSets { "specified. If left unset, files are generated in a temporary directory.", }) + f.StringSliceVar(&StringSliceVar{ + Name: "dev-tls-san", + Target: &c.flagDevTLSSANs, + Default: nil, + Usage: "Additional Subject Alternative Name (as a DNS name or IP address) " + + "to generate the certificate with if `-dev-tls` is specified. The " + + "certificate will always use localhost, localhost4, localhost6, " + + "localhost.localdomain, and the host name as alternate DNS names, " + + "and 127.0.0.1 as an alternate IP address. This flag can be specified " + + "multiple times to specify multiple SANs.", + }) + f.StringVar(&StringVar{ Name: "dev-root-token-id", Target: &c.flagDevRootTokenID, @@ -977,7 +990,17 @@ func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) { return nil, nil, certDir, err } } - config, err = server.DevTLSConfig(devStorageType, certDir) + extraSANs := c.flagDevTLSSANs + host, _, err := net.SplitHostPort(c.flagDevListenAddr) + if err == nil { + // 127.0.0.1 is the default, and already included in the SANs. + // Empty host means listen on all interfaces, but users should use the + // -dev-tls-san flag to get the right SANs in that case. + if host != "" && host != "127.0.0.1" { + extraSANs = append(extraSANs, host) + } + } + config, err = server.DevTLSConfig(devStorageType, certDir, extraSANs) f = func() { if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil { diff --git a/command/server/config.go b/command/server/config.go index 1b4a5e8f99..fc9ab7cfd6 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -176,13 +176,13 @@ ui = true } // DevTLSConfig is a Config that is used for dev tls mode of Vault. -func DevTLSConfig(storageType, certDir string) (*Config, error) { +func DevTLSConfig(storageType, certDir string, extraSANs []string) (*Config, error) { ca, err := GenerateCA() if err != nil { return nil, err } - cert, key, err := GenerateCert(ca.Template, ca.Signer) + cert, key, err := generateCert(ca.Template, ca.Signer, extraSANs) if err != nil { return nil, err } diff --git a/command/server/tls_util.go b/command/server/tls_util.go index 3782370df6..cd07dde927 100644 --- a/command/server/tls_util.go +++ b/command/server/tls_util.go @@ -27,8 +27,8 @@ type CaCert struct { Signer crypto.Signer } -// GenerateCert creates a new leaf cert from provided CA template and signer -func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (string, string, error) { +// generateCert creates a new leaf cert from provided CA template and signer +func generateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer, extraSANs []string) (string, string, error) { // Create the private key signer, keyPEM, err := privateKey() if err != nil { @@ -80,6 +80,13 @@ func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (str if !foundHostname { template.DNSNames = append(template.DNSNames, hostname) } + for _, san := range extraSANs { + if ip := net.ParseIP(san); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, san) + } + } bs, err := x509.CreateCertificate( rand.Reader, &template, caCertTemplate, signer.Public(), caSigner) diff --git a/command/server/tls_util_test.go b/command/server/tls_util_test.go new file mode 100644 index 0000000000..31ef5a3fdb --- /dev/null +++ b/command/server/tls_util_test.go @@ -0,0 +1,80 @@ +package server + +import ( + "crypto/x509" + "encoding/pem" + "testing" + + "github.com/hashicorp/go-secure-stdlib/strutil" +) + +// TestGenerateCertExtraSans ensures the implementation backing the flag +// -dev-tls-san populates alternate DNS and IP address names in the generated +// certificate as expected. +func TestGenerateCertExtraSans(t *testing.T) { + ca, err := GenerateCA() + if err != nil { + t.Fatal(err) + } + + for name, tc := range map[string]struct { + extraSans []string + expectedDNSNames []string + expectedIPAddresses []string + }{ + "empty": {}, + "DNS names": { + extraSans: []string{"foo", "foo.bar"}, + expectedDNSNames: []string{"foo", "foo.bar"}, + }, + "IP addresses": { + extraSans: []string{"0.0.0.0", "::1"}, + expectedIPAddresses: []string{"0.0.0.0", "::1"}, + }, + "mixed": { + extraSans: []string{"bar", "0.0.0.0", "::1"}, + expectedDNSNames: []string{"bar"}, + expectedIPAddresses: []string{"0.0.0.0", "::1"}, + }, + } { + t.Run(name, func(t *testing.T) { + certStr, _, err := generateCert(ca.Template, ca.Signer, tc.extraSans) + if err != nil { + t.Fatal(err) + } + + block, _ := pem.Decode([]byte(certStr)) + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatal(err) + } + + expectedDNSNamesLen := len(tc.expectedDNSNames) + 5 + if len(cert.DNSNames) != expectedDNSNamesLen { + t.Errorf("Wrong number of DNS names, expected %d but got %v", expectedDNSNamesLen, cert.DNSNames) + } + expectedIPAddrLen := len(tc.expectedIPAddresses) + 1 + if len(cert.IPAddresses) != expectedIPAddrLen { + t.Errorf("Wrong number of IP addresses, expected %d but got %v", expectedIPAddrLen, cert.IPAddresses) + } + + for _, expected := range tc.expectedDNSNames { + if !strutil.StrListContains(cert.DNSNames, expected) { + t.Errorf("Missing DNS name %s", expected) + } + } + for _, expected := range tc.expectedIPAddresses { + var found bool + for _, ip := range cert.IPAddresses { + if ip.String() == expected { + found = true + break + } + } + if !found { + t.Errorf("Missing IP address %s", expected) + } + } + }) + } +} diff --git a/command/server_test.go b/command/server_test.go index 9436fdfa8a..677705176b 100644 --- a/command/server_test.go +++ b/command/server_test.go @@ -24,6 +24,7 @@ import ( "github.com/hashicorp/vault/sdk/physical" physInmem "github.com/hashicorp/vault/sdk/physical/inmem" "github.com/mitchellh/cli" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -383,13 +384,19 @@ func TestConfigureDevTLS(t *testing.T) { fun() } - require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription) - require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription) - if testcase.ConfigNotNil { - require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription) - require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription) - } - require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription) - require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription) + t.Run(testcase.TestDescription, func(t *testing.T) { + assert.Equal(t, testcase.DeferFuncNotNil, (fun != nil)) + assert.Equal(t, testcase.ConfigNotNil, cfg != nil) + if testcase.ConfigNotNil && cfg != nil { + assert.True(t, len(cfg.Listeners) > 0) + assert.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable) + } + assert.Equal(t, testcase.CertPathEmpty, len(certPath) == 0) + if testcase.ErrNotNil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) } }