From 6e537bb376d29e7054ca7eab854f02eb481808e9 Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Mon, 8 Jan 2024 12:21:13 +0000 Subject: [PATCH] Support reloading database plugins across multiple mounts (#24512) * Support reloading database plugins across multiple mounts * Add clarifying comment to MountEntry.Path field * Tests: Replace non-parallelisable t.Setenv with plugin env settings --- builtin/logical/database/backend_test.go | 271 +++++++++--------- .../logical/database/versioning_large_test.go | 7 +- builtin/plugin/backend_test.go | 5 +- changelog/24512.txt | 6 + http/plugin_test.go | 6 +- vault/external_tests/plugin/plugin_test.go | 32 +++ vault/logical_system.go | 25 +- vault/mount.go | 2 +- vault/plugin_reload.go | 47 ++- 9 files changed, 233 insertions(+), 168 deletions(-) create mode 100644 changelog/24512.txt diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 0554365e1c..f330e5e5c2 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -6,12 +6,15 @@ package database import ( "context" "database/sql" + "encoding/json" + "errors" "fmt" "log" "net/url" "os" "reflect" "strings" + "sync" "testing" "time" @@ -35,12 +38,26 @@ import ( "github.com/mitchellh/mapstructure" ) +func getClusterPostgresDBWithFactory(t *testing.T, factory logical.Factory) (*vault.TestCluster, logical.SystemView) { + t.Helper() + cluster, sys := getClusterWithFactory(t, factory) + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed", + []string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)}) + return cluster, sys +} + func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView) { + t.Helper() + cluster, sys := getClusterPostgresDBWithFactory(t, Factory) + return cluster, sys +} + +func getClusterWithFactory(t *testing.T, factory logical.Factory) (*vault.TestCluster, logical.SystemView) { t.Helper() pluginDir := corehelpers.MakeTestPluginDir(t) coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ - "database": Factory, + "database": factory, }, BuiltinRegistry: builtinplugins.Registry, PluginDirectory: pluginDir, @@ -53,36 +70,14 @@ func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView) cores := cluster.Cores vault.TestWaitActive(t, cores[0].Core) - os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - sys := vault.TestDynamicSystemView(cores[0].Core, nil) - vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed", []string{}) return cluster, sys } func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { t.Helper() - pluginDir := corehelpers.MakeTestPluginDir(t) - coreConfig := &vault.CoreConfig{ - LogicalBackends: map[string]logical.Factory{ - "database": Factory, - }, - BuiltinRegistry: builtinplugins.Registry, - PluginDirectory: pluginDir, - } - - cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ - HandlerFunc: vaulthttp.Handler, - }) - cluster.Start() - cores := cluster.Cores - vault.TestWaitActive(t, cores[0].Core) - - os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - - sys := vault.TestDynamicSystemView(cores[0].Core, nil) - + cluster, sys := getClusterWithFactory(t, Factory) return cluster, sys } @@ -515,7 +510,7 @@ func TestBackend_basic(t *testing.T) { if credsResp.Secret.TTL != 5*time.Minute { t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL) } - if !testCredsExist(t, credsResp, connURL) { + if !testCredsExist(t, credsResp.Data, connURL) { t.Fatalf("Creds should exist") } @@ -535,7 +530,7 @@ func TestBackend_basic(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - if testCredsExist(t, credsResp, connURL) { + if testCredsExist(t, credsResp.Data, connURL) { t.Fatalf("Creds should not exist") } } @@ -553,7 +548,7 @@ func TestBackend_basic(t *testing.T) { if err != nil || (credsResp != nil && credsResp.IsError()) { t.Fatalf("err:%s resp:%#v\n", err, credsResp) } - if !testCredsExist(t, credsResp, connURL) { + if !testCredsExist(t, credsResp.Data, connURL) { t.Fatalf("Creds should exist") } @@ -586,108 +581,118 @@ func TestBackend_basic(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - if testCredsExist(t, credsResp, connURL) { + if testCredsExist(t, credsResp.Data, connURL) { t.Fatalf("Creds should not exist") } } } -func TestBackend_connectionCrud(t *testing.T) { - cluster, sys := getClusterPostgresDB(t) - defer cluster.Cleanup() +// singletonDBFactory allows us to reach into the internals of a databaseBackend +// even when it's been created by a call to the sys mount. The factory method +// satisfies the logical.Factory type, and lazily creates the databaseBackend +// once the SystemView has been provided because the factory method itself is an +// input for creating the test cluster and its system view. +type singletonDBFactory struct { + once sync.Once + db *databaseBackend + + sys logical.SystemView +} + +// factory satisfies the logical.Factory type. +func (s *singletonDBFactory) factory(context.Context, *logical.BackendConfig) (logical.Backend, error) { + if s.sys == nil { + return nil, errors.New("sys is nil") + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} - config.System = sys + config.System = s.sys - b, err := Factory(context.Background(), config) + var err error + s.once.Do(func() { + var b logical.Backend + b, err = Factory(context.Background(), config) + s.db = b.(*databaseBackend) + }) if err != nil { - t.Fatal(err) + return nil, err } - defer b.Cleanup(context.Background()) + if s.db == nil { + return nil, errors.New("db is nil") + } + return s.db, nil +} + +func TestBackend_connectionCrud(t *testing.T) { + dbFactory := &singletonDBFactory{} + cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory) + defer cluster.Cleanup() + + dbFactory.sys = sys + client := cluster.Cores[0].Client.Logical() cleanup, connURL := postgreshelper.PrepareTestContainer(t, "13.4-buster") defer cleanup() + // Mount the database plugin. + resp, err := client.Write("sys/mounts/database", map[string]interface{}{ + "type": "database", + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + // Configure a connection - data := map[string]interface{}{ + resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ "connection_url": "test", "plugin_name": "postgresql-database-plugin", "verify_connection": false, - } - req := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/plugin-test", - Storage: config.StorageView, - Data: data, - } - resp, err := b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + }) + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } // Configure a second connection to confirm below it doesn't get restarted. - data = map[string]interface{}{ + resp, err = client.Write("database/config/plugin-test-hana", map[string]interface{}{ "connection_url": "test", "plugin_name": "hana-database-plugin", "verify_connection": false, - } - req = &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/plugin-test-hana", - Storage: config.StorageView, - Data: data, - } - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + }) + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } // Create a role - data = map[string]interface{}{ + resp, err = client.Write("database/roles/plugin-role-test", map[string]interface{}{ "db_name": "plugin-test", "creation_statements": testRole, "revocation_statements": defaultRevocationSQL, "default_ttl": "5m", "max_ttl": "10m", - } - req = &logical.Request{ - Operation: logical.UpdateOperation, - Path: "roles/plugin-role-test", - Storage: config.StorageView, - Data: data, - } - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + }) + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } // Update the connection - data = map[string]interface{}{ + resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ "connection_url": connURL, "plugin_name": "postgresql-database-plugin", "allowed_roles": []string{"plugin-role-test"}, "username": "postgres", "password": "secret", "private_key": "PRIVATE_KEY", - } - req = &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/plugin-test", - Storage: config.StorageView, - Data: data, - } - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + }) + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } if len(resp.Warnings) == 0 { t.Fatalf("expected warning about password in url %s, resp:%#v\n", connURL, resp) } - req.Operation = logical.ReadOperation - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + resp, err = client.Read("database/config/plugin-test") + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{}) @@ -703,11 +708,16 @@ func TestBackend_connectionCrud(t *testing.T) { } // Replace connection url with templated version - req.Operation = logical.UpdateOperation - connURL = strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}") - data["connection_url"] = connURL - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + templatedConnURL := strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}") + resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ + "connection_url": templatedConnURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, + "username": "postgres", + "password": "secret", + "private_key": "PRIVATE_KEY", + }) + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } @@ -716,36 +726,38 @@ func TestBackend_connectionCrud(t *testing.T) { "plugin_name": "postgresql-database-plugin", "connection_details": map[string]interface{}{ "username": "postgres", - "connection_url": connURL, + "connection_url": templatedConnURL, }, - "allowed_roles": []string{"plugin-role-test"}, - "root_credentials_rotate_statements": []string(nil), + "allowed_roles": []any{"plugin-role-test"}, + "root_credentials_rotate_statements": []any{}, "password_policy": "", "plugin_version": "", } - req.Operation = logical.ReadOperation - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + resp, err = client.Read("database/config/plugin-test") + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } delete(resp.Data["connection_details"].(map[string]interface{}), "name") if diff := deep.Equal(resp.Data, expected); diff != nil { - t.Fatal(diff) + t.Fatal(strings.Join(diff, "\n")) } // Test endpoints for reloading plugins. - for _, reloadPath := range []string{ - "reset/plugin-test", - "reload/postgresql-database-plugin", + for _, reload := range []struct { + path string + data map[string]any + checkCount bool + }{ + {"database/reset/plugin-test", nil, false}, + {"database/reload/postgresql-database-plugin", nil, true}, + {"sys/plugins/reload/backend", map[string]any{ + "plugin": "postgresql-database-plugin", + }, false}, } { getConnectionID := func(name string) string { t.Helper() - dbBackend, ok := b.(*databaseBackend) - if !ok { - t.Fatal("could not convert logical.Backend to databaseBackend") - } - dbi := dbBackend.connections.Get(name) + dbi := dbFactory.db.connections.Get(name) if dbi == nil { t.Fatal("no plugin-test dbi") } @@ -753,14 +765,8 @@ func TestBackend_connectionCrud(t *testing.T) { } initialID := getConnectionID("plugin-test") hanaID := getConnectionID("plugin-test-hana") - req = &logical.Request{ - Operation: logical.UpdateOperation, - Path: reloadPath, - Storage: config.StorageView, - Data: map[string]interface{}{}, - } - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + resp, err = client.Write(reload.path, reload.data) + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } if initialID == getConnectionID("plugin-test") { @@ -769,54 +775,43 @@ func TestBackend_connectionCrud(t *testing.T) { if hanaID != getConnectionID("plugin-test-hana") { t.Fatal("hana plugin got restarted but shouldn't have been") } - if strings.HasPrefix(reloadPath, "reload/") { - if expected := 1; expected != resp.Data["count"] { - t.Fatalf("expected %d but got %d", expected, resp.Data["count"]) + if reload.checkCount { + actual, err := resp.Data["count"].(json.Number).Int64() + if err != nil { + t.Fatal(err) } - if expected := []string{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) { + if expected := 1; expected != int(actual) { + t.Fatalf("expected %d but got %d", expected, resp.Data["count"].(int)) + } + if expected := []any{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) { t.Fatalf("expected %v but got %v", expected, resp.Data["connections"]) } } } // Get creds - data = map[string]interface{}{} - req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "creds/plugin-role-test", - Storage: config.StorageView, - Data: data, - } - credsResp, err := b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (credsResp != nil && credsResp.IsError()) { + credsResp, err := client.Read("database/creds/plugin-role-test") + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, credsResp) } - credCheckURL := dbutil.QueryHelper(connURL, map[string]string{ + credCheckURL := dbutil.QueryHelper(templatedConnURL, map[string]string{ "username": "postgres", "password": "secret", }) - if !testCredsExist(t, credsResp, credCheckURL) { + if !testCredsExist(t, credsResp.Data, credCheckURL) { t.Fatalf("Creds should exist") } // Delete Connection - data = map[string]interface{}{} - req = &logical.Request{ - Operation: logical.DeleteOperation, - Path: "config/plugin-test", - Storage: config.StorageView, - Data: data, - } - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + resp, err = client.Delete("database/config/plugin-test") + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } // Read connection - req.Operation = logical.ReadOperation - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil || (resp != nil && resp.IsError()) { + resp, err = client.Read("database/config/plugin-test") + if err != nil { t.Fatalf("err:%s resp:%#v\n", err, resp) } @@ -1190,7 +1185,7 @@ func TestBackend_allowedRoles(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, credsResp) } - if !testCredsExist(t, credsResp, connURL) { + if !testCredsExist(t, credsResp.Data, connURL) { t.Fatalf("Creds should exist") } @@ -1224,7 +1219,7 @@ func TestBackend_allowedRoles(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, credsResp) } - if !testCredsExist(t, credsResp, connURL) { + if !testCredsExist(t, credsResp.Data, connURL) { t.Fatalf("Creds should exist") } @@ -1271,7 +1266,7 @@ func TestBackend_allowedRoles(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, credsResp) } - if !testCredsExist(t, credsResp, connURL) { + if !testCredsExist(t, credsResp.Data, connURL) { t.Fatalf("Creds should exist") } } @@ -1581,13 +1576,13 @@ func TestNewDatabaseWrapper_IgnoresBuiltinVersion(t *testing.T) { } } -func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool { +func testCredsExist(t *testing.T, data map[string]any, connURL string) bool { t.Helper() var d struct { Username string `mapstructure:"username"` Password string `mapstructure:"password"` } - if err := mapstructure.Decode(resp.Data, &d); err != nil { + if err := mapstructure.Decode(data, &d); err != nil { t.Fatal(err) } log.Printf("[TRACE] Generated credentials: %v", d) diff --git a/builtin/logical/database/versioning_large_test.go b/builtin/logical/database/versioning_large_test.go index bacb4a6a71..be936c7603 100644 --- a/builtin/logical/database/versioning_large_test.go +++ b/builtin/logical/database/versioning_large_test.go @@ -25,9 +25,10 @@ func TestPlugin_lifecycle(t *testing.T) { cluster, sys := getCluster(t) defer cluster.Cleanup() - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", []string{}) - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{}) - vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", []string{}) + env := []string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)} + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", env) + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", env) + vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", env) config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} diff --git a/builtin/plugin/backend_test.go b/builtin/plugin/backend_test.go index c1b7a83d15..7134440612 100644 --- a/builtin/plugin/backend_test.go +++ b/builtin/plugin/backend_test.go @@ -140,9 +140,8 @@ func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func()) }, } - os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd, []string{}) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd, + []string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)}) return config, func() { cluster.Cleanup() diff --git a/changelog/24512.txt b/changelog/24512.txt new file mode 100644 index 0000000000..efed04a225 --- /dev/null +++ b/changelog/24512.txt @@ -0,0 +1,6 @@ +```release-note:change +plugins: Add a warning to the response from sys/plugins/reload/backend if no plugins were reloaded. +``` +```release-note:improvement +secrets/database: Support reloading named database plugins using the sys/plugins/reload/backend API endpoint. +``` diff --git a/http/plugin_test.go b/http/plugin_test.go index fa67187621..b215a6b1c6 100644 --- a/http/plugin_test.go +++ b/http/plugin_test.go @@ -5,6 +5,7 @@ package http import ( "encoding/json" + "fmt" "io/ioutil" "os" "reflect" @@ -55,10 +56,9 @@ func getPluginClusterAndCore(t *testing.T, logger log.Logger) (*vault.TestCluste cores := cluster.Cores core := cores[0] - os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - vault.TestWaitActive(benchhelpers.TBtoT(t), core.Core) - vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain", []string{}) + vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain", + []string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)}) // Mount the mock plugin err = core.Client.Sys().Mount("mock", &api.MountInput{ diff --git a/vault/external_tests/plugin/plugin_test.go b/vault/external_tests/plugin/plugin_test.go index e29340a17c..affe8fa6b8 100644 --- a/vault/external_tests/plugin/plugin_test.go +++ b/vault/external_tests/plugin/plugin_test.go @@ -560,6 +560,9 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} if resp.Data["reload_id"] == nil { t.Fatal("no reload_id in response") } + if len(resp.Warnings) != 0 { + t.Fatal(resp.Warnings) + } for i := 0; i < 2; i++ { // Ensure internal backed value is reset @@ -578,6 +581,35 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} } } +func TestSystemBackend_PluginReload_WarningIfNoneReloaded(t *testing.T) { + cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical, "v5") + defer cluster.Cleanup() + + core := cluster.Cores[0] + client := core.Client + + for _, backendType := range []logical.BackendType{logical.TypeLogical, logical.TypeCredential} { + t.Run(backendType.String(), func(t *testing.T) { + // Perform plugin reload + resp, err := client.Logical().Write("sys/plugins/reload/backend", map[string]any{ + "plugin": "does-not-exist", + }) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: %v", resp) + } + if resp.Data["reload_id"] == nil { + t.Fatal("no reload_id in response") + } + if len(resp.Warnings) == 0 { + t.Fatal("expected warning") + } + }) + } +} + // testSystemBackendMock returns a systemBackend with the desired number // of mounted mock plugin backends. numMounts alternates between different // ways of providing the plugin_name. diff --git a/vault/logical_system.go b/vault/logical_system.go index 1ae623546e..6215954850 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -738,11 +738,24 @@ func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logic return logical.ErrorResponse("plugin or mounts must be provided"), nil } + resp := logical.Response{ + Data: map[string]interface{}{ + "reload_id": req.ID, + }, + } + if pluginName != "" { - err := b.Core.reloadMatchingPlugin(ctx, pluginName) + reloaded, err := b.Core.reloadMatchingPlugin(ctx, pluginName) if err != nil { return nil, err } + if reloaded == 0 { + if scope == globalScope { + resp.AddWarning("no plugins were reloaded locally (but they may be reloaded on other nodes)") + } else { + resp.AddWarning("no plugins were reloaded") + } + } } else if len(pluginMounts) > 0 { err := b.Core.reloadMatchingPluginMounts(ctx, pluginMounts) if err != nil { @@ -750,20 +763,14 @@ func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logic } } - r := logical.Response{ - Data: map[string]interface{}{ - "reload_id": req.ID, - }, - } - if scope == globalScope { err := handleGlobalPluginReload(ctx, b.Core, req.ID, pluginName, pluginMounts) if err != nil { return nil, err } - return logical.RespondWithStatusCode(&r, req, http.StatusAccepted) + return logical.RespondWithStatusCode(&resp, req, http.StatusAccepted) } - return &r, nil + return &resp, nil } func (b *SystemBackend) handlePluginRuntimeCatalogUpdate(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) { diff --git a/vault/mount.go b/vault/mount.go index 158dbd9356..d70e819360 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -322,7 +322,7 @@ const mountStateUnmounting = "unmounting" // MountEntry is used to represent a mount table entry type MountEntry struct { Table string `json:"table"` // The table it belongs to - Path string `json:"path"` // Mount Path + Path string `json:"path"` // Mount Path, as provided in the mount API call but with a trailing slash, i.e. no auth/ or namespace prefix. Type string `json:"type"` // Logical backend Type. NB: This is the plugin name, e.g. my-vault-plugin, NOT plugin type (e.g. auth). Description string `json:"description"` // User-provided description UUID string `json:"uuid"` // Barrier view UUID diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go index 7fa6e936e6..938c47eb34 100644 --- a/vault/plugin_reload.go +++ b/vault/plugin_reload.go @@ -70,10 +70,10 @@ func (c *Core) reloadMatchingPluginMounts(ctx context.Context, mounts []string) return errors } -// reloadPlugin reloads all mounted backends that are of -// plugin pluginName (name of the plugin as registered in -// the plugin catalog). -func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) error { +// reloadMatchingPlugin reloads all mounted backends that are named pluginName +// (name of the plugin as registered in the plugin catalog). It returns the +// number of plugins that were reloaded and an error if any. +func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) (reloaded int, err error) { c.mountsLock.RLock() defer c.mountsLock.RUnlock() c.authLock.RLock() @@ -81,25 +81,49 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro ns, err := namespace.FromContext(ctx) if err != nil { - return err + return reloaded, err } - // Filter mount entries that only matches the plugin name for _, entry := range c.mounts.Entries { // We dont reload mounts that are not in the same namespace if ns.ID != entry.Namespace().ID { continue } + if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) { err := c.reloadBackendCommon(ctx, entry, false) if err != nil { - return err + return reloaded, err + } + reloaded++ + c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.Version) + } else if entry.Type == "database" { + // The combined database plugin is itself a secrets engine, but + // knowledge of whether a database plugin is in use within a particular + // mount is internal to the combined database plugin's storage, so + // we delegate the reload request with an internally routed request. + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: entry.Path + "reload/" + pluginName, + } + resp, err := c.router.Route(ctx, req) + if err != nil { + return reloaded, err + } + if resp == nil { + return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s", pluginName, entry.Path) + } + if resp.IsError() { + return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s: %s", pluginName, entry.Path, resp.Error()) + } + + if count, ok := resp.Data["count"].(int); ok && count > 0 { + c.logger.Info("successfully reloaded database plugin(s)", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "connections", resp.Data["connections"]) + reloaded += count } - c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "path", entry.Path, "version", entry.Version) } } - // Filter auth mount entries that ony matches the plugin name for _, entry := range c.auth.Entries { // We dont reload mounts that are not in the same namespace if ns.ID != entry.Namespace().ID { @@ -109,13 +133,14 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) { err := c.reloadBackendCommon(ctx, entry, true) if err != nil { - return err + return reloaded, err } + reloaded++ c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version) } } - return nil + return reloaded, nil } // reloadBackendCommon is a generic method to reload a backend provided a