mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 10:37:56 +00:00 
			
		
		
		
	DBPW - Update PostgreSQL to adhere to v5 Database interface (#10061)
This commit is contained in:
		| @@ -3,35 +3,36 @@ package postgresql | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"errors" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/errwrap" | 	"github.com/hashicorp/errwrap" | ||||||
|  | 	"github.com/hashicorp/go-multierror" | ||||||
| 	"github.com/hashicorp/vault/api" | 	"github.com/hashicorp/vault/api" | ||||||
| 	"github.com/hashicorp/vault/sdk/database/dbplugin" |  | ||||||
| 	"github.com/hashicorp/vault/sdk/database/helper/connutil" | 	"github.com/hashicorp/vault/sdk/database/helper/connutil" | ||||||
| 	"github.com/hashicorp/vault/sdk/database/helper/credsutil" | 	"github.com/hashicorp/vault/sdk/database/helper/credsutil" | ||||||
| 	"github.com/hashicorp/vault/sdk/database/helper/dbutil" | 	"github.com/hashicorp/vault/sdk/database/helper/dbutil" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/database/newdbplugin" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/dbtxn" | 	"github.com/hashicorp/vault/sdk/helper/dbtxn" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/strutil" | 	"github.com/hashicorp/vault/sdk/helper/strutil" | ||||||
| 	"github.com/lib/pq" | 	"github.com/lib/pq" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	postgreSQLTypeName      = "postgres" | 	postgreSQLTypeName         = "postgres" | ||||||
| 	defaultPostgresRenewSQL = ` | 	defaultExpirationStatement = ` | ||||||
| ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; | ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; | ||||||
| ` | ` | ||||||
| 	defaultPostgresRotateRootCredentialsSQL = ` | 	defaultChangePasswordStatement = ` | ||||||
| ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; | ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; | ||||||
| ` | ` | ||||||
|  |  | ||||||
|  | 	expirationFormat = "2006-01-02 15:04:05-0700" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	_ dbplugin.Database = &PostgreSQL{} | 	_ newdbplugin.Database = &PostgreSQL{} | ||||||
|  |  | ||||||
| 	// postgresEndStatement is basically the word "END" but | 	// postgresEndStatement is basically the word "END" but | ||||||
| 	// surrounded by a word boundary to differentiate it from | 	// surrounded by a word boundary to differentiate it from | ||||||
| @@ -51,7 +52,7 @@ var ( | |||||||
| func New() (interface{}, error) { | func New() (interface{}, error) { | ||||||
| 	db := new() | 	db := new() | ||||||
| 	// Wrap the plugin with middleware to sanitize errors | 	// Wrap the plugin with middleware to sanitize errors | ||||||
| 	dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) | 	dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) | ||||||
| 	return dbType, nil | 	return dbType, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -59,16 +60,8 @@ func new() *PostgreSQL { | |||||||
| 	connProducer := &connutil.SQLConnectionProducer{} | 	connProducer := &connutil.SQLConnectionProducer{} | ||||||
| 	connProducer.Type = postgreSQLTypeName | 	connProducer.Type = postgreSQLTypeName | ||||||
|  |  | ||||||
| 	credsProducer := &credsutil.SQLCredentialsProducer{ |  | ||||||
| 		DisplayNameLen: 8, |  | ||||||
| 		RoleNameLen:    8, |  | ||||||
| 		UsernameLen:    63, |  | ||||||
| 		Separator:      "-", |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	db := &PostgreSQL{ | 	db := &PostgreSQL{ | ||||||
| 		SQLConnectionProducer: connProducer, | 		SQLConnectionProducer: connProducer, | ||||||
| 		CredentialsProducer:   credsProducer, |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return db | 	return db | ||||||
| @@ -81,14 +74,24 @@ func Run(apiTLSConfig *api.TLSConfig) error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) | 	newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type PostgreSQL struct { | type PostgreSQL struct { | ||||||
| 	*connutil.SQLConnectionProducer | 	*connutil.SQLConnectionProducer | ||||||
| 	credsutil.CredentialsProducer | } | ||||||
|  |  | ||||||
|  | func (p *PostgreSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) { | ||||||
|  | 	newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return newdbplugin.InitializeResponse{}, err | ||||||
|  | 	} | ||||||
|  | 	resp := newdbplugin.InitializeResponse{ | ||||||
|  | 		Config: newConf, | ||||||
|  | 	} | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *PostgreSQL) Type() (string, error) { | func (p *PostgreSQL) Type() (string, error) { | ||||||
| @@ -104,54 +107,58 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { | |||||||
| 	return db.(*sql.DB), nil | 	return db.(*sql.DB), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // SetCredentials uses provided information to set/create a user in the | func (p *PostgreSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) { | ||||||
| // database. Unlike CreateUser, this method requires a username be provided and | 	if req.Username == "" { | ||||||
| // uses the name given, instead of generating a name. This is used for creating | 		return newdbplugin.UpdateUserResponse{}, fmt.Errorf("missing username") | ||||||
| // and setting the password of static accounts, as well as rolling back | 	} | ||||||
| // passwords in the database in the event an updated database fails to save in | 	if req.Password == nil && req.Expiration == nil { | ||||||
| // Vault's storage. | 		return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested") | ||||||
| func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { |  | ||||||
| 	if len(statements.Rotation) == 0 { |  | ||||||
| 		statements.Rotation = []string{defaultPostgresRotateRootCredentialsSQL} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	username = staticUser.Username | 	merr := &multierror.Error{} | ||||||
| 	password = staticUser.Password | 	if req.Password != nil { | ||||||
| 	if username == "" || password == "" { | 		err := p.changeUserPassword(ctx, req.Username, req.Password) | ||||||
| 		return "", "", errors.New("must provide both username and password") | 		merr = multierror.Append(merr, err) | ||||||
|  | 	} | ||||||
|  | 	if req.Expiration != nil { | ||||||
|  | 		err := p.changeUserExpiration(ctx, req.Username, req.Expiration) | ||||||
|  | 		merr = multierror.Append(merr, err) | ||||||
|  | 	} | ||||||
|  | 	return newdbplugin.UpdateUserResponse{}, merr.ErrorOrNil() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error { | ||||||
|  | 	stmts := changePass.Statements.Commands | ||||||
|  | 	if len(stmts) == 0 { | ||||||
|  | 		stmts = []string{defaultChangePasswordStatement} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	password := changePass.NewPassword | ||||||
|  | 	if password == "" { | ||||||
|  | 		return fmt.Errorf("missing password") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Grab the lock |  | ||||||
| 	p.Lock() | 	p.Lock() | ||||||
| 	defer p.Unlock() | 	defer p.Unlock() | ||||||
|  |  | ||||||
| 	// Get the connection |  | ||||||
| 	db, err := p.getConnection(ctx) | 	db, err := p.getConnection(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", "", err | 		return fmt.Errorf("unable to get connection: %w", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check if the role exists | 	// Check if the role exists | ||||||
| 	var exists bool | 	var exists bool | ||||||
| 	err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) | 	err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) | ||||||
| 	if err != nil && err != sql.ErrNoRows { | 	if err != nil && err != sql.ErrNoRows { | ||||||
| 		return "", "", err | 		return fmt.Errorf("user does not appear to exist: %w", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Vault requires the database user already exist, and that the credentials |  | ||||||
| 	// used to execute the rotation statements has sufficient privileges. |  | ||||||
| 	stmts := statements.Rotation |  | ||||||
|  |  | ||||||
| 	// Start a transaction |  | ||||||
| 	tx, err := db.BeginTx(ctx, nil) | 	tx, err := db.BeginTx(ctx, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", "", err | 		return fmt.Errorf("unable to start transaction: %w", err) | ||||||
| 	} | 	} | ||||||
| 	defer func() { | 	defer tx.Rollback() | ||||||
| 		_ = tx.Rollback() |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	// Execute each query |  | ||||||
| 	for _, stmt := range stmts { | 	for _, stmt := range stmts { | ||||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||||
| 			query = strings.TrimSpace(query) | 			query = strings.TrimSpace(query) | ||||||
| @@ -160,117 +167,30 @@ func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Sta | |||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			m := map[string]string{ | 			m := map[string]string{ | ||||||
| 				"name":     staticUser.Username, | 				"name":     username, | ||||||
| 				"username": staticUser.Username, | 				"username": username, | ||||||
| 				"password": password, | 				"password": password, | ||||||
| 			} | 			} | ||||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | ||||||
| 				return "", "", err | 				return fmt.Errorf("failed to execute query: %w", err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Commit the transaction |  | ||||||
| 	if err := tx.Commit(); err != nil { | 	if err := tx.Commit(); err != nil { | ||||||
| 		return "", "", err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return username, password, nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, changeExp *newdbplugin.ChangeExpiration) error { | ||||||
| 	statements = dbutil.StatementCompatibilityHelper(statements) |  | ||||||
|  |  | ||||||
| 	if len(statements.Creation) == 0 { |  | ||||||
| 		return "", "", dbutil.ErrEmptyCreationStatement |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Grab the lock |  | ||||||
| 	p.Lock() | 	p.Lock() | ||||||
| 	defer p.Unlock() | 	defer p.Unlock() | ||||||
|  |  | ||||||
| 	username, err = p.GenerateUsername(usernameConfig) | 	renewStmts := changeExp.Statements.Commands | ||||||
| 	if err != nil { |  | ||||||
| 		return "", "", err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	password, err = p.GeneratePassword() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", "", err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	expirationStr, err := p.GenerateExpiration(expiration) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", "", err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Get the connection |  | ||||||
| 	db, err := p.getConnection(ctx) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", "", err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Start a transaction |  | ||||||
| 	tx, err := db.BeginTx(ctx, nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", "", err |  | ||||||
|  |  | ||||||
| 	} |  | ||||||
| 	defer func() { |  | ||||||
| 		tx.Rollback() |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	// Execute each query |  | ||||||
| 	for _, stmt := range statements.Creation { |  | ||||||
| 		if containsMultilineStatement(stmt) { |  | ||||||
| 			// Execute it as-is. |  | ||||||
| 			m := map[string]string{ |  | ||||||
| 				"name":       username, |  | ||||||
| 				"username":   username, |  | ||||||
| 				"password":   password, |  | ||||||
| 				"expiration": expirationStr, |  | ||||||
| 			} |  | ||||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil { |  | ||||||
| 				return "", "", err |  | ||||||
| 			} |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		// Otherwise, it's fine to split the statements on the semicolon. |  | ||||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { |  | ||||||
| 			query = strings.TrimSpace(query) |  | ||||||
| 			if len(query) == 0 { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			m := map[string]string{ |  | ||||||
| 				"name":       username, |  | ||||||
| 				"username":   username, |  | ||||||
| 				"password":   password, |  | ||||||
| 				"expiration": expirationStr, |  | ||||||
| 			} |  | ||||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { |  | ||||||
| 				return "", "", err |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Commit the transaction |  | ||||||
| 	if err := tx.Commit(); err != nil { |  | ||||||
| 		return "", "", err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return username, password, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { |  | ||||||
| 	p.Lock() |  | ||||||
| 	defer p.Unlock() |  | ||||||
|  |  | ||||||
| 	statements = dbutil.StatementCompatibilityHelper(statements) |  | ||||||
|  |  | ||||||
| 	renewStmts := statements.Renewal |  | ||||||
| 	if len(renewStmts) == 0 { | 	if len(renewStmts) == 0 { | ||||||
| 		renewStmts = []string{defaultPostgresRenewSQL} | 		renewStmts = []string{defaultExpirationStatement} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	db, err := p.getConnection(ctx) | 	db, err := p.getConnection(ctx) | ||||||
| @@ -286,10 +206,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen | |||||||
| 		tx.Rollback() | 		tx.Rollback() | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	expirationStr, err := p.GenerateExpiration(expiration) | 	expirationStr := changeExp.NewExpiration.Format(expirationFormat) | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, stmt := range renewStmts { | 	for _, stmt := range renewStmts { | ||||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||||
| @@ -312,21 +229,93 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen | |||||||
| 	return tx.Commit() | 	return tx.Commit() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | func (p *PostgreSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) { | ||||||
| 	// Grab the lock | 	if len(req.Statements.Commands) == 0 { | ||||||
|  | 		return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	p.Lock() | 	p.Lock() | ||||||
| 	defer p.Unlock() | 	defer p.Unlock() | ||||||
|  |  | ||||||
| 	statements = dbutil.StatementCompatibilityHelper(statements) | 	username, err := credsutil.GenerateUsername( | ||||||
|  | 		credsutil.DisplayName(req.UsernameConfig.DisplayName, 8), | ||||||
| 	if len(statements.Revocation) == 0 { | 		credsutil.RoleName(req.UsernameConfig.RoleName, 8), | ||||||
| 		return p.defaultRevokeUser(ctx, username) | 		credsutil.Separator("-"), | ||||||
|  | 		credsutil.MaxLength(63), | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return newdbplugin.NewUserResponse{}, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return p.customRevokeUser(ctx, username, statements.Revocation) | 	expirationStr := req.Expiration.Format(expirationFormat) | ||||||
|  |  | ||||||
|  | 	db, err := p.getConnection(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	tx, err := db.BeginTx(ctx, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to start transaction: %w", err) | ||||||
|  |  | ||||||
|  | 	} | ||||||
|  | 	defer tx.Rollback() | ||||||
|  |  | ||||||
|  | 	for _, stmt := range req.Statements.Commands { | ||||||
|  | 		if containsMultilineStatement(stmt) { | ||||||
|  | 			// Execute it as-is. | ||||||
|  | 			m := map[string]string{ | ||||||
|  | 				"name":       username, | ||||||
|  | 				"username":   username, | ||||||
|  | 				"password":   req.Password, | ||||||
|  | 				"expiration": expirationStr, | ||||||
|  | 			} | ||||||
|  | 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil { | ||||||
|  | 				return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		// Otherwise, it's fine to split the statements on the semicolon. | ||||||
|  | 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||||
|  | 			query = strings.TrimSpace(query) | ||||||
|  | 			if len(query) == 0 { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			m := map[string]string{ | ||||||
|  | 				"name":       username, | ||||||
|  | 				"username":   username, | ||||||
|  | 				"password":   req.Password, | ||||||
|  | 				"expiration": expirationStr, | ||||||
|  | 			} | ||||||
|  | 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | ||||||
|  | 				return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := tx.Commit(); err != nil { | ||||||
|  | 		return newdbplugin.NewUserResponse{}, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp := newdbplugin.NewUserResponse{ | ||||||
|  | 		Username: username, | ||||||
|  | 	} | ||||||
|  | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { | func (p *PostgreSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) { | ||||||
|  | 	p.Lock() | ||||||
|  | 	defer p.Unlock() | ||||||
|  |  | ||||||
|  | 	if len(req.Statements.Commands) == 0 { | ||||||
|  | 		return newdbplugin.DeleteUserResponse{}, p.defaultDeleteUser(ctx, req.Username) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return newdbplugin.DeleteUserResponse{}, p.customDeleteUser(ctx, req.Username, req.Statements.Commands) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *PostgreSQL) customDeleteUser(ctx context.Context, username string, revocationStmts []string) error { | ||||||
| 	db, err := p.getConnection(ctx) | 	db, err := p.getConnection(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| @@ -360,7 +349,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revo | |||||||
| 	return tx.Commit() | 	return tx.Commit() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { | func (p *PostgreSQL) defaultDeleteUser(ctx context.Context, username string) error { | ||||||
| 	db, err := p.getConnection(ctx) | 	db, err := p.getConnection(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| @@ -471,65 +460,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { | func (p *PostgreSQL) secretValues() map[string]string { | ||||||
| 	p.Lock() | 	return map[string]string{ | ||||||
| 	defer p.Unlock() | 		p.Password: "[password]", | ||||||
|  |  | ||||||
| 	if len(p.Username) == 0 || len(p.Password) == 0 { |  | ||||||
| 		return nil, errors.New("username and password are required to rotate") |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	rotateStatements := statements |  | ||||||
| 	if len(rotateStatements) == 0 { |  | ||||||
| 		rotateStatements = []string{defaultPostgresRotateRootCredentialsSQL} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	db, err := p.getConnection(ctx) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	tx, err := db.BeginTx(ctx, nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	defer func() { |  | ||||||
| 		tx.Rollback() |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	password, err := p.GeneratePassword() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, stmt := range rotateStatements { |  | ||||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { |  | ||||||
| 			query = strings.TrimSpace(query) |  | ||||||
| 			if len(query) == 0 { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			m := map[string]string{ |  | ||||||
| 				"name":     p.Username, |  | ||||||
| 				"username": p.Username, |  | ||||||
| 				"password": password, |  | ||||||
| 			} |  | ||||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { |  | ||||||
| 				return nil, err |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if err := tx.Commit(); err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Close the database connection to ensure no new connections come in |  | ||||||
| 	if err := db.Close(); err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	p.RawConfig["password"] = password |  | ||||||
| 	return p.RawConfig, nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // containsMultilineStatement is a best effort to determine whether | // containsMultilineStatement is a best effort to determine whether | ||||||
|   | |||||||
| @@ -9,9 +9,8 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/helper/testhelpers/postgresql" | 	"github.com/hashicorp/vault/helper/testhelpers/postgresql" | ||||||
| 	"github.com/hashicorp/vault/sdk/database/dbplugin" | 	"github.com/hashicorp/vault/sdk/database/newdbplugin" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/dbtxn" | 	dbtesting "github.com/hashicorp/vault/sdk/database/newdbplugin/testing" | ||||||
| 	"github.com/lib/pq" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) { | func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) { | ||||||
| @@ -24,12 +23,14 @@ func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, f | |||||||
| 		connectionDetails[k] = v | 		connectionDetails[k] = v | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	db := new() | 	req := newdbplugin.InitializeRequest{ | ||||||
| 	_, err := db.Init(context.Background(), connectionDetails, true) | 		Config:           connectionDetails, | ||||||
| 	if err != nil { | 		VerifyConnection: true, | ||||||
| 		t.Fatalf("err: %s", err) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	db := new() | ||||||
|  | 	dbtesting.AssertInitialize(t, db, req) | ||||||
|  |  | ||||||
| 	if !db.Initialized { | 	if !db.Initialized { | ||||||
| 		t.Fatal("Database should be initialized") | 		t.Fatal("Database should be initialized") | ||||||
| 	} | 	} | ||||||
| @@ -58,90 +59,163 @@ func TestPostgreSQL_InitializeWithStringVals(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) { | func TestPostgreSQL_NewUser(t *testing.T) { | ||||||
| 	db := new() |  | ||||||
|  |  | ||||||
| 	usernameConfig := dbplugin.UsernameConfig{ |  | ||||||
| 		DisplayName: "test", |  | ||||||
| 		RoleName:    "test", |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	username, password, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) |  | ||||||
| 	if err == nil { |  | ||||||
| 		t.Fatalf("expected err, got nil") |  | ||||||
| 	} |  | ||||||
| 	if username != "" { |  | ||||||
| 		t.Fatalf("expected empty username, got [%s]", username) |  | ||||||
| 	} |  | ||||||
| 	if password != "" { |  | ||||||
| 		t.Fatalf("expected empty password, got [%s]", password) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestPostgreSQL_CreateUser(t *testing.T) { |  | ||||||
| 	type testCase struct { | 	type testCase struct { | ||||||
| 		createStmts          []string | 		req            newdbplugin.NewUserRequest | ||||||
| 		shouldTestCredsExist bool | 		expectErr      bool | ||||||
|  | 		credsAssertion credsAssertion | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	tests := map[string]testCase{ | 	tests := map[string]testCase{ | ||||||
| 		"admin name": { | 		"no creation statements": { | ||||||
| 			createStmts: []string{` | 			req: newdbplugin.NewUserRequest{ | ||||||
| 				CREATE ROLE "{{name}}" WITH | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
| 				  LOGIN | 					DisplayName: "test", | ||||||
| 				  PASSWORD '{{password}}' | 					RoleName:    "test", | ||||||
| 				  VALID UNTIL '{{expiration}}'; | 				}, | ||||||
| 				GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, | 				// No statements | ||||||
|  | 				Password:   "somesecurepassword", | ||||||
|  | 				Expiration: time.Now().Add(1 * time.Minute), | ||||||
| 			}, | 			}, | ||||||
| 			shouldTestCredsExist: true, | 			expectErr:      true, | ||||||
|  | 			credsAssertion: assertCredsDoNotExist, | ||||||
|  | 		}, | ||||||
|  | 		"admin name": { | ||||||
|  | 			req: newdbplugin.NewUserRequest{ | ||||||
|  | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
|  | 					DisplayName: "test", | ||||||
|  | 					RoleName:    "test", | ||||||
|  | 				}, | ||||||
|  | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: []string{` | ||||||
|  | 						CREATE ROLE "{{name}}" WITH | ||||||
|  | 						  LOGIN | ||||||
|  | 						  PASSWORD '{{password}}' | ||||||
|  | 						  VALID UNTIL '{{expiration}}'; | ||||||
|  | 						GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   "somesecurepassword", | ||||||
|  | 				Expiration: time.Now().Add(1 * time.Minute), | ||||||
|  | 			}, | ||||||
|  | 			expectErr:      false, | ||||||
|  | 			credsAssertion: assertCredsExist, | ||||||
| 		}, | 		}, | ||||||
| 		"admin username": { | 		"admin username": { | ||||||
| 			createStmts: []string{` | 			req: newdbplugin.NewUserRequest{ | ||||||
| 				CREATE ROLE "{{username}}" WITH | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
| 				  LOGIN | 					DisplayName: "test", | ||||||
| 				  PASSWORD '{{password}}' | 					RoleName:    "test", | ||||||
| 				  VALID UNTIL '{{expiration}}'; | 				}, | ||||||
| 				GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: []string{` | ||||||
|  | 						CREATE ROLE "{{username}}" WITH | ||||||
|  | 						  LOGIN | ||||||
|  | 						  PASSWORD '{{password}}' | ||||||
|  | 						  VALID UNTIL '{{expiration}}'; | ||||||
|  | 						GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   "somesecurepassword", | ||||||
|  | 				Expiration: time.Now().Add(1 * time.Minute), | ||||||
| 			}, | 			}, | ||||||
| 			shouldTestCredsExist: true, | 			expectErr:      false, | ||||||
|  | 			credsAssertion: assertCredsExist, | ||||||
| 		}, | 		}, | ||||||
| 		"read only name": { | 		"read only name": { | ||||||
| 			createStmts: []string{` | 			req: newdbplugin.NewUserRequest{ | ||||||
| 				CREATE ROLE "{{name}}" WITH | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
| 				  LOGIN | 					DisplayName: "test", | ||||||
| 				  PASSWORD '{{password}}' | 					RoleName:    "test", | ||||||
| 				  VALID UNTIL '{{expiration}}'; | 				}, | ||||||
| 				GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; | 				Statements: newdbplugin.Statements{ | ||||||
| 				GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, | 					Commands: []string{` | ||||||
|  | 						CREATE ROLE "{{name}}" WITH | ||||||
|  | 						  LOGIN | ||||||
|  | 						  PASSWORD '{{password}}' | ||||||
|  | 						  VALID UNTIL '{{expiration}}'; | ||||||
|  | 						GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; | ||||||
|  | 						GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   "somesecurepassword", | ||||||
|  | 				Expiration: time.Now().Add(1 * time.Minute), | ||||||
| 			}, | 			}, | ||||||
| 			shouldTestCredsExist: true, | 			expectErr:      false, | ||||||
|  | 			credsAssertion: assertCredsExist, | ||||||
| 		}, | 		}, | ||||||
| 		"read only username": { | 		"read only username": { | ||||||
| 			createStmts: []string{` | 			req: newdbplugin.NewUserRequest{ | ||||||
| 				CREATE ROLE "{{username}}" WITH | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
| 				  LOGIN | 					DisplayName: "test", | ||||||
| 				  PASSWORD '{{password}}' | 					RoleName:    "test", | ||||||
| 				  VALID UNTIL '{{expiration}}'; | 				}, | ||||||
| 				GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; | 				Statements: newdbplugin.Statements{ | ||||||
| 				GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, | 					Commands: []string{` | ||||||
|  | 						CREATE ROLE "{{username}}" WITH | ||||||
|  | 						  LOGIN | ||||||
|  | 						  PASSWORD '{{password}}' | ||||||
|  | 						  VALID UNTIL '{{expiration}}'; | ||||||
|  | 						GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; | ||||||
|  | 						GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   "somesecurepassword", | ||||||
|  | 				Expiration: time.Now().Add(1 * time.Minute), | ||||||
| 			}, | 			}, | ||||||
| 			shouldTestCredsExist: true, | 			expectErr:      false, | ||||||
|  | 			credsAssertion: assertCredsExist, | ||||||
| 		}, | 		}, | ||||||
| 		// https://github.com/hashicorp/vault/issues/6098 | 		// https://github.com/hashicorp/vault/issues/6098 | ||||||
| 		"reproduce GH-6098": { | 		"reproduce GH-6098": { | ||||||
| 			createStmts: []string{ | 			req: newdbplugin.NewUserRequest{ | ||||||
| 				// NOTE: "rolname" in the following line is not a typo. | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
| 				"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", | 					DisplayName: "test", | ||||||
|  | 					RoleName:    "test", | ||||||
|  | 				}, | ||||||
|  | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: []string{ | ||||||
|  | 						// NOTE: "rolname" in the following line is not a typo. | ||||||
|  | 						"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   "somesecurepassword", | ||||||
|  | 				Expiration: time.Now().Add(1 * time.Minute), | ||||||
| 			}, | 			}, | ||||||
| 			// This test statement doesn't generate creds. | 			expectErr:      false, | ||||||
| 			shouldTestCredsExist: false, | 			credsAssertion: assertCredsDoNotExist, | ||||||
| 		}, | 		}, | ||||||
| 		"reproduce issue with template": { | 		"reproduce issue with template": { | ||||||
| 			createStmts: []string{ | 			req: newdbplugin.NewUserRequest{ | ||||||
| 				`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
|  | 					DisplayName: "test", | ||||||
|  | 					RoleName:    "test", | ||||||
|  | 				}, | ||||||
|  | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: []string{ | ||||||
|  | 						`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   "somesecurepassword", | ||||||
|  | 				Expiration: time.Now().Add(1 * time.Minute), | ||||||
| 			}, | 			}, | ||||||
| 			// This test statement doesn't generate creds. | 			expectErr:      false, | ||||||
| 			shouldTestCredsExist: false, | 			credsAssertion: assertCredsDoNotExist, | ||||||
|  | 		}, | ||||||
|  | 		"large block statements": { | ||||||
|  | 			req: newdbplugin.NewUserRequest{ | ||||||
|  | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
|  | 					DisplayName: "test", | ||||||
|  | 					RoleName:    "test", | ||||||
|  | 				}, | ||||||
|  | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: newUserLargeBlockStatements, | ||||||
|  | 				}, | ||||||
|  | 				Password:   "somesecurepassword", | ||||||
|  | 				Expiration: time.Now().Add(1 * time.Minute), | ||||||
|  | 			}, | ||||||
|  | 			expectErr:      false, | ||||||
|  | 			credsAssertion: assertCredsExist, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -151,59 +225,55 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | |||||||
|  |  | ||||||
| 	for name, test := range tests { | 	for name, test := range tests { | ||||||
| 		t.Run(name, func(t *testing.T) { | 		t.Run(name, func(t *testing.T) { | ||||||
| 			usernameConfig := dbplugin.UsernameConfig{ |  | ||||||
| 				DisplayName: "test", |  | ||||||
| 				RoleName:    "test", |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			statements := dbplugin.Statements{ |  | ||||||
| 				Creation: test.createStmts, |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Give a timeout just in case the test decides to be problematic | 			// Give a timeout just in case the test decides to be problematic | ||||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) | 			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) | ||||||
| 			defer cancel() | 			defer cancel() | ||||||
|  |  | ||||||
| 			username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(time.Minute)) | 			resp, err := db.NewUser(ctx, test.req) | ||||||
| 			if err != nil { | 			if test.expectErr && err == nil { | ||||||
| 				t.Fatalf("err: %s", err) | 				t.Fatalf("err expected, got nil") | ||||||
|  | 			} | ||||||
|  | 			if !test.expectErr && err != nil { | ||||||
|  | 				t.Fatalf("no error expected, got: %s", err) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if !test.shouldTestCredsExist { | 			test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password) | ||||||
| 				// We're done here. |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { |  | ||||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Ensure that the role doesn't expire immediately | 			// Ensure that the role doesn't expire immediately | ||||||
| 			time.Sleep(2 * time.Second) | 			time.Sleep(2 * time.Second) | ||||||
|  |  | ||||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { | 			test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password) | ||||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) |  | ||||||
| 			} |  | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestPostgreSQL_RenewUser(t *testing.T) { | func TestUpdateUser_Password(t *testing.T) { | ||||||
| 	type testCase struct { | 	type testCase struct { | ||||||
| 		renewalStmts []string | 		statements     []string | ||||||
|  | 		expectErr      bool | ||||||
|  | 		credsAssertion credsAssertion | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	tests := map[string]testCase{ | 	tests := map[string]testCase{ | ||||||
| 		"empty renewal statements": { | 		"default statements": { | ||||||
| 			renewalStmts: nil, | 			statements:     nil, | ||||||
|  | 			expectErr:      false, | ||||||
|  | 			credsAssertion: assertCredsExist, | ||||||
| 		}, | 		}, | ||||||
| 		"default renewal name": { | 		"explicit default statements": { | ||||||
| 			renewalStmts: []string{defaultPostgresRenewSQL}, | 			statements:     []string{defaultChangePasswordStatement}, | ||||||
|  | 			expectErr:      false, | ||||||
|  | 			credsAssertion: assertCredsExist, | ||||||
| 		}, | 		}, | ||||||
| 		"default renewal username": { | 		"name instead of username": { | ||||||
| 			renewalStmts: []string{` | 			statements:     []string{`ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`}, | ||||||
| 				ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`, | 			expectErr:      false, | ||||||
| 			}, | 			credsAssertion: assertCredsExist, | ||||||
|  | 		}, | ||||||
|  | 		"bad statements": { | ||||||
|  | 			statements:     []string{`asdofyas8uf77asoiajv`}, | ||||||
|  | 			expectErr:      true, | ||||||
|  | 			credsAssertion: assertCredsDoNotExist, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -213,128 +283,251 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | |||||||
|  |  | ||||||
| 	for name, test := range tests { | 	for name, test := range tests { | ||||||
| 		t.Run(name, func(t *testing.T) { | 		t.Run(name, func(t *testing.T) { | ||||||
| 			statements := dbplugin.Statements{ | 			initialPass := "myreallysecurepassword" | ||||||
| 				Creation: []string{createAdminUser}, | 			createReq := newdbplugin.NewUserRequest{ | ||||||
| 				Renewal:  test.renewalStmts, | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
|  | 					DisplayName: "test", | ||||||
|  | 					RoleName:    "test", | ||||||
|  | 				}, | ||||||
|  | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: []string{createAdminUser}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   initialPass, | ||||||
|  | 				Expiration: time.Now().Add(2 * time.Second), | ||||||
|  | 			} | ||||||
|  | 			createResp := dbtesting.AssertNewUser(t, db, createReq) | ||||||
|  |  | ||||||
|  | 			assertCredsExist(t, db.ConnectionURL, createResp.Username, initialPass) | ||||||
|  |  | ||||||
|  | 			newPass := "somenewpassword" | ||||||
|  | 			updateReq := newdbplugin.UpdateUserRequest{ | ||||||
|  | 				Username: createResp.Username, | ||||||
|  | 				Password: &newdbplugin.ChangePassword{ | ||||||
|  | 					NewPassword: newPass, | ||||||
|  | 					Statements: newdbplugin.Statements{ | ||||||
|  | 						Commands: test.statements, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			usernameConfig := dbplugin.UsernameConfig{ | 			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||||
| 				DisplayName: "test", |  | ||||||
| 				RoleName:    "test", |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Give a timeout just in case the test decides to be problematic |  | ||||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |  | ||||||
| 			defer cancel() | 			defer cancel() | ||||||
|  | 			_, err := db.UpdateUser(ctx, updateReq) | ||||||
| 			username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(2*time.Second)) | 			if test.expectErr && err == nil { | ||||||
| 			if err != nil { | 				t.Fatalf("err expected, got nil") | ||||||
| 				t.Fatalf("err: %s", err) | 			} | ||||||
|  | 			if !test.expectErr && err != nil { | ||||||
|  | 				t.Fatalf("no error expected, got: %s", err) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { | 			test.credsAssertion(t, db.ConnectionURL, createResp.Username, newPass) | ||||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = db.RenewUser(ctx, statements, username, time.Now().Add(time.Minute)) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatalf("err: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Sleep longer than the initial expiration time |  | ||||||
| 			time.Sleep(2 * time.Second) |  | ||||||
|  |  | ||||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { |  | ||||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) |  | ||||||
| 			} |  | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	t.Run("user does not exist", func(t *testing.T) { | ||||||
|  | 		newPass := "somenewpassword" | ||||||
|  | 		updateReq := newdbplugin.UpdateUserRequest{ | ||||||
|  | 			Username: "missing-user", | ||||||
|  | 			Password: &newdbplugin.ChangePassword{ | ||||||
|  | 				NewPassword: newPass, | ||||||
|  | 				Statements:  newdbplugin.Statements{}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||||
|  | 		defer cancel() | ||||||
|  | 		_, err := db.UpdateUser(ctx, updateReq) | ||||||
|  | 		if err == nil { | ||||||
|  | 			t.Fatalf("err expected, got nil") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		assertCredsDoNotExist(t, db.ConnectionURL, updateReq.Username, newPass) | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestPostgreSQL_RotateRootCredentials(t *testing.T) { | func TestUpdateUser_Expiration(t *testing.T) { | ||||||
| 	type testCase struct { | 	type testCase struct { | ||||||
| 		statements []string | 		initialExpiration  time.Time | ||||||
|  | 		newExpiration      time.Time | ||||||
|  | 		expectedExpiration time.Time | ||||||
|  | 		statements         []string | ||||||
|  | 		expectErr          bool | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	now := time.Now() | ||||||
| 	tests := map[string]testCase{ | 	tests := map[string]testCase{ | ||||||
| 		"empty statements": { | 		"no statements": { | ||||||
| 			statements: nil, | 			initialExpiration:  now.Add(1 * time.Minute), | ||||||
|  | 			newExpiration:      now.Add(5 * time.Minute), | ||||||
|  | 			expectedExpiration: now.Add(5 * time.Minute), | ||||||
|  | 			statements:         nil, | ||||||
|  | 			expectErr:          false, | ||||||
| 		}, | 		}, | ||||||
| 		"default name": { | 		"default statements with name": { | ||||||
| 			statements: []string{` | 			initialExpiration:  now.Add(1 * time.Minute), | ||||||
| 				ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, | 			newExpiration:      now.Add(5 * time.Minute), | ||||||
| 			}, | 			expectedExpiration: now.Add(5 * time.Minute), | ||||||
|  | 			statements:         []string{defaultExpirationStatement}, | ||||||
|  | 			expectErr:          false, | ||||||
| 		}, | 		}, | ||||||
| 		"default username": { | 		"default statements with username": { | ||||||
| 			statements: []string{defaultPostgresRotateRootCredentialsSQL}, | 			initialExpiration:  now.Add(1 * time.Minute), | ||||||
|  | 			newExpiration:      now.Add(5 * time.Minute), | ||||||
|  | 			expectedExpiration: now.Add(5 * time.Minute), | ||||||
|  | 			statements:         []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`}, | ||||||
|  | 			expectErr:          false, | ||||||
|  | 		}, | ||||||
|  | 		"bad statements": { | ||||||
|  | 			initialExpiration:  now.Add(1 * time.Minute), | ||||||
|  | 			newExpiration:      now.Add(5 * time.Minute), | ||||||
|  | 			expectedExpiration: now.Add(1 * time.Minute), | ||||||
|  | 			statements:         []string{"ladshfouay09sgj"}, | ||||||
|  | 			expectErr:          true, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Shared test container for speed - there should not be any overlap between the tests | ||||||
|  | 	db, cleanup := getPostgreSQL(t, nil) | ||||||
|  | 	defer cleanup() | ||||||
|  |  | ||||||
| 	for name, test := range tests { | 	for name, test := range tests { | ||||||
| 		t.Run(name, func(t *testing.T) { | 		t.Run(name, func(t *testing.T) { | ||||||
| 			cleanup, connURL := postgresql.PrepareTestContainer(t, "latest") | 			password := "myreallysecurepassword" | ||||||
| 			defer cleanup() | 			initialExpiration := test.initialExpiration.Truncate(time.Second) | ||||||
| 			connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1) | 			createReq := newdbplugin.NewUserRequest{ | ||||||
|  | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
|  | 					DisplayName: "test", | ||||||
|  | 					RoleName:    "test", | ||||||
|  | 				}, | ||||||
|  | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: []string{createAdminUser}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   password, | ||||||
|  | 				Expiration: initialExpiration, | ||||||
|  | 			} | ||||||
|  | 			createResp := dbtesting.AssertNewUser(t, db, createReq) | ||||||
|  |  | ||||||
| 			connectionDetails := map[string]interface{}{ | 			assertCredsExist(t, db.ConnectionURL, createResp.Username, password) | ||||||
| 				"connection_url":       connURL, |  | ||||||
| 				"max_open_connections": 5, | 			actualExpiration := getExpiration(t, db, createResp.Username) | ||||||
| 				"username":             "postgres", | 			if actualExpiration.IsZero() { | ||||||
| 				"password":             "secret", | 				t.Fatalf("Initial expiration is zero but should be set") | ||||||
|  | 			} | ||||||
|  | 			if !actualExpiration.Equal(initialExpiration) { | ||||||
|  | 				t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, initialExpiration) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			db := new() | 			newExpiration := test.newExpiration.Truncate(time.Second) | ||||||
| 			connProducer := db.SQLConnectionProducer | 			updateReq := newdbplugin.UpdateUserRequest{ | ||||||
|  | 				Username: createResp.Username, | ||||||
|  | 				Expiration: &newdbplugin.ChangeExpiration{ | ||||||
|  | 					NewExpiration: newExpiration, | ||||||
|  | 					Statements: newdbplugin.Statements{ | ||||||
|  | 						Commands: test.statements, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			// Give a timeout just in case the test decides to be problematic | 			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |  | ||||||
| 			defer cancel() | 			defer cancel() | ||||||
|  | 			_, err := db.UpdateUser(ctx, updateReq) | ||||||
| 			_, err := db.Init(ctx, connectionDetails, true) | 			if test.expectErr && err == nil { | ||||||
| 			if err != nil { | 				t.Fatalf("err expected, got nil") | ||||||
| 				t.Fatalf("err: %s", err) | 			} | ||||||
|  | 			if !test.expectErr && err != nil { | ||||||
|  | 				t.Fatalf("no error expected, got: %s", err) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if !connProducer.Initialized { | 			expectedExpiration := test.expectedExpiration.Truncate(time.Second) | ||||||
| 				t.Fatal("Database should be initialized") | 			actualExpiration = getExpiration(t, db, createResp.Username) | ||||||
| 			} | 			if !actualExpiration.Equal(expectedExpiration) { | ||||||
|  | 				t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, expectedExpiration) | ||||||
| 			newConf, err := db.RotateRootCredentials(ctx, test.statements) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatalf("err: %v", err) |  | ||||||
| 			} |  | ||||||
| 			if newConf["password"] == "secret" { |  | ||||||
| 				t.Fatal("password was not updated") |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = db.Close() |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatalf("failed to close: %s", err) |  | ||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestPostgreSQL_RevokeUser(t *testing.T) { | func getExpiration(t testing.TB, db *PostgreSQL, username string) time.Time { | ||||||
|  | 	t.Helper() | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	query := fmt.Sprintf("select valuntil from pg_catalog.pg_user where usename = '%s'", username) | ||||||
|  | 	conn, err := db.getConnection(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to get connection to database: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	stmt, err := conn.PrepareContext(ctx, query) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to prepare statement: %s", err) | ||||||
|  | 	} | ||||||
|  | 	defer stmt.Close() | ||||||
|  |  | ||||||
|  | 	rows, err := stmt.QueryContext(ctx) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to execute query to get expiration: %s", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !rows.Next() { | ||||||
|  | 		return time.Time{} // No expiration | ||||||
|  | 	} | ||||||
|  | 	rawExp := "" | ||||||
|  | 	err = rows.Scan(&rawExp) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Unable to get raw expiration: %s", err) | ||||||
|  | 	} | ||||||
|  | 	if rawExp == "" { | ||||||
|  | 		return time.Time{} // No expiration | ||||||
|  | 	} | ||||||
|  | 	exp, err := time.Parse(time.RFC3339, rawExp) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Failed to parse expiration %q: %s", rawExp, err) | ||||||
|  | 	} | ||||||
|  | 	return exp | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDeleteUser(t *testing.T) { | ||||||
| 	type testCase struct { | 	type testCase struct { | ||||||
| 		revokeStmts []string | 		revokeStmts    []string | ||||||
|  | 		expectErr      bool | ||||||
|  | 		credsAssertion credsAssertion | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	tests := map[string]testCase{ | 	tests := map[string]testCase{ | ||||||
| 		"empty statements": { | 		"no statements": { | ||||||
| 			revokeStmts: nil, | 			revokeStmts: nil, | ||||||
|  | 			expectErr:   false, | ||||||
|  | 			// Wait for a short time before failing because postgres takes a moment to finish deleting the user | ||||||
|  | 			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), | ||||||
| 		}, | 		}, | ||||||
| 		"explicit default name": { | 		"statements with name": { | ||||||
| 			revokeStmts: []string{defaultPostgresRevocationSQL}, | 			revokeStmts: []string{` | ||||||
|  | 				REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; | ||||||
|  | 				REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; | ||||||
|  | 				REVOKE USAGE ON SCHEMA public FROM "{{name}}"; | ||||||
|  | 		 | ||||||
|  | 				DROP ROLE IF EXISTS "{{name}}";`}, | ||||||
|  | 			expectErr: false, | ||||||
|  | 			// Wait for a short time before failing because postgres takes a moment to finish deleting the user | ||||||
|  | 			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), | ||||||
| 		}, | 		}, | ||||||
| 		"explicit default username": { | 		"statements with username": { | ||||||
| 			revokeStmts: []string{` | 			revokeStmts: []string{` | ||||||
| 				REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}"; | 				REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}"; | ||||||
| 				REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}"; | 				REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}"; | ||||||
| 				REVOKE USAGE ON SCHEMA public FROM "{{username}}"; | 				REVOKE USAGE ON SCHEMA public FROM "{{username}}"; | ||||||
| 				 | 		 | ||||||
| 				DROP ROLE IF EXISTS "{{username}}";`, | 				DROP ROLE IF EXISTS "{{username}}";`}, | ||||||
| 			}, | 			expectErr: false, | ||||||
|  | 			// Wait for a short time before failing because postgres takes a moment to finish deleting the user | ||||||
|  | 			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), | ||||||
|  | 		}, | ||||||
|  | 		"bad statements": { | ||||||
|  | 			revokeStmts: []string{`8a9yhfoiasjff`}, | ||||||
|  | 			expectErr:   true, | ||||||
|  | 			// Wait for a short time before checking because postgres takes a moment to finish deleting the user | ||||||
|  | 			credsAssertion: assertCredsExistAfter(100 * time.Millisecond), | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -344,155 +537,91 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | |||||||
|  |  | ||||||
| 	for name, test := range tests { | 	for name, test := range tests { | ||||||
| 		t.Run(name, func(t *testing.T) { | 		t.Run(name, func(t *testing.T) { | ||||||
| 			statements := dbplugin.Statements{ | 			password := "myreallysecurepassword" | ||||||
| 				Creation:   []string{createAdminUser}, | 			createReq := newdbplugin.NewUserRequest{ | ||||||
| 				Revocation: test.revokeStmts, | 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||||
|  | 					DisplayName: "test", | ||||||
|  | 					RoleName:    "test", | ||||||
|  | 				}, | ||||||
|  | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: []string{createAdminUser}, | ||||||
|  | 				}, | ||||||
|  | 				Password:   password, | ||||||
|  | 				Expiration: time.Now().Add(2 * time.Second), | ||||||
|  | 			} | ||||||
|  | 			createResp := dbtesting.AssertNewUser(t, db, createReq) | ||||||
|  |  | ||||||
|  | 			assertCredsExist(t, db.ConnectionURL, createResp.Username, password) | ||||||
|  |  | ||||||
|  | 			deleteReq := newdbplugin.DeleteUserRequest{ | ||||||
|  | 				Username: createResp.Username, | ||||||
|  | 				Statements: newdbplugin.Statements{ | ||||||
|  | 					Commands: test.revokeStmts, | ||||||
|  | 				}, | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			usernameConfig := dbplugin.UsernameConfig{ | 			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||||
| 				DisplayName: "test", | 			defer cancel() | ||||||
| 				RoleName:    "test", |  | ||||||
|  | 			_, err := db.DeleteUser(ctx, deleteReq) | ||||||
|  | 			if test.expectErr && err == nil { | ||||||
|  | 				t.Fatalf("err expected, got nil") | ||||||
|  | 			} | ||||||
|  | 			if !test.expectErr && err != nil { | ||||||
|  | 				t.Fatalf("no error expected, got: %s", err) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | 			test.credsAssertion(t, db.ConnectionURL, createResp.Username, password) | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatalf("err: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { |  | ||||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Test default revoke statements |  | ||||||
| 			err = db.RevokeUser(context.Background(), statements, username) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatalf("err: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if err := testCredsExist(t, db.ConnectionURL, username, password); err == nil { |  | ||||||
| 				t.Fatal("Credentials were not revoked") |  | ||||||
| 			} |  | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestPostgreSQL_SetCredentials_missingArgs(t *testing.T) { | type credsAssertion func(t testing.TB, connURL, username, password string) | ||||||
| 	type testCase struct { |  | ||||||
| 		statements dbplugin.Statements |  | ||||||
| 		userConfig dbplugin.StaticUserConfig |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	tests := map[string]testCase{ | func assertCredsExist(t testing.TB, connURL, username, password string) { | ||||||
| 		"empty rotation statements": { | 	t.Helper() | ||||||
| 			statements: dbplugin.Statements{ | 	err := testCredsExist(t, connURL, username, password) | ||||||
| 				Rotation: nil, | 	if err != nil { | ||||||
| 			}, | 		t.Fatalf("user does not exist: %s", err) | ||||||
| 			userConfig: dbplugin.StaticUserConfig{ |  | ||||||
| 				Username: "testuser", |  | ||||||
| 				Password: "password", |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"empty username": { |  | ||||||
| 			statements: dbplugin.Statements{ |  | ||||||
| 				Rotation: []string{` |  | ||||||
| 					ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, |  | ||||||
| 				}, |  | ||||||
| 			}, |  | ||||||
| 			userConfig: dbplugin.StaticUserConfig{ |  | ||||||
| 				Username: "", |  | ||||||
| 				Password: "password", |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 		"empty password": { |  | ||||||
| 			statements: dbplugin.Statements{ |  | ||||||
| 				Rotation: []string{` |  | ||||||
| 					ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, |  | ||||||
| 				}, |  | ||||||
| 			}, |  | ||||||
| 			userConfig: dbplugin.StaticUserConfig{ |  | ||||||
| 				Username: "testuser", |  | ||||||
| 				Password: "", |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for name, test := range tests { |  | ||||||
| 		t.Run(name, func(t *testing.T) { |  | ||||||
| 			db := new() |  | ||||||
|  |  | ||||||
| 			username, password, err := db.SetCredentials(context.Background(), test.statements, test.userConfig) |  | ||||||
| 			if err == nil { |  | ||||||
| 				t.Fatalf("expected err, got nil") |  | ||||||
| 			} |  | ||||||
| 			if username != "" { |  | ||||||
| 				t.Fatalf("expected empty username, got [%s]", username) |  | ||||||
| 			} |  | ||||||
| 			if password != "" { |  | ||||||
| 				t.Fatalf("expected empty password, got [%s]", password) |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestPostgresSQL_SetCredentials(t *testing.T) { | func assertCredsDoNotExist(t testing.TB, connURL, username, password string) { | ||||||
| 	type testCase struct { | 	t.Helper() | ||||||
| 		rotationStmts []string | 	err := testCredsExist(t, connURL, username, password) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatalf("user should not exist but does") | ||||||
| 	} | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| 	tests := map[string]testCase{ | func waitUntilCredsDoNotExist(timeout time.Duration) credsAssertion { | ||||||
| 		"name rotation": { | 	return func(t testing.TB, connURL, username, password string) { | ||||||
| 			rotationStmts: []string{` | 		t.Helper() | ||||||
| 				ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, | 		ctx, cancel := context.WithTimeout(context.Background(), timeout) | ||||||
| 			}, | 		defer cancel() | ||||||
| 		}, |  | ||||||
| 		"username rotation": { | 		ticker := time.NewTicker(10 * time.Millisecond) | ||||||
| 			rotationStmts: []string{` | 		defer ticker.Stop() | ||||||
| 				ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';`, | 		for { | ||||||
| 			}, | 			select { | ||||||
| 		}, | 			case <-ctx.Done(): | ||||||
|  | 				t.Fatalf("Timed out waiting for user %s to be deleted", username) | ||||||
|  | 			case <-ticker.C: | ||||||
|  | 				err := testCredsExist(t, connURL, username, password) | ||||||
|  | 				if err != nil { | ||||||
|  | 					// Happy path | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| 	for name, test := range tests { | func assertCredsExistAfter(timeout time.Duration) credsAssertion { | ||||||
| 		t.Run(name, func(t *testing.T) { | 	return func(t testing.TB, connURL, username, password string) { | ||||||
| 			db, cleanup := getPostgreSQL(t, nil) | 		t.Helper() | ||||||
| 			defer cleanup() | 		time.Sleep(timeout) | ||||||
|  | 		assertCredsExist(t, connURL, username, password) | ||||||
| 			// create the database user |  | ||||||
| 			dbUser := "vaultstatictest" |  | ||||||
| 			initPassword := "password" |  | ||||||
| 			createTestPGUser(t, db.ConnectionURL, dbUser, initPassword, testRoleStaticCreate) |  | ||||||
|  |  | ||||||
| 			statements := dbplugin.Statements{ |  | ||||||
| 				Rotation: test.rotationStmts, |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			password, err := db.GenerateCredentials(context.Background()) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatal(err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			usernameConfig := dbplugin.StaticUserConfig{ |  | ||||||
| 				Username: dbUser, |  | ||||||
| 				Password: password, |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if err := testCredsExist(t, db.ConnectionURL, dbUser, initPassword); err != nil { |  | ||||||
| 				t.Fatalf("Could not connect with initial credentials: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			username, password, err := db.SetCredentials(context.Background(), statements, usernameConfig) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatalf("err: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if err := testCredsExist(t, db.ConnectionURL, username, password); err != nil { |  | ||||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if err := testCredsExist(t, db.ConnectionURL, username, initPassword); err == nil { |  | ||||||
| 				t.Fatalf("Should not be able to connect with initial credentials") |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -516,7 +645,7 @@ CREATE ROLE "{{name}}" WITH | |||||||
| GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; | GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; | ||||||
| ` | ` | ||||||
|  |  | ||||||
| var testPostgresBlockStatementRoleSlice = []string{ | var newUserLargeBlockStatements = []string{ | ||||||
| 	` | 	` | ||||||
| DO $$ | DO $$ | ||||||
| BEGIN | BEGIN | ||||||
| @@ -539,59 +668,6 @@ $$ | |||||||
| 	`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, | 	`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, | ||||||
| } | } | ||||||
|  |  | ||||||
| const defaultPostgresRevocationSQL = ` |  | ||||||
| REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; |  | ||||||
| REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; |  | ||||||
| REVOKE USAGE ON SCHEMA public FROM "{{name}}"; |  | ||||||
|  |  | ||||||
| DROP ROLE IF EXISTS "{{name}}"; |  | ||||||
| ` |  | ||||||
|  |  | ||||||
| const testRoleStaticCreate = ` |  | ||||||
| CREATE ROLE "{{name}}" WITH |  | ||||||
|   LOGIN |  | ||||||
|   PASSWORD '{{password}}'; |  | ||||||
| ` |  | ||||||
|  |  | ||||||
| // This is a copy of a test helper method also found in |  | ||||||
| // builtin/logical/database/rotation_test.go , and should be moved into a shared |  | ||||||
| // helper file in the future. |  | ||||||
| func createTestPGUser(t *testing.T, connURL string, username, password, query string) { |  | ||||||
| 	t.Helper() |  | ||||||
| 	conn, err := pq.ParseURL(connURL) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	db, err := sql.Open("postgres", conn) |  | ||||||
| 	defer db.Close() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Start a transaction |  | ||||||
| 	ctx := context.Background() |  | ||||||
| 	tx, err := db.BeginTx(ctx, nil) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer func() { |  | ||||||
| 		_ = tx.Rollback() |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	m := map[string]string{ |  | ||||||
| 		"name":     username, |  | ||||||
| 		"password": password, |  | ||||||
| 	} |  | ||||||
| 	if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	// Commit the transaction |  | ||||||
| 	if err := tx.Commit(); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestContainsMultilineStatement(t *testing.T) { | func TestContainsMultilineStatement(t *testing.T) { | ||||||
| 	type testCase struct { | 	type testCase struct { | ||||||
| 		Input    string | 		Input    string | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Michael Golowka
					Michael Golowka