agent: support providing certificate information in cert's config map (#9819)

* agent: support providing certificate information in cert's config map

* update TestCertEndToEnd

* remove URL reference on warning message
This commit is contained in:
Calvin Leung Huang
2020-08-25 14:26:06 -07:00
committed by GitHub
parent 4e69f5bf35
commit cca11493ce
7 changed files with 678 additions and 503 deletions

View File

@@ -344,14 +344,14 @@ func (c *AgentCommand) Run(args []string) int {
switch sc.Type { switch sc.Type {
case "file": case "file":
config := &sink.SinkConfig{ config := &sink.SinkConfig{
Logger: c.logger.Named("sink.file"), Logger: c.logger.Named("sink.file"),
Config: sc.Config, Config: sc.Config,
Client: client, Client: client,
WrapTTL: sc.WrapTTL, WrapTTL: sc.WrapTTL,
DHType: sc.DHType, DHType: sc.DHType,
DeriveKey: sc.DeriveKey, DeriveKey: sc.DeriveKey,
DHPath: sc.DHPath, DHPath: sc.DHPath,
AAD: sc.AAD, AAD: sc.AAD,
} }
s, err := file.NewFileSink(config) s, err := file.NewFileSink(config)
if err != nil { if err != nil {
@@ -411,9 +411,25 @@ func (c *AgentCommand) Run(args []string) int {
} }
} }
// Output the header that the server has started // Warn if cache _and_ cert auto-auth is enabled but certificates were not
// provided in the auto_auth.method["cert"].config stanza.
if config.Cache != nil && (config.AutoAuth != nil && config.AutoAuth.Method != nil && config.AutoAuth.Method.Type == "cert") {
_, okCertFile := config.AutoAuth.Method.Config["client_cert"]
_, okCertKey := config.AutoAuth.Method.Config["client_key"]
// If neither of these exists in the cert stanza, agent will use the
// certs from the vault stanza.
if !okCertFile && !okCertKey {
c.UI.Warn(wrapAtLength("WARNING! Cache is enabled and using the same certificates " +
"from the 'cert' auto-auth method specified in the 'vault' stanza. Consider " +
"specifying certificate information in the 'cert' auto-auth's config stanza."))
}
}
// Output the header that the agent has started
if !c.flagCombineLogs { if !c.flagCombineLogs {
c.UI.Output("==> Vault server started! Log data will stream in below:\n") c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
} }
// Inform any tests that the server is ready // Inform any tests that the server is ready

View File

@@ -11,6 +11,8 @@ import (
"github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/jsonutil"
) )
// AuthMethod is the interface that auto-auth methods implement for the agent
// to use.
type AuthMethod interface { type AuthMethod interface {
// Authenticate returns a mount path, header, request body, and error. // Authenticate returns a mount path, header, request body, and error.
// The header may be nil if no special header is needed. // The header may be nil if no special header is needed.
@@ -20,6 +22,13 @@ type AuthMethod interface {
Shutdown() Shutdown()
} }
// AuthMethodWithClient is an extended interface that can return an API client
// for use during the authentication call.
type AuthMethodWithClient interface {
AuthMethod
AuthClient(client *api.Client) (*api.Client, error)
}
type AuthConfig struct { type AuthConfig struct {
Logger hclog.Logger Logger hclog.Logger
MountPath string MountPath string
@@ -122,6 +131,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second)) backoff := 2*time.Second + time.Duration(ah.random.Int63()%int64(time.Second*2)-int64(time.Second))
ah.logger.Info("authenticating") ah.logger.Info("authenticating")
path, header, data, err := am.Authenticate(ctx, ah.client) path, header, data, err := am.Authenticate(ctx, ah.client)
if err != nil { if err != nil {
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds()) ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoff.Seconds())
@@ -129,9 +139,22 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
continue continue
} }
clientToUse := ah.client var clientToUse *api.Client
switch am.(type) {
case AuthMethodWithClient:
clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client)
if err != nil {
ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff)
continue
}
default:
clientToUse = ah.client
}
if ah.wrapTTL > 0 { if ah.wrapTTL > 0 {
wrapClient, err := ah.client.Clone() wrapClient, err := clientToUse.Clone()
if err != nil { if err != nil {
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds()) ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoff.Seconds())
backoffOrQuit(ctx, backoff) backoffOrQuit(ctx, backoff)
@@ -216,7 +239,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) {
watcher.Stop() watcher.Stop()
} }
watcher, err = ah.client.NewLifetimeWatcher(&api.LifetimeWatcherInput{ watcher, err = clientToUse.NewLifetimeWatcher(&api.LifetimeWatcherInput{
Secret: secret, Secret: secret,
}) })
if err != nil { if err != nil {

View File

@@ -15,8 +15,17 @@ type certMethod struct {
logger hclog.Logger logger hclog.Logger
mountPath string mountPath string
name string name string
caCert string
clientCert string
clientKey string
// Client is the cached client to use if cert info was provided.
client *api.Client
} }
var _ auth.AuthMethodWithClient = &certMethod{}
func NewCertAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) { func NewCertAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if conf == nil { if conf == nil {
return nil, errors.New("empty config") return nil, errors.New("empty config")
@@ -28,7 +37,6 @@ func NewCertAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
c := &certMethod{ c := &certMethod{
logger: conf.Logger, logger: conf.Logger,
mountPath: conf.MountPath, mountPath: conf.MountPath,
name: "",
} }
if conf.Config != nil { if conf.Config != nil {
@@ -40,6 +48,30 @@ func NewCertAuthMethod(conf *auth.AuthConfig) (auth.AuthMethod, error) {
if !ok { if !ok {
return nil, errors.New("could not convert 'name' config value to string") return nil, errors.New("could not convert 'name' config value to string")
} }
caCertRaw, ok := conf.Config["ca_cert"]
if ok {
c.caCert, ok = caCertRaw.(string)
if !ok {
return nil, errors.New("could not convert 'ca_cert' config value to string")
}
}
clientCertRaw, ok := conf.Config["client_cert"]
if ok {
c.clientCert, ok = clientCertRaw.(string)
if !ok {
return nil, errors.New("could not convert 'cert_file' config value to string")
}
}
clientKeyRaw, ok := conf.Config["client_key"]
if ok {
c.clientKey, ok = clientKeyRaw.(string)
if !ok {
return nil, errors.New("could not convert 'cert_key' config value to string")
}
}
} }
return c, nil return c, nil
@@ -64,3 +96,47 @@ func (c *certMethod) NewCreds() chan struct{} {
func (c *certMethod) CredSuccess() {} func (c *certMethod) CredSuccess() {}
func (c *certMethod) Shutdown() {} func (c *certMethod) Shutdown() {}
// AuthClient uses the existing client's address and returns a new client with
// the auto-auth method's certificate information if that's provided in its
// config map.
func (c *certMethod) AuthClient(client *api.Client) (*api.Client, error) {
c.logger.Trace("deriving auth client to use")
clientToAuth := client
if c.caCert != "" || (c.clientKey != "" && c.clientCert != "") {
// Return cached client if present
if c.client != nil {
return client, nil
}
config := api.DefaultConfig()
if config.Error != nil {
return nil, config.Error
}
config.Address = client.Address()
t := &api.TLSConfig{
CACert: c.caCert,
ClientCert: c.clientCert,
ClientKey: c.clientKey,
}
// Setup TLS config
if err := config.ConfigureTLS(t); err != nil {
return nil, err
}
var err error
clientToAuth, err = api.NewClient(config)
if err != nil {
return nil, err
}
// Cache the client for future use
c.client = clientToAuth
}
return clientToAuth, nil
}

View File

@@ -0,0 +1,549 @@
package agent
import (
"context"
"encoding/pem"
"fmt"
"io/ioutil"
"os"
"testing"
"time"
"github.com/hashicorp/vault/builtin/logical/pki"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
vaultcert "github.com/hashicorp/vault/builtin/credential/cert"
"github.com/hashicorp/vault/command/agent/auth"
agentcert "github.com/hashicorp/vault/command/agent/auth/cert"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/helper/dhutil"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
)
func TestCertEndToEnd(t *testing.T) {
cases := []struct {
name string
withCertRoleName bool
ahWrapping bool
}{
{
"with name with wrapping",
true,
true,
},
{
"with name without wrapping",
true,
false,
},
{
"without name with wrapping",
false,
true,
},
{
"without name without wrapping",
false,
false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
testCertEndToEnd(t, tc.withCertRoleName, tc.ahWrapping)
})
}
}
func testCertEndToEnd(t *testing.T, withCertRoleName, ahWrapping bool) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"cert": vaultcert.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
// Setup Vault
err := client.Sys().EnableAuthWithOptions("cert", &api.EnableAuthOptions{
Type: "cert",
})
if err != nil {
t.Fatal(err)
}
certificatePEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cluster.CACert.Raw})
certRoleName := "test"
_, err = client.Logical().Write(fmt.Sprintf("auth/cert/certs/%s", certRoleName), map[string]interface{}{
"certificate": string(certificatePEM),
"policies": "default",
})
if err != nil {
t.Fatal(err)
}
// Generate encryption params
pub, pri, err := dhutil.GeneratePublicPrivateKey()
if err != nil {
t.Fatal(err)
}
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
if err != nil {
t.Fatal(err)
}
out := ouf.Name()
ouf.Close()
os.Remove(out)
t.Logf("output: %s", out)
dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.")
if err != nil {
t.Fatal(err)
}
dhpath := dhpathf.Name()
dhpathf.Close()
os.Remove(dhpath)
// Write DH public key to file
mPubKey, err := jsonutil.EncodeJSON(&dhutil.PublicKeyInfo{
Curve25519PublicKey: pub,
})
if err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(dhpath, mPubKey, 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote dh param file", "path", dhpath)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
aaConfig := map[string]interface{}{}
if withCertRoleName {
aaConfig["name"] = certRoleName
}
am, err := agentcert.NewCertAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.cert"),
MountPath: "auth/cert",
Config: aaConfig,
})
if err != nil {
t.Fatal(err)
}
ahConfig := &auth.AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
EnableReauthOnNewCredentials: true,
}
if ahWrapping {
ahConfig.WrapTTL = 10 * time.Second
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
defer func() {
<-ah.DoneCh
}()
config := &sink.SinkConfig{
Logger: logger.Named("sink.file"),
AAD: "foobar",
DHType: "curve25519",
DHPath: dhpath,
DeriveKey: true,
Config: map[string]interface{}{
"path": out,
},
}
if !ahWrapping {
config.WrapTTL = 10 * time.Second
}
fs, err := file.NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = fs
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
cloned, err := client.Clone()
if err != nil {
t.Fatal(err)
}
checkToken := func() string {
timeout := time.Now().Add(5 * time.Second)
for {
if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout")
}
val, err := ioutil.ReadFile(out)
if err == nil {
os.Remove(out)
if len(val) == 0 {
t.Fatal("written token was empty")
}
// First decrypt it
resp := new(dhutil.Envelope)
if err := jsonutil.DecodeJSON(val, resp); err != nil {
continue
}
shared, err := dhutil.GenerateSharedSecret(pri, resp.Curve25519PublicKey)
if err != nil {
t.Fatal(err)
}
aesKey, err := dhutil.DeriveSharedKey(shared, pub, resp.Curve25519PublicKey)
if err != nil {
t.Fatal(err)
}
if len(aesKey) == 0 {
t.Fatal("got empty aes key")
}
val, err = dhutil.DecryptAES(aesKey, resp.EncryptedPayload, resp.Nonce, []byte("foobar"))
if err != nil {
t.Fatalf("error: %v\nresp: %v", err, string(val))
}
// Now unwrap it
wrapInfo := new(api.SecretWrapInfo)
if err := jsonutil.DecodeJSON(val, wrapInfo); err != nil {
t.Fatal(err)
}
switch {
case wrapInfo.TTL != 10:
t.Fatalf("bad wrap info: %v", wrapInfo.TTL)
case !ahWrapping && wrapInfo.CreationPath != "sys/wrapping/wrap":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case ahWrapping && wrapInfo.CreationPath != "auth/cert/login":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case wrapInfo.Token == "":
t.Fatal("wrap token is empty")
}
cloned.SetToken(wrapInfo.Token)
secret, err := cloned.Logical().Unwrap("")
if err != nil {
t.Fatal(err)
}
if ahWrapping {
switch {
case secret.Auth == nil:
t.Fatal("unwrap secret auth is nil")
case secret.Auth.ClientToken == "":
t.Fatal("unwrap token is nil")
}
return secret.Auth.ClientToken
} else {
switch {
case secret.Data == nil:
t.Fatal("unwrap secret data is nil")
case secret.Data["token"] == nil:
t.Fatal("unwrap token is nil")
}
return secret.Data["token"].(string)
}
}
time.Sleep(250 * time.Millisecond)
}
}
checkToken()
}
func TestCertEndToEnd_CertsInConfig(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"cert": vaultcert.Factory,
},
LogicalBackends: map[string]logical.Factory{
"pki": pki.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
// /////////////
// PKI setup
// /////////////
// Mount /pki as a root CA
err := client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "32h",
},
})
if err != nil {
t.Fatal(err)
}
// Set the cluster's certificate as the root CA in /pki
pemBundleRootCA := string(cluster.CACertPEM) + string(cluster.CAKeyPEM)
_, err = client.Logical().Write("pki/config/ca", map[string]interface{}{
"pem_bundle": pemBundleRootCA,
})
if err != nil {
t.Fatal(err)
}
// Mount /pki2 to operate as an intermediate CA
err = client.Sys().Mount("pki2", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "32h",
},
})
if err != nil {
t.Fatal(err)
}
// Create a CSR for the intermediate CA
secret, err := client.Logical().Write("pki2/intermediate/generate/internal", nil)
if err != nil {
t.Fatal(err)
}
intermediateCSR := secret.Data["csr"].(string)
// Sign the intermediate CSR using /pki
secret, err = client.Logical().Write("pki/root/sign-intermediate", map[string]interface{}{
"permitted_dns_domains": ".myvault.com",
"csr": intermediateCSR,
})
if err != nil {
t.Fatal(err)
}
intermediateCertPEM := secret.Data["certificate"].(string)
// Configure the intermediate cert as the CA in /pki2
_, err = client.Logical().Write("pki2/intermediate/set-signed", map[string]interface{}{
"certificate": intermediateCertPEM,
})
if err != nil {
t.Fatal(err)
}
// Create a role on the intermediate CA mount
_, err = client.Logical().Write("pki2/roles/myvault-dot-com", map[string]interface{}{
"allowed_domains": "myvault.com",
"allow_subdomains": "true",
"max_ttl": "5m",
})
if err != nil {
t.Fatal(err)
}
// Issue a leaf cert using the intermediate CA
secret, err = client.Logical().Write("pki2/issue/myvault-dot-com", map[string]interface{}{
"common_name": "cert.myvault.com",
"format": "pem",
"ip_sans": "127.0.0.1",
})
if err != nil {
t.Fatal(err)
}
leafCertPEM := secret.Data["certificate"].(string)
leafCertKeyPEM := secret.Data["private_key"].(string)
// Create temporary files for CA cert, client cert and client cert key.
// This is used to configure TLS in the api client.
caCertFile, err := ioutil.TempFile("", "caCert")
if err != nil {
t.Fatal(err)
}
defer os.Remove(caCertFile.Name())
if _, err := caCertFile.Write([]byte(cluster.CACertPEM)); err != nil {
t.Fatal(err)
}
if err := caCertFile.Close(); err != nil {
t.Fatal(err)
}
leafCertFile, err := ioutil.TempFile("", "leafCert")
if err != nil {
t.Fatal(err)
}
defer os.Remove(leafCertFile.Name())
if _, err := leafCertFile.Write([]byte(leafCertPEM)); err != nil {
t.Fatal(err)
}
if err := leafCertFile.Close(); err != nil {
t.Fatal(err)
}
leafCertKeyFile, err := ioutil.TempFile("", "leafCertKey")
if err != nil {
t.Fatal(err)
}
defer os.Remove(leafCertKeyFile.Name())
if _, err := leafCertKeyFile.Write([]byte(leafCertKeyPEM)); err != nil {
t.Fatal(err)
}
if err := leafCertKeyFile.Close(); err != nil {
t.Fatal(err)
}
// /////////////
// Cert auth setup
// /////////////
// Enable the cert auth method
err = client.Sys().EnableAuthWithOptions("cert", &api.EnableAuthOptions{
Type: "cert",
})
if err != nil {
t.Fatal(err)
}
// Set the intermediate CA cert as a trusted certificate in the backend
_, err = client.Logical().Write("auth/cert/certs/myvault-dot-com", map[string]interface{}{
"display_name": "myvault.com",
"policies": "default",
"certificate": intermediateCertPEM,
})
if err != nil {
t.Fatal(err)
}
// /////////////
// Auth handler (auto-auth) setup
// /////////////
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
am, err := agentcert.NewCertAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.cert"),
MountPath: "auth/cert",
Config: map[string]interface{}{
"ca_cert": caCertFile.Name(),
"client_cert": leafCertFile.Name(),
"client_key": leafCertKeyFile.Name(),
},
})
if err != nil {
t.Fatal(err)
}
ahConfig := &auth.AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
EnableReauthOnNewCredentials: true,
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
defer func() {
<-ah.DoneCh
}()
// /////////////
// Sink setup
// /////////////
// Use TempFile to get us a generated file name to use for the sink.
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
if err != nil {
t.Fatal(err)
}
ouf.Close()
out := ouf.Name()
os.Remove(out)
t.Logf("output: %s", out)
config := &sink.SinkConfig{
Logger: logger.Named("sink.file"),
Config: map[string]interface{}{
"path": out,
},
}
fs, err := file.NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = fs
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
// Read the token from the sink
timeout := time.Now().Add(5 * time.Second)
for {
if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout")
}
// Attempt to read the sink file until we get a token or the timeout is
// reached.
val, err := ioutil.ReadFile(out)
if err == nil {
os.Remove(out)
if len(val) == 0 {
t.Fatal("written token was empty")
}
t.Logf("sink token: %s", val)
break
}
time.Sleep(250 * time.Millisecond)
}
}

