mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 17:52:32 +00:00
Add -dev-tls-san flag (#22657)
* Add -dev-tls-san flag This is helpful when wanting to set up a dev server with TLS in Kubernetes and any other situations where the dev server may not be the same machine as the Vault client (e.g. in combination with some /etc/hosts entries) * Automatically add (best-effort only) -dev-listen-address host to extraSANs
This commit is contained in:
3
changelog/22657.txt
Normal file
3
changelog/22657.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
```release-note:improvement
|
||||
command/server: add `-dev-tls-san` flag to configure subject alternative names for the certificate generated when using `-dev-tls`.
|
||||
```
|
||||
@@ -132,6 +132,7 @@ type ServerCommand struct {
|
||||
flagDev bool
|
||||
flagDevTLS bool
|
||||
flagDevTLSCertDir string
|
||||
flagDevTLSSANs []string
|
||||
flagDevRootTokenID string
|
||||
flagDevListenAddr string
|
||||
flagDevNoStoreToken bool
|
||||
@@ -256,6 +257,18 @@ func (c *ServerCommand) Flags() *FlagSets {
|
||||
"specified. If left unset, files are generated in a temporary directory.",
|
||||
})
|
||||
|
||||
f.StringSliceVar(&StringSliceVar{
|
||||
Name: "dev-tls-san",
|
||||
Target: &c.flagDevTLSSANs,
|
||||
Default: nil,
|
||||
Usage: "Additional Subject Alternative Name (as a DNS name or IP address) " +
|
||||
"to generate the certificate with if `-dev-tls` is specified. The " +
|
||||
"certificate will always use localhost, localhost4, localhost6, " +
|
||||
"localhost.localdomain, and the host name as alternate DNS names, " +
|
||||
"and 127.0.0.1 as an alternate IP address. This flag can be specified " +
|
||||
"multiple times to specify multiple SANs.",
|
||||
})
|
||||
|
||||
f.StringVar(&StringVar{
|
||||
Name: "dev-root-token-id",
|
||||
Target: &c.flagDevRootTokenID,
|
||||
@@ -977,7 +990,17 @@ func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) {
|
||||
return nil, nil, certDir, err
|
||||
}
|
||||
}
|
||||
config, err = server.DevTLSConfig(devStorageType, certDir)
|
||||
extraSANs := c.flagDevTLSSANs
|
||||
host, _, err := net.SplitHostPort(c.flagDevListenAddr)
|
||||
if err == nil {
|
||||
// 127.0.0.1 is the default, and already included in the SANs.
|
||||
// Empty host means listen on all interfaces, but users should use the
|
||||
// -dev-tls-san flag to get the right SANs in that case.
|
||||
if host != "" && host != "127.0.0.1" {
|
||||
extraSANs = append(extraSANs, host)
|
||||
}
|
||||
}
|
||||
config, err = server.DevTLSConfig(devStorageType, certDir, extraSANs)
|
||||
|
||||
f = func() {
|
||||
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil {
|
||||
|
||||
@@ -176,13 +176,13 @@ ui = true
|
||||
}
|
||||
|
||||
// DevTLSConfig is a Config that is used for dev tls mode of Vault.
|
||||
func DevTLSConfig(storageType, certDir string) (*Config, error) {
|
||||
func DevTLSConfig(storageType, certDir string, extraSANs []string) (*Config, error) {
|
||||
ca, err := GenerateCA()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert, key, err := GenerateCert(ca.Template, ca.Signer)
|
||||
cert, key, err := generateCert(ca.Template, ca.Signer, extraSANs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -27,8 +27,8 @@ type CaCert struct {
|
||||
Signer crypto.Signer
|
||||
}
|
||||
|
||||
// GenerateCert creates a new leaf cert from provided CA template and signer
|
||||
func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (string, string, error) {
|
||||
// generateCert creates a new leaf cert from provided CA template and signer
|
||||
func generateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer, extraSANs []string) (string, string, error) {
|
||||
// Create the private key
|
||||
signer, keyPEM, err := privateKey()
|
||||
if err != nil {
|
||||
@@ -80,6 +80,13 @@ func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (str
|
||||
if !foundHostname {
|
||||
template.DNSNames = append(template.DNSNames, hostname)
|
||||
}
|
||||
for _, san := range extraSANs {
|
||||
if ip := net.ParseIP(san); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
} else {
|
||||
template.DNSNames = append(template.DNSNames, san)
|
||||
}
|
||||
}
|
||||
|
||||
bs, err := x509.CreateCertificate(
|
||||
rand.Reader, &template, caCertTemplate, signer.Public(), caSigner)
|
||||
|
||||
80
command/server/tls_util_test.go
Normal file
80
command/server/tls_util_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
)
|
||||
|
||||
// TestGenerateCertExtraSans ensures the implementation backing the flag
|
||||
// -dev-tls-san populates alternate DNS and IP address names in the generated
|
||||
// certificate as expected.
|
||||
func TestGenerateCertExtraSans(t *testing.T) {
|
||||
ca, err := GenerateCA()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for name, tc := range map[string]struct {
|
||||
extraSans []string
|
||||
expectedDNSNames []string
|
||||
expectedIPAddresses []string
|
||||
}{
|
||||
"empty": {},
|
||||
"DNS names": {
|
||||
extraSans: []string{"foo", "foo.bar"},
|
||||
expectedDNSNames: []string{"foo", "foo.bar"},
|
||||
},
|
||||
"IP addresses": {
|
||||
extraSans: []string{"0.0.0.0", "::1"},
|
||||
expectedIPAddresses: []string{"0.0.0.0", "::1"},
|
||||
},
|
||||
"mixed": {
|
||||
extraSans: []string{"bar", "0.0.0.0", "::1"},
|
||||
expectedDNSNames: []string{"bar"},
|
||||
expectedIPAddresses: []string{"0.0.0.0", "::1"},
|
||||
},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
certStr, _, err := generateCert(ca.Template, ca.Signer, tc.extraSans)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(certStr))
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedDNSNamesLen := len(tc.expectedDNSNames) + 5
|
||||
if len(cert.DNSNames) != expectedDNSNamesLen {
|
||||
t.Errorf("Wrong number of DNS names, expected %d but got %v", expectedDNSNamesLen, cert.DNSNames)
|
||||
}
|
||||
expectedIPAddrLen := len(tc.expectedIPAddresses) + 1
|
||||
if len(cert.IPAddresses) != expectedIPAddrLen {
|
||||
t.Errorf("Wrong number of IP addresses, expected %d but got %v", expectedIPAddrLen, cert.IPAddresses)
|
||||
}
|
||||
|
||||
for _, expected := range tc.expectedDNSNames {
|
||||
if !strutil.StrListContains(cert.DNSNames, expected) {
|
||||
t.Errorf("Missing DNS name %s", expected)
|
||||
}
|
||||
}
|
||||
for _, expected := range tc.expectedIPAddresses {
|
||||
var found bool
|
||||
for _, ip := range cert.IPAddresses {
|
||||
if ip.String() == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Missing IP address %s", expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/hashicorp/vault/sdk/physical"
|
||||
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
|
||||
"github.com/mitchellh/cli"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -383,13 +384,19 @@ func TestConfigureDevTLS(t *testing.T) {
|
||||
fun()
|
||||
}
|
||||
|
||||
require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription)
|
||||
require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription)
|
||||
if testcase.ConfigNotNil {
|
||||
require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription)
|
||||
require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription)
|
||||
}
|
||||
require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription)
|
||||
require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription)
|
||||
t.Run(testcase.TestDescription, func(t *testing.T) {
|
||||
assert.Equal(t, testcase.DeferFuncNotNil, (fun != nil))
|
||||
assert.Equal(t, testcase.ConfigNotNil, cfg != nil)
|
||||
if testcase.ConfigNotNil && cfg != nil {
|
||||
assert.True(t, len(cfg.Listeners) > 0)
|
||||
assert.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable)
|
||||
}
|
||||
assert.Equal(t, testcase.CertPathEmpty, len(certPath) == 0)
|
||||
if testcase.ErrNotNil {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user