mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 11:38:02 +00:00
[DBPW 4/X] Update DB engine to support v4 and v5 interfaces with password policies (#9878)
This commit is contained in:
@@ -8,12 +8,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/locksutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
|
||||
type dbPluginInstance struct {
|
||||
sync.RWMutex
|
||||
dbplugin.Database
|
||||
database databaseVersionWrapper
|
||||
|
||||
id string
|
||||
name string
|
||||
@@ -46,7 +46,7 @@ func (dbi *dbPluginInstance) Close() error {
|
||||
}
|
||||
dbi.closed = true
|
||||
|
||||
return dbi.Database.Close()
|
||||
return dbi.database.Close()
|
||||
}
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
@@ -89,7 +89,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathCredsCreate(&b),
|
||||
pathRotateCredentials(&b),
|
||||
pathRotateRootCredentials(&b),
|
||||
),
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
@@ -240,9 +240,9 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
|
||||
unlockFunc := b.RUnlock
|
||||
defer func() { unlockFunc() }()
|
||||
|
||||
db, ok := b.connections[name]
|
||||
dbi, ok := b.connections[name]
|
||||
if ok {
|
||||
return db, nil
|
||||
return dbi, nil
|
||||
}
|
||||
|
||||
// Upgrade lock
|
||||
@@ -250,20 +250,9 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
|
||||
b.Lock()
|
||||
unlockFunc = b.Unlock
|
||||
|
||||
db, ok = b.connections[name]
|
||||
dbi, ok = b.connections[name]
|
||||
if ok {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
dbp, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = dbp.Init(ctx, config.ConnectionDetails, true)
|
||||
if err != nil {
|
||||
dbp.Close()
|
||||
return nil, err
|
||||
return dbi, nil
|
||||
}
|
||||
|
||||
id, err := uuid.GenerateUUID()
|
||||
@@ -271,14 +260,28 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db = &dbPluginInstance{
|
||||
Database: dbp,
|
||||
name: name,
|
||||
id: id,
|
||||
dbw, err := newDatabaseWrapper(ctx, config.PluginName, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create database instance: %w", err)
|
||||
}
|
||||
|
||||
b.connections[name] = db
|
||||
return db, nil
|
||||
initReq := newdbplugin.InitializeRequest{
|
||||
Config: config.ConnectionDetails,
|
||||
VerifyConnection: true,
|
||||
}
|
||||
_, err = dbw.Initialize(ctx, initReq)
|
||||
if err != nil {
|
||||
dbw.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbi = &dbPluginInstance{
|
||||
database: dbw,
|
||||
id: id,
|
||||
name: name,
|
||||
}
|
||||
b.connections[name] = dbi
|
||||
return dbi, nil
|
||||
}
|
||||
|
||||
// invalidateQueue cancels any background queue loading and destroys the queue.
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
"github.com/hashicorp/vault-plugin-database-mongodbatlas"
|
||||
mongodbatlas "github.com/hashicorp/vault-plugin-database-mongodbatlas"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
||||
@@ -231,6 +231,7 @@ func TestBackend_config_connection(t *testing.T) {
|
||||
},
|
||||
"allowed_roles": []string{"*"},
|
||||
"root_credentials_rotate_statements": []string{},
|
||||
"password_policy": "",
|
||||
}
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
|
||||
@@ -283,6 +284,7 @@ func TestBackend_config_connection(t *testing.T) {
|
||||
},
|
||||
"allowed_roles": []string{"*"},
|
||||
"root_credentials_rotate_statements": []string{},
|
||||
"password_policy": "",
|
||||
}
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
|
||||
@@ -324,6 +326,7 @@ func TestBackend_config_connection(t *testing.T) {
|
||||
},
|
||||
"allowed_roles": []string{"flu", "barre"},
|
||||
"root_credentials_rotate_statements": []string{},
|
||||
"password_policy": "",
|
||||
}
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
|
||||
@@ -711,6 +714,7 @@ func TestBackend_connectionCrud(t *testing.T) {
|
||||
},
|
||||
"allowed_roles": []string{"plugin-role-test"},
|
||||
"root_credentials_rotate_statements": []string(nil),
|
||||
"password_policy": "",
|
||||
}
|
||||
req.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
|
||||
|
||||
101
builtin/logical/database/mocks_test.go
Normal file
101
builtin/logical/database/mocks_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var _ newdbplugin.Database = &mockNewDatabase{}
|
||||
|
||||
type mockNewDatabase struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockNewDatabase) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
|
||||
args := m.Called(ctx, req)
|
||||
return args.Get(0).(newdbplugin.InitializeResponse), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockNewDatabase) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
|
||||
args := m.Called(ctx, req)
|
||||
return args.Get(0).(newdbplugin.NewUserResponse), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockNewDatabase) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
|
||||
args := m.Called(ctx, req)
|
||||
return args.Get(0).(newdbplugin.UpdateUserResponse), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockNewDatabase) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
|
||||
args := m.Called(ctx, req)
|
||||
return args.Get(0).(newdbplugin.DeleteUserResponse), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockNewDatabase) Type() (string, error) {
|
||||
args := m.Called()
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockNewDatabase) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
var _ dbplugin.Database = &mockLegacyDatabase{}
|
||||
|
||||
type mockLegacyDatabase struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
args := m.Called(ctx, statements, usernameConfig, expiration)
|
||||
return args.String(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
args := m.Called(ctx, statements, username, expiration)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
args := m.Called(ctx, statements, username)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error) {
|
||||
args := m.Called(ctx, statements)
|
||||
return args.Get(0).(map[string]interface{}), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) GenerateCredentials(ctx context.Context) (string, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticConfig dbplugin.StaticUserConfig) (username string, password string, err error) {
|
||||
args := m.Called(ctx, statements, staticConfig)
|
||||
return args.String(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error) {
|
||||
args := m.Called(ctx, config, verifyConnection)
|
||||
return args.Get(0).(map[string]interface{}), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) Type() (string, error) {
|
||||
args := m.Called()
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockLegacyDatabase) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error) {
|
||||
panic("Initialize should not be called")
|
||||
}
|
||||
118
builtin/logical/database/mockv4.go
Normal file
118
builtin/logical/database/mockv4.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
)
|
||||
|
||||
const mockV4Type = "mockv4"
|
||||
|
||||
// MockDatabaseV4 is an implementation of Database interface
|
||||
type MockDatabaseV4 struct {
|
||||
config map[string]interface{}
|
||||
}
|
||||
|
||||
var _ dbplugin.Database = &MockDatabaseV4{}
|
||||
|
||||
// New returns a new in-memory instance
|
||||
func NewV4() (interface{}, error) {
|
||||
return MockDatabaseV4{}, nil
|
||||
}
|
||||
|
||||
// RunV4 instantiates a MongoDB object, and runs the RPC server for the plugin
|
||||
func RunV4(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := NewV4()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error) {
|
||||
log.Default().Info("Init called",
|
||||
"config", config,
|
||||
"verifyConnection", verifyConnection)
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error) {
|
||||
_, err = m.Init(ctx, config, verifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
log.Default().Info("CreateUser called",
|
||||
"statements", statements,
|
||||
"usernameConfig", usernameConfig,
|
||||
"expiration", expiration)
|
||||
|
||||
now := time.Now()
|
||||
user := fmt.Sprintf("mockv4_user_%s", now.Format(time.RFC3339))
|
||||
pass, err := m.GenerateCredentials(ctx)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate credentials: %w", err)
|
||||
}
|
||||
return user, pass, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
log.Default().Info("RenewUser called",
|
||||
"statements", statements,
|
||||
"username", username,
|
||||
"expiration", expiration)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
log.Default().Info("RevokeUser called",
|
||||
"statements", statements,
|
||||
"username", username)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error) {
|
||||
log.Default().Info("RotateRootCredentials called",
|
||||
"statements", statements)
|
||||
|
||||
newPassword, err := m.GenerateCredentials(ctx)
|
||||
if err != nil {
|
||||
return config, fmt.Errorf("failed to generate credentials: %w", err)
|
||||
}
|
||||
config["password"] = newPassword
|
||||
|
||||
return m.config, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticConfig dbplugin.StaticUserConfig) (username string, password string, err error) {
|
||||
log.Default().Info("SetCredentials called",
|
||||
"statements", statements,
|
||||
"staticConfig", staticConfig)
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) GenerateCredentials(ctx context.Context) (password string, err error) {
|
||||
now := time.Now()
|
||||
pass := fmt.Sprintf("mockv4_password_%s", now.Format(time.RFC3339))
|
||||
return pass, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) Type() (string, error) {
|
||||
log.Default().Info("Type called")
|
||||
return mockV4Type, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV4) Close() error {
|
||||
log.Default().Info("Close called")
|
||||
return nil
|
||||
}
|
||||
85
builtin/logical/database/mockv5.go
Normal file
85
builtin/logical/database/mockv5.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
)
|
||||
|
||||
const mockV5Type = "mockv5"
|
||||
|
||||
// MockDatabaseV5 is an implementation of Database interface
|
||||
type MockDatabaseV5 struct {
|
||||
config map[string]interface{}
|
||||
}
|
||||
|
||||
var _ newdbplugin.Database = &MockDatabaseV5{}
|
||||
|
||||
// New returns a new in-memory instance
|
||||
func New() (interface{}, error) {
|
||||
db := MockDatabaseV5{}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Run instantiates a MongoDB object, and runs the RPC server for the plugin
|
||||
func RunV5(apiTLSConfig *api.TLSConfig) error {
|
||||
dbType, err := New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV5) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
|
||||
log.Default().Info("Initialize called",
|
||||
"req", req)
|
||||
|
||||
config := req.Config
|
||||
config["from-plugin"] = "this value is from the plugin itself"
|
||||
|
||||
resp := newdbplugin.InitializeResponse{
|
||||
Config: req.Config,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV5) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
|
||||
log.Default().Info("NewUser called",
|
||||
"req", req)
|
||||
|
||||
now := time.Now()
|
||||
user := fmt.Sprintf("mockv5_user_%s", now.Format(time.RFC3339))
|
||||
resp := newdbplugin.NewUserResponse{
|
||||
Username: user,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV5) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
|
||||
log.Default().Info("UpdateUser called",
|
||||
"req", req)
|
||||
return newdbplugin.UpdateUserResponse{}, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV5) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
|
||||
log.Default().Info("DeleteUser called",
|
||||
"req", req)
|
||||
return newdbplugin.DeleteUserResponse{}, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV5) Type() (string, error) {
|
||||
log.Default().Info("Type called")
|
||||
return mockV5Type, nil
|
||||
}
|
||||
|
||||
func (m MockDatabaseV5) Close() error {
|
||||
log.Default().Info("Close called")
|
||||
return nil
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/errwrap"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
@@ -30,6 +30,8 @@ type DatabaseConfig struct {
|
||||
AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"`
|
||||
|
||||
RootCredentialsRotateStatements []string `json:"root_credentials_rotate_statements" structs:"root_credentials_rotate_statements" mapstructure:"root_credentials_rotate_statements"`
|
||||
|
||||
PasswordPolicy string `json:"password_policy" structs:"password_policy" mapstructure:"password_policy"`
|
||||
}
|
||||
|
||||
// pathResetConnection configures a path to reset a plugin.
|
||||
@@ -114,6 +116,10 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
|
||||
page for more information on support and formatting for this
|
||||
parameter.`,
|
||||
},
|
||||
"password_policy": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `Password policy to use when generating passwords.`,
|
||||
},
|
||||
},
|
||||
|
||||
ExistenceCheck: b.connectionExistenceCheck(),
|
||||
@@ -138,7 +144,7 @@ func (b *databaseBackend) connectionExistenceCheck() framework.ExistenceFunc {
|
||||
|
||||
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return false, errors.New("failed to read connection configuration")
|
||||
return false, fmt.Errorf("failed to read connection configuration: %w", err)
|
||||
}
|
||||
|
||||
return entry != nil, nil
|
||||
@@ -179,7 +185,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
|
||||
|
||||
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read connection configuration")
|
||||
return nil, fmt.Errorf("failed to read connection configuration: %w", err)
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
@@ -245,7 +251,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
|
||||
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read connection configuration")
|
||||
return nil, fmt.Errorf("failed to read connection configuration: %w", err)
|
||||
}
|
||||
if entry != nil {
|
||||
if err := entry.DecodeJSON(config); err != nil {
|
||||
@@ -274,6 +280,10 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
config.RootCredentialsRotateStatements = data.Get("root_rotation_statements").([]string)
|
||||
}
|
||||
|
||||
if passwordPolicyRaw, ok := data.GetOk("password_policy"); ok {
|
||||
config.PasswordPolicy = passwordPolicyRaw.(string)
|
||||
}
|
||||
|
||||
// Remove these entries from the data before we store it keyed under
|
||||
// ConnectionDetails.
|
||||
delete(data.Raw, "name")
|
||||
@@ -281,11 +291,11 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
delete(data.Raw, "allowed_roles")
|
||||
delete(data.Raw, "verify_connection")
|
||||
delete(data.Raw, "root_rotation_statements")
|
||||
delete(data.Raw, "password_policy")
|
||||
|
||||
// Create a database plugin and initialize it.
|
||||
db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
|
||||
id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If this is an update, take any new values, overwrite what was there
|
||||
@@ -302,37 +312,39 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
}
|
||||
}
|
||||
|
||||
config.ConnectionDetails, err = db.Init(ctx, config.ConnectionDetails, verifyConnection)
|
||||
// Create a database plugin and initialize it.
|
||||
dbw, err := newDatabaseWrapper(ctx, config.PluginName, b.System(), b.logger)
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||
return logical.ErrorResponse("error creating database object: %s", err), nil
|
||||
}
|
||||
|
||||
initReq := newdbplugin.InitializeRequest{
|
||||
Config: config.ConnectionDetails,
|
||||
VerifyConnection: verifyConnection,
|
||||
}
|
||||
initResp, err := dbw.Initialize(ctx, initReq)
|
||||
if err != nil {
|
||||
dbw.Close()
|
||||
return logical.ErrorResponse("error creating database object: %s", err), nil
|
||||
}
|
||||
config.ConnectionDetails = initResp.Config
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Close and remove the old connection
|
||||
b.clearConnection(name)
|
||||
|
||||
id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.connections[name] = &dbPluginInstance{
|
||||
Database: db,
|
||||
database: dbw,
|
||||
name: name,
|
||||
id: id,
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err = logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
|
||||
err = storeConfig(ctx, req.Storage, name, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &logical.Response{}
|
||||
|
||||
@@ -346,10 +358,30 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// If using a legacy DB plugin and set the `password_policy` field, send a warning to the user indicating
|
||||
// the `password_policy` will not be used
|
||||
if dbw.isV4() && config.PasswordPolicy != "" {
|
||||
resp.AddWarning(fmt.Sprintf("%s does not support password policies - upgrade to the latest version of "+
|
||||
"Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func storeConfig(ctx context.Context, storage logical.Storage, name string, config *DatabaseConfig) error {
|
||||
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to marshal object to JSON: %w", err)
|
||||
}
|
||||
|
||||
err = storage.Put(ctx, entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save object: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure connection details to a database plugin.
|
||||
`
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
@@ -73,13 +73,13 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
}
|
||||
|
||||
// Get the Database object
|
||||
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
|
||||
dbi, err := b.GetConnection(ctx, req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.RLock()
|
||||
defer db.RUnlock()
|
||||
dbi.RLock()
|
||||
defer dbi.RUnlock()
|
||||
|
||||
ttl, _, err := framework.CalculateTTL(b.System(), 0, role.DefaultTTL, 0, role.MaxTTL, 0, time.Time{})
|
||||
if err != nil {
|
||||
@@ -90,27 +90,44 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
// to ensure the database credential does not expire before the lease
|
||||
expiration = expiration.Add(5 * time.Second)
|
||||
|
||||
usernameConfig := dbplugin.UsernameConfig{
|
||||
DisplayName: req.DisplayName,
|
||||
RoleName: name,
|
||||
password, err := dbi.database.GeneratePassword(ctx, b.System(), dbConfig.PasswordPolicy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to generate password: %w", err)
|
||||
}
|
||||
|
||||
// Create the user
|
||||
username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration)
|
||||
newUserReq := newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: req.DisplayName,
|
||||
RoleName: name,
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: role.Statements.Creation,
|
||||
},
|
||||
RollbackStatements: newdbplugin.Statements{
|
||||
Commands: role.Statements.Rollback,
|
||||
},
|
||||
Password: password,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
// Overwriting the password in the event this is a legacy database plugin and the provided password is ignored
|
||||
newUserResp, password, err := dbi.database.NewUser(ctx, newUserReq)
|
||||
if err != nil {
|
||||
b.CloseIfShutdown(db, err)
|
||||
b.CloseIfShutdown(dbi, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
respData := map[string]interface{}{
|
||||
"username": newUserResp.Username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
}
|
||||
internal := map[string]interface{}{
|
||||
"username": newUserResp.Username,
|
||||
"role": name,
|
||||
"db_name": role.DBName,
|
||||
"revocation_statements": role.Statements.Revocation,
|
||||
})
|
||||
}
|
||||
resp := b.Secret(SecretCredsType).Response(respData, internal)
|
||||
resp.Secret.TTL = role.DefaultTTL
|
||||
resp.Secret.MaxTTL = role.MaxTTL
|
||||
return resp, nil
|
||||
|
||||
@@ -5,16 +5,13 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/queue"
|
||||
)
|
||||
|
||||
func pathRotateCredentials(b *databaseBackend) []*framework.Path {
|
||||
func pathRotateRootCredentials(b *databaseBackend) []*framework.Path {
|
||||
return []*framework.Path{
|
||||
&framework.Path{
|
||||
Pattern: "rotate-root/" + framework.GenericNameRegex("name"),
|
||||
@@ -27,7 +24,7 @@ func pathRotateCredentials(b *databaseBackend) []*framework.Path {
|
||||
|
||||
Operations: map[logical.Operation]framework.OperationHandler{
|
||||
logical.UpdateOperation: &framework.PathOperation{
|
||||
Callback: b.pathRotateCredentialsUpdate(),
|
||||
Callback: b.pathRotateRootCredentialsUpdate(),
|
||||
ForwardPerformanceSecondary: true,
|
||||
ForwardPerformanceStandby: true,
|
||||
},
|
||||
@@ -59,7 +56,7 @@ func pathRotateCredentials(b *databaseBackend) []*framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc {
|
||||
func (b *databaseBackend) pathRotateRootCredentialsUpdate() framework.OperationFunc {
|
||||
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
@@ -71,15 +68,15 @@ func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := b.GetConnection(ctx, req.Storage, name)
|
||||
dbi, err := b.GetConnection(ctx, req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Close the plugin
|
||||
db.closed = true
|
||||
if err := db.Database.Close(); err != nil {
|
||||
dbi.closed = true
|
||||
if err := dbi.database.Close(); err != nil {
|
||||
b.Logger().Error("error closing the database plugin connection", "err", err)
|
||||
}
|
||||
// Even on error, still remove the connection
|
||||
@@ -91,13 +88,13 @@ func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc
|
||||
defer b.Unlock()
|
||||
|
||||
// Take the write lock on the instance
|
||||
db.Lock()
|
||||
defer db.Unlock()
|
||||
dbi.Lock()
|
||||
defer dbi.Unlock()
|
||||
|
||||
// Generate new credentials
|
||||
userName := config.ConnectionDetails["username"].(string)
|
||||
username := config.ConnectionDetails["username"].(string)
|
||||
oldPassword := config.ConnectionDetails["password"].(string)
|
||||
newPassword, err := db.GenerateCredentials(ctx)
|
||||
newPassword, err := dbi.database.GeneratePassword(ctx, b.System(), config.PasswordPolicy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -106,7 +103,7 @@ func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc
|
||||
// Write a WAL entry
|
||||
walID, err := framework.PutWAL(ctx, req.Storage, rotateRootWALKey, &rotateRootCredentialsWAL{
|
||||
ConnectionName: name,
|
||||
UserName: userName,
|
||||
UserName: username,
|
||||
OldPassword: oldPassword,
|
||||
NewPassword: newPassword,
|
||||
})
|
||||
@@ -114,37 +111,32 @@ func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Attempt to use SetCredentials for the root credential rotation
|
||||
statements := dbplugin.Statements{Rotation: config.RootCredentialsRotateStatements}
|
||||
userConfig := dbplugin.StaticUserConfig{
|
||||
Username: userName,
|
||||
Password: newPassword,
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: username,
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: newPassword,
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: config.RootCredentialsRotateStatements,
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, _, err := db.SetCredentials(ctx, statements, userConfig); err != nil {
|
||||
if status.Code(err) == codes.Unimplemented {
|
||||
// Fall back to using RotateRootCredentials if unimplemented
|
||||
config.ConnectionDetails, err = db.RotateRootCredentials(ctx,
|
||||
config.RootCredentialsRotateStatements)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newConfigDetails, err := dbi.database.UpdateUser(ctx, updateReq, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
if newConfigDetails != nil {
|
||||
config.ConnectionDetails = newConfigDetails
|
||||
}
|
||||
|
||||
// Update storage with the new root credentials
|
||||
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
|
||||
err = storeConfig(ctx, req.Storage, name, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete the WAL entry after successfully rotating root credentials
|
||||
if err := framework.DeleteWAL(ctx, req.Storage, walID); err != nil {
|
||||
err = framework.DeleteWAL(ctx, req.Storage, walID)
|
||||
if err != nil {
|
||||
b.Logger().Warn("unable to delete WAL", "error", err, "WAL ID", walID)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -73,13 +73,13 @@ func (b *databaseBackend) walRollback(ctx context.Context, req *logical.Request,
|
||||
}
|
||||
|
||||
// rollbackDatabaseCredentials rolls back root database credentials for
|
||||
// the connection associated with the passed WAL entry. It will creates
|
||||
// the connection associated with the passed WAL entry. It will create
|
||||
// a connection to the database using the WAL entry new password in
|
||||
// order to alter the password to be the WAL entry old password.
|
||||
func (b *databaseBackend) rollbackDatabaseCredentials(ctx context.Context, config *DatabaseConfig, entry rotateRootCredentialsWAL) error {
|
||||
// Attempt to get a connection with the WAL entry new password.
|
||||
config.ConnectionDetails["password"] = entry.NewPassword
|
||||
dbc, err := b.GetConnectionWithConfig(ctx, entry.ConnectionName, config)
|
||||
dbi, err := b.GetConnectionWithConfig(ctx, entry.ConnectionName, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -91,22 +91,21 @@ func (b *databaseBackend) rollbackDatabaseCredentials(ctx context.Context, confi
|
||||
}
|
||||
}()
|
||||
|
||||
// Roll back the database password to the WAL entry old password
|
||||
statements := dbplugin.Statements{Rotation: config.RootCredentialsRotateStatements}
|
||||
userConfig := dbplugin.StaticUserConfig{
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: entry.UserName,
|
||||
Password: entry.OldPassword,
|
||||
}
|
||||
if _, _, err := dbc.SetCredentials(ctx, statements, userConfig); err != nil {
|
||||
// If the database plugin doesn't implement SetCredentials, the root
|
||||
// credentials can't be rolled back. This means the root credential
|
||||
// rotation happened via the plugin RotateRootCredentials RPC.
|
||||
if status.Code(err) == codes.Unimplemented {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: entry.OldPassword,
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: config.RootCredentialsRotateStatements,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
// It actually is the root user here, but we only want to use SetCredentials since
|
||||
// RotateRootCredentials doesn't give any control over what password is used
|
||||
_, err = dbi.database.UpdateUser(ctx, updateReq, false)
|
||||
if status.Code(err) == codes.Unimplemented {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,10 +4,11 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
@@ -92,18 +93,22 @@ func TestBackend_RotateRootCredentials_WAL_rollback(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get a connection to the database plugin
|
||||
pc, err := dbBackend.GetConnection(context.Background(),
|
||||
dbi, err := dbBackend.GetConnection(context.Background(),
|
||||
config.StorageView, "plugin-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Alter the database password so it no longer matches what is in storage
|
||||
_, _, err = pc.SetCredentials(context.Background(), dbplugin.Statements{},
|
||||
dbplugin.StaticUserConfig{
|
||||
Username: databaseUser,
|
||||
Password: "newSecret",
|
||||
})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: databaseUser,
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: "newSecret",
|
||||
},
|
||||
}
|
||||
_, err = dbi.database.UpdateUser(ctx, updateReq, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -335,17 +340,21 @@ func TestBackend_RotateRootCredentials_WAL_no_rollback_2(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get a connection to the database plugin
|
||||
pc, err := dbBackend.GetConnection(context.Background(), config.StorageView, "plugin-test")
|
||||
dbi, err := dbBackend.GetConnection(context.Background(), config.StorageView, "plugin-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Alter the database password
|
||||
_, _, err = pc.SetCredentials(context.Background(), dbplugin.Statements{},
|
||||
dbplugin.StaticUserConfig{
|
||||
Username: databaseUser,
|
||||
Password: "newSecret",
|
||||
})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: databaseUser,
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: "newSecret",
|
||||
},
|
||||
}
|
||||
_, err = dbi.database.UpdateUser(ctx, updateReq, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/locksutil"
|
||||
@@ -319,21 +320,20 @@ func (b *databaseBackend) setStaticAccount(ctx context.Context, s logical.Storag
|
||||
}
|
||||
|
||||
// Get the Database object
|
||||
db, err := b.GetConnection(ctx, s, input.Role.DBName)
|
||||
dbi, err := b.GetConnection(ctx, s, input.Role.DBName)
|
||||
if err != nil {
|
||||
return output, err
|
||||
}
|
||||
|
||||
db.RLock()
|
||||
defer db.RUnlock()
|
||||
dbi.RLock()
|
||||
defer dbi.RUnlock()
|
||||
|
||||
// Use password from input if available. This happens if we're restoring from
|
||||
// a WAL item or processing the rotation queue with an item that has a WAL
|
||||
// associated with it
|
||||
newPassword := input.Password
|
||||
if newPassword == "" {
|
||||
// Generate a new password
|
||||
newPassword, err = db.GenerateCredentials(ctx)
|
||||
newPassword, err = dbi.database.GeneratePassword(ctx, b.System(), dbConfig.PasswordPolicy)
|
||||
if err != nil {
|
||||
return output, err
|
||||
}
|
||||
@@ -358,21 +358,26 @@ func (b *databaseBackend) setStaticAccount(ctx context.Context, s logical.Storag
|
||||
}
|
||||
}
|
||||
|
||||
_, password, err := db.SetCredentials(ctx, input.Role.Statements, config)
|
||||
if err != nil {
|
||||
b.CloseIfShutdown(db, err)
|
||||
return output, errwrap.Wrapf("error setting credentials: {{err}}", err)
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: input.Role.StaticAccount.Username,
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: newPassword,
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: input.Role.Statements.Rotation,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if newPassword != password {
|
||||
return output, errors.New("mismatch passwords returned")
|
||||
_, err = dbi.database.UpdateUser(ctx, updateReq, false)
|
||||
if err != nil {
|
||||
b.CloseIfShutdown(dbi, err)
|
||||
return output, errwrap.Wrapf("error setting credentials: {{err}}", err)
|
||||
}
|
||||
|
||||
// Store updated role information
|
||||
// lvr is the known LastVaultRotation
|
||||
lvr := time.Now()
|
||||
input.Role.StaticAccount.LastVaultRotation = lvr
|
||||
input.Role.StaticAccount.Password = password
|
||||
input.Role.StaticAccount.Password = newPassword
|
||||
output.RotationTime = lvr
|
||||
|
||||
entry, err := logical.StorageEntryJSON(databaseStaticRolePath+input.RoleName, input.Role)
|
||||
@@ -393,7 +398,7 @@ func (b *databaseBackend) setStaticAccount(ctx context.Context, s logical.Storag
|
||||
return &setStaticAccountOutput{RotationTime: lvr}, merr
|
||||
}
|
||||
|
||||
// initQueue preforms the necessary checks and initializations needed to preform
|
||||
// initQueue preforms the necessary checks and initializations needed to perform
|
||||
// automatic credential rotation for roles associated with static accounts. This
|
||||
// method verifies if a queue is needed (primary server or local mount), and if
|
||||
// so initializes the queue and launches a go-routine to periodically invoke a
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
@@ -45,13 +46,13 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||
}
|
||||
|
||||
// Get the Database object
|
||||
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
|
||||
dbi, err := b.GetConnection(ctx, req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.RLock()
|
||||
defer db.RUnlock()
|
||||
dbi.RLock()
|
||||
defer dbi.RUnlock()
|
||||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
ttl, _, err := framework.CalculateTTL(b.System(), req.Secret.Increment, role.DefaultTTL, 0, role.MaxTTL, 0, req.Secret.IssueTime)
|
||||
@@ -63,9 +64,19 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||
// Adding a small buffer since the TTL will be calculated again after this call
|
||||
// to ensure the database credential does not expire before the lease
|
||||
expireTime = expireTime.Add(5 * time.Second)
|
||||
err := db.RenewUser(ctx, role.Statements, username, expireTime)
|
||||
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: username,
|
||||
Expiration: &newdbplugin.ChangeExpiration{
|
||||
NewExpiration: expireTime,
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: role.Statements.Renewal,
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := dbi.database.UpdateUser(ctx, updateReq, false)
|
||||
if err != nil {
|
||||
b.CloseIfShutdown(db, err)
|
||||
b.CloseIfShutdown(dbi, err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -103,41 +114,49 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
||||
dbName = role.DBName
|
||||
statements = role.Statements
|
||||
} else {
|
||||
if dbNameRaw, ok := req.Secret.InternalData["db_name"]; !ok {
|
||||
dbNameRaw, ok := req.Secret.InternalData["db_name"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error during revoke: could not find role with name %q or embedded revocation db name data", req.Secret.InternalData["role"])
|
||||
} else {
|
||||
dbName = dbNameRaw.(string)
|
||||
}
|
||||
if statementsRaw, ok := req.Secret.InternalData["revocation_statements"]; !ok {
|
||||
dbName = dbNameRaw.(string)
|
||||
|
||||
statementsRaw, ok := req.Secret.InternalData["revocation_statements"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error during revoke: could not find role with name %q or embedded revocation statement data", req.Secret.InternalData["role"])
|
||||
} else {
|
||||
// If we don't actually have any statements, because none were
|
||||
// set in the role, we'll end up with an empty one and the
|
||||
// default for the db type will be attempted
|
||||
if statementsRaw != nil {
|
||||
statementsSlice, ok := statementsRaw.([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error during revoke: could not find role with name %q and embedded reovcation data could not be read", req.Secret.InternalData["role"])
|
||||
} else {
|
||||
for _, v := range statementsSlice {
|
||||
statements.Revocation = append(statements.Revocation, v.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we don't actually have any statements, because none were
|
||||
// set in the role, we'll end up with an empty one and the
|
||||
// default for the db type will be attempted
|
||||
if statementsRaw != nil {
|
||||
statementsSlice, ok := statementsRaw.([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error during revoke: could not find role with name %q and embedded reovcation data could not be read", req.Secret.InternalData["role"])
|
||||
}
|
||||
for _, v := range statementsSlice {
|
||||
statements.Revocation = append(statements.Revocation, v.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get our connection
|
||||
db, err := b.GetConnection(ctx, req.Storage, dbName)
|
||||
dbi, err := b.GetConnection(ctx, req.Storage, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.RLock()
|
||||
defer db.RUnlock()
|
||||
dbi.RLock()
|
||||
defer dbi.RUnlock()
|
||||
|
||||
if err := db.RevokeUser(ctx, statements, username); err != nil {
|
||||
b.CloseIfShutdown(db, err)
|
||||
deleteReq := newdbplugin.DeleteUserRequest{
|
||||
Username: username,
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: statements.Revocation,
|
||||
},
|
||||
}
|
||||
_, err = dbi.database.DeleteUser(ctx, deleteReq)
|
||||
if err != nil {
|
||||
b.CloseIfShutdown(dbi, err)
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
|
||||
268
builtin/logical/database/version_wrapper.go
Normal file
268
builtin/logical/database/version_wrapper.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/random"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type databaseVersionWrapper struct {
|
||||
v4 dbplugin.Database
|
||||
v5 newdbplugin.Database
|
||||
}
|
||||
|
||||
// newDatabaseWrapper figures out which version of the database the pluginName is referring to and returns a wrapper object
|
||||
// that can be used to make operations on the underlying database plugin.
|
||||
func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (dbw databaseVersionWrapper, err error) {
|
||||
newDB, err := newdbplugin.PluginFactory(ctx, pluginName, sys, logger)
|
||||
if err == nil {
|
||||
dbw = databaseVersionWrapper{
|
||||
v5: newDB,
|
||||
}
|
||||
return dbw, nil
|
||||
}
|
||||
|
||||
legacyDB, err := dbplugin.PluginFactory(ctx, pluginName, sys, logger)
|
||||
if err == nil {
|
||||
dbw = databaseVersionWrapper{
|
||||
v4: legacyDB,
|
||||
}
|
||||
return dbw, nil
|
||||
}
|
||||
|
||||
return dbw, fmt.Errorf("invalid database version")
|
||||
}
|
||||
|
||||
// Initialize the underlying database. This is analogous to a constructor on the database plugin object.
|
||||
// Errors if the wrapper does not contain an underlying database.
|
||||
func (d databaseVersionWrapper) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
|
||||
if !d.isV5() && !d.isV4() {
|
||||
return newdbplugin.InitializeResponse{}, fmt.Errorf("no underlying database specified")
|
||||
}
|
||||
|
||||
// v5 Database
|
||||
if d.isV5() {
|
||||
return d.v5.Initialize(ctx, req)
|
||||
}
|
||||
|
||||
// v4 Database
|
||||
saveConfig, err := d.v4.Init(ctx, req.Config, req.VerifyConnection)
|
||||
if err != nil {
|
||||
return newdbplugin.InitializeResponse{}, err
|
||||
}
|
||||
resp := newdbplugin.InitializeResponse{
|
||||
Config: saveConfig,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// NewUser in the database. This is different from the v5 Database in that it returns a password as well.
|
||||
// This is done because the v4 Database is expected to generate a password and return it. The NewUserResponse
|
||||
// does not have a way of returning the password so this function signature needs to be different.
|
||||
// The password returned here should be considered the source of truth, not the provided password.
|
||||
// Errors if the wrapper does not contain an underlying database.
|
||||
func (d databaseVersionWrapper) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (resp newdbplugin.NewUserResponse, password string, err error) {
|
||||
if !d.isV5() && !d.isV4() {
|
||||
return newdbplugin.NewUserResponse{}, "", fmt.Errorf("no underlying database specified")
|
||||
}
|
||||
|
||||
// v5 Database
|
||||
if d.isV5() {
|
||||
resp, err = d.v5.NewUser(ctx, req)
|
||||
return resp, req.Password, err
|
||||
}
|
||||
|
||||
// v4 Database
|
||||
stmts := dbplugin.Statements{
|
||||
Creation: req.Statements.Commands,
|
||||
Rollback: req.RollbackStatements.Commands,
|
||||
}
|
||||
usernameConfig := dbplugin.UsernameConfig{
|
||||
DisplayName: req.UsernameConfig.DisplayName,
|
||||
RoleName: req.UsernameConfig.RoleName,
|
||||
}
|
||||
username, password, err := d.v4.CreateUser(ctx, stmts, usernameConfig, req.Expiration)
|
||||
if err != nil {
|
||||
return resp, "", err
|
||||
}
|
||||
|
||||
resp = newdbplugin.NewUserResponse{
|
||||
Username: username,
|
||||
}
|
||||
return resp, password, nil
|
||||
}
|
||||
|
||||
// UpdateUser in the underlying database. This is used to update any information currently supported
|
||||
// in the UpdateUserRequest such as password credentials or user TTL.
|
||||
// Errors if the wrapper does not contain an underlying database.
|
||||
func (d databaseVersionWrapper) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest, isRootUser bool) (saveConfig map[string]interface{}, err error) {
|
||||
if !d.isV5() && !d.isV4() {
|
||||
return nil, fmt.Errorf("no underlying database specified")
|
||||
}
|
||||
|
||||
// v5 Database
|
||||
if d.isV5() {
|
||||
_, err := d.v5.UpdateUser(ctx, req)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// v4 Database
|
||||
if req.Password == nil && req.Expiration == nil {
|
||||
return nil, fmt.Errorf("missing change to be sent to the database")
|
||||
}
|
||||
if req.Password != nil && req.Expiration != nil {
|
||||
// We could support this, but it would require handling partial
|
||||
// errors which I'm punting on since we don't need it for now
|
||||
return nil, fmt.Errorf("cannot specify both password and expiration change at the same time")
|
||||
}
|
||||
|
||||
// Change password
|
||||
if req.Password != nil {
|
||||
return d.changePasswordLegacy(ctx, req.Username, req.Password, isRootUser)
|
||||
}
|
||||
|
||||
// Change expiration date
|
||||
if req.Expiration != nil {
|
||||
stmts := dbplugin.Statements{
|
||||
Renewal: req.Expiration.Statements.Commands,
|
||||
}
|
||||
err := d.v4.RenewUser(ctx, stmts, req.Username, req.Expiration.NewExpiration)
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// changePasswordLegacy attempts to use SetCredentials to change the password for the user with the password provided
|
||||
// in ChangePassword. If that user is the root user and SetCredentials is unimplemented, it will fall back to using
|
||||
// RotateRootCredentials. If not the root user, this will not use RotateRootCredentials.
|
||||
func (d databaseVersionWrapper) changePasswordLegacy(ctx context.Context, username string, passwordChange *newdbplugin.ChangePassword, isRootUser bool) (saveConfig map[string]interface{}, err error) {
|
||||
err = d.changeUserPasswordLegacy(ctx, username, passwordChange)
|
||||
|
||||
// If changing the root user's password but SetCredentials is unimplemented, fall back to RotateRootCredentials
|
||||
if isRootUser && status.Code(err) == codes.Unimplemented {
|
||||
saveConfig, err = d.changeRootUserPasswordLegacy(ctx, passwordChange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return saveConfig, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d databaseVersionWrapper) changeUserPasswordLegacy(ctx context.Context, username string, passwordChange *newdbplugin.ChangePassword) (err error) {
|
||||
stmts := dbplugin.Statements{
|
||||
Rotation: passwordChange.Statements.Commands,
|
||||
}
|
||||
staticConfig := dbplugin.StaticUserConfig{
|
||||
Username: username,
|
||||
Password: passwordChange.NewPassword,
|
||||
}
|
||||
_, _, err = d.v4.SetCredentials(ctx, stmts, staticConfig)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d databaseVersionWrapper) changeRootUserPasswordLegacy(ctx context.Context, passwordChange *newdbplugin.ChangePassword) (saveConfig map[string]interface{}, err error) {
|
||||
return d.v4.RotateRootCredentials(ctx, passwordChange.Statements.Commands)
|
||||
}
|
||||
|
||||
// DeleteUser in the underlying database. Errors if the wrapper does not contain an underlying database.
|
||||
func (d databaseVersionWrapper) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
|
||||
if !d.isV5() && !d.isV4() {
|
||||
return newdbplugin.DeleteUserResponse{}, fmt.Errorf("no underlying database specified")
|
||||
}
|
||||
|
||||
// v5 Database
|
||||
if d.isV5() {
|
||||
return d.v5.DeleteUser(ctx, req)
|
||||
}
|
||||
|
||||
// v4 Database
|
||||
stmts := dbplugin.Statements{
|
||||
Revocation: req.Statements.Commands,
|
||||
}
|
||||
err := d.v4.RevokeUser(ctx, stmts, req.Username)
|
||||
return newdbplugin.DeleteUserResponse{}, err
|
||||
}
|
||||
|
||||
// Type of the underlying database. Errors if the wrapper does not contain an underlying database.
|
||||
func (d databaseVersionWrapper) Type() (string, error) {
|
||||
if !d.isV5() && !d.isV4() {
|
||||
return "", fmt.Errorf("no underlying database specified")
|
||||
}
|
||||
|
||||
// v5 Database
|
||||
if d.isV5() {
|
||||
return d.v5.Type()
|
||||
}
|
||||
|
||||
// v4 Database
|
||||
return d.v4.Type()
|
||||
}
|
||||
|
||||
// Close the underlying database. Errors if the wrapper does not contain an underlying database.
|
||||
func (d databaseVersionWrapper) Close() error {
|
||||
if !d.isV5() && !d.isV4() {
|
||||
return fmt.Errorf("no underlying database specified")
|
||||
}
|
||||
// v5 Database
|
||||
if d.isV5() {
|
||||
return d.v5.Close()
|
||||
}
|
||||
|
||||
// v4 Database
|
||||
return d.v4.Close()
|
||||
}
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////////
|
||||
// Password generation
|
||||
// /////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type passwordGenerator interface {
|
||||
GeneratePasswordFromPolicy(ctx context.Context, policyName string) (password string, err error)
|
||||
}
|
||||
|
||||
var (
|
||||
defaultPasswordGenerator = random.DefaultStringGenerator
|
||||
)
|
||||
|
||||
// GeneratePassword either from the v4 database or by using the provided password policy. If using a v5 database
|
||||
// and no password policy is specified, this will have a reasonable default password generator.
|
||||
func (d databaseVersionWrapper) GeneratePassword(ctx context.Context, generator passwordGenerator, passwordPolicy string) (password string, err error) {
|
||||
if !d.isV5() && !d.isV4() {
|
||||
return "", fmt.Errorf("no underlying database specified")
|
||||
}
|
||||
|
||||
// If using the legacy database, use GenerateCredentials instead of password policies
|
||||
// This will keep the existing behavior even though passwords can be generated with a policy
|
||||
if d.isV4() {
|
||||
password, err := d.v4.GenerateCredentials(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
|
||||
if passwordPolicy == "" {
|
||||
return defaultPasswordGenerator.Generate(ctx, rand.Reader)
|
||||
}
|
||||
return generator.GeneratePasswordFromPolicy(ctx, passwordPolicy)
|
||||
}
|
||||
|
||||
func (d databaseVersionWrapper) isV5() bool {
|
||||
return d.v5 != nil
|
||||
}
|
||||
|
||||
func (d databaseVersionWrapper) isV4() bool {
|
||||
return d.v4 != nil
|
||||
}
|
||||
994
builtin/logical/database/version_wrapper_test.go
Normal file
994
builtin/logical/database/version_wrapper_test.go
Normal file
@@ -0,0 +1,994 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestInitDatabase_missingDB(t *testing.T) {
|
||||
dbw := databaseVersionWrapper{}
|
||||
|
||||
req := newdbplugin.InitializeRequest{}
|
||||
resp, err := dbw.Initialize(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
||||
expectedResp := newdbplugin.InitializeResponse{}
|
||||
if !reflect.DeepEqual(resp, expectedResp) {
|
||||
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, expectedResp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitDatabase_newDB(t *testing.T) {
|
||||
type testCase struct {
|
||||
req newdbplugin.InitializeRequest
|
||||
|
||||
newInitResp newdbplugin.InitializeResponse
|
||||
newInitErr error
|
||||
newInitCalls int
|
||||
|
||||
expectedResp newdbplugin.InitializeResponse
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"success": {
|
||||
req: newdbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
VerifyConnection: true,
|
||||
},
|
||||
newInitResp: newdbplugin.InitializeResponse{
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
newInitCalls: 1,
|
||||
expectedResp: newdbplugin.InitializeResponse{
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
"error": {
|
||||
req: newdbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
VerifyConnection: true,
|
||||
},
|
||||
newInitResp: newdbplugin.InitializeResponse{},
|
||||
newInitErr: fmt.Errorf("test error"),
|
||||
newInitCalls: 1,
|
||||
expectedResp: newdbplugin.InitializeResponse{},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
newDB := new(mockNewDatabase)
|
||||
newDB.On("Initialize", mock.Anything, mock.Anything).
|
||||
Return(test.newInitResp, test.newInitErr)
|
||||
defer newDB.AssertNumberOfCalls(t, "Initialize", test.newInitCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v5: newDB,
|
||||
}
|
||||
|
||||
resp, err := dbw.Initialize(context.Background(), test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp, test.expectedResp) {
|
||||
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitDatabase_legacyDB(t *testing.T) {
|
||||
type testCase struct {
|
||||
req newdbplugin.InitializeRequest
|
||||
|
||||
initConfig map[string]interface{}
|
||||
initErr error
|
||||
initCalls int
|
||||
|
||||
expectedResp newdbplugin.InitializeResponse
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"success": {
|
||||
req: newdbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
VerifyConnection: true,
|
||||
},
|
||||
initConfig: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
initCalls: 1,
|
||||
expectedResp: newdbplugin.InitializeResponse{
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
"error": {
|
||||
req: newdbplugin.InitializeRequest{
|
||||
Config: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
VerifyConnection: true,
|
||||
},
|
||||
initErr: fmt.Errorf("test error"),
|
||||
initCalls: 1,
|
||||
expectedResp: newdbplugin.InitializeResponse{},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
legacyDB := new(mockLegacyDatabase)
|
||||
legacyDB.On("Init", mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(test.initConfig, test.initErr)
|
||||
defer legacyDB.AssertNumberOfCalls(t, "Init", test.initCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v4: legacyDB,
|
||||
}
|
||||
|
||||
resp, err := dbw.Initialize(context.Background(), test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp, test.expectedResp) {
|
||||
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakePasswordGenerator struct {
|
||||
password string
|
||||
err error
|
||||
}
|
||||
|
||||
func (pg fakePasswordGenerator) GeneratePasswordFromPolicy(ctx context.Context, policy string) (string, error) {
|
||||
return pg.password, pg.err
|
||||
}
|
||||
|
||||
func TestGeneratePassword_missingDB(t *testing.T) {
|
||||
dbw := databaseVersionWrapper{}
|
||||
|
||||
gen := fakePasswordGenerator{
|
||||
err: fmt.Errorf("this shouldn't be called"),
|
||||
}
|
||||
pass, err := dbw.GeneratePassword(context.Background(), gen, "policy")
|
||||
if err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
||||
if pass != "" {
|
||||
t.Fatalf("Password should be empty but was: %s", pass)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePassword_legacy(t *testing.T) {
|
||||
type testCase struct {
|
||||
legacyPassword string
|
||||
legacyErr error
|
||||
legacyCalls int
|
||||
|
||||
expectedPassword string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"legacy password generation": {
|
||||
legacyPassword: "legacy_password",
|
||||
legacyErr: nil,
|
||||
legacyCalls: 1,
|
||||
|
||||
expectedPassword: "legacy_password",
|
||||
expectErr: false,
|
||||
},
|
||||
"legacy password failure": {
|
||||
legacyPassword: "",
|
||||
legacyErr: fmt.Errorf("failed :("),
|
||||
legacyCalls: 1,
|
||||
|
||||
expectedPassword: "",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
legacyDB := new(mockLegacyDatabase)
|
||||
legacyDB.On("GenerateCredentials", mock.Anything).
|
||||
Return(test.legacyPassword, test.legacyErr)
|
||||
defer legacyDB.AssertNumberOfCalls(t, "GenerateCredentials", test.legacyCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v4: legacyDB,
|
||||
}
|
||||
|
||||
passGen := fakePasswordGenerator{
|
||||
err: fmt.Errorf("this should not be called"),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
password, err := dbw.GeneratePassword(ctx, passGen, "test_policy")
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
if password != test.expectedPassword {
|
||||
t.Fatalf("Actual password: %s Expected password: %s", password, test.expectedPassword)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePassword_policies(t *testing.T) {
|
||||
type testCase struct {
|
||||
passwordPolicyPassword string
|
||||
passwordPolicyErr error
|
||||
|
||||
expectedPassword string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"password policy generation": {
|
||||
passwordPolicyPassword: "new_password",
|
||||
|
||||
expectedPassword: "new_password",
|
||||
expectErr: false,
|
||||
},
|
||||
"password policy error": {
|
||||
passwordPolicyPassword: "",
|
||||
passwordPolicyErr: fmt.Errorf("test error"),
|
||||
|
||||
expectedPassword: "",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
newDB := new(mockNewDatabase)
|
||||
defer newDB.AssertExpectations(t)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v5: newDB,
|
||||
}
|
||||
|
||||
passGen := fakePasswordGenerator{
|
||||
password: test.passwordPolicyPassword,
|
||||
err: test.passwordPolicyErr,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
password, err := dbw.GeneratePassword(ctx, passGen, "test_policy")
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
if password != test.expectedPassword {
|
||||
t.Fatalf("Actual password: %s Expected password: %s", password, test.expectedPassword)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePassword_no_policy(t *testing.T) {
|
||||
newDB := new(mockNewDatabase)
|
||||
defer newDB.AssertExpectations(t)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v5: newDB,
|
||||
}
|
||||
|
||||
passGen := fakePasswordGenerator{
|
||||
password: "",
|
||||
err: fmt.Errorf("should not be called"),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
password, err := dbw.GeneratePassword(ctx, passGen, "")
|
||||
if err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
if password == "" {
|
||||
t.Fatalf("missing password")
|
||||
}
|
||||
|
||||
rawRegex := "^[a-zA-Z0-9-]{20}$"
|
||||
re := regexp.MustCompile(rawRegex)
|
||||
if !re.MatchString(password) {
|
||||
t.Fatalf("password %q did not match regex: %q", password, rawRegex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUser_missingDB(t *testing.T) {
|
||||
dbw := databaseVersionWrapper{}
|
||||
|
||||
req := newdbplugin.NewUserRequest{}
|
||||
resp, pass, err := dbw.NewUser(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
||||
expectedResp := newdbplugin.NewUserResponse{}
|
||||
if !reflect.DeepEqual(resp, expectedResp) {
|
||||
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, expectedResp)
|
||||
}
|
||||
|
||||
if pass != "" {
|
||||
t.Fatalf("Password should be empty but was: %s", pass)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUser_newDB(t *testing.T) {
|
||||
type testCase struct {
|
||||
req newdbplugin.NewUserRequest
|
||||
|
||||
newUserResp newdbplugin.NewUserResponse
|
||||
newUserErr error
|
||||
newUserCalls int
|
||||
|
||||
expectedResp newdbplugin.NewUserResponse
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"success": {
|
||||
req: newdbplugin.NewUserRequest{
|
||||
Password: "new_password",
|
||||
},
|
||||
|
||||
newUserResp: newdbplugin.NewUserResponse{
|
||||
Username: "newuser",
|
||||
},
|
||||
newUserCalls: 1,
|
||||
|
||||
expectedResp: newdbplugin.NewUserResponse{
|
||||
Username: "newuser",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
"error": {
|
||||
req: newdbplugin.NewUserRequest{
|
||||
Password: "new_password",
|
||||
},
|
||||
|
||||
newUserErr: fmt.Errorf("test error"),
|
||||
newUserCalls: 1,
|
||||
|
||||
expectedResp: newdbplugin.NewUserResponse{},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
newDB := new(mockNewDatabase)
|
||||
newDB.On("NewUser", mock.Anything, mock.Anything).
|
||||
Return(test.newUserResp, test.newUserErr)
|
||||
defer newDB.AssertNumberOfCalls(t, "NewUser", test.newUserCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v5: newDB,
|
||||
}
|
||||
|
||||
resp, password, err := dbw.NewUser(context.Background(), test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp, test.expectedResp) {
|
||||
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
|
||||
}
|
||||
|
||||
if password != test.req.Password {
|
||||
t.Fatalf("Actual password: %s Expected password: %s", password, test.req.Password)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUser_legacyDB(t *testing.T) {
|
||||
type testCase struct {
|
||||
req newdbplugin.NewUserRequest
|
||||
|
||||
createUserUsername string
|
||||
createUserPassword string
|
||||
createUserErr error
|
||||
createUserCalls int
|
||||
|
||||
expectedResp newdbplugin.NewUserResponse
|
||||
expectedPassword string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"success": {
|
||||
req: newdbplugin.NewUserRequest{
|
||||
Password: "new_password",
|
||||
},
|
||||
|
||||
createUserUsername: "newuser",
|
||||
createUserPassword: "securepassword",
|
||||
createUserCalls: 1,
|
||||
|
||||
expectedResp: newdbplugin.NewUserResponse{
|
||||
Username: "newuser",
|
||||
},
|
||||
expectedPassword: "securepassword",
|
||||
expectErr: false,
|
||||
},
|
||||
"error": {
|
||||
req: newdbplugin.NewUserRequest{
|
||||
Password: "new_password",
|
||||
},
|
||||
|
||||
createUserErr: fmt.Errorf("test error"),
|
||||
createUserCalls: 1,
|
||||
|
||||
expectedResp: newdbplugin.NewUserResponse{},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
legacyDB := new(mockLegacyDatabase)
|
||||
legacyDB.On("CreateUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(test.createUserUsername, test.createUserPassword, test.createUserErr)
|
||||
defer legacyDB.AssertNumberOfCalls(t, "CreateUser", test.createUserCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v4: legacyDB,
|
||||
}
|
||||
|
||||
resp, password, err := dbw.NewUser(context.Background(), test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp, test.expectedResp) {
|
||||
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
|
||||
}
|
||||
|
||||
if password != test.expectedPassword {
|
||||
t.Fatalf("Actual password: %s Expected password: %s", password, test.req.Password)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser_missingDB(t *testing.T) {
|
||||
dbw := databaseVersionWrapper{}
|
||||
|
||||
req := newdbplugin.UpdateUserRequest{}
|
||||
resp, err := dbw.UpdateUser(context.Background(), req, false)
|
||||
if err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
||||
expectedConfig := map[string]interface{}(nil)
|
||||
if !reflect.DeepEqual(resp, expectedConfig) {
|
||||
t.Fatalf("Actual config: %#v\nExpected config: %#v", resp, expectedConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser_newDB(t *testing.T) {
|
||||
type testCase struct {
|
||||
req newdbplugin.UpdateUserRequest
|
||||
|
||||
updateUserErr error
|
||||
updateUserCalls int
|
||||
|
||||
expectedResp newdbplugin.UpdateUserResponse
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"success": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
},
|
||||
updateUserCalls: 1,
|
||||
expectErr: false,
|
||||
},
|
||||
"error": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
},
|
||||
updateUserErr: fmt.Errorf("test error"),
|
||||
updateUserCalls: 1,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
newDB := new(mockNewDatabase)
|
||||
newDB.On("UpdateUser", mock.Anything, mock.Anything).
|
||||
Return(newdbplugin.UpdateUserResponse{}, test.updateUserErr)
|
||||
defer newDB.AssertNumberOfCalls(t, "UpdateUser", test.updateUserCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v5: newDB,
|
||||
}
|
||||
|
||||
_, err := dbw.UpdateUser(context.Background(), test.req, false)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser_legacyDB(t *testing.T) {
|
||||
type testCase struct {
|
||||
req newdbplugin.UpdateUserRequest
|
||||
isRootUser bool
|
||||
|
||||
setCredentialsErr error
|
||||
setCredentialsCalls int
|
||||
|
||||
rotateRootConfig map[string]interface{}
|
||||
rotateRootErr error
|
||||
rotateRootCalls int
|
||||
|
||||
renewUserErr error
|
||||
renewUserCalls int
|
||||
|
||||
expectedConfig map[string]interface{}
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"missing changes": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
},
|
||||
isRootUser: false,
|
||||
|
||||
setCredentialsCalls: 0,
|
||||
rotateRootCalls: 0,
|
||||
renewUserCalls: 0,
|
||||
|
||||
expectErr: true,
|
||||
},
|
||||
"both password and expiration changes": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
Password: &newdbplugin.ChangePassword{},
|
||||
Expiration: &newdbplugin.ChangeExpiration{},
|
||||
},
|
||||
isRootUser: false,
|
||||
|
||||
setCredentialsCalls: 0,
|
||||
rotateRootCalls: 0,
|
||||
renewUserCalls: 0,
|
||||
|
||||
expectErr: true,
|
||||
},
|
||||
"change password - SetCredentials": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: "newpassowrd",
|
||||
},
|
||||
},
|
||||
isRootUser: false,
|
||||
|
||||
setCredentialsErr: nil,
|
||||
setCredentialsCalls: 1,
|
||||
rotateRootCalls: 0,
|
||||
renewUserCalls: 0,
|
||||
|
||||
expectedConfig: nil,
|
||||
expectErr: false,
|
||||
},
|
||||
"change password - SetCredentials failed": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: "newpassowrd",
|
||||
},
|
||||
},
|
||||
isRootUser: false,
|
||||
|
||||
setCredentialsErr: fmt.Errorf("set credentials failed"),
|
||||
setCredentialsCalls: 1,
|
||||
rotateRootCalls: 0,
|
||||
renewUserCalls: 0,
|
||||
|
||||
expectedConfig: nil,
|
||||
expectErr: true,
|
||||
},
|
||||
"change password - SetCredentials unimplemented but not a root user": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: "newpassowrd",
|
||||
},
|
||||
},
|
||||
isRootUser: false,
|
||||
|
||||
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
|
||||
setCredentialsCalls: 1,
|
||||
|
||||
rotateRootCalls: 0,
|
||||
renewUserCalls: 0,
|
||||
|
||||
expectedConfig: nil,
|
||||
expectErr: true,
|
||||
},
|
||||
"change password - RotateRootCredentials": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: "newpassowrd",
|
||||
},
|
||||
},
|
||||
isRootUser: true,
|
||||
|
||||
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
|
||||
setCredentialsCalls: 1,
|
||||
|
||||
rotateRootConfig: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
rotateRootCalls: 1,
|
||||
|
||||
renewUserCalls: 0,
|
||||
|
||||
expectedConfig: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
"change password - RotateRootCredentials failed": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: "newpassowrd",
|
||||
},
|
||||
},
|
||||
isRootUser: true,
|
||||
|
||||
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
|
||||
setCredentialsCalls: 1,
|
||||
|
||||
rotateRootErr: fmt.Errorf("rotate root failed"),
|
||||
rotateRootCalls: 1,
|
||||
renewUserCalls: 0,
|
||||
|
||||
expectedConfig: nil,
|
||||
expectErr: true,
|
||||
},
|
||||
|
||||
"change expiration": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
Expiration: &newdbplugin.ChangeExpiration{
|
||||
NewExpiration: time.Now(),
|
||||
},
|
||||
},
|
||||
isRootUser: false,
|
||||
|
||||
setCredentialsCalls: 0,
|
||||
rotateRootCalls: 0,
|
||||
|
||||
renewUserErr: nil,
|
||||
renewUserCalls: 1,
|
||||
|
||||
expectedConfig: nil,
|
||||
expectErr: false,
|
||||
},
|
||||
"change expiration failed": {
|
||||
req: newdbplugin.UpdateUserRequest{
|
||||
Username: "existing_user",
|
||||
Expiration: &newdbplugin.ChangeExpiration{
|
||||
NewExpiration: time.Now(),
|
||||
},
|
||||
},
|
||||
isRootUser: false,
|
||||
|
||||
setCredentialsCalls: 0,
|
||||
rotateRootCalls: 0,
|
||||
|
||||
renewUserErr: fmt.Errorf("test error"),
|
||||
renewUserCalls: 1,
|
||||
|
||||
expectedConfig: nil,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
legacyDB := new(mockLegacyDatabase)
|
||||
legacyDB.On("SetCredentials", mock.Anything, mock.Anything, mock.Anything).
|
||||
Return("", "", test.setCredentialsErr)
|
||||
defer legacyDB.AssertNumberOfCalls(t, "SetCredentials", test.setCredentialsCalls)
|
||||
|
||||
legacyDB.On("RotateRootCredentials", mock.Anything, mock.Anything).
|
||||
Return(test.rotateRootConfig, test.rotateRootErr)
|
||||
defer legacyDB.AssertNumberOfCalls(t, "RotateRootCredentials", test.rotateRootCalls)
|
||||
|
||||
legacyDB.On("RenewUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(test.renewUserErr)
|
||||
defer legacyDB.AssertNumberOfCalls(t, "RenewUser", test.renewUserCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v4: legacyDB,
|
||||
}
|
||||
|
||||
newConfig, err := dbw.UpdateUser(context.Background(), test.req, test.isRootUser)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(newConfig, test.expectedConfig) {
|
||||
t.Fatalf("Actual config: %#v\nExpected config: %#v", newConfig, test.expectedConfig)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser_missingDB(t *testing.T) {
|
||||
dbw := databaseVersionWrapper{}
|
||||
|
||||
req := newdbplugin.DeleteUserRequest{}
|
||||
_, err := dbw.DeleteUser(context.Background(), req)
|
||||
if err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser_newDB(t *testing.T) {
|
||||
type testCase struct {
|
||||
req newdbplugin.DeleteUserRequest
|
||||
|
||||
deleteUserErr error
|
||||
deleteUserCalls int
|
||||
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"success": {
|
||||
req: newdbplugin.DeleteUserRequest{
|
||||
Username: "existing_user",
|
||||
},
|
||||
|
||||
deleteUserErr: nil,
|
||||
deleteUserCalls: 1,
|
||||
|
||||
expectErr: false,
|
||||
},
|
||||
"error": {
|
||||
req: newdbplugin.DeleteUserRequest{
|
||||
Username: "existing_user",
|
||||
},
|
||||
|
||||
deleteUserErr: fmt.Errorf("test error"),
|
||||
deleteUserCalls: 1,
|
||||
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
newDB := new(mockNewDatabase)
|
||||
newDB.On("DeleteUser", mock.Anything, mock.Anything).
|
||||
Return(newdbplugin.DeleteUserResponse{}, test.deleteUserErr)
|
||||
defer newDB.AssertNumberOfCalls(t, "DeleteUser", test.deleteUserCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v5: newDB,
|
||||
}
|
||||
|
||||
_, err := dbw.DeleteUser(context.Background(), test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser_legacyDB(t *testing.T) {
|
||||
type testCase struct {
|
||||
req newdbplugin.DeleteUserRequest
|
||||
|
||||
revokeUserErr error
|
||||
revokeUserCalls int
|
||||
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"success": {
|
||||
req: newdbplugin.DeleteUserRequest{
|
||||
Username: "existing_user",
|
||||
},
|
||||
|
||||
revokeUserErr: nil,
|
||||
revokeUserCalls: 1,
|
||||
|
||||
expectErr: false,
|
||||
},
|
||||
"error": {
|
||||
req: newdbplugin.DeleteUserRequest{
|
||||
Username: "existing_user",
|
||||
},
|
||||
|
||||
revokeUserErr: fmt.Errorf("test error"),
|
||||
revokeUserCalls: 1,
|
||||
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
legacyDB := new(mockLegacyDatabase)
|
||||
legacyDB.On("RevokeUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(test.revokeUserErr)
|
||||
defer legacyDB.AssertNumberOfCalls(t, "RevokeUser", test.revokeUserCalls)
|
||||
|
||||
dbw := databaseVersionWrapper{
|
||||
v4: legacyDB,
|
||||
}
|
||||
|
||||
_, err := dbw.DeleteUser(context.Background(), test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type badValue struct{}
|
||||
|
||||
func (badValue) MarshalJSON() ([]byte, error) {
|
||||
return nil, fmt.Errorf("this value cannot be marshalled to JSON")
|
||||
}
|
||||
|
||||
var _ logical.Storage = fakeStorage{}
|
||||
|
||||
type fakeStorage struct {
|
||||
putErr error
|
||||
}
|
||||
|
||||
func (f fakeStorage) Put(ctx context.Context, entry *logical.StorageEntry) error {
|
||||
return f.putErr
|
||||
}
|
||||
|
||||
func (f fakeStorage) List(ctx context.Context, s string) ([]string, error) {
|
||||
panic("list not implemented")
|
||||
}
|
||||
func (f fakeStorage) Get(ctx context.Context, s string) (*logical.StorageEntry, error) {
|
||||
panic("get not implemented")
|
||||
}
|
||||
func (f fakeStorage) Delete(ctx context.Context, s string) error {
|
||||
panic("delete not implemented")
|
||||
}
|
||||
|
||||
func TestStoreConfig(t *testing.T) {
|
||||
type testCase struct {
|
||||
config *DatabaseConfig
|
||||
putErr error
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"bad config": {
|
||||
config: &DatabaseConfig{
|
||||
PluginName: "testplugin",
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"bad value": badValue{},
|
||||
},
|
||||
},
|
||||
putErr: nil,
|
||||
expectErr: true,
|
||||
},
|
||||
"storage error": {
|
||||
config: &DatabaseConfig{
|
||||
PluginName: "testplugin",
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
putErr: fmt.Errorf("failed to store config"),
|
||||
expectErr: true,
|
||||
},
|
||||
"happy path": {
|
||||
config: &DatabaseConfig{
|
||||
PluginName: "testplugin",
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
putErr: nil,
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
storage := fakeStorage{
|
||||
putErr: test.putErr,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
err := storeConfig(ctx, storage, "testconfig", test.config)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
317
builtin/logical/database/versioning_large_test.go
Normal file
317
builtin/logical/database/versioning_large_test.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package database
|
||||
|
||||
// This file contains all "large"/expensive tests. These are running requests against a running backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func TestPlugin_lifecycle(t *testing.T) {
|
||||
cluster, sys := getCluster(t)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_MockV4", []string{}, "")
|
||||
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_MockV5", []string{}, "")
|
||||
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
config.System = sys
|
||||
lb, err := Factory(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b, ok := lb.(*databaseBackend)
|
||||
if !ok {
|
||||
t.Fatal("could not convert to database backend")
|
||||
}
|
||||
defer b.Cleanup(context.Background())
|
||||
|
||||
type testCase struct {
|
||||
dbName string
|
||||
dbType string
|
||||
configData map[string]interface{}
|
||||
assertDynamicUsername stringAssertion
|
||||
assertDynamicPassword stringAssertion
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"v4": {
|
||||
dbName: "mockv4",
|
||||
dbType: "mock-v4-database-plugin",
|
||||
configData: map[string]interface{}{
|
||||
"name": "mockv4",
|
||||
"plugin_name": "mock-v4-database-plugin",
|
||||
"connection_url": "sample_connection_url",
|
||||
"verify_connection": true,
|
||||
"allowed_roles": []string{"*"},
|
||||
"username": "mockv4-user",
|
||||
"password": "mysecurepassword",
|
||||
},
|
||||
assertDynamicUsername: assertStringPrefix("mockv4_user_"),
|
||||
assertDynamicPassword: assertStringPrefix("mockv4_"),
|
||||
},
|
||||
"v5": {
|
||||
dbName: "mockv5",
|
||||
dbType: "mock-v5-database-plugin",
|
||||
configData: map[string]interface{}{
|
||||
"connection_url": "sample_connection_url",
|
||||
"plugin_name": "mock-v5-database-plugin",
|
||||
"verify_connection": true,
|
||||
"allowed_roles": []string{"*"},
|
||||
"name": "mockv5",
|
||||
"username": "mockv5-user",
|
||||
"password": "mysecurepassword",
|
||||
},
|
||||
assertDynamicUsername: assertStringPrefix("mockv5_user_"),
|
||||
assertDynamicPassword: assertStringRegex("^[a-zA-Z0-9-]{20}"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
cleanupReqs := []*logical.Request{}
|
||||
defer cleanup(t, b, cleanupReqs)
|
||||
|
||||
// /////////////////////////////////////////////////////////////////
|
||||
// Configure
|
||||
req := &logical.Request{
|
||||
Operation: logical.CreateOperation,
|
||||
Path: fmt.Sprintf("config/%s", test.dbName),
|
||||
Storage: config.StorageView,
|
||||
Data: test.configData,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := b.HandleRequest(ctx, req)
|
||||
assertErrIsNil(t, err)
|
||||
assertRespHasNoErr(t, resp)
|
||||
assertNoRespData(t, resp)
|
||||
|
||||
cleanupReqs = append(cleanupReqs, &logical.Request{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: fmt.Sprintf("config/%s", test.dbName),
|
||||
Storage: config.StorageView,
|
||||
})
|
||||
|
||||
// /////////////////////////////////////////////////////////////////
|
||||
// Rotate root credentials
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: fmt.Sprintf("rotate-root/%s", test.dbName),
|
||||
Storage: config.StorageView,
|
||||
}
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err = b.HandleRequest(ctx, req)
|
||||
assertErrIsNil(t, err)
|
||||
assertRespHasNoErr(t, resp)
|
||||
assertNoRespData(t, resp)
|
||||
|
||||
// /////////////////////////////////////////////////////////////////
|
||||
// Dynamic credentials
|
||||
|
||||
// Create role
|
||||
dynamicRoleName := "dynamic-role"
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: fmt.Sprintf("roles/%s", dynamicRoleName),
|
||||
Storage: config.StorageView,
|
||||
Data: map[string]interface{}{
|
||||
"db_name": test.dbName,
|
||||
"default_ttl": "5s",
|
||||
"max_ttl": "1m",
|
||||
},
|
||||
}
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err = b.HandleRequest(ctx, req)
|
||||
assertErrIsNil(t, err)
|
||||
assertRespHasNoErr(t, resp)
|
||||
assertNoRespData(t, resp)
|
||||
|
||||
cleanupReqs = append(cleanupReqs, &logical.Request{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: fmt.Sprintf("roles/%s", dynamicRoleName),
|
||||
Storage: config.StorageView,
|
||||
})
|
||||
|
||||
// Generate credentials
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: fmt.Sprintf("creds/%s", dynamicRoleName),
|
||||
Storage: config.StorageView,
|
||||
}
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err = b.HandleRequest(ctx, req)
|
||||
assertErrIsNil(t, err)
|
||||
assertRespHasNoErr(t, resp)
|
||||
assertRespHasData(t, resp)
|
||||
|
||||
// TODO: Figure out how to make a call to the cluster that gives back a lease ID
|
||||
// And also rotates the secret out after its TTL
|
||||
|
||||
// /////////////////////////////////////////////////////////////////
|
||||
// Static credentials
|
||||
|
||||
// Create static role
|
||||
staticRoleName := "static-role"
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: fmt.Sprintf("static-roles/%s", staticRoleName),
|
||||
Storage: config.StorageView,
|
||||
Data: map[string]interface{}{
|
||||
"db_name": test.dbName,
|
||||
"username": "static-username",
|
||||
"rotation_period": "5",
|
||||
},
|
||||
}
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err = b.HandleRequest(ctx, req)
|
||||
assertErrIsNil(t, err)
|
||||
assertRespHasNoErr(t, resp)
|
||||
assertNoRespData(t, resp)
|
||||
|
||||
cleanupReqs = append(cleanupReqs, &logical.Request{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: fmt.Sprintf("static-roles/%s", staticRoleName),
|
||||
Storage: config.StorageView,
|
||||
})
|
||||
|
||||
// Get credentials
|
||||
req = &logical.Request{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: fmt.Sprintf("static-creds/%s", staticRoleName),
|
||||
Storage: config.StorageView,
|
||||
}
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err = b.HandleRequest(ctx, req)
|
||||
assertErrIsNil(t, err)
|
||||
assertRespHasNoErr(t, resp)
|
||||
assertRespHasData(t, resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup(t *testing.T, b *databaseBackend, reqs []*logical.Request) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Go in stack order so it works similar to defer
|
||||
for i := len(reqs) - 1; i >= 0; i-- {
|
||||
req := reqs[i]
|
||||
resp, err := b.HandleRequest(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Error cleaning up: %s", err)
|
||||
}
|
||||
if resp != nil && resp.IsError() {
|
||||
t.Fatalf("Error cleaning up: %s", resp.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMain_MockV4(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
|
||||
args := []string{"--ca-cert=" + caPEM}
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
RunV4(apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
func TestBackend_PluginMain_MockV5(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
|
||||
args := []string{"--ca-cert=" + caPEM}
|
||||
|
||||
apiClientMeta := &api.PluginAPIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
|
||||
RunV5(apiClientMeta.GetTLSConfig())
|
||||
}
|
||||
|
||||
func assertNoRespData(t *testing.T, resp *logical.Response) {
|
||||
t.Helper()
|
||||
if resp != nil && len(resp.Data) > 0 {
|
||||
t.Fatalf("Response had data when none was expected: %#v", resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func assertRespHasData(t *testing.T, resp *logical.Response) {
|
||||
t.Helper()
|
||||
if resp == nil || len(resp.Data) == 0 {
|
||||
t.Fatalf("Response didn't have any data when some was expected")
|
||||
}
|
||||
}
|
||||
|
||||
type stringAssertion func(t *testing.T, str string)
|
||||
|
||||
func assertStringPrefix(expectedPrefix string) stringAssertion {
|
||||
return func(t *testing.T, str string) {
|
||||
t.Helper()
|
||||
if !strings.HasPrefix(str, expectedPrefix) {
|
||||
t.Fatalf("Missing prefix '%s': Actual: '%s'", expectedPrefix, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertStringRegex(expectedRegex string) stringAssertion {
|
||||
re := regexp.MustCompile(expectedRegex)
|
||||
return func(t *testing.T, str string) {
|
||||
if !re.MatchString(str) {
|
||||
t.Fatalf("Actual: '%s' did not match regexp '%s'", str, expectedRegex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertRespHasNoErr(t *testing.T, resp *logical.Response) {
|
||||
t.Helper()
|
||||
if resp != nil && resp.IsError() {
|
||||
t.Fatalf("response is error: %#v\n", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func assertErrIsNil(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatalf("No error expected, got: %s", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user