mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	Add -dev-tls-san flag (#22657)
* Add -dev-tls-san flag This is helpful when wanting to set up a dev server with TLS in Kubernetes and any other situations where the dev server may not be the same machine as the Vault client (e.g. in combination with some /etc/hosts entries) * Automatically add (best-effort only) -dev-listen-address host to extraSANs
This commit is contained in:
		
							
								
								
									
										3
									
								
								changelog/22657.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/22657.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | ```release-note:improvement | ||||||
|  | command/server: add `-dev-tls-san` flag to configure subject alternative names for the certificate generated when using `-dev-tls`. | ||||||
|  | ``` | ||||||
| @@ -132,6 +132,7 @@ type ServerCommand struct { | |||||||
| 	flagDev                bool | 	flagDev                bool | ||||||
| 	flagDevTLS             bool | 	flagDevTLS             bool | ||||||
| 	flagDevTLSCertDir      string | 	flagDevTLSCertDir      string | ||||||
|  | 	flagDevTLSSANs         []string | ||||||
| 	flagDevRootTokenID     string | 	flagDevRootTokenID     string | ||||||
| 	flagDevListenAddr      string | 	flagDevListenAddr      string | ||||||
| 	flagDevNoStoreToken    bool | 	flagDevNoStoreToken    bool | ||||||
| @@ -256,6 +257,18 @@ func (c *ServerCommand) Flags() *FlagSets { | |||||||
| 			"specified. If left unset, files are generated in a temporary directory.", | 			"specified. If left unset, files are generated in a temporary directory.", | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|  | 	f.StringSliceVar(&StringSliceVar{ | ||||||
|  | 		Name:    "dev-tls-san", | ||||||
|  | 		Target:  &c.flagDevTLSSANs, | ||||||
|  | 		Default: nil, | ||||||
|  | 		Usage: "Additional Subject Alternative Name (as a DNS name or IP address) " + | ||||||
|  | 			"to generate the certificate with if `-dev-tls` is specified. The " + | ||||||
|  | 			"certificate will always use localhost, localhost4, localhost6, " + | ||||||
|  | 			"localhost.localdomain, and the host name as alternate DNS names, " + | ||||||
|  | 			"and 127.0.0.1 as an alternate IP address. This flag can be specified " + | ||||||
|  | 			"multiple times to specify multiple SANs.", | ||||||
|  | 	}) | ||||||
|  |  | ||||||
| 	f.StringVar(&StringVar{ | 	f.StringVar(&StringVar{ | ||||||
| 		Name:    "dev-root-token-id", | 		Name:    "dev-root-token-id", | ||||||
| 		Target:  &c.flagDevRootTokenID, | 		Target:  &c.flagDevRootTokenID, | ||||||
| @@ -977,7 +990,17 @@ func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) { | |||||||
| 				return nil, nil, certDir, err | 				return nil, nil, certDir, err | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		config, err = server.DevTLSConfig(devStorageType, certDir) | 		extraSANs := c.flagDevTLSSANs | ||||||
|  | 		host, _, err := net.SplitHostPort(c.flagDevListenAddr) | ||||||
|  | 		if err == nil { | ||||||
|  | 			// 127.0.0.1 is the default, and already included in the SANs. | ||||||
|  | 			// Empty host means listen on all interfaces, but users should use the | ||||||
|  | 			// -dev-tls-san flag to get the right SANs in that case. | ||||||
|  | 			if host != "" && host != "127.0.0.1" { | ||||||
|  | 				extraSANs = append(extraSANs, host) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		config, err = server.DevTLSConfig(devStorageType, certDir, extraSANs) | ||||||
|  |  | ||||||
| 		f = func() { | 		f = func() { | ||||||
| 			if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil { | 			if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil { | ||||||
|   | |||||||
| @@ -176,13 +176,13 @@ ui = true | |||||||
| } | } | ||||||
|  |  | ||||||
| // DevTLSConfig is a Config that is used for dev tls mode of Vault. | // DevTLSConfig is a Config that is used for dev tls mode of Vault. | ||||||
| func DevTLSConfig(storageType, certDir string) (*Config, error) { | func DevTLSConfig(storageType, certDir string, extraSANs []string) (*Config, error) { | ||||||
| 	ca, err := GenerateCA() | 	ca, err := GenerateCA() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	cert, key, err := GenerateCert(ca.Template, ca.Signer) | 	cert, key, err := generateCert(ca.Template, ca.Signer, extraSANs) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -27,8 +27,8 @@ type CaCert struct { | |||||||
| 	Signer   crypto.Signer | 	Signer   crypto.Signer | ||||||
| } | } | ||||||
|  |  | ||||||
| // GenerateCert creates a new leaf cert from provided CA template and signer | // generateCert creates a new leaf cert from provided CA template and signer | ||||||
| func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (string, string, error) { | func generateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer, extraSANs []string) (string, string, error) { | ||||||
| 	// Create the private key | 	// Create the private key | ||||||
| 	signer, keyPEM, err := privateKey() | 	signer, keyPEM, err := privateKey() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -80,6 +80,13 @@ func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (str | |||||||
| 	if !foundHostname { | 	if !foundHostname { | ||||||
| 		template.DNSNames = append(template.DNSNames, hostname) | 		template.DNSNames = append(template.DNSNames, hostname) | ||||||
| 	} | 	} | ||||||
|  | 	for _, san := range extraSANs { | ||||||
|  | 		if ip := net.ParseIP(san); ip != nil { | ||||||
|  | 			template.IPAddresses = append(template.IPAddresses, ip) | ||||||
|  | 		} else { | ||||||
|  | 			template.DNSNames = append(template.DNSNames, san) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	bs, err := x509.CreateCertificate( | 	bs, err := x509.CreateCertificate( | ||||||
| 		rand.Reader, &template, caCertTemplate, signer.Public(), caSigner) | 		rand.Reader, &template, caCertTemplate, signer.Public(), caSigner) | ||||||
|   | |||||||
							
								
								
									
										80
									
								
								command/server/tls_util_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								command/server/tls_util_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | |||||||
|  | package server | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/x509" | ||||||
|  | 	"encoding/pem" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/go-secure-stdlib/strutil" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // TestGenerateCertExtraSans ensures the implementation backing the flag | ||||||
|  | // -dev-tls-san populates alternate DNS and IP address names in the generated | ||||||
|  | // certificate as expected. | ||||||
|  | func TestGenerateCertExtraSans(t *testing.T) { | ||||||
|  | 	ca, err := GenerateCA() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for name, tc := range map[string]struct { | ||||||
|  | 		extraSans           []string | ||||||
|  | 		expectedDNSNames    []string | ||||||
|  | 		expectedIPAddresses []string | ||||||
|  | 	}{ | ||||||
|  | 		"empty": {}, | ||||||
|  | 		"DNS names": { | ||||||
|  | 			extraSans:        []string{"foo", "foo.bar"}, | ||||||
|  | 			expectedDNSNames: []string{"foo", "foo.bar"}, | ||||||
|  | 		}, | ||||||
|  | 		"IP addresses": { | ||||||
|  | 			extraSans:           []string{"0.0.0.0", "::1"}, | ||||||
|  | 			expectedIPAddresses: []string{"0.0.0.0", "::1"}, | ||||||
|  | 		}, | ||||||
|  | 		"mixed": { | ||||||
|  | 			extraSans:           []string{"bar", "0.0.0.0", "::1"}, | ||||||
|  | 			expectedDNSNames:    []string{"bar"}, | ||||||
|  | 			expectedIPAddresses: []string{"0.0.0.0", "::1"}, | ||||||
|  | 		}, | ||||||
|  | 	} { | ||||||
|  | 		t.Run(name, func(t *testing.T) { | ||||||
|  | 			certStr, _, err := generateCert(ca.Template, ca.Signer, tc.extraSans) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			block, _ := pem.Decode([]byte(certStr)) | ||||||
|  | 			cert, err := x509.ParseCertificate(block.Bytes) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			expectedDNSNamesLen := len(tc.expectedDNSNames) + 5 | ||||||
|  | 			if len(cert.DNSNames) != expectedDNSNamesLen { | ||||||
|  | 				t.Errorf("Wrong number of DNS names, expected %d but got %v", expectedDNSNamesLen, cert.DNSNames) | ||||||
|  | 			} | ||||||
|  | 			expectedIPAddrLen := len(tc.expectedIPAddresses) + 1 | ||||||
|  | 			if len(cert.IPAddresses) != expectedIPAddrLen { | ||||||
|  | 				t.Errorf("Wrong number of IP addresses, expected %d but got %v", expectedIPAddrLen, cert.IPAddresses) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			for _, expected := range tc.expectedDNSNames { | ||||||
|  | 				if !strutil.StrListContains(cert.DNSNames, expected) { | ||||||
|  | 					t.Errorf("Missing DNS name %s", expected) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			for _, expected := range tc.expectedIPAddresses { | ||||||
|  | 				var found bool | ||||||
|  | 				for _, ip := range cert.IPAddresses { | ||||||
|  | 					if ip.String() == expected { | ||||||
|  | 						found = true | ||||||
|  | 						break | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				if !found { | ||||||
|  | 					t.Errorf("Missing IP address %s", expected) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -24,6 +24,7 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/sdk/physical" | 	"github.com/hashicorp/vault/sdk/physical" | ||||||
| 	physInmem "github.com/hashicorp/vault/sdk/physical/inmem" | 	physInmem "github.com/hashicorp/vault/sdk/physical/inmem" | ||||||
| 	"github.com/mitchellh/cli" | 	"github.com/mitchellh/cli" | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -383,13 +384,19 @@ func TestConfigureDevTLS(t *testing.T) { | |||||||
| 			fun() | 			fun() | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription) | 		t.Run(testcase.TestDescription, func(t *testing.T) { | ||||||
| 		require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription) | 			assert.Equal(t, testcase.DeferFuncNotNil, (fun != nil)) | ||||||
| 		if testcase.ConfigNotNil { | 			assert.Equal(t, testcase.ConfigNotNil, cfg != nil) | ||||||
| 			require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription) | 			if testcase.ConfigNotNil && cfg != nil { | ||||||
| 			require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription) | 				assert.True(t, len(cfg.Listeners) > 0) | ||||||
|  | 				assert.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable) | ||||||
| 			} | 			} | ||||||
| 		require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription) | 			assert.Equal(t, testcase.CertPathEmpty, len(certPath) == 0) | ||||||
| 		require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription) | 			if testcase.ErrNotNil { | ||||||
|  | 				assert.Error(t, err) | ||||||
|  | 			} else { | ||||||
|  | 				assert.NoError(t, err) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Tom Proctor
					Tom Proctor