DBPW - Migrate Redshift database plugin to v5 interface (#10195)

This commit is contained in:
Tom Proctor
2020-10-23 14:10:57 +01:00
committed by GitHub
parent ee09e54d80
commit be0a3d28f9
3 changed files with 405 additions and 523 deletions

View File

@@ -108,7 +108,7 @@ func newRegistry() *registry {
"mongodbatlas-database-plugin": dbMongoAtlas.New, "mongodbatlas-database-plugin": dbMongoAtlas.New,
"mssql-database-plugin": dbMssql.New, "mssql-database-plugin": dbMssql.New,
"postgresql-database-plugin": dbPostgres.New, "postgresql-database-plugin": dbPostgres.New,
"redshift-database-plugin": dbRedshift.New(true), "redshift-database-plugin": dbRedshift.New,
}, },
logicalBackends: map[string]logical.Factory{ logicalBackends: map[string]logical.Factory{
"ad": logicalAd.Factory, "ad": logicalAd.Factory,

View File

@@ -6,11 +6,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/database/dbplugin" dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil" "github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil"
@@ -31,37 +30,28 @@ const (
ALTER USER "{{name}}" VALID UNTIL '{{expiration}}'; ALTER USER "{{name}}" VALID UNTIL '{{expiration}}';
` `
defaultRotateRootCredentialsSQL = ` defaultRotateRootCredentialsSQL = `
ALTER USER "{{username}}" WITH PASSWORD '{{password}}'; ALTER USER "{{name}}" WITH PASSWORD '{{password}}';
` `
) )
// lowercaseUsername is the reason we wrote this plugin. Redshift implements (mostly) var _ dbplugin.Database = (*RedShift)(nil)
// a postgres 8 interface, and part of that is under the hood, it's lowercasing the
// usernames. // New implements builtinplugins.BuiltinFactory
func New(lowercaseUsername bool) func() (interface{}, error) { // Redshift implements (mostly) a postgres 8 interface, and part of that is
return func() (interface{}, error) { // under the hood, it's lower-casing the usernames.
db := newRedshift(lowercaseUsername) func New() (interface{}, error) {
// Wrap the plugin with middleware to sanitize errors db := newRedshift()
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) // Wrap the plugin with middleware to sanitize errors
return dbType, nil dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
} return dbType, nil
} }
func newRedshift(lowercaseUsername bool) *RedShift { func newRedshift() *RedShift {
connProducer := &connutil.SQLConnectionProducer{} connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = sqlTypeName connProducer.Type = sqlTypeName
credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: 8,
RoleNameLen: 8,
UsernameLen: 63,
Separator: "-",
LowercaseUsername: lowercaseUsername,
}
db := &RedShift{ db := &RedShift{
SQLConnectionProducer: connProducer, SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
} }
return db return db
@@ -69,14 +59,32 @@ func newRedshift(lowercaseUsername bool) *RedShift {
type RedShift struct { type RedShift struct {
*connutil.SQLConnectionProducer *connutil.SQLConnectionProducer
credsutil.CredentialsProducer }
func (r *RedShift) secretValues() map[string]string {
return map[string]string{
r.Password: "[password]",
}
} }
func (r *RedShift) Type() (string, error) { func (r *RedShift) Type() (string, error) {
return middlewareTypeName, nil return middlewareTypeName, nil
} }
// getConnection accepts a context and retuns a new pointer to a sql.DB object. // Initialize must be called on each new RedShift struct before use.
// It uses the connutil.SQLConnectionProducer's Init function to do all the lifting.
func (r *RedShift) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
conf, err := r.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("error initializing db: %w", err)
}
return dbplugin.InitializeResponse{
Config: conf,
}, nil
}
// getConnection accepts a context and returns a new pointer to a sql.DB object.
// It's up to the caller to close the connection or handle reuse logic. // It's up to the caller to close the connection or handle reuse logic.
func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) { func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := r.Connection(ctx) db, err := r.Connection(ctx)
@@ -86,116 +94,44 @@ func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) {
return db.(*sql.DB), nil return db.(*sql.DB), nil
} }
// SetCredentials uses provided information to set/create a user in the // NewUser creates a new user in the database. There is no default statement for
// database. Unlike CreateUser, this method requires a username be provided and // creating users, so one must be specified in the plugin config.
// uses the name given, instead of generating a name. This is used for creating // Generated usernames are of the form v-{display-name}-{role-name}-{UUID}-{timestamp}
// and setting the password of static accounts, as well as rolling back func (r *RedShift) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
// passwords in the database in the event an updated database fails to save in if len(req.Statements.Commands) == 0 {
// Vault's storage. return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
func (r *RedShift) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
if len(statements.Rotation) == 0 {
statements.Rotation = []string{defaultRotateRootCredentialsSQL}
}
username = staticUser.Username
password = staticUser.Password
if username == "" || password == "" {
return "", "", errors.New("must provide both username and password")
} }
// Grab the lock // Grab the lock
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
// Get the connection usernameOpts := []credsutil.UsernameOpt{
db, err := r.getConnection(ctx) credsutil.DisplayName(req.UsernameConfig.DisplayName, 8),
credsutil.RoleName(req.UsernameConfig.RoleName, 8),
credsutil.MaxLength(63),
credsutil.Separator("-"),
credsutil.ToLower(),
}
username, err := credsutil.GenerateUsername(usernameOpts...)
if err != nil { if err != nil {
return "", "", err return dbplugin.NewUserResponse{}, err
}
defer db.Close()
// Check if the role exists
var exists bool
err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
return "", "", err
}
// Vault requires the database user already exist, and that the credentials
// used to execute the rotation statements has sufficient privileges.
stmts := statements.Rotation
// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
}
defer func() {
tx.Rollback()
}()
// Execute each query
for _, stmt := range stmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": staticUser.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return "", "", err
}
return username, password, nil
}
func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
}
// Grab the lock
r.Lock()
defer r.Unlock()
username, err = r.GenerateUsername(usernameConfig)
if err != nil {
return "", "", err
}
password, err = r.GeneratePassword()
if err != nil {
return "", "", err
}
expirationStr, err := r.GenerateExpiration(expiration)
if err != nil {
return "", "", err
} }
password := req.Password
expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700")
// Get the connection // Get the connection
db, err := r.getConnection(ctx) db, err := r.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return dbplugin.NewUserResponse{}, err
} }
defer db.Close() defer db.Close()
// Start a transaction // Start a transaction
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return dbplugin.NewUserResponse{}, err
} }
defer func() { defer func() {
@@ -203,7 +139,7 @@ func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statement
}() }()
// Execute each query // Execute each query
for _, stmt := range statements.Creation { for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
if len(query) == 0 { if len(query) == 0 {
@@ -212,53 +148,81 @@ func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statement
m := map[string]string{ m := map[string]string{
"name": username, "name": username,
"username": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
} }
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err return dbplugin.NewUserResponse{}, err
} }
} }
} }
// Commit the transaction // Commit the transaction
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return "", "", err return dbplugin.NewUserResponse{}, err
} }
return username, password, nil return dbplugin.NewUserResponse{
Username: username,
}, nil
} }
func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { // UpdateUser can update the expiration or the password of a user, or both.
// The updates all happen in a single transaction, so they will either all
// succeed or all fail.
// Both updates support both default and custom statements.
func (r *RedShift) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) {
if req.Password == nil && req.Expiration == nil {
return dbplugin.UpdateUserResponse{}, errors.New("no changes requested")
}
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
renewStmts := statements.Renewal
if len(renewStmts) == 0 {
renewStmts = []string{defaultRenewSQL}
}
db, err := r.getConnection(ctx) db, err := r.getConnection(ctx)
if err != nil { if err != nil {
return err return dbplugin.UpdateUserResponse{}, err
} }
defer db.Close() defer db.Close()
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return dbplugin.UpdateUserResponse{}, err
} }
defer func() { defer func() {
tx.Rollback() tx.Rollback()
}() }()
expirationStr, err := r.GenerateExpiration(expiration) if req.Expiration != nil {
if err != nil { err = updateUserExpiration(ctx, req, tx)
return err if err != nil {
return dbplugin.UpdateUserResponse{}, err
}
} }
for _, stmt := range renewStmts { if req.Password != nil {
err = updateUserPassword(ctx, req, tx)
if err != nil {
return dbplugin.UpdateUserResponse{}, err
}
}
err = tx.Commit()
return dbplugin.UpdateUserResponse{}, err
}
func updateUserExpiration(ctx context.Context, req dbplugin.UpdateUserRequest, tx *sql.Tx) error {
if req.Username == "" {
return errors.New("must provide a username to update user expiration")
}
renewStmts := req.Expiration.Statements
if len(renewStmts.Commands) == 0 {
renewStmts.Commands = []string{defaultRenewSQL}
}
expirationStr := req.Expiration.NewExpiration.Format("2006-01-02 15:04:05-0700")
for _, stmt := range renewStmts.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
if len(query) == 0 { if len(query) == 0 {
@@ -266,7 +230,8 @@ func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements
} }
m := map[string]string{ m := map[string]string{
"name": username, "name": req.Username,
"username": req.Username,
"expiration": expirationStr, "expiration": expirationStr,
} }
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
@@ -275,39 +240,36 @@ func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements
} }
} }
return tx.Commit() return nil
} }
func (r *RedShift) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { func updateUserPassword(ctx context.Context, req dbplugin.UpdateUserRequest, tx *sql.Tx) error {
// Grab the lock username := req.Username
r.Lock() password := req.Password.NewPassword
defer r.Unlock() if username == "" || password == "" {
return errors.New("must provide both username and a new password to update user password")
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Revocation) == 0 {
return r.defaultRevokeUser(ctx, username)
} }
return r.customRevokeUser(ctx, username, statements.Revocation) // Check if the role exists
} var exists bool
err := tx.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists)
func (r *RedShift) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { if err != nil && err != sql.ErrNoRows {
db, err := r.getConnection(ctx) // Server error
if err != nil {
return err return err
} }
defer db.Close() if err == sql.ErrNoRows || !exists {
// Most likely a user error
tx, err := db.BeginTx(ctx, nil) return fmt.Errorf("cannot update password for username %q because it does not exist", username)
if err != nil {
return err
} }
defer func() {
tx.Rollback()
}()
for _, stmt := range revocationStmts { // Vault requires the database user already exist, and that the credentials
// used to execute the rotation statements has sufficient privileges.
statements := req.Password.Statements.Commands
if len(statements) == 0 {
statements = []string{defaultRotateRootCredentialsSQL}
}
// Execute each query
for _, stmt := range statements {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
if len(query) == 0 { if len(query) == 0 {
@@ -315,7 +277,9 @@ func (r *RedShift) customRevokeUser(ctx context.Context, username string, revoca
} }
m := map[string]string{ m := map[string]string{
"name": username, "name": username,
"username": username,
"password": password,
} }
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return err return err
@@ -323,25 +287,76 @@ func (r *RedShift) customRevokeUser(ctx context.Context, username string, revoca
} }
} }
return tx.Commit() return nil
} }
func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error { // DeleteUser supports both default and custom statements to delete a user.
func (r *RedShift) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
// Grab the lock
r.Lock()
defer r.Unlock()
if len(req.Statements.Commands) == 0 {
return r.defaultDeleteUser(ctx, req)
}
return r.customDeleteUser(ctx, req)
}
func (r *RedShift) customDeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
db, err := r.getConnection(ctx) db, err := r.getConnection(ctx)
if err != nil { if err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
defer db.Close() defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return dbplugin.DeleteUserResponse{}, err
}
defer func() {
tx.Rollback()
}()
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": req.Username,
"username": req.Username,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return dbplugin.DeleteUserResponse{}, err
}
}
}
return dbplugin.DeleteUserResponse{}, tx.Commit()
}
func (r *RedShift) defaultDeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
db, err := r.getConnection(ctx)
if err != nil {
return dbplugin.DeleteUserResponse{}, err
}
defer db.Close()
username := req.Username
// Check if the role exists // Check if the role exists
var exists bool var exists bool
err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists) err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return dbplugin.DeleteUserResponse{}, err
} }
if !exists { if !exists {
return nil // No error as Redshift may have deleted the user via TTL before we got to it.
return dbplugin.DeleteUserResponse{}, nil
} }
// Query for permissions; we need to revoke permissions before we can drop // Query for permissions; we need to revoke permissions before we can drop
@@ -350,13 +365,13 @@ func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error
// we want to remove as much access as possible // we want to remove as much access as possible
stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil { if err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
defer stmt.Close() defer stmt.Close()
rows, err := stmt.QueryContext(ctx, username) rows, err := stmt.QueryContext(ctx, username)
if err != nil { if err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
defer rows.Close() defer rows.Close()
@@ -393,7 +408,7 @@ func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error
// this username // this username
var dbname sql.NullString var dbname sql.NullString
if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil { if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
if dbname.Valid { if dbname.Valid {
@@ -432,78 +447,22 @@ $$;`)
// can't drop if not all privileges are revoked // can't drop if not all privileges are revoked
if rows.Err() != nil { if rows.Err() != nil {
return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err()) return dbplugin.DeleteUserResponse{}, errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err())
} }
if lastStmtError != nil { if lastStmtError != nil {
return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError) return dbplugin.DeleteUserResponse{}, errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError)
} }
// Drop this user // Drop this user
stmt, err = db.PrepareContext(ctx, fmt.Sprintf( stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
`DROP USER IF EXISTS %s;`, pq.QuoteIdentifier(username))) `DROP USER IF EXISTS %s;`, pq.QuoteIdentifier(username)))
if err != nil { if err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
return nil return dbplugin.DeleteUserResponse{}, nil
}
func (r *RedShift) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
r.Lock()
defer r.Unlock()
if len(r.Username) == 0 || len(r.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatements := statements
if len(rotateStatements) == 0 {
rotateStatements = []string{defaultRotateRootCredentialsSQL}
}
db, err := r.getConnection(ctx)
if err != nil {
return nil, err
}
defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := r.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatements {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"username": r.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
r.RawConfig["password"] = password
return r.RawConfig, nil
} }

View File

@@ -3,16 +3,19 @@ package redshift
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"os" "os"
"strings" "reflect"
"regexp"
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/database/dbplugin"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
"github.com/hashicorp/vault/sdk/helper/dbtxn" "github.com/hashicorp/vault/sdk/helper/dbtxn"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/lib/pq" "github.com/lib/pq"
) )
@@ -25,10 +28,6 @@ as environment variables to be used to run these tests. Note that these tests
will create users on your redshift cluster and currently do not clean up after will create users on your redshift cluster and currently do not clean up after
themselves. themselves.
The RotateRoot test is potentially destructive in that it will rotate your root
password on your Redshift cluster to an insecure, cleartext password defined in the
test method. Because of this, you must pass TEST_ROTATE_ROOT=1 to enable it explicitly.
Do not run this test suite against a production Redshift cluster. Do not run this test suite against a production Redshift cluster.
Configuration: Configuration:
@@ -37,7 +36,6 @@ Configuration:
REDSHIFT_USER=my-redshift-admin-user REDSHIFT_USER=my-redshift-admin-user
REDSHIFT_PASSWORD=my-redshift-admin-password REDSHIFT_PASSWORD=my-redshift-admin-password
VAULT_ACC=<unset || 1> # This must be set to run any of the tests in this test suite VAULT_ACC=<unset || 1> # This must be set to run any of the tests in this test suite
TEST_ROTATE_ROOT=<unset || 1> # This must be set to explicitly run the rotate root test
*/ */
var ( var (
@@ -48,281 +46,230 @@ var (
vaultACC = "VAULT_ACC" vaultACC = "VAULT_ACC"
) )
func redshiftEnv() (url string, user string, password string, errEmpty error) { func interpolateConnectionURL(url, user, password string) string {
errEmpty = errors.New("err: empty but required env value") return fmt.Sprintf("postgres://%s:%s@%s", user, password, url)
}
func redshiftEnv() (connURL string, url string, user string, password string, errEmpty error) {
if url = os.Getenv(keyRedshiftURL); url == "" { if url = os.Getenv(keyRedshiftURL); url == "" {
return "", "", "", errEmpty return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftURL)
} }
if user = os.Getenv(keyRedshiftUser); url == "" { if user = os.Getenv(keyRedshiftUser); url == "" {
return "", "", "", errEmpty return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftUser)
} }
if password = os.Getenv(keyRedshiftPassword); url == "" { if password = os.Getenv(keyRedshiftPassword); url == "" {
return "", "", "", errEmpty return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftPassword)
} }
url = fmt.Sprintf("postgres://%s:%s@%s", user, password, url) connURL = interpolateConnectionURL(url, user, password)
return connURL, url, user, password, nil
return url, user, password, nil
} }
func TestPostgreSQL_Initialize(t *testing.T) { func TestRedshift_Initialize(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, _, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
"max_open_connections": 5, "max_open_connections": 73,
} }
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) resp := dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
} })
if !db.Initialized { if !db.Initialized {
t.Fatal("Database should be initialized") t.Fatal("Database should be initialized")
} }
expectedConfig := make(map[string]interface{})
err = db.Close() for k, v := range connectionDetails {
if err != nil { expectedConfig[k] = v
t.Fatalf("err: %s", err) }
} if !reflect.DeepEqual(expectedConfig, resp.Config) {
t.Fatalf("Expected config %+v, but was %v", expectedConfig, resp.Config)
// Test decoding a string value for max_open_connections }
connectionDetails = map[string]interface{}{ if db.MaxOpenConnections != 73 {
"connection_url": url, t.Fatalf("Expected max_open_connections to be set to 73, but was %d", db.MaxOpenConnections)
"max_open_connections": "5",
}
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
} }
dbtesting.AssertClose(t, db)
} }
func TestPostgreSQL_CreateUser(t *testing.T) { func TestRedshift_NewUser(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, url, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
} }
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
} })
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameMetadata{
DisplayName: "test", DisplayName: "test",
RoleName: "test", RoleName: "test",
} }
const password = "SuperSecurePa55w0rd!"
for _, commands := range [][]string{{testRedshiftRole}, {testRedshiftReadOnlyRole}} {
resp := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{
UsernameConfig: usernameConfig,
Password: password,
Statements: dbplugin.Statements{
Commands: commands,
},
Expiration: time.Now().Add(5 * time.Minute),
})
username := resp.Username
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s\n%s:%s", err, username, password)
}
usernameRegex := regexp.MustCompile("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$")
if !usernameRegex.Match([]byte(username)) {
t.Fatalf("Expected username %q to match regex %q", username, usernameRegex.String())
}
}
dbtesting.AssertClose(t, db)
}
func TestRedshift_NewUser_NoCreationStatement_ShouldError(t *testing.T) {
if os.Getenv(vaultACC) != "1" {
t.SkipNow()
}
connURL, _, _, _, err := redshiftEnv()
if err != nil {
t.Fatal(err)
}
connectionDetails := map[string]interface{}{
"connection_url": connURL,
}
db := newRedshift()
dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
})
usernameConfig := dbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
}
const password = "SuperSecurePa55w0rd!"
// Test with no configured Creation Statement // Test with no configured Creation Statement
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) _, err = db.NewUser(context.Background(), dbplugin.NewUserRequest{
UsernameConfig: usernameConfig,
Password: password,
Statements: dbplugin.Statements{
Commands: []string{}, // Empty commands field here should cause error.
},
Expiration: time.Now().Add(5 * time.Minute),
})
if err == nil { if err == nil {
t.Fatal("Expected error when no creation statement is provided") t.Fatal("Expected error when no creation statement is provided")
} }
statements := dbplugin.Statements{ dbtesting.AssertClose(t, db)
Creation: []string{testRedshiftRole},
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s\n%s:%s", err, username, password)
}
statements.Creation = []string{testRedshiftReadOnlyRole}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Sleep to make sure we haven't expired if granularity is only down to the second
time.Sleep(2 * time.Second)
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
} }
func TestPostgreSQL_RenewUser(t *testing.T) { func TestRedshift_UpdateUser_Expiration(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, url, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
} }
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
} })
statements := dbplugin.Statements{ usernameConfig := dbplugin.UsernameMetadata{
Creation: []string{testRedshiftRole},
}
usernameConfig := dbplugin.UsernameConfig{
DisplayName: "test", DisplayName: "test",
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) const password = "SuperSecurePa55w0rd!"
if err != nil { const initialTTL = 2 * time.Second
t.Fatalf("err: %s", err) const longTTL = time.Minute
} for _, commands := range [][]string{{}, {defaultRenewSQL}} {
newResp := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{
if err = testCredsExist(t, url, username, password); err != nil { UsernameConfig: usernameConfig,
t.Fatalf("Could not connect with new credentials: %s", err) Password: password,
} Statements: dbplugin.Statements{Commands: []string{testRedshiftRole}},
Expiration: time.Now().Add(initialTTL),
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) })
if err != nil { username := newResp.Username
t.Fatalf("err: %s", err)
} if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
// Sleep longer than the initial expiration time }
time.Sleep(2 * time.Second)
dbtesting.AssertUpdateUser(t, db, dbplugin.UpdateUserRequest{
if err = testCredsExist(t, url, username, password); err != nil { Username: username,
t.Fatalf("Could not connect with new credentials: %s", err) Expiration: &dbplugin.ChangeExpiration{
} NewExpiration: time.Now().Add(longTTL),
statements.Renewal = []string{defaultRenewSQL} Statements: dbplugin.Statements{Commands: commands},
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) },
if err != nil { })
t.Fatalf("err: %s", err)
} // Sleep longer than the initial expiration time
time.Sleep(initialTTL + time.Second)
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err) if err = testCredsExist(t, url, username, password); err != nil {
} t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Sleep longer than the initial expiration time
time.Sleep(2 * time.Second)
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
} }
dbtesting.AssertClose(t, db)
} }
func TestPostgreSQL_RevokeUser(t *testing.T) { func TestRedshift_UpdateUser_Password(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, url, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
}
db := newRedshift(true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
Creation: []string{testRedshiftRole},
}
usernameConfig := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// Test default revoke statements
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, url, username, password); err == nil {
t.Fatal("Credentials were not revoked")
}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// Test custom revoke statements
statements.Revocation = []string{defaultRedshiftRevocationSQL}
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, url, username, password); err == nil {
t.Fatal("Credentials were not revoked")
}
}
func TestPostgresSQL_SetCredentials(t *testing.T) {
if os.Getenv(vaultACC) != "1" {
t.SkipNow()
}
url, _, _, err := redshiftEnv()
if err != nil {
t.Fatal(err)
}
connectionDetails := map[string]interface{}{
"connection_url": url,
} }
// create the database user // create the database user
@@ -331,121 +278,97 @@ func TestPostgresSQL_SetCredentials(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
dbUser := "vaultstatictest-" + fmt.Sprintf("%s", uid) dbUser := "vaultstatictest-" + fmt.Sprintf("%s", uid)
createTestPGUser(t, url, dbUser, "1Password", testRoleStaticCreate) createTestPGUser(t, connURL, dbUser, "1Password", testRoleStaticCreate)
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
})
const password1 = "MyTemporaryUserPassword1!"
const password2 = "MyTemporaryUserPassword2!"
for _, tc := range []struct {
password string
commands []string
}{
{password1, []string{}},
{password2, []string{testRedshiftStaticRoleRotate}},
} {
dbtesting.AssertUpdateUser(t, db, dbplugin.UpdateUserRequest{
Username: dbUser,
Password: &dbplugin.ChangePassword{
NewPassword: tc.password,
Statements: dbplugin.Statements{Commands: tc.commands},
},
})
if err := testCredsExist(t, url, dbUser, tc.password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
} }
password, err := db.GenerateCredentials(context.Background()) dbtesting.AssertClose(t, db)
if err != nil {
t.Fatal(err)
}
usernameConfig := dbplugin.StaticUserConfig{
Username: dbUser,
Password: password,
}
// Test with no configured Rotation Statement
username, password, err := db.SetCredentials(context.Background(), dbplugin.Statements{}, usernameConfig)
if err == nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
Rotation: []string{testRedshiftStaticRoleRotate},
}
// User should not exist, make sure we can create
username, password, err = db.SetCredentials(context.Background(), statements, usernameConfig)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// call SetCredentials again, password will change
newPassword, _ := db.GenerateCredentials(context.Background())
usernameConfig.Password = newPassword
username, password, err = db.SetCredentials(context.Background(), statements, usernameConfig)
if err != nil {
t.Fatalf("err: %s", err)
}
if password != newPassword {
t.Fatal("passwords should have changed")
}
if err := testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
} }
func TestPostgreSQL_RotateRootCredentials(t *testing.T) { func TestRedshift_DeleteUser(t *testing.T) {
/* if os.Getenv(vaultACC) != "1" {
Extra precaution is taken for rotating root creds because it's assumed that this
test will run against a live redshift cluster. This test must run last because
it is destructive.
To run this test you must pass TEST_ROTATE_ROOT=1
*/
if os.Getenv(vaultACC) != "1" || os.Getenv("TEST_ROTATE_ROOT") != "1" {
t.SkipNow() t.SkipNow()
} }
url, adminUser, adminPassword, err := redshiftEnv() connURL, url, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
"username": adminUser,
"password": adminPassword,
} }
db := newRedshift(true) db := newRedshift()
dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
})
connProducer := db.SQLConnectionProducer usernameConfig := dbplugin.UsernameMetadata{
DisplayName: "test",
_, err = db.Init(context.Background(), connectionDetails, true) RoleName: "test",
if err != nil {
t.Fatalf("err: %s", err)
} }
if !connProducer.Initialized { const password = "SuperSecretPa55word!"
t.Fatal("Database should be initialized") for _, commands := range [][]string{{}, {defaultRedshiftRevocationSQL}} {
newResponse := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{
UsernameConfig: usernameConfig,
Statements: dbplugin.Statements{Commands: []string{testRedshiftRole}},
Password: password,
Expiration: time.Now().Add(2 * time.Second),
})
username := newResponse.Username
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// Intentionally _not_ using dbtesting here as the call almost always takes longer than the 2s default timeout
db.DeleteUser(context.Background(), dbplugin.DeleteUserRequest{
Username: username,
Statements: dbplugin.Statements{Commands: commands},
})
if err := testCredsExist(t, url, username, password); err == nil {
t.Fatal("Credentials were not revoked")
}
} }
newConf, err := db.RotateRootCredentials(context.Background(), nil) dbtesting.AssertClose(t, db)
if err != nil {
t.Fatalf("err: %v", err)
}
fmt.Printf("rotated root credentials, new user/pass:\nusername: %s\npassword: %s\n", newConf["username"], newConf["password"])
if newConf["password"] == adminPassword {
t.Fatal("password was not updated")
}
err = db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
} }
func testCredsExist(t testing.TB, connURL, username, password string) error { func testCredsExist(t testing.TB, url, username, password string) error {
t.Helper() t.Helper()
_, adminUser, adminPassword, err := redshiftEnv()
if err != nil {
return err
}
connURL = strings.Replace(connURL, fmt.Sprintf("%s:%s", adminUser, adminPassword), fmt.Sprintf("%s:%s", username, password), 1) connURL := interpolateConnectionURL(url, username, password)
db, err := sql.Open("postgres", connURL) db, err := sql.Open("postgres", connURL)
if err != nil { if err != nil {
return err return err