Lazy-load plugin mounts (#3255)

* Lazy load plugins to avoid setup-unwrap cycle

* Remove commented blocks

* Refactor NewTestCluster, use single core cluster on basic plugin tests

* Set c.pluginDirectory in TestAddTestPlugin for setupPluginCatalog to work properly

* Add special path to mock plugin

* Move ensureCoresSealed to vault/testing.go

* Use same method for EnsureCoresSealed and Cleanup

* Bump ensureCoresSealed timeout to 60s

* Correctly handle nil opts on NewTestCluster

* Add metadata flag to APIClientMeta, use meta-enabled plugin when mounting to bootstrap

* Check metadata flag directly on the plugin process

* Plumb isMetadataMode down to PluginRunner

* Add NOOP shims when running in metadata mode

* Remove unused flag from the APIMetadata object

* Remove setupSecretPlugins and setupCredentialPlugins functions

* Move when we setup rollback manager to after the plugins are initialized

* Fix tests

* Fix merge issue

* start rollback manager after the credential setup

* Add guards against running certain client and server functions while in metadata mode

* Call initialize once a plugin is loaded on the fly

* Add more tests, update basic secret/auth plugin tests to trigger lazy loading

* Skip mount if plugin removed from catalog

* Fixup

* Remove commented line on LookupPlugin

* Fail on mount operation if plugin is re-added to catalog and mount is on existing path

* Check type and special paths on startBackend

* Fix merge conflicts

* Refactor PluginRunner run methods to use runCommon, fix TestSystemBackend_Plugin_auth
This commit is contained in:
Calvin Leung Huang
2017-09-01 01:02:03 -04:00
committed by GitHub
parent 3f4a593ec2
commit 3b8b68097d
22 changed files with 816 additions and 480 deletions

View File

@@ -3,13 +3,20 @@ package plugin
import ( import (
"fmt" "fmt"
"net/rpc" "net/rpc"
"reflect"
"sync" "sync"
uuid "github.com/hashicorp/go-uuid" uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
bplugin "github.com/hashicorp/vault/logical/plugin" bplugin "github.com/hashicorp/vault/logical/plugin"
) )
var (
ErrMismatchType = fmt.Errorf("mismatch on mounted backend and plugin backend type")
ErrMismatchPaths = fmt.Errorf("mismatch on mounted backend and plugin backend special paths")
)
// Factory returns a configured plugin logical.Backend. // Factory returns a configured plugin logical.Backend.
func Factory(conf *logical.BackendConfig) (logical.Backend, error) { func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
_, ok := conf.Config["plugin_name"] _, ok := conf.Config["plugin_name"]
@@ -31,14 +38,33 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
// or as a concrete implementation if builtin, casted as logical.Backend. // or as a concrete implementation if builtin, casted as logical.Backend.
func Backend(conf *logical.BackendConfig) (logical.Backend, error) { func Backend(conf *logical.BackendConfig) (logical.Backend, error) {
var b backend var b backend
name := conf.Config["plugin_name"] name := conf.Config["plugin_name"]
sys := conf.System sys := conf.System
raw, err := bplugin.NewBackend(name, sys, conf.Logger) // NewBackend with isMetadataMode set to true
raw, err := bplugin.NewBackend(name, sys, conf.Logger, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b.Backend = raw err = raw.Setup(conf)
if err != nil {
return nil, err
}
// Get SpecialPaths and BackendType
paths := raw.SpecialPaths()
btype := raw.Type()
// Cleanup meta plugin backend
raw.Cleanup()
// Initialize b.Backend with dummy backend since plugin
// backends will need to be lazy loaded.
b.Backend = &framework.Backend{
PathsSpecial: paths,
BackendType: btype,
}
b.config = conf b.config = conf
return &b, nil return &b, nil
@@ -53,16 +79,24 @@ type backend struct {
// Used to detect if we already reloaded // Used to detect if we already reloaded
canary string canary string
// Used to detect if plugin is set
loaded bool
} }
func (b *backend) reloadBackend() error { func (b *backend) reloadBackend() error {
b.Logger().Trace("plugin: reloading plugin backend", "plugin", b.config.Config["plugin_name"])
return b.startBackend()
}
// startBackend starts a plugin backend
func (b *backend) startBackend() error {
pluginName := b.config.Config["plugin_name"] pluginName := b.config.Config["plugin_name"]
b.Logger().Trace("plugin: reloading plugin backend", "plugin", pluginName)
// Ensure proper cleanup of the backend (i.e. call client.Kill()) // Ensure proper cleanup of the backend (i.e. call client.Kill())
b.Backend.Cleanup() b.Backend.Cleanup()
nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger) nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger, false)
if err != nil { if err != nil {
return err return err
} }
@@ -70,7 +104,29 @@ func (b *backend) reloadBackend() error {
if err != nil { if err != nil {
return err return err
} }
// If the backend has not been loaded (i.e. still in metadata mode),
// check if type and special paths still matches
if !b.loaded {
if b.Backend.Type() != nb.Type() {
nb.Cleanup()
b.Logger().Warn("plugin: failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType)
return ErrMismatchType
}
if !reflect.DeepEqual(b.Backend.SpecialPaths(), nb.SpecialPaths()) {
nb.Cleanup()
b.Logger().Warn("plugin: failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths)
return ErrMismatchPaths
}
}
b.Backend = nb b.Backend = nb
b.loaded = true
// Call initialize
if err := b.Backend.Initialize(); err != nil {
return err
}
return nil return nil
} }
@@ -79,6 +135,23 @@ func (b *backend) reloadBackend() error {
func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error) { func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error) {
b.RLock() b.RLock()
canary := b.canary canary := b.canary
// Lazy-load backend
if !b.loaded {
// Upgrade lock
b.RUnlock()
b.Lock()
// Check once more after lock swap
if !b.loaded {
err := b.startBackend()
if err != nil {
b.Unlock()
return nil, err
}
}
b.Unlock()
b.RLock()
}
resp, err := b.Backend.HandleRequest(req) resp, err := b.Backend.HandleRequest(req)
b.RUnlock() b.RUnlock()
// Need to compare string value for case were err comes from plugin RPC // Need to compare string value for case were err comes from plugin RPC
@@ -112,6 +185,24 @@ func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error)
func (b *backend) HandleExistenceCheck(req *logical.Request) (bool, bool, error) { func (b *backend) HandleExistenceCheck(req *logical.Request) (bool, bool, error) {
b.RLock() b.RLock()
canary := b.canary canary := b.canary
// Lazy-load backend
if !b.loaded {
// Upgrade lock
b.RUnlock()
b.Lock()
// Check once more after lock swap
if !b.loaded {
err := b.startBackend()
if err != nil {
b.Unlock()
return false, false, err
}
}
b.Unlock()
b.RLock()
}
checkFound, exists, err := b.Backend.HandleExistenceCheck(req) checkFound, exists, err := b.Backend.HandleExistenceCheck(req)
b.RUnlock() b.RUnlock()
if err != nil && err.Error() == rpc.ErrShutdown.Error() { if err != nil && err.Error() == rpc.ErrShutdown.Error() {

View File

@@ -1,6 +1,7 @@
package plugin package plugin
import ( import (
"fmt"
"os" "os"
"testing" "testing"
@@ -39,7 +40,8 @@ func TestBackend_Factory(t *testing.T) {
} }
func TestBackend_PluginMain(t *testing.T) { func TestBackend_PluginMain(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { args := []string{}
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
return return
} }
@@ -48,7 +50,7 @@ func TestBackend_PluginMain(t *testing.T) {
t.Fatal("CA cert not passed in") t.Fatal("CA cert not passed in")
} }
args := []string{"--ca-cert=" + caPEM} args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
apiClientMeta := &pluginutil.APIClientMeta{} apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet() flags := apiClientMeta.FlagSet()

View File

@@ -7,7 +7,7 @@ import (
) )
var ( var (
// PluginUnwrapTokenEnv is the ENV name used to pass the configuration for // PluginMlockEnabled is the ENV name used to pass the configuration for
// enabling mlock // enabling mlock
PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED"
) )

View File

@@ -2,6 +2,7 @@ package pluginutil
import ( import (
"crypto/sha256" "crypto/sha256"
"crypto/tls"
"flag" "flag"
"fmt" "fmt"
"os/exec" "os/exec"
@@ -22,6 +23,7 @@ type Looker interface {
// Wrapper interface defines the functions needed by the runner to wrap the // Wrapper interface defines the functions needed by the runner to wrap the
// metadata needed to run a plugin process. This includes looking up Mlock // metadata needed to run a plugin process. This includes looking up Mlock
// configuration and wrapping data in a respose wrapped token. // configuration and wrapping data in a respose wrapped token.
// logical.SystemView implementataions satisfy this interface.
type RunnerUtil interface { type RunnerUtil interface {
ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error)
MlockEnabled() bool MlockEnabled() bool
@@ -44,56 +46,82 @@ type PluginRunner struct {
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"` BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
} }
// Run takes a wrapper instance, and the go-plugin paramaters and executes a // Run takes a wrapper RunnerUtil instance along with the go-plugin paramaters and
// plugin. // returns a configured plugin.Client with TLS Configured and a wrapping token set
// on PluginUnwrapTokenEnv for plugin process consumption.
func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) { func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) {
// Get a CA TLS Certificate return r.runCommon(wrapper, pluginMap, hs, env, logger, false)
certBytes, key, err := generateCert() }
if err != nil {
return nil, err
}
// Use CA to sign a client cert and return a configured TLS config // RunMetadataMode returns a configured plugin.Client that will dispense a plugin
clientTLSConfig, err := createClientTLSConfig(certBytes, key) // in metadata mode. The PluginMetadaModeEnv is passed in as part of the Cmd to
if err != nil { // plugin.Client, and consumed by the plugin process on pluginutil.VaultPluginTLSProvider.
return nil, err func (r *PluginRunner) RunMetadataMode(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger) (*plugin.Client, error) {
} return r.runCommon(wrapper, pluginMap, hs, env, logger, true)
// Use CA to sign a server cert and wrap the values in a response wrapped }
// token.
wrapToken, err := wrapServerConfig(wrapper, certBytes, key)
if err != nil {
return nil, err
}
func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string, logger log.Logger, isMetadataMode bool) (*plugin.Client, error) {
cmd := exec.Command(r.Command, r.Args...) cmd := exec.Command(r.Command, r.Args...)
cmd.Env = append(cmd.Env, env...) cmd.Env = append(cmd.Env, env...)
// Add the response wrap token to the ENV of the plugin
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken))
// Add the mlock setting to the ENV of the plugin // Add the mlock setting to the ENV of the plugin
if wrapper.MlockEnabled() { if wrapper.MlockEnabled() {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true"))
} }
secureConfig := &plugin.SecureConfig{
Checksum: r.Sha256,
Hash: sha256.New(),
}
// Create logger for the plugin client // Create logger for the plugin client
clogger := &hclogFaker{ clogger := &hclogFaker{
logger: logger, logger: logger,
} }
namedLogger := clogger.ResetNamed("plugin") namedLogger := clogger.ResetNamed("plugin")
client := plugin.NewClient(&plugin.ClientConfig{ var clientTLSConfig *tls.Config
if !isMetadataMode {
// Add the metadata mode ENV and set it to false
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMetadaModeEnv, "false"))
// Get a CA TLS Certificate
certBytes, key, err := generateCert()
if err != nil {
return nil, err
}
// Use CA to sign a client cert and return a configured TLS config
clientTLSConfig, err = createClientTLSConfig(certBytes, key)
if err != nil {
return nil, err
}
// Use CA to sign a server cert and wrap the values in a response wrapped
// token.
wrapToken, err := wrapServerConfig(wrapper, certBytes, key)
if err != nil {
return nil, err
}
// Add the response wrap token to the ENV of the plugin
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken))
} else {
namedLogger = clogger.ResetNamed("plugin.metadata")
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMetadaModeEnv, "true"))
}
secureConfig := &plugin.SecureConfig{
Checksum: r.Sha256,
Hash: sha256.New(),
}
clientConfig := &plugin.ClientConfig{
HandshakeConfig: hs, HandshakeConfig: hs,
Plugins: pluginMap, Plugins: pluginMap,
Cmd: cmd, Cmd: cmd,
TLSConfig: clientTLSConfig,
SecureConfig: secureConfig, SecureConfig: secureConfig,
TLSConfig: clientTLSConfig,
Logger: namedLogger, Logger: namedLogger,
}) }
client := plugin.NewClient(clientConfig)
return client, nil return client, nil
} }
@@ -108,7 +136,7 @@ type APIClientMeta struct {
} }
func (f *APIClientMeta) FlagSet() *flag.FlagSet { func (f *APIClientMeta) FlagSet() *flag.FlagSet {
fs := flag.NewFlagSet("tls settings", flag.ContinueOnError) fs := flag.NewFlagSet("vault plugin settings", flag.ContinueOnError)
fs.StringVar(&f.flagCACert, "ca-cert", "", "") fs.StringVar(&f.flagCACert, "ca-cert", "", "")
fs.StringVar(&f.flagCAPath, "ca-path", "", "") fs.StringVar(&f.flagCAPath, "ca-path", "", "")

View File

@@ -29,6 +29,10 @@ var (
// PluginCACertPEMEnv is an ENV name used for holding a CA PEM-encoded // PluginCACertPEMEnv is an ENV name used for holding a CA PEM-encoded
// string. Used for testing. // string. Used for testing.
PluginCACertPEMEnv = "VAULT_TESTING_PLUGIN_CA_PEM" PluginCACertPEMEnv = "VAULT_TESTING_PLUGIN_CA_PEM"
// PluginMetadaModeEnv is an ENV name used to disable TLS communication
// to bootstrap mounting plugins.
PluginMetadaModeEnv = "VAULT_PLUGIN_METADATA_MODE"
) )
// generateCert is used internally to create certificates for the plugin // generateCert is used internally to create certificates for the plugin
@@ -124,6 +128,10 @@ func wrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (
// VaultPluginTLSProvider is run inside a plugin and retrives the response // VaultPluginTLSProvider is run inside a plugin and retrives the response
// wrapped TLS certificate from vault. It returns a configured TLS Config. // wrapped TLS certificate from vault. It returns a configured TLS Config.
func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) { func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) {
if os.Getenv(PluginMetadaModeEnv) == "true" {
return nil
}
return func() (*tls.Config, error) { return func() (*tls.Config, error) {
unwrapToken := os.Getenv(PluginUnwrapTokenEnv) unwrapToken := os.Getenv(PluginUnwrapTokenEnv)
@@ -157,7 +165,10 @@ func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, er
clientConf := api.DefaultConfig() clientConf := api.DefaultConfig()
clientConf.Address = vaultAddr clientConf.Address = vaultAddr
if apiTLSConfig != nil { if apiTLSConfig != nil {
clientConf.ConfigureTLS(apiTLSConfig) err := clientConf.ConfigureTLS(apiTLSConfig)
if err != nil {
return nil, errwrap.Wrapf("error configuring api client {{err}}", err)
}
} }
client, err := api.NewClient(clientConf) client, err := api.NewClient(clientConf)
if err != nil { if err != nil {

View File

@@ -9,7 +9,8 @@ import (
// BackendPlugin is the plugin.Plugin implementation // BackendPlugin is the plugin.Plugin implementation
type BackendPlugin struct { type BackendPlugin struct {
Factory func(*logical.BackendConfig) (logical.Backend, error) Factory func(*logical.BackendConfig) (logical.Backend, error)
metadataMode bool
} }
// Server gets called when on plugin.Serve() // Server gets called when on plugin.Serve()
@@ -19,5 +20,5 @@ func (b *BackendPlugin) Server(broker *plugin.MuxBroker) (interface{}, error) {
// Client gets called on plugin.NewClient() // Client gets called on plugin.NewClient()
func (b BackendPlugin) Client(broker *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { func (b BackendPlugin) Client(broker *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
return &backendPluginClient{client: c, broker: broker}, nil return &backendPluginClient{client: c, broker: broker, metadataMode: b.metadataMode}, nil
} }

View File

@@ -1,6 +1,7 @@
package plugin package plugin
import ( import (
"errors"
"net/rpc" "net/rpc"
"github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
@@ -8,11 +9,16 @@ import (
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
) )
var (
ErrClientInMetadataMode = errors.New("plugin client can not perform action while in metadata mode")
)
// backendPluginClient implements logical.Backend and is the // backendPluginClient implements logical.Backend and is the
// go-plugin client. // go-plugin client.
type backendPluginClient struct { type backendPluginClient struct {
broker *plugin.MuxBroker broker *plugin.MuxBroker
client *rpc.Client client *rpc.Client
metadataMode bool
system logical.SystemView system logical.SystemView
logger log.Logger logger log.Logger
@@ -83,6 +89,10 @@ type RegisterLicenseReply struct {
} }
func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Response, error) { func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Response, error) {
if b.metadataMode {
return nil, ErrClientInMetadataMode
}
// Do not send the storage, since go-plugin cannot serialize // Do not send the storage, since go-plugin cannot serialize
// interfaces. The server will pick up the storage from the shim. // interfaces. The server will pick up the storage from the shim.
req.Storage = nil req.Storage = nil
@@ -136,6 +146,10 @@ func (b *backendPluginClient) Logger() log.Logger {
} }
func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool, bool, error) { func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool, bool, error) {
if b.metadataMode {
return false, false, ErrClientInMetadataMode
}
// Do not send the storage, since go-plugin cannot serialize // Do not send the storage, since go-plugin cannot serialize
// interfaces. The server will pick up the storage from the shim. // interfaces. The server will pick up the storage from the shim.
req.Storage = nil req.Storage = nil
@@ -172,31 +186,49 @@ func (b *backendPluginClient) Cleanup() {
} }
func (b *backendPluginClient) Initialize() error { func (b *backendPluginClient) Initialize() error {
if b.metadataMode {
return ErrClientInMetadataMode
}
err := b.client.Call("Plugin.Initialize", new(interface{}), &struct{}{}) err := b.client.Call("Plugin.Initialize", new(interface{}), &struct{}{})
return err return err
} }
func (b *backendPluginClient) InvalidateKey(key string) { func (b *backendPluginClient) InvalidateKey(key string) {
if b.metadataMode {
return
}
b.client.Call("Plugin.InvalidateKey", key, &struct{}{}) b.client.Call("Plugin.InvalidateKey", key, &struct{}{})
} }
func (b *backendPluginClient) Setup(config *logical.BackendConfig) error { func (b *backendPluginClient) Setup(config *logical.BackendConfig) error {
// Shim logical.Storage // Shim logical.Storage
storageImpl := config.StorageView
if b.metadataMode {
storageImpl = &NOOPStorage{}
}
storageID := b.broker.NextId() storageID := b.broker.NextId()
go b.broker.AcceptAndServe(storageID, &StorageServer{ go b.broker.AcceptAndServe(storageID, &StorageServer{
impl: config.StorageView, impl: storageImpl,
}) })
// Shim log.Logger // Shim log.Logger
loggerImpl := config.Logger
if b.metadataMode {
loggerImpl = log.NullLog
}
loggerID := b.broker.NextId() loggerID := b.broker.NextId()
go b.broker.AcceptAndServe(loggerID, &LoggerServer{ go b.broker.AcceptAndServe(loggerID, &LoggerServer{
logger: config.Logger, logger: loggerImpl,
}) })
// Shim logical.SystemView // Shim logical.SystemView
sysViewImpl := config.System
if b.metadataMode {
sysViewImpl = &logical.StaticSystemView{}
}
sysViewID := b.broker.NextId() sysViewID := b.broker.NextId()
go b.broker.AcceptAndServe(sysViewID, &SystemViewServer{ go b.broker.AcceptAndServe(sysViewID, &SystemViewServer{
impl: config.System, impl: sysViewImpl,
}) })
args := &SetupArgs{ args := &SetupArgs{
@@ -233,6 +265,10 @@ func (b *backendPluginClient) Type() logical.BackendType {
} }
func (b *backendPluginClient) RegisterLicense(license interface{}) error { func (b *backendPluginClient) RegisterLicense(license interface{}) error {
if b.metadataMode {
return ErrClientInMetadataMode
}
var reply RegisterLicenseReply var reply RegisterLicenseReply
args := RegisterLicenseArgs{ args := RegisterLicenseArgs{
License: license, License: license,

View File

@@ -1,12 +1,19 @@
package plugin package plugin
import ( import (
"errors"
"net/rpc" "net/rpc"
"os"
"github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
var (
ErrServerInMetadataMode = errors.New("plugin server can not perform action while in metadata mode")
)
// backendPluginServer is the RPC server that backendPluginClient talks to, // backendPluginServer is the RPC server that backendPluginClient talks to,
// it methods conforming to requirements by net/rpc // it methods conforming to requirements by net/rpc
type backendPluginServer struct { type backendPluginServer struct {
@@ -19,7 +26,15 @@ type backendPluginServer struct {
storageClient *rpc.Client storageClient *rpc.Client
} }
func inMetadataMode() bool {
return os.Getenv(pluginutil.PluginMetadaModeEnv) == "true"
}
func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *HandleRequestReply) error { func (b *backendPluginServer) HandleRequest(args *HandleRequestArgs, reply *HandleRequestReply) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
storage := &StorageClient{client: b.storageClient} storage := &StorageClient{client: b.storageClient}
args.Request.Storage = storage args.Request.Storage = storage
@@ -40,6 +55,10 @@ func (b *backendPluginServer) SpecialPaths(_ interface{}, reply *SpecialPathsRep
} }
func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArgs, reply *HandleExistenceCheckReply) error { func (b *backendPluginServer) HandleExistenceCheck(args *HandleExistenceCheckArgs, reply *HandleExistenceCheckReply) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
storage := &StorageClient{client: b.storageClient} storage := &StorageClient{client: b.storageClient}
args.Request.Storage = storage args.Request.Storage = storage
@@ -64,11 +83,19 @@ func (b *backendPluginServer) Cleanup(_ interface{}, _ *struct{}) error {
} }
func (b *backendPluginServer) Initialize(_ interface{}, _ *struct{}) error { func (b *backendPluginServer) Initialize(_ interface{}, _ *struct{}) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
err := b.backend.Initialize() err := b.backend.Initialize()
return err return err
} }
func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error { func (b *backendPluginServer) InvalidateKey(args string, _ *struct{}) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
b.backend.InvalidateKey(args) b.backend.InvalidateKey(args)
return nil return nil
} }
@@ -145,6 +172,10 @@ func (b *backendPluginServer) Type(_ interface{}, reply *TypeReply) error {
} }
func (b *backendPluginServer) RegisterLicense(args *RegisterLicenseArgs, reply *RegisterLicenseReply) error { func (b *backendPluginServer) RegisterLicense(args *RegisterLicenseArgs, reply *RegisterLicenseReply) error {
if inMetadataMode() {
return ErrServerInMetadataMode
}
err := b.backend.RegisterLicense(args.License) err := b.backend.RegisterLicense(args.License)
if err != nil { if err != nil {
*reply = RegisterLicenseReply{ *reply = RegisterLicenseReply{

View File

@@ -43,6 +43,7 @@ func Backend() *backend {
kvPaths(&b), kvPaths(&b),
[]*framework.Path{ []*framework.Path{
pathInternal(&b), pathInternal(&b),
pathSpecial(&b),
}, },
), ),
PathsSpecial: &logical.Paths{ PathsSpecial: &logical.Paths{

View File

@@ -13,7 +13,7 @@ import (
func main() { func main() {
apiClientMeta := &pluginutil.APIClientMeta{} apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet() flags := apiClientMeta.FlagSet()
flags.Parse(os.Args) flags.Parse(os.Args[1:]) // Ignore command, strictly parse flags
tlsConfig := apiClientMeta.GetTLSConfig() tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)

View File

@@ -0,0 +1,27 @@
package mock
import (
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
// pathSpecial is used to test special paths.
func pathSpecial(b *backend) *framework.Path {
return &framework.Path{
Pattern: "special",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathSpecialRead,
},
}
}
func (b *backend) pathSpecialRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// Return the secret
return &logical.Response{
Data: map[string]interface{}{
"data": "foo",
},
}, nil
}

View File

@@ -40,8 +40,9 @@ func (b *BackendPluginClient) Cleanup() {
// NewBackend will return an instance of an RPC-based client implementation of the backend for // NewBackend will return an instance of an RPC-based client implementation of the backend for
// external plugins, or a concrete implementation of the backend if it is a builtin backend. // external plugins, or a concrete implementation of the backend if it is a builtin backend.
// The backend is returned as a logical.Backend interface. // The backend is returned as a logical.Backend interface. The isMetadataMode param determines whether
func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (logical.Backend, error) { // the plugin should run in metadata mode.
func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
// Look for plugin in the plugin catalog // Look for plugin in the plugin catalog
pluginRunner, err := sys.LookupPlugin(pluginName) pluginRunner, err := sys.LookupPlugin(pluginName)
if err != nil { if err != nil {
@@ -65,7 +66,7 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
} else { } else {
// create a backendPluginClient instance // create a backendPluginClient instance
backend, err = newPluginClient(sys, pluginRunner, logger) backend, err = newPluginClient(sys, pluginRunner, logger, isMetadataMode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -74,12 +75,21 @@ func NewBackend(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Log
return backend, nil return backend, nil
} }
func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger) (logical.Backend, error) { func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (logical.Backend, error) {
// pluginMap is the map of plugins we can dispense. // pluginMap is the map of plugins we can dispense.
pluginMap := map[string]plugin.Plugin{ pluginMap := map[string]plugin.Plugin{
"backend": &BackendPlugin{}, "backend": &BackendPlugin{
metadataMode: isMetadataMode,
},
}
var client *plugin.Client
var err error
if isMetadataMode {
client, err = pluginRunner.RunMetadataMode(sys, pluginMap, handshakeConfig, []string{}, logger)
} else {
client, err = pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
} }
client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -20,7 +20,8 @@ type ServeOpts struct {
TLSProviderFunc TLSProdiverFunc TLSProviderFunc TLSProdiverFunc
} }
// Serve is used to serve a backend plugin // Serve is a helper function used to serve a backend plugin. This
// should be ran on the plugin's main process.
func Serve(opts *ServeOpts) error { func Serve(opts *ServeOpts) error {
// pluginMap is the map of plugins we can dispense. // pluginMap is the map of plugins we can dispense.
var pluginMap = map[string]plugin.Plugin{ var pluginMap = map[string]plugin.Plugin{
@@ -34,6 +35,7 @@ func Serve(opts *ServeOpts) error {
return err return err
} }
// If FetchMetadata is true, run without TLSProvider
plugin.Serve(&plugin.ServeConfig{ plugin.Serve(&plugin.ServeConfig{
HandshakeConfig: handshakeConfig, HandshakeConfig: handshakeConfig,
Plugins: pluginMap, Plugins: pluginMap,

View File

@@ -117,3 +117,23 @@ type StoragePutReply struct {
type StorageDeleteReply struct { type StorageDeleteReply struct {
Error *plugin.BasicError Error *plugin.BasicError
} }
// NOOPStorage is used to deny access to the storage interface while running a
// backend plugin in metadata mode.
type NOOPStorage struct{}
func (s *NOOPStorage) List(prefix string) ([]string, error) {
return []string{}, nil
}
func (s *NOOPStorage) Get(key string) (*logical.StorageEntry, error) {
return nil, nil
}
func (s *NOOPStorage) Put(entry *logical.StorageEntry) error {
return nil
}
func (s *NOOPStorage) Delete(key string) error {
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
@@ -397,7 +398,6 @@ func (c *Core) persistAuth(table *MountTable, localOnly bool) error {
// setupCredentials is invoked after we've loaded the auth table to // setupCredentials is invoked after we've loaded the auth table to
// initialize the credential backends and setup the router // initialize the credential backends and setup the router
func (c *Core) setupCredentials() error { func (c *Core) setupCredentials() error {
var backend logical.Backend
var view *BarrierView var view *BarrierView
var err error var err error
var persistNeeded bool var persistNeeded bool
@@ -406,6 +406,7 @@ func (c *Core) setupCredentials() error {
defer c.authLock.Unlock() defer c.authLock.Unlock()
for _, entry := range c.auth.Entries { for _, entry := range c.auth.Entries {
var backend logical.Backend
// Work around some problematic code that existed in master for a while // Work around some problematic code that existed in master for a while
if strings.HasPrefix(entry.Path, credentialRoutePrefix) { if strings.HasPrefix(entry.Path, credentialRoutePrefix) {
entry.Path = strings.TrimPrefix(entry.Path, credentialRoutePrefix) entry.Path = strings.TrimPrefix(entry.Path, credentialRoutePrefix)
@@ -425,6 +426,9 @@ func (c *Core) setupCredentials() error {
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf) backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
if err != nil { if err != nil {
c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err) c.logger.Error("core: failed to create credential entry", "path", entry.Path, "error", err)
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
goto ROUTER_MOUNT
}
return errLoadAuthFailed return errLoadAuthFailed
} }
if backend == nil { if backend == nil {
@@ -432,15 +436,14 @@ func (c *Core) setupCredentials() error {
} }
// Check for the correct backend type // Check for the correct backend type
backendType := backend.Type() if entry.Type == "plugin" && backend.Type() != logical.TypeCredential {
if entry.Type == "plugin" && backendType != logical.TypeCredential { return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backend.Type())
return fmt.Errorf("cannot mount '%s' of type '%s' as an auth backend", entry.Config.PluginName, backendType)
} }
if err := backend.Initialize(); err != nil { if err := backend.Initialize(); err != nil {
return err return err
} }
ROUTER_MOUNT:
// Mount the backend // Mount the backend
path := credentialRoutePrefix + entry.Path path := credentialRoutePrefix + entry.Path
err = c.router.Mount(backend, path, entry, view) err = c.router.Mount(backend, path, entry, view)

View File

@@ -1369,9 +1369,6 @@ func (c *Core) postUnseal() (retErr error) {
if err := c.setupMounts(); err != nil { if err := c.setupMounts(); err != nil {
return err return err
} }
if err := c.startRollback(); err != nil {
return err
}
if err := c.setupPolicyStore(); err != nil { if err := c.setupPolicyStore(); err != nil {
return err return err
} }
@@ -1384,6 +1381,9 @@ func (c *Core) postUnseal() (retErr error) {
if err := c.setupCredentials(); err != nil { if err := c.setupCredentials(); err != nil {
return err return err
} }
if err := c.startRollback(); err != nil {
return err
}
if err := c.setupExpiration(); err != nil { if err := c.setupExpiration(); err != nil {
return err return err
} }

View File

@@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/helper/wrapping"
@@ -132,7 +134,7 @@ func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner,
return nil, err return nil, err
} }
if r == nil { if r == nil {
return nil, fmt.Errorf("no plugin found with name: %s", name) return nil, errwrap.Wrapf(fmt.Sprintf("{{err}}: %s", name), ErrPluginNotFound)
} }
return r, nil return r, nil

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"testing" "testing"
"time"
"github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
@@ -15,17 +16,196 @@ import (
) )
func TestSystemBackend_Plugin_secret(t *testing.T) { func TestSystemBackend_Plugin_secret(t *testing.T) {
cluster := testSystemBackendMock(t, 2, logical.TypeLogical) cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup() defer cluster.Cleanup()
core := cluster.Cores[0]
// Make a request to lazy load the plugin
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
} }
func TestSystemBackend_Plugin_auth(t *testing.T) { func TestSystemBackend_Plugin_auth(t *testing.T) {
cluster := testSystemBackendMock(t, 2, logical.TypeCredential) cluster := testSystemBackendMock(t, 1, 1, logical.TypeCredential)
defer cluster.Cleanup() defer cluster.Cleanup()
core := cluster.Cores[0]
// Make a request to lazy load the plugin
req := logical.TestRequest(t, logical.ReadOperation, "auth/mock-0/internal")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
}
func TestSystemBackend_Plugin_MismatchType(t *testing.T) {
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Replace the plugin with a credential backend
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials")
// Make a request to lazy load the now-credential plugin
// and expect an error
req := logical.TestRequest(t, logical.ReadOperation, "mock-0/internal")
req.ClientToken = core.Client.Token()
_, err := core.HandleRequest(req)
if err == nil {
t.Fatalf("expected error due to mismatch on error type: %s", err)
}
// Sleep a bit before cleanup is called
time.Sleep(1 * time.Second)
}
func TestSystemBackend_Plugin_CatalogRemoved(t *testing.T) {
t.Run("secret", func(t *testing.T) {
testPlugin_CatalogRemoved(t, logical.TypeLogical, false)
})
t.Run("auth", func(t *testing.T) {
testPlugin_CatalogRemoved(t, logical.TypeCredential, false)
})
t.Run("secret-mount-existing", func(t *testing.T) {
testPlugin_CatalogRemoved(t, logical.TypeLogical, true)
})
t.Run("auth-mount-existing", func(t *testing.T) {
testPlugin_CatalogRemoved(t, logical.TypeCredential, true)
})
}
func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMount bool) {
cluster := testSystemBackendMock(t, 1, 1, btype)
defer cluster.Cleanup()
core := cluster.Cores[0]
// Remove the plugin from the catalog
req := logical.TestRequest(t, logical.DeleteOperation, "sys/plugins/catalog/mock-plugin")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
if testMount {
// Add plugin back to the catalog
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical")
// Mount the plugin at the same path after plugin is re-added to the catalog
// and expect an error due to existing path.
var err error
switch btype {
case logical.TypeLogical:
_, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{
"type": "plugin",
"config": map[string]interface{}{
"plugin_name": "mock-plugin",
},
})
case logical.TypeCredential:
_, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{
"type": "plugin",
"plugin_name": "mock-plugin",
})
}
if err == nil {
t.Fatal("expected error when mounting on existing path")
}
}
} }
func TestSystemBackend_Plugin_autoReload(t *testing.T) { func TestSystemBackend_Plugin_autoReload(t *testing.T) {
cluster := testSystemBackendMock(t, 1, logical.TypeLogical) cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup() defer cluster.Cleanup()
core := cluster.Cores[0] core := cluster.Cores[0]
@@ -65,6 +245,35 @@ func TestSystemBackend_Plugin_autoReload(t *testing.T) {
} }
} }
func TestSystemBackend_Plugin_SealUnseal(t *testing.T) {
cluster := testSystemBackendMock(t, 1, 1, logical.TypeLogical)
defer cluster.Cleanup()
// Seal the cluster
cluster.EnsureCoresSealed(t)
// Unseal the cluster
barrierKeys := cluster.BarrierKeys
for _, core := range cluster.Cores {
for _, key := range barrierKeys {
_, err := core.Unseal(vault.TestKeyCopy(key))
if err != nil {
t.Fatal(err)
}
}
sealed, err := core.Sealed()
if err != nil {
t.Fatalf("err checking seal status: %s", err)
}
if sealed {
t.Fatal("should not be sealed")
}
// Wait for active so post-unseal takes place
// If it fails, it means unseal process failed
vault.TestWaitActive(t, core.Core)
}
}
func TestSystemBackend_Plugin_reload(t *testing.T) { func TestSystemBackend_Plugin_reload(t *testing.T) {
data := map[string]interface{}{ data := map[string]interface{}{
"plugin": "mock-plugin", "plugin": "mock-plugin",
@@ -77,8 +286,9 @@ func TestSystemBackend_Plugin_reload(t *testing.T) {
t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) }) t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })
} }
// Helper func to test different reload methods on plugin reload endpoint
func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) { func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) {
cluster := testSystemBackendMock(t, 2, logical.TypeLogical) cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical)
defer cluster.Cleanup() defer cluster.Cleanup()
core := cluster.Cores[0] core := cluster.Cores[0]
@@ -123,7 +333,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
// testSystemBackendMock returns a systemBackend with the desired number // testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends // of mounted mock plugin backends
func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.BackendType) *vault.TestCluster { func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType) *vault.TestCluster {
coreConfig := &vault.CoreConfig{ coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{ LogicalBackends: map[string]logical.Factory{
"plugin": plugin.Factory, "plugin": plugin.Factory,
@@ -134,7 +344,9 @@ func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.Back
} }
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler, HandlerFunc: vaulthttp.Handler,
KeepStandbysSealed: true,
NumCores: numCores,
}) })
cluster.Start() cluster.Start()
@@ -197,7 +409,8 @@ func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.Back
} }
func TestBackend_PluginMainLogical(t *testing.T) { func TestBackend_PluginMainLogical(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { args := []string{}
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
return return
} }
@@ -205,16 +418,16 @@ func TestBackend_PluginMainLogical(t *testing.T) {
if caPEM == "" { if caPEM == "" {
t.Fatal("CA cert not passed in") t.Fatal("CA cert not passed in")
} }
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
factoryFunc := mock.FactoryType(logical.TypeLogical)
args := []string{"--ca-cert=" + caPEM}
apiClientMeta := &pluginutil.APIClientMeta{} apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet() flags := apiClientMeta.FlagSet()
flags.Parse(args) flags.Parse(args)
tlsConfig := apiClientMeta.GetTLSConfig() tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
factoryFunc := mock.FactoryType(logical.TypeLogical)
err := lplugin.Serve(&lplugin.ServeOpts{ err := lplugin.Serve(&lplugin.ServeOpts{
BackendFactoryFunc: factoryFunc, BackendFactoryFunc: factoryFunc,
TLSProviderFunc: tlsProviderFunc, TLSProviderFunc: tlsProviderFunc,
@@ -225,7 +438,8 @@ func TestBackend_PluginMainLogical(t *testing.T) {
} }
func TestBackend_PluginMainCredentials(t *testing.T) { func TestBackend_PluginMainCredentials(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { args := []string{}
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" && os.Getenv(pluginutil.PluginMetadaModeEnv) != "true" {
return return
} }
@@ -233,16 +447,16 @@ func TestBackend_PluginMainCredentials(t *testing.T) {
if caPEM == "" { if caPEM == "" {
t.Fatal("CA cert not passed in") t.Fatal("CA cert not passed in")
} }
args = append(args, fmt.Sprintf("--ca-cert=%s", caPEM))
factoryFunc := mock.FactoryType(logical.TypeCredential)
args := []string{"--ca-cert=" + caPEM}
apiClientMeta := &pluginutil.APIClientMeta{} apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet() flags := apiClientMeta.FlagSet()
flags.Parse(args) flags.Parse(args)
tlsConfig := apiClientMeta.GetTLSConfig() tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
factoryFunc := mock.FactoryType(logical.TypeCredential)
err := lplugin.Serve(&lplugin.ServeOpts{ err := lplugin.Serve(&lplugin.ServeOpts{
BackendFactoryFunc: factoryFunc, BackendFactoryFunc: factoryFunc,
TLSProviderFunc: tlsProviderFunc, TLSProviderFunc: tlsProviderFunc,

View File

@@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/jsonutil"
@@ -663,11 +664,12 @@ func (c *Core) setupMounts() error {
c.mountsLock.Lock() c.mountsLock.Lock()
defer c.mountsLock.Unlock() defer c.mountsLock.Unlock()
var backend logical.Backend
var view *BarrierView var view *BarrierView
var err error var err error
for _, entry := range c.mounts.Entries { for _, entry := range c.mounts.Entries {
var backend logical.Backend
// Initialize the backend, special casing for system // Initialize the backend, special casing for system
barrierPath := backendBarrierPrefix + entry.UUID + "/" barrierPath := backendBarrierPrefix + entry.UUID + "/"
if entry.Type == "system" { if entry.Type == "system" {
@@ -686,6 +688,9 @@ func (c *Core) setupMounts() error {
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf) backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
if err != nil { if err != nil {
c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err) c.logger.Error("core: failed to create mount entry", "path", entry.Path, "error", err)
if errwrap.Contains(err, ErrPluginNotFound.Error()) && entry.Type == "plugin" {
goto ROUTER_MOUNT
}
return errLoadMountsFailed return errLoadMountsFailed
} }
if backend == nil { if backend == nil {
@@ -693,9 +698,8 @@ func (c *Core) setupMounts() error {
} }
// Check for the correct backend type // Check for the correct backend type
backendType := backend.Type() if entry.Type == "plugin" && backend.Type() != logical.TypeLogical {
if entry.Type == "plugin" && backendType != logical.TypeLogical { return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backend.Type())
return fmt.Errorf("cannot mount '%s' of type '%s' as a logical backend", entry.Config.PluginName, backendType)
} }
if err := backend.Initialize(); err != nil { if err := backend.Initialize(); err != nil {
@@ -710,7 +714,7 @@ func (c *Core) setupMounts() error {
ch.saltUUID = entry.UUID ch.saltUUID = entry.UUID
ch.storageView = view ch.storageView = view
} }
ROUTER_MOUNT:
// Mount the backend // Mount the backend
err = c.router.Mount(backend, entry.Path, entry, view) err = c.router.Mount(backend, entry.Path, entry, view)
if err != nil { if err != nil {

View File

@@ -19,6 +19,7 @@ import (
var ( var (
pluginCatalogPath = "core/plugin-catalog/" pluginCatalogPath = "core/plugin-catalog/"
ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured") ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured")
ErrPluginNotFound = errors.New("plugin not found in the catalog")
) )
// PluginCatalog keeps a record of plugins known to vault. External plugins need // PluginCatalog keeps a record of plugins known to vault. External plugins need
@@ -37,6 +38,10 @@ func (c *Core) setupPluginCatalog() error {
directory: c.pluginDirectory, directory: c.pluginDirectory,
} }
if c.logger.IsInfo() {
c.logger.Info("core: successfully setup plugin catalog", "plugin-directory", c.pluginDirectory)
}
return nil return nil
} }

View File

@@ -64,9 +64,12 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount
} }
// Build the paths // Build the paths
paths := backend.SpecialPaths() paths := new(logical.Paths)
if paths == nil { if backend != nil {
paths = new(logical.Paths) specialPaths := backend.SpecialPaths()
if specialPaths != nil {
paths = specialPaths
}
} }
// Create a mount entry // Create a mount entry

View File

@@ -335,11 +335,17 @@ func TestAddTestPlugin(t testing.T, c *Core, name, testFunc string) {
} }
sum := hash.Sum(nil) sum := hash.Sum(nil)
c.pluginCatalog.directory, err = filepath.EvalSymlinks(os.Args[0])
// Determine plugin directory path
fullPath, err := filepath.EvalSymlinks(os.Args[0])
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory) directoryPath := filepath.Dir(fullPath)
// Set core's plugin directory and plugin catalog directory
c.pluginDirectory = directoryPath
c.pluginCatalog.directory = directoryPath
command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc) command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc)
err = c.pluginCatalog.Set(name, command, sum) err = c.pluginCatalog.Set(name, command, sum)
@@ -585,6 +591,7 @@ func GenerateRandBytes(length int) ([]byte, error) {
} }
func TestWaitActive(t testing.T, core *Core) { func TestWaitActive(t testing.T, core *Core) {
t.Helper()
start := time.Now() start := time.Now()
var standby bool var standby bool
var err error var err error
@@ -627,6 +634,13 @@ func (c *TestCluster) Start() {
} }
} }
func (c *TestCluster) EnsureCoresSealed(t testing.T) {
t.Helper()
if err := c.ensureCoresSealed(); err != nil {
t.Fatal(err)
}
}
func (c *TestCluster) Cleanup() { func (c *TestCluster) Cleanup() {
// Close listeners // Close listeners
for _, core := range c.Cores { for _, core := range c.Cores {
@@ -638,25 +652,7 @@ func (c *TestCluster) Cleanup() {
} }
// Seal the cores // Seal the cores
for _, core := range c.Cores { c.ensureCoresSealed()
if err := core.Shutdown(); err != nil {
continue
}
timeout := time.Now().Add(60 * time.Second)
for {
if time.Now().After(timeout) {
continue
}
sealed, err := core.Sealed()
if err != nil {
continue
}
if sealed {
break
}
time.Sleep(250 * time.Millisecond)
}
}
// Remove any temp dir that exists // Remove any temp dir that exists
if c.TempDir != "" { if c.TempDir != "" {
@@ -667,6 +663,29 @@ func (c *TestCluster) Cleanup() {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
func (c *TestCluster) ensureCoresSealed() error {
for _, core := range c.Cores {
if err := core.Shutdown(); err != nil {
return err
}
timeout := time.Now().Add(60 * time.Second)
for {
if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for core to seal")
}
sealed, err := core.Sealed()
if err != nil {
return err
}
if sealed {
break
}
time.Sleep(250 * time.Millisecond)
}
}
return nil
}
type TestListener struct { type TestListener struct {
net.Listener net.Listener
Address *net.TCPAddr Address *net.TCPAddr
@@ -692,9 +711,29 @@ type TestClusterOptions struct {
KeepStandbysSealed bool KeepStandbysSealed bool
HandlerFunc func(*Core) http.Handler HandlerFunc func(*Core) http.Handler
BaseListenAddress string BaseListenAddress string
NumCores int
} }
var DefaultNumCores = 3
type certInfo struct {
cert *x509.Certificate
certPEM []byte
certBytes []byte
key *ecdsa.PrivateKey
keyPEM []byte
}
// NewTestCluster creates a new test cluster based on the provided core config
// and test cluster options.
func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster { func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *TestCluster {
var numCores int
if opts == nil || opts.NumCores == 0 {
numCores = DefaultNumCores
} else {
numCores = opts.NumCores
}
certIPs := []net.IP{ certIPs := []net.IP{
net.IPv6loopback, net.IPv6loopback,
net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1"),
@@ -770,270 +809,131 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
t.Fatal(err) t.Fatal(err)
} }
s1Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) var certInfoSlice []*certInfo
if err != nil {
t.Fatal(err)
}
s1CertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
s1CertBytes, err := x509.CreateCertificate(rand.Reader, s1CertTemplate, caCert, s1Key.Public(), caKey)
if err != nil {
t.Fatal(err)
}
s1Cert, err := x509.ParseCertificate(s1CertBytes)
if err != nil {
t.Fatal(err)
}
s1CertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: s1CertBytes,
}
s1CertPEM := pem.EncodeToMemory(s1CertPEMBlock)
s1MarshaledKey, err := x509.MarshalECPrivateKey(s1Key)
if err != nil {
t.Fatal(err)
}
s1KeyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: s1MarshaledKey,
}
s1KeyPEM := pem.EncodeToMemory(s1KeyPEMBlock)
s2Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) //
if err != nil { // Certs generation
t.Fatal(err) //
} for i := 0; i < numCores; i++ {
s2CertTemplate := &x509.Certificate{ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Subject: pkix.Name{ if err != nil {
CommonName: "localhost", t.Fatal(err)
}, }
DNSNames: []string{"localhost"}, certTemplate := &x509.Certificate{
IPAddresses: certIPs, Subject: pkix.Name{
ExtKeyUsage: []x509.ExtKeyUsage{ CommonName: "localhost",
x509.ExtKeyUsageServerAuth, },
x509.ExtKeyUsageClientAuth, DNSNames: []string{"localhost"},
}, IPAddresses: certIPs,
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, ExtKeyUsage: []x509.ExtKeyUsage{
SerialNumber: big.NewInt(mathrand.Int63()), x509.ExtKeyUsageServerAuth,
NotBefore: time.Now().Add(-30 * time.Second), x509.ExtKeyUsageClientAuth,
NotAfter: time.Now().Add(262980 * time.Hour), },
} KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
s2CertBytes, err := x509.CreateCertificate(rand.Reader, s2CertTemplate, caCert, s2Key.Public(), caKey) SerialNumber: big.NewInt(mathrand.Int63()),
if err != nil { NotBefore: time.Now().Add(-30 * time.Second),
t.Fatal(err) NotAfter: time.Now().Add(262980 * time.Hour),
} }
s2Cert, err := x509.ParseCertificate(s2CertBytes) certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, caCert, key.Public(), caKey)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
s2CertPEMBlock := &pem.Block{ cert, err := x509.ParseCertificate(certBytes)
Type: "CERTIFICATE", if err != nil {
Bytes: s2CertBytes, t.Fatal(err)
} }
s2CertPEM := pem.EncodeToMemory(s2CertPEMBlock) certPEMBlock := &pem.Block{
s2MarshaledKey, err := x509.MarshalECPrivateKey(s2Key) Type: "CERTIFICATE",
if err != nil { Bytes: certBytes,
t.Fatal(err) }
} certPEM := pem.EncodeToMemory(certPEMBlock)
s2KeyPEMBlock := &pem.Block{ marshaledKey, err := x509.MarshalECPrivateKey(key)
Type: "EC PRIVATE KEY", if err != nil {
Bytes: s2MarshaledKey, t.Fatal(err)
} }
s2KeyPEM := pem.EncodeToMemory(s2KeyPEMBlock) keyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledKey,
}
keyPEM := pem.EncodeToMemory(keyPEMBlock)
s3Key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) certInfoSlice = append(certInfoSlice, &certInfo{
if err != nil { cert: cert,
t.Fatal(err) certPEM: certPEM,
certBytes: certBytes,
key: key,
keyPEM: keyPEM,
})
} }
s3CertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: certIPs,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
s3CertBytes, err := x509.CreateCertificate(rand.Reader, s3CertTemplate, caCert, s3Key.Public(), caKey)
if err != nil {
t.Fatal(err)
}
s3Cert, err := x509.ParseCertificate(s3CertBytes)
if err != nil {
t.Fatal(err)
}
s3CertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: s3CertBytes,
}
s3CertPEM := pem.EncodeToMemory(s3CertPEMBlock)
s3MarshaledKey, err := x509.MarshalECPrivateKey(s3Key)
if err != nil {
t.Fatal(err)
}
s3KeyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: s3MarshaledKey,
}
s3KeyPEM := pem.EncodeToMemory(s3KeyPEMBlock)
logger := logformat.NewVaultLogger(log.LevelTrace)
// //
// Listener setup // Listener setup
// //
ports := []int{0, 0, 0} logger := logformat.NewVaultLogger(log.LevelTrace)
ports := make([]int, numCores)
if baseAddr != nil { if baseAddr != nil {
ports = []int{baseAddr.Port, baseAddr.Port + 1, baseAddr.Port + 2} for i := 0; i < numCores; i++ {
ports[i] = baseAddr.Port + i
}
} else { } else {
baseAddr = &net.TCPAddr{ baseAddr = &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: 0, Port: 0,
} }
} }
baseAddr.Port = ports[0]
ln, err := net.ListenTCP("tcp", baseAddr)
if err != nil {
t.Fatal(err)
}
s1CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port))
s1KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node1_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port))
err = ioutil.WriteFile(s1CertFile, s1CertPEM, 0755)
if err != nil {
t.Fatal(err)
}
err = ioutil.WriteFile(s1KeyFile, s1KeyPEM, 0755)
if err != nil {
t.Fatal(err)
}
s1TLSCert, err := tls.X509KeyPair(s1CertPEM, s1KeyPEM)
if err != nil {
t.Fatal(err)
}
s1CertGetter := reload.NewCertificateGetter(s1CertFile, s1KeyFile)
s1TLSConfig := &tls.Config{
Certificates: []tls.Certificate{s1TLSCert},
RootCAs: testCluster.RootCAs,
ClientCAs: testCluster.RootCAs,
ClientAuth: tls.VerifyClientCertIfGiven,
NextProtos: []string{"h2", "http/1.1"},
GetCertificate: s1CertGetter.GetCertificate,
}
s1TLSConfig.BuildNameToCertificate()
c1lns := []*TestListener{&TestListener{
Listener: tls.NewListener(ln, s1TLSConfig),
Address: ln.Addr().(*net.TCPAddr),
},
}
var handler1 http.Handler = http.NewServeMux()
server1 := &http.Server{
Handler: handler1,
}
if err := http2.ConfigureServer(server1, nil); err != nil {
t.Fatal(err)
}
baseAddr.Port = ports[1] listeners := [][]*TestListener{}
ln, err = net.ListenTCP("tcp", baseAddr) servers := []*http.Server{}
if err != nil { handlers := []http.Handler{}
t.Fatal(err) tlsConfigs := []*tls.Config{}
} certGetters := []*reload.CertificateGetter{}
s2CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port)) for i := 0; i < numCores; i++ {
s2KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node2_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port)) baseAddr.Port = ports[i]
err = ioutil.WriteFile(s2CertFile, s2CertPEM, 0755) ln, err := net.ListenTCP("tcp", baseAddr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = ioutil.WriteFile(s2KeyFile, s2KeyPEM, 0755) certFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_cert.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
if err != nil { keyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node%d_port_%d_key.pem", i+1, ln.Addr().(*net.TCPAddr).Port))
t.Fatal(err) err = ioutil.WriteFile(certFile, certInfoSlice[i].certPEM, 0755)
} if err != nil {
s2TLSCert, err := tls.X509KeyPair(s2CertPEM, s2KeyPEM) t.Fatal(err)
if err != nil { }
t.Fatal(err) err = ioutil.WriteFile(keyFile, certInfoSlice[i].keyPEM, 0755)
} if err != nil {
s2CertGetter := reload.NewCertificateGetter(s2CertFile, s2KeyFile) t.Fatal(err)
s2TLSConfig := &tls.Config{ }
Certificates: []tls.Certificate{s2TLSCert}, tlsCert, err := tls.X509KeyPair(certInfoSlice[i].certPEM, certInfoSlice[i].keyPEM)
RootCAs: testCluster.RootCAs, if err != nil {
ClientCAs: testCluster.RootCAs, t.Fatal(err)
ClientAuth: tls.VerifyClientCertIfGiven, }
NextProtos: []string{"h2", "http/1.1"}, certGetter := reload.NewCertificateGetter(certFile, keyFile)
GetCertificate: s2CertGetter.GetCertificate, certGetters = append(certGetters, certGetter)
} tlsConfig := &tls.Config{
s2TLSConfig.BuildNameToCertificate() Certificates: []tls.Certificate{tlsCert},
c2lns := []*TestListener{&TestListener{ RootCAs: testCluster.RootCAs,
Listener: tls.NewListener(ln, s2TLSConfig), ClientCAs: testCluster.RootCAs,
Address: ln.Addr().(*net.TCPAddr), ClientAuth: tls.VerifyClientCertIfGiven,
}, NextProtos: []string{"h2", "http/1.1"},
} GetCertificate: certGetter.GetCertificate,
var handler2 http.Handler = http.NewServeMux() }
server2 := &http.Server{ tlsConfig.BuildNameToCertificate()
Handler: handler2, tlsConfigs = append(tlsConfigs, tlsConfig)
} lns := []*TestListener{&TestListener{
if err := http2.ConfigureServer(server2, nil); err != nil { Listener: tls.NewListener(ln, tlsConfig),
t.Fatal(err) Address: ln.Addr().(*net.TCPAddr),
} },
}
baseAddr.Port = ports[2] listeners = append(listeners, lns)
ln, err = net.ListenTCP("tcp", baseAddr) var handler http.Handler = http.NewServeMux()
if err != nil { handlers = append(handlers, handler)
t.Fatal(err) server := &http.Server{
} Handler: handler,
s3CertFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_cert.pem", ln.Addr().(*net.TCPAddr).Port)) }
s3KeyFile := filepath.Join(testCluster.TempDir, fmt.Sprintf("node3_port_%d_key.pem", ln.Addr().(*net.TCPAddr).Port)) servers = append(servers, server)
err = ioutil.WriteFile(s3CertFile, s3CertPEM, 0755) if err := http2.ConfigureServer(server, nil); err != nil {
if err != nil { t.Fatal(err)
t.Fatal(err) }
}
err = ioutil.WriteFile(s3KeyFile, s3KeyPEM, 0755)
if err != nil {
t.Fatal(err)
}
s3TLSCert, err := tls.X509KeyPair(s3CertPEM, s3KeyPEM)
if err != nil {
t.Fatal(err)
}
s3CertGetter := reload.NewCertificateGetter(s3CertFile, s3KeyFile)
s3TLSConfig := &tls.Config{
Certificates: []tls.Certificate{s3TLSCert},
RootCAs: testCluster.RootCAs,
ClientCAs: testCluster.RootCAs,
ClientAuth: tls.VerifyClientCertIfGiven,
NextProtos: []string{"h2", "http/1.1"},
GetCertificate: s3CertGetter.GetCertificate,
}
s3TLSConfig.BuildNameToCertificate()
c3lns := []*TestListener{&TestListener{
Listener: tls.NewListener(ln, s3TLSConfig),
Address: ln.Addr().(*net.TCPAddr),
},
}
var handler3 http.Handler = http.NewServeMux()
server3 := &http.Server{
Handler: handler3,
}
if err := http2.ConfigureServer(server3, nil); err != nil {
t.Fatal(err)
} }
// Create three cores with the same physical and different redirect/cluster // Create three cores with the same physical and different redirect/cluster
@@ -1049,8 +949,8 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
LogicalBackends: make(map[string]logical.Factory), LogicalBackends: make(map[string]logical.Factory),
CredentialBackends: make(map[string]logical.Factory), CredentialBackends: make(map[string]logical.Factory),
AuditBackends: make(map[string]audit.Factory), AuditBackends: make(map[string]audit.Factory),
RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port), RedirectAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port),
ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", c1lns[0].Address.Port+105), ClusterAddr: fmt.Sprintf("https://127.0.0.1:%d", listeners[0][0].Address.Port+105),
DisableMlock: true, DisableMlock: true,
EnableUI: true, EnableUI: true,
} }
@@ -1126,39 +1026,21 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
coreConfig.HAPhysical = haPhys.(physical.HABackend) coreConfig.HAPhysical = haPhys.(physical.HABackend)
} }
c1, err := NewCore(coreConfig) cores := []*Core{}
if err != nil { for i := 0; i < numCores; i++ {
t.Fatalf("err: %v", err) coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port)
} if coreConfig.ClusterAddr != "" {
if opts != nil && opts.HandlerFunc != nil { coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", listeners[i][0].Address.Port+105)
handler1 = opts.HandlerFunc(c1) }
server1.Handler = handler1 c, err := NewCore(coreConfig)
} if err != nil {
t.Fatalf("err: %v", err)
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port) }
if coreConfig.ClusterAddr != "" { cores = append(cores, c)
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port+105) if opts != nil && opts.HandlerFunc != nil {
} handlers[i] = opts.HandlerFunc(c)
c2, err := NewCore(coreConfig) servers[i].Handler = handlers[i]
if err != nil { }
t.Fatalf("err: %v", err)
}
if opts != nil && opts.HandlerFunc != nil {
handler2 = opts.HandlerFunc(c2)
server2.Handler = handler2
}
coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port)
if coreConfig.ClusterAddr != "" {
coreConfig.ClusterAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port+105)
}
c3, err := NewCore(coreConfig)
if err != nil {
t.Fatalf("err: %v", err)
}
if opts != nil && opts.HandlerFunc != nil {
handler3 = opts.HandlerFunc(c3)
server3.Handler = handler3
} }
// //
@@ -1175,16 +1057,19 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
return ret return ret
} }
c2.SetClusterListenerAddrs(clusterAddrGen(c2lns)) if numCores > 1 {
c2.SetClusterHandler(handler2) for i := 1; i < numCores; i++ {
c3.SetClusterListenerAddrs(clusterAddrGen(c3lns)) cores[i].SetClusterListenerAddrs(clusterAddrGen(listeners[i]))
c3.SetClusterHandler(handler3) cores[i].SetClusterHandler(handlers[i])
}
}
keys, root := TestCoreInitClusterWrapperSetup(t, c1, clusterAddrGen(c1lns), handler1) keys, root := TestCoreInitClusterWrapperSetup(t, cores[0], clusterAddrGen(listeners[0]), handlers[0])
barrierKeys, _ := copystructure.Copy(keys) barrierKeys, _ := copystructure.Copy(keys)
testCluster.BarrierKeys = barrierKeys.([][]byte) testCluster.BarrierKeys = barrierKeys.([][]byte)
testCluster.RootToken = root testCluster.RootToken = root
// Write root token and barrier keys
err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755) err = ioutil.WriteFile(filepath.Join(testCluster.TempDir, "root_token"), []byte(root), 0755)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -1201,14 +1086,15 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
t.Fatal(err) t.Fatal(err)
} }
// Unseal first core
for _, key := range keys { for _, key := range keys {
if _, err := c1.Unseal(TestKeyCopy(key)); err != nil { if _, err := cores[0].Unseal(TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err) t.Fatalf("unseal err: %s", err)
} }
} }
// Verify unsealed // Verify unsealed
sealed, err := c1.Sealed() sealed, err := cores[0].Sealed()
if err != nil { if err != nil {
t.Fatalf("err checking seal status: %s", err) t.Fatalf("err checking seal status: %s", err)
} }
@@ -1216,41 +1102,38 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
t.Fatal("should not be sealed") t.Fatal("should not be sealed")
} }
TestWaitActive(t, c1) TestWaitActive(t, cores[0])
if opts == nil || !opts.KeepStandbysSealed { // Unseal other cores unless otherwise specified
for _, key := range keys { if (opts == nil || !opts.KeepStandbysSealed) && numCores > 1 {
if _, err := c2.Unseal(TestKeyCopy(key)); err != nil { for i := 1; i < numCores; i++ {
t.Fatalf("unseal err: %s", err) for _, key := range keys {
} if _, err := cores[i].Unseal(TestKeyCopy(key)); err != nil {
} t.Fatalf("unseal err: %s", err)
for _, key := range keys { }
if _, err := c3.Unseal(TestKeyCopy(key)); err != nil {
t.Fatalf("unseal err: %s", err)
} }
} }
// Let them come fully up to standby // Let them come fully up to standby
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
// Ensure cluster connection info is populated // Ensure cluster connection info is populated.
isLeader, _, _, err := c2.Leader() // Other cores should not come up as leaders.
if err != nil { for i := 1; i < numCores; i++ {
t.Fatal(err) isLeader, _, _, err := cores[i].Leader()
} if err != nil {
if isLeader { t.Fatal(err)
t.Fatal("c2 should not be leader") }
} if isLeader {
isLeader, _, _, err = c3.Leader() t.Fatalf("core[%d] should not be leader", i)
if err != nil { }
t.Fatal(err)
}
if isLeader {
t.Fatal("c3 should not be leader")
} }
} }
cluster, err := c1.Cluster() //
// Set test cluster core(s) and test cluster
//
cluster, err := cores[0].Cluster()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1278,65 +1161,27 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
} }
var ret []*TestClusterCore var ret []*TestClusterCore
t1 := &TestClusterCore{ for i := 0; i < numCores; i++ {
Core: c1, tcc := &TestClusterCore{
ServerKey: s1Key, Core: cores[i],
ServerKeyPEM: s1KeyPEM, ServerKey: certInfoSlice[i].key,
ServerCert: s1Cert, ServerKeyPEM: certInfoSlice[i].keyPEM,
ServerCertBytes: s1CertBytes, ServerCert: certInfoSlice[i].cert,
ServerCertPEM: s1CertPEM, ServerCertBytes: certInfoSlice[i].certBytes,
Listeners: c1lns, ServerCertPEM: certInfoSlice[i].certPEM,
Handler: handler1, Listeners: listeners[i],
Server: server1, Handler: handlers[i],
TLSConfig: s1TLSConfig, Server: servers[i],
Client: getAPIClient(c1lns[0].Address.Port, s1TLSConfig), TLSConfig: tlsConfigs[i],
Client: getAPIClient(listeners[i][0].Address.Port, tlsConfigs[i]),
}
tcc.ReloadFuncs = &cores[i].reloadFuncs
tcc.ReloadFuncsLock = &cores[i].reloadFuncsLock
tcc.ReloadFuncsLock.Lock()
(*tcc.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{certGetters[i].Reload}
tcc.ReloadFuncsLock.Unlock()
ret = append(ret, tcc)
} }
t1.ReloadFuncs = &c1.reloadFuncs
t1.ReloadFuncsLock = &c1.reloadFuncsLock
t1.ReloadFuncsLock.Lock()
(*t1.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s1CertGetter.Reload}
t1.ReloadFuncsLock.Unlock()
ret = append(ret, t1)
t2 := &TestClusterCore{
Core: c2,
ServerKey: s2Key,
ServerKeyPEM: s2KeyPEM,
ServerCert: s2Cert,
ServerCertBytes: s2CertBytes,
ServerCertPEM: s2CertPEM,
Listeners: c2lns,
Handler: handler2,
Server: server2,
TLSConfig: s2TLSConfig,
Client: getAPIClient(c2lns[0].Address.Port, s2TLSConfig),
}
t2.ReloadFuncs = &c2.reloadFuncs
t2.ReloadFuncsLock = &c2.reloadFuncsLock
t2.ReloadFuncsLock.Lock()
(*t2.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s2CertGetter.Reload}
t2.ReloadFuncsLock.Unlock()
ret = append(ret, t2)
t3 := &TestClusterCore{
Core: c3,
ServerKey: s3Key,
ServerKeyPEM: s3KeyPEM,
ServerCert: s3Cert,
ServerCertBytes: s3CertBytes,
ServerCertPEM: s3CertPEM,
Listeners: c3lns,
Handler: handler3,
Server: server3,
TLSConfig: s3TLSConfig,
Client: getAPIClient(c3lns[0].Address.Port, s3TLSConfig),
}
t3.ReloadFuncs = &c3.reloadFuncs
t3.ReloadFuncsLock = &c3.reloadFuncsLock
t3.ReloadFuncsLock.Lock()
(*t3.ReloadFuncs)["listener|tcp"] = []reload.ReloadFunc{s3CertGetter.Reload}
t3.ReloadFuncsLock.Unlock()
ret = append(ret, t3)
testCluster.Cores = ret testCluster.Cores = ret
return &testCluster return &testCluster