[DBPW 4/X] Update DB engine to support v4 and v5 interfaces with password policies (#9878)

This commit is contained in:
Michael Golowka
2020-09-18 15:10:54 -06:00
committed by GitHub
parent 7c49c094fa
commit 1cd0c0599b
76 changed files with 21485 additions and 424 deletions

View File

@@ -8,12 +8,12 @@ import (
"sync"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/errwrap"
uuid "github.com/hashicorp/go-uuid"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/locksutil"
"github.com/hashicorp/vault/sdk/helper/strutil"
@@ -30,7 +30,7 @@ const (
type dbPluginInstance struct {
sync.RWMutex
dbplugin.Database
database databaseVersionWrapper
id string
name string
@@ -46,7 +46,7 @@ func (dbi *dbPluginInstance) Close() error {
}
dbi.closed = true
return dbi.Database.Close()
return dbi.database.Close()
}
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
@@ -89,7 +89,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
pathListRoles(&b),
pathRoles(&b),
pathCredsCreate(&b),
pathRotateCredentials(&b),
pathRotateRootCredentials(&b),
),
Secrets: []*framework.Secret{
@@ -240,9 +240,9 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
unlockFunc := b.RUnlock
defer func() { unlockFunc() }()
db, ok := b.connections[name]
dbi, ok := b.connections[name]
if ok {
return db, nil
return dbi, nil
}
// Upgrade lock
@@ -250,20 +250,9 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
b.Lock()
unlockFunc = b.Unlock
db, ok = b.connections[name]
dbi, ok = b.connections[name]
if ok {
return db, nil
}
dbp, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
return nil, err
}
_, err = dbp.Init(ctx, config.ConnectionDetails, true)
if err != nil {
dbp.Close()
return nil, err
return dbi, nil
}
id, err := uuid.GenerateUUID()
@@ -271,14 +260,28 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
return nil, err
}
db = &dbPluginInstance{
Database: dbp,
name: name,
id: id,
dbw, err := newDatabaseWrapper(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
return nil, fmt.Errorf("unable to create database instance: %w", err)
}
b.connections[name] = db
return db, nil
initReq := newdbplugin.InitializeRequest{
Config: config.ConnectionDetails,
VerifyConnection: true,
}
_, err = dbw.Initialize(ctx, initReq)
if err != nil {
dbw.Close()
return nil, err
}
dbi = &dbPluginInstance{
database: dbw,
id: id,
name: name,
}
b.connections[name] = dbi
return dbi, nil
}
// invalidateQueue cancels any background queue loading and destroys the queue.

View File

@@ -11,7 +11,7 @@ import (
"time"
"github.com/go-test/deep"
"github.com/hashicorp/vault-plugin-database-mongodbatlas"
mongodbatlas "github.com/hashicorp/vault-plugin-database-mongodbatlas"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/namespace"
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
@@ -231,6 +231,7 @@ func TestBackend_config_connection(t *testing.T) {
},
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
@@ -283,6 +284,7 @@ func TestBackend_config_connection(t *testing.T) {
},
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
@@ -324,6 +326,7 @@ func TestBackend_config_connection(t *testing.T) {
},
"allowed_roles": []string{"flu", "barre"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
@@ -711,6 +714,7 @@ func TestBackend_connectionCrud(t *testing.T) {
},
"allowed_roles": []string{"plugin-role-test"},
"root_credentials_rotate_statements": []string(nil),
"password_policy": "",
}
req.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), req)

View File

@@ -0,0 +1,101 @@
package database
import (
"context"
"time"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/stretchr/testify/mock"
)
var _ newdbplugin.Database = &mockNewDatabase{}
type mockNewDatabase struct {
mock.Mock
}
func (m *mockNewDatabase) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
args := m.Called(ctx, req)
return args.Get(0).(newdbplugin.InitializeResponse), args.Error(1)
}
func (m *mockNewDatabase) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
args := m.Called(ctx, req)
return args.Get(0).(newdbplugin.NewUserResponse), args.Error(1)
}
func (m *mockNewDatabase) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
args := m.Called(ctx, req)
return args.Get(0).(newdbplugin.UpdateUserResponse), args.Error(1)
}
func (m *mockNewDatabase) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
args := m.Called(ctx, req)
return args.Get(0).(newdbplugin.DeleteUserResponse), args.Error(1)
}
func (m *mockNewDatabase) Type() (string, error) {
args := m.Called()
return args.String(0), args.Error(1)
}
func (m *mockNewDatabase) Close() error {
args := m.Called()
return args.Error(0)
}
var _ dbplugin.Database = &mockLegacyDatabase{}
type mockLegacyDatabase struct {
mock.Mock
}
func (m *mockLegacyDatabase) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
args := m.Called(ctx, statements, usernameConfig, expiration)
return args.String(0), args.String(1), args.Error(2)
}
func (m *mockLegacyDatabase) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
args := m.Called(ctx, statements, username, expiration)
return args.Error(0)
}
func (m *mockLegacyDatabase) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
args := m.Called(ctx, statements, username)
return args.Error(0)
}
func (m *mockLegacyDatabase) RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error) {
args := m.Called(ctx, statements)
return args.Get(0).(map[string]interface{}), args.Error(1)
}
func (m *mockLegacyDatabase) GenerateCredentials(ctx context.Context) (string, error) {
args := m.Called(ctx)
return args.String(0), args.Error(1)
}
func (m *mockLegacyDatabase) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticConfig dbplugin.StaticUserConfig) (username string, password string, err error) {
args := m.Called(ctx, statements, staticConfig)
return args.String(0), args.String(1), args.Error(2)
}
func (m *mockLegacyDatabase) Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error) {
args := m.Called(ctx, config, verifyConnection)
return args.Get(0).(map[string]interface{}), args.Error(1)
}
func (m *mockLegacyDatabase) Type() (string, error) {
args := m.Called()
return args.String(0), args.Error(1)
}
func (m *mockLegacyDatabase) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *mockLegacyDatabase) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error) {
panic("Initialize should not be called")
}

View File

@@ -0,0 +1,118 @@
package database
import (
"context"
"fmt"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/dbplugin"
)
const mockV4Type = "mockv4"
// MockDatabaseV4 is an implementation of Database interface
type MockDatabaseV4 struct {
config map[string]interface{}
}
var _ dbplugin.Database = &MockDatabaseV4{}
// New returns a new in-memory instance
func NewV4() (interface{}, error) {
return MockDatabaseV4{}, nil
}
// RunV4 instantiates a MongoDB object, and runs the RPC server for the plugin
func RunV4(apiTLSConfig *api.TLSConfig) error {
dbType, err := NewV4()
if err != nil {
return err
}
dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
return nil
}
func (m MockDatabaseV4) Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error) {
log.Default().Info("Init called",
"config", config,
"verifyConnection", verifyConnection)
return config, nil
}
func (m MockDatabaseV4) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error) {
_, err = m.Init(ctx, config, verifyConnection)
return err
}
func (m MockDatabaseV4) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
log.Default().Info("CreateUser called",
"statements", statements,
"usernameConfig", usernameConfig,
"expiration", expiration)
now := time.Now()
user := fmt.Sprintf("mockv4_user_%s", now.Format(time.RFC3339))
pass, err := m.GenerateCredentials(ctx)
if err != nil {
return "", "", fmt.Errorf("failed to generate credentials: %w", err)
}
return user, pass, nil
}
func (m MockDatabaseV4) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
log.Default().Info("RenewUser called",
"statements", statements,
"username", username,
"expiration", expiration)
return nil
}
func (m MockDatabaseV4) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
log.Default().Info("RevokeUser called",
"statements", statements,
"username", username)
return nil
}
func (m MockDatabaseV4) RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error) {
log.Default().Info("RotateRootCredentials called",
"statements", statements)
newPassword, err := m.GenerateCredentials(ctx)
if err != nil {
return config, fmt.Errorf("failed to generate credentials: %w", err)
}
config["password"] = newPassword
return m.config, nil
}
func (m MockDatabaseV4) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticConfig dbplugin.StaticUserConfig) (username string, password string, err error) {
log.Default().Info("SetCredentials called",
"statements", statements,
"staticConfig", staticConfig)
return "", "", nil
}
func (m MockDatabaseV4) GenerateCredentials(ctx context.Context) (password string, err error) {
now := time.Now()
pass := fmt.Sprintf("mockv4_password_%s", now.Format(time.RFC3339))
return pass, nil
}
func (m MockDatabaseV4) Type() (string, error) {
log.Default().Info("Type called")
return mockV4Type, nil
}
func (m MockDatabaseV4) Close() error {
log.Default().Info("Close called")
return nil
}

