mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
DBPW - Update PostgreSQL to adhere to v5 Database interface (#10061)
This commit is contained in:
@@ -3,35 +3,36 @@ package postgresql
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||
"github.com/hashicorp/vault/sdk/helper/strutil"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
postgreSQLTypeName = "postgres"
|
||||
defaultPostgresRenewSQL = `
|
||||
postgreSQLTypeName = "postgres"
|
||||
defaultExpirationStatement = `
|
||||
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
|
||||
`
|
||||
defaultPostgresRotateRootCredentialsSQL = `
|
||||
defaultChangePasswordStatement = `
|
||||
ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
|
||||
`
|
||||
|
||||
expirationFormat = "2006-01-02 15:04:05-0700"
|
||||
)
|
||||
|
||||
var (
|
||||
_ dbplugin.Database = &PostgreSQL{}
|
||||
_ newdbplugin.Database = &PostgreSQL{}
|
||||
|
||||
// postgresEndStatement is basically the word "END" but
|
||||
// surrounded by a word boundary to differentiate it from
|
||||
@@ -51,7 +52,7 @@ var (
|
||||
func New() (interface{}, error) {
|
||||
db := new()
|
||||
// Wrap the plugin with middleware to sanitize errors
|
||||
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
|
||||
dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
|
||||
return dbType, nil
|
||||
}
|
||||
|
||||
@@ -59,16 +60,8 @@ func new() *PostgreSQL {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
connProducer.Type = postgreSQLTypeName
|
||||
|
||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||
DisplayNameLen: 8,
|
||||
RoleNameLen: 8,
|
||||
UsernameLen: 63,
|
||||
Separator: "-",
|
||||
}
|
||||
|
||||
db := &PostgreSQL{
|
||||
SQLConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
return db
|
||||
@@ -81,14 +74,24 @@ func Run(apiTLSConfig *api.TLSConfig) error {
|
||||
return err
|
||||
}
|
||||
|
||||
dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
|
||||
newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type PostgreSQL struct {
|
||||
*connutil.SQLConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
|
||||
newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
|
||||
if err != nil {
|
||||
return newdbplugin.InitializeResponse{}, err
|
||||
}
|
||||
resp := newdbplugin.InitializeResponse{
|
||||
Config: newConf,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Type() (string, error) {
|
||||
@@ -104,54 +107,58 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
// SetCredentials uses provided information to set/create a user in the
|
||||
// database. Unlike CreateUser, this method requires a username be provided and
|
||||
// uses the name given, instead of generating a name. This is used for creating
|
||||
// and setting the password of static accounts, as well as rolling back
|
||||
// passwords in the database in the event an updated database fails to save in
|
||||
// Vault's storage.
|
||||
func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
|
||||
if len(statements.Rotation) == 0 {
|
||||
statements.Rotation = []string{defaultPostgresRotateRootCredentialsSQL}
|
||||
func (p *PostgreSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
|
||||
if req.Username == "" {
|
||||
return newdbplugin.UpdateUserResponse{}, fmt.Errorf("missing username")
|
||||
}
|
||||
if req.Password == nil && req.Expiration == nil {
|
||||
return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
|
||||
}
|
||||
|
||||
username = staticUser.Username
|
||||
password = staticUser.Password
|
||||
if username == "" || password == "" {
|
||||
return "", "", errors.New("must provide both username and password")
|
||||
merr := &multierror.Error{}
|
||||
if req.Password != nil {
|
||||
err := p.changeUserPassword(ctx, req.Username, req.Password)
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
if req.Expiration != nil {
|
||||
err := p.changeUserExpiration(ctx, req.Username, req.Expiration)
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
return newdbplugin.UpdateUserResponse{}, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error {
|
||||
stmts := changePass.Statements.Commands
|
||||
if len(stmts) == 0 {
|
||||
stmts = []string{defaultChangePasswordStatement}
|
||||
}
|
||||
|
||||
password := changePass.NewPassword
|
||||
if password == "" {
|
||||
return fmt.Errorf("missing password")
|
||||
}
|
||||
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return fmt.Errorf("unable to get connection: %w", err)
|
||||
}
|
||||
|
||||
// Check if the role exists
|
||||
var exists bool
|
||||
err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", "", err
|
||||
return fmt.Errorf("user does not appear to exist: %w", err)
|
||||
}
|
||||
|
||||
// Vault requires the database user already exist, and that the credentials
|
||||
// used to execute the rotation statements has sufficient privileges.
|
||||
stmts := statements.Rotation
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return fmt.Errorf("unable to start transaction: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, stmt := range stmts {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
@@ -160,117 +167,30 @@ func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Sta
|
||||
}
|
||||
|
||||
m := map[string]string{
|
||||
"name": staticUser.Username,
|
||||
"username": staticUser.Username,
|
||||
"name": username,
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
|
||||
return "", "", err
|
||||
return fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", err
|
||||
return err
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||
|
||||
if len(statements.Creation) == 0 {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
// Grab the lock
|
||||
func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, changeExp *newdbplugin.ChangeExpiration) error {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
username, err = p.GenerateUsername(usernameConfig)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
password, err = p.GeneratePassword()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
expirationStr, err := p.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Get the connection
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
// Execute each query
|
||||
for _, stmt := range statements.Creation {
|
||||
if containsMultilineStatement(stmt) {
|
||||
// Execute it as-is.
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Otherwise, it's fine to split the statements on the semicolon.
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||
|
||||
renewStmts := statements.Renewal
|
||||
renewStmts := changeExp.Statements.Commands
|
||||
if len(renewStmts) == 0 {
|
||||
renewStmts = []string{defaultPostgresRenewSQL}
|
||||
renewStmts = []string{defaultExpirationStatement}
|
||||
}
|
||||
|
||||
db, err := p.getConnection(ctx)
|
||||
@@ -286,10 +206,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
expirationStr, err := p.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
expirationStr := changeExp.NewExpiration.Format(expirationFormat)
|
||||
|
||||
for _, stmt := range renewStmts {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||
@@ -312,21 +229,93 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
func (p *PostgreSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
|
||||
if len(req.Statements.Commands) == 0 {
|
||||
return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||
|
||||
if len(statements.Revocation) == 0 {
|
||||
return p.defaultRevokeUser(ctx, username)
|
||||
username, err := credsutil.GenerateUsername(
|
||||
credsutil.DisplayName(req.UsernameConfig.DisplayName, 8),
|
||||
credsutil.RoleName(req.UsernameConfig.RoleName, 8),
|
||||
credsutil.Separator("-"),
|
||||
credsutil.MaxLength(63),
|
||||
)
|
||||
if err != nil {
|
||||
return newdbplugin.NewUserResponse{}, err
|
||||
}
|
||||
|
||||
return p.customRevokeUser(ctx, username, statements.Revocation)
|
||||
expirationStr := req.Expiration.Format(expirationFormat)
|
||||
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
|
||||
}
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to start transaction: %w", err)
|
||||
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, stmt := range req.Statements.Commands {
|
||||
if containsMultilineStatement(stmt) {
|
||||
// Execute it as-is.
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
"username": username,
|
||||
"password": req.Password,
|
||||
"expiration": expirationStr,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil {
|
||||
return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Otherwise, it's fine to split the statements on the semicolon.
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
"username": username,
|
||||
"password": req.Password,
|
||||
"expiration": expirationStr,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
|
||||
return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return newdbplugin.NewUserResponse{}, err
|
||||
}
|
||||
|
||||
resp := newdbplugin.NewUserResponse{
|
||||
Username: username,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error {
|
||||
func (p *PostgreSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if len(req.Statements.Commands) == 0 {
|
||||
return newdbplugin.DeleteUserResponse{}, p.defaultDeleteUser(ctx, req.Username)
|
||||
}
|
||||
|
||||
return newdbplugin.DeleteUserResponse{}, p.customDeleteUser(ctx, req.Username, req.Statements.Commands)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customDeleteUser(ctx context.Context, username string, revocationStmts []string) error {
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -360,7 +349,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revo
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
|
||||
func (p *PostgreSQL) defaultDeleteUser(ctx context.Context, username string) error {
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -471,65 +460,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if len(p.Username) == 0 || len(p.Password) == 0 {
|
||||
return nil, errors.New("username and password are required to rotate")
|
||||
func (p *PostgreSQL) secretValues() map[string]string {
|
||||
return map[string]string{
|
||||
p.Password: "[password]",
|
||||
}
|
||||
|
||||
rotateStatements := statements
|
||||
if len(rotateStatements) == 0 {
|
||||
rotateStatements = []string{defaultPostgresRotateRootCredentialsSQL}
|
||||
}
|
||||
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
password, err := p.GeneratePassword()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, stmt := range rotateStatements {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
m := map[string]string{
|
||||
"name": p.Username,
|
||||
"username": p.Username,
|
||||
"password": password,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Close the database connection to ensure no new connections come in
|
||||
if err := db.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.RawConfig["password"] = password
|
||||
return p.RawConfig, nil
|
||||
}
|
||||
|
||||
// containsMultilineStatement is a best effort to determine whether
|
||||
|
||||
@@ -9,9 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/dbtxn"
|
||||
"github.com/lib/pq"
|
||||
"github.com/hashicorp/vault/sdk/database/newdbplugin"
|
||||
dbtesting "github.com/hashicorp/vault/sdk/database/newdbplugin/testing"
|
||||
)
|
||||
|
||||
func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) {
|
||||
@@ -24,12 +23,14 @@ func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, f
|
||||
connectionDetails[k] = v
|
||||
}
|
||||
|
||||
db := new()
|
||||
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
req := newdbplugin.InitializeRequest{
|
||||
Config: connectionDetails,
|
||||
VerifyConnection: true,
|
||||
}
|
||||
|
||||
db := new()
|
||||
dbtesting.AssertInitialize(t, db, req)
|
||||
|
||||
if !db.Initialized {
|
||||
t.Fatal("Database should be initialized")
|
||||
}
|
||||
@@ -58,90 +59,163 @@ func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) {
|
||||
db := new()
|
||||
|
||||
usernameConfig := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatalf("expected err, got nil")
|
||||
}
|
||||
if username != "" {
|
||||
t.Fatalf("expected empty username, got [%s]", username)
|
||||
}
|
||||
if password != "" {
|
||||
t.Fatalf("expected empty password, got [%s]", password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
func TestPostgreSQL_NewUser(t *testing.T) {
|
||||
type testCase struct {
|
||||
createStmts []string
|
||||
shouldTestCredsExist bool
|
||||
req newdbplugin.NewUserRequest
|
||||
expectErr bool
|
||||
credsAssertion credsAssertion
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"admin name": {
|
||||
createStmts: []string{`
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
|
||||
"no creation statements": {
|
||||
req: newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
// No statements
|
||||
Password: "somesecurepassword",
|
||||
Expiration: time.Now().Add(1 * time.Minute),
|
||||
},
|
||||
shouldTestCredsExist: true,
|
||||
expectErr: true,
|
||||
credsAssertion: assertCredsDoNotExist,
|
||||
},
|
||||
"admin name": {
|
||||
req: newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{`
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
|
||||
},
|
||||
},
|
||||
Password: "somesecurepassword",
|
||||
Expiration: time.Now().Add(1 * time.Minute),
|
||||
},
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsExist,
|
||||
},
|
||||
"admin username": {
|
||||
createStmts: []string{`
|
||||
CREATE ROLE "{{username}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`,
|
||||
req: newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{`
|
||||
CREATE ROLE "{{username}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`,
|
||||
},
|
||||
},
|
||||
Password: "somesecurepassword",
|
||||
Expiration: time.Now().Add(1 * time.Minute),
|
||||
},
|
||||
shouldTestCredsExist: true,
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsExist,
|
||||
},
|
||||
"read only name": {
|
||||
createStmts: []string{`
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`,
|
||||
req: newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{`
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`,
|
||||
},
|
||||
},
|
||||
Password: "somesecurepassword",
|
||||
Expiration: time.Now().Add(1 * time.Minute),
|
||||
},
|
||||
shouldTestCredsExist: true,
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsExist,
|
||||
},
|
||||
"read only username": {
|
||||
createStmts: []string{`
|
||||
CREATE ROLE "{{username}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`,
|
||||
req: newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{`
|
||||
CREATE ROLE "{{username}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`,
|
||||
},
|
||||
},
|
||||
Password: "somesecurepassword",
|
||||
Expiration: time.Now().Add(1 * time.Minute),
|
||||
},
|
||||
shouldTestCredsExist: true,
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsExist,
|
||||
},
|
||||
// https://github.com/hashicorp/vault/issues/6098
|
||||
"reproduce GH-6098": {
|
||||
createStmts: []string{
|
||||
// NOTE: "rolname" in the following line is not a typo.
|
||||
"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$",
|
||||
req: newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{
|
||||
// NOTE: "rolname" in the following line is not a typo.
|
||||
"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$",
|
||||
},
|
||||
},
|
||||
Password: "somesecurepassword",
|
||||
Expiration: time.Now().Add(1 * time.Minute),
|
||||
},
|
||||
// This test statement doesn't generate creds.
|
||||
shouldTestCredsExist: false,
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsDoNotExist,
|
||||
},
|
||||
"reproduce issue with template": {
|
||||
createStmts: []string{
|
||||
`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`,
|
||||
req: newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{
|
||||
`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`,
|
||||
},
|
||||
},
|
||||
Password: "somesecurepassword",
|
||||
Expiration: time.Now().Add(1 * time.Minute),
|
||||
},
|
||||
// This test statement doesn't generate creds.
|
||||
shouldTestCredsExist: false,
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsDoNotExist,
|
||||
},
|
||||
"large block statements": {
|
||||
req: newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: newUserLargeBlockStatements,
|
||||
},
|
||||
Password: "somesecurepassword",
|
||||
Expiration: time.Now().Add(1 * time.Minute),
|
||||
},
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsExist,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -151,59 +225,55 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
usernameConfig := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
Creation: test.createStmts,
|
||||
}
|
||||
|
||||
// Give a timeout just in case the test decides to be problematic
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
resp, err := db.NewUser(ctx, test.req)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !test.shouldTestCredsExist {
|
||||
// We're done here.
|
||||
return
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password)
|
||||
|
||||
// Ensure that the role doesn't expire immediately
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
func TestUpdateUser_Password(t *testing.T) {
|
||||
type testCase struct {
|
||||
renewalStmts []string
|
||||
statements []string
|
||||
expectErr bool
|
||||
credsAssertion credsAssertion
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"empty renewal statements": {
|
||||
renewalStmts: nil,
|
||||
"default statements": {
|
||||
statements: nil,
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsExist,
|
||||
},
|
||||
"default renewal name": {
|
||||
renewalStmts: []string{defaultPostgresRenewSQL},
|
||||
"explicit default statements": {
|
||||
statements: []string{defaultChangePasswordStatement},
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsExist,
|
||||
},
|
||||
"default renewal username": {
|
||||
renewalStmts: []string{`
|
||||
ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`,
|
||||
},
|
||||
"name instead of username": {
|
||||
statements: []string{`ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`},
|
||||
expectErr: false,
|
||||
credsAssertion: assertCredsExist,
|
||||
},
|
||||
"bad statements": {
|
||||
statements: []string{`asdofyas8uf77asoiajv`},
|
||||
expectErr: true,
|
||||
credsAssertion: assertCredsDoNotExist,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -213,128 +283,251 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
statements := dbplugin.Statements{
|
||||
Creation: []string{createAdminUser},
|
||||
Renewal: test.renewalStmts,
|
||||
initialPass := "myreallysecurepassword"
|
||||
createReq := newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{createAdminUser},
|
||||
},
|
||||
Password: initialPass,
|
||||
Expiration: time.Now().Add(2 * time.Second),
|
||||
}
|
||||
createResp := dbtesting.AssertNewUser(t, db, createReq)
|
||||
|
||||
assertCredsExist(t, db.ConnectionURL, createResp.Username, initialPass)
|
||||
|
||||
newPass := "somenewpassword"
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: createResp.Username,
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: newPass,
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: test.statements,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
usernameConfig := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
// Give a timeout just in case the test decides to be problematic
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
_, err := db.UpdateUser(ctx, updateReq)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(ctx, statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Sleep longer than the initial expiration time
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
test.credsAssertion(t, db.ConnectionURL, createResp.Username, newPass)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("user does not exist", func(t *testing.T) {
|
||||
newPass := "somenewpassword"
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: "missing-user",
|
||||
Password: &newdbplugin.ChangePassword{
|
||||
NewPassword: newPass,
|
||||
Statements: newdbplugin.Statements{},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, err := db.UpdateUser(ctx, updateReq)
|
||||
if err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
|
||||
assertCredsDoNotExist(t, db.ConnectionURL, updateReq.Username, newPass)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RotateRootCredentials(t *testing.T) {
|
||||
func TestUpdateUser_Expiration(t *testing.T) {
|
||||
type testCase struct {
|
||||
statements []string
|
||||
initialExpiration time.Time
|
||||
newExpiration time.Time
|
||||
expectedExpiration time.Time
|
||||
statements []string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
tests := map[string]testCase{
|
||||
"empty statements": {
|
||||
statements: nil,
|
||||
"no statements": {
|
||||
initialExpiration: now.Add(1 * time.Minute),
|
||||
newExpiration: now.Add(5 * time.Minute),
|
||||
expectedExpiration: now.Add(5 * time.Minute),
|
||||
statements: nil,
|
||||
expectErr: false,
|
||||
},
|
||||
"default name": {
|
||||
statements: []string{`
|
||||
ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`,
|
||||
},
|
||||
"default statements with name": {
|
||||
initialExpiration: now.Add(1 * time.Minute),
|
||||
newExpiration: now.Add(5 * time.Minute),
|
||||
expectedExpiration: now.Add(5 * time.Minute),
|
||||
statements: []string{defaultExpirationStatement},
|
||||
expectErr: false,
|
||||
},
|
||||
"default username": {
|
||||
statements: []string{defaultPostgresRotateRootCredentialsSQL},
|
||||
"default statements with username": {
|
||||
initialExpiration: now.Add(1 * time.Minute),
|
||||
newExpiration: now.Add(5 * time.Minute),
|
||||
expectedExpiration: now.Add(5 * time.Minute),
|
||||
statements: []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`},
|
||||
expectErr: false,
|
||||
},
|
||||
"bad statements": {
|
||||
initialExpiration: now.Add(1 * time.Minute),
|
||||
newExpiration: now.Add(5 * time.Minute),
|
||||
expectedExpiration: now.Add(1 * time.Minute),
|
||||
statements: []string{"ladshfouay09sgj"},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Shared test container for speed - there should not be any overlap between the tests
|
||||
db, cleanup := getPostgreSQL(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
cleanup, connURL := postgresql.PrepareTestContainer(t, "latest")
|
||||
defer cleanup()
|
||||
connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1)
|
||||
password := "myreallysecurepassword"
|
||||
initialExpiration := test.initialExpiration.Truncate(time.Second)
|
||||
createReq := newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{createAdminUser},
|
||||
},
|
||||
Password: password,
|
||||
Expiration: initialExpiration,
|
||||
}
|
||||
createResp := dbtesting.AssertNewUser(t, db, createReq)
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"max_open_connections": 5,
|
||||
"username": "postgres",
|
||||
"password": "secret",
|
||||
assertCredsExist(t, db.ConnectionURL, createResp.Username, password)
|
||||
|
||||
actualExpiration := getExpiration(t, db, createResp.Username)
|
||||
if actualExpiration.IsZero() {
|
||||
t.Fatalf("Initial expiration is zero but should be set")
|
||||
}
|
||||
if !actualExpiration.Equal(initialExpiration) {
|
||||
t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, initialExpiration)
|
||||
}
|
||||
|
||||
db := new()
|
||||
connProducer := db.SQLConnectionProducer
|
||||
newExpiration := test.newExpiration.Truncate(time.Second)
|
||||
updateReq := newdbplugin.UpdateUserRequest{
|
||||
Username: createResp.Username,
|
||||
Expiration: &newdbplugin.ChangeExpiration{
|
||||
NewExpiration: newExpiration,
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: test.statements,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Give a timeout just in case the test decides to be problematic
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.Init(ctx, connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
_, err := db.UpdateUser(ctx, updateReq)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.Initialized {
|
||||
t.Fatal("Database should be initialized")
|
||||
}
|
||||
|
||||
newConf, err := db.RotateRootCredentials(ctx, test.statements)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if newConf["password"] == "secret" {
|
||||
t.Fatal("password was not updated")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close: %s", err)
|
||||
expectedExpiration := test.expectedExpiration.Truncate(time.Second)
|
||||
actualExpiration = getExpiration(t, db, createResp.Username)
|
||||
if !actualExpiration.Equal(expectedExpiration) {
|
||||
t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, expectedExpiration)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
func getExpiration(t testing.TB, db *PostgreSQL, username string) time.Time {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
query := fmt.Sprintf("select valuntil from pg_catalog.pg_user where usename = '%s'", username)
|
||||
conn, err := db.getConnection(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection to database: %s", err)
|
||||
}
|
||||
|
||||
stmt, err := conn.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to prepare statement: %s", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query to get expiration: %s", err)
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
return time.Time{} // No expiration
|
||||
}
|
||||
rawExp := ""
|
||||
err = rows.Scan(&rawExp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to get raw expiration: %s", err)
|
||||
}
|
||||
if rawExp == "" {
|
||||
return time.Time{} // No expiration
|
||||
}
|
||||
exp, err := time.Parse(time.RFC3339, rawExp)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse expiration %q: %s", rawExp, err)
|
||||
}
|
||||
return exp
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
type testCase struct {
|
||||
revokeStmts []string
|
||||
revokeStmts []string
|
||||
expectErr bool
|
||||
credsAssertion credsAssertion
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"empty statements": {
|
||||
"no statements": {
|
||||
revokeStmts: nil,
|
||||
expectErr: false,
|
||||
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
|
||||
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
|
||||
},
|
||||
"explicit default name": {
|
||||
revokeStmts: []string{defaultPostgresRevocationSQL},
|
||||
"statements with name": {
|
||||
revokeStmts: []string{`
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
|
||||
|
||||
DROP ROLE IF EXISTS "{{name}}";`},
|
||||
expectErr: false,
|
||||
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
|
||||
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
|
||||
},
|
||||
"explicit default username": {
|
||||
"statements with username": {
|
||||
revokeStmts: []string{`
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}";
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}";
|
||||
REVOKE USAGE ON SCHEMA public FROM "{{username}}";
|
||||
|
||||
DROP ROLE IF EXISTS "{{username}}";`,
|
||||
},
|
||||
|
||||
DROP ROLE IF EXISTS "{{username}}";`},
|
||||
expectErr: false,
|
||||
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
|
||||
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
|
||||
},
|
||||
"bad statements": {
|
||||
revokeStmts: []string{`8a9yhfoiasjff`},
|
||||
expectErr: true,
|
||||
// Wait for a short time before checking because postgres takes a moment to finish deleting the user
|
||||
credsAssertion: assertCredsExistAfter(100 * time.Millisecond),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -344,155 +537,91 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
statements := dbplugin.Statements{
|
||||
Creation: []string{createAdminUser},
|
||||
Revocation: test.revokeStmts,
|
||||
password := "myreallysecurepassword"
|
||||
createReq := newdbplugin.NewUserRequest{
|
||||
UsernameConfig: newdbplugin.UsernameMetadata{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
},
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: []string{createAdminUser},
|
||||
},
|
||||
Password: password,
|
||||
Expiration: time.Now().Add(2 * time.Second),
|
||||
}
|
||||
createResp := dbtesting.AssertNewUser(t, db, createReq)
|
||||
|
||||
assertCredsExist(t, db.ConnectionURL, createResp.Username, password)
|
||||
|
||||
deleteReq := newdbplugin.DeleteUserRequest{
|
||||
Username: createResp.Username,
|
||||
Statements: newdbplugin.Statements{
|
||||
Commands: test.revokeStmts,
|
||||
},
|
||||
}
|
||||
|
||||
usernameConfig := dbplugin.UsernameConfig{
|
||||
DisplayName: "test",
|
||||
RoleName: "test",
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := db.DeleteUser(ctx, deleteReq)
|
||||
if test.expectErr && err == nil {
|
||||
t.Fatalf("err expected, got nil")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Fatalf("no error expected, got: %s", err)
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statements
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, db.ConnectionURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
test.credsAssertion(t, db.ConnectionURL, createResp.Username, password)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_SetCredentials_missingArgs(t *testing.T) {
|
||||
type testCase struct {
|
||||
statements dbplugin.Statements
|
||||
userConfig dbplugin.StaticUserConfig
|
||||
}
|
||||
type credsAssertion func(t testing.TB, connURL, username, password string)
|
||||
|
||||
tests := map[string]testCase{
|
||||
"empty rotation statements": {
|
||||
statements: dbplugin.Statements{
|
||||
Rotation: nil,
|
||||
},
|
||||
userConfig: dbplugin.StaticUserConfig{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
},
|
||||
},
|
||||
"empty username": {
|
||||
statements: dbplugin.Statements{
|
||||
Rotation: []string{`
|
||||
ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`,
|
||||
},
|
||||
},
|
||||
userConfig: dbplugin.StaticUserConfig{
|
||||
Username: "",
|
||||
Password: "password",
|
||||
},
|
||||
},
|
||||
"empty password": {
|
||||
statements: dbplugin.Statements{
|
||||
Rotation: []string{`
|
||||
ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`,
|
||||
},
|
||||
},
|
||||
userConfig: dbplugin.StaticUserConfig{
|
||||
Username: "testuser",
|
||||
Password: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db := new()
|
||||
|
||||
username, password, err := db.SetCredentials(context.Background(), test.statements, test.userConfig)
|
||||
if err == nil {
|
||||
t.Fatalf("expected err, got nil")
|
||||
}
|
||||
if username != "" {
|
||||
t.Fatalf("expected empty username, got [%s]", username)
|
||||
}
|
||||
if password != "" {
|
||||
t.Fatalf("expected empty password, got [%s]", password)
|
||||
}
|
||||
})
|
||||
func assertCredsExist(t testing.TB, connURL, username, password string) {
|
||||
t.Helper()
|
||||
err := testCredsExist(t, connURL, username, password)
|
||||
if err != nil {
|
||||
t.Fatalf("user does not exist: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresSQL_SetCredentials(t *testing.T) {
|
||||
type testCase struct {
|
||||
rotationStmts []string
|
||||
func assertCredsDoNotExist(t testing.TB, connURL, username, password string) {
|
||||
t.Helper()
|
||||
err := testCredsExist(t, connURL, username, password)
|
||||
if err == nil {
|
||||
t.Fatalf("user should not exist but does")
|
||||
}
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"name rotation": {
|
||||
rotationStmts: []string{`
|
||||
ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`,
|
||||
},
|
||||
},
|
||||
"username rotation": {
|
||||
rotationStmts: []string{`
|
||||
ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';`,
|
||||
},
|
||||
},
|
||||
func waitUntilCredsDoNotExist(timeout time.Duration) credsAssertion {
|
||||
return func(t testing.TB, connURL, username, password string) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(10 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("Timed out waiting for user %s to be deleted", username)
|
||||
case <-ticker.C:
|
||||
err := testCredsExist(t, connURL, username, password)
|
||||
if err != nil {
|
||||
// Happy path
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
db, cleanup := getPostgreSQL(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
// create the database user
|
||||
dbUser := "vaultstatictest"
|
||||
initPassword := "password"
|
||||
createTestPGUser(t, db.ConnectionURL, dbUser, initPassword, testRoleStaticCreate)
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
Rotation: test.rotationStmts,
|
||||
}
|
||||
|
||||
password, err := db.GenerateCredentials(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
usernameConfig := dbplugin.StaticUserConfig{
|
||||
Username: dbUser,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, db.ConnectionURL, dbUser, initPassword); err != nil {
|
||||
t.Fatalf("Could not connect with initial credentials: %s", err)
|
||||
}
|
||||
|
||||
username, password, err := db.SetCredentials(context.Background(), statements, usernameConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, db.ConnectionURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, db.ConnectionURL, username, initPassword); err == nil {
|
||||
t.Fatalf("Should not be able to connect with initial credentials")
|
||||
}
|
||||
})
|
||||
func assertCredsExistAfter(timeout time.Duration) credsAssertion {
|
||||
return func(t testing.TB, connURL, username, password string) {
|
||||
t.Helper()
|
||||
time.Sleep(timeout)
|
||||
assertCredsExist(t, connURL, username, password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -516,7 +645,7 @@ CREATE ROLE "{{name}}" WITH
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
var testPostgresBlockStatementRoleSlice = []string{
|
||||
var newUserLargeBlockStatements = []string{
|
||||
`
|
||||
DO $$
|
||||
BEGIN
|
||||
@@ -539,59 +668,6 @@ $$
|
||||
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
|
||||
}
|
||||
|
||||
const defaultPostgresRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
|
||||
|
||||
DROP ROLE IF EXISTS "{{name}}";
|
||||
`
|
||||
|
||||
const testRoleStaticCreate = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}';
|
||||
`
|
||||
|
||||
// This is a copy of a test helper method also found in
|
||||
// builtin/logical/database/rotation_test.go , and should be moved into a shared
|
||||
// helper file in the future.
|
||||
func createTestPGUser(t *testing.T, connURL string, username, password, query string) {
|
||||
t.Helper()
|
||||
conn, err := pq.ParseURL(connURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", conn)
|
||||
defer db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
ctx := context.Background()
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
m := map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
}
|
||||
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsMultilineStatement(t *testing.T) {
|
||||
type testCase struct {
|
||||
Input string
|
||||
|
||||
Reference in New Issue
Block a user