DBPW - Migrate Redshift database plugin to v5 interface (#10195)

This commit is contained in:
Tom Proctor
2020-10-23 14:10:57 +01:00
committed by GitHub
parent ee09e54d80
commit be0a3d28f9
3 changed files with 405 additions and 523 deletions

View File

@@ -6,11 +6,10 @@ import (
"errors"
"fmt"
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/database/dbplugin"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
@@ -31,37 +30,28 @@ const (
ALTER USER "{{name}}" VALID UNTIL '{{expiration}}';
`
defaultRotateRootCredentialsSQL = `
ALTER USER "{{username}}" WITH PASSWORD '{{password}}';
ALTER USER "{{name}}" WITH PASSWORD '{{password}}';
`
)
// lowercaseUsername is the reason we wrote this plugin. Redshift implements (mostly)
// a postgres 8 interface, and part of that is under the hood, it's lowercasing the
// usernames.
func New(lowercaseUsername bool) func() (interface{}, error) {
return func() (interface{}, error) {
db := newRedshift(lowercaseUsername)
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
return dbType, nil
}
var _ dbplugin.Database = (*RedShift)(nil)
// New implements builtinplugins.BuiltinFactory
// Redshift implements (mostly) a postgres 8 interface, and part of that is
// under the hood, it's lower-casing the usernames.
func New() (interface{}, error) {
db := newRedshift()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}
func newRedshift(lowercaseUsername bool) *RedShift {
func newRedshift() *RedShift {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = sqlTypeName
credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: 8,
RoleNameLen: 8,
UsernameLen: 63,
Separator: "-",
LowercaseUsername: lowercaseUsername,
}
db := &RedShift{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
return db
@@ -69,14 +59,32 @@ func newRedshift(lowercaseUsername bool) *RedShift {
type RedShift struct {
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}
func (r *RedShift) secretValues() map[string]string {
return map[string]string{
r.Password: "[password]",
}
}
func (r *RedShift) Type() (string, error) {
return middlewareTypeName, nil
}
// getConnection accepts a context and retuns a new pointer to a sql.DB object.
// Initialize must be called on each new RedShift struct before use.
// It uses the connutil.SQLConnectionProducer's Init function to do all the lifting.
func (r *RedShift) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
conf, err := r.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("error initializing db: %w", err)
}
return dbplugin.InitializeResponse{
Config: conf,
}, nil
}
// getConnection accepts a context and returns a new pointer to a sql.DB object.
// It's up to the caller to close the connection or handle reuse logic.
func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := r.Connection(ctx)
@@ -86,116 +94,44 @@ func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) {
return db.(*sql.DB), nil
}
// SetCredentials uses provided information to set/create a user in the
// database. Unlike CreateUser, this method requires a username be provided and
// uses the name given, instead of generating a name. This is used for creating
// and setting the password of static accounts, as well as rolling back
// passwords in the database in the event an updated database fails to save in
// Vault's storage.
func (r *RedShift) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
if len(statements.Rotation) == 0 {
statements.Rotation = []string{defaultRotateRootCredentialsSQL}
}
username = staticUser.Username
password = staticUser.Password
if username == "" || password == "" {
return "", "", errors.New("must provide both username and password")
// NewUser creates a new user in the database. There is no default statement for
// creating users, so one must be specified in the plugin config.
// Generated usernames are of the form v-{display-name}-{role-name}-{UUID}-{timestamp}
func (r *RedShift) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
if len(req.Statements.Commands) == 0 {
return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
}
// Grab the lock
r.Lock()
defer r.Unlock()
// Get the connection
db, err := r.getConnection(ctx)
usernameOpts := []credsutil.UsernameOpt{
credsutil.DisplayName(req.UsernameConfig.DisplayName, 8),
credsutil.RoleName(req.UsernameConfig.RoleName, 8),
credsutil.MaxLength(63),
credsutil.Separator("-"),
credsutil.ToLower(),
}
username, err := credsutil.GenerateUsername(usernameOpts...)
if err != nil {
return "", "", err
}
defer db.Close()
// Check if the role exists
var exists bool
err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
return "", "", err
}
// Vault requires the database user already exist, and that the credentials
// used to execute the rotation statements has sufficient privileges.
stmts := statements.Rotation
// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
}
defer func() {
tx.Rollback()
}()
// Execute each query
for _, stmt := range stmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": staticUser.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return "", "", err
}
return username, password, nil
}
func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
}
// Grab the lock
r.Lock()
defer r.Unlock()
username, err = r.GenerateUsername(usernameConfig)
if err != nil {
return "", "", err
}
password, err = r.GeneratePassword()
if err != nil {
return "", "", err
}
expirationStr, err := r.GenerateExpiration(expiration)
if err != nil {
return "", "", err
return dbplugin.NewUserResponse{}, err
}
password := req.Password
expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700")
// Get the connection
db, err := r.getConnection(ctx)
if err != nil {
return "", "", err
return dbplugin.NewUserResponse{}, err
}
defer db.Close()
// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
return dbplugin.NewUserResponse{}, err
}
defer func() {
@@ -203,7 +139,7 @@ func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statement
}()
// Execute each query
for _, stmt := range statements.Creation {
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
@@ -212,53 +148,81 @@ func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statement
m := map[string]string{
"name": username,
"username": username,
"password": password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
return dbplugin.NewUserResponse{}, err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return "", "", err
return dbplugin.NewUserResponse{}, err
}
return username, password, nil
return dbplugin.NewUserResponse{
Username: username,
}, nil
}
func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
// UpdateUser can update the expiration or the password of a user, or both.
// The updates all happen in a single transaction, so they will either all
// succeed or all fail.
// Both updates support both default and custom statements.
func (r *RedShift) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) {
if req.Password == nil && req.Expiration == nil {
return dbplugin.UpdateUserResponse{}, errors.New("no changes requested")
}
r.Lock()
defer r.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
renewStmts := statements.Renewal
if len(renewStmts) == 0 {
renewStmts = []string{defaultRenewSQL}
}
db, err := r.getConnection(ctx)
if err != nil {
return err
return dbplugin.UpdateUserResponse{}, err
}
defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
return dbplugin.UpdateUserResponse{}, err
}
defer func() {
tx.Rollback()
}()
expirationStr, err := r.GenerateExpiration(expiration)
if err != nil {
return err
if req.Expiration != nil {
err = updateUserExpiration(ctx, req, tx)
if err != nil {
return dbplugin.UpdateUserResponse{}, err
}
}
for _, stmt := range renewStmts {
if req.Password != nil {
err = updateUserPassword(ctx, req, tx)
if err != nil {
return dbplugin.UpdateUserResponse{}, err
}
}
err = tx.Commit()
return dbplugin.UpdateUserResponse{}, err
}
func updateUserExpiration(ctx context.Context, req dbplugin.UpdateUserRequest, tx *sql.Tx) error {
if req.Username == "" {
return errors.New("must provide a username to update user expiration")
}
renewStmts := req.Expiration.Statements
if len(renewStmts.Commands) == 0 {
renewStmts.Commands = []string{defaultRenewSQL}
}
expirationStr := req.Expiration.NewExpiration.Format("2006-01-02 15:04:05-0700")
for _, stmt := range renewStmts.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
@@ -266,7 +230,8 @@ func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements
}
m := map[string]string{
"name": username,
"name": req.Username,
"username": req.Username,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
@@ -275,39 +240,36 @@ func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements
}
}
return tx.Commit()
return nil
}
func (r *RedShift) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the lock
r.Lock()
defer r.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Revocation) == 0 {
return r.defaultRevokeUser(ctx, username)
func updateUserPassword(ctx context.Context, req dbplugin.UpdateUserRequest, tx *sql.Tx) error {
username := req.Username
password := req.Password.NewPassword
if username == "" || password == "" {
return errors.New("must provide both username and a new password to update user password")
}
return r.customRevokeUser(ctx, username, statements.Revocation)
}
func (r *RedShift) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error {
db, err := r.getConnection(ctx)
if err != nil {
// Check if the role exists
var exists bool
err := tx.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
// Server error
return err
}
defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
if err == sql.ErrNoRows || !exists {
// Most likely a user error
return fmt.Errorf("cannot update password for username %q because it does not exist", username)
}
defer func() {
tx.Rollback()
}()
for _, stmt := range revocationStmts {
// Vault requires the database user already exist, and that the credentials
// used to execute the rotation statements has sufficient privileges.
statements := req.Password.Statements.Commands
if len(statements) == 0 {
statements = []string{defaultRotateRootCredentialsSQL}
}
// Execute each query
for _, stmt := range statements {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
@@ -315,7 +277,9 @@ func (r *RedShift) customRevokeUser(ctx context.Context, username string, revoca
}
m := map[string]string{
"name": username,
"name": username,
"username": username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return err
@@ -323,25 +287,76 @@ func (r *RedShift) customRevokeUser(ctx context.Context, username string, revoca
}
}
return tx.Commit()
return nil
}
func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error {
// DeleteUser supports both default and custom statements to delete a user.
func (r *RedShift) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
// Grab the lock
r.Lock()
defer r.Unlock()
if len(req.Statements.Commands) == 0 {
return r.defaultDeleteUser(ctx, req)
}
return r.customDeleteUser(ctx, req)
}
func (r *RedShift) customDeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
db, err := r.getConnection(ctx)
if err != nil {
return err
return dbplugin.DeleteUserResponse{}, err
}
defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return dbplugin.DeleteUserResponse{}, err
}
defer func() {
tx.Rollback()
}()
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": req.Username,
"username": req.Username,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return dbplugin.DeleteUserResponse{}, err
}
}
}
return dbplugin.DeleteUserResponse{}, tx.Commit()
}
func (r *RedShift) defaultDeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
db, err := r.getConnection(ctx)
if err != nil {
return dbplugin.DeleteUserResponse{}, err
}
defer db.Close()
username := req.Username
// Check if the role exists
var exists bool
err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
return err
return dbplugin.DeleteUserResponse{}, err
}
if !exists {
return nil
// No error as Redshift may have deleted the user via TTL before we got to it.
return dbplugin.DeleteUserResponse{}, nil
}
// Query for permissions; we need to revoke permissions before we can drop
@@ -350,13 +365,13 @@ func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error
// we want to remove as much access as possible
stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil {
return err
return dbplugin.DeleteUserResponse{}, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx, username)
if err != nil {
return err
return dbplugin.DeleteUserResponse{}, err
}
defer rows.Close()
@@ -393,7 +408,7 @@ func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error
// this username
var dbname sql.NullString
if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
return err
return dbplugin.DeleteUserResponse{}, err
}
if dbname.Valid {
@@ -432,78 +447,22 @@ $$;`)
// can't drop if not all privileges are revoked
if rows.Err() != nil {
return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err())
return dbplugin.DeleteUserResponse{}, errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err())
}
if lastStmtError != nil {
return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError)
return dbplugin.DeleteUserResponse{}, errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError)
}
// Drop this user
stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
`DROP USER IF EXISTS %s;`, pq.QuoteIdentifier(username)))
if err != nil {
return err
return dbplugin.DeleteUserResponse{}, err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return err
return dbplugin.DeleteUserResponse{}, err
}
return nil
}
func (r *RedShift) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
r.Lock()
defer r.Unlock()
if len(r.Username) == 0 || len(r.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatements := statements
if len(rotateStatements) == 0 {
rotateStatements = []string{defaultRotateRootCredentialsSQL}
}
db, err := r.getConnection(ctx)
if err != nil {
return nil, err
}
defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := r.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatements {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"username": r.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
r.RawConfig["password"] = password
return r.RawConfig, nil
return dbplugin.DeleteUserResponse{}, nil
}