View File

@@ -0,0 +1,85 @@
package database
import (
"context"
"fmt"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
)
const mockV5Type = "mockv5"
// MockDatabaseV5 is an implementation of Database interface
type MockDatabaseV5 struct {
config map[string]interface{}
}
var _ newdbplugin.Database = &MockDatabaseV5{}
// New returns a new in-memory instance
func New() (interface{}, error) {
db := MockDatabaseV5{}
return db, nil
}
// Run instantiates a MongoDB object, and runs the RPC server for the plugin
func RunV5(apiTLSConfig *api.TLSConfig) error {
dbType, err := New()
if err != nil {
return err
}
newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
return nil
}
func (m MockDatabaseV5) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
log.Default().Info("Initialize called",
"req", req)
config := req.Config
config["from-plugin"] = "this value is from the plugin itself"
resp := newdbplugin.InitializeResponse{
Config: req.Config,
}
return resp, nil
}
func (m MockDatabaseV5) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
log.Default().Info("NewUser called",
"req", req)
now := time.Now()
user := fmt.Sprintf("mockv5_user_%s", now.Format(time.RFC3339))
resp := newdbplugin.NewUserResponse{
Username: user,
}
return resp, nil
}
func (m MockDatabaseV5) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
log.Default().Info("UpdateUser called",
"req", req)
return newdbplugin.UpdateUserResponse{}, nil
}
func (m MockDatabaseV5) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
log.Default().Info("DeleteUser called",
"req", req)
return newdbplugin.DeleteUserResponse{}, nil
}
func (m MockDatabaseV5) Type() (string, error) {
log.Default().Info("Type called")
return mockV5Type, nil
}
func (m MockDatabaseV5) Close() error {
log.Default().Info("Close called")
return nil
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/fatih/structs"
"github.com/hashicorp/errwrap"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -30,6 +30,8 @@ type DatabaseConfig struct {
AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"`
RootCredentialsRotateStatements []string `json:"root_credentials_rotate_statements" structs:"root_credentials_rotate_statements" mapstructure:"root_credentials_rotate_statements"`
PasswordPolicy string `json:"password_policy" structs:"password_policy" mapstructure:"password_policy"`
}
// pathResetConnection configures a path to reset a plugin.
@@ -114,6 +116,10 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
page for more information on support and formatting for this
parameter.`,
},
"password_policy": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Password policy to use when generating passwords.`,
},
},
ExistenceCheck: b.connectionExistenceCheck(),
@@ -138,7 +144,7 @@ func (b *databaseBackend) connectionExistenceCheck() framework.ExistenceFunc {
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return false, errors.New("failed to read connection configuration")
return false, fmt.Errorf("failed to read connection configuration: %w", err)
}
return entry != nil, nil
@@ -179,7 +185,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, errors.New("failed to read connection configuration")
return nil, fmt.Errorf("failed to read connection configuration: %w", err)
}
if entry == nil {
return nil, nil
@@ -245,7 +251,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, errors.New("failed to read connection configuration")
return nil, fmt.Errorf("failed to read connection configuration: %w", err)
}
if entry != nil {
if err := entry.DecodeJSON(config); err != nil {
@@ -274,6 +280,10 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
config.RootCredentialsRotateStatements = data.Get("root_rotation_statements").([]string)
}
if passwordPolicyRaw, ok := data.GetOk("password_policy"); ok {
config.PasswordPolicy = passwordPolicyRaw.(string)
}
// Remove these entries from the data before we store it keyed under
// ConnectionDetails.
delete(data.Raw, "name")
@@ -281,11 +291,11 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
delete(data.Raw, "allowed_roles")
delete(data.Raw, "verify_connection")
delete(data.Raw, "root_rotation_statements")
delete(data.Raw, "password_policy")
// Create a database plugin and initialize it.
db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
id, err := uuid.GenerateUUID()
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
return nil, err
}
// If this is an update, take any new values, overwrite what was there
@@ -302,37 +312,39 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
}
}
config.ConnectionDetails, err = db.Init(ctx, config.ConnectionDetails, verifyConnection)
// Create a database plugin and initialize it.
dbw, err := newDatabaseWrapper(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
db.Close()
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
return logical.ErrorResponse("error creating database object: %s", err), nil
}
initReq := newdbplugin.InitializeRequest{
Config: config.ConnectionDetails,
VerifyConnection: verifyConnection,
}
initResp, err := dbw.Initialize(ctx, initReq)
if err != nil {
dbw.Close()
return logical.ErrorResponse("error creating database object: %s", err), nil
}
config.ConnectionDetails = initResp.Config
b.Lock()
defer b.Unlock()
// Close and remove the old connection
b.clearConnection(name)
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
b.connections[name] = &dbPluginInstance{
Database: db,
database: dbw,
name: name,
id: id,
}
// Store it
entry, err = logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
err = storeConfig(ctx, req.Storage, name, config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
resp := &logical.Response{}
@@ -346,10 +358,30 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
}
}
// If using a legacy DB plugin and set the `password_policy` field, send a warning to the user indicating
// the `password_policy` will not be used
if dbw.isV4() && config.PasswordPolicy != "" {
resp.AddWarning(fmt.Sprintf("%s does not support password policies - upgrade to the latest version of "+
"Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName))
}
return resp, nil
}
}
func storeConfig(ctx context.Context, storage logical.Storage, name string, config *DatabaseConfig) error {
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
if err != nil {
return fmt.Errorf("unable to marshal object to JSON: %w", err)
}
err = storage.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save object: %w", err)
}
return nil
}
const pathConfigConnectionHelpSyn = `
Configure connection details to a database plugin.
`

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
@@ -73,13 +73,13 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
}
// Get the Database object
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
dbi, err := b.GetConnection(ctx, req.Storage, role.DBName)
if err != nil {
return nil, err
}
db.RLock()
defer db.RUnlock()
dbi.RLock()
defer dbi.RUnlock()
ttl, _, err := framework.CalculateTTL(b.System(), 0, role.DefaultTTL, 0, role.MaxTTL, 0, time.Time{})
if err != nil {
@@ -90,27 +90,44 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
// to ensure the database credential does not expire before the lease
expiration = expiration.Add(5 * time.Second)
usernameConfig := dbplugin.UsernameConfig{
DisplayName: req.DisplayName,
RoleName: name,
password, err := dbi.database.GeneratePassword(ctx, b.System(), dbConfig.PasswordPolicy)
if err != nil {
return nil, fmt.Errorf("unable to generate password: %w", err)
}
// Create the user
username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration)
newUserReq := newdbplugin.NewUserRequest{
UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: req.DisplayName,
RoleName: name,
},
Statements: newdbplugin.Statements{
Commands: role.Statements.Creation,
},
RollbackStatements: newdbplugin.Statements{
Commands: role.Statements.Rollback,
},
Password: password,
Expiration: expiration,
}
// Overwriting the password in the event this is a legacy database plugin and the provided password is ignored
newUserResp, password, err := dbi.database.NewUser(ctx, newUserReq)
if err != nil {
b.CloseIfShutdown(db, err)
b.CloseIfShutdown(dbi, err)
return nil, err
}
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username,
respData := map[string]interface{}{
"username": newUserResp.Username,
"password": password,
}, map[string]interface{}{
"username": username,
}
internal := map[string]interface{}{
"username": newUserResp.Username,
"role": name,
"db_name": role.DBName,
"revocation_statements": role.Statements.Revocation,
})
}
resp := b.Secret(SecretCredsType).Response(respData, internal)
resp.Secret.TTL = role.DefaultTTL
resp.Secret.MaxTTL = role.MaxTTL
return resp, nil

View File

@@ -5,16 +5,13 @@ import (
"fmt"
"time"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/queue"
)
func pathRotateCredentials(b *databaseBackend) []*framework.Path {
func pathRotateRootCredentials(b *databaseBackend) []*framework.Path {
return []*framework.Path{
&framework.Path{
Pattern: "rotate-root/" + framework.GenericNameRegex("name"),
@@ -27,7 +24,7 @@ func pathRotateCredentials(b *databaseBackend) []*framework.Path {
Operations: map[logical.Operation]framework.OperationHandler{
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathRotateCredentialsUpdate(),
Callback: b.pathRotateRootCredentialsUpdate(),
ForwardPerformanceSecondary: true,
ForwardPerformanceStandby: true,
},
@@ -59,7 +56,7 @@ func pathRotateCredentials(b *databaseBackend) []*framework.Path {
}
}
func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc {
func (b *databaseBackend) pathRotateRootCredentialsUpdate() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
@@ -71,15 +68,15 @@ func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc
return nil, err
}
db, err := b.GetConnection(ctx, req.Storage, name)
dbi, err := b.GetConnection(ctx, req.Storage, name)
if err != nil {
return nil, err
}
defer func() {
// Close the plugin
db.closed = true
if err := db.Database.Close(); err != nil {
dbi.closed = true
if err := dbi.database.Close(); err != nil {
b.Logger().Error("error closing the database plugin connection", "err", err)
}
// Even on error, still remove the connection
@@ -91,13 +88,13 @@ func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc
defer b.Unlock()
// Take the write lock on the instance
db.Lock()
defer db.Unlock()
dbi.Lock()
defer dbi.Unlock()
// Generate new credentials
userName := config.ConnectionDetails["username"].(string)
username := config.ConnectionDetails["username"].(string)
oldPassword := config.ConnectionDetails["password"].(string)
newPassword, err := db.GenerateCredentials(ctx)
newPassword, err := dbi.database.GeneratePassword(ctx, b.System(), config.PasswordPolicy)
if err != nil {
return nil, err
}
@@ -106,7 +103,7 @@ func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc
// Write a WAL entry
walID, err := framework.PutWAL(ctx, req.Storage, rotateRootWALKey, &rotateRootCredentialsWAL{
ConnectionName: name,
UserName: userName,
UserName: username,
OldPassword: oldPassword,
NewPassword: newPassword,
})
@@ -114,37 +111,32 @@ func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc
return nil, err
}
// Attempt to use SetCredentials for the root credential rotation
statements := dbplugin.Statements{Rotation: config.RootCredentialsRotateStatements}
userConfig := dbplugin.StaticUserConfig{
Username: userName,
Password: newPassword,
updateReq := newdbplugin.UpdateUserRequest{
Username: username,
Password: &newdbplugin.ChangePassword{
NewPassword: newPassword,
Statements: newdbplugin.Statements{
Commands: config.RootCredentialsRotateStatements,
},
},
}
if _, _, err := db.SetCredentials(ctx, statements, userConfig); err != nil {
if status.Code(err) == codes.Unimplemented {
// Fall back to using RotateRootCredentials if unimplemented
config.ConnectionDetails, err = db.RotateRootCredentials(ctx,
config.RootCredentialsRotateStatements)
}
if err != nil {
return nil, err
}
newConfigDetails, err := dbi.database.UpdateUser(ctx, updateReq, true)
if err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
}
if newConfigDetails != nil {
config.ConnectionDetails = newConfigDetails
}
// Update storage with the new root credentials
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
err = storeConfig(ctx, req.Storage, name, config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
// Delete the WAL entry after successfully rotating root credentials
if err := framework.DeleteWAL(ctx, req.Storage, walID); err != nil {
err = framework.DeleteWAL(ctx, req.Storage, walID)
if err != nil {
b.Logger().Warn("unable to delete WAL", "error", err, "WAL ID", walID)
}
return nil, nil
}
}

View File

@@ -4,7 +4,7 @@ import (
"context"
"errors"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/logical"
"github.com/mitchellh/mapstructure"
"google.golang.org/grpc/codes"
@@ -73,13 +73,13 @@ func (b *databaseBackend) walRollback(ctx context.Context, req *logical.Request,
}
// rollbackDatabaseCredentials rolls back root database credentials for
// the connection associated with the passed WAL entry. It will creates
// the connection associated with the passed WAL entry. It will create
// a connection to the database using the WAL entry new password in
// order to alter the password to be the WAL entry old password.
func (b *databaseBackend) rollbackDatabaseCredentials(ctx context.Context, config *DatabaseConfig, entry rotateRootCredentialsWAL) error {
// Attempt to get a connection with the WAL entry new password.
config.ConnectionDetails["password"] = entry.NewPassword
dbc, err := b.GetConnectionWithConfig(ctx, entry.ConnectionName, config)
dbi, err := b.GetConnectionWithConfig(ctx, entry.ConnectionName, config)
if err != nil {
return err
}
@@ -91,22 +91,21 @@ func (b *databaseBackend) rollbackDatabaseCredentials(ctx context.Context, confi
}
}()
// Roll back the database password to the WAL entry old password
statements := dbplugin.Statements{Rotation: config.RootCredentialsRotateStatements}
userConfig := dbplugin.StaticUserConfig{
updateReq := newdbplugin.UpdateUserRequest{
Username: entry.UserName,
Password: entry.OldPassword,
}
if _, _, err := dbc.SetCredentials(ctx, statements, userConfig); err != nil {
// If the database plugin doesn't implement SetCredentials, the root
// credentials can't be rolled back. This means the root credential
// rotation happened via the plugin RotateRootCredentials RPC.
if status.Code(err) == codes.Unimplemented {
return nil
}
return err
Password: &newdbplugin.ChangePassword{
NewPassword: entry.OldPassword,
Statements: newdbplugin.Statements{
Commands: config.RootCredentialsRotateStatements,
},
},
}
return nil
// It actually is the root user here, but we only want to use SetCredentials since
// RotateRootCredentials doesn't give any control over what password is used
_, err = dbi.database.UpdateUser(ctx, updateReq, false)
if status.Code(err) == codes.Unimplemented {
return nil
}
return err
}

View File

@@ -4,10 +4,11 @@ import (
"context"
"strings"
"testing"
"time"
"github.com/hashicorp/vault/helper/namespace"
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -92,18 +93,22 @@ func TestBackend_RotateRootCredentials_WAL_rollback(t *testing.T) {
}
// Get a connection to the database plugin
pc, err := dbBackend.GetConnection(context.Background(),
dbi, err := dbBackend.GetConnection(context.Background(),
config.StorageView, "plugin-test")
if err != nil {
t.Fatal(err)
}
// Alter the database password so it no longer matches what is in storage
_, _, err = pc.SetCredentials(context.Background(), dbplugin.Statements{},
dbplugin.StaticUserConfig{
Username: databaseUser,
Password: "newSecret",
})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
updateReq := newdbplugin.UpdateUserRequest{
Username: databaseUser,
Password: &newdbplugin.ChangePassword{
NewPassword: "newSecret",
},
}
_, err = dbi.database.UpdateUser(ctx, updateReq, false)
if err != nil {
t.Fatal(err)
}
@@ -335,17 +340,21 @@ func TestBackend_RotateRootCredentials_WAL_no_rollback_2(t *testing.T) {
}
// Get a connection to the database plugin
pc, err := dbBackend.GetConnection(context.Background(), config.StorageView, "plugin-test")
dbi, err := dbBackend.GetConnection(context.Background(), config.StorageView, "plugin-test")
if err != nil {
t.Fatal(err)
}
// Alter the database password
_, _, err = pc.SetCredentials(context.Background(), dbplugin.Statements{},
dbplugin.StaticUserConfig{
Username: databaseUser,
Password: "newSecret",
})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
updateReq := newdbplugin.UpdateUserRequest{
Username: databaseUser,
Password: &newdbplugin.ChangePassword{
NewPassword: "newSecret",
},
}
_, err = dbi.database.UpdateUser(ctx, updateReq, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/locksutil"
@@ -319,21 +320,20 @@ func (b *databaseBackend) setStaticAccount(ctx context.Context, s logical.Storag
}
// Get the Database object
db, err := b.GetConnection(ctx, s, input.Role.DBName)
dbi, err := b.GetConnection(ctx, s, input.Role.DBName)
if err != nil {
return output, err
}
db.RLock()
defer db.RUnlock()
dbi.RLock()
defer dbi.RUnlock()
// Use password from input if available. This happens if we're restoring from
// a WAL item or processing the rotation queue with an item that has a WAL
// associated with it
newPassword := input.Password
if newPassword == "" {
// Generate a new password
newPassword, err = db.GenerateCredentials(ctx)
newPassword, err = dbi.database.GeneratePassword(ctx, b.System(), dbConfig.PasswordPolicy)
if err != nil {
return output, err
}
@@ -358,21 +358,26 @@ func (b *databaseBackend) setStaticAccount(ctx context.Context, s logical.Storag
}
}
_, password, err := db.SetCredentials(ctx, input.Role.Statements, config)
if err != nil {
b.CloseIfShutdown(db, err)
return output, errwrap.Wrapf("error setting credentials: {{err}}", err)
updateReq := newdbplugin.UpdateUserRequest{
Username: input.Role.StaticAccount.Username,
Password: &newdbplugin.ChangePassword{
NewPassword: newPassword,
Statements: newdbplugin.Statements{
Commands: input.Role.Statements.Rotation,
},
},
}
if newPassword != password {
return output, errors.New("mismatch passwords returned")
_, err = dbi.database.UpdateUser(ctx, updateReq, false)
if err != nil {
b.CloseIfShutdown(dbi, err)
return output, errwrap.Wrapf("error setting credentials: {{err}}", err)
}
// Store updated role information
// lvr is the known LastVaultRotation
lvr := time.Now()
input.Role.StaticAccount.LastVaultRotation = lvr
input.Role.StaticAccount.Password = password
input.Role.StaticAccount.Password = newPassword
output.RotationTime = lvr
entry, err := logical.StorageEntryJSON(databaseStaticRolePath+input.RoleName, input.Role)
@@ -393,7 +398,7 @@ func (b *databaseBackend) setStaticAccount(ctx context.Context, s logical.Storag
return &setStaticAccountOutput{RotationTime: lvr}, merr
}
// initQueue preforms the necessary checks and initializations needed to preform
// initQueue preforms the necessary checks and initializations needed to perform
// automatic credential rotation for roles associated with static accounts. This
// method verifies if a queue is needed (primary server or local mount), and if
// so initializes the queue and launches a go-routine to periodically invoke a

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -45,13 +46,13 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
}
// Get the Database object
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
dbi, err := b.GetConnection(ctx, req.Storage, role.DBName)
if err != nil {
return nil, err
}
db.RLock()
defer db.RUnlock()
dbi.RLock()
defer dbi.RUnlock()
// Make sure we increase the VALID UNTIL endpoint for this user.
ttl, _, err := framework.CalculateTTL(b.System(), req.Secret.Increment, role.DefaultTTL, 0, role.MaxTTL, 0, req.Secret.IssueTime)
@@ -63,9 +64,19 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
// Adding a small buffer since the TTL will be calculated again after this call
// to ensure the database credential does not expire before the lease
expireTime = expireTime.Add(5 * time.Second)
err := db.RenewUser(ctx, role.Statements, username, expireTime)
updateReq := newdbplugin.UpdateUserRequest{
Username: username,
Expiration: &newdbplugin.ChangeExpiration{
NewExpiration: expireTime,
Statements: newdbplugin.Statements{
Commands: role.Statements.Renewal,
},
},
}
_, err := dbi.database.UpdateUser(ctx, updateReq, false)
if err != nil {
b.CloseIfShutdown(db, err)
b.CloseIfShutdown(dbi, err)
return nil, err
}
}
@@ -103,41 +114,49 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
dbName = role.DBName
statements = role.Statements
} else {
if dbNameRaw, ok := req.Secret.InternalData["db_name"]; !ok {
dbNameRaw, ok := req.Secret.InternalData["db_name"]
if !ok {
return nil, fmt.Errorf("error during revoke: could not find role with name %q or embedded revocation db name data", req.Secret.InternalData["role"])
} else {
dbName = dbNameRaw.(string)
}
if statementsRaw, ok := req.Secret.InternalData["revocation_statements"]; !ok {
dbName = dbNameRaw.(string)
statementsRaw, ok := req.Secret.InternalData["revocation_statements"]
if !ok {
return nil, fmt.Errorf("error during revoke: could not find role with name %q or embedded revocation statement data", req.Secret.InternalData["role"])
} else {
// If we don't actually have any statements, because none were
// set in the role, we'll end up with an empty one and the
// default for the db type will be attempted
if statementsRaw != nil {
statementsSlice, ok := statementsRaw.([]interface{})
if !ok {
return nil, fmt.Errorf("error during revoke: could not find role with name %q and embedded reovcation data could not be read", req.Secret.InternalData["role"])
} else {
for _, v := range statementsSlice {
statements.Revocation = append(statements.Revocation, v.(string))
}
}
}
// If we don't actually have any statements, because none were
// set in the role, we'll end up with an empty one and the
// default for the db type will be attempted
if statementsRaw != nil {
statementsSlice, ok := statementsRaw.([]interface{})
if !ok {
return nil, fmt.Errorf("error during revoke: could not find role with name %q and embedded reovcation data could not be read", req.Secret.InternalData["role"])
}
for _, v := range statementsSlice {
statements.Revocation = append(statements.Revocation, v.(string))
}
}
}
// Get our connection
db, err := b.GetConnection(ctx, req.Storage, dbName)
dbi, err := b.GetConnection(ctx, req.Storage, dbName)
if err != nil {
return nil, err
}
db.RLock()
defer db.RUnlock()
dbi.RLock()
defer dbi.RUnlock()
if err := db.RevokeUser(ctx, statements, username); err != nil {
b.CloseIfShutdown(db, err)
deleteReq := newdbplugin.DeleteUserRequest{
Username: username,
Statements: newdbplugin.Statements{
Commands: statements.Revocation,
},
}
_, err = dbi.database.DeleteUser(ctx, deleteReq)
if err != nil {
b.CloseIfShutdown(dbi, err)
return nil, err
}
return resp, nil

View File

@@ -0,0 +1,268 @@
package database
import (
"context"
"crypto/rand"
"fmt"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/random"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type databaseVersionWrapper struct {
v4 dbplugin.Database
v5 newdbplugin.Database
}
// newDatabaseWrapper figures out which version of the database the pluginName is referring to and returns a wrapper object
// that can be used to make operations on the underlying database plugin.
func newDatabaseWrapper(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (dbw databaseVersionWrapper, err error) {
newDB, err := newdbplugin.PluginFactory(ctx, pluginName, sys, logger)
if err == nil {
dbw = databaseVersionWrapper{
v5: newDB,
}
return dbw, nil
}
legacyDB, err := dbplugin.PluginFactory(ctx, pluginName, sys, logger)
if err == nil {
dbw = databaseVersionWrapper{
v4: legacyDB,
}
return dbw, nil
}
return dbw, fmt.Errorf("invalid database version")
}
// Initialize the underlying database. This is analogous to a constructor on the database plugin object.
// Errors if the wrapper does not contain an underlying database.
func (d databaseVersionWrapper) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
if !d.isV5() && !d.isV4() {
return newdbplugin.InitializeResponse{}, fmt.Errorf("no underlying database specified")
}
// v5 Database
if d.isV5() {
return d.v5.Initialize(ctx, req)
}
// v4 Database
saveConfig, err := d.v4.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return newdbplugin.InitializeResponse{}, err
}
resp := newdbplugin.InitializeResponse{
Config: saveConfig,
}
return resp, nil
}
// NewUser in the database. This is different from the v5 Database in that it returns a password as well.
// This is done because the v4 Database is expected to generate a password and return it. The NewUserResponse
// does not have a way of returning the password so this function signature needs to be different.
// The password returned here should be considered the source of truth, not the provided password.
// Errors if the wrapper does not contain an underlying database.
func (d databaseVersionWrapper) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (resp newdbplugin.NewUserResponse, password string, err error) {
if !d.isV5() && !d.isV4() {
return newdbplugin.NewUserResponse{}, "", fmt.Errorf("no underlying database specified")
}
// v5 Database
if d.isV5() {
resp, err = d.v5.NewUser(ctx, req)
return resp, req.Password, err
}
// v4 Database
stmts := dbplugin.Statements{
Creation: req.Statements.Commands,
Rollback: req.RollbackStatements.Commands,
}
usernameConfig := dbplugin.UsernameConfig{
DisplayName: req.UsernameConfig.DisplayName,
RoleName: req.UsernameConfig.RoleName,
}
username, password, err := d.v4.CreateUser(ctx, stmts, usernameConfig, req.Expiration)
if err != nil {
return resp, "", err
}
resp = newdbplugin.NewUserResponse{
Username: username,
}
return resp, password, nil
}
// UpdateUser in the underlying database. This is used to update any information currently supported
// in the UpdateUserRequest such as password credentials or user TTL.
// Errors if the wrapper does not contain an underlying database.
func (d databaseVersionWrapper) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest, isRootUser bool) (saveConfig map[string]interface{}, err error) {
if !d.isV5() && !d.isV4() {
return nil, fmt.Errorf("no underlying database specified")
}
// v5 Database
if d.isV5() {
_, err := d.v5.UpdateUser(ctx, req)
return nil, err
}
// v4 Database
if req.Password == nil && req.Expiration == nil {
return nil, fmt.Errorf("missing change to be sent to the database")
}
if req.Password != nil && req.Expiration != nil {
// We could support this, but it would require handling partial
// errors which I'm punting on since we don't need it for now
return nil, fmt.Errorf("cannot specify both password and expiration change at the same time")
}
// Change password
if req.Password != nil {
return d.changePasswordLegacy(ctx, req.Username, req.Password, isRootUser)
}
// Change expiration date
if req.Expiration != nil {
stmts := dbplugin.Statements{
Renewal: req.Expiration.Statements.Commands,
}
err := d.v4.RenewUser(ctx, stmts, req.Username, req.Expiration.NewExpiration)
return nil, err
}
return nil, nil
}
// changePasswordLegacy attempts to use SetCredentials to change the password for the user with the password provided
// in ChangePassword. If that user is the root user and SetCredentials is unimplemented, it will fall back to using
// RotateRootCredentials. If not the root user, this will not use RotateRootCredentials.
func (d databaseVersionWrapper) changePasswordLegacy(ctx context.Context, username string, passwordChange *newdbplugin.ChangePassword, isRootUser bool) (saveConfig map[string]interface{}, err error) {
err = d.changeUserPasswordLegacy(ctx, username, passwordChange)
// If changing the root user's password but SetCredentials is unimplemented, fall back to RotateRootCredentials
if isRootUser && status.Code(err) == codes.Unimplemented {
saveConfig, err = d.changeRootUserPasswordLegacy(ctx, passwordChange)
if err != nil {
return nil, err
}
return saveConfig, nil
}
if err != nil {
return nil, err
}
return nil, nil
}
func (d databaseVersionWrapper) changeUserPasswordLegacy(ctx context.Context, username string, passwordChange *newdbplugin.ChangePassword) (err error) {
stmts := dbplugin.Statements{
Rotation: passwordChange.Statements.Commands,
}
staticConfig := dbplugin.StaticUserConfig{
Username: username,
Password: passwordChange.NewPassword,
}
_, _, err = d.v4.SetCredentials(ctx, stmts, staticConfig)
return err
}
func (d databaseVersionWrapper) changeRootUserPasswordLegacy(ctx context.Context, passwordChange *newdbplugin.ChangePassword) (saveConfig map[string]interface{}, err error) {
return d.v4.RotateRootCredentials(ctx, passwordChange.Statements.Commands)
}
// DeleteUser in the underlying database. Errors if the wrapper does not contain an underlying database.
func (d databaseVersionWrapper) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
if !d.isV5() && !d.isV4() {
return newdbplugin.DeleteUserResponse{}, fmt.Errorf("no underlying database specified")
}
// v5 Database
if d.isV5() {
return d.v5.DeleteUser(ctx, req)
}
// v4 Database
stmts := dbplugin.Statements{
Revocation: req.Statements.Commands,
}
err := d.v4.RevokeUser(ctx, stmts, req.Username)
return newdbplugin.DeleteUserResponse{}, err
}
// Type of the underlying database. Errors if the wrapper does not contain an underlying database.
func (d databaseVersionWrapper) Type() (string, error) {
if !d.isV5() && !d.isV4() {
return "", fmt.Errorf("no underlying database specified")
}
// v5 Database
if d.isV5() {
return d.v5.Type()
}
// v4 Database
return d.v4.Type()
}
// Close the underlying database. Errors if the wrapper does not contain an underlying database.
func (d databaseVersionWrapper) Close() error {
if !d.isV5() && !d.isV4() {
return fmt.Errorf("no underlying database specified")
}
// v5 Database
if d.isV5() {
return d.v5.Close()
}
// v4 Database
return d.v4.Close()
}
// /////////////////////////////////////////////////////////////////////////////////
// Password generation
// /////////////////////////////////////////////////////////////////////////////////
type passwordGenerator interface {
GeneratePasswordFromPolicy(ctx context.Context, policyName string) (password string, err error)
}
var (
defaultPasswordGenerator = random.DefaultStringGenerator
)
// GeneratePassword either from the v4 database or by using the provided password policy. If using a v5 database
// and no password policy is specified, this will have a reasonable default password generator.
func (d databaseVersionWrapper) GeneratePassword(ctx context.Context, generator passwordGenerator, passwordPolicy string) (password string, err error) {
if !d.isV5() && !d.isV4() {
return "", fmt.Errorf("no underlying database specified")
}
// If using the legacy database, use GenerateCredentials instead of password policies
// This will keep the existing behavior even though passwords can be generated with a policy
if d.isV4() {
password, err := d.v4.GenerateCredentials(ctx)
if err != nil {
return "", err
}
return password, nil
}
if passwordPolicy == "" {
return defaultPasswordGenerator.Generate(ctx, rand.Reader)
}
return generator.GeneratePasswordFromPolicy(ctx, passwordPolicy)
}
func (d databaseVersionWrapper) isV5() bool {
return d.v5 != nil
}
func (d databaseVersionWrapper) isV4() bool {
return d.v4 != nil
}

View File

@@ -0,0 +1,994 @@
package database
import (
"context"
"fmt"
"reflect"
"regexp"
"testing"
"time"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestInitDatabase_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
req := newdbplugin.InitializeRequest{}
resp, err := dbw.Initialize(context.Background(), req)
if err == nil {
t.Fatalf("err expected, got nil")
}
expectedResp := newdbplugin.InitializeResponse{}
if !reflect.DeepEqual(resp, expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, expectedResp)
}
}
func TestInitDatabase_newDB(t *testing.T) {
type testCase struct {
req newdbplugin.InitializeRequest
newInitResp newdbplugin.InitializeResponse
newInitErr error
newInitCalls int
expectedResp newdbplugin.InitializeResponse
expectErr bool
}
tests := map[string]testCase{
"success": {
req: newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
VerifyConnection: true,
},
newInitResp: newdbplugin.InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
newInitCalls: 1,
expectedResp: newdbplugin.InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
expectErr: false,
},
"error": {
req: newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
VerifyConnection: true,
},
newInitResp: newdbplugin.InitializeResponse{},
newInitErr: fmt.Errorf("test error"),
newInitCalls: 1,
expectedResp: newdbplugin.InitializeResponse{},
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
newDB.On("Initialize", mock.Anything, mock.Anything).
Return(test.newInitResp, test.newInitErr)
defer newDB.AssertNumberOfCalls(t, "Initialize", test.newInitCalls)
dbw := databaseVersionWrapper{
v5: newDB,
}
resp, err := dbw.Initialize(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
}
})
}
}
func TestInitDatabase_legacyDB(t *testing.T) {
type testCase struct {
req newdbplugin.InitializeRequest
initConfig map[string]interface{}
initErr error
initCalls int
expectedResp newdbplugin.InitializeResponse
expectErr bool
}
tests := map[string]testCase{
"success": {
req: newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
VerifyConnection: true,
},
initConfig: map[string]interface{}{
"foo": "bar",
},
initCalls: 1,
expectedResp: newdbplugin.InitializeResponse{
Config: map[string]interface{}{
"foo": "bar",
},
},
expectErr: false,
},
"error": {
req: newdbplugin.InitializeRequest{
Config: map[string]interface{}{
"foo": "bar",
},
VerifyConnection: true,
},
initErr: fmt.Errorf("test error"),
initCalls: 1,
expectedResp: newdbplugin.InitializeResponse{},
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("Init", mock.Anything, mock.Anything, mock.Anything).
Return(test.initConfig, test.initErr)
defer legacyDB.AssertNumberOfCalls(t, "Init", test.initCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
resp, err := dbw.Initialize(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
}
})
}
}
type fakePasswordGenerator struct {
password string
err error
}
func (pg fakePasswordGenerator) GeneratePasswordFromPolicy(ctx context.Context, policy string) (string, error) {
return pg.password, pg.err
}
func TestGeneratePassword_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
gen := fakePasswordGenerator{
err: fmt.Errorf("this shouldn't be called"),
}
pass, err := dbw.GeneratePassword(context.Background(), gen, "policy")
if err == nil {
t.Fatalf("err expected, got nil")
}
if pass != "" {
t.Fatalf("Password should be empty but was: %s", pass)
}
}
func TestGeneratePassword_legacy(t *testing.T) {
type testCase struct {
legacyPassword string
legacyErr error
legacyCalls int
expectedPassword string
expectErr bool
}
tests := map[string]testCase{
"legacy password generation": {
legacyPassword: "legacy_password",
legacyErr: nil,
legacyCalls: 1,
expectedPassword: "legacy_password",
expectErr: false,
},
"legacy password failure": {
legacyPassword: "",
legacyErr: fmt.Errorf("failed :("),
legacyCalls: 1,
expectedPassword: "",
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("GenerateCredentials", mock.Anything).
Return(test.legacyPassword, test.legacyErr)
defer legacyDB.AssertNumberOfCalls(t, "GenerateCredentials", test.legacyCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
passGen := fakePasswordGenerator{
err: fmt.Errorf("this should not be called"),
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
password, err := dbw.GeneratePassword(ctx, passGen, "test_policy")
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if password != test.expectedPassword {
t.Fatalf("Actual password: %s Expected password: %s", password, test.expectedPassword)
}
})
}
}
func TestGeneratePassword_policies(t *testing.T) {
type testCase struct {
passwordPolicyPassword string
passwordPolicyErr error
expectedPassword string
expectErr bool
}
tests := map[string]testCase{
"password policy generation": {
passwordPolicyPassword: "new_password",
expectedPassword: "new_password",
expectErr: false,
},
"password policy error": {
passwordPolicyPassword: "",
passwordPolicyErr: fmt.Errorf("test error"),
expectedPassword: "",
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
defer newDB.AssertExpectations(t)
dbw := databaseVersionWrapper{
v5: newDB,
}
passGen := fakePasswordGenerator{
password: test.passwordPolicyPassword,
err: test.passwordPolicyErr,
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
password, err := dbw.GeneratePassword(ctx, passGen, "test_policy")
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if password != test.expectedPassword {
t.Fatalf("Actual password: %s Expected password: %s", password, test.expectedPassword)
}
})
}
}
func TestGeneratePassword_no_policy(t *testing.T) {
newDB := new(mockNewDatabase)
defer newDB.AssertExpectations(t)
dbw := databaseVersionWrapper{
v5: newDB,
}
passGen := fakePasswordGenerator{
password: "",
err: fmt.Errorf("should not be called"),
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
password, err := dbw.GeneratePassword(ctx, passGen, "")
if err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if password == "" {
t.Fatalf("missing password")
}
rawRegex := "^[a-zA-Z0-9-]{20}$"
re := regexp.MustCompile(rawRegex)
if !re.MatchString(password) {
t.Fatalf("password %q did not match regex: %q", password, rawRegex)
}
}
func TestNewUser_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
req := newdbplugin.NewUserRequest{}
resp, pass, err := dbw.NewUser(context.Background(), req)
if err == nil {
t.Fatalf("err expected, got nil")
}
expectedResp := newdbplugin.NewUserResponse{}
if !reflect.DeepEqual(resp, expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, expectedResp)
}
if pass != "" {
t.Fatalf("Password should be empty but was: %s", pass)
}
}
func TestNewUser_newDB(t *testing.T) {
type testCase struct {
req newdbplugin.NewUserRequest
newUserResp newdbplugin.NewUserResponse
newUserErr error
newUserCalls int
expectedResp newdbplugin.NewUserResponse
expectErr bool
}
tests := map[string]testCase{
"success": {
req: newdbplugin.NewUserRequest{
Password: "new_password",
},
newUserResp: newdbplugin.NewUserResponse{
Username: "newuser",
},
newUserCalls: 1,
expectedResp: newdbplugin.NewUserResponse{
Username: "newuser",
},
expectErr: false,
},
"error": {
req: newdbplugin.NewUserRequest{
Password: "new_password",
},
newUserErr: fmt.Errorf("test error"),
newUserCalls: 1,
expectedResp: newdbplugin.NewUserResponse{},
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
newDB.On("NewUser", mock.Anything, mock.Anything).
Return(test.newUserResp, test.newUserErr)
defer newDB.AssertNumberOfCalls(t, "NewUser", test.newUserCalls)
dbw := databaseVersionWrapper{
v5: newDB,
}
resp, password, err := dbw.NewUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
}
if password != test.req.Password {
t.Fatalf("Actual password: %s Expected password: %s", password, test.req.Password)
}
})
}
}
func TestNewUser_legacyDB(t *testing.T) {
type testCase struct {
req newdbplugin.NewUserRequest
createUserUsername string
createUserPassword string
createUserErr error
createUserCalls int
expectedResp newdbplugin.NewUserResponse
expectedPassword string
expectErr bool
}
tests := map[string]testCase{
"success": {
req: newdbplugin.NewUserRequest{
Password: "new_password",
},
createUserUsername: "newuser",
createUserPassword: "securepassword",
createUserCalls: 1,
expectedResp: newdbplugin.NewUserResponse{
Username: "newuser",
},
expectedPassword: "securepassword",
expectErr: false,
},
"error": {
req: newdbplugin.NewUserRequest{
Password: "new_password",
},
createUserErr: fmt.Errorf("test error"),
createUserCalls: 1,
expectedResp: newdbplugin.NewUserResponse{},
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("CreateUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(test.createUserUsername, test.createUserPassword, test.createUserErr)
defer legacyDB.AssertNumberOfCalls(t, "CreateUser", test.createUserCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
resp, password, err := dbw.NewUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(resp, test.expectedResp) {
t.Fatalf("Actual resp: %#v\nExpected resp: %#v", resp, test.expectedResp)
}
if password != test.expectedPassword {
t.Fatalf("Actual password: %s Expected password: %s", password, test.req.Password)
}
})
}
}
func TestUpdateUser_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
req := newdbplugin.UpdateUserRequest{}
resp, err := dbw.UpdateUser(context.Background(), req, false)
if err == nil {
t.Fatalf("err expected, got nil")
}
expectedConfig := map[string]interface{}(nil)
if !reflect.DeepEqual(resp, expectedConfig) {
t.Fatalf("Actual config: %#v\nExpected config: %#v", resp, expectedConfig)
}
}
func TestUpdateUser_newDB(t *testing.T) {
type testCase struct {
req newdbplugin.UpdateUserRequest
updateUserErr error
updateUserCalls int
expectedResp newdbplugin.UpdateUserResponse
expectErr bool
}
tests := map[string]testCase{
"success": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
},
updateUserCalls: 1,
expectErr: false,
},
"error": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
},
updateUserErr: fmt.Errorf("test error"),
updateUserCalls: 1,
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
newDB.On("UpdateUser", mock.Anything, mock.Anything).
Return(newdbplugin.UpdateUserResponse{}, test.updateUserErr)
defer newDB.AssertNumberOfCalls(t, "UpdateUser", test.updateUserCalls)
dbw := databaseVersionWrapper{
v5: newDB,
}
_, err := dbw.UpdateUser(context.Background(), test.req, false)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
func TestUpdateUser_legacyDB(t *testing.T) {
type testCase struct {
req newdbplugin.UpdateUserRequest
isRootUser bool
setCredentialsErr error
setCredentialsCalls int
rotateRootConfig map[string]interface{}
rotateRootErr error
rotateRootCalls int
renewUserErr error
renewUserCalls int
expectedConfig map[string]interface{}
expectErr bool
}
tests := map[string]testCase{
"missing changes": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
},
isRootUser: false,
setCredentialsCalls: 0,
rotateRootCalls: 0,
renewUserCalls: 0,
expectErr: true,
},
"both password and expiration changes": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
Password: &newdbplugin.ChangePassword{},
Expiration: &newdbplugin.ChangeExpiration{},
},
isRootUser: false,
setCredentialsCalls: 0,
rotateRootCalls: 0,
renewUserCalls: 0,
expectErr: true,
},
"change password - SetCredentials": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
Password: &newdbplugin.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: false,
setCredentialsErr: nil,
setCredentialsCalls: 1,
rotateRootCalls: 0,
renewUserCalls: 0,
expectedConfig: nil,
expectErr: false,
},
"change password - SetCredentials failed": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
Password: &newdbplugin.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: false,
setCredentialsErr: fmt.Errorf("set credentials failed"),
setCredentialsCalls: 1,
rotateRootCalls: 0,
renewUserCalls: 0,
expectedConfig: nil,
expectErr: true,
},
"change password - SetCredentials unimplemented but not a root user": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
Password: &newdbplugin.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: false,
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
setCredentialsCalls: 1,
rotateRootCalls: 0,
renewUserCalls: 0,
expectedConfig: nil,
expectErr: true,
},
"change password - RotateRootCredentials": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
Password: &newdbplugin.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: true,
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
setCredentialsCalls: 1,
rotateRootConfig: map[string]interface{}{
"foo": "bar",
},
rotateRootCalls: 1,
renewUserCalls: 0,
expectedConfig: map[string]interface{}{
"foo": "bar",
},
expectErr: false,
},
"change password - RotateRootCredentials failed": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
Password: &newdbplugin.ChangePassword{
NewPassword: "newpassowrd",
},
},
isRootUser: true,
setCredentialsErr: status.Error(codes.Unimplemented, "SetCredentials is not implemented"),
setCredentialsCalls: 1,
rotateRootErr: fmt.Errorf("rotate root failed"),
rotateRootCalls: 1,
renewUserCalls: 0,
expectedConfig: nil,
expectErr: true,
},
"change expiration": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
Expiration: &newdbplugin.ChangeExpiration{
NewExpiration: time.Now(),
},
},
isRootUser: false,
setCredentialsCalls: 0,
rotateRootCalls: 0,
renewUserErr: nil,
renewUserCalls: 1,
expectedConfig: nil,
expectErr: false,
},
"change expiration failed": {
req: newdbplugin.UpdateUserRequest{
Username: "existing_user",
Expiration: &newdbplugin.ChangeExpiration{
NewExpiration: time.Now(),
},
},
isRootUser: false,
setCredentialsCalls: 0,
rotateRootCalls: 0,
renewUserErr: fmt.Errorf("test error"),
renewUserCalls: 1,
expectedConfig: nil,
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("SetCredentials", mock.Anything, mock.Anything, mock.Anything).
Return("", "", test.setCredentialsErr)
defer legacyDB.AssertNumberOfCalls(t, "SetCredentials", test.setCredentialsCalls)
legacyDB.On("RotateRootCredentials", mock.Anything, mock.Anything).
Return(test.rotateRootConfig, test.rotateRootErr)
defer legacyDB.AssertNumberOfCalls(t, "RotateRootCredentials", test.rotateRootCalls)
legacyDB.On("RenewUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(test.renewUserErr)
defer legacyDB.AssertNumberOfCalls(t, "RenewUser", test.renewUserCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
newConfig, err := dbw.UpdateUser(context.Background(), test.req, test.isRootUser)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
if !reflect.DeepEqual(newConfig, test.expectedConfig) {
t.Fatalf("Actual config: %#v\nExpected config: %#v", newConfig, test.expectedConfig)
}
})
}
}
func TestDeleteUser_missingDB(t *testing.T) {
dbw := databaseVersionWrapper{}
req := newdbplugin.DeleteUserRequest{}
_, err := dbw.DeleteUser(context.Background(), req)
if err == nil {
t.Fatalf("err expected, got nil")
}
}
func TestDeleteUser_newDB(t *testing.T) {
type testCase struct {
req newdbplugin.DeleteUserRequest
deleteUserErr error
deleteUserCalls int
expectErr bool
}
tests := map[string]testCase{
"success": {
req: newdbplugin.DeleteUserRequest{
Username: "existing_user",
},
deleteUserErr: nil,
deleteUserCalls: 1,
expectErr: false,
},
"error": {
req: newdbplugin.DeleteUserRequest{
Username: "existing_user",
},
deleteUserErr: fmt.Errorf("test error"),
deleteUserCalls: 1,
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
newDB := new(mockNewDatabase)
newDB.On("DeleteUser", mock.Anything, mock.Anything).
Return(newdbplugin.DeleteUserResponse{}, test.deleteUserErr)
defer newDB.AssertNumberOfCalls(t, "DeleteUser", test.deleteUserCalls)
dbw := databaseVersionWrapper{
v5: newDB,
}
_, err := dbw.DeleteUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
func TestDeleteUser_legacyDB(t *testing.T) {
type testCase struct {
req newdbplugin.DeleteUserRequest
revokeUserErr error
revokeUserCalls int
expectErr bool
}
tests := map[string]testCase{
"success": {
req: newdbplugin.DeleteUserRequest{
Username: "existing_user",
},
revokeUserErr: nil,
revokeUserCalls: 1,
expectErr: false,
},
"error": {
req: newdbplugin.DeleteUserRequest{
Username: "existing_user",
},
revokeUserErr: fmt.Errorf("test error"),
revokeUserCalls: 1,
expectErr: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
legacyDB := new(mockLegacyDatabase)
legacyDB.On("RevokeUser", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(test.revokeUserErr)
defer legacyDB.AssertNumberOfCalls(t, "RevokeUser", test.revokeUserCalls)
dbw := databaseVersionWrapper{
v4: legacyDB,
}
_, err := dbw.DeleteUser(context.Background(), test.req)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}
type badValue struct{}
func (badValue) MarshalJSON() ([]byte, error) {
return nil, fmt.Errorf("this value cannot be marshalled to JSON")
}
var _ logical.Storage = fakeStorage{}
type fakeStorage struct {
putErr error
}
func (f fakeStorage) Put(ctx context.Context, entry *logical.StorageEntry) error {
return f.putErr
}
func (f fakeStorage) List(ctx context.Context, s string) ([]string, error) {
panic("list not implemented")
}
func (f fakeStorage) Get(ctx context.Context, s string) (*logical.StorageEntry, error) {
panic("get not implemented")
}
func (f fakeStorage) Delete(ctx context.Context, s string) error {
panic("delete not implemented")
}
func TestStoreConfig(t *testing.T) {
type testCase struct {
config *DatabaseConfig
putErr error
expectErr bool
}
tests := map[string]testCase{
"bad config": {
config: &DatabaseConfig{
PluginName: "testplugin",
ConnectionDetails: map[string]interface{}{
"bad value": badValue{},
},
},
putErr: nil,
expectErr: true,
},
"storage error": {
config: &DatabaseConfig{
PluginName: "testplugin",
ConnectionDetails: map[string]interface{}{
"foo": "bar",
},
},
putErr: fmt.Errorf("failed to store config"),
expectErr: true,
},
"happy path": {
config: &DatabaseConfig{
PluginName: "testplugin",
ConnectionDetails: map[string]interface{}{
"foo": "bar",
},
},
putErr: nil,
expectErr: false,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
storage := fakeStorage{
putErr: test.putErr,
}
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := storeConfig(ctx, storage, "testconfig", test.config)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
}
})
}
}

View File

@@ -0,0 +1,317 @@
package database
// This file contains all "large"/expensive tests. These are running requests against a running backend
import (
"context"
"fmt"
"os"
"regexp"
"strings"
"testing"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
)
func TestPlugin_lifecycle(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v4-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_MockV4", []string{}, "")
vault.TestAddTestPlugin(t, cluster.Cores[0].Core, "mock-v5-database-plugin", consts.PluginTypeDatabase, "TestBackend_PluginMain_MockV5", []string{}, "")
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
lb, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
b, ok := lb.(*databaseBackend)
if !ok {
t.Fatal("could not convert to database backend")
}
defer b.Cleanup(context.Background())
type testCase struct {
dbName string
dbType string
configData map[string]interface{}
assertDynamicUsername stringAssertion
assertDynamicPassword stringAssertion
}
tests := map[string]testCase{
"v4": {
dbName: "mockv4",
dbType: "mock-v4-database-plugin",
configData: map[string]interface{}{
"name": "mockv4",
"plugin_name": "mock-v4-database-plugin",
"connection_url": "sample_connection_url",
"verify_connection": true,
"allowed_roles": []string{"*"},
"username": "mockv4-user",
"password": "mysecurepassword",
},
assertDynamicUsername: assertStringPrefix("mockv4_user_"),
assertDynamicPassword: assertStringPrefix("mockv4_"),
},
"v5": {
dbName: "mockv5",
dbType: "mock-v5-database-plugin",
configData: map[string]interface{}{
"connection_url": "sample_connection_url",
"plugin_name": "mock-v5-database-plugin",
"verify_connection": true,
"allowed_roles": []string{"*"},
"name": "mockv5",
"username": "mockv5-user",
"password": "mysecurepassword",
},
assertDynamicUsername: assertStringPrefix("mockv5_user_"),
assertDynamicPassword: assertStringRegex("^[a-zA-Z0-9-]{20}"),
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
cleanupReqs := []*logical.Request{}
defer cleanup(t, b, cleanupReqs)
// /////////////////////////////////////////////////////////////////
// Configure
req := &logical.Request{
Operation: logical.CreateOperation,
Path: fmt.Sprintf("config/%s", test.dbName),
Storage: config.StorageView,
Data: test.configData,
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resp, err := b.HandleRequest(ctx, req)
assertErrIsNil(t, err)
assertRespHasNoErr(t, resp)
assertNoRespData(t, resp)
cleanupReqs = append(cleanupReqs, &logical.Request{
Operation: logical.DeleteOperation,
Path: fmt.Sprintf("config/%s", test.dbName),
Storage: config.StorageView,
})
// /////////////////////////////////////////////////////////////////
// Rotate root credentials
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: fmt.Sprintf("rotate-root/%s", test.dbName),
Storage: config.StorageView,
}
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resp, err = b.HandleRequest(ctx, req)
assertErrIsNil(t, err)
assertRespHasNoErr(t, resp)
assertNoRespData(t, resp)
// /////////////////////////////////////////////////////////////////
// Dynamic credentials
// Create role
dynamicRoleName := "dynamic-role"
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: fmt.Sprintf("roles/%s", dynamicRoleName),
Storage: config.StorageView,
Data: map[string]interface{}{
"db_name": test.dbName,
"default_ttl": "5s",
"max_ttl": "1m",
},
}
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resp, err = b.HandleRequest(ctx, req)
assertErrIsNil(t, err)
assertRespHasNoErr(t, resp)
assertNoRespData(t, resp)
cleanupReqs = append(cleanupReqs, &logical.Request{
Operation: logical.DeleteOperation,
Path: fmt.Sprintf("roles/%s", dynamicRoleName),
Storage: config.StorageView,
})
// Generate credentials
req = &logical.Request{
Operation: logical.ReadOperation,
Path: fmt.Sprintf("creds/%s", dynamicRoleName),
Storage: config.StorageView,
}
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resp, err = b.HandleRequest(ctx, req)
assertErrIsNil(t, err)
assertRespHasNoErr(t, resp)
assertRespHasData(t, resp)
// TODO: Figure out how to make a call to the cluster that gives back a lease ID
// And also rotates the secret out after its TTL
// /////////////////////////////////////////////////////////////////
// Static credentials
// Create static role
staticRoleName := "static-role"
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: fmt.Sprintf("static-roles/%s", staticRoleName),
Storage: config.StorageView,
Data: map[string]interface{}{
"db_name": test.dbName,
"username": "static-username",
"rotation_period": "5",
},
}
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resp, err = b.HandleRequest(ctx, req)
assertErrIsNil(t, err)
assertRespHasNoErr(t, resp)
assertNoRespData(t, resp)
cleanupReqs = append(cleanupReqs, &logical.Request{
Operation: logical.DeleteOperation,
Path: fmt.Sprintf("static-roles/%s", staticRoleName),
Storage: config.StorageView,
})
// Get credentials
req = &logical.Request{
Operation: logical.ReadOperation,
Path: fmt.Sprintf("static-creds/%s", staticRoleName),
Storage: config.StorageView,
}
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resp, err = b.HandleRequest(ctx, req)
assertErrIsNil(t, err)
assertRespHasNoErr(t, resp)
assertRespHasData(t, resp)
})
}
}
func cleanup(t *testing.T, b *databaseBackend, reqs []*logical.Request) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Go in stack order so it works similar to defer
for i := len(reqs) - 1; i >= 0; i-- {
req := reqs[i]
resp, err := b.HandleRequest(ctx, req)
if err != nil {
t.Fatalf("Error cleaning up: %s", err)
}
if resp != nil && resp.IsError() {
t.Fatalf("Error cleaning up: %s", resp.Error())
}
}
}
func TestBackend_PluginMain_MockV4(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
if caPEM == "" {
t.Fatal("CA cert not passed in")
}
args := []string{"--ca-cert=" + caPEM}
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
RunV4(apiClientMeta.GetTLSConfig())
}
func TestBackend_PluginMain_MockV5(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
if caPEM == "" {
t.Fatal("CA cert not passed in")
}
args := []string{"--ca-cert=" + caPEM}
apiClientMeta := &api.PluginAPIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
RunV5(apiClientMeta.GetTLSConfig())
}
func assertNoRespData(t *testing.T, resp *logical.Response) {
t.Helper()
if resp != nil && len(resp.Data) > 0 {
t.Fatalf("Response had data when none was expected: %#v", resp.Data)
}
}
func assertRespHasData(t *testing.T, resp *logical.Response) {
t.Helper()
if resp == nil || len(resp.Data) == 0 {
t.Fatalf("Response didn't have any data when some was expected")
}
}
type stringAssertion func(t *testing.T, str string)
func assertStringPrefix(expectedPrefix string) stringAssertion {
return func(t *testing.T, str string) {
t.Helper()
if !strings.HasPrefix(str, expectedPrefix) {
t.Fatalf("Missing prefix '%s': Actual: '%s'", expectedPrefix, str)
}
}
}
func assertStringRegex(expectedRegex string) stringAssertion {
re := regexp.MustCompile(expectedRegex)
return func(t *testing.T, str string) {
if !re.MatchString(str) {
t.Fatalf("Actual: '%s' did not match regexp '%s'", str, expectedRegex)
}
}
}
func assertRespHasNoErr(t *testing.T, resp *logical.Response) {
t.Helper()
if resp != nil && resp.IsError() {
t.Fatalf("response is error: %#v\n", resp)
}
}
func assertErrIsNil(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatalf("No error expected, got: %s", err)
}
}