View File

@@ -1,248 +0,0 @@
package agent
import (
"context"
"encoding/pem"
"io/ioutil"
"os"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
vaultcert "github.com/hashicorp/vault/builtin/credential/cert"
"github.com/hashicorp/vault/command/agent/auth"
agentcert "github.com/hashicorp/vault/command/agent/auth/cert"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/helper/dhutil"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
)
func TestCertWithNameEndToEnd(t *testing.T) {
testCertWithNameEndToEnd(t, false)
testCertWithNameEndToEnd(t, true)
}
func testCertWithNameEndToEnd(t *testing.T, ahWrapping bool) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"cert": vaultcert.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
// Setup Vault
err := client.Sys().EnableAuthWithOptions("cert", &api.EnableAuthOptions{
Type: "cert",
})
if err != nil {
t.Fatal(err)
}
certificatePEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cluster.CACert.Raw})
_, err = client.Logical().Write("auth/cert/certs/test", map[string]interface{}{
"name": "test",
"certificate": string(certificatePEM),
"policies": "default",
})
if err != nil {
t.Fatal(err)
}
// Generate encryption params
pub, pri, err := dhutil.GeneratePublicPrivateKey()
if err != nil {
t.Fatal(err)
}
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
if err != nil {
t.Fatal(err)
}
out := ouf.Name()
ouf.Close()
os.Remove(out)
t.Logf("output: %s", out)
dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.")
if err != nil {
t.Fatal(err)
}
dhpath := dhpathf.Name()
dhpathf.Close()
os.Remove(dhpath)
// Write DH public key to file
mPubKey, err := jsonutil.EncodeJSON(&dhutil.PublicKeyInfo{
Curve25519PublicKey: pub,
})
if err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(dhpath, mPubKey, 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote dh param file", "path", dhpath)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
am, err := agentcert.NewCertAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.cert"),
MountPath: "auth/cert",
Config: map[string]interface{}{
"name": "test",
},
})
if err != nil {
t.Fatal(err)
}
ahConfig := &auth.AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
EnableReauthOnNewCredentials: true,
}
if ahWrapping {
ahConfig.WrapTTL = 10 * time.Second
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
defer func() {
<-ah.DoneCh
}()
config := &sink.SinkConfig{
Logger: logger.Named("sink.file"),
AAD: "foobar",
DHType: "curve25519",
DHPath: dhpath,
DeriveKey: true,
Config: map[string]interface{}{
"path": out,
},
}
if !ahWrapping {
config.WrapTTL = 10 * time.Second
}
fs, err := file.NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = fs
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
cloned, err := client.Clone()
if err != nil {
t.Fatal(err)
}
checkToken := func() string {
timeout := time.Now().Add(5 * time.Second)
for {
if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout")
}
val, err := ioutil.ReadFile(out)
if err == nil {
os.Remove(out)
if len(val) == 0 {
t.Fatal("written token was empty")
}
// First decrypt it
resp := new(dhutil.Envelope)
if err := jsonutil.DecodeJSON(val, resp); err != nil {
continue
}
shared, err := dhutil.GenerateSharedSecret(pri, resp.Curve25519PublicKey)
if err != nil {
t.Fatal(err)
}
aesKey, err := dhutil.DeriveSharedKey(shared, pub, resp.Curve25519PublicKey)
if err != nil {
t.Fatal(err)
}
if len(aesKey) == 0 {
t.Fatal("got empty aes key")
}
val, err = dhutil.DecryptAES(aesKey, resp.EncryptedPayload, resp.Nonce, []byte("foobar"))
if err != nil {
t.Fatalf("error: %v\nresp: %v", err, string(val))
}
// Now unwrap it
wrapInfo := new(api.SecretWrapInfo)
if err := jsonutil.DecodeJSON(val, wrapInfo); err != nil {
t.Fatal(err)
}
switch {
case wrapInfo.TTL != 10:
t.Fatalf("bad wrap info: %v", wrapInfo.TTL)
case !ahWrapping && wrapInfo.CreationPath != "sys/wrapping/wrap":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case ahWrapping && wrapInfo.CreationPath != "auth/cert/login":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case wrapInfo.Token == "":
t.Fatal("wrap token is empty")
}
cloned.SetToken(wrapInfo.Token)
secret, err := cloned.Logical().Unwrap("")
if err != nil {
t.Fatal(err)
}
if ahWrapping {
switch {
case secret.Auth == nil:
t.Fatal("unwrap secret auth is nil")
case secret.Auth.ClientToken == "":
t.Fatal("unwrap token is nil")
}
return secret.Auth.ClientToken
} else {
switch {
case secret.Data == nil:
t.Fatal("unwrap secret data is nil")
case secret.Data["token"] == nil:
t.Fatal("unwrap token is nil")
}
return secret.Data["token"].(string)
}
}
time.Sleep(250 * time.Millisecond)
}
}
checkToken()
}

View File

@@ -1,241 +0,0 @@
package agent
import (
"context"
"encoding/pem"
"io/ioutil"
"os"
"testing"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
vaultcert "github.com/hashicorp/vault/builtin/credential/cert"
"github.com/hashicorp/vault/command/agent/auth"
agentcert "github.com/hashicorp/vault/command/agent/auth/cert"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/helper/dhutil"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
)
func TestCertWithNoNAmeEndToEnd(t *testing.T) {
testCertWithNoNAmeEndToEnd(t, false)
testCertWithNoNAmeEndToEnd(t, true)
}
func testCertWithNoNAmeEndToEnd(t *testing.T, ahWrapping bool) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
CredentialBackends: map[string]logical.Factory{
"cert": vaultcert.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
vault.TestWaitActive(t, cluster.Cores[0].Core)
client := cluster.Cores[0].Client
// Setup Vault
err := client.Sys().EnableAuthWithOptions("cert", &api.EnableAuthOptions{
Type: "cert",
})
if err != nil {
t.Fatal(err)
}
certificatePEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cluster.CACert.Raw})
_, err = client.Logical().Write("auth/cert/certs/test", map[string]interface{}{
"name": "test",
"certificate": string(certificatePEM),
"policies": "default",
})
if err != nil {
t.Fatal(err)
}
// Generate encryption params
pub, pri, err := dhutil.GeneratePublicPrivateKey()
if err != nil {
t.Fatal(err)
}
ouf, err := ioutil.TempFile("", "auth.tokensink.test.")
if err != nil {
t.Fatal(err)
}
out := ouf.Name()
ouf.Close()
os.Remove(out)
t.Logf("output: %s", out)
dhpathf, err := ioutil.TempFile("", "auth.dhpath.test.")
if err != nil {
t.Fatal(err)
}
dhpath := dhpathf.Name()
dhpathf.Close()
os.Remove(dhpath)
// Write DH public key to file
mPubKey, err := jsonutil.EncodeJSON(&dhutil.PublicKeyInfo{
Curve25519PublicKey: pub,
})
if err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(dhpath, mPubKey, 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote dh param file", "path", dhpath)
}
ctx, cancelFunc := context.WithCancel(context.Background())
timer := time.AfterFunc(30*time.Second, func() {
cancelFunc()
})
defer timer.Stop()
am, err := agentcert.NewCertAuthMethod(&auth.AuthConfig{
Logger: logger.Named("auth.cert"),
MountPath: "auth/cert",
})
if err != nil {
t.Fatal(err)
}
ahConfig := &auth.AuthHandlerConfig{
Logger: logger.Named("auth.handler"),
Client: client,
EnableReauthOnNewCredentials: true,
}
if ahWrapping {
ahConfig.WrapTTL = 10 * time.Second
}
ah := auth.NewAuthHandler(ahConfig)
go ah.Run(ctx, am)
defer func() {
<-ah.DoneCh
}()
config := &sink.SinkConfig{
Logger: logger.Named("sink.file"),
AAD: "foobar",
DHType: "curve25519",
DHPath: dhpath,
Config: map[string]interface{}{
"path": out,
},
}
if !ahWrapping {
config.WrapTTL = 10 * time.Second
}
fs, err := file.NewFileSink(config)
if err != nil {
t.Fatal(err)
}
config.Sink = fs
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: logger.Named("sink.server"),
Client: client,
})
go ss.Run(ctx, ah.OutputCh, []*sink.SinkConfig{config})
defer func() {
<-ss.DoneCh
}()
// This has to be after the other defers so it happens first
defer cancelFunc()
cloned, err := client.Clone()
if err != nil {
t.Fatal(err)
}
checkToken := func() string {
timeout := time.Now().Add(5 * time.Second)
for {
if time.Now().After(timeout) {
t.Fatal("did not find a written token after timeout")
}
val, err := ioutil.ReadFile(out)
if err == nil {
os.Remove(out)
if len(val) == 0 {
t.Fatal("written token was empty")
}
// First decrypt it
resp := new(dhutil.Envelope)
if err := jsonutil.DecodeJSON(val, resp); err != nil {
continue
}
aesKey, err := dhutil.GenerateSharedSecret(pri, resp.Curve25519PublicKey)
if err != nil {
t.Fatal(err)
}
if len(aesKey) == 0 {
t.Fatal("got empty aes key")
}
val, err = dhutil.DecryptAES(aesKey, resp.EncryptedPayload, resp.Nonce, []byte("foobar"))
if err != nil {
t.Fatalf("error: %v\nresp: %v", err, string(val))
}
// Now unwrap it
wrapInfo := new(api.SecretWrapInfo)
if err := jsonutil.DecodeJSON(val, wrapInfo); err != nil {
t.Fatal(err)
}
switch {
case wrapInfo.TTL != 10:
t.Fatalf("bad wrap info: %v", wrapInfo.TTL)
case !ahWrapping && wrapInfo.CreationPath != "sys/wrapping/wrap":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case ahWrapping && wrapInfo.CreationPath != "auth/cert/login":
t.Fatalf("bad wrap path: %v", wrapInfo.CreationPath)
case wrapInfo.Token == "":
t.Fatal("wrap token is empty")
}
cloned.SetToken(wrapInfo.Token)
secret, err := cloned.Logical().Unwrap("")
if err != nil {
t.Fatal(err)
}
if ahWrapping {
switch {
case secret.Auth == nil:
t.Fatal("unwrap secret auth is nil")
case secret.Auth.ClientToken == "":
t.Fatal("unwrap token is nil")
}
return secret.Auth.ClientToken
} else {
switch {
case secret.Data == nil:
t.Fatal("unwrap secret data is nil")
case secret.Data["token"] == nil:
t.Fatal("unwrap token is nil")
}
return secret.Data["token"].(string)
}
}
time.Sleep(250 * time.Millisecond)
}
}
checkToken()
}

View File

@@ -30,7 +30,7 @@ type Config struct {
Templates []*ctconfig.TemplateConfig `hcl:"templates"` Templates []*ctconfig.TemplateConfig `hcl:"templates"`
} }
// Vault contains configuration for connnecting to Vault servers // Vault contains configuration for connecting to Vault servers
type Vault struct { type Vault struct {
Address string `hcl:"address"` Address string `hcl:"address"`
CACert string `hcl:"ca_cert"` CACert string `hcl:"ca_cert"`