mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	DBPW - Update PostgreSQL to adhere to v5 Database interface (#10061)
This commit is contained in:
		| @@ -3,35 +3,36 @@ package postgresql | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/errwrap" | ||||
| 	"github.com/hashicorp/go-multierror" | ||||
| 	"github.com/hashicorp/vault/api" | ||||
| 	"github.com/hashicorp/vault/sdk/database/dbplugin" | ||||
| 	"github.com/hashicorp/vault/sdk/database/helper/connutil" | ||||
| 	"github.com/hashicorp/vault/sdk/database/helper/credsutil" | ||||
| 	"github.com/hashicorp/vault/sdk/database/helper/dbutil" | ||||
| 	"github.com/hashicorp/vault/sdk/database/newdbplugin" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/dbtxn" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/strutil" | ||||
| 	"github.com/lib/pq" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	postgreSQLTypeName      = "postgres" | ||||
| 	defaultPostgresRenewSQL = ` | ||||
| 	postgreSQLTypeName         = "postgres" | ||||
| 	defaultExpirationStatement = ` | ||||
| ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; | ||||
| ` | ||||
| 	defaultPostgresRotateRootCredentialsSQL = ` | ||||
| 	defaultChangePasswordStatement = ` | ||||
| ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; | ||||
| ` | ||||
|  | ||||
| 	expirationFormat = "2006-01-02 15:04:05-0700" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	_ dbplugin.Database = &PostgreSQL{} | ||||
| 	_ newdbplugin.Database = &PostgreSQL{} | ||||
|  | ||||
| 	// postgresEndStatement is basically the word "END" but | ||||
| 	// surrounded by a word boundary to differentiate it from | ||||
| @@ -51,7 +52,7 @@ var ( | ||||
| func New() (interface{}, error) { | ||||
| 	db := new() | ||||
| 	// Wrap the plugin with middleware to sanitize errors | ||||
| 	dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) | ||||
| 	dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) | ||||
| 	return dbType, nil | ||||
| } | ||||
|  | ||||
| @@ -59,16 +60,8 @@ func new() *PostgreSQL { | ||||
| 	connProducer := &connutil.SQLConnectionProducer{} | ||||
| 	connProducer.Type = postgreSQLTypeName | ||||
|  | ||||
| 	credsProducer := &credsutil.SQLCredentialsProducer{ | ||||
| 		DisplayNameLen: 8, | ||||
| 		RoleNameLen:    8, | ||||
| 		UsernameLen:    63, | ||||
| 		Separator:      "-", | ||||
| 	} | ||||
|  | ||||
| 	db := &PostgreSQL{ | ||||
| 		SQLConnectionProducer: connProducer, | ||||
| 		CredentialsProducer:   credsProducer, | ||||
| 	} | ||||
|  | ||||
| 	return db | ||||
| @@ -81,14 +74,24 @@ func Run(apiTLSConfig *api.TLSConfig) error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) | ||||
| 	newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type PostgreSQL struct { | ||||
| 	*connutil.SQLConnectionProducer | ||||
| 	credsutil.CredentialsProducer | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) { | ||||
| 	newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection) | ||||
| 	if err != nil { | ||||
| 		return newdbplugin.InitializeResponse{}, err | ||||
| 	} | ||||
| 	resp := newdbplugin.InitializeResponse{ | ||||
| 		Config: newConf, | ||||
| 	} | ||||
| 	return resp, nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) Type() (string, error) { | ||||
| @@ -104,54 +107,58 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { | ||||
| 	return db.(*sql.DB), nil | ||||
| } | ||||
|  | ||||
| // SetCredentials uses provided information to set/create a user in the | ||||
| // database. Unlike CreateUser, this method requires a username be provided and | ||||
| // uses the name given, instead of generating a name. This is used for creating | ||||
| // and setting the password of static accounts, as well as rolling back | ||||
| // passwords in the database in the event an updated database fails to save in | ||||
| // Vault's storage. | ||||
| func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { | ||||
| 	if len(statements.Rotation) == 0 { | ||||
| 		statements.Rotation = []string{defaultPostgresRotateRootCredentialsSQL} | ||||
| func (p *PostgreSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) { | ||||
| 	if req.Username == "" { | ||||
| 		return newdbplugin.UpdateUserResponse{}, fmt.Errorf("missing username") | ||||
| 	} | ||||
| 	if req.Password == nil && req.Expiration == nil { | ||||
| 		return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested") | ||||
| 	} | ||||
|  | ||||
| 	username = staticUser.Username | ||||
| 	password = staticUser.Password | ||||
| 	if username == "" || password == "" { | ||||
| 		return "", "", errors.New("must provide both username and password") | ||||
| 	merr := &multierror.Error{} | ||||
| 	if req.Password != nil { | ||||
| 		err := p.changeUserPassword(ctx, req.Username, req.Password) | ||||
| 		merr = multierror.Append(merr, err) | ||||
| 	} | ||||
| 	if req.Expiration != nil { | ||||
| 		err := p.changeUserExpiration(ctx, req.Username, req.Expiration) | ||||
| 		merr = multierror.Append(merr, err) | ||||
| 	} | ||||
| 	return newdbplugin.UpdateUserResponse{}, merr.ErrorOrNil() | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error { | ||||
| 	stmts := changePass.Statements.Commands | ||||
| 	if len(stmts) == 0 { | ||||
| 		stmts = []string{defaultChangePasswordStatement} | ||||
| 	} | ||||
|  | ||||
| 	password := changePass.NewPassword | ||||
| 	if password == "" { | ||||
| 		return fmt.Errorf("missing password") | ||||
| 	} | ||||
|  | ||||
| 	// Grab the lock | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	// Get the connection | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 		return fmt.Errorf("unable to get connection: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Check if the role exists | ||||
| 	var exists bool | ||||
| 	err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) | ||||
| 	if err != nil && err != sql.ErrNoRows { | ||||
| 		return "", "", err | ||||
| 		return fmt.Errorf("user does not appear to exist: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Vault requires the database user already exist, and that the credentials | ||||
| 	// used to execute the rotation statements has sufficient privileges. | ||||
| 	stmts := statements.Rotation | ||||
|  | ||||
| 	// Start a transaction | ||||
| 	tx, err := db.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 		return fmt.Errorf("unable to start transaction: %w", err) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		_ = tx.Rollback() | ||||
| 	}() | ||||
| 	defer tx.Rollback() | ||||
|  | ||||
| 	// Execute each query | ||||
| 	for _, stmt := range stmts { | ||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||
| 			query = strings.TrimSpace(query) | ||||
| @@ -160,117 +167,30 @@ func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Sta | ||||
| 			} | ||||
|  | ||||
| 			m := map[string]string{ | ||||
| 				"name":     staticUser.Username, | ||||
| 				"username": staticUser.Username, | ||||
| 				"name":     username, | ||||
| 				"username": username, | ||||
| 				"password": password, | ||||
| 			} | ||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | ||||
| 				return "", "", err | ||||
| 				return fmt.Errorf("failed to execute query: %w", err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Commit the transaction | ||||
| 	if err := tx.Commit(); err != nil { | ||||
| 		return "", "", err | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return username, password, nil | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	statements = dbutil.StatementCompatibilityHelper(statements) | ||||
|  | ||||
| 	if len(statements.Creation) == 0 { | ||||
| 		return "", "", dbutil.ErrEmptyCreationStatement | ||||
| 	} | ||||
|  | ||||
| 	// Grab the lock | ||||
| func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, changeExp *newdbplugin.ChangeExpiration) error { | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	username, err = p.GenerateUsername(usernameConfig) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	password, err = p.GeneratePassword() | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	expirationStr, err := p.GenerateExpiration(expiration) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	// Get the connection | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	// Start a transaction | ||||
| 	tx, err := db.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
|  | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		tx.Rollback() | ||||
| 	}() | ||||
|  | ||||
| 	// Execute each query | ||||
| 	for _, stmt := range statements.Creation { | ||||
| 		if containsMultilineStatement(stmt) { | ||||
| 			// Execute it as-is. | ||||
| 			m := map[string]string{ | ||||
| 				"name":       username, | ||||
| 				"username":   username, | ||||
| 				"password":   password, | ||||
| 				"expiration": expirationStr, | ||||
| 			} | ||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil { | ||||
| 				return "", "", err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
| 		// Otherwise, it's fine to split the statements on the semicolon. | ||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||
| 			query = strings.TrimSpace(query) | ||||
| 			if len(query) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			m := map[string]string{ | ||||
| 				"name":       username, | ||||
| 				"username":   username, | ||||
| 				"password":   password, | ||||
| 				"expiration": expirationStr, | ||||
| 			} | ||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | ||||
| 				return "", "", err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Commit the transaction | ||||
| 	if err := tx.Commit(); err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	return username, password, nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	statements = dbutil.StatementCompatibilityHelper(statements) | ||||
|  | ||||
| 	renewStmts := statements.Renewal | ||||
| 	renewStmts := changeExp.Statements.Commands | ||||
| 	if len(renewStmts) == 0 { | ||||
| 		renewStmts = []string{defaultPostgresRenewSQL} | ||||
| 		renewStmts = []string{defaultExpirationStatement} | ||||
| 	} | ||||
|  | ||||
| 	db, err := p.getConnection(ctx) | ||||
| @@ -286,10 +206,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen | ||||
| 		tx.Rollback() | ||||
| 	}() | ||||
|  | ||||
| 	expirationStr, err := p.GenerateExpiration(expiration) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	expirationStr := changeExp.NewExpiration.Format(expirationFormat) | ||||
|  | ||||
| 	for _, stmt := range renewStmts { | ||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||
| @@ -312,21 +229,93 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen | ||||
| 	return tx.Commit() | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	// Grab the lock | ||||
| func (p *PostgreSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) { | ||||
| 	if len(req.Statements.Commands) == 0 { | ||||
| 		return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement | ||||
| 	} | ||||
|  | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	statements = dbutil.StatementCompatibilityHelper(statements) | ||||
|  | ||||
| 	if len(statements.Revocation) == 0 { | ||||
| 		return p.defaultRevokeUser(ctx, username) | ||||
| 	username, err := credsutil.GenerateUsername( | ||||
| 		credsutil.DisplayName(req.UsernameConfig.DisplayName, 8), | ||||
| 		credsutil.RoleName(req.UsernameConfig.RoleName, 8), | ||||
| 		credsutil.Separator("-"), | ||||
| 		credsutil.MaxLength(63), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return newdbplugin.NewUserResponse{}, err | ||||
| 	} | ||||
|  | ||||
| 	return p.customRevokeUser(ctx, username, statements.Revocation) | ||||
| 	expirationStr := req.Expiration.Format(expirationFormat) | ||||
|  | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	tx, err := db.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to start transaction: %w", err) | ||||
|  | ||||
| 	} | ||||
| 	defer tx.Rollback() | ||||
|  | ||||
| 	for _, stmt := range req.Statements.Commands { | ||||
| 		if containsMultilineStatement(stmt) { | ||||
| 			// Execute it as-is. | ||||
| 			m := map[string]string{ | ||||
| 				"name":       username, | ||||
| 				"username":   username, | ||||
| 				"password":   req.Password, | ||||
| 				"expiration": expirationStr, | ||||
| 			} | ||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil { | ||||
| 				return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err) | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
| 		// Otherwise, it's fine to split the statements on the semicolon. | ||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||
| 			query = strings.TrimSpace(query) | ||||
| 			if len(query) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			m := map[string]string{ | ||||
| 				"name":       username, | ||||
| 				"username":   username, | ||||
| 				"password":   req.Password, | ||||
| 				"expiration": expirationStr, | ||||
| 			} | ||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | ||||
| 				return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := tx.Commit(); err != nil { | ||||
| 		return newdbplugin.NewUserResponse{}, err | ||||
| 	} | ||||
|  | ||||
| 	resp := newdbplugin.NewUserResponse{ | ||||
| 		Username: username, | ||||
| 	} | ||||
| 	return resp, nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { | ||||
| func (p *PostgreSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) { | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	if len(req.Statements.Commands) == 0 { | ||||
| 		return newdbplugin.DeleteUserResponse{}, p.defaultDeleteUser(ctx, req.Username) | ||||
| 	} | ||||
|  | ||||
| 	return newdbplugin.DeleteUserResponse{}, p.customDeleteUser(ctx, req.Username, req.Statements.Commands) | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) customDeleteUser(ctx context.Context, username string, revocationStmts []string) error { | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| @@ -360,7 +349,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revo | ||||
| 	return tx.Commit() | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { | ||||
| func (p *PostgreSQL) defaultDeleteUser(ctx context.Context, username string) error { | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| @@ -471,65 +460,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	if len(p.Username) == 0 || len(p.Password) == 0 { | ||||
| 		return nil, errors.New("username and password are required to rotate") | ||||
| func (p *PostgreSQL) secretValues() map[string]string { | ||||
| 	return map[string]string{ | ||||
| 		p.Password: "[password]", | ||||
| 	} | ||||
|  | ||||
| 	rotateStatements := statements | ||||
| 	if len(rotateStatements) == 0 { | ||||
| 		rotateStatements = []string{defaultPostgresRotateRootCredentialsSQL} | ||||
| 	} | ||||
|  | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	tx, err := db.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		tx.Rollback() | ||||
| 	}() | ||||
|  | ||||
| 	password, err := p.GeneratePassword() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	for _, stmt := range rotateStatements { | ||||
| 		for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | ||||
| 			query = strings.TrimSpace(query) | ||||
| 			if len(query) == 0 { | ||||
| 				continue | ||||
| 			} | ||||
| 			m := map[string]string{ | ||||
| 				"name":     p.Username, | ||||
| 				"username": p.Username, | ||||
| 				"password": password, | ||||
| 			} | ||||
| 			if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := tx.Commit(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Close the database connection to ensure no new connections come in | ||||
| 	if err := db.Close(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	p.RawConfig["password"] = password | ||||
| 	return p.RawConfig, nil | ||||
| } | ||||
|  | ||||
| // containsMultilineStatement is a best effort to determine whether | ||||
|   | ||||
| @@ -9,9 +9,8 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/helper/testhelpers/postgresql" | ||||
| 	"github.com/hashicorp/vault/sdk/database/dbplugin" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/dbtxn" | ||||
| 	"github.com/lib/pq" | ||||
| 	"github.com/hashicorp/vault/sdk/database/newdbplugin" | ||||
| 	dbtesting "github.com/hashicorp/vault/sdk/database/newdbplugin/testing" | ||||
| ) | ||||
|  | ||||
| func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) { | ||||
| @@ -24,12 +23,14 @@ func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, f | ||||
| 		connectionDetails[k] = v | ||||
| 	} | ||||
|  | ||||
| 	db := new() | ||||
| 	_, err := db.Init(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	req := newdbplugin.InitializeRequest{ | ||||
| 		Config:           connectionDetails, | ||||
| 		VerifyConnection: true, | ||||
| 	} | ||||
|  | ||||
| 	db := new() | ||||
| 	dbtesting.AssertInitialize(t, db, req) | ||||
|  | ||||
| 	if !db.Initialized { | ||||
| 		t.Fatal("Database should be initialized") | ||||
| 	} | ||||
| @@ -58,90 +59,163 @@ func TestPostgreSQL_InitializeWithStringVals(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) { | ||||
| 	db := new() | ||||
|  | ||||
| 	usernameConfig := dbplugin.UsernameConfig{ | ||||
| 		DisplayName: "test", | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err == nil { | ||||
| 		t.Fatalf("expected err, got nil") | ||||
| 	} | ||||
| 	if username != "" { | ||||
| 		t.Fatalf("expected empty username, got [%s]", username) | ||||
| 	} | ||||
| 	if password != "" { | ||||
| 		t.Fatalf("expected empty password, got [%s]", password) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| func TestPostgreSQL_NewUser(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		createStmts          []string | ||||
| 		shouldTestCredsExist bool | ||||
| 		req            newdbplugin.NewUserRequest | ||||
| 		expectErr      bool | ||||
| 		credsAssertion credsAssertion | ||||
| 	} | ||||
|  | ||||
| 	tests := map[string]testCase{ | ||||
| 		"admin name": { | ||||
| 			createStmts: []string{` | ||||
| 				CREATE ROLE "{{name}}" WITH | ||||
| 				  LOGIN | ||||
| 				  PASSWORD '{{password}}' | ||||
| 				  VALID UNTIL '{{expiration}}'; | ||||
| 				GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, | ||||
| 		"no creation statements": { | ||||
| 			req: newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				// No statements | ||||
| 				Password:   "somesecurepassword", | ||||
| 				Expiration: time.Now().Add(1 * time.Minute), | ||||
| 			}, | ||||
| 			shouldTestCredsExist: true, | ||||
| 			expectErr:      true, | ||||
| 			credsAssertion: assertCredsDoNotExist, | ||||
| 		}, | ||||
| 		"admin name": { | ||||
| 			req: newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{` | ||||
| 						CREATE ROLE "{{name}}" WITH | ||||
| 						  LOGIN | ||||
| 						  PASSWORD '{{password}}' | ||||
| 						  VALID UNTIL '{{expiration}}'; | ||||
| 						GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, | ||||
| 					}, | ||||
| 				}, | ||||
| 				Password:   "somesecurepassword", | ||||
| 				Expiration: time.Now().Add(1 * time.Minute), | ||||
| 			}, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsExist, | ||||
| 		}, | ||||
| 		"admin username": { | ||||
| 			createStmts: []string{` | ||||
| 				CREATE ROLE "{{username}}" WITH | ||||
| 				  LOGIN | ||||
| 				  PASSWORD '{{password}}' | ||||
| 				  VALID UNTIL '{{expiration}}'; | ||||
| 				GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, | ||||
| 			req: newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{` | ||||
| 						CREATE ROLE "{{username}}" WITH | ||||
| 						  LOGIN | ||||
| 						  PASSWORD '{{password}}' | ||||
| 						  VALID UNTIL '{{expiration}}'; | ||||
| 						GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, | ||||
| 					}, | ||||
| 				}, | ||||
| 				Password:   "somesecurepassword", | ||||
| 				Expiration: time.Now().Add(1 * time.Minute), | ||||
| 			}, | ||||
| 			shouldTestCredsExist: true, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsExist, | ||||
| 		}, | ||||
| 		"read only name": { | ||||
| 			createStmts: []string{` | ||||
| 				CREATE ROLE "{{name}}" WITH | ||||
| 				  LOGIN | ||||
| 				  PASSWORD '{{password}}' | ||||
| 				  VALID UNTIL '{{expiration}}'; | ||||
| 				GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; | ||||
| 				GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, | ||||
| 			req: newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{` | ||||
| 						CREATE ROLE "{{name}}" WITH | ||||
| 						  LOGIN | ||||
| 						  PASSWORD '{{password}}' | ||||
| 						  VALID UNTIL '{{expiration}}'; | ||||
| 						GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; | ||||
| 						GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, | ||||
| 					}, | ||||
| 				}, | ||||
| 				Password:   "somesecurepassword", | ||||
| 				Expiration: time.Now().Add(1 * time.Minute), | ||||
| 			}, | ||||
| 			shouldTestCredsExist: true, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsExist, | ||||
| 		}, | ||||
| 		"read only username": { | ||||
| 			createStmts: []string{` | ||||
| 				CREATE ROLE "{{username}}" WITH | ||||
| 				  LOGIN | ||||
| 				  PASSWORD '{{password}}' | ||||
| 				  VALID UNTIL '{{expiration}}'; | ||||
| 				GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; | ||||
| 				GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, | ||||
| 			req: newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{` | ||||
| 						CREATE ROLE "{{username}}" WITH | ||||
| 						  LOGIN | ||||
| 						  PASSWORD '{{password}}' | ||||
| 						  VALID UNTIL '{{expiration}}'; | ||||
| 						GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; | ||||
| 						GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, | ||||
| 					}, | ||||
| 				}, | ||||
| 				Password:   "somesecurepassword", | ||||
| 				Expiration: time.Now().Add(1 * time.Minute), | ||||
| 			}, | ||||
| 			shouldTestCredsExist: true, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsExist, | ||||
| 		}, | ||||
| 		// https://github.com/hashicorp/vault/issues/6098 | ||||
| 		"reproduce GH-6098": { | ||||
| 			createStmts: []string{ | ||||
| 				// NOTE: "rolname" in the following line is not a typo. | ||||
| 				"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", | ||||
| 			req: newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{ | ||||
| 						// NOTE: "rolname" in the following line is not a typo. | ||||
| 						"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", | ||||
| 					}, | ||||
| 				}, | ||||
| 				Password:   "somesecurepassword", | ||||
| 				Expiration: time.Now().Add(1 * time.Minute), | ||||
| 			}, | ||||
| 			// This test statement doesn't generate creds. | ||||
| 			shouldTestCredsExist: false, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsDoNotExist, | ||||
| 		}, | ||||
| 		"reproduce issue with template": { | ||||
| 			createStmts: []string{ | ||||
| 				`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, | ||||
| 			req: newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{ | ||||
| 						`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, | ||||
| 					}, | ||||
| 				}, | ||||
| 				Password:   "somesecurepassword", | ||||
| 				Expiration: time.Now().Add(1 * time.Minute), | ||||
| 			}, | ||||
| 			// This test statement doesn't generate creds. | ||||
| 			shouldTestCredsExist: false, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsDoNotExist, | ||||
| 		}, | ||||
| 		"large block statements": { | ||||
| 			req: newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: newUserLargeBlockStatements, | ||||
| 				}, | ||||
| 				Password:   "somesecurepassword", | ||||
| 				Expiration: time.Now().Add(1 * time.Minute), | ||||
| 			}, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsExist, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -151,59 +225,55 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
|  | ||||
| 	for name, test := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			usernameConfig := dbplugin.UsernameConfig{ | ||||
| 				DisplayName: "test", | ||||
| 				RoleName:    "test", | ||||
| 			} | ||||
|  | ||||
| 			statements := dbplugin.Statements{ | ||||
| 				Creation: test.createStmts, | ||||
| 			} | ||||
|  | ||||
| 			// Give a timeout just in case the test decides to be problematic | ||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) | ||||
| 			defer cancel() | ||||
|  | ||||
| 			username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			resp, err := db.NewUser(ctx, test.req) | ||||
| 			if test.expectErr && err == nil { | ||||
| 				t.Fatalf("err expected, got nil") | ||||
| 			} | ||||
| 			if !test.expectErr && err != nil { | ||||
| 				t.Fatalf("no error expected, got: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			if !test.shouldTestCredsExist { | ||||
| 				// We're done here. | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { | ||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 			} | ||||
| 			test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password) | ||||
|  | ||||
| 			// Ensure that the role doesn't expire immediately | ||||
| 			time.Sleep(2 * time.Second) | ||||
|  | ||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { | ||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 			} | ||||
| 			test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| func TestUpdateUser_Password(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		renewalStmts []string | ||||
| 		statements     []string | ||||
| 		expectErr      bool | ||||
| 		credsAssertion credsAssertion | ||||
| 	} | ||||
|  | ||||
| 	tests := map[string]testCase{ | ||||
| 		"empty renewal statements": { | ||||
| 			renewalStmts: nil, | ||||
| 		"default statements": { | ||||
| 			statements:     nil, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsExist, | ||||
| 		}, | ||||
| 		"default renewal name": { | ||||
| 			renewalStmts: []string{defaultPostgresRenewSQL}, | ||||
| 		"explicit default statements": { | ||||
| 			statements:     []string{defaultChangePasswordStatement}, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsExist, | ||||
| 		}, | ||||
| 		"default renewal username": { | ||||
| 			renewalStmts: []string{` | ||||
| 				ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`, | ||||
| 			}, | ||||
| 		"name instead of username": { | ||||
| 			statements:     []string{`ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`}, | ||||
| 			expectErr:      false, | ||||
| 			credsAssertion: assertCredsExist, | ||||
| 		}, | ||||
| 		"bad statements": { | ||||
| 			statements:     []string{`asdofyas8uf77asoiajv`}, | ||||
| 			expectErr:      true, | ||||
| 			credsAssertion: assertCredsDoNotExist, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -213,128 +283,251 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
|  | ||||
| 	for name, test := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			statements := dbplugin.Statements{ | ||||
| 				Creation: []string{createAdminUser}, | ||||
| 				Renewal:  test.renewalStmts, | ||||
| 			initialPass := "myreallysecurepassword" | ||||
| 			createReq := newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{createAdminUser}, | ||||
| 				}, | ||||
| 				Password:   initialPass, | ||||
| 				Expiration: time.Now().Add(2 * time.Second), | ||||
| 			} | ||||
| 			createResp := dbtesting.AssertNewUser(t, db, createReq) | ||||
|  | ||||
| 			assertCredsExist(t, db.ConnectionURL, createResp.Username, initialPass) | ||||
|  | ||||
| 			newPass := "somenewpassword" | ||||
| 			updateReq := newdbplugin.UpdateUserRequest{ | ||||
| 				Username: createResp.Username, | ||||
| 				Password: &newdbplugin.ChangePassword{ | ||||
| 					NewPassword: newPass, | ||||
| 					Statements: newdbplugin.Statements{ | ||||
| 						Commands: test.statements, | ||||
| 					}, | ||||
| 				}, | ||||
| 			} | ||||
|  | ||||
| 			usernameConfig := dbplugin.UsernameConfig{ | ||||
| 				DisplayName: "test", | ||||
| 				RoleName:    "test", | ||||
| 			} | ||||
|  | ||||
| 			// Give a timeout just in case the test decides to be problematic | ||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) | ||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 			defer cancel() | ||||
|  | ||||
| 			username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			_, err := db.UpdateUser(ctx, updateReq) | ||||
| 			if test.expectErr && err == nil { | ||||
| 				t.Fatalf("err expected, got nil") | ||||
| 			} | ||||
| 			if !test.expectErr && err != nil { | ||||
| 				t.Fatalf("no error expected, got: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { | ||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			err = db.RenewUser(ctx, statements, username, time.Now().Add(time.Minute)) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			// Sleep longer than the initial expiration time | ||||
| 			time.Sleep(2 * time.Second) | ||||
|  | ||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { | ||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 			} | ||||
| 			test.credsAssertion(t, db.ConnectionURL, createResp.Username, newPass) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	t.Run("user does not exist", func(t *testing.T) { | ||||
| 		newPass := "somenewpassword" | ||||
| 		updateReq := newdbplugin.UpdateUserRequest{ | ||||
| 			Username: "missing-user", | ||||
| 			Password: &newdbplugin.ChangePassword{ | ||||
| 				NewPassword: newPass, | ||||
| 				Statements:  newdbplugin.Statements{}, | ||||
| 			}, | ||||
| 		} | ||||
|  | ||||
| 		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 		defer cancel() | ||||
| 		_, err := db.UpdateUser(ctx, updateReq) | ||||
| 		if err == nil { | ||||
| 			t.Fatalf("err expected, got nil") | ||||
| 		} | ||||
|  | ||||
| 		assertCredsDoNotExist(t, db.ConnectionURL, updateReq.Username, newPass) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestPostgreSQL_RotateRootCredentials(t *testing.T) { | ||||
| func TestUpdateUser_Expiration(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		statements []string | ||||
| 		initialExpiration  time.Time | ||||
| 		newExpiration      time.Time | ||||
| 		expectedExpiration time.Time | ||||
| 		statements         []string | ||||
| 		expectErr          bool | ||||
| 	} | ||||
|  | ||||
| 	now := time.Now() | ||||
| 	tests := map[string]testCase{ | ||||
| 		"empty statements": { | ||||
| 			statements: nil, | ||||
| 		"no statements": { | ||||
| 			initialExpiration:  now.Add(1 * time.Minute), | ||||
| 			newExpiration:      now.Add(5 * time.Minute), | ||||
| 			expectedExpiration: now.Add(5 * time.Minute), | ||||
| 			statements:         nil, | ||||
| 			expectErr:          false, | ||||
| 		}, | ||||
| 		"default name": { | ||||
| 			statements: []string{` | ||||
| 				ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, | ||||
| 			}, | ||||
| 		"default statements with name": { | ||||
| 			initialExpiration:  now.Add(1 * time.Minute), | ||||
| 			newExpiration:      now.Add(5 * time.Minute), | ||||
| 			expectedExpiration: now.Add(5 * time.Minute), | ||||
| 			statements:         []string{defaultExpirationStatement}, | ||||
| 			expectErr:          false, | ||||
| 		}, | ||||
| 		"default username": { | ||||
| 			statements: []string{defaultPostgresRotateRootCredentialsSQL}, | ||||
| 		"default statements with username": { | ||||
| 			initialExpiration:  now.Add(1 * time.Minute), | ||||
| 			newExpiration:      now.Add(5 * time.Minute), | ||||
| 			expectedExpiration: now.Add(5 * time.Minute), | ||||
| 			statements:         []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`}, | ||||
| 			expectErr:          false, | ||||
| 		}, | ||||
| 		"bad statements": { | ||||
| 			initialExpiration:  now.Add(1 * time.Minute), | ||||
| 			newExpiration:      now.Add(5 * time.Minute), | ||||
| 			expectedExpiration: now.Add(1 * time.Minute), | ||||
| 			statements:         []string{"ladshfouay09sgj"}, | ||||
| 			expectErr:          true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// Shared test container for speed - there should not be any overlap between the tests | ||||
| 	db, cleanup := getPostgreSQL(t, nil) | ||||
| 	defer cleanup() | ||||
|  | ||||
| 	for name, test := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			cleanup, connURL := postgresql.PrepareTestContainer(t, "latest") | ||||
| 			defer cleanup() | ||||
| 			connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1) | ||||
| 			password := "myreallysecurepassword" | ||||
| 			initialExpiration := test.initialExpiration.Truncate(time.Second) | ||||
| 			createReq := newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{createAdminUser}, | ||||
| 				}, | ||||
| 				Password:   password, | ||||
| 				Expiration: initialExpiration, | ||||
| 			} | ||||
| 			createResp := dbtesting.AssertNewUser(t, db, createReq) | ||||
|  | ||||
| 			connectionDetails := map[string]interface{}{ | ||||
| 				"connection_url":       connURL, | ||||
| 				"max_open_connections": 5, | ||||
| 				"username":             "postgres", | ||||
| 				"password":             "secret", | ||||
| 			assertCredsExist(t, db.ConnectionURL, createResp.Username, password) | ||||
|  | ||||
| 			actualExpiration := getExpiration(t, db, createResp.Username) | ||||
| 			if actualExpiration.IsZero() { | ||||
| 				t.Fatalf("Initial expiration is zero but should be set") | ||||
| 			} | ||||
| 			if !actualExpiration.Equal(initialExpiration) { | ||||
| 				t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, initialExpiration) | ||||
| 			} | ||||
|  | ||||
| 			db := new() | ||||
| 			connProducer := db.SQLConnectionProducer | ||||
| 			newExpiration := test.newExpiration.Truncate(time.Second) | ||||
| 			updateReq := newdbplugin.UpdateUserRequest{ | ||||
| 				Username: createResp.Username, | ||||
| 				Expiration: &newdbplugin.ChangeExpiration{ | ||||
| 					NewExpiration: newExpiration, | ||||
| 					Statements: newdbplugin.Statements{ | ||||
| 						Commands: test.statements, | ||||
| 					}, | ||||
| 				}, | ||||
| 			} | ||||
|  | ||||
| 			// Give a timeout just in case the test decides to be problematic | ||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) | ||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 			defer cancel() | ||||
|  | ||||
| 			_, err := db.Init(ctx, connectionDetails, true) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			_, err := db.UpdateUser(ctx, updateReq) | ||||
| 			if test.expectErr && err == nil { | ||||
| 				t.Fatalf("err expected, got nil") | ||||
| 			} | ||||
| 			if !test.expectErr && err != nil { | ||||
| 				t.Fatalf("no error expected, got: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			if !connProducer.Initialized { | ||||
| 				t.Fatal("Database should be initialized") | ||||
| 			} | ||||
|  | ||||
| 			newConf, err := db.RotateRootCredentials(ctx, test.statements) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("err: %v", err) | ||||
| 			} | ||||
| 			if newConf["password"] == "secret" { | ||||
| 				t.Fatal("password was not updated") | ||||
| 			} | ||||
|  | ||||
| 			err = db.Close() | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("failed to close: %s", err) | ||||
| 			expectedExpiration := test.expectedExpiration.Truncate(time.Second) | ||||
| 			actualExpiration = getExpiration(t, db, createResp.Username) | ||||
| 			if !actualExpiration.Equal(expectedExpiration) { | ||||
| 				t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, expectedExpiration) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
| func getExpiration(t testing.TB, db *PostgreSQL, username string) time.Time { | ||||
| 	t.Helper() | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	query := fmt.Sprintf("select valuntil from pg_catalog.pg_user where usename = '%s'", username) | ||||
| 	conn, err := db.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to get connection to database: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	stmt, err := conn.PrepareContext(ctx, query) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to prepare statement: %s", err) | ||||
| 	} | ||||
| 	defer stmt.Close() | ||||
|  | ||||
| 	rows, err := stmt.QueryContext(ctx) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to execute query to get expiration: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if !rows.Next() { | ||||
| 		return time.Time{} // No expiration | ||||
| 	} | ||||
| 	rawExp := "" | ||||
| 	err = rows.Scan(&rawExp) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Unable to get raw expiration: %s", err) | ||||
| 	} | ||||
| 	if rawExp == "" { | ||||
| 		return time.Time{} // No expiration | ||||
| 	} | ||||
| 	exp, err := time.Parse(time.RFC3339, rawExp) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to parse expiration %q: %s", rawExp, err) | ||||
| 	} | ||||
| 	return exp | ||||
| } | ||||
|  | ||||
| func TestDeleteUser(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		revokeStmts []string | ||||
| 		revokeStmts    []string | ||||
| 		expectErr      bool | ||||
| 		credsAssertion credsAssertion | ||||
| 	} | ||||
|  | ||||
| 	tests := map[string]testCase{ | ||||
| 		"empty statements": { | ||||
| 		"no statements": { | ||||
| 			revokeStmts: nil, | ||||
| 			expectErr:   false, | ||||
| 			// Wait for a short time before failing because postgres takes a moment to finish deleting the user | ||||
| 			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), | ||||
| 		}, | ||||
| 		"explicit default name": { | ||||
| 			revokeStmts: []string{defaultPostgresRevocationSQL}, | ||||
| 		"statements with name": { | ||||
| 			revokeStmts: []string{` | ||||
| 				REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; | ||||
| 				REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; | ||||
| 				REVOKE USAGE ON SCHEMA public FROM "{{name}}"; | ||||
| 		 | ||||
| 				DROP ROLE IF EXISTS "{{name}}";`}, | ||||
| 			expectErr: false, | ||||
| 			// Wait for a short time before failing because postgres takes a moment to finish deleting the user | ||||
| 			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), | ||||
| 		}, | ||||
| 		"explicit default username": { | ||||
| 		"statements with username": { | ||||
| 			revokeStmts: []string{` | ||||
| 				REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}"; | ||||
| 				REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}"; | ||||
| 				REVOKE USAGE ON SCHEMA public FROM "{{username}}"; | ||||
| 				 | ||||
| 				DROP ROLE IF EXISTS "{{username}}";`, | ||||
| 			}, | ||||
| 		 | ||||
| 				DROP ROLE IF EXISTS "{{username}}";`}, | ||||
| 			expectErr: false, | ||||
| 			// Wait for a short time before failing because postgres takes a moment to finish deleting the user | ||||
| 			credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), | ||||
| 		}, | ||||
| 		"bad statements": { | ||||
| 			revokeStmts: []string{`8a9yhfoiasjff`}, | ||||
| 			expectErr:   true, | ||||
| 			// Wait for a short time before checking because postgres takes a moment to finish deleting the user | ||||
| 			credsAssertion: assertCredsExistAfter(100 * time.Millisecond), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -344,155 +537,91 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	for name, test := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			statements := dbplugin.Statements{ | ||||
| 				Creation:   []string{createAdminUser}, | ||||
| 				Revocation: test.revokeStmts, | ||||
| 			password := "myreallysecurepassword" | ||||
| 			createReq := newdbplugin.NewUserRequest{ | ||||
| 				UsernameConfig: newdbplugin.UsernameMetadata{ | ||||
| 					DisplayName: "test", | ||||
| 					RoleName:    "test", | ||||
| 				}, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: []string{createAdminUser}, | ||||
| 				}, | ||||
| 				Password:   password, | ||||
| 				Expiration: time.Now().Add(2 * time.Second), | ||||
| 			} | ||||
| 			createResp := dbtesting.AssertNewUser(t, db, createReq) | ||||
|  | ||||
| 			assertCredsExist(t, db.ConnectionURL, createResp.Username, password) | ||||
|  | ||||
| 			deleteReq := newdbplugin.DeleteUserRequest{ | ||||
| 				Username: createResp.Username, | ||||
| 				Statements: newdbplugin.Statements{ | ||||
| 					Commands: test.revokeStmts, | ||||
| 				}, | ||||
| 			} | ||||
|  | ||||
| 			usernameConfig := dbplugin.UsernameConfig{ | ||||
| 				DisplayName: "test", | ||||
| 				RoleName:    "test", | ||||
| 			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 			defer cancel() | ||||
|  | ||||
| 			_, err := db.DeleteUser(ctx, deleteReq) | ||||
| 			if test.expectErr && err == nil { | ||||
| 				t.Fatalf("err expected, got nil") | ||||
| 			} | ||||
| 			if !test.expectErr && err != nil { | ||||
| 				t.Fatalf("no error expected, got: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { | ||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			// Test default revoke statements | ||||
| 			err = db.RevokeUser(context.Background(), statements, username) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			if err := testCredsExist(t, db.ConnectionURL, username, password); err == nil { | ||||
| 				t.Fatal("Credentials were not revoked") | ||||
| 			} | ||||
| 			test.credsAssertion(t, db.ConnectionURL, createResp.Username, password) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPostgreSQL_SetCredentials_missingArgs(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		statements dbplugin.Statements | ||||
| 		userConfig dbplugin.StaticUserConfig | ||||
| 	} | ||||
| type credsAssertion func(t testing.TB, connURL, username, password string) | ||||
|  | ||||
| 	tests := map[string]testCase{ | ||||
| 		"empty rotation statements": { | ||||
| 			statements: dbplugin.Statements{ | ||||
| 				Rotation: nil, | ||||
| 			}, | ||||
| 			userConfig: dbplugin.StaticUserConfig{ | ||||
| 				Username: "testuser", | ||||
| 				Password: "password", | ||||
| 			}, | ||||
| 		}, | ||||
| 		"empty username": { | ||||
| 			statements: dbplugin.Statements{ | ||||
| 				Rotation: []string{` | ||||
| 					ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, | ||||
| 				}, | ||||
| 			}, | ||||
| 			userConfig: dbplugin.StaticUserConfig{ | ||||
| 				Username: "", | ||||
| 				Password: "password", | ||||
| 			}, | ||||
| 		}, | ||||
| 		"empty password": { | ||||
| 			statements: dbplugin.Statements{ | ||||
| 				Rotation: []string{` | ||||
| 					ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, | ||||
| 				}, | ||||
| 			}, | ||||
| 			userConfig: dbplugin.StaticUserConfig{ | ||||
| 				Username: "testuser", | ||||
| 				Password: "", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for name, test := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			db := new() | ||||
|  | ||||
| 			username, password, err := db.SetCredentials(context.Background(), test.statements, test.userConfig) | ||||
| 			if err == nil { | ||||
| 				t.Fatalf("expected err, got nil") | ||||
| 			} | ||||
| 			if username != "" { | ||||
| 				t.Fatalf("expected empty username, got [%s]", username) | ||||
| 			} | ||||
| 			if password != "" { | ||||
| 				t.Fatalf("expected empty password, got [%s]", password) | ||||
| 			} | ||||
| 		}) | ||||
| func assertCredsExist(t testing.TB, connURL, username, password string) { | ||||
| 	t.Helper() | ||||
| 	err := testCredsExist(t, connURL, username, password) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("user does not exist: %s", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPostgresSQL_SetCredentials(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		rotationStmts []string | ||||
| func assertCredsDoNotExist(t testing.TB, connURL, username, password string) { | ||||
| 	t.Helper() | ||||
| 	err := testCredsExist(t, connURL, username, password) | ||||
| 	if err == nil { | ||||
| 		t.Fatalf("user should not exist but does") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| 	tests := map[string]testCase{ | ||||
| 		"name rotation": { | ||||
| 			rotationStmts: []string{` | ||||
| 				ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, | ||||
| 			}, | ||||
| 		}, | ||||
| 		"username rotation": { | ||||
| 			rotationStmts: []string{` | ||||
| 				ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';`, | ||||
| 			}, | ||||
| 		}, | ||||
| func waitUntilCredsDoNotExist(timeout time.Duration) credsAssertion { | ||||
| 	return func(t testing.TB, connURL, username, password string) { | ||||
| 		t.Helper() | ||||
| 		ctx, cancel := context.WithTimeout(context.Background(), timeout) | ||||
| 		defer cancel() | ||||
|  | ||||
| 		ticker := time.NewTicker(10 * time.Millisecond) | ||||
| 		defer ticker.Stop() | ||||
| 		for { | ||||
| 			select { | ||||
| 			case <-ctx.Done(): | ||||
| 				t.Fatalf("Timed out waiting for user %s to be deleted", username) | ||||
| 			case <-ticker.C: | ||||
| 				err := testCredsExist(t, connURL, username, password) | ||||
| 				if err != nil { | ||||
| 					// Happy path | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| 	for name, test := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			db, cleanup := getPostgreSQL(t, nil) | ||||
| 			defer cleanup() | ||||
|  | ||||
| 			// create the database user | ||||
| 			dbUser := "vaultstatictest" | ||||
| 			initPassword := "password" | ||||
| 			createTestPGUser(t, db.ConnectionURL, dbUser, initPassword, testRoleStaticCreate) | ||||
|  | ||||
| 			statements := dbplugin.Statements{ | ||||
| 				Rotation: test.rotationStmts, | ||||
| 			} | ||||
|  | ||||
| 			password, err := db.GenerateCredentials(context.Background()) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
|  | ||||
| 			usernameConfig := dbplugin.StaticUserConfig{ | ||||
| 				Username: dbUser, | ||||
| 				Password: password, | ||||
| 			} | ||||
|  | ||||
| 			if err := testCredsExist(t, db.ConnectionURL, dbUser, initPassword); err != nil { | ||||
| 				t.Fatalf("Could not connect with initial credentials: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			username, password, err := db.SetCredentials(context.Background(), statements, usernameConfig) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("err: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			if err := testCredsExist(t, db.ConnectionURL, username, password); err != nil { | ||||
| 				t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 			} | ||||
|  | ||||
| 			if err := testCredsExist(t, db.ConnectionURL, username, initPassword); err == nil { | ||||
| 				t.Fatalf("Should not be able to connect with initial credentials") | ||||
| 			} | ||||
| 		}) | ||||
| func assertCredsExistAfter(timeout time.Duration) credsAssertion { | ||||
| 	return func(t testing.TB, connURL, username, password string) { | ||||
| 		t.Helper() | ||||
| 		time.Sleep(timeout) | ||||
| 		assertCredsExist(t, connURL, username, password) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -516,7 +645,7 @@ CREATE ROLE "{{name}}" WITH | ||||
| GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; | ||||
| ` | ||||
|  | ||||
| var testPostgresBlockStatementRoleSlice = []string{ | ||||
| var newUserLargeBlockStatements = []string{ | ||||
| 	` | ||||
| DO $$ | ||||
| BEGIN | ||||
| @@ -539,59 +668,6 @@ $$ | ||||
| 	`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, | ||||
| } | ||||
|  | ||||
| const defaultPostgresRevocationSQL = ` | ||||
| REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; | ||||
| REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; | ||||
| REVOKE USAGE ON SCHEMA public FROM "{{name}}"; | ||||
|  | ||||
| DROP ROLE IF EXISTS "{{name}}"; | ||||
| ` | ||||
|  | ||||
| const testRoleStaticCreate = ` | ||||
| CREATE ROLE "{{name}}" WITH | ||||
|   LOGIN | ||||
|   PASSWORD '{{password}}'; | ||||
| ` | ||||
|  | ||||
| // This is a copy of a test helper method also found in | ||||
| // builtin/logical/database/rotation_test.go , and should be moved into a shared | ||||
| // helper file in the future. | ||||
| func createTestPGUser(t *testing.T, connURL string, username, password, query string) { | ||||
| 	t.Helper() | ||||
| 	conn, err := pq.ParseURL(connURL) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	db, err := sql.Open("postgres", conn) | ||||
| 	defer db.Close() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Start a transaction | ||||
| 	ctx := context.Background() | ||||
| 	tx, err := db.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		_ = tx.Rollback() | ||||
| 	}() | ||||
|  | ||||
| 	m := map[string]string{ | ||||
| 		"name":     username, | ||||
| 		"password": password, | ||||
| 	} | ||||
| 	if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	// Commit the transaction | ||||
| 	if err := tx.Commit(); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestContainsMultilineStatement(t *testing.T) { | ||||
| 	type testCase struct { | ||||
| 		Input    string | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Michael Golowka
					Michael Golowka