mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 09:42:25 +00:00
Support reloading database plugins across multiple mounts (#24512)
* Support reloading database plugins across multiple mounts * Add clarifying comment to MountEntry.Path field * Tests: Replace non-parallelisable t.Setenv with plugin env settings
This commit is contained in:
@@ -6,12 +6,15 @@ package database
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -35,12 +38,26 @@ import (
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
func getClusterPostgresDBWithFactory(t *testing.T, factory logical.Factory) (*vault.TestCluster, logical.SystemView) {
|
||||
t.Helper()
|
||||
cluster, sys := getClusterWithFactory(t, factory)
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed",
|
||||
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})
|
||||
return cluster, sys
|
||||
}
|
||||
|
||||
func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
||||
t.Helper()
|
||||
cluster, sys := getClusterPostgresDBWithFactory(t, Factory)
|
||||
return cluster, sys
|
||||
}
|
||||
|
||||
func getClusterWithFactory(t *testing.T, factory logical.Factory) (*vault.TestCluster, logical.SystemView) {
|
||||
t.Helper()
|
||||
pluginDir := corehelpers.MakeTestPluginDir(t)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"database": Factory,
|
||||
"database": factory,
|
||||
},
|
||||
BuiltinRegistry: builtinplugins.Registry,
|
||||
PluginDirectory: pluginDir,
|
||||
@@ -53,36 +70,14 @@ func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView)
|
||||
cores := cluster.Cores
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
||||
|
||||
sys := vault.TestDynamicSystemView(cores[0].Core, nil)
|
||||
vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed", []string{})
|
||||
|
||||
return cluster, sys
|
||||
}
|
||||
|
||||
func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
||||
t.Helper()
|
||||
pluginDir := corehelpers.MakeTestPluginDir(t)
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"database": Factory,
|
||||
},
|
||||
BuiltinRegistry: builtinplugins.Registry,
|
||||
PluginDirectory: pluginDir,
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
cluster.Start()
|
||||
cores := cluster.Cores
|
||||
vault.TestWaitActive(t, cores[0].Core)
|
||||
|
||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
||||
|
||||
sys := vault.TestDynamicSystemView(cores[0].Core, nil)
|
||||
|
||||
cluster, sys := getClusterWithFactory(t, Factory)
|
||||
return cluster, sys
|
||||
}
|
||||
|
||||
@@ -515,7 +510,7 @@ func TestBackend_basic(t *testing.T) {
|
||||
if credsResp.Secret.TTL != 5*time.Minute {
|
||||
t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL)
|
||||
}
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
@@ -535,7 +530,7 @@ func TestBackend_basic(t *testing.T) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
if testCredsExist(t, credsResp, connURL) {
|
||||
if testCredsExist(t, credsResp.Data, connURL) {
|
||||
t.Fatalf("Creds should not exist")
|
||||
}
|
||||
}
|
||||
@@ -553,7 +548,7 @@ func TestBackend_basic(t *testing.T) {
|
||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
@@ -586,108 +581,118 @@ func TestBackend_basic(t *testing.T) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
if testCredsExist(t, credsResp, connURL) {
|
||||
if testCredsExist(t, credsResp.Data, connURL) {
|
||||
t.Fatalf("Creds should not exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_connectionCrud(t *testing.T) {
|
||||
cluster, sys := getClusterPostgresDB(t)
|
||||
defer cluster.Cleanup()
|
||||
// singletonDBFactory allows us to reach into the internals of a databaseBackend
|
||||
// even when it's been created by a call to the sys mount. The factory method
|
||||
// satisfies the logical.Factory type, and lazily creates the databaseBackend
|
||||
// once the SystemView has been provided because the factory method itself is an
|
||||
// input for creating the test cluster and its system view.
|
||||
type singletonDBFactory struct {
|
||||
once sync.Once
|
||||
db *databaseBackend
|
||||
|
||||
sys logical.SystemView
|
||||
}
|
||||
|
||||
// factory satisfies the logical.Factory type.
|
||||
func (s *singletonDBFactory) factory(context.Context, *logical.BackendConfig) (logical.Backend, error) {
|
||||
if s.sys == nil {
|
||||
return nil, errors.New("sys is nil")
|
||||
}
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
config.System = s.sys
|
||||
|
||||
b, err := Factory(context.Background(), config)
|
||||
var err error
|
||||
s.once.Do(func() {
|
||||
var b logical.Backend
|
||||
b, err = Factory(context.Background(), config)
|
||||
s.db = b.(*databaseBackend)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil, err
|
||||
}
|
||||
defer b.Cleanup(context.Background())
|
||||
if s.db == nil {
|
||||
return nil, errors.New("db is nil")
|
||||
}
|
||||
return s.db, nil
|
||||
}
|
||||
|
||||
func TestBackend_connectionCrud(t *testing.T) {
|
||||
dbFactory := &singletonDBFactory{}
|
||||
cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
dbFactory.sys = sys
|
||||
client := cluster.Cores[0].Client.Logical()
|
||||
|
||||
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "13.4-buster")
|
||||
defer cleanup()
|
||||
|
||||
// Mount the database plugin.
|
||||
resp, err := client.Write("sys/mounts/database", map[string]interface{}{
|
||||
"type": "database",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Configure a connection
|
||||
data := map[string]interface{}{
|
||||
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
|
||||
"connection_url": "test",
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"verify_connection": false,
|
||||
}
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Configure a second connection to confirm below it doesn't get restarted.
|
||||
data = map[string]interface{}{
|
||||
resp, err = client.Write("database/config/plugin-test-hana", map[string]interface{}{
|
||||
"connection_url": "test",
|
||||
"plugin_name": "hana-database-plugin",
|
||||
"verify_connection": false,
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test-hana",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Create a role
|
||||
data = map[string]interface{}{
|
||||
resp, err = client.Write("database/roles/plugin-role-test", map[string]interface{}{
|
||||
"db_name": "plugin-test",
|
||||
"creation_statements": testRole,
|
||||
"revocation_statements": defaultRevocationSQL,
|
||||
"default_ttl": "5m",
|
||||
"max_ttl": "10m",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "roles/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Update the connection
|
||||
data = map[string]interface{}{
|
||||
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"allowed_roles": []string{"plugin-role-test"},
|
||||
"username": "postgres",
|
||||
"password": "secret",
|
||||
"private_key": "PRIVATE_KEY",
|
||||
}
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
if len(resp.Warnings) == 0 {
|
||||
t.Fatalf("expected warning about password in url %s, resp:%#v\n", connURL, resp)
|
||||
}
|
||||
|
||||
req.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
resp, err = client.Read("database/config/plugin-test")
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{})
|
||||
@@ -703,11 +708,16 @@ func TestBackend_connectionCrud(t *testing.T) {
|
||||
}
|
||||
|
||||
// Replace connection url with templated version
|
||||
req.Operation = logical.UpdateOperation
|
||||
connURL = strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}")
|
||||
data["connection_url"] = connURL
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
templatedConnURL := strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}")
|
||||
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
|
||||
"connection_url": templatedConnURL,
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"allowed_roles": []string{"plugin-role-test"},
|
||||
"username": "postgres",
|
||||
"password": "secret",
|
||||
"private_key": "PRIVATE_KEY",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
@@ -716,36 +726,38 @@ func TestBackend_connectionCrud(t *testing.T) {
|
||||
"plugin_name": "postgresql-database-plugin",
|
||||
"connection_details": map[string]interface{}{
|
||||
"username": "postgres",
|
||||
"connection_url": connURL,
|
||||
"connection_url": templatedConnURL,
|
||||
},
|
||||
"allowed_roles": []string{"plugin-role-test"},
|
||||
"root_credentials_rotate_statements": []string(nil),
|
||||
"allowed_roles": []any{"plugin-role-test"},
|
||||
"root_credentials_rotate_statements": []any{},
|
||||
"password_policy": "",
|
||||
"plugin_version": "",
|
||||
}
|
||||
req.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
resp, err = client.Read("database/config/plugin-test")
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
|
||||
if diff := deep.Equal(resp.Data, expected); diff != nil {
|
||||
t.Fatal(diff)
|
||||
t.Fatal(strings.Join(diff, "\n"))
|
||||
}
|
||||
|
||||
// Test endpoints for reloading plugins.
|
||||
for _, reloadPath := range []string{
|
||||
"reset/plugin-test",
|
||||
"reload/postgresql-database-plugin",
|
||||
for _, reload := range []struct {
|
||||
path string
|
||||
data map[string]any
|
||||
checkCount bool
|
||||
}{
|
||||
{"database/reset/plugin-test", nil, false},
|
||||
{"database/reload/postgresql-database-plugin", nil, true},
|
||||
{"sys/plugins/reload/backend", map[string]any{
|
||||
"plugin": "postgresql-database-plugin",
|
||||
}, false},
|
||||
} {
|
||||
getConnectionID := func(name string) string {
|
||||
t.Helper()
|
||||
dbBackend, ok := b.(*databaseBackend)
|
||||
if !ok {
|
||||
t.Fatal("could not convert logical.Backend to databaseBackend")
|
||||
}
|
||||
dbi := dbBackend.connections.Get(name)
|
||||
dbi := dbFactory.db.connections.Get(name)
|
||||
if dbi == nil {
|
||||
t.Fatal("no plugin-test dbi")
|
||||
}
|
||||
@@ -753,14 +765,8 @@ func TestBackend_connectionCrud(t *testing.T) {
|
||||
}
|
||||
initialID := getConnectionID("plugin-test")
|
||||
hanaID := getConnectionID("plugin-test-hana")
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: reloadPath,
|
||||
Storage: config.StorageView,
|
||||
Data: map[string]interface{}{},
|
||||
}
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
resp, err = client.Write(reload.path, reload.data)
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
if initialID == getConnectionID("plugin-test") {
|
||||
@@ -769,54 +775,43 @@ func TestBackend_connectionCrud(t *testing.T) {
|
||||
if hanaID != getConnectionID("plugin-test-hana") {
|
||||
t.Fatal("hana plugin got restarted but shouldn't have been")
|
||||
}
|
||||
if strings.HasPrefix(reloadPath, "reload/") {
|
||||
if expected := 1; expected != resp.Data["count"] {
|
||||
t.Fatalf("expected %d but got %d", expected, resp.Data["count"])
|
||||
if reload.checkCount {
|
||||
actual, err := resp.Data["count"].(json.Number).Int64()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if expected := []string{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) {
|
||||
if expected := 1; expected != int(actual) {
|
||||
t.Fatalf("expected %d but got %d", expected, resp.Data["count"].(int))
|
||||
}
|
||||
if expected := []any{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) {
|
||||
t.Fatalf("expected %v but got %v", expected, resp.Data["connections"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get creds
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "creds/plugin-role-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
credsResp, err := b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
||||
credsResp, err := client.Read("database/creds/plugin-role-test")
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
credCheckURL := dbutil.QueryHelper(connURL, map[string]string{
|
||||
credCheckURL := dbutil.QueryHelper(templatedConnURL, map[string]string{
|
||||
"username": "postgres",
|
||||
"password": "secret",
|
||||
})
|
||||
if !testCredsExist(t, credsResp, credCheckURL) {
|
||||
if !testCredsExist(t, credsResp.Data, credCheckURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
// Delete Connection
|
||||
data = map[string]interface{}{}
|
||||
req = &logical.Request{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "config/plugin-test",
|
||||
Storage: config.StorageView,
|
||||
Data: data,
|
||||
}
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
resp, err = client.Delete("database/config/plugin-test")
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
// Read connection
|
||||
req.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
resp, err = client.Read("database/config/plugin-test")
|
||||
if err != nil {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
@@ -1190,7 +1185,7 @@ func TestBackend_allowedRoles(t *testing.T) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
@@ -1224,7 +1219,7 @@ func TestBackend_allowedRoles(t *testing.T) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
|
||||
@@ -1271,7 +1266,7 @@ func TestBackend_allowedRoles(t *testing.T) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||
}
|
||||
|
||||
if !testCredsExist(t, credsResp, connURL) {
|
||||
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||
t.Fatalf("Creds should exist")
|
||||
}
|
||||
}
|
||||
@@ -1581,13 +1576,13 @@ func TestNewDatabaseWrapper_IgnoresBuiltinVersion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool {
|
||||
func testCredsExist(t *testing.T, data map[string]any, connURL string) bool {
|
||||
t.Helper()
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
if err := mapstructure.Decode(data, &d); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||
|
||||
@@ -25,9 +25,10 @@ func TestPlugin_lifecycle(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", []string{})
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{})
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", []string{})
|
||||
env := []string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)}
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", env)
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", env)
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", env)
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
||||
@@ -140,9 +140,8 @@ func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func())
|
||||
},
|
||||
}
|
||||
|
||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
||||
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd, []string{})
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd,
|
||||
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})
|
||||
|
||||
return config, func() {
|
||||
cluster.Cleanup()
|
||||
|
||||
6
changelog/24512.txt
Normal file
6
changelog/24512.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
```release-note:change
|
||||
plugins: Add a warning to the response from sys/plugins/reload/backend if no plugins were reloaded.
|
||||
```
|
||||
```release-note:improvement
|
||||
secrets/database: Support reloading named database plugins using the sys/plugins/reload/backend API endpoint.
|
||||
```
|
||||
@@ -5,6 +5,7 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -55,10 +56,9 @@ func getPluginClusterAndCore(t *testing.T, logger log.Logger) (*vault.TestCluste
|
||||
cores := cluster.Cores
|
||||
core := cores[0]
|
||||
|
||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
||||
|
||||
vault.TestWaitActive(benchhelpers.TBtoT(t), core.Core)
|
||||
vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain", []string{})
|
||||
vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain",
|
||||
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})
|
||||
|
||||
// Mount the mock plugin
|
||||
err = core.Client.Sys().Mount("mock", &api.MountInput{
|
||||
|
||||
@@ -560,6 +560,9 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
||||
if resp.Data["reload_id"] == nil {
|
||||
t.Fatal("no reload_id in response")
|
||||
}
|
||||
if len(resp.Warnings) != 0 {
|
||||
t.Fatal(resp.Warnings)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
// Ensure internal backed value is reset
|
||||
@@ -578,6 +581,35 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemBackend_PluginReload_WarningIfNoneReloaded(t *testing.T) {
|
||||
cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical, "v5")
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
client := core.Client
|
||||
|
||||
for _, backendType := range []logical.BackendType{logical.TypeLogical, logical.TypeCredential} {
|
||||
t.Run(backendType.String(), func(t *testing.T) {
|
||||
// Perform plugin reload
|
||||
resp, err := client.Logical().Write("sys/plugins/reload/backend", map[string]any{
|
||||
"plugin": "does-not-exist",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
if resp.Data["reload_id"] == nil {
|
||||
t.Fatal("no reload_id in response")
|
||||
}
|
||||
if len(resp.Warnings) == 0 {
|
||||
t.Fatal("expected warning")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testSystemBackendMock returns a systemBackend with the desired number
|
||||
// of mounted mock plugin backends. numMounts alternates between different
|
||||
// ways of providing the plugin_name.
|
||||
|
||||
@@ -738,11 +738,24 @@ func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logic
|
||||
return logical.ErrorResponse("plugin or mounts must be provided"), nil
|
||||
}
|
||||
|
||||
resp := logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"reload_id": req.ID,
|
||||
},
|
||||
}
|
||||
|
||||
if pluginName != "" {
|
||||
err := b.Core.reloadMatchingPlugin(ctx, pluginName)
|
||||
reloaded, err := b.Core.reloadMatchingPlugin(ctx, pluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reloaded == 0 {
|
||||
if scope == globalScope {
|
||||
resp.AddWarning("no plugins were reloaded locally (but they may be reloaded on other nodes)")
|
||||
} else {
|
||||
resp.AddWarning("no plugins were reloaded")
|
||||
}
|
||||
}
|
||||
} else if len(pluginMounts) > 0 {
|
||||
err := b.Core.reloadMatchingPluginMounts(ctx, pluginMounts)
|
||||
if err != nil {
|
||||
@@ -750,20 +763,14 @@ func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logic
|
||||
}
|
||||
}
|
||||
|
||||
r := logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"reload_id": req.ID,
|
||||
},
|
||||
}
|
||||
|
||||
if scope == globalScope {
|
||||
err := handleGlobalPluginReload(ctx, b.Core, req.ID, pluginName, pluginMounts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logical.RespondWithStatusCode(&r, req, http.StatusAccepted)
|
||||
return logical.RespondWithStatusCode(&resp, req, http.StatusAccepted)
|
||||
}
|
||||
return &r, nil
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handlePluginRuntimeCatalogUpdate(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
|
||||
@@ -322,7 +322,7 @@ const mountStateUnmounting = "unmounting"
|
||||
// MountEntry is used to represent a mount table entry
|
||||
type MountEntry struct {
|
||||
Table string `json:"table"` // The table it belongs to
|
||||
Path string `json:"path"` // Mount Path
|
||||
Path string `json:"path"` // Mount Path, as provided in the mount API call but with a trailing slash, i.e. no auth/ or namespace prefix.
|
||||
Type string `json:"type"` // Logical backend Type. NB: This is the plugin name, e.g. my-vault-plugin, NOT plugin type (e.g. auth).
|
||||
Description string `json:"description"` // User-provided description
|
||||
UUID string `json:"uuid"` // Barrier view UUID
|
||||
|
||||
@@ -70,10 +70,10 @@ func (c *Core) reloadMatchingPluginMounts(ctx context.Context, mounts []string)
|
||||
return errors
|
||||
}
|
||||
|
||||
// reloadPlugin reloads all mounted backends that are of
|
||||
// plugin pluginName (name of the plugin as registered in
|
||||
// the plugin catalog).
|
||||
func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) error {
|
||||
// reloadMatchingPlugin reloads all mounted backends that are named pluginName
|
||||
// (name of the plugin as registered in the plugin catalog). It returns the
|
||||
// number of plugins that were reloaded and an error if any.
|
||||
func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) (reloaded int, err error) {
|
||||
c.mountsLock.RLock()
|
||||
defer c.mountsLock.RUnlock()
|
||||
c.authLock.RLock()
|
||||
@@ -81,25 +81,49 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro
|
||||
|
||||
ns, err := namespace.FromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return reloaded, err
|
||||
}
|
||||
|
||||
// Filter mount entries that only matches the plugin name
|
||||
for _, entry := range c.mounts.Entries {
|
||||
// We dont reload mounts that are not in the same namespace
|
||||
if ns.ID != entry.Namespace().ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
|
||||
err := c.reloadBackendCommon(ctx, entry, false)
|
||||
if err != nil {
|
||||
return err
|
||||
return reloaded, err
|
||||
}
|
||||
reloaded++
|
||||
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.Version)
|
||||
} else if entry.Type == "database" {
|
||||
// The combined database plugin is itself a secrets engine, but
|
||||
// knowledge of whether a database plugin is in use within a particular
|
||||
// mount is internal to the combined database plugin's storage, so
|
||||
// we delegate the reload request with an internally routed request.
|
||||
req := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: entry.Path + "reload/" + pluginName,
|
||||
}
|
||||
resp, err := c.router.Route(ctx, req)
|
||||
if err != nil {
|
||||
return reloaded, err
|
||||
}
|
||||
if resp == nil {
|
||||
return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s", pluginName, entry.Path)
|
||||
}
|
||||
if resp.IsError() {
|
||||
return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s: %s", pluginName, entry.Path, resp.Error())
|
||||
}
|
||||
|
||||
if count, ok := resp.Data["count"].(int); ok && count > 0 {
|
||||
c.logger.Info("successfully reloaded database plugin(s)", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "connections", resp.Data["connections"])
|
||||
reloaded += count
|
||||
}
|
||||
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "path", entry.Path, "version", entry.Version)
|
||||
}
|
||||
}
|
||||
|
||||
// Filter auth mount entries that ony matches the plugin name
|
||||
for _, entry := range c.auth.Entries {
|
||||
// We dont reload mounts that are not in the same namespace
|
||||
if ns.ID != entry.Namespace().ID {
|
||||
@@ -109,13 +133,14 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro
|
||||
if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
|
||||
err := c.reloadBackendCommon(ctx, entry, true)
|
||||
if err != nil {
|
||||
return err
|
||||
return reloaded, err
|
||||
}
|
||||
reloaded++
|
||||
c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return reloaded, nil
|
||||
}
|
||||
|
||||
// reloadBackendCommon is a generic method to reload a backend provided a
|
||||
|
||||
Reference in New Issue
Block a user