From 28f90f1efee64e402713a3f11dbc940eb72b9b9d Mon Sep 17 00:00:00 2001 From: Michael Golowka <72365+pcman312@users.noreply.github.com> Date: Wed, 7 Oct 2020 12:58:11 -0600 Subject: [PATCH] DBPW - Update PostgreSQL to adhere to v5 Database interface (#10061) --- plugins/database/postgresql/postgresql.go | 356 ++++---- .../database/postgresql/postgresql_test.go | 814 ++++++++++-------- 2 files changed, 590 insertions(+), 580 deletions(-) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index d0094b53b5..d8aa0a094d 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -3,35 +3,36 @@ package postgresql import ( "context" "database/sql" - "errors" "fmt" "regexp" "strings" - "time" "github.com/hashicorp/errwrap" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/sdk/database/dbplugin" "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/credsutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil" + "github.com/hashicorp/vault/sdk/database/newdbplugin" "github.com/hashicorp/vault/sdk/helper/dbtxn" "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/lib/pq" ) const ( - postgreSQLTypeName = "postgres" - defaultPostgresRenewSQL = ` + postgreSQLTypeName = "postgres" + defaultExpirationStatement = ` ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; ` - defaultPostgresRotateRootCredentialsSQL = ` + defaultChangePasswordStatement = ` ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; ` + + expirationFormat = "2006-01-02 15:04:05-0700" ) var ( - _ dbplugin.Database = &PostgreSQL{} + _ newdbplugin.Database = &PostgreSQL{} // postgresEndStatement is basically the word "END" but // surrounded by a word boundary to differentiate it from @@ -51,7 +52,7 @@ var ( func New() (interface{}, error) { db := new() // Wrap the plugin with middleware to sanitize errors - dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) + dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) return dbType, nil } @@ -59,16 +60,8 @@ func new() *PostgreSQL { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = postgreSQLTypeName - credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 8, - RoleNameLen: 8, - UsernameLen: 63, - Separator: "-", - } - db := &PostgreSQL{ SQLConnectionProducer: connProducer, - CredentialsProducer: credsProducer, } return db @@ -81,14 +74,24 @@ func Run(apiTLSConfig *api.TLSConfig) error { return err } - dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) + newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) return nil } type PostgreSQL struct { *connutil.SQLConnectionProducer - credsutil.CredentialsProducer +} + +func (p *PostgreSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) { + newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection) + if err != nil { + return newdbplugin.InitializeResponse{}, err + } + resp := newdbplugin.InitializeResponse{ + Config: newConf, + } + return resp, nil } func (p *PostgreSQL) Type() (string, error) { @@ -104,54 +107,58 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { return db.(*sql.DB), nil } -// SetCredentials uses provided information to set/create a user in the -// database. Unlike CreateUser, this method requires a username be provided and -// uses the name given, instead of generating a name. This is used for creating -// and setting the password of static accounts, as well as rolling back -// passwords in the database in the event an updated database fails to save in -// Vault's storage. -func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { - if len(statements.Rotation) == 0 { - statements.Rotation = []string{defaultPostgresRotateRootCredentialsSQL} +func (p *PostgreSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) { + if req.Username == "" { + return newdbplugin.UpdateUserResponse{}, fmt.Errorf("missing username") + } + if req.Password == nil && req.Expiration == nil { + return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested") } - username = staticUser.Username - password = staticUser.Password - if username == "" || password == "" { - return "", "", errors.New("must provide both username and password") + merr := &multierror.Error{} + if req.Password != nil { + err := p.changeUserPassword(ctx, req.Username, req.Password) + merr = multierror.Append(merr, err) + } + if req.Expiration != nil { + err := p.changeUserExpiration(ctx, req.Username, req.Expiration) + merr = multierror.Append(merr, err) + } + return newdbplugin.UpdateUserResponse{}, merr.ErrorOrNil() +} + +func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error { + stmts := changePass.Statements.Commands + if len(stmts) == 0 { + stmts = []string{defaultChangePasswordStatement} + } + + password := changePass.NewPassword + if password == "" { + return fmt.Errorf("missing password") } - // Grab the lock p.Lock() defer p.Unlock() - // Get the connection db, err := p.getConnection(ctx) if err != nil { - return "", "", err + return fmt.Errorf("unable to get connection: %w", err) } // Check if the role exists var exists bool err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) if err != nil && err != sql.ErrNoRows { - return "", "", err + return fmt.Errorf("user does not appear to exist: %w", err) } - // Vault requires the database user already exist, and that the credentials - // used to execute the rotation statements has sufficient privileges. - stmts := statements.Rotation - - // Start a transaction tx, err := db.BeginTx(ctx, nil) if err != nil { - return "", "", err + return fmt.Errorf("unable to start transaction: %w", err) } - defer func() { - _ = tx.Rollback() - }() + defer tx.Rollback() - // Execute each query for _, stmt := range stmts { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { query = strings.TrimSpace(query) @@ -160,117 +167,30 @@ func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Sta } m := map[string]string{ - "name": staticUser.Username, - "username": staticUser.Username, + "name": username, + "username": username, "password": password, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return "", "", err + return fmt.Errorf("failed to execute query: %w", err) } } } - // Commit the transaction if err := tx.Commit(); err != nil { - return "", "", err + return err } - return username, password, nil + return nil } -func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { - statements = dbutil.StatementCompatibilityHelper(statements) - - if len(statements.Creation) == 0 { - return "", "", dbutil.ErrEmptyCreationStatement - } - - // Grab the lock +func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, changeExp *newdbplugin.ChangeExpiration) error { p.Lock() defer p.Unlock() - username, err = p.GenerateUsername(usernameConfig) - if err != nil { - return "", "", err - } - - password, err = p.GeneratePassword() - if err != nil { - return "", "", err - } - - expirationStr, err := p.GenerateExpiration(expiration) - if err != nil { - return "", "", err - } - - // Get the connection - db, err := p.getConnection(ctx) - if err != nil { - return "", "", err - } - - // Start a transaction - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return "", "", err - - } - defer func() { - tx.Rollback() - }() - - // Execute each query - for _, stmt := range statements.Creation { - if containsMultilineStatement(stmt) { - // Execute it as-is. - m := map[string]string{ - "name": username, - "username": username, - "password": password, - "expiration": expirationStr, - } - if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil { - return "", "", err - } - continue - } - // Otherwise, it's fine to split the statements on the semicolon. - for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - m := map[string]string{ - "name": username, - "username": username, - "password": password, - "expiration": expirationStr, - } - if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return "", "", err - } - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return "", "", err - } - - return username, password, nil -} - -func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { - p.Lock() - defer p.Unlock() - - statements = dbutil.StatementCompatibilityHelper(statements) - - renewStmts := statements.Renewal + renewStmts := changeExp.Statements.Commands if len(renewStmts) == 0 { - renewStmts = []string{defaultPostgresRenewSQL} + renewStmts = []string{defaultExpirationStatement} } db, err := p.getConnection(ctx) @@ -286,10 +206,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen tx.Rollback() }() - expirationStr, err := p.GenerateExpiration(expiration) - if err != nil { - return err - } + expirationStr := changeExp.NewExpiration.Format(expirationFormat) for _, stmt := range renewStmts { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { @@ -312,21 +229,93 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen return tx.Commit() } -func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { - // Grab the lock +func (p *PostgreSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) { + if len(req.Statements.Commands) == 0 { + return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement + } + p.Lock() defer p.Unlock() - statements = dbutil.StatementCompatibilityHelper(statements) - - if len(statements.Revocation) == 0 { - return p.defaultRevokeUser(ctx, username) + username, err := credsutil.GenerateUsername( + credsutil.DisplayName(req.UsernameConfig.DisplayName, 8), + credsutil.RoleName(req.UsernameConfig.RoleName, 8), + credsutil.Separator("-"), + credsutil.MaxLength(63), + ) + if err != nil { + return newdbplugin.NewUserResponse{}, err } - return p.customRevokeUser(ctx, username, statements.Revocation) + expirationStr := req.Expiration.Format(expirationFormat) + + db, err := p.getConnection(ctx) + if err != nil { + return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err) + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to start transaction: %w", err) + + } + defer tx.Rollback() + + for _, stmt := range req.Statements.Commands { + if containsMultilineStatement(stmt) { + // Execute it as-is. + m := map[string]string{ + "name": username, + "username": username, + "password": req.Password, + "expiration": expirationStr, + } + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil { + return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err) + } + continue + } + // Otherwise, it's fine to split the statements on the semicolon. + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + m := map[string]string{ + "name": username, + "username": username, + "password": req.Password, + "expiration": expirationStr, + } + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { + return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err) + } + } + } + + if err := tx.Commit(); err != nil { + return newdbplugin.NewUserResponse{}, err + } + + resp := newdbplugin.NewUserResponse{ + Username: username, + } + return resp, nil } -func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { +func (p *PostgreSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) { + p.Lock() + defer p.Unlock() + + if len(req.Statements.Commands) == 0 { + return newdbplugin.DeleteUserResponse{}, p.defaultDeleteUser(ctx, req.Username) + } + + return newdbplugin.DeleteUserResponse{}, p.customDeleteUser(ctx, req.Username, req.Statements.Commands) +} + +func (p *PostgreSQL) customDeleteUser(ctx context.Context, username string, revocationStmts []string) error { db, err := p.getConnection(ctx) if err != nil { return err @@ -360,7 +349,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revo return tx.Commit() } -func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { +func (p *PostgreSQL) defaultDeleteUser(ctx context.Context, username string) error { db, err := p.getConnection(ctx) if err != nil { return err @@ -471,65 +460,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err return nil } -func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { - p.Lock() - defer p.Unlock() - - if len(p.Username) == 0 || len(p.Password) == 0 { - return nil, errors.New("username and password are required to rotate") +func (p *PostgreSQL) secretValues() map[string]string { + return map[string]string{ + p.Password: "[password]", } - - rotateStatements := statements - if len(rotateStatements) == 0 { - rotateStatements = []string{defaultPostgresRotateRootCredentialsSQL} - } - - db, err := p.getConnection(ctx) - if err != nil { - return nil, err - } - - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer func() { - tx.Rollback() - }() - - password, err := p.GeneratePassword() - if err != nil { - return nil, err - } - - for _, stmt := range rotateStatements { - for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - m := map[string]string{ - "name": p.Username, - "username": p.Username, - "password": password, - } - if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return nil, err - } - } - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - // Close the database connection to ensure no new connections come in - if err := db.Close(); err != nil { - return nil, err - } - - p.RawConfig["password"] = password - return p.RawConfig, nil } // containsMultilineStatement is a best effort to determine whether diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 8c099a61a0..fa86af3f3c 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -9,9 +9,8 @@ import ( "time" "github.com/hashicorp/vault/helper/testhelpers/postgresql" - "github.com/hashicorp/vault/sdk/database/dbplugin" - "github.com/hashicorp/vault/sdk/helper/dbtxn" - "github.com/lib/pq" + "github.com/hashicorp/vault/sdk/database/newdbplugin" + dbtesting "github.com/hashicorp/vault/sdk/database/newdbplugin/testing" ) func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) { @@ -24,12 +23,14 @@ func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, f connectionDetails[k] = v } - db := new() - _, err := db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) + req := newdbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, } + db := new() + dbtesting.AssertInitialize(t, db, req) + if !db.Initialized { t.Fatal("Database should be initialized") } @@ -58,90 +59,163 @@ func TestPostgreSQL_InitializeWithStringVals(t *testing.T) { } } -func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) { - db := new() - - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } - - username, password, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) - if err == nil { - t.Fatalf("expected err, got nil") - } - if username != "" { - t.Fatalf("expected empty username, got [%s]", username) - } - if password != "" { - t.Fatalf("expected empty password, got [%s]", password) - } -} - -func TestPostgreSQL_CreateUser(t *testing.T) { +func TestPostgreSQL_NewUser(t *testing.T) { type testCase struct { - createStmts []string - shouldTestCredsExist bool + req newdbplugin.NewUserRequest + expectErr bool + credsAssertion credsAssertion } tests := map[string]testCase{ - "admin name": { - createStmts: []string{` - CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, + "no creation statements": { + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + // No statements + Password: "somesecurepassword", + Expiration: time.Now().Add(1 * time.Minute), }, - shouldTestCredsExist: true, + expectErr: true, + credsAssertion: assertCredsDoNotExist, + }, + "admin name": { + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{` + CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, + }, + }, + Password: "somesecurepassword", + Expiration: time.Now().Add(1 * time.Minute), + }, + expectErr: false, + credsAssertion: assertCredsExist, }, "admin username": { - createStmts: []string{` - CREATE ROLE "{{username}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{` + CREATE ROLE "{{username}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, + }, + }, + Password: "somesecurepassword", + Expiration: time.Now().Add(1 * time.Minute), }, - shouldTestCredsExist: true, + expectErr: false, + credsAssertion: assertCredsExist, }, "read only name": { - createStmts: []string{` - CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; - GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; - GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{` + CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; + GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, + }, + }, + Password: "somesecurepassword", + Expiration: time.Now().Add(1 * time.Minute), }, - shouldTestCredsExist: true, + expectErr: false, + credsAssertion: assertCredsExist, }, "read only username": { - createStmts: []string{` - CREATE ROLE "{{username}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; - GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; - GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{` + CREATE ROLE "{{username}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; + GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, + }, + }, + Password: "somesecurepassword", + Expiration: time.Now().Add(1 * time.Minute), }, - shouldTestCredsExist: true, + expectErr: false, + credsAssertion: assertCredsExist, }, // https://github.com/hashicorp/vault/issues/6098 "reproduce GH-6098": { - createStmts: []string{ - // NOTE: "rolname" in the following line is not a typo. - "DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{ + // NOTE: "rolname" in the following line is not a typo. + "DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", + }, + }, + Password: "somesecurepassword", + Expiration: time.Now().Add(1 * time.Minute), }, - // This test statement doesn't generate creds. - shouldTestCredsExist: false, + expectErr: false, + credsAssertion: assertCredsDoNotExist, }, "reproduce issue with template": { - createStmts: []string{ - `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{ + `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, + }, + }, + Password: "somesecurepassword", + Expiration: time.Now().Add(1 * time.Minute), }, - // This test statement doesn't generate creds. - shouldTestCredsExist: false, + expectErr: false, + credsAssertion: assertCredsDoNotExist, + }, + "large block statements": { + req: newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: newUserLargeBlockStatements, + }, + Password: "somesecurepassword", + Expiration: time.Now().Add(1 * time.Minute), + }, + expectErr: false, + credsAssertion: assertCredsExist, }, } @@ -151,59 +225,55 @@ func TestPostgreSQL_CreateUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } - - statements := dbplugin.Statements{ - Creation: test.createStmts, - } - // Give a timeout just in case the test decides to be problematic ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) + resp, err := db.NewUser(ctx, test.req) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) } - if !test.shouldTestCredsExist { - // We're done here. - return - } - - if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password) // Ensure that the role doesn't expire immediately time.Sleep(2 * time.Second) - if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password) }) } } -func TestPostgreSQL_RenewUser(t *testing.T) { +func TestUpdateUser_Password(t *testing.T) { type testCase struct { - renewalStmts []string + statements []string + expectErr bool + credsAssertion credsAssertion } tests := map[string]testCase{ - "empty renewal statements": { - renewalStmts: nil, + "default statements": { + statements: nil, + expectErr: false, + credsAssertion: assertCredsExist, }, - "default renewal name": { - renewalStmts: []string{defaultPostgresRenewSQL}, + "explicit default statements": { + statements: []string{defaultChangePasswordStatement}, + expectErr: false, + credsAssertion: assertCredsExist, }, - "default renewal username": { - renewalStmts: []string{` - ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`, - }, + "name instead of username": { + statements: []string{`ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`}, + expectErr: false, + credsAssertion: assertCredsExist, + }, + "bad statements": { + statements: []string{`asdofyas8uf77asoiajv`}, + expectErr: true, + credsAssertion: assertCredsDoNotExist, }, } @@ -213,128 +283,251 @@ func TestPostgreSQL_RenewUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - statements := dbplugin.Statements{ - Creation: []string{createAdminUser}, - Renewal: test.renewalStmts, + initialPass := "myreallysecurepassword" + createReq := newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{createAdminUser}, + }, + Password: initialPass, + Expiration: time.Now().Add(2 * time.Second), + } + createResp := dbtesting.AssertNewUser(t, db, createReq) + + assertCredsExist(t, db.ConnectionURL, createResp.Username, initialPass) + + newPass := "somenewpassword" + updateReq := newdbplugin.UpdateUserRequest{ + Username: createResp.Username, + Password: &newdbplugin.ChangePassword{ + NewPassword: newPass, + Statements: newdbplugin.Statements{ + Commands: test.statements, + }, + }, } - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } - - // Give a timeout just in case the test decides to be problematic - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - - username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) + _, err := db.UpdateUser(ctx, updateReq) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) } - if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - err = db.RenewUser(ctx, statements, username, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Sleep longer than the initial expiration time - time.Sleep(2 * time.Second) - - if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + test.credsAssertion(t, db.ConnectionURL, createResp.Username, newPass) }) } + + t.Run("user does not exist", func(t *testing.T) { + newPass := "somenewpassword" + updateReq := newdbplugin.UpdateUserRequest{ + Username: "missing-user", + Password: &newdbplugin.ChangePassword{ + NewPassword: newPass, + Statements: newdbplugin.Statements{}, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := db.UpdateUser(ctx, updateReq) + if err == nil { + t.Fatalf("err expected, got nil") + } + + assertCredsDoNotExist(t, db.ConnectionURL, updateReq.Username, newPass) + }) } -func TestPostgreSQL_RotateRootCredentials(t *testing.T) { +func TestUpdateUser_Expiration(t *testing.T) { type testCase struct { - statements []string + initialExpiration time.Time + newExpiration time.Time + expectedExpiration time.Time + statements []string + expectErr bool } + now := time.Now() tests := map[string]testCase{ - "empty statements": { - statements: nil, + "no statements": { + initialExpiration: now.Add(1 * time.Minute), + newExpiration: now.Add(5 * time.Minute), + expectedExpiration: now.Add(5 * time.Minute), + statements: nil, + expectErr: false, }, - "default name": { - statements: []string{` - ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, - }, + "default statements with name": { + initialExpiration: now.Add(1 * time.Minute), + newExpiration: now.Add(5 * time.Minute), + expectedExpiration: now.Add(5 * time.Minute), + statements: []string{defaultExpirationStatement}, + expectErr: false, }, - "default username": { - statements: []string{defaultPostgresRotateRootCredentialsSQL}, + "default statements with username": { + initialExpiration: now.Add(1 * time.Minute), + newExpiration: now.Add(5 * time.Minute), + expectedExpiration: now.Add(5 * time.Minute), + statements: []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`}, + expectErr: false, + }, + "bad statements": { + initialExpiration: now.Add(1 * time.Minute), + newExpiration: now.Add(5 * time.Minute), + expectedExpiration: now.Add(1 * time.Minute), + statements: []string{"ladshfouay09sgj"}, + expectErr: true, }, } + // Shared test container for speed - there should not be any overlap between the tests + db, cleanup := getPostgreSQL(t, nil) + defer cleanup() + for name, test := range tests { t.Run(name, func(t *testing.T) { - cleanup, connURL := postgresql.PrepareTestContainer(t, "latest") - defer cleanup() - connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1) + password := "myreallysecurepassword" + initialExpiration := test.initialExpiration.Truncate(time.Second) + createReq := newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{createAdminUser}, + }, + Password: password, + Expiration: initialExpiration, + } + createResp := dbtesting.AssertNewUser(t, db, createReq) - connectionDetails := map[string]interface{}{ - "connection_url": connURL, - "max_open_connections": 5, - "username": "postgres", - "password": "secret", + assertCredsExist(t, db.ConnectionURL, createResp.Username, password) + + actualExpiration := getExpiration(t, db, createResp.Username) + if actualExpiration.IsZero() { + t.Fatalf("Initial expiration is zero but should be set") + } + if !actualExpiration.Equal(initialExpiration) { + t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, initialExpiration) } - db := new() - connProducer := db.SQLConnectionProducer + newExpiration := test.newExpiration.Truncate(time.Second) + updateReq := newdbplugin.UpdateUserRequest{ + Username: createResp.Username, + Expiration: &newdbplugin.ChangeExpiration{ + NewExpiration: newExpiration, + Statements: newdbplugin.Statements{ + Commands: test.statements, + }, + }, + } - // Give a timeout just in case the test decides to be problematic - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - - _, err := db.Init(ctx, connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) + _, err := db.UpdateUser(ctx, updateReq) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) } - if !connProducer.Initialized { - t.Fatal("Database should be initialized") - } - - newConf, err := db.RotateRootCredentials(ctx, test.statements) - if err != nil { - t.Fatalf("err: %v", err) - } - if newConf["password"] == "secret" { - t.Fatal("password was not updated") - } - - err = db.Close() - if err != nil { - t.Fatalf("failed to close: %s", err) + expectedExpiration := test.expectedExpiration.Truncate(time.Second) + actualExpiration = getExpiration(t, db, createResp.Username) + if !actualExpiration.Equal(expectedExpiration) { + t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, expectedExpiration) } }) } } -func TestPostgreSQL_RevokeUser(t *testing.T) { +func getExpiration(t testing.TB, db *PostgreSQL, username string) time.Time { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + query := fmt.Sprintf("select valuntil from pg_catalog.pg_user where usename = '%s'", username) + conn, err := db.getConnection(ctx) + if err != nil { + t.Fatalf("Failed to get connection to database: %s", err) + } + + stmt, err := conn.PrepareContext(ctx, query) + if err != nil { + t.Fatalf("Failed to prepare statement: %s", err) + } + defer stmt.Close() + + rows, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatalf("Failed to execute query to get expiration: %s", err) + } + + if !rows.Next() { + return time.Time{} // No expiration + } + rawExp := "" + err = rows.Scan(&rawExp) + if err != nil { + t.Fatalf("Unable to get raw expiration: %s", err) + } + if rawExp == "" { + return time.Time{} // No expiration + } + exp, err := time.Parse(time.RFC3339, rawExp) + if err != nil { + t.Fatalf("Failed to parse expiration %q: %s", rawExp, err) + } + return exp +} + +func TestDeleteUser(t *testing.T) { type testCase struct { - revokeStmts []string + revokeStmts []string + expectErr bool + credsAssertion credsAssertion } tests := map[string]testCase{ - "empty statements": { + "no statements": { revokeStmts: nil, + expectErr: false, + // Wait for a short time before failing because postgres takes a moment to finish deleting the user + credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), }, - "explicit default name": { - revokeStmts: []string{defaultPostgresRevocationSQL}, + "statements with name": { + revokeStmts: []string{` + REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; + REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; + REVOKE USAGE ON SCHEMA public FROM "{{name}}"; + + DROP ROLE IF EXISTS "{{name}}";`}, + expectErr: false, + // Wait for a short time before failing because postgres takes a moment to finish deleting the user + credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), }, - "explicit default username": { + "statements with username": { revokeStmts: []string{` REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}"; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}"; REVOKE USAGE ON SCHEMA public FROM "{{username}}"; - - DROP ROLE IF EXISTS "{{username}}";`, - }, + + DROP ROLE IF EXISTS "{{username}}";`}, + expectErr: false, + // Wait for a short time before failing because postgres takes a moment to finish deleting the user + credsAssertion: waitUntilCredsDoNotExist(2 * time.Second), + }, + "bad statements": { + revokeStmts: []string{`8a9yhfoiasjff`}, + expectErr: true, + // Wait for a short time before checking because postgres takes a moment to finish deleting the user + credsAssertion: assertCredsExistAfter(100 * time.Millisecond), }, } @@ -344,155 +537,91 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - statements := dbplugin.Statements{ - Creation: []string{createAdminUser}, - Revocation: test.revokeStmts, + password := "myreallysecurepassword" + createReq := newdbplugin.NewUserRequest{ + UsernameConfig: newdbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + }, + Statements: newdbplugin.Statements{ + Commands: []string{createAdminUser}, + }, + Password: password, + Expiration: time.Now().Add(2 * time.Second), + } + createResp := dbtesting.AssertNewUser(t, db, createReq) + + assertCredsExist(t, db.ConnectionURL, createResp.Username, password) + + deleteReq := newdbplugin.DeleteUserRequest{ + Username: createResp.Username, + Statements: newdbplugin.Statements{ + Commands: test.revokeStmts, + }, } - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := db.DeleteUser(ctx, deleteReq) + if test.expectErr && err == nil { + t.Fatalf("err expected, got nil") + } + if !test.expectErr && err != nil { + t.Fatalf("no error expected, got: %s", err) } - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - // Test default revoke statements - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := testCredsExist(t, db.ConnectionURL, username, password); err == nil { - t.Fatal("Credentials were not revoked") - } + test.credsAssertion(t, db.ConnectionURL, createResp.Username, password) }) } } -func TestPostgreSQL_SetCredentials_missingArgs(t *testing.T) { - type testCase struct { - statements dbplugin.Statements - userConfig dbplugin.StaticUserConfig - } +type credsAssertion func(t testing.TB, connURL, username, password string) - tests := map[string]testCase{ - "empty rotation statements": { - statements: dbplugin.Statements{ - Rotation: nil, - }, - userConfig: dbplugin.StaticUserConfig{ - Username: "testuser", - Password: "password", - }, - }, - "empty username": { - statements: dbplugin.Statements{ - Rotation: []string{` - ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, - }, - }, - userConfig: dbplugin.StaticUserConfig{ - Username: "", - Password: "password", - }, - }, - "empty password": { - statements: dbplugin.Statements{ - Rotation: []string{` - ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, - }, - }, - userConfig: dbplugin.StaticUserConfig{ - Username: "testuser", - Password: "", - }, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - db := new() - - username, password, err := db.SetCredentials(context.Background(), test.statements, test.userConfig) - if err == nil { - t.Fatalf("expected err, got nil") - } - if username != "" { - t.Fatalf("expected empty username, got [%s]", username) - } - if password != "" { - t.Fatalf("expected empty password, got [%s]", password) - } - }) +func assertCredsExist(t testing.TB, connURL, username, password string) { + t.Helper() + err := testCredsExist(t, connURL, username, password) + if err != nil { + t.Fatalf("user does not exist: %s", err) } } -func TestPostgresSQL_SetCredentials(t *testing.T) { - type testCase struct { - rotationStmts []string +func assertCredsDoNotExist(t testing.TB, connURL, username, password string) { + t.Helper() + err := testCredsExist(t, connURL, username, password) + if err == nil { + t.Fatalf("user should not exist but does") } +} - tests := map[string]testCase{ - "name rotation": { - rotationStmts: []string{` - ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, - }, - }, - "username rotation": { - rotationStmts: []string{` - ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';`, - }, - }, +func waitUntilCredsDoNotExist(timeout time.Duration) credsAssertion { + return func(t testing.TB, connURL, username, password string) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + t.Fatalf("Timed out waiting for user %s to be deleted", username) + case <-ticker.C: + err := testCredsExist(t, connURL, username, password) + if err != nil { + // Happy path + return + } + } + } } +} - for name, test := range tests { - t.Run(name, func(t *testing.T) { - db, cleanup := getPostgreSQL(t, nil) - defer cleanup() - - // create the database user - dbUser := "vaultstatictest" - initPassword := "password" - createTestPGUser(t, db.ConnectionURL, dbUser, initPassword, testRoleStaticCreate) - - statements := dbplugin.Statements{ - Rotation: test.rotationStmts, - } - - password, err := db.GenerateCredentials(context.Background()) - if err != nil { - t.Fatal(err) - } - - usernameConfig := dbplugin.StaticUserConfig{ - Username: dbUser, - Password: password, - } - - if err := testCredsExist(t, db.ConnectionURL, dbUser, initPassword); err != nil { - t.Fatalf("Could not connect with initial credentials: %s", err) - } - - username, password, err := db.SetCredentials(context.Background(), statements, usernameConfig) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := testCredsExist(t, db.ConnectionURL, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - if err := testCredsExist(t, db.ConnectionURL, username, initPassword); err == nil { - t.Fatalf("Should not be able to connect with initial credentials") - } - }) +func assertCredsExistAfter(timeout time.Duration) credsAssertion { + return func(t testing.TB, connURL, username, password string) { + t.Helper() + time.Sleep(timeout) + assertCredsExist(t, connURL, username, password) } } @@ -516,7 +645,7 @@ CREATE ROLE "{{name}}" WITH GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; ` -var testPostgresBlockStatementRoleSlice = []string{ +var newUserLargeBlockStatements = []string{ ` DO $$ BEGIN @@ -539,59 +668,6 @@ $$ `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, } -const defaultPostgresRevocationSQL = ` -REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; -REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; -REVOKE USAGE ON SCHEMA public FROM "{{name}}"; - -DROP ROLE IF EXISTS "{{name}}"; -` - -const testRoleStaticCreate = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}'; -` - -// This is a copy of a test helper method also found in -// builtin/logical/database/rotation_test.go , and should be moved into a shared -// helper file in the future. -func createTestPGUser(t *testing.T, connURL string, username, password, query string) { - t.Helper() - conn, err := pq.ParseURL(connURL) - if err != nil { - t.Fatal(err) - } - - db, err := sql.Open("postgres", conn) - defer db.Close() - if err != nil { - t.Fatal(err) - } - - // Start a transaction - ctx := context.Background() - tx, err := db.BeginTx(ctx, nil) - if err != nil { - t.Fatal(err) - } - defer func() { - _ = tx.Rollback() - }() - - m := map[string]string{ - "name": username, - "password": password, - } - if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - t.Fatal(err) - } - // Commit the transaction - if err := tx.Commit(); err != nil { - t.Fatal(err) - } -} - func TestContainsMultilineStatement(t *testing.T) { type testCase struct { Input string