From a4180c193b5aa74bb173b058238585bec7e9ae8c Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Thu, 7 Dec 2023 12:36:17 +0000 Subject: [PATCH] Refactor plugin catalog and plugin runtime catalog into their own package (#24403) * Refactor plugin catalog into its own package * Fix some unnecessarily slow tests due to accidentally running multiple plugin processes * Clean up MakeTestPluginDir helper * Move getBackendVersion tests to plugin catalog package * Use corehelpers.MakeTestPlugin consistently * Fix semgrep failure: check for nil value from logical.Storage --- builtin/logical/database/backend_test.go | 11 +- .../logical/database/dbplugin/plugin_test.go | 9 +- .../logical/database/versioning_large_test.go | 14 +- builtin/plugin/backend_test.go | 9 +- command/auth_tune_test.go | 3 +- command/plugin_deregister_test.go | 12 +- command/plugin_info_test.go | 9 +- command/plugin_register_test.go | 6 +- command/plugin_reload_test.go | 3 +- command/secrets_tune_test.go | 3 +- command/server.go | 3 +- helper/testhelpers/corehelpers/corehelpers.go | 21 +- http/plugin_test.go | 8 +- vault/auth.go | 42 +- vault/core.go | 33 +- vault/dynamic_system_view.go | 3 +- vault/external_plugin_container_test.go | 48 -- vault/external_plugin_test.go | 115 ++--- .../plugin/external_plugin_test.go | 9 +- vault/external_tests/plugin/plugin_test.go | 38 +- vault/logical_system.go | 27 +- vault/logical_system_test.go | 151 ++----- vault/mount.go | 3 +- vault/plugincatalog/builtin_registry.go | 17 + vault/{ => plugincatalog}/plugin_catalog.go | 109 +++-- .../plugin_catalog_test.go | 427 ++++++++++++++---- .../plugin_runtime_catalog.go | 19 +- .../plugin_runtime_catalog_test.go | 34 +- vault/plugincatalog/testing.go | 94 ++++ vault/testing.go | 143 +----- 30 files changed, 747 insertions(+), 676 deletions(-) create mode 100644 vault/plugincatalog/builtin_registry.go rename vault/{ => plugincatalog}/plugin_catalog.go (95%) rename vault/{ => plugincatalog}/plugin_catalog_test.go (60%) rename vault/{ => plugincatalog}/plugin_runtime_catalog.go (90%) rename vault/{ => plugincatalog}/plugin_runtime_catalog_test.go (61%) create mode 100644 vault/plugincatalog/testing.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 424464934a..173a6eea93 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/plugins/database/postgresql" @@ -35,11 +36,14 @@ import ( ) func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView) { + t.Helper() + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "database": Factory, }, BuiltinRegistry: builtinplugins.Registry, + PluginDirectory: pluginDir, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ @@ -52,17 +56,20 @@ func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView) os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) sys := vault.TestDynamicSystemView(cores[0].Core, nil) - vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed", []string{}, "") + vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed", []string{}) return cluster, sys } func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { + t.Helper() + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "database": Factory, }, BuiltinRegistry: builtinplugins.Registry, + PluginDirectory: pluginDir, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ @@ -1473,7 +1480,7 @@ func TestBackend_AsyncClose(t *testing.T) { // Test that having a plugin that takes a LONG time to close will not cause the cleanup function to take // longer than 750ms. cluster, sys := getCluster(t) - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "hanging-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_Hanging", []string{}, "") + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "hanging-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_Hanging", []string{}) t.Cleanup(cluster.Cleanup) config := logical.TestBackendConfig() diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 43f4e2d144..8c9737cdbf 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -13,6 +13,7 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/database/dbplugin" "github.com/hashicorp/vault/sdk/helper/consts" @@ -106,14 +107,18 @@ func (m *mockPlugin) SetCredentials(ctx context.Context, statements dbplugin.Sta } func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { - cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + t.Helper() + pluginDir := corehelpers.MakeTestPluginDir(t) + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + PluginDirectory: pluginDir, + }, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, }) cluster.Start() cores := cluster.Cores sys := vault.TestDynamicSystemView(cores[0].Core, nil) - vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", consts.PluginTypeDatabase, "", "TestPlugin_GRPC_Main", []string{}, "") + vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", consts.PluginTypeDatabase, "", "TestPlugin_GRPC_Main", []string{}) return cluster, sys } diff --git a/builtin/logical/database/versioning_large_test.go b/builtin/logical/database/versioning_large_test.go index 482e5f3533..bacb4a6a71 100644 --- a/builtin/logical/database/versioning_large_test.go +++ b/builtin/logical/database/versioning_large_test.go @@ -25,9 +25,9 @@ func TestPlugin_lifecycle(t *testing.T) { cluster, sys := getCluster(t) defer cluster.Cleanup() - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", []string{}, "") - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{}, "") - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", []string{}, "") + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", []string{}) + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{}) + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", []string{}) config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -226,7 +226,7 @@ func TestPlugin_VersionSelection(t *testing.T) { defer cluster.Cleanup() for _, version := range []string{"v11.0.0", "v11.0.1-rc1", "v2.0.0"} { - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, version, "TestBackend_PluginMain_MockV5", []string{}, "") + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, version, "TestBackend_PluginMain_MockV5", []string{}) } config := logical.TestBackendConfig() @@ -312,11 +312,11 @@ func TestPlugin_VersionSelection(t *testing.T) { } // Register a newer version of the plugin, and ensure that's the new default version selected. - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "v11.0.1", "TestBackend_PluginMain_MockV5", []string{}, "") + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "v11.0.1", "TestBackend_PluginMain_MockV5", []string{}) t.Run("no version specified, new latest version selected", test(t, "", "v11.0.1")) // Register an unversioned plugin and ensure that is now selected when no version is specified. - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{}, "") + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{}) for name, tc := range map[string]struct { selectVersion string expectedVersion string @@ -397,7 +397,7 @@ func TestPlugin_VersionMustBeExplicitlyUpgraded(t *testing.T) { } // Register versioned plugin, and check that a new write to existing config doesn't upgrade the plugin implicitly. - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mysql-database-plugin", consts.PluginTypeDatabase, "v1.0.0", "TestBackend_PluginMain_MockV5", []string{}, "") + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mysql-database-plugin", consts.PluginTypeDatabase, "v1.0.0", "TestBackend_PluginMain_MockV5", []string{}) resp, err = b.HandleRequest(context.Background(), &logical.Request{ Operation: logical.UpdateOperation, Path: "config/db", diff --git a/builtin/plugin/backend_test.go b/builtin/plugin/backend_test.go index 2d06c2b318..c1b7a83d15 100644 --- a/builtin/plugin/backend_test.go +++ b/builtin/plugin/backend_test.go @@ -12,6 +12,7 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/plugin" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/logging" @@ -115,7 +116,11 @@ func TestBackend_PluginMain_Multiplexed(t *testing.T) { } func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func()) { - cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + t.Helper() + pluginDir := corehelpers.MakeTestPluginDir(t) + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + PluginDirectory: pluginDir, + }, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, }) cluster.Start() @@ -137,7 +142,7 @@ func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func()) os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd, []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd, []string{}) return config, func() { cluster.Cleanup() diff --git a/command/auth_tune_test.go b/command/auth_tune_test.go index 4bc8ddf465..a06f0d291b 100644 --- a/command/auth_tune_test.go +++ b/command/auth_tune_test.go @@ -78,8 +78,7 @@ func TestAuthTuneCommand_Run(t *testing.T) { t.Run("integration", func(t *testing.T) { t.Run("flags_all", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() diff --git a/command/plugin_deregister_test.go b/command/plugin_deregister_test.go index d80ed9f3fe..46e52df797 100644 --- a/command/plugin_deregister_test.go +++ b/command/plugin_deregister_test.go @@ -80,8 +80,7 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { t.Run("integration", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() @@ -138,8 +137,7 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { t.Run("integration with version", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() @@ -186,8 +184,7 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { t.Run("integration with missing version", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() @@ -233,8 +230,7 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { t.Run("deregister builtin", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() diff --git a/command/plugin_info_test.go b/command/plugin_info_test.go index f0e66d8be5..58525312d7 100644 --- a/command/plugin_info_test.go +++ b/command/plugin_info_test.go @@ -85,8 +85,7 @@ func TestPluginInfoCommand_Run(t *testing.T) { t.Run("default", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() @@ -116,8 +115,7 @@ func TestPluginInfoCommand_Run(t *testing.T) { t.Run("version flag", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() @@ -162,8 +160,7 @@ func TestPluginInfoCommand_Run(t *testing.T) { t.Run("field", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() diff --git a/command/plugin_register_test.go b/command/plugin_register_test.go index ed46a16caa..c50644ae43 100644 --- a/command/plugin_register_test.go +++ b/command/plugin_register_test.go @@ -85,8 +85,7 @@ func TestPluginRegisterCommand_Run(t *testing.T) { t.Run("integration", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() @@ -134,8 +133,7 @@ func TestPluginRegisterCommand_Run(t *testing.T) { t.Run("integration with version", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() diff --git a/command/plugin_reload_test.go b/command/plugin_reload_test.go index f3af275eb4..edbca3e4e9 100644 --- a/command/plugin_reload_test.go +++ b/command/plugin_reload_test.go @@ -85,8 +85,7 @@ func TestPluginReloadCommand_Run(t *testing.T) { t.Run("integration", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() diff --git a/command/secrets_tune_test.go b/command/secrets_tune_test.go index 5c1670db92..8b7965ff29 100644 --- a/command/secrets_tune_test.go +++ b/command/secrets_tune_test.go @@ -152,8 +152,7 @@ func TestSecretsTuneCommand_Run(t *testing.T) { t.Run("integration", func(t *testing.T) { t.Run("flags_all", func(t *testing.T) { t.Parallel() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - defer cleanup(t) + pluginDir := corehelpers.MakeTestPluginDir(t) client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() diff --git a/command/server.go b/command/server.go index 26cec2afbb..761ab74508 100644 --- a/command/server.go +++ b/command/server.go @@ -61,6 +61,7 @@ import ( sr "github.com/hashicorp/vault/serviceregistration" "github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault/hcp_link" + "github.com/hashicorp/vault/vault/plugincatalog" vaultseal "github.com/hashicorp/vault/vault/seal" "github.com/hashicorp/vault/version" "github.com/mitchellh/go-testing-interface" @@ -3152,7 +3153,7 @@ func initDevCore(c *ServerCommand, coreConfig *vault.CoreConfig, config *server. for _, name := range list { path := filepath.Join(f.Name(), name) if err := c.addPlugin(path, init.RootToken, core); err != nil { - if !errwrap.Contains(err, vault.ErrPluginBadType.Error()) { + if !errwrap.Contains(err, plugincatalog.ErrPluginBadType.Error()) { return fmt.Errorf("Error enabling plugin %s: %s", name, err) } pluginsNotLoaded = append(pluginsNotLoaded, name) diff --git a/helper/testhelpers/corehelpers/corehelpers.go b/helper/testhelpers/corehelpers/corehelpers.go index 79b5611644..582a4d9ef6 100644 --- a/helper/testhelpers/corehelpers/corehelpers.go +++ b/helper/testhelpers/corehelpers/corehelpers.go @@ -49,36 +49,27 @@ func RetryUntil(t testing.T, timeout time.Duration, f func() error) { // MakeTestPluginDir creates a temporary directory suitable for holding plugins. // This helper also resolves symlinks to make tests happy on OS X. -func MakeTestPluginDir(t testing.T) (string, func(t testing.T)) { - if t != nil { - t.Helper() - } +func MakeTestPluginDir(t testing.T) string { + t.Helper() dir, err := os.MkdirTemp("", "") if err != nil { - if t == nil { - panic(err) - } t.Fatal(err) } // OSX tempdir are /var, but actually symlinked to /private/var dir, err = filepath.EvalSymlinks(dir) if err != nil { - if t == nil { - panic(err) - } t.Fatal(err) } - return dir, func(t testing.T) { + t.Cleanup(func() { if err := os.RemoveAll(dir); err != nil { - if t == nil { - panic(err) - } t.Fatal(err) } - } + }) + + return dir } func NewMockBuiltinRegistry() *mockBuiltinRegistry { diff --git a/http/plugin_test.go b/http/plugin_test.go index 4606edfce6..fa67187621 100644 --- a/http/plugin_test.go +++ b/http/plugin_test.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/vault/api" bplugin "github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/helper/benchhelpers" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" @@ -25,7 +26,8 @@ import ( "github.com/hashicorp/vault/vault" ) -func getPluginClusterAndCore(t testing.TB, logger log.Logger) (*vault.TestCluster, *vault.TestClusterCore) { +func getPluginClusterAndCore(t *testing.T, logger log.Logger) (*vault.TestCluster, *vault.TestClusterCore) { + t.Helper() inm, err := inmem.NewTransactionalInmem(nil, logger) if err != nil { t.Fatal(err) @@ -35,12 +37,14 @@ func getPluginClusterAndCore(t testing.TB, logger log.Logger) (*vault.TestCluste t.Fatal(err) } + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ Physical: inm, HAPhysical: inmha.(physical.HABackend), LogicalBackends: map[string]logical.Factory{ "plugin": bplugin.Factory, }, + PluginDirectory: pluginDir, } cluster := vault.NewTestCluster(benchhelpers.TBtoT(t), coreConfig, &vault.TestClusterOptions{ @@ -54,7 +58,7 @@ func getPluginClusterAndCore(t testing.TB, logger log.Logger) (*vault.TestCluste os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) vault.TestWaitActive(benchhelpers.TBtoT(t), core.Core) - vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain", []string{}, "") + vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain", []string{}) // Mount the mock plugin err = core.Client.Sys().Mount("mock", &api.MountInput{ diff --git a/vault/auth.go b/vault/auth.go index 1ab5a887bc..7c62e77706 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -8,6 +8,7 @@ import ( "encoding/hex" "errors" "fmt" + "path/filepath" "strings" "github.com/hashicorp/go-secure-stdlib/strutil" @@ -18,6 +19,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/plugincatalog" ) const ( @@ -969,7 +971,7 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV if entry.Version != "" { errContext += fmt.Sprintf(", version=%s", entry.Version) } - return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + return nil, "", fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) } if len(plug.Sha256) > 0 { runningSha = hex.EncodeToString(plug.Sha256) @@ -1027,6 +1029,44 @@ func (c *Core) newCredentialBackend(ctx context.Context, entry *MountEntry, sysV return b, runningSha, nil } +func wrapFactoryCheckPerms(core *Core, f logical.Factory) logical.Factory { + return func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { + pluginName := conf.Config["plugin_name"] + pluginVersion := conf.Config["plugin_version"] + pluginTypeRaw := conf.Config["plugin_type"] + pluginType, err := consts.ParsePluginType(pluginTypeRaw) + if err != nil { + return nil, err + } + + pluginDescription := fmt.Sprintf("%s plugin %s", pluginTypeRaw, pluginName) + if pluginVersion != "" { + pluginDescription += " version " + pluginVersion + } + + plugin, err := core.pluginCatalog.Get(ctx, pluginName, pluginType, pluginVersion) + if err != nil { + return nil, fmt.Errorf("failed to find %s in plugin catalog: %w", pluginDescription, err) + } + if plugin == nil { + return nil, fmt.Errorf("failed to find %s in plugin catalog", pluginDescription) + } + if plugin.OCIImage != "" { + return f(ctx, conf) + } + + command, err := filepath.Rel(core.pluginDirectory, plugin.Command) + if err != nil { + return nil, fmt.Errorf("failed to compute plugin command: %w", err) + } + + if err := core.CheckPluginPerms(command); err != nil { + return nil, err + } + return f(ctx, conf) + } +} + // defaultAuthTable creates a default auth table func (c *Core) defaultAuthTable() *MountTable { table := &MountTable{ diff --git a/vault/core.go b/vault/core.go index e02e53a82e..020ca71ebc 100644 --- a/vault/core.go +++ b/vault/core.go @@ -63,6 +63,7 @@ import ( "github.com/hashicorp/vault/shamir" "github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/eventbus" + "github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/vault/quotas" vaultseal "github.com/hashicorp/vault/vault/seal" "github.com/hashicorp/vault/version" @@ -94,6 +95,11 @@ const ( // how policies should be applied coreGroupPolicyApplicationPath = "core/group-policy-application-mode" + // Path in storage for the plugin catalog. + pluginCatalogPath = "core/plugin-catalog/" + // Path in storage for the plugin runtime catalog. + pluginRuntimeCatalogPath = "core/plugin-runtime-catalog/" + // groupPolicyApplicationModeWithinNamespaceHierarchy is a configuration option for group // policy application modes, which allows only in-namespace-hierarchy policy application groupPolicyApplicationModeWithinNamespaceHierarchy = "within_namespace_hierarchy" @@ -242,7 +248,7 @@ type Core struct { // The registry of builtin plugins is passed in here as an interface because // if it's used directly, it results in import cycles. - builtinRegistry BuiltinRegistry + builtinRegistry plugincatalog.BuiltinRegistry // N.B.: This is used to populate a dev token down replication, as // otherwise, after replication is started, a dev would have to go through @@ -537,10 +543,10 @@ type Core struct { pluginFilePermissions int // pluginCatalog is used to manage plugin configurations - pluginCatalog *PluginCatalog + pluginCatalog *plugincatalog.PluginCatalog // pluginRuntimeCatalog is used to manage plugin runtime configurations - pluginRuntimeCatalog *PluginRuntimeCatalog + pluginRuntimeCatalog *plugincatalog.PluginRuntimeCatalog // The userFailedLoginInfo map has user failed login information. // It has user information (alias-name and mount accessor) as a key @@ -736,7 +742,7 @@ type CoreConfig struct { DevToken string - BuiltinRegistry BuiltinRegistry + BuiltinRegistry plugincatalog.BuiltinRegistry LogicalBackends map[string]logical.Factory @@ -2394,11 +2400,15 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c return err } } - if err := c.setupPluginRuntimeCatalog(ctx); err != nil { + if pluginRuntimeCatalog, err := plugincatalog.SetupPluginRuntimeCatalog(ctx, c.logger, NewBarrierView(c.barrier, pluginRuntimeCatalogPath)); err != nil { return err + } else { + c.pluginRuntimeCatalog = pluginRuntimeCatalog } - if err := c.setupPluginCatalog(ctx); err != nil { + if pluginCatalog, err := plugincatalog.SetupPluginCatalog(ctx, c.logger, c.builtinRegistry, NewBarrierView(c.barrier, pluginCatalogPath), c.pluginDirectory, c.enableMlock, c.pluginRuntimeCatalog); err != nil { return err + } else { + c.pluginCatalog = pluginCatalog } if err := c.loadMounts(ctx); err != nil { return err @@ -3386,17 +3396,6 @@ func (c *Core) MetricSink() *metricsutil.ClusterMetricSink { return c.metricSink } -// BuiltinRegistry is an interface that allows the "vault" package to use -// the registry of builtin plugins without getting an import cycle. It -// also allows for mocking the registry easily. -type BuiltinRegistry interface { - Contains(name string, pluginType consts.PluginType) bool - Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) - Keys(pluginType consts.PluginType) []string - DeprecationStatus(name string, pluginType consts.PluginType) (consts.DeprecationStatus, bool) - IsBuiltinEntPlugin(name string, pluginType consts.PluginType) bool -} - func (c *Core) AuditLogger() AuditLogger { return &basicAuditor{c: c} } diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 9b03d3303a..e881bad2bb 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/version" ) @@ -283,7 +284,7 @@ func (d dynamicSystemView) LookupPluginVersion(ctx context.Context, name string, if version != "" { errContext += fmt.Sprintf(", version=%s", version) } - return nil, fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + return nil, fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) } return r, nil diff --git a/vault/external_plugin_container_test.go b/vault/external_plugin_container_test.go index 03aca4ca6d..b77545372c 100644 --- a/vault/external_plugin_container_test.go +++ b/vault/external_plugin_container_test.go @@ -5,7 +5,6 @@ package vault import ( "context" - "encoding/hex" "fmt" "os" "os/exec" @@ -15,8 +14,6 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" "github.com/hashicorp/vault/sdk/helper/consts" - "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" - "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" ) @@ -111,51 +108,6 @@ func mountAndUnmountContainerPlugin_WithRuntime(t *testing.T, c *Core, plugin pl routeRequest(false) } -func TestExternalPluginInContainer_GetBackendTypeVersion(t *testing.T) { - c, plugins := testClusterWithContainerPlugins(t, []consts.PluginType{ - consts.PluginTypeCredential, - consts.PluginTypeSecrets, - consts.PluginTypeDatabase, - }) - for _, plugin := range plugins { - t.Run(plugin.Typ.String(), func(t *testing.T) { - for _, ociRuntime := range []string{"runc", "runsc"} { - t.Run(ociRuntime, func(t *testing.T) { - if _, err := exec.LookPath(ociRuntime); err != nil { - t.Skipf("Skipping test as %s not found on path", ociRuntime) - } - shaBytes, _ := hex.DecodeString(plugin.ImageSha256) - entry := &pluginutil.PluginRunner{ - Name: plugin.Name, - OCIImage: plugin.Image, - Args: nil, - Sha256: shaBytes, - Builtin: false, - Runtime: ociRuntime, - RuntimeConfig: &pluginruntimeutil.PluginRuntimeConfig{ - OCIRuntime: ociRuntime, - }, - } - - var version logical.PluginVersion - var err error - if plugin.Typ == consts.PluginTypeDatabase { - version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) - } else { - version, err = c.pluginCatalog.getBackendRunningVersion(context.Background(), entry) - } - if err != nil { - t.Fatal(err) - } - if version.Version != plugin.Version { - t.Errorf("Expected to get version %v but got %v", plugin.Version, version.Version) - } - }) - } - }) - } -} - func registerContainerPlugin(t *testing.T, sys *SystemBackend, pluginName, pluginType, version, sha, image, runtime string) { t.Helper() req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("plugins/catalog/%s/%s", pluginType, pluginName)) diff --git a/vault/external_plugin_test.go b/vault/external_plugin_test.go index e0fa2172b5..03151d985e 100644 --- a/vault/external_plugin_test.go +++ b/vault/external_plugin_test.go @@ -5,12 +5,10 @@ package vault import ( "context" - "encoding/hex" "errors" "fmt" "os" "path" - "path/filepath" "strings" "testing" @@ -19,10 +17,10 @@ import ( "github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" - "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/plugin" "github.com/hashicorp/vault/sdk/plugin/mock" + "github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/version" ) @@ -31,8 +29,7 @@ const vaultTestingMockPluginEnv = "VAULT_TESTING_MOCK_PLUGIN" // version is used to override the plugin's self-reported version func testCoreWithPlugins(t *testing.T, typ consts.PluginType, versions ...string) (*Core, []pluginhelpers.TestPlugin) { t.Helper() - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) var plugins []pluginhelpers.TestPlugin for _, version := range versions { @@ -186,8 +183,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { } func TestCore_EnableExternalPlugin_Deregister_SealUnseal(t *testing.T) { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) // create an external plugin to shadow the builtin "pending-removal-test-plugin" pluginName := "therug" @@ -207,7 +203,7 @@ func TestCore_EnableExternalPlugin_Deregister_SealUnseal(t *testing.T) { // Register a plugin registerPlugin(t, c.systemBackend, pluginName, consts.PluginTypeCredential.String(), "", plugin.Sha256, plugin.FileName) mountPlugin(t, c.systemBackend, pluginName, consts.PluginTypeCredential, "", "") - plugct := len(c.pluginCatalog.externalPlugins) + plugct := c.pluginCatalog.Processes() if plugct != 1 { t.Fatalf("expected a single external plugin entry after registering, got: %d", plugct) } @@ -228,7 +224,7 @@ func TestCore_EnableExternalPlugin_Deregister_SealUnseal(t *testing.T) { } } - plugct = len(c.pluginCatalog.externalPlugins) + plugct = c.pluginCatalog.Processes() if plugct != 0 { t.Fatalf("expected no plugin entries after unseal, got: %d", plugct) } @@ -258,8 +254,7 @@ func TestCore_EnableExternalPlugin_Deregister_SealUnseal(t *testing.T) { // version store is cleared. Vault sees the next unseal as a major upgrade and // should immediately shut down. func TestCore_Unseal_isMajorVersionFirstMount_PendingRemoval_Plugin(t *testing.T) { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) // create an external plugin to shadow the builtin "pending-removal-test-plugin" pluginName := "pending-removal-test-plugin" @@ -337,8 +332,7 @@ func TestCore_Unseal_isMajorVersionFirstMount_PendingRemoval_Plugin(t *testing.T } func TestCore_EnableExternalPlugin_PendingRemoval(t *testing.T) { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) // create an external plugin to shadow the builtin "pending-removal-test-plugin" pluginName := "pending-removal-test-plugin" @@ -372,8 +366,7 @@ func TestCore_EnableExternalPlugin_PendingRemoval(t *testing.T) { } func TestCore_EnableExternalPlugin_ShadowBuiltin(t *testing.T) { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) // create an external plugin to shadow the builtin "approle" plugin := pluginhelpers.CompilePlugin(t, consts.PluginTypeCredential, "v1.2.3", pluginDir) @@ -451,8 +444,7 @@ func TestCore_EnableExternalPlugin_ShadowBuiltin(t *testing.T) { } func TestCore_EnableExternalKv_MultipleVersions(t *testing.T) { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) // new kv plugin can be registered but not mounted plugin := pluginhelpers.CompilePlugin(t, consts.PluginTypeSecrets, "v1.2.3", pluginDir) @@ -504,8 +496,7 @@ func TestCore_EnableExternalKv_MultipleVersions(t *testing.T) { } func TestCore_EnableExternalNoop_MultipleVersions(t *testing.T) { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) // new noop plugin can be registered but not mounted plugin := pluginhelpers.CompilePlugin(t, consts.PluginTypeCredential, "v1.2.3", pluginDir) @@ -626,7 +617,7 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) { }, } resp, _ := c.systemBackend.HandleRequest(namespace.RootContext(nil), req) - if resp == nil || !resp.IsError() || !strings.Contains(resp.Error().Error(), ErrPluginNotFound.Error()) { + if resp == nil || !resp.IsError() || !strings.Contains(resp.Error().Error(), plugincatalog.ErrPluginNotFound.Error()) { t.Fatalf("Expected to get plugin not found but got: %v", resp.Error()) } }) @@ -663,55 +654,6 @@ func TestCore_EnableExternalCredentialPlugin_InvalidName(t *testing.T) { } } -func TestExternalPlugin_getBackendTypeVersion(t *testing.T) { - for name, tc := range map[string]struct { - pluginType consts.PluginType - setRunningVersion string - }{ - "external credential plugin": { - pluginType: consts.PluginTypeCredential, - setRunningVersion: "v1.2.3", - }, - "external secrets plugin": { - pluginType: consts.PluginTypeSecrets, - setRunningVersion: "v1.2.3", - }, - "external database plugin": { - pluginType: consts.PluginTypeDatabase, - setRunningVersion: "v1.2.3", - }, - } { - t.Run(name, func(t *testing.T) { - c, plugins := testCoreWithPlugins(t, tc.pluginType, tc.setRunningVersion) - registerPlugin(t, c.systemBackend, plugins[0].Name, tc.pluginType.String(), tc.setRunningVersion, plugins[0].Sha256, plugins[0].FileName) - - shaBytes, _ := hex.DecodeString(plugins[0].Sha256) - commandFull := filepath.Join(c.pluginCatalog.directory, plugins[0].FileName) - entry := &pluginutil.PluginRunner{ - Name: plugins[0].Name, - Command: commandFull, - Args: nil, - Sha256: shaBytes, - Builtin: false, - } - - var version logical.PluginVersion - var err error - if tc.pluginType == consts.PluginTypeDatabase { - version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) - } else { - version, err = c.pluginCatalog.getBackendRunningVersion(context.Background(), entry) - } - if err != nil { - t.Fatal(err) - } - if version.Version != tc.setRunningVersion { - t.Errorf("Expected to get version %v but got %v", tc.setRunningVersion, version.Version) - } - }) - } -} - func TestExternalPlugin_CheckFilePermissions(t *testing.T) { // Turn on the check. if err := os.Setenv(consts.VaultEnableFilePermissionsCheckEnv, "true"); err != nil { @@ -788,7 +730,10 @@ func TestExternalPlugin_CheckFilePermissions(t *testing.T) { func TestExternalPlugin_DifferentVersionsAndArgs_AreNotMultiplexed(t *testing.T) { env := []string{fmt.Sprintf("%s=yes", vaultTestingMockPluginEnv)} - core, _, _ := TestCoreUnsealed(t) + pluginDir := corehelpers.MakeTestPluginDir(t) + core, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: pluginDir, + }) for i, tc := range []struct { version string @@ -798,29 +743,32 @@ func TestExternalPlugin_DifferentVersionsAndArgs_AreNotMultiplexed(t *testing.T) {"v1.2.4", "TestBackend_PluginMain_Multiplexed_Logical_v124"}, } { // Register and mount plugins. - TestAddTestPlugin(t, core, "mux-secret", consts.PluginTypeSecrets, tc.version, tc.testName, env, "") + TestAddTestPlugin(t, core, "mux-secret", consts.PluginTypeSecrets, tc.version, tc.testName, env) mountPlugin(t, core.systemBackend, "mux-secret", consts.PluginTypeSecrets, tc.version, fmt.Sprintf("foo%d", i)) } - if len(core.pluginCatalog.externalPlugins) != 2 { - t.Fatalf("expected 2 external plugins, but got %d", len(core.pluginCatalog.externalPlugins)) + if core.pluginCatalog.Processes() != 2 { + t.Fatalf("expected 2 external plugins, but got %d", core.pluginCatalog.Processes()) } } func TestExternalPlugin_DifferentTypes_AreNotMultiplexed(t *testing.T) { const version = "v1.2.3" env := []string{fmt.Sprintf("%s=yes", vaultTestingMockPluginEnv)} - core, _, _ := TestCoreUnsealed(t) + pluginDir := corehelpers.MakeTestPluginDir(t) + core, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: pluginDir, + }) // Register and mount plugins. - TestAddTestPlugin(t, core, "mux-aws", consts.PluginTypeSecrets, version, "TestBackend_PluginMain_Multiplexed_Logical_v123", env, "") - TestAddTestPlugin(t, core, "mux-aws", consts.PluginTypeCredential, version, "TestBackend_PluginMain_Multiplexed_Credential_v123", env, "") + TestAddTestPlugin(t, core, "mux-aws", consts.PluginTypeSecrets, version, "TestBackend_PluginMain_Multiplexed_Logical_v123", env) + TestAddTestPlugin(t, core, "mux-aws", consts.PluginTypeCredential, version, "TestBackend_PluginMain_Multiplexed_Credential_v123", env) mountPlugin(t, core.systemBackend, "mux-aws", consts.PluginTypeSecrets, version, "") mountPlugin(t, core.systemBackend, "mux-aws", consts.PluginTypeCredential, version, "") - if len(core.pluginCatalog.externalPlugins) != 2 { - t.Fatalf("expected 2 external plugins, but got %d", len(core.pluginCatalog.externalPlugins)) + if core.pluginCatalog.Processes() != 2 { + t.Fatalf("expected 2 external plugins, but got %d", core.pluginCatalog.Processes()) } } @@ -834,16 +782,19 @@ func TestExternalPlugin_DifferentEnv_AreNotMultiplexed(t *testing.T) { "FOO=BAR", } - core, _, _ := TestCoreUnsealed(t) + pluginDir := corehelpers.MakeTestPluginDir(t) + core, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: pluginDir, + }) // Register and mount plugins. for i, env := range [][]string{baseEnv, alteredEnv} { - TestAddTestPlugin(t, core, "mux-secret", consts.PluginTypeSecrets, version, "TestBackend_PluginMain_Multiplexed_Logical_v123", env, "") + TestAddTestPlugin(t, core, "mux-secret", consts.PluginTypeSecrets, version, "TestBackend_PluginMain_Multiplexed_Logical_v123", env) mountPlugin(t, core.systemBackend, "mux-secret", consts.PluginTypeSecrets, version, fmt.Sprintf("foo%d", i)) } - if len(core.pluginCatalog.externalPlugins) != 2 { - t.Fatalf("expected 2 external plugins, but got %d", len(core.pluginCatalog.externalPlugins)) + if core.pluginCatalog.Processes() != 2 { + t.Fatalf("expected 2 external plugins, but got %d", core.pluginCatalog.Processes()) } } diff --git a/vault/external_tests/plugin/external_plugin_test.go b/vault/external_tests/plugin/external_plugin_test.go index b0c00c46da..1e78cf0be1 100644 --- a/vault/external_tests/plugin/external_plugin_test.go +++ b/vault/external_tests/plugin/external_plugin_test.go @@ -32,8 +32,7 @@ import ( ) func getClusterWithFileAuditBackend(t *testing.T, typ consts.PluginType, numCores int) *vault.TestCluster { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ PluginDirectory: pluginDir, LogicalBackends: map[string]logical.Factory{ @@ -63,8 +62,7 @@ func getClusterWithFileAuditBackend(t *testing.T, typ consts.PluginType, numCore } func getCluster(t *testing.T, typ consts.PluginType, numCores int) *vault.TestCluster { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ PluginDirectory: pluginDir, LogicalBackends: map[string]logical.Factory{ @@ -94,8 +92,7 @@ func getCluster(t *testing.T, typ consts.PluginType, numCores int) *vault.TestCl // rollback and reload a plugin without triggering race conditions by the go // race detector func TestExternalPlugin_RollbackAndReload(t *testing.T) { - pluginDir, cleanup := corehelpers.MakeTestPluginDir(t) - t.Cleanup(func() { cleanup(t) }) + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ // set rollback period to a short interval to make conditions more "racy" RollbackPeriod: 1 * time.Second, diff --git a/vault/external_tests/plugin/plugin_test.go b/vault/external_tests/plugin/plugin_test.go index 80486f862b..e29340a17c 100644 --- a/vault/external_tests/plugin/plugin_test.go +++ b/vault/external_tests/plugin/plugin_test.go @@ -6,15 +6,14 @@ package plugin_test import ( "context" "fmt" - "io/ioutil" "os" "path/filepath" "testing" - "time" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginutil" @@ -246,7 +245,7 @@ func TestSystemBackend_Plugin_MismatchType(t *testing.T) { core := cluster.Cores[0] // Add a credential backend with the same name - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "", "TestBackend_PluginMainCredentials", []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "", "TestBackend_PluginMainCredentials", []string{}) // Make a request to lazy load the now-credential plugin // and expect an error @@ -256,9 +255,6 @@ func TestSystemBackend_Plugin_MismatchType(t *testing.T) { if err != nil { t.Fatalf("adding a same-named plugin of a different type should be no problem: %s", err) } - - // Sleep a bit before cleanup is called - time.Sleep(1 * time.Second) }) } } @@ -344,13 +340,13 @@ func testPlugin_CatalogRemoved(t *testing.T, btype logical.BackendType, testMoun switch btype { case logical.TypeLogical: // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", logicalVersionMap[tc.pluginVersion], []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", logicalVersionMap[tc.pluginVersion], []string{}) _, err = core.Client.Logical().Write("sys/mounts/mock-0", map[string]interface{}{ "type": "test", }) case logical.TypeCredential: // Add plugin back to the catalog - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "", credentialVersionMap[tc.pluginVersion], []string{}, "") + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "", credentialVersionMap[tc.pluginVersion], []string{}) _, err = core.Client.Logical().Write("sys/auth/mock-0", map[string]interface{}{ "type": "test", }) @@ -588,6 +584,8 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} // // The mounts are mounted at sys/mounts/mock-[numMounts] or sys/auth/mock-[numMounts] func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType logical.BackendType, pluginVersion string) *vault.TestCluster { + t.Helper() + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "plugin": plugin.Factory, @@ -595,19 +593,14 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo CredentialBackends: map[string]logical.Factory{ "plugin": plugin.Factory, }, - } - - // Create a tempdir, cluster.Cleanup will clean up this directory - tempDir, err := ioutil.TempDir("", "vault-test-cluster") - if err != nil { - t.Fatal(err) + PluginDirectory: pluginDir, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, KeepStandbysSealed: true, NumCores: numCores, - TempDir: tempDir, + TempDir: pluginDir, }) cluster.Start() @@ -620,7 +613,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo switch backendType { case logical.TypeLogical: plugin := logicalVersionMap[pluginVersion] - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", plugin, env, tempDir) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", plugin, env) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ @@ -636,7 +629,7 @@ func testSystemBackendMock(t *testing.T, numCores, numMounts int, backendType lo } case logical.TypeCredential: plugin := credentialVersionMap[pluginVersion] - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "", plugin, env, tempDir) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeCredential, "", plugin, env) for i := 0; i < numMounts; i++ { // Alternate input styles for plugin_name on every other mount options := map[string]interface{}{ @@ -666,22 +659,19 @@ func TestSystemBackend_Plugin_Env(t *testing.T) { // testSystemBackend_SingleCluster_Env is a helper func that returns a single // cluster and a single mounted plugin logical backend. func testSystemBackend_SingleCluster_Env(t *testing.T, env []string) *vault.TestCluster { + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "test": plugin.Factory, }, - } - // Create a tempdir, cluster.Cleanup will clean up this directory - tempDir, err := ioutil.TempDir("", "vault-test-cluster") - if err != nil { - t.Fatal(err) + PluginDirectory: pluginDir, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, KeepStandbysSealed: true, NumCores: 1, - TempDir: tempDir, + TempDir: pluginDir, }) cluster.Start() @@ -690,7 +680,7 @@ func testSystemBackend_SingleCluster_Env(t *testing.T, env []string) *vault.Test client := core.Client env = append([]string{pluginutil.PluginCACertPEMEnv + "=" + cluster.CACertPEMFile}, env...) - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestBackend_PluginMainEnv", env, tempDir) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestBackend_PluginMainEnv", env) options := map[string]interface{}{ "type": "mock-plugin", } diff --git a/vault/logical_system.go b/vault/logical_system.go index bd5c93e227..bf3a37719a 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -49,6 +49,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/roottoken" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/version" "github.com/mitchellh/mapstructure" "golang.org/x/crypto/sha3" @@ -453,9 +454,6 @@ func (b *SystemBackend) handlePluginCatalogUntypedList(ctx context.Context, _ *l return nil, err } - // Sort for consistent ordering - sortVersionedPlugins(versioned) - versionedPlugins = append(versionedPlugins, versioned...) } @@ -495,23 +493,6 @@ func (b *SystemBackend) handlePluginCatalogUntypedList(ctx context.Context, _ *l }, nil } -func sortVersionedPlugins(versionedPlugins []pluginutil.VersionedPlugin) { - sort.SliceStable(versionedPlugins, func(i, j int) bool { - left, right := versionedPlugins[i], versionedPlugins[j] - if left.Type != right.Type { - return left.Type < right.Type - } - if left.Name != right.Name { - return left.Name < right.Name - } - if left.Version != right.Version { - return right.SemanticVersion.GreaterThan(left.SemanticVersion) - } - - return false - }) -} - func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) { pluginName := d.Get("name").(string) if pluginName == "" { @@ -604,7 +585,7 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica Sha256: sha256Bytes, }) if err != nil { - if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") { + if errors.Is(err, plugincatalog.ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") { return logical.ErrorResponse(err.Error()), nil } return nil, err @@ -649,7 +630,7 @@ func (b *SystemBackend) handlePluginCatalogRead(ctx context.Context, _ *logical. command := plugin.Command if !plugin.Builtin && plugin.OCIImage == "" { - command, err = filepath.Rel(b.Core.pluginCatalog.directory, command) + command, err = filepath.Rel(b.Core.pluginDirectory, command) if err != nil { return nil, err } @@ -877,7 +858,7 @@ func (b *SystemBackend) handlePluginRuntimeCatalogRead(ctx context.Context, _ *l } conf, err := b.Core.pluginRuntimeCatalog.Get(ctx, runtimeName, runtimeType) - if err != nil && !errors.Is(err, ErrPluginRuntimeNotFound) { + if err != nil && !errors.Is(err, plugincatalog.ErrPluginRuntimeNotFound) { return nil, err } if conf == nil { diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index baa77c74f0..9b28f1bfc3 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -24,7 +24,6 @@ import ( "github.com/hashicorp/go-hclog" wrapping "github.com/hashicorp/go-kms-wrapping/v2" aeadwrapper "github.com/hashicorp/go-kms-wrapping/wrappers/aead/v2" - semver "github.com/hashicorp/go-version" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/experiments" @@ -43,6 +42,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/helper/testhelpers/schema" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/vault/seal" "github.com/hashicorp/vault/version" "github.com/mitchellh/mapstructure" @@ -2173,7 +2173,14 @@ func TestSystemBackend_disableAuth(t *testing.T) { } func TestSystemBackend_tuneAuth(t *testing.T) { - c, b, _ := testCoreSystemBackend(t) + tempDir, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + c, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: tempDir, + }) + b := c.systemBackend c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { return &NoopBackend{BackendType: logical.TypeCredential}, nil } @@ -2188,7 +2195,7 @@ func TestSystemBackend_tuneAuth(t *testing.T) { } schema.ValidateResponse( t, - schema.GetResponseSchema(t, b.(*SystemBackend).Route(req.Path), req.Operation), + schema.GetResponseSchema(t, b.Route(req.Path), req.Operation), resp, true, ) @@ -2209,24 +2216,19 @@ func TestSystemBackend_tuneAuth(t *testing.T) { req.Data["description"] = "" req.Data["plugin_version"] = "v1.0.0" resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err == nil || resp == nil || !resp.IsError() || !strings.Contains(resp.Error().Error(), ErrPluginNotFound.Error()) { + if err == nil || resp == nil || !resp.IsError() || !strings.Contains(resp.Error().Error(), plugincatalog.ErrPluginNotFound.Error()) { t.Fatalf("expected tune request to fail, but got resp: %#v, err: %s", resp, err) } schema.ValidateResponse( t, - schema.GetResponseSchema(t, b.(*SystemBackend).Route(req.Path), req.Operation), + schema.GetResponseSchema(t, b.Route(req.Path), req.Operation), resp, true, ) // Register the plugin in the catalog, and then try the same request again. { - tempDir, err := filepath.EvalSymlinks(t.TempDir()) - if err != nil { - t.Fatal(err) - } - c.pluginCatalog.directory = tempDir file, err := os.Create(filepath.Join(tempDir, "foo")) if err != nil { t.Fatal(err) @@ -2263,7 +2265,7 @@ func TestSystemBackend_tuneAuth(t *testing.T) { } schema.ValidateResponse( t, - schema.GetResponseSchema(t, b.(*SystemBackend).Route(req.Path), req.Operation), + schema.GetResponseSchema(t, b.Route(req.Path), req.Operation), resp, true, ) @@ -3323,13 +3325,14 @@ func testCoreSystemBackendRaw(t *testing.T) (*Core, logical.Backend, string) { } func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { - c, b, _ := testCoreSystemBackend(t) - // Bootstrap the pluginCatalog sym, err := filepath.EvalSymlinks(os.TempDir()) if err != nil { t.Fatalf("error: %v", err) } - c.pluginCatalog.directory = sym + c, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: sym, + }) + b := c.systemBackend req := logical.TestRequest(t, logical.ListOperation, "plugins/catalog/database") resp, err := b.HandleRequest(namespace.RootContext(nil), req) @@ -3339,7 +3342,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { schema.ValidateResponse( t, - schema.GetResponseSchema(t, b.(*SystemBackend).Route(req.Path), req.Operation), + schema.GetResponseSchema(t, b.Route(req.Path), req.Operation), resp, true, ) @@ -3356,7 +3359,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { schema.ValidateResponse( t, - schema.GetResponseSchema(t, b.(*SystemBackend).Route(req.Path), req.Operation), + schema.GetResponseSchema(t, b.Route(req.Path), req.Operation), resp, true, ) @@ -3401,7 +3404,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { schema.ValidateResponse( t, - schema.GetResponseSchema(t, b.(*SystemBackend).Route(req.Path), req.Operation), + schema.GetResponseSchema(t, b.Route(req.Path), req.Operation), resp, true, ) @@ -3440,7 +3443,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { schema.ValidateResponse( t, - schema.GetResponseSchema(t, b.(*SystemBackend).Route(req.Path), req.Operation), + schema.GetResponseSchema(t, b.Route(req.Path), req.Operation), resp, true, ) @@ -3500,13 +3503,14 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { // TestSystemBackend_PluginCatalog_ContainerCRUD tests that plugins registered // with oci_image set get recorded properly in the catalog. func TestSystemBackend_PluginCatalog_ContainerCRUD(t *testing.T) { - c, b, _ := testCoreSystemBackend(t) - // Bootstrap the pluginCatalog sym, err := filepath.EvalSymlinks(os.TempDir()) if err != nil { t.Fatalf("error: %v", err) } - c.pluginCatalog.directory = sym + c, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: sym, + }) + b := c.systemBackend for name, tc := range map[string]struct { in, expected map[string]any @@ -3618,13 +3622,14 @@ func TestSystemBackend_PluginCatalog_ListPlugins_SucceedsWithAuditLogEnabled(t * } func TestSystemBackend_PluginCatalog_CannotRegisterBuiltinPlugins(t *testing.T) { - c, b, _ := testCoreSystemBackend(t) - // Bootstrap the pluginCatalog sym, err := filepath.EvalSymlinks(os.TempDir()) if err != nil { t.Fatalf("error: %v", err) } - c.pluginCatalog.directory = sym + c, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: sym, + }) + b := c.systemBackend // Set a plugin req := logical.TestRequest(t, logical.UpdateOperation, "plugins/catalog/database/test-plugin") @@ -5816,79 +5821,6 @@ func TestSystemBackend_LoggersByName(t *testing.T) { } } -func TestSortVersionedPlugins(t *testing.T) { - versionedPlugin := func(typ consts.PluginType, name string, version string, builtin bool) pluginutil.VersionedPlugin { - return pluginutil.VersionedPlugin{ - Type: typ.String(), - Name: name, - Version: version, - SHA256: "", - Builtin: builtin, - SemanticVersion: func() *semver.Version { - if version != "" { - return semver.Must(semver.NewVersion(version)) - } - - return semver.Must(semver.NewVersion("0.0.0")) - }(), - } - } - - differingTypes := []pluginutil.VersionedPlugin{ - versionedPlugin(consts.PluginTypeSecrets, "c", "1.0.0", false), - versionedPlugin(consts.PluginTypeDatabase, "c", "1.0.0", false), - versionedPlugin(consts.PluginTypeCredential, "c", "1.0.0", false), - } - differingNames := []pluginutil.VersionedPlugin{ - versionedPlugin(consts.PluginTypeCredential, "c", "1.0.0", false), - versionedPlugin(consts.PluginTypeCredential, "b", "1.0.0", false), - versionedPlugin(consts.PluginTypeCredential, "a", "1.0.0", false), - } - differingVersions := []pluginutil.VersionedPlugin{ - versionedPlugin(consts.PluginTypeCredential, "c", "10.0.0", false), - versionedPlugin(consts.PluginTypeCredential, "c", "2.0.1", false), - versionedPlugin(consts.PluginTypeCredential, "c", "2.1.0", false), - } - versionedUnversionedAndBuiltin := []pluginutil.VersionedPlugin{ - versionedPlugin(consts.PluginTypeCredential, "c", "1.0.0", false), - versionedPlugin(consts.PluginTypeCredential, "c", "", false), - versionedPlugin(consts.PluginTypeCredential, "c", "1.0.0", true), - } - - for name, tc := range map[string][]pluginutil.VersionedPlugin{ - "ascending types": differingTypes, - "ascending names": differingNames, - "ascending versions": differingVersions, - // Include differing versions twice so we can test out equality too. - "differing types, names and versions": append(differingTypes, - append(differingNames, - append(differingVersions, differingVersions...)...)...), - "mix of unversioned, versioned, and builtin": versionedUnversionedAndBuiltin, - } { - t.Run(name, func(t *testing.T) { - sortVersionedPlugins(tc) - for i := 1; i < len(tc); i++ { - previous := tc[i-1] - current := tc[i] - if current.Type > previous.Type { - continue - } - if current.Name > previous.Name { - continue - } - if current.SemanticVersion.GreaterThan(previous.SemanticVersion) { - continue - } - if current.Type == previous.Type && current.Name == previous.Name && current.SemanticVersion.Equal(previous.SemanticVersion) { - continue - } - - t.Fatalf("versioned plugins at index %d and %d were not properly sorted: %+v, %+v", i-1, i, previous, current) - } - }) - } -} - func TestValidateVersion(t *testing.T) { b := testSystemBackend(t).(*SystemBackend) k8sAuthBuiltin := versions.GetBuiltinVersion(consts.PluginTypeCredential, "kubernetes") @@ -5927,12 +5859,13 @@ func TestValidateVersion(t *testing.T) { } func TestValidateVersion_HelpfulErrorWhenBuiltinOverridden(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) - tempDir, err := filepath.EvalSymlinks(t.TempDir()) + tempDir, err := filepath.EvalSymlinks(os.TempDir()) if err != nil { - t.Fatal(err) + t.Fatalf("error: %v", err) } - core.pluginCatalog.directory = tempDir + core, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: tempDir, + }) b := core.systemBackend // Shadow a builtin and test getting a helpful error back. @@ -6378,7 +6311,14 @@ func TestSystemBackend_pluginRuntime_CannotDeleteRuntimeWithReferencingPlugins(t if runtime.GOOS != "linux" { t.Skip("Currently plugincontainer only supports linux") } - c, b, _ := testCoreSystemBackend(t) + sym, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + c, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ + PluginDirectory: sym, + }) + b := c.systemBackend conf := pluginruntimeutil.PluginRuntimeConfig{ Name: "foo", @@ -6406,13 +6346,6 @@ func TestSystemBackend_pluginRuntime_CannotDeleteRuntimeWithReferencingPlugins(t t.Fatalf("bad: %#v", resp) } - // Bootstrap the pluginCatalog - sym, err := filepath.EvalSymlinks(os.TempDir()) - if err != nil { - t.Fatalf("error: %v", err) - } - c.pluginCatalog.directory = sym - // Register the plugin referencing the runtime. req = logical.TestRequest(t, logical.UpdateOperation, "plugins/catalog/database/test-plugin") req.Data["version"] = "v0.16.0" diff --git a/vault/mount.go b/vault/mount.go index d06c331ea1..158dbd9356 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -24,6 +24,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/plugincatalog" "github.com/mitchellh/copystructure" ) @@ -1690,7 +1691,7 @@ func (c *Core) newLogicalBackend(ctx context.Context, entry *MountEntry, sysView if entry.Version != "" { errContext += fmt.Sprintf(", version=%s", entry.Version) } - return nil, "", fmt.Errorf("%w: %s", ErrPluginNotFound, errContext) + return nil, "", fmt.Errorf("%w: %s", plugincatalog.ErrPluginNotFound, errContext) } if len(plug.Sha256) > 0 { runningSha = hex.EncodeToString(plug.Sha256) diff --git a/vault/plugincatalog/builtin_registry.go b/vault/plugincatalog/builtin_registry.go new file mode 100644 index 0000000000..38bbb34013 --- /dev/null +++ b/vault/plugincatalog/builtin_registry.go @@ -0,0 +1,17 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package plugincatalog + +import "github.com/hashicorp/vault/sdk/helper/consts" + +// BuiltinRegistry is an interface that allows the "vault" package to use +// the registry of builtin plugins without getting an import cycle. It +// also allows for mocking the registry easily. +type BuiltinRegistry interface { + Contains(name string, pluginType consts.PluginType) bool + Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) + Keys(pluginType consts.PluginType) []string + DeprecationStatus(name string, pluginType consts.PluginType) (consts.DeprecationStatus, bool) + IsBuiltinEntPlugin(name string, pluginType consts.PluginType) bool +} diff --git a/vault/plugin_catalog.go b/vault/plugincatalog/plugin_catalog.go similarity index 95% rename from vault/plugin_catalog.go rename to vault/plugincatalog/plugin_catalog.go index 68936cbd5e..2c39009df7 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugincatalog/plugin_catalog.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package vault +package plugincatalog import ( "context" @@ -11,6 +11,7 @@ import ( "fmt" "path" "path/filepath" + "sort" "strings" "sync" @@ -33,7 +34,6 @@ import ( ) var ( - pluginCatalogPath = "core/plugin-catalog/" ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured") ErrPluginNotFound = errors.New("plugin not found in the catalog") ErrPluginConnectionNotFound = errors.New("plugin connection not found for client") @@ -45,7 +45,7 @@ var ( // plugins are automatically detected and included in the catalog. type PluginCatalog struct { builtinRegistry BuiltinRegistry - catalogView *BarrierView + catalogView logical.Storage directory string logger log.Logger @@ -137,67 +137,37 @@ type pluginClient struct { plugin.ClientProtocol } -func wrapFactoryCheckPerms(core *Core, f logical.Factory) logical.Factory { - return func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { - pluginName := conf.Config["plugin_name"] - pluginVersion := conf.Config["plugin_version"] - pluginTypeRaw := conf.Config["plugin_type"] - pluginType, err := consts.ParsePluginType(pluginTypeRaw) - if err != nil { - return nil, err - } - - pluginDescription := fmt.Sprintf("%s plugin %s", pluginTypeRaw, pluginName) - if pluginVersion != "" { - pluginDescription += " version " + pluginVersion - } - - plugin, err := core.pluginCatalog.Get(ctx, pluginName, pluginType, pluginVersion) - if err != nil { - return nil, fmt.Errorf("failed to find %s in plugin catalog: %w", pluginDescription, err) - } - if plugin == nil { - return nil, fmt.Errorf("failed to find %s in plugin catalog", pluginDescription) - } - if plugin.OCIImage != "" { - return f(ctx, conf) - } - - command, err := filepath.Rel(core.pluginCatalog.directory, plugin.Command) - if err != nil { - return nil, fmt.Errorf("failed to compute plugin command: %w", err) - } - - if err := core.CheckPluginPerms(command); err != nil { - return nil, err - } - return f(ctx, conf) - } -} - -func (c *Core) setupPluginCatalog(ctx context.Context) error { - c.pluginCatalog = &PluginCatalog{ - builtinRegistry: c.builtinRegistry, - catalogView: NewBarrierView(c.barrier, pluginCatalogPath), - directory: c.pluginDirectory, - logger: c.logger, - mlockPlugins: c.enableMlock, +func SetupPluginCatalog( + ctx context.Context, + logger log.Logger, + builtinRegistry BuiltinRegistry, + catalogView logical.Storage, + pluginDirectory string, + enableMlock bool, + pluginRuntimeCatalog *PluginRuntimeCatalog, +) (*PluginCatalog, error) { + pluginCatalog := &PluginCatalog{ + builtinRegistry: builtinRegistry, + catalogView: catalogView, + directory: pluginDirectory, + logger: logger, + mlockPlugins: enableMlock, wrapper: logical.StaticSystemView{VersionString: version.GetVersion().Version}, - runtimeCatalog: c.pluginRuntimeCatalog, + runtimeCatalog: pluginRuntimeCatalog, } // Run upgrade if untyped plugins exist - err := c.pluginCatalog.UpgradePlugins(ctx, c.logger) + err := pluginCatalog.UpgradePlugins(ctx, logger) if err != nil { - c.logger.Error("error while upgrading plugin storage", "error", err) - return err + logger.Error("error while upgrading plugin storage", "error", err) + return nil, err } - if c.logger.IsInfo() { - c.logger.Info("successfully setup plugin catalog", "plugin-directory", c.pluginDirectory) + if logger.IsInfo() { + logger.Info("successfully setup plugin catalog", "plugin-directory", pluginDirectory) } - return nil + return pluginCatalog, nil } type pluginClientConn struct { @@ -232,6 +202,10 @@ func (p *pluginClient) Reload() error { return p.reloadFunc() } +func (c *PluginCatalog) Processes() int { + return len(c.externalPlugins) +} + // reloadExternalPlugin // This should be called with the write lock held. func (c *PluginCatalog) reloadExternalPlugin(key externalPluginsKey, id, pluginBinaryRef string) error { @@ -781,6 +755,11 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e continue } + if pluginRaw == nil { + retErr = multierror.Append(fmt.Errorf("%q plugin entry was nil", pluginName)) + continue + } + plugin := new(pluginutil.PluginRunner) if err := jsonutil.DecodeJSON(pluginRaw.Value, plugin); err != nil { retErr = multierror.Append(fmt.Errorf("failed to decode plugin entry: %w", err)) @@ -1187,5 +1166,25 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug }) } + // Sort the result for consistent ordering. + sortVersionedPlugins(result) + return result, nil } + +func sortVersionedPlugins(versionedPlugins []pluginutil.VersionedPlugin) { + sort.SliceStable(versionedPlugins, func(i, j int) bool { + left, right := versionedPlugins[i], versionedPlugins[j] + if left.Type != right.Type { + return left.Type < right.Type + } + if left.Name != right.Name { + return left.Name < right.Name + } + if left.Version != right.Version { + return right.SemanticVersion.GreaterThan(left.SemanticVersion) + } + + return false + }) +} diff --git a/vault/plugin_catalog_test.go b/vault/plugincatalog/plugin_catalog_test.go similarity index 60% rename from vault/plugin_catalog_test.go rename to vault/plugincatalog/plugin_catalog_test.go index 5dd62886ed..9704a37426 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugincatalog/plugin_catalog_test.go @@ -1,52 +1,85 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package vault +package plugincatalog import ( "context" "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io/ioutil" "os" + "os/exec" "path/filepath" "reflect" + "runtime" "sort" "testing" + "github.com/hashicorp/go-hclog" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-version" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/credential/userpass" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" + "github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" "github.com/hashicorp/vault/helper/versions" "github.com/hashicorp/vault/plugins/database/postgresql" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/physical/inmem" backendplugin "github.com/hashicorp/vault/sdk/plugin" "github.com/hashicorp/vault/helper/builtinplugins" ) -func TestPluginCatalog_CRUD(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) - tempDir, err := filepath.EvalSymlinks(t.TempDir()) +func testPluginCatalog(t *testing.T) *PluginCatalog { + logger := hclog.New(&hclog.LoggerOptions{ + Level: hclog.Trace, + }) + storage, err := inmem.NewInmem(nil, logger) if err != nil { t.Fatal(err) } - core.pluginCatalog.directory = tempDir + testDir, err := filepath.EvalSymlinks(filepath.Dir(os.Args[0])) + if err != nil { + t.Fatal(err) + } + pluginRuntimeCatalog := testPluginRuntimeCatalog(t) + pluginCatalog, err := SetupPluginCatalog( + context.Background(), + logger, + corehelpers.NewMockBuiltinRegistry(), + logical.NewLogicalStorage(storage), + testDir, + false, + pluginRuntimeCatalog, + ) + if err != nil { + t.Fatal(err) + } + return pluginCatalog +} +func TestPluginCatalog_CRUD(t *testing.T) { const pluginName = "mysql-database-plugin" + pluginCatalog := testPluginCatalog(t) + // Get builtin plugin - p, err := core.pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "") + p, err := pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "") if err != nil { t.Fatalf("unexpected error %v", err) } // Get it again, explicitly specifying builtin version builtinVersion := versions.GetBuiltinVersion(consts.PluginTypeDatabase, pluginName) - p2, err := core.pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, builtinVersion) + p2, err := pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, builtinVersion) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -73,14 +106,14 @@ func TestPluginCatalog_CRUD(t *testing.T) { } // Set a plugin, test overwriting a builtin plugin - file, err := os.CreateTemp(tempDir, "temp") + file, err := os.CreateTemp(pluginCatalog.directory, "temp") if err != nil { t.Fatal(err) } defer file.Close() command := filepath.Base(file.Name()) - err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: pluginName, Type: consts.PluginTypeDatabase, Version: "", @@ -94,14 +127,14 @@ func TestPluginCatalog_CRUD(t *testing.T) { } // Get the plugin - p, err = core.pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "") + p, err = pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "") if err != nil { t.Fatalf("unexpected error %v", err) } // Get it again, explicitly specifying builtin version. // This time it should fail because it was overwritten. - p2, err = core.pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, builtinVersion) + p2, err = pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, builtinVersion) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -112,7 +145,7 @@ func TestPluginCatalog_CRUD(t *testing.T) { expected := &pluginutil.PluginRunner{ Name: pluginName, Type: consts.PluginTypeDatabase, - Command: filepath.Join(tempDir, filepath.Base(file.Name())), + Command: filepath.Join(pluginCatalog.directory, filepath.Base(file.Name())), Args: []string{"--test"}, Env: []string{"FOO=BAR"}, Sha256: []byte{'1'}, @@ -125,13 +158,13 @@ func TestPluginCatalog_CRUD(t *testing.T) { } // Delete the plugin - err = core.pluginCatalog.Delete(context.Background(), pluginName, consts.PluginTypeDatabase, "") + err = pluginCatalog.Delete(context.Background(), pluginName, consts.PluginTypeDatabase, "") if err != nil { t.Fatalf("unexpected err: %v", err) } // Get builtin plugin - p, err = core.pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "") + p, err = pluginCatalog.Get(context.Background(), pluginName, consts.PluginTypeDatabase, "") if err != nil { t.Fatalf("unexpected error %v", err) } @@ -155,15 +188,10 @@ func TestPluginCatalog_CRUD(t *testing.T) { } func TestPluginCatalog_VersionedCRUD(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) - tempDir, err := filepath.EvalSymlinks(t.TempDir()) - if err != nil { - t.Fatal(err) - } - core.pluginCatalog.directory = tempDir + pluginCatalog := testPluginCatalog(t) // Set a versioned plugin. - file, err := ioutil.TempFile(tempDir, "temp") + file, err := os.CreateTemp(pluginCatalog.directory, "temp") if err != nil { t.Fatal(err) } @@ -172,7 +200,7 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) { const name = "mysql-database-plugin" const version = "1.0.0" command := fmt.Sprintf("%s", filepath.Base(file.Name())) - err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: name, Type: consts.PluginTypeDatabase, Version: version, @@ -186,7 +214,7 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) { } // Get the plugin - plugin, err := core.pluginCatalog.Get(context.Background(), name, consts.PluginTypeDatabase, version) + plugin, err := pluginCatalog.Get(context.Background(), name, consts.PluginTypeDatabase, version) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -195,7 +223,7 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) { Name: name, Type: consts.PluginTypeDatabase, Version: version, - Command: filepath.Join(tempDir, filepath.Base(file.Name())), + Command: filepath.Join(pluginCatalog.directory, filepath.Base(file.Name())), Args: []string{"--test"}, Env: []string{"FOO=BAR"}, Sha256: []byte{'1'}, @@ -208,7 +236,7 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) { // Also get the builtin version to check we can still access that. builtinVersion := versions.GetBuiltinVersion(consts.PluginTypeDatabase, name) - plugin, err = core.pluginCatalog.Get(context.Background(), name, consts.PluginTypeDatabase, builtinVersion) + plugin, err = pluginCatalog.Get(context.Background(), name, consts.PluginTypeDatabase, builtinVersion) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -237,13 +265,13 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) { } // Delete the plugin - err = core.pluginCatalog.Delete(context.Background(), name, consts.PluginTypeDatabase, version) + err = pluginCatalog.Delete(context.Background(), name, consts.PluginTypeDatabase, version) if err != nil { t.Fatalf("unexpected err: %v", err) } // Get plugin - should fail - plugin, err = core.pluginCatalog.Get(context.Background(), name, consts.PluginTypeDatabase, version) + plugin, err = pluginCatalog.Get(context.Background(), name, consts.PluginTypeDatabase, version) if err != nil { t.Fatal(err) } @@ -253,19 +281,14 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) { } func TestPluginCatalog_List(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) - tempDir, err := filepath.EvalSymlinks(t.TempDir()) - if err != nil { - t.Fatal(err) - } - core.pluginCatalog.directory = tempDir + pluginCatalog := testPluginCatalog(t) // Get builtin plugins and sort them builtinKeys := builtinplugins.Registry.Keys(consts.PluginTypeDatabase) sort.Strings(builtinKeys) // List only builtin plugins - plugins, err := core.pluginCatalog.List(context.Background(), consts.PluginTypeDatabase) + plugins, err := pluginCatalog.List(context.Background(), consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -280,14 +303,14 @@ func TestPluginCatalog_List(t *testing.T) { } // Set a plugin, test overwriting a builtin plugin - file, err := ioutil.TempFile(tempDir, "temp") + file, err := os.CreateTemp(pluginCatalog.directory, "temp") if err != nil { t.Fatal(err) } defer file.Close() command := filepath.Base(file.Name()) - err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: "mysql-database-plugin", Type: consts.PluginTypeDatabase, Version: "", @@ -301,7 +324,7 @@ func TestPluginCatalog_List(t *testing.T) { } // Set another plugin - err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: "aaaaaaa", Type: consts.PluginTypeDatabase, Version: "", @@ -315,7 +338,7 @@ func TestPluginCatalog_List(t *testing.T) { } // List the plugins - plugins, err = core.pluginCatalog.List(context.Background(), consts.PluginTypeDatabase) + plugins, err = pluginCatalog.List(context.Background(), consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -338,19 +361,14 @@ func TestPluginCatalog_List(t *testing.T) { } func TestPluginCatalog_ListVersionedPlugins(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) - tempDir, err := filepath.EvalSymlinks(t.TempDir()) - if err != nil { - t.Fatal(err) - } - core.pluginCatalog.directory = tempDir + pluginCatalog := testPluginCatalog(t) // Get builtin plugins and sort them builtinKeys := builtinplugins.Registry.Keys(consts.PluginTypeDatabase) sort.Strings(builtinKeys) // List only builtin plugins - plugins, err := core.pluginCatalog.ListVersionedPlugins(context.Background(), consts.PluginTypeDatabase) + plugins, err := pluginCatalog.ListVersionedPlugins(context.Background(), consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -367,14 +385,14 @@ func TestPluginCatalog_ListVersionedPlugins(t *testing.T) { } // Set a plugin, test overwriting a builtin plugin - file, err := ioutil.TempFile(tempDir, "temp") + file, err := ioutil.TempFile(pluginCatalog.directory, "temp") if err != nil { t.Fatal(err) } defer file.Close() command := filepath.Base(file.Name()) - err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: "mysql-database-plugin", Type: consts.PluginTypeDatabase, Version: "", @@ -388,7 +406,7 @@ func TestPluginCatalog_ListVersionedPlugins(t *testing.T) { } // Set another plugin, with version information - err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: "aaaaaaa", Type: consts.PluginTypeDatabase, Version: "1.1.0", @@ -402,7 +420,7 @@ func TestPluginCatalog_ListVersionedPlugins(t *testing.T) { } // List the plugins - plugins, err = core.pluginCatalog.ListVersionedPlugins(context.Background(), consts.PluginTypeDatabase) + plugins, err = pluginCatalog.ListVersionedPlugins(context.Background(), consts.PluginTypeDatabase) if err != nil { t.Fatalf("unexpected error %v", err) } @@ -450,14 +468,9 @@ func TestPluginCatalog_ListVersionedPlugins(t *testing.T) { } func TestPluginCatalog_ListHandlesPluginNamesWithSlashes(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) - tempDir, err := filepath.EvalSymlinks(t.TempDir()) - if err != nil { - t.Fatal(err) - } - core.pluginCatalog.directory = tempDir + pluginCatalog := testPluginCatalog(t) - file, err := ioutil.TempFile(tempDir, "temp") + file, err := os.CreateTemp(pluginCatalog.directory, "temp") if err != nil { t.Fatal(err) } @@ -489,7 +502,7 @@ func TestPluginCatalog_ListHandlesPluginNamesWithSlashes(t *testing.T) { }, } for _, entry := range pluginsToRegister { - err = core.pluginCatalog.Set(ctx, pluginutil.SetPluginInput{ + err = pluginCatalog.Set(ctx, pluginutil.SetPluginInput{ Name: entry.Name, Type: consts.PluginTypeCredential, Version: entry.Version, @@ -503,7 +516,7 @@ func TestPluginCatalog_ListHandlesPluginNamesWithSlashes(t *testing.T) { } } - plugins, err := core.pluginCatalog.ListVersionedPlugins(ctx, consts.PluginTypeCredential) + plugins, err := pluginCatalog.ListVersionedPlugins(ctx, consts.PluginTypeCredential) if err != nil { t.Fatal(err) } @@ -524,30 +537,25 @@ func TestPluginCatalog_ListHandlesPluginNamesWithSlashes(t *testing.T) { } func TestPluginCatalog_NewPluginClient(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) - tempDir, err := filepath.EvalSymlinks(t.TempDir()) - if err != nil { - t.Fatal(err) - } - core.pluginCatalog.directory = tempDir + pluginCatalog := testPluginCatalog(t) - if extPlugins := len(core.pluginCatalog.externalPlugins); extPlugins != 0 { + if extPlugins := len(pluginCatalog.externalPlugins); extPlugins != 0 { t.Fatalf("expected externalPlugins map to be of len 0 but got %d", extPlugins) } // register plugins - TestAddTestPlugin(t, core, "mux-postgres", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_PostgresMultiplexed", []string{}, "") - TestAddTestPlugin(t, core, "single-postgres-1", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Postgres", []string{}, "") - TestAddTestPlugin(t, core, "single-postgres-2", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Postgres", []string{}, "") + TestAddTestPlugin(t, pluginCatalog, "mux-postgres", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_PostgresMultiplexed", []string{}) + TestAddTestPlugin(t, pluginCatalog, "single-postgres-1", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Postgres", []string{}) + TestAddTestPlugin(t, pluginCatalog, "single-postgres-2", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Postgres", []string{}) - TestAddTestPlugin(t, core, "mux-userpass", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_UserpassMultiplexed", []string{}, "") - TestAddTestPlugin(t, core, "single-userpass-1", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Userpass", []string{}, "") - TestAddTestPlugin(t, core, "single-userpass-2", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Userpass", []string{}, "") + TestAddTestPlugin(t, pluginCatalog, "mux-userpass", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_UserpassMultiplexed", []string{}) + TestAddTestPlugin(t, pluginCatalog, "single-userpass-1", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Userpass", []string{}) + TestAddTestPlugin(t, pluginCatalog, "single-userpass-2", consts.PluginTypeUnknown, "", "TestPluginCatalog_PluginMain_Userpass", []string{}) getKey := func(pluginName string, pluginType consts.PluginType) externalPluginsKey { t.Helper() ctx := context.Background() - plugin, err := core.pluginCatalog.Get(ctx, pluginName, pluginType, "") + plugin, err := pluginCatalog.Get(ctx, pluginName, pluginType, "") if err != nil { t.Fatal(err) } @@ -565,27 +573,27 @@ func TestPluginCatalog_NewPluginClient(t *testing.T) { // run plugins // run "mux-postgres" twice which will start a single plugin for 2 // distinct connections - c := TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "mux-postgres") + c := testRunTestPlugin(t, pluginCatalog, consts.PluginTypeDatabase, "mux-postgres") pluginClients = append(pluginClients, c) - c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "mux-postgres") + c = testRunTestPlugin(t, pluginCatalog, consts.PluginTypeDatabase, "mux-postgres") pluginClients = append(pluginClients, c) - c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "single-postgres-1") + c = testRunTestPlugin(t, pluginCatalog, consts.PluginTypeDatabase, "single-postgres-1") pluginClients = append(pluginClients, c) - c = TestRunTestPlugin(t, core, consts.PluginTypeDatabase, "single-postgres-2") + c = testRunTestPlugin(t, pluginCatalog, consts.PluginTypeDatabase, "single-postgres-2") pluginClients = append(pluginClients, c) // run "mux-userpass" twice which will start a single plugin for 2 // distinct connections - c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "mux-userpass") + c = testRunTestPlugin(t, pluginCatalog, consts.PluginTypeCredential, "mux-userpass") pluginClients = append(pluginClients, c) - c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "mux-userpass") + c = testRunTestPlugin(t, pluginCatalog, consts.PluginTypeCredential, "mux-userpass") pluginClients = append(pluginClients, c) - c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "single-userpass-1") + c = testRunTestPlugin(t, pluginCatalog, consts.PluginTypeCredential, "single-userpass-1") pluginClients = append(pluginClients, c) - c = TestRunTestPlugin(t, core, consts.PluginTypeCredential, "single-userpass-2") + c = testRunTestPlugin(t, pluginCatalog, consts.PluginTypeCredential, "single-userpass-2") pluginClients = append(pluginClients, c) - externalPlugins := core.pluginCatalog.externalPlugins + externalPlugins := pluginCatalog.externalPlugins if len(externalPlugins) != 6 { t.Fatalf("expected externalPlugins map to be of len 6 but got %d", len(externalPlugins)) } @@ -654,13 +662,10 @@ func TestPluginCatalog_MakeExternalPluginsKey_Comparable(t *testing.T) { // always allow registration of container plugins (rejecting on non-Linux happens // in the logical system API handler). func TestPluginCatalog_ErrDirectoryNotConfigured(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) - tempDir, err := filepath.EvalSymlinks(t.TempDir()) - if err != nil { - t.Fatal(err) - } + catalog := testPluginCatalog(t) + tempDir := catalog.directory + catalog.directory = "" - catalog := core.pluginCatalog tests := map[string]func(t *testing.T){ "set binary plugin": func(t *testing.T) { file, err := os.CreateTemp(tempDir, "temp") @@ -694,7 +699,7 @@ func TestPluginCatalog_ErrDirectoryNotConfigured(t *testing.T) { t.Fatal("expected error without directory set") } // Make sure we can still get builtins too - _, err = core.pluginCatalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "") + _, err = catalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "") if err != nil { t.Fatalf("unexpected error %v", err) } @@ -702,7 +707,7 @@ func TestPluginCatalog_ErrDirectoryNotConfigured(t *testing.T) { "set container plugin": func(t *testing.T) { // Should never error. const image = "does-not-exist" - err = catalog.Set(context.Background(), pluginutil.SetPluginInput{ + err := catalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: "container", Type: consts.PluginTypeDatabase, OCIImage: image, @@ -719,7 +724,7 @@ func TestPluginCatalog_ErrDirectoryNotConfigured(t *testing.T) { t.Fatalf("Expected %s, got %s", image, p.OCIImage) } // Make sure we can still get builtins too - _, err = core.pluginCatalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "") + _, err = catalog.Get(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "") if err != nil { t.Fatalf("unexpected error %v", err) } @@ -732,7 +737,7 @@ func TestPluginCatalog_ErrDirectoryNotConfigured(t *testing.T) { } }) - core.pluginCatalog.directory = tempDir + catalog.directory = tempDir t.Run("directory set", func(t *testing.T) { for name, test := range tests { @@ -745,10 +750,10 @@ func TestPluginCatalog_ErrDirectoryNotConfigured(t *testing.T) { // are returned with their container runtime config populated if it was // specified. func TestRuntimeConfigPopulatedIfSpecified(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) + pluginCatalog := testPluginCatalog(t) const image = "does-not-exist" const runtime = "custom-runtime" - err := core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + err := pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: "container", Type: consts.PluginTypeDatabase, OCIImage: image, @@ -759,7 +764,7 @@ func TestRuntimeConfigPopulatedIfSpecified(t *testing.T) { } const ociRuntime = "some-other-oci-runtime" - err = core.pluginRuntimeCatalog.Set(context.Background(), &pluginruntimeutil.PluginRuntimeConfig{ + err = pluginCatalog.runtimeCatalog.Set(context.Background(), &pluginruntimeutil.PluginRuntimeConfig{ Name: runtime, Type: consts.PluginRuntimeTypeContainer, OCIRuntime: ociRuntime, @@ -769,7 +774,7 @@ func TestRuntimeConfigPopulatedIfSpecified(t *testing.T) { } // Now setting the plugin with a runtime should succeed. - err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ Name: "container", Type: consts.PluginTypeDatabase, OCIImage: image, @@ -779,7 +784,7 @@ func TestRuntimeConfigPopulatedIfSpecified(t *testing.T) { t.Fatal(err) } - p, err := core.pluginCatalog.Get(context.Background(), "container", consts.PluginTypeDatabase, "") + p, err := pluginCatalog.Get(context.Background(), "container", consts.PluginTypeDatabase, "") if err != nil { t.Fatal(err) } @@ -874,3 +879,233 @@ func expectMultiplexingSupport(t *testing.T, expected, actual bool) { t.Fatalf("expected external plugin multiplexing support to be %t", expected) } } + +func TestSortVersionedPlugins(t *testing.T) { + versionedPlugin := func(typ consts.PluginType, name string, pluginVersion string, builtin bool) pluginutil.VersionedPlugin { + return pluginutil.VersionedPlugin{ + Type: typ.String(), + Name: name, + Version: pluginVersion, + SHA256: "", + Builtin: builtin, + SemanticVersion: func() *version.Version { + if pluginVersion != "" { + return version.Must(version.NewVersion(pluginVersion)) + } + + return version.Must(version.NewVersion("0.0.0")) + }(), + } + } + + differingTypes := []pluginutil.VersionedPlugin{ + versionedPlugin(consts.PluginTypeSecrets, "c", "1.0.0", false), + versionedPlugin(consts.PluginTypeDatabase, "c", "1.0.0", false), + versionedPlugin(consts.PluginTypeCredential, "c", "1.0.0", false), + } + differingNames := []pluginutil.VersionedPlugin{ + versionedPlugin(consts.PluginTypeCredential, "c", "1.0.0", false), + versionedPlugin(consts.PluginTypeCredential, "b", "1.0.0", false), + versionedPlugin(consts.PluginTypeCredential, "a", "1.0.0", false), + } + differingVersions := []pluginutil.VersionedPlugin{ + versionedPlugin(consts.PluginTypeCredential, "c", "10.0.0", false), + versionedPlugin(consts.PluginTypeCredential, "c", "2.0.1", false), + versionedPlugin(consts.PluginTypeCredential, "c", "2.1.0", false), + } + versionedUnversionedAndBuiltin := []pluginutil.VersionedPlugin{ + versionedPlugin(consts.PluginTypeCredential, "c", "1.0.0", false), + versionedPlugin(consts.PluginTypeCredential, "c", "", false), + versionedPlugin(consts.PluginTypeCredential, "c", "1.0.0", true), + } + + for name, tc := range map[string][]pluginutil.VersionedPlugin{ + "ascending types": differingTypes, + "ascending names": differingNames, + "ascending versions": differingVersions, + // Include differing versions twice so we can test out equality too. + "differing types, names and versions": append(differingTypes, + append(differingNames, + append(differingVersions, differingVersions...)...)...), + "mix of unversioned, versioned, and builtin": versionedUnversionedAndBuiltin, + } { + t.Run(name, func(t *testing.T) { + sortVersionedPlugins(tc) + for i := 1; i < len(tc); i++ { + previous := tc[i-1] + current := tc[i] + if current.Type > previous.Type { + continue + } + if current.Name > previous.Name { + continue + } + if current.SemanticVersion.GreaterThan(previous.SemanticVersion) { + continue + } + if current.Type == previous.Type && current.Name == previous.Name && current.SemanticVersion.Equal(previous.SemanticVersion) { + continue + } + + t.Fatalf("versioned plugins at index %d and %d were not properly sorted: %+v, %+v", i-1, i, previous, current) + } + }) + } +} + +func TestExternalPlugin_getBackendTypeVersion(t *testing.T) { + for name, tc := range map[string]struct { + pluginType consts.PluginType + setRunningVersion string + }{ + "external credential plugin": { + pluginType: consts.PluginTypeCredential, + setRunningVersion: "v1.2.3", + }, + "external secrets plugin": { + pluginType: consts.PluginTypeSecrets, + setRunningVersion: "v1.2.3", + }, + "external database plugin": { + pluginType: consts.PluginTypeDatabase, + setRunningVersion: "v1.2.3", + }, + } { + t.Run(name, func(t *testing.T) { + pluginCatalog := testPluginCatalog(t) + plugin := pluginhelpers.CompilePlugin(t, tc.pluginType, tc.setRunningVersion, pluginCatalog.directory) + + shaBytes, _ := hex.DecodeString(plugin.Sha256) + commandFull := filepath.Join(pluginCatalog.directory, plugin.FileName) + entry := &pluginutil.PluginRunner{ + Name: plugin.Name, + Command: commandFull, + Args: nil, + Sha256: shaBytes, + Builtin: false, + } + + var version logical.PluginVersion + var err error + if tc.pluginType == consts.PluginTypeDatabase { + version, err = pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) + } else { + version, err = pluginCatalog.getBackendRunningVersion(context.Background(), entry) + } + if err != nil { + t.Fatal(err) + } + if version.Version != tc.setRunningVersion { + t.Errorf("Expected to get version %v but got %v", tc.setRunningVersion, version.Version) + } + }) + } +} + +func TestExternalPluginInContainer_GetBackendTypeVersion(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Running plugins in containers is only supported on linux") + } + + pluginCatalog := testPluginCatalog(t) + + var plugins []pluginhelpers.TestPlugin + for _, pluginType := range []consts.PluginType{ + consts.PluginTypeCredential, + consts.PluginTypeSecrets, + consts.PluginTypeDatabase, + } { + plugin := pluginhelpers.CompilePlugin(t, pluginType, "v1.2.3", pluginCatalog.directory) + plugin.Image, plugin.ImageSha256 = pluginhelpers.BuildPluginContainerImage(t, plugin, pluginCatalog.directory) + plugins = append(plugins, plugin) + } + + for _, plugin := range plugins { + t.Run(plugin.Typ.String(), func(t *testing.T) { + for _, ociRuntime := range []string{"runc", "runsc"} { + t.Run(ociRuntime, func(t *testing.T) { + if _, err := exec.LookPath(ociRuntime); err != nil { + t.Skipf("Skipping test as %s not found on path", ociRuntime) + } + + shaBytes, _ := hex.DecodeString(plugin.ImageSha256) + entry := &pluginutil.PluginRunner{ + Name: plugin.Name, + OCIImage: plugin.Image, + Args: nil, + Sha256: shaBytes, + Builtin: false, + Runtime: ociRuntime, + RuntimeConfig: &pluginruntimeutil.PluginRuntimeConfig{ + OCIRuntime: ociRuntime, + }, + } + + var version logical.PluginVersion + var err error + if plugin.Typ == consts.PluginTypeDatabase { + version, err = pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) + } else { + version, err = pluginCatalog.getBackendRunningVersion(context.Background(), entry) + } + if err != nil { + t.Fatal(err) + } + if version.Version != plugin.Version { + t.Errorf("Expected to get version %v but got %v", plugin.Version, version.Version) + } + }) + } + }) + } +} + +// testRunTestPlugin runs the testFunc which has already been registered to the +// plugin catalog and returns a pluginClient. This can be called after calling +// TestAddTestPlugin. +func testRunTestPlugin(t *testing.T, pluginCatalog *PluginCatalog, pluginType consts.PluginType, pluginName string) *pluginClient { + t.Helper() + config := testPluginClientConfig(pluginCatalog, pluginType, pluginName) + client, err := pluginCatalog.NewPluginClient(context.Background(), config) + if err != nil { + t.Fatal(err) + } + + return client +} + +func testPluginClientConfig(pluginCatalog *PluginCatalog, pluginType consts.PluginType, pluginName string) pluginutil.PluginClientConfig { + config := pluginutil.PluginClientConfig{ + Name: pluginName, + PluginType: pluginType, + Logger: log.NewNullLogger(), + AutoMTLS: true, + IsMetadataMode: false, + Wrapper: pluginCatalogStaticSystemView{ + pluginCatalog: pluginCatalog, + StaticSystemView: logical.StaticSystemView{ + VersionString: "testVersion", + }, + }, + } + + switch pluginType { + case consts.PluginTypeCredential, consts.PluginTypeSecrets: + config.PluginSets = backendplugin.PluginSet + config.HandshakeConfig = backendplugin.HandshakeConfig + case consts.PluginTypeDatabase: + config.PluginSets = v5.PluginSets + config.HandshakeConfig = v5.HandshakeConfig + } + + return config +} + +type pluginCatalogStaticSystemView struct { + logical.StaticSystemView + pluginCatalog *PluginCatalog +} + +func (p pluginCatalogStaticSystemView) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { + return p.pluginCatalog.NewPluginClient(ctx, config) +} diff --git a/vault/plugin_runtime_catalog.go b/vault/plugincatalog/plugin_runtime_catalog.go similarity index 90% rename from vault/plugin_runtime_catalog.go rename to vault/plugincatalog/plugin_runtime_catalog.go index 9ce5a6e844..e82c52ccd9 100644 --- a/vault/plugin_runtime_catalog.go +++ b/vault/plugincatalog/plugin_runtime_catalog.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package vault +package plugincatalog import ( "context" @@ -20,7 +20,6 @@ import ( ) var ( - pluginRuntimeCatalogPath = "core/plugin-runtime-catalog/" ErrPluginRuntimeNotFound = errors.New("plugin runtime not found") ErrPluginRuntimeBadType = errors.New("unable to determine plugin runtime type") ErrPluginRuntimeBadContainerConfig = errors.New("bad container config") @@ -29,23 +28,23 @@ var ( // PluginRuntimeCatalog keeps a record of plugin runtimes. Plugin runtimes need // to be registered to the catalog before they can be used in backends when registering plugins with runtimes type PluginRuntimeCatalog struct { - catalogView *BarrierView + catalogView logical.Storage logger log.Logger lock sync.RWMutex } -func (c *Core) setupPluginRuntimeCatalog(ctx context.Context) error { - c.pluginRuntimeCatalog = &PluginRuntimeCatalog{ - catalogView: NewBarrierView(c.barrier, pluginRuntimeCatalogPath), - logger: c.logger, +func SetupPluginRuntimeCatalog(ctx context.Context, logger log.Logger, catalogView logical.Storage) (*PluginRuntimeCatalog, error) { + pluginRuntimeCatalog := &PluginRuntimeCatalog{ + catalogView: catalogView, + logger: logger, } - if c.logger.IsInfo() { - c.logger.Info("successfully setup plugin runtime catalog") + if logger.IsInfo() { + logger.Info("successfully setup plugin runtime catalog") } - return nil + return pluginRuntimeCatalog, nil } // Get retrieves a plugin runtime with the specified name from the catalog diff --git a/vault/plugin_runtime_catalog_test.go b/vault/plugincatalog/plugin_runtime_catalog_test.go similarity index 61% rename from vault/plugin_runtime_catalog_test.go rename to vault/plugincatalog/plugin_runtime_catalog_test.go index 50bded62fc..43dcfffde3 100644 --- a/vault/plugin_runtime_catalog_test.go +++ b/vault/plugincatalog/plugin_runtime_catalog_test.go @@ -1,18 +1,34 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -package vault +package plugincatalog import ( "context" "reflect" "testing" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/physical/inmem" ) +func testPluginRuntimeCatalog(t *testing.T) *PluginRuntimeCatalog { + logger := hclog.Default() + storage, err := inmem.NewInmem(nil, logger) + if err != nil { + t.Fatal(err) + } + pluginRuntimeCatalog, err := SetupPluginRuntimeCatalog(context.Background(), logger, logical.NewLogicalStorage(storage)) + if err != nil { + t.Fatal(err) + } + return pluginRuntimeCatalog +} + func TestPluginRuntimeCatalog_CRUD(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) + pluginRuntimeCatalog := testPluginRuntimeCatalog(t) ctx := context.Background() expected := &pluginruntimeutil.PluginRuntimeConfig{ @@ -24,13 +40,13 @@ func TestPluginRuntimeCatalog_CRUD(t *testing.T) { } // Set new plugin runtime - err := core.pluginRuntimeCatalog.Set(ctx, expected) + err := pluginRuntimeCatalog.Set(ctx, expected) if err != nil { t.Fatalf("err: %v", err) } // Get plugin runtime - runner, err := core.pluginRuntimeCatalog.Get(ctx, expected.Name, expected.Type) + runner, err := pluginRuntimeCatalog.Get(ctx, expected.Name, expected.Type) if err != nil { t.Fatalf("err: %v", err) } @@ -42,13 +58,13 @@ func TestPluginRuntimeCatalog_CRUD(t *testing.T) { expected.CgroupParent = "memorylimit-cgroup" expected.CPU = 2 expected.Memory = 5000 - err = core.pluginRuntimeCatalog.Set(ctx, expected) + err = pluginRuntimeCatalog.Set(ctx, expected) if err != nil { t.Fatalf("err: %v", err) } // Get plugin runtime again - runner, err = core.pluginRuntimeCatalog.Get(ctx, expected.Name, expected.Type) + runner, err = pluginRuntimeCatalog.Get(ctx, expected.Name, expected.Type) if err != nil { t.Fatalf("err: %v", err) } @@ -57,7 +73,7 @@ func TestPluginRuntimeCatalog_CRUD(t *testing.T) { t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", runner, expected) } - configs, err := core.pluginRuntimeCatalog.List(ctx, expected.Type) + configs, err := pluginRuntimeCatalog.List(ctx, expected.Type) if err != nil { t.Fatalf("err: %v", err) } @@ -66,13 +82,13 @@ func TestPluginRuntimeCatalog_CRUD(t *testing.T) { } // Delete plugin runtime - err = core.pluginRuntimeCatalog.Delete(ctx, expected.Name, expected.Type) + err = pluginRuntimeCatalog.Delete(ctx, expected.Name, expected.Type) if err != nil { t.Fatalf("err: %v", err) } // Assert the plugin runtime catalog is empty - configs, err = core.pluginRuntimeCatalog.List(ctx, expected.Type) + configs, err = pluginRuntimeCatalog.List(ctx, expected.Type) if err != nil { t.Fatalf("err: %v", err) } diff --git a/vault/plugincatalog/testing.go b/vault/plugincatalog/testing.go new file mode 100644 index 0000000000..cb5a52e7f8 --- /dev/null +++ b/vault/plugincatalog/testing.go @@ -0,0 +1,94 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package plugincatalog + +import ( + "context" + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/mitchellh/go-testing-interface" +) + +// TestAddTestPlugin registers the testFunc as part of the plugin command to the +// plugin catalog. The plugin catalog must be configured with a pluginDirectory. +// NB: The test func you pass in MUST be in the same package as the parent test, +// or the test func won't be compiled into the test binary being run and the output +// will be something like: +// stderr (ignored by go-plugin): "testing: warning: no tests to run" +// stdout: "PASS" +func TestAddTestPlugin(t testing.T, pluginCatalog *PluginCatalog, name string, pluginType consts.PluginType, version string, testFunc string, env []string) { + t.Helper() + if pluginCatalog.directory == "" { + t.Fatal("plugin catalog must have a plugin directory set to add plugins") + } + file, err := os.Open(os.Args[0]) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + fileName := filepath.Base(os.Args[0]) + + fi, err := file.Stat() + if err != nil { + t.Fatal(err) + } + + // Copy over the file to the temp dir + dst := filepath.Join(pluginCatalog.directory, fileName) + + // delete the file first to avoid notary failures in macOS + _ = os.Remove(dst) // ignore error + out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) + if err != nil { + t.Fatal(err) + } + defer out.Close() + + if _, err = io.Copy(out, file); err != nil { + t.Fatal(err) + } + err = out.Sync() + if err != nil { + t.Fatal(err) + } + // Ensure that the file is closed and written. This seems to be + // necessary on Linux systems. + out.Close() + + // Copied the file, now seek to the start again to calculate its sha256 hash. + _, err = file.Seek(0, 0) + if err != nil { + t.Fatal(err) + } + + hash := sha256.New() + _, err = io.Copy(hash, file) + if err != nil { + t.Fatal(err) + } + sum := hash.Sum(nil) + + // The flag is a regex, so use ^$ to make sure we only run a single test + // with an exact match. + args := []string{fmt.Sprintf("--test.run=^%s$", testFunc)} + err = pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ + Name: name, + Type: pluginType, + Version: version, + Command: fileName, + Args: args, + Env: env, + Sha256: sum, + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/vault/testing.go b/vault/testing.go index 0a092c1465..893cdcb650 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -9,7 +9,6 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" @@ -17,7 +16,6 @@ import ( "encoding/pem" "errors" "fmt" - "io" "io/ioutil" "math/big" mathrand "math/rand" @@ -52,17 +50,15 @@ import ( "github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" "github.com/hashicorp/vault/internalshared/configutil" - v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/logging" - "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/helper/testcluster" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" physInmem "github.com/hashicorp/vault/sdk/physical/inmem" - backendplugin "github.com/hashicorp/vault/sdk/plugin" "github.com/hashicorp/vault/vault/cluster" + "github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/vault/seal" ) @@ -543,136 +539,9 @@ func TestDynamicSystemView(c *Core, ns *namespace.Namespace) *dynamicSystemView return &dynamicSystemView{c, me, c.perfStandby} } -// TestAddTestPlugin registers the testFunc as part of the plugin command to the -// plugin catalog. If provided, uses tmpDir as the plugin directory. -// NB: The test func you pass in MUST be in the same package as the parent test, -// or the test func won't be compiled into the test binary being run and the output -// will be something like: -// stderr (ignored by go-plugin): "testing: warning: no tests to run" -// stdout: "PASS" -func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.PluginType, version string, testFunc string, env []string, tempDir string) { - file, err := os.Open(os.Args[0]) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - dirPath := filepath.Dir(os.Args[0]) - fileName := filepath.Base(os.Args[0]) - - if tempDir != "" { - fi, err := file.Stat() - if err != nil { - t.Fatal(err) - } - - // Copy over the file to the temp dir - dst := filepath.Join(tempDir, fileName) - - // delete the file first to avoid notary failures in macOS - _ = os.Remove(dst) // ignore error - out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) - if err != nil { - t.Fatal(err) - } - defer out.Close() - - if _, err = io.Copy(out, file); err != nil { - t.Fatal(err) - } - err = out.Sync() - if err != nil { - t.Fatal(err) - } - // Ensure that the file is closed and written. This seems to be - // necessary on Linux systems. - out.Close() - - dirPath = tempDir - } - - // Determine plugin directory full path, evaluating potential symlink path - fullPath, err := filepath.EvalSymlinks(dirPath) - if err != nil { - t.Fatal(err) - } - - reader, err := os.Open(filepath.Join(fullPath, fileName)) - if err != nil { - t.Fatal(err) - } - defer reader.Close() - - // Find out the sha256 - hash := sha256.New() - - _, err = io.Copy(hash, reader) - if err != nil { - t.Fatal(err) - } - - sum := hash.Sum(nil) - - // Set core's plugin directory and plugin catalog directory - c.pluginDirectory = fullPath - c.pluginCatalog.directory = fullPath - - args := []string{fmt.Sprintf("--test.run=%s", testFunc)} - err = c.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{ - Name: name, - Type: pluginType, - Version: version, - Command: fileName, - Args: args, - Env: env, - Sha256: sum, - }) - if err != nil { - t.Fatal(err) - } -} - -// TestRunTestPlugin runs the testFunc which has already been registered to the -// plugin catalog and returns a pluginClient. This can be called after calling -// TestAddTestPlugin. -func TestRunTestPlugin(t testing.T, c *Core, pluginType consts.PluginType, pluginName string) *pluginClient { +func TestAddTestPlugin(t testing.T, core *Core, name string, pluginType consts.PluginType, version string, testFunc string, env []string) { t.Helper() - config := TestPluginClientConfig(c, pluginType, pluginName) - client, err := c.pluginCatalog.NewPluginClient(context.Background(), config) - if err != nil { - t.Fatal(err) - } - - return client -} - -func TestPluginClientConfig(c *Core, pluginType consts.PluginType, pluginName string) pluginutil.PluginClientConfig { - dsv := TestDynamicSystemView(c, nil) - switch pluginType { - case consts.PluginTypeCredential, consts.PluginTypeSecrets: - return pluginutil.PluginClientConfig{ - Name: pluginName, - PluginType: pluginType, - PluginSets: backendplugin.PluginSet, - HandshakeConfig: backendplugin.HandshakeConfig, - Logger: log.NewNullLogger(), - AutoMTLS: true, - IsMetadataMode: false, - Wrapper: dsv, - } - case consts.PluginTypeDatabase: - return pluginutil.PluginClientConfig{ - Name: pluginName, - PluginType: pluginType, - PluginSets: v5.PluginSets, - HandshakeConfig: v5.HandshakeConfig, - Logger: log.NewNullLogger(), - AutoMTLS: true, - IsMetadataMode: false, - Wrapper: dsv, - } - } - return pluginutil.PluginClientConfig{} + plugincatalog.TestAddTestPlugin(t, core.pluginCatalog, name, pluginType, version, testFunc, env) } var ( @@ -1707,13 +1576,9 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te t.Skip("Running plugins in containers is only supported on linux") } - var pluginDir string - var cleanup func(t testing.T) - if coreConfig.PluginDirectory == "" { - pluginDir, cleanup = corehelpers.MakeTestPluginDir(t) + pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig.PluginDirectory = pluginDir - t.Cleanup(func() { cleanup(t) }) } for _, version := range pluginType.Versions {