mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Support reloading database plugins across multiple mounts (#24512)
* Support reloading database plugins across multiple mounts * Add clarifying comment to MountEntry.Path field * Tests: Replace non-parallelisable t.Setenv with plugin env settings
This commit is contained in:
@@ -6,12 +6,15 @@ package database
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -35,12 +38,26 @@ import (
|
|||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getClusterPostgresDBWithFactory(t *testing.T, factory logical.Factory) (*vault.TestCluster, logical.SystemView) {
|
||||||
|
t.Helper()
|
||||||
|
cluster, sys := getClusterWithFactory(t, factory)
|
||||||
|
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed",
|
||||||
|
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})
|
||||||
|
return cluster, sys
|
||||||
|
}
|
||||||
|
|
||||||
func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
||||||
|
t.Helper()
|
||||||
|
cluster, sys := getClusterPostgresDBWithFactory(t, Factory)
|
||||||
|
return cluster, sys
|
||||||
|
}
|
||||||
|
|
||||||
|
func getClusterWithFactory(t *testing.T, factory logical.Factory) (*vault.TestCluster, logical.SystemView) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
pluginDir := corehelpers.MakeTestPluginDir(t)
|
pluginDir := corehelpers.MakeTestPluginDir(t)
|
||||||
coreConfig := &vault.CoreConfig{
|
coreConfig := &vault.CoreConfig{
|
||||||
LogicalBackends: map[string]logical.Factory{
|
LogicalBackends: map[string]logical.Factory{
|
||||||
"database": Factory,
|
"database": factory,
|
||||||
},
|
},
|
||||||
BuiltinRegistry: builtinplugins.Registry,
|
BuiltinRegistry: builtinplugins.Registry,
|
||||||
PluginDirectory: pluginDir,
|
PluginDirectory: pluginDir,
|
||||||
@@ -53,36 +70,14 @@ func getClusterPostgresDB(t *testing.T) (*vault.TestCluster, logical.SystemView)
|
|||||||
cores := cluster.Cores
|
cores := cluster.Cores
|
||||||
vault.TestWaitActive(t, cores[0].Core)
|
vault.TestWaitActive(t, cores[0].Core)
|
||||||
|
|
||||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
|
||||||
|
|
||||||
sys := vault.TestDynamicSystemView(cores[0].Core, nil)
|
sys := vault.TestDynamicSystemView(cores[0].Core, nil)
|
||||||
vault.TestAddTestPlugin(t, cores[0].Core, "postgresql-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_PostgresMultiplexed", []string{})
|
|
||||||
|
|
||||||
return cluster, sys
|
return cluster, sys
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
pluginDir := corehelpers.MakeTestPluginDir(t)
|
cluster, sys := getClusterWithFactory(t, Factory)
|
||||||
coreConfig := &vault.CoreConfig{
|
|
||||||
LogicalBackends: map[string]logical.Factory{
|
|
||||||
"database": Factory,
|
|
||||||
},
|
|
||||||
BuiltinRegistry: builtinplugins.Registry,
|
|
||||||
PluginDirectory: pluginDir,
|
|
||||||
}
|
|
||||||
|
|
||||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
|
||||||
HandlerFunc: vaulthttp.Handler,
|
|
||||||
})
|
|
||||||
cluster.Start()
|
|
||||||
cores := cluster.Cores
|
|
||||||
vault.TestWaitActive(t, cores[0].Core)
|
|
||||||
|
|
||||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
|
||||||
|
|
||||||
sys := vault.TestDynamicSystemView(cores[0].Core, nil)
|
|
||||||
|
|
||||||
return cluster, sys
|
return cluster, sys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -515,7 +510,7 @@ func TestBackend_basic(t *testing.T) {
|
|||||||
if credsResp.Secret.TTL != 5*time.Minute {
|
if credsResp.Secret.TTL != 5*time.Minute {
|
||||||
t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL)
|
t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL)
|
||||||
}
|
}
|
||||||
if !testCredsExist(t, credsResp, connURL) {
|
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||||
t.Fatalf("Creds should exist")
|
t.Fatalf("Creds should exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,7 +530,7 @@ func TestBackend_basic(t *testing.T) {
|
|||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if testCredsExist(t, credsResp, connURL) {
|
if testCredsExist(t, credsResp.Data, connURL) {
|
||||||
t.Fatalf("Creds should not exist")
|
t.Fatalf("Creds should not exist")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -553,7 +548,7 @@ func TestBackend_basic(t *testing.T) {
|
|||||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||||
}
|
}
|
||||||
if !testCredsExist(t, credsResp, connURL) {
|
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||||
t.Fatalf("Creds should exist")
|
t.Fatalf("Creds should exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,108 +581,118 @@ func TestBackend_basic(t *testing.T) {
|
|||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if testCredsExist(t, credsResp, connURL) {
|
if testCredsExist(t, credsResp.Data, connURL) {
|
||||||
t.Fatalf("Creds should not exist")
|
t.Fatalf("Creds should not exist")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBackend_connectionCrud(t *testing.T) {
|
// singletonDBFactory allows us to reach into the internals of a databaseBackend
|
||||||
cluster, sys := getClusterPostgresDB(t)
|
// even when it's been created by a call to the sys mount. The factory method
|
||||||
defer cluster.Cleanup()
|
// satisfies the logical.Factory type, and lazily creates the databaseBackend
|
||||||
|
// once the SystemView has been provided because the factory method itself is an
|
||||||
|
// input for creating the test cluster and its system view.
|
||||||
|
type singletonDBFactory struct {
|
||||||
|
once sync.Once
|
||||||
|
db *databaseBackend
|
||||||
|
|
||||||
|
sys logical.SystemView
|
||||||
|
}
|
||||||
|
|
||||||
|
// factory satisfies the logical.Factory type.
|
||||||
|
func (s *singletonDBFactory) factory(context.Context, *logical.BackendConfig) (logical.Backend, error) {
|
||||||
|
if s.sys == nil {
|
||||||
|
return nil, errors.New("sys is nil")
|
||||||
|
}
|
||||||
|
|
||||||
config := logical.TestBackendConfig()
|
config := logical.TestBackendConfig()
|
||||||
config.StorageView = &logical.InmemStorage{}
|
config.StorageView = &logical.InmemStorage{}
|
||||||
config.System = sys
|
config.System = s.sys
|
||||||
|
|
||||||
b, err := Factory(context.Background(), config)
|
var err error
|
||||||
|
s.once.Do(func() {
|
||||||
|
var b logical.Backend
|
||||||
|
b, err = Factory(context.Background(), config)
|
||||||
|
s.db = b.(*databaseBackend)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer b.Cleanup(context.Background())
|
if s.db == nil {
|
||||||
|
return nil, errors.New("db is nil")
|
||||||
|
}
|
||||||
|
return s.db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackend_connectionCrud(t *testing.T) {
|
||||||
|
dbFactory := &singletonDBFactory{}
|
||||||
|
cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory)
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
dbFactory.sys = sys
|
||||||
|
client := cluster.Cores[0].Client.Logical()
|
||||||
|
|
||||||
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "13.4-buster")
|
cleanup, connURL := postgreshelper.PrepareTestContainer(t, "13.4-buster")
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
|
// Mount the database plugin.
|
||||||
|
resp, err := client.Write("sys/mounts/database", map[string]interface{}{
|
||||||
|
"type": "database",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
|
}
|
||||||
|
|
||||||
// Configure a connection
|
// Configure a connection
|
||||||
data := map[string]interface{}{
|
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
|
||||||
"connection_url": "test",
|
"connection_url": "test",
|
||||||
"plugin_name": "postgresql-database-plugin",
|
"plugin_name": "postgresql-database-plugin",
|
||||||
"verify_connection": false,
|
"verify_connection": false,
|
||||||
}
|
})
|
||||||
req := &logical.Request{
|
if err != nil {
|
||||||
Operation: logical.UpdateOperation,
|
|
||||||
Path: "config/plugin-test",
|
|
||||||
Storage: config.StorageView,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
resp, err := b.HandleRequest(namespace.RootContext(nil), req)
|
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure a second connection to confirm below it doesn't get restarted.
|
// Configure a second connection to confirm below it doesn't get restarted.
|
||||||
data = map[string]interface{}{
|
resp, err = client.Write("database/config/plugin-test-hana", map[string]interface{}{
|
||||||
"connection_url": "test",
|
"connection_url": "test",
|
||||||
"plugin_name": "hana-database-plugin",
|
"plugin_name": "hana-database-plugin",
|
||||||
"verify_connection": false,
|
"verify_connection": false,
|
||||||
}
|
})
|
||||||
req = &logical.Request{
|
if err != nil {
|
||||||
Operation: logical.UpdateOperation,
|
|
||||||
Path: "config/plugin-test-hana",
|
|
||||||
Storage: config.StorageView,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a role
|
// Create a role
|
||||||
data = map[string]interface{}{
|
resp, err = client.Write("database/roles/plugin-role-test", map[string]interface{}{
|
||||||
"db_name": "plugin-test",
|
"db_name": "plugin-test",
|
||||||
"creation_statements": testRole,
|
"creation_statements": testRole,
|
||||||
"revocation_statements": defaultRevocationSQL,
|
"revocation_statements": defaultRevocationSQL,
|
||||||
"default_ttl": "5m",
|
"default_ttl": "5m",
|
||||||
"max_ttl": "10m",
|
"max_ttl": "10m",
|
||||||
}
|
})
|
||||||
req = &logical.Request{
|
if err != nil {
|
||||||
Operation: logical.UpdateOperation,
|
|
||||||
Path: "roles/plugin-role-test",
|
|
||||||
Storage: config.StorageView,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the connection
|
// Update the connection
|
||||||
data = map[string]interface{}{
|
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
|
||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
"plugin_name": "postgresql-database-plugin",
|
"plugin_name": "postgresql-database-plugin",
|
||||||
"allowed_roles": []string{"plugin-role-test"},
|
"allowed_roles": []string{"plugin-role-test"},
|
||||||
"username": "postgres",
|
"username": "postgres",
|
||||||
"password": "secret",
|
"password": "secret",
|
||||||
"private_key": "PRIVATE_KEY",
|
"private_key": "PRIVATE_KEY",
|
||||||
}
|
})
|
||||||
req = &logical.Request{
|
if err != nil {
|
||||||
Operation: logical.UpdateOperation,
|
|
||||||
Path: "config/plugin-test",
|
|
||||||
Storage: config.StorageView,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
if len(resp.Warnings) == 0 {
|
if len(resp.Warnings) == 0 {
|
||||||
t.Fatalf("expected warning about password in url %s, resp:%#v\n", connURL, resp)
|
t.Fatalf("expected warning about password in url %s, resp:%#v\n", connURL, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Operation = logical.ReadOperation
|
resp, err = client.Read("database/config/plugin-test")
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
if err != nil {
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{})
|
returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{})
|
||||||
@@ -703,11 +708,16 @@ func TestBackend_connectionCrud(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Replace connection url with templated version
|
// Replace connection url with templated version
|
||||||
req.Operation = logical.UpdateOperation
|
templatedConnURL := strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}")
|
||||||
connURL = strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}")
|
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
|
||||||
data["connection_url"] = connURL
|
"connection_url": templatedConnURL,
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
"plugin_name": "postgresql-database-plugin",
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
"allowed_roles": []string{"plugin-role-test"},
|
||||||
|
"username": "postgres",
|
||||||
|
"password": "secret",
|
||||||
|
"private_key": "PRIVATE_KEY",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -716,36 +726,38 @@ func TestBackend_connectionCrud(t *testing.T) {
|
|||||||
"plugin_name": "postgresql-database-plugin",
|
"plugin_name": "postgresql-database-plugin",
|
||||||
"connection_details": map[string]interface{}{
|
"connection_details": map[string]interface{}{
|
||||||
"username": "postgres",
|
"username": "postgres",
|
||||||
"connection_url": connURL,
|
"connection_url": templatedConnURL,
|
||||||
},
|
},
|
||||||
"allowed_roles": []string{"plugin-role-test"},
|
"allowed_roles": []any{"plugin-role-test"},
|
||||||
"root_credentials_rotate_statements": []string(nil),
|
"root_credentials_rotate_statements": []any{},
|
||||||
"password_policy": "",
|
"password_policy": "",
|
||||||
"plugin_version": "",
|
"plugin_version": "",
|
||||||
}
|
}
|
||||||
req.Operation = logical.ReadOperation
|
resp, err = client.Read("database/config/plugin-test")
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
if err != nil {
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
|
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
|
||||||
if diff := deep.Equal(resp.Data, expected); diff != nil {
|
if diff := deep.Equal(resp.Data, expected); diff != nil {
|
||||||
t.Fatal(diff)
|
t.Fatal(strings.Join(diff, "\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test endpoints for reloading plugins.
|
// Test endpoints for reloading plugins.
|
||||||
for _, reloadPath := range []string{
|
for _, reload := range []struct {
|
||||||
"reset/plugin-test",
|
path string
|
||||||
"reload/postgresql-database-plugin",
|
data map[string]any
|
||||||
|
checkCount bool
|
||||||
|
}{
|
||||||
|
{"database/reset/plugin-test", nil, false},
|
||||||
|
{"database/reload/postgresql-database-plugin", nil, true},
|
||||||
|
{"sys/plugins/reload/backend", map[string]any{
|
||||||
|
"plugin": "postgresql-database-plugin",
|
||||||
|
}, false},
|
||||||
} {
|
} {
|
||||||
getConnectionID := func(name string) string {
|
getConnectionID := func(name string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dbBackend, ok := b.(*databaseBackend)
|
dbi := dbFactory.db.connections.Get(name)
|
||||||
if !ok {
|
|
||||||
t.Fatal("could not convert logical.Backend to databaseBackend")
|
|
||||||
}
|
|
||||||
dbi := dbBackend.connections.Get(name)
|
|
||||||
if dbi == nil {
|
if dbi == nil {
|
||||||
t.Fatal("no plugin-test dbi")
|
t.Fatal("no plugin-test dbi")
|
||||||
}
|
}
|
||||||
@@ -753,14 +765,8 @@ func TestBackend_connectionCrud(t *testing.T) {
|
|||||||
}
|
}
|
||||||
initialID := getConnectionID("plugin-test")
|
initialID := getConnectionID("plugin-test")
|
||||||
hanaID := getConnectionID("plugin-test-hana")
|
hanaID := getConnectionID("plugin-test-hana")
|
||||||
req = &logical.Request{
|
resp, err = client.Write(reload.path, reload.data)
|
||||||
Operation: logical.UpdateOperation,
|
if err != nil {
|
||||||
Path: reloadPath,
|
|
||||||
Storage: config.StorageView,
|
|
||||||
Data: map[string]interface{}{},
|
|
||||||
}
|
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
if initialID == getConnectionID("plugin-test") {
|
if initialID == getConnectionID("plugin-test") {
|
||||||
@@ -769,54 +775,43 @@ func TestBackend_connectionCrud(t *testing.T) {
|
|||||||
if hanaID != getConnectionID("plugin-test-hana") {
|
if hanaID != getConnectionID("plugin-test-hana") {
|
||||||
t.Fatal("hana plugin got restarted but shouldn't have been")
|
t.Fatal("hana plugin got restarted but shouldn't have been")
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(reloadPath, "reload/") {
|
if reload.checkCount {
|
||||||
if expected := 1; expected != resp.Data["count"] {
|
actual, err := resp.Data["count"].(json.Number).Int64()
|
||||||
t.Fatalf("expected %d but got %d", expected, resp.Data["count"])
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if expected := []string{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) {
|
if expected := 1; expected != int(actual) {
|
||||||
|
t.Fatalf("expected %d but got %d", expected, resp.Data["count"].(int))
|
||||||
|
}
|
||||||
|
if expected := []any{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) {
|
||||||
t.Fatalf("expected %v but got %v", expected, resp.Data["connections"])
|
t.Fatalf("expected %v but got %v", expected, resp.Data["connections"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get creds
|
// Get creds
|
||||||
data = map[string]interface{}{}
|
credsResp, err := client.Read("database/creds/plugin-role-test")
|
||||||
req = &logical.Request{
|
if err != nil {
|
||||||
Operation: logical.ReadOperation,
|
|
||||||
Path: "creds/plugin-role-test",
|
|
||||||
Storage: config.StorageView,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
credsResp, err := b.HandleRequest(namespace.RootContext(nil), req)
|
|
||||||
if err != nil || (credsResp != nil && credsResp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
credCheckURL := dbutil.QueryHelper(connURL, map[string]string{
|
credCheckURL := dbutil.QueryHelper(templatedConnURL, map[string]string{
|
||||||
"username": "postgres",
|
"username": "postgres",
|
||||||
"password": "secret",
|
"password": "secret",
|
||||||
})
|
})
|
||||||
if !testCredsExist(t, credsResp, credCheckURL) {
|
if !testCredsExist(t, credsResp.Data, credCheckURL) {
|
||||||
t.Fatalf("Creds should exist")
|
t.Fatalf("Creds should exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete Connection
|
// Delete Connection
|
||||||
data = map[string]interface{}{}
|
resp, err = client.Delete("database/config/plugin-test")
|
||||||
req = &logical.Request{
|
if err != nil {
|
||||||
Operation: logical.DeleteOperation,
|
|
||||||
Path: "config/plugin-test",
|
|
||||||
Storage: config.StorageView,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read connection
|
// Read connection
|
||||||
req.Operation = logical.ReadOperation
|
resp, err = client.Read("database/config/plugin-test")
|
||||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
if err != nil {
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
|
||||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1190,7 +1185,7 @@ func TestBackend_allowedRoles(t *testing.T) {
|
|||||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !testCredsExist(t, credsResp, connURL) {
|
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||||
t.Fatalf("Creds should exist")
|
t.Fatalf("Creds should exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1224,7 +1219,7 @@ func TestBackend_allowedRoles(t *testing.T) {
|
|||||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !testCredsExist(t, credsResp, connURL) {
|
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||||
t.Fatalf("Creds should exist")
|
t.Fatalf("Creds should exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1271,7 +1266,7 @@ func TestBackend_allowedRoles(t *testing.T) {
|
|||||||
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !testCredsExist(t, credsResp, connURL) {
|
if !testCredsExist(t, credsResp.Data, connURL) {
|
||||||
t.Fatalf("Creds should exist")
|
t.Fatalf("Creds should exist")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1581,13 +1576,13 @@ func TestNewDatabaseWrapper_IgnoresBuiltinVersion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool {
|
func testCredsExist(t *testing.T, data map[string]any, connURL string) bool {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
var d struct {
|
var d struct {
|
||||||
Username string `mapstructure:"username"`
|
Username string `mapstructure:"username"`
|
||||||
Password string `mapstructure:"password"`
|
Password string `mapstructure:"password"`
|
||||||
}
|
}
|
||||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
if err := mapstructure.Decode(data, &d); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ func TestPlugin_lifecycle(t *testing.T) {
|
|||||||
cluster, sys := getCluster(t)
|
cluster, sys := getCluster(t)
|
||||||
defer cluster.Cleanup()
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", []string{})
|
env := []string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)}
|
||||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", []string{})
|
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV4", env)
|
||||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", []string{})
|
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV5", env)
|
||||||
|
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v6-database-plugin-muxed", consts.PluginTypeDatabase, "", "TestBackend_PluginMain_MockV6Multiplexed", env)
|
||||||
|
|
||||||
config := logical.TestBackendConfig()
|
config := logical.TestBackendConfig()
|
||||||
config.StorageView = &logical.InmemStorage{}
|
config.StorageView = &logical.InmemStorage{}
|
||||||
|
|||||||
@@ -140,9 +140,8 @@ func testConfig(t *testing.T, pluginCmd string) (*logical.BackendConfig, func())
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd,
|
||||||
|
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})
|
||||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", consts.PluginTypeSecrets, "", pluginCmd, []string{})
|
|
||||||
|
|
||||||
return config, func() {
|
return config, func() {
|
||||||
cluster.Cleanup()
|
cluster.Cleanup()
|
||||||
|
|||||||
6
changelog/24512.txt
Normal file
6
changelog/24512.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
```release-note:change
|
||||||
|
plugins: Add a warning to the response from sys/plugins/reload/backend if no plugins were reloaded.
|
||||||
|
```
|
||||||
|
```release-note:improvement
|
||||||
|
secrets/database: Support reloading named database plugins using the sys/plugins/reload/backend API endpoint.
|
||||||
|
```
|
||||||
@@ -5,6 +5,7 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -55,10 +56,9 @@ func getPluginClusterAndCore(t *testing.T, logger log.Logger) (*vault.TestCluste
|
|||||||
cores := cluster.Cores
|
cores := cluster.Cores
|
||||||
core := cores[0]
|
core := cores[0]
|
||||||
|
|
||||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
|
||||||
|
|
||||||
vault.TestWaitActive(benchhelpers.TBtoT(t), core.Core)
|
vault.TestWaitActive(benchhelpers.TBtoT(t), core.Core)
|
||||||
vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain", []string{})
|
vault.TestAddTestPlugin(benchhelpers.TBtoT(t), core.Core, "mock-plugin", consts.PluginTypeSecrets, "", "TestPlugin_PluginMain",
|
||||||
|
[]string{fmt.Sprintf("%s=%s", pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)})
|
||||||
|
|
||||||
// Mount the mock plugin
|
// Mount the mock plugin
|
||||||
err = core.Client.Sys().Mount("mock", &api.MountInput{
|
err = core.Client.Sys().Mount("mock", &api.MountInput{
|
||||||
|
|||||||
@@ -560,6 +560,9 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
|||||||
if resp.Data["reload_id"] == nil {
|
if resp.Data["reload_id"] == nil {
|
||||||
t.Fatal("no reload_id in response")
|
t.Fatal("no reload_id in response")
|
||||||
}
|
}
|
||||||
|
if len(resp.Warnings) != 0 {
|
||||||
|
t.Fatal(resp.Warnings)
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
// Ensure internal backed value is reset
|
// Ensure internal backed value is reset
|
||||||
@@ -578,6 +581,35 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSystemBackend_PluginReload_WarningIfNoneReloaded(t *testing.T) {
|
||||||
|
cluster := testSystemBackendMock(t, 1, 2, logical.TypeLogical, "v5")
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
core := cluster.Cores[0]
|
||||||
|
client := core.Client
|
||||||
|
|
||||||
|
for _, backendType := range []logical.BackendType{logical.TypeLogical, logical.TypeCredential} {
|
||||||
|
t.Run(backendType.String(), func(t *testing.T) {
|
||||||
|
// Perform plugin reload
|
||||||
|
resp, err := client.Logical().Write("sys/plugins/reload/backend", map[string]any{
|
||||||
|
"plugin": "does-not-exist",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
t.Fatalf("bad: %v", resp)
|
||||||
|
}
|
||||||
|
if resp.Data["reload_id"] == nil {
|
||||||
|
t.Fatal("no reload_id in response")
|
||||||
|
}
|
||||||
|
if len(resp.Warnings) == 0 {
|
||||||
|
t.Fatal("expected warning")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// testSystemBackendMock returns a systemBackend with the desired number
|
// testSystemBackendMock returns a systemBackend with the desired number
|
||||||
// of mounted mock plugin backends. numMounts alternates between different
|
// of mounted mock plugin backends. numMounts alternates between different
|
||||||
// ways of providing the plugin_name.
|
// ways of providing the plugin_name.
|
||||||
|
|||||||
@@ -738,11 +738,24 @@ func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logic
|
|||||||
return logical.ErrorResponse("plugin or mounts must be provided"), nil
|
return logical.ErrorResponse("plugin or mounts must be provided"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp := logical.Response{
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"reload_id": req.ID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
if pluginName != "" {
|
if pluginName != "" {
|
||||||
err := b.Core.reloadMatchingPlugin(ctx, pluginName)
|
reloaded, err := b.Core.reloadMatchingPlugin(ctx, pluginName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if reloaded == 0 {
|
||||||
|
if scope == globalScope {
|
||||||
|
resp.AddWarning("no plugins were reloaded locally (but they may be reloaded on other nodes)")
|
||||||
|
} else {
|
||||||
|
resp.AddWarning("no plugins were reloaded")
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if len(pluginMounts) > 0 {
|
} else if len(pluginMounts) > 0 {
|
||||||
err := b.Core.reloadMatchingPluginMounts(ctx, pluginMounts)
|
err := b.Core.reloadMatchingPluginMounts(ctx, pluginMounts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -750,20 +763,14 @@ func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logic
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r := logical.Response{
|
|
||||||
Data: map[string]interface{}{
|
|
||||||
"reload_id": req.ID,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if scope == globalScope {
|
if scope == globalScope {
|
||||||
err := handleGlobalPluginReload(ctx, b.Core, req.ID, pluginName, pluginMounts)
|
err := handleGlobalPluginReload(ctx, b.Core, req.ID, pluginName, pluginMounts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return logical.RespondWithStatusCode(&r, req, http.StatusAccepted)
|
return logical.RespondWithStatusCode(&resp, req, http.StatusAccepted)
|
||||||
}
|
}
|
||||||
return &r, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *SystemBackend) handlePluginRuntimeCatalogUpdate(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
func (b *SystemBackend) handlePluginRuntimeCatalogUpdate(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ const mountStateUnmounting = "unmounting"
|
|||||||
// MountEntry is used to represent a mount table entry
|
// MountEntry is used to represent a mount table entry
|
||||||
type MountEntry struct {
|
type MountEntry struct {
|
||||||
Table string `json:"table"` // The table it belongs to
|
Table string `json:"table"` // The table it belongs to
|
||||||
Path string `json:"path"` // Mount Path
|
Path string `json:"path"` // Mount Path, as provided in the mount API call but with a trailing slash, i.e. no auth/ or namespace prefix.
|
||||||
Type string `json:"type"` // Logical backend Type. NB: This is the plugin name, e.g. my-vault-plugin, NOT plugin type (e.g. auth).
|
Type string `json:"type"` // Logical backend Type. NB: This is the plugin name, e.g. my-vault-plugin, NOT plugin type (e.g. auth).
|
||||||
Description string `json:"description"` // User-provided description
|
Description string `json:"description"` // User-provided description
|
||||||
UUID string `json:"uuid"` // Barrier view UUID
|
UUID string `json:"uuid"` // Barrier view UUID
|
||||||
|
|||||||
@@ -70,10 +70,10 @@ func (c *Core) reloadMatchingPluginMounts(ctx context.Context, mounts []string)
|
|||||||
return errors
|
return errors
|
||||||
}
|
}
|
||||||
|
|
||||||
// reloadPlugin reloads all mounted backends that are of
|
// reloadMatchingPlugin reloads all mounted backends that are named pluginName
|
||||||
// plugin pluginName (name of the plugin as registered in
|
// (name of the plugin as registered in the plugin catalog). It returns the
|
||||||
// the plugin catalog).
|
// number of plugins that were reloaded and an error if any.
|
||||||
func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) error {
|
func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) (reloaded int, err error) {
|
||||||
c.mountsLock.RLock()
|
c.mountsLock.RLock()
|
||||||
defer c.mountsLock.RUnlock()
|
defer c.mountsLock.RUnlock()
|
||||||
c.authLock.RLock()
|
c.authLock.RLock()
|
||||||
@@ -81,25 +81,49 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro
|
|||||||
|
|
||||||
ns, err := namespace.FromContext(ctx)
|
ns, err := namespace.FromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return reloaded, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter mount entries that only matches the plugin name
|
|
||||||
for _, entry := range c.mounts.Entries {
|
for _, entry := range c.mounts.Entries {
|
||||||
// We dont reload mounts that are not in the same namespace
|
// We dont reload mounts that are not in the same namespace
|
||||||
if ns.ID != entry.Namespace().ID {
|
if ns.ID != entry.Namespace().ID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
|
if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
|
||||||
err := c.reloadBackendCommon(ctx, entry, false)
|
err := c.reloadBackendCommon(ctx, entry, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return reloaded, err
|
||||||
|
}
|
||||||
|
reloaded++
|
||||||
|
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "version", entry.Version)
|
||||||
|
} else if entry.Type == "database" {
|
||||||
|
// The combined database plugin is itself a secrets engine, but
|
||||||
|
// knowledge of whether a database plugin is in use within a particular
|
||||||
|
// mount is internal to the combined database plugin's storage, so
|
||||||
|
// we delegate the reload request with an internally routed request.
|
||||||
|
req := &logical.Request{
|
||||||
|
Operation: logical.UpdateOperation,
|
||||||
|
Path: entry.Path + "reload/" + pluginName,
|
||||||
|
}
|
||||||
|
resp, err := c.router.Route(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return reloaded, err
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s", pluginName, entry.Path)
|
||||||
|
}
|
||||||
|
if resp.IsError() {
|
||||||
|
return reloaded, fmt.Errorf("failed to reload %q database plugin(s) mounted under %s: %s", pluginName, entry.Path, resp.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if count, ok := resp.Data["count"].(int); ok && count > 0 {
|
||||||
|
c.logger.Info("successfully reloaded database plugin(s)", "plugin", pluginName, "namespace", entry.Namespace(), "path", entry.Path, "connections", resp.Data["connections"])
|
||||||
|
reloaded += count
|
||||||
}
|
}
|
||||||
c.logger.Info("successfully reloaded plugin", "plugin", pluginName, "path", entry.Path, "version", entry.Version)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter auth mount entries that ony matches the plugin name
|
|
||||||
for _, entry := range c.auth.Entries {
|
for _, entry := range c.auth.Entries {
|
||||||
// We dont reload mounts that are not in the same namespace
|
// We dont reload mounts that are not in the same namespace
|
||||||
if ns.ID != entry.Namespace().ID {
|
if ns.ID != entry.Namespace().ID {
|
||||||
@@ -109,13 +133,14 @@ func (c *Core) reloadMatchingPlugin(ctx context.Context, pluginName string) erro
|
|||||||
if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
|
if entry.Type == pluginName || (entry.Type == "plugin" && entry.Config.PluginName == pluginName) {
|
||||||
err := c.reloadBackendCommon(ctx, entry, true)
|
err := c.reloadBackendCommon(ctx, entry, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return reloaded, err
|
||||||
}
|
}
|
||||||
|
reloaded++
|
||||||
c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version)
|
c.logger.Info("successfully reloaded plugin", "plugin", entry.Accessor, "path", entry.Path, "version", entry.Version)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return reloaded, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// reloadBackendCommon is a generic method to reload a backend provided a
|
// reloadBackendCommon is a generic method to reload a backend provided a
|
||||||
|
|||||||
Reference in New Issue
Block a user