DBPW - Update PostgreSQL to adhere to v5 Database interface (#10061)

This commit is contained in:
Michael Golowka
2020-10-07 12:58:11 -06:00
committed by GitHub
parent 9620db17d1
commit 28f90f1efe
2 changed files with 590 additions and 580 deletions

View File

@@ -3,35 +3,36 @@ package postgresql
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/dbplugin"
"github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil" "github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/helper/dbtxn" "github.com/hashicorp/vault/sdk/helper/dbtxn"
"github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/lib/pq" "github.com/lib/pq"
) )
const ( const (
postgreSQLTypeName = "postgres" postgreSQLTypeName = "postgres"
defaultPostgresRenewSQL = ` defaultExpirationStatement = `
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
` `
defaultPostgresRotateRootCredentialsSQL = ` defaultChangePasswordStatement = `
ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
` `
expirationFormat = "2006-01-02 15:04:05-0700"
) )
var ( var (
_ dbplugin.Database = &PostgreSQL{} _ newdbplugin.Database = &PostgreSQL{}
// postgresEndStatement is basically the word "END" but // postgresEndStatement is basically the word "END" but
// surrounded by a word boundary to differentiate it from // surrounded by a word boundary to differentiate it from
@@ -51,7 +52,7 @@ var (
func New() (interface{}, error) { func New() (interface{}, error) {
db := new() db := new()
// Wrap the plugin with middleware to sanitize errors // Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil return dbType, nil
} }
@@ -59,16 +60,8 @@ func new() *PostgreSQL {
connProducer := &connutil.SQLConnectionProducer{} connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = postgreSQLTypeName connProducer.Type = postgreSQLTypeName
credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: 8,
RoleNameLen: 8,
UsernameLen: 63,
Separator: "-",
}
db := &PostgreSQL{ db := &PostgreSQL{
SQLConnectionProducer: connProducer, SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
} }
return db return db
@@ -81,14 +74,24 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err return err
} }
dbplugin.Serve(dbType.(dbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig)) newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
return nil return nil
} }
type PostgreSQL struct { type PostgreSQL struct {
*connutil.SQLConnectionProducer *connutil.SQLConnectionProducer
credsutil.CredentialsProducer }
func (p *PostgreSQL) Initialize(ctx context.Context, req newdbplugin.InitializeRequest) (newdbplugin.InitializeResponse, error) {
newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return newdbplugin.InitializeResponse{}, err
}
resp := newdbplugin.InitializeResponse{
Config: newConf,
}
return resp, nil
} }
func (p *PostgreSQL) Type() (string, error) { func (p *PostgreSQL) Type() (string, error) {
@@ -104,54 +107,58 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
return db.(*sql.DB), nil return db.(*sql.DB), nil
} }
// SetCredentials uses provided information to set/create a user in the func (p *PostgreSQL) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
// database. Unlike CreateUser, this method requires a username be provided and if req.Username == "" {
// uses the name given, instead of generating a name. This is used for creating return newdbplugin.UpdateUserResponse{}, fmt.Errorf("missing username")
// and setting the password of static accounts, as well as rolling back }
// passwords in the database in the event an updated database fails to save in if req.Password == nil && req.Expiration == nil {
// Vault's storage. return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
if len(statements.Rotation) == 0 {
statements.Rotation = []string{defaultPostgresRotateRootCredentialsSQL}
} }
username = staticUser.Username merr := &multierror.Error{}
password = staticUser.Password if req.Password != nil {
if username == "" || password == "" { err := p.changeUserPassword(ctx, req.Username, req.Password)
return "", "", errors.New("must provide both username and password") merr = multierror.Append(merr, err)
}
if req.Expiration != nil {
err := p.changeUserExpiration(ctx, req.Username, req.Expiration)
merr = multierror.Append(merr, err)
}
return newdbplugin.UpdateUserResponse{}, merr.ErrorOrNil()
}
func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error {
stmts := changePass.Statements.Commands
if len(stmts) == 0 {
stmts = []string{defaultChangePasswordStatement}
}
password := changePass.NewPassword
if password == "" {
return fmt.Errorf("missing password")
} }
// Grab the lock
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
// Get the connection
db, err := p.getConnection(ctx) db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return fmt.Errorf("unable to get connection: %w", err)
} }
// Check if the role exists // Check if the role exists
var exists bool var exists bool
err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return "", "", err return fmt.Errorf("user does not appear to exist: %w", err)
} }
// Vault requires the database user already exist, and that the credentials
// used to execute the rotation statements has sufficient privileges.
stmts := statements.Rotation
// Start a transaction
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return fmt.Errorf("unable to start transaction: %w", err)
} }
defer func() { defer tx.Rollback()
_ = tx.Rollback()
}()
// Execute each query
for _, stmt := range stmts { for _, stmt := range stmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
@@ -160,117 +167,30 @@ func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Sta
} }
m := map[string]string{ m := map[string]string{
"name": staticUser.Username, "name": username,
"username": staticUser.Username, "username": username,
"password": password, "password": password,
} }
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err return fmt.Errorf("failed to execute query: %w", err)
} }
} }
} }
// Commit the transaction
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return "", "", err return err
} }
return username, password, nil return nil
} }
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, changeExp *newdbplugin.ChangeExpiration) error {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
}
// Grab the lock
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
username, err = p.GenerateUsername(usernameConfig) renewStmts := changeExp.Statements.Commands
if err != nil {
return "", "", err
}
password, err = p.GeneratePassword()
if err != nil {
return "", "", err
}
expirationStr, err := p.GenerateExpiration(expiration)
if err != nil {
return "", "", err
}
// Get the connection
db, err := p.getConnection(ctx)
if err != nil {
return "", "", err
}
// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
}
defer func() {
tx.Rollback()
}()
// Execute each query
for _, stmt := range statements.Creation {
if containsMultilineStatement(stmt) {
// Execute it as-is.
m := map[string]string{
"name": username,
"username": username,
"password": password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil {
return "", "", err
}
continue
}
// Otherwise, it's fine to split the statements on the semicolon.
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
"username": username,
"password": password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return "", "", err
}
return username, password, nil
}
func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
p.Lock()
defer p.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
renewStmts := statements.Renewal
if len(renewStmts) == 0 { if len(renewStmts) == 0 {
renewStmts = []string{defaultPostgresRenewSQL} renewStmts = []string{defaultExpirationStatement}
} }
db, err := p.getConnection(ctx) db, err := p.getConnection(ctx)
@@ -286,10 +206,7 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
tx.Rollback() tx.Rollback()
}() }()
expirationStr, err := p.GenerateExpiration(expiration) expirationStr := changeExp.NewExpiration.Format(expirationFormat)
if err != nil {
return err
}
for _, stmt := range renewStmts { for _, stmt := range renewStmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
@@ -312,21 +229,93 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
return tx.Commit() return tx.Commit()
} }
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { func (p *PostgreSQL) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
// Grab the lock if len(req.Statements.Commands) == 0 {
return newdbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
}
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements) username, err := credsutil.GenerateUsername(
credsutil.DisplayName(req.UsernameConfig.DisplayName, 8),
if len(statements.Revocation) == 0 { credsutil.RoleName(req.UsernameConfig.RoleName, 8),
return p.defaultRevokeUser(ctx, username) credsutil.Separator("-"),
credsutil.MaxLength(63),
)
if err != nil {
return newdbplugin.NewUserResponse{}, err
} }
return p.customRevokeUser(ctx, username, statements.Revocation) expirationStr := req.Expiration.Format(expirationFormat)
db, err := p.getConnection(ctx)
if err != nil {
return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to get connection: %w", err)
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return newdbplugin.NewUserResponse{}, fmt.Errorf("unable to start transaction: %w", err)
}
defer tx.Rollback()
for _, stmt := range req.Statements.Commands {
if containsMultilineStatement(stmt) {
// Execute it as-is.
m := map[string]string{
"name": username,
"username": username,
"password": req.Password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil {
return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err)
}
continue
}
// Otherwise, it's fine to split the statements on the semicolon.
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
"username": username,
"password": req.Password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return newdbplugin.NewUserResponse{}, fmt.Errorf("failed to execute query: %w", err)
}
}
}
if err := tx.Commit(); err != nil {
return newdbplugin.NewUserResponse{}, err
}
resp := newdbplugin.NewUserResponse{
Username: username,
}
return resp, nil
} }
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { func (p *PostgreSQL) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
p.Lock()
defer p.Unlock()
if len(req.Statements.Commands) == 0 {
return newdbplugin.DeleteUserResponse{}, p.defaultDeleteUser(ctx, req.Username)
}
return newdbplugin.DeleteUserResponse{}, p.customDeleteUser(ctx, req.Username, req.Statements.Commands)
}
func (p *PostgreSQL) customDeleteUser(ctx context.Context, username string, revocationStmts []string) error {
db, err := p.getConnection(ctx) db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return err return err
@@ -360,7 +349,7 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revo
return tx.Commit() return tx.Commit()
} }
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { func (p *PostgreSQL) defaultDeleteUser(ctx context.Context, username string) error {
db, err := p.getConnection(ctx) db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return err return err
@@ -471,65 +460,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
return nil return nil
} }
func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { func (p *PostgreSQL) secretValues() map[string]string {
p.Lock() return map[string]string{
defer p.Unlock() p.Password: "[password]",
if len(p.Username) == 0 || len(p.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
} }
rotateStatements := statements
if len(rotateStatements) == 0 {
rotateStatements = []string{defaultPostgresRotateRootCredentialsSQL}
}
db, err := p.getConnection(ctx)
if err != nil {
return nil, err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := p.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatements {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": p.Username,
"username": p.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
// Close the database connection to ensure no new connections come in
if err := db.Close(); err != nil {
return nil, err
}
p.RawConfig["password"] = password
return p.RawConfig, nil
} }
// containsMultilineStatement is a best effort to determine whether // containsMultilineStatement is a best effort to determine whether

View File

@@ -9,9 +9,8 @@ import (
"time" "time"
"github.com/hashicorp/vault/helper/testhelpers/postgresql" "github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/database/dbplugin" "github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/helper/dbtxn" dbtesting "github.com/hashicorp/vault/sdk/database/newdbplugin/testing"
"github.com/lib/pq"
) )
func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) { func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, func()) {
@@ -24,12 +23,14 @@ func getPostgreSQL(t *testing.T, options map[string]interface{}) (*PostgreSQL, f
connectionDetails[k] = v connectionDetails[k] = v
} }
db := new() req := newdbplugin.InitializeRequest{
_, err := db.Init(context.Background(), connectionDetails, true) Config: connectionDetails,
if err != nil { VerifyConnection: true,
t.Fatalf("err: %s", err)
} }
db := new()
dbtesting.AssertInitialize(t, db, req)
if !db.Initialized { if !db.Initialized {
t.Fatal("Database should be initialized") t.Fatal("Database should be initialized")
} }
@@ -58,90 +59,163 @@ func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
} }
} }
func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) { func TestPostgreSQL_NewUser(t *testing.T) {
db := new()
usernameConfig := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
username, password, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
if err == nil {
t.Fatalf("expected err, got nil")
}
if username != "" {
t.Fatalf("expected empty username, got [%s]", username)
}
if password != "" {
t.Fatalf("expected empty password, got [%s]", password)
}
}
func TestPostgreSQL_CreateUser(t *testing.T) {
type testCase struct { type testCase struct {
createStmts []string req newdbplugin.NewUserRequest
shouldTestCredsExist bool expectErr bool
credsAssertion credsAssertion
} }
tests := map[string]testCase{ tests := map[string]testCase{
"admin name": { "no creation statements": {
createStmts: []string{` req: newdbplugin.NewUserRequest{
CREATE ROLE "{{name}}" WITH UsernameConfig: newdbplugin.UsernameMetadata{
LOGIN DisplayName: "test",
PASSWORD '{{password}}' RoleName: "test",
VALID UNTIL '{{expiration}}'; },
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`, // No statements
Password: "somesecurepassword",
Expiration: time.Now().Add(1 * time.Minute),
}, },
shouldTestCredsExist: true, expectErr: true,
credsAssertion: assertCredsDoNotExist,
},
"admin name": {
req: newdbplugin.NewUserRequest{
UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{
Commands: []string{`
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
},
},
Password: "somesecurepassword",
Expiration: time.Now().Add(1 * time.Minute),
},
expectErr: false,
credsAssertion: assertCredsExist,
}, },
"admin username": { "admin username": {
createStmts: []string{` req: newdbplugin.NewUserRequest{
CREATE ROLE "{{username}}" WITH UsernameConfig: newdbplugin.UsernameMetadata{
LOGIN DisplayName: "test",
PASSWORD '{{password}}' RoleName: "test",
VALID UNTIL '{{expiration}}'; },
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`, Statements: newdbplugin.Statements{
Commands: []string{`
CREATE ROLE "{{username}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`,
},
},
Password: "somesecurepassword",
Expiration: time.Now().Add(1 * time.Minute),
}, },
shouldTestCredsExist: true, expectErr: false,
credsAssertion: assertCredsExist,
}, },
"read only name": { "read only name": {
createStmts: []string{` req: newdbplugin.NewUserRequest{
CREATE ROLE "{{name}}" WITH UsernameConfig: newdbplugin.UsernameMetadata{
LOGIN DisplayName: "test",
PASSWORD '{{password}}' RoleName: "test",
VALID UNTIL '{{expiration}}'; },
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; Statements: newdbplugin.Statements{
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`, Commands: []string{`
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`,
},
},
Password: "somesecurepassword",
Expiration: time.Now().Add(1 * time.Minute),
}, },
shouldTestCredsExist: true, expectErr: false,
credsAssertion: assertCredsExist,
}, },
"read only username": { "read only username": {
createStmts: []string{` req: newdbplugin.NewUserRequest{
CREATE ROLE "{{username}}" WITH UsernameConfig: newdbplugin.UsernameMetadata{
LOGIN DisplayName: "test",
PASSWORD '{{password}}' RoleName: "test",
VALID UNTIL '{{expiration}}'; },
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}"; Statements: newdbplugin.Statements{
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`, Commands: []string{`
CREATE ROLE "{{username}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`,
},
},
Password: "somesecurepassword",
Expiration: time.Now().Add(1 * time.Minute),
}, },
shouldTestCredsExist: true, expectErr: false,
credsAssertion: assertCredsExist,
}, },
// https://github.com/hashicorp/vault/issues/6098 // https://github.com/hashicorp/vault/issues/6098
"reproduce GH-6098": { "reproduce GH-6098": {
createStmts: []string{ req: newdbplugin.NewUserRequest{
// NOTE: "rolname" in the following line is not a typo. UsernameConfig: newdbplugin.UsernameMetadata{
"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$", DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{
Commands: []string{
// NOTE: "rolname" in the following line is not a typo.
"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$",
},
},
Password: "somesecurepassword",
Expiration: time.Now().Add(1 * time.Minute),
}, },
// This test statement doesn't generate creds. expectErr: false,
shouldTestCredsExist: false, credsAssertion: assertCredsDoNotExist,
}, },
"reproduce issue with template": { "reproduce issue with template": {
createStmts: []string{ req: newdbplugin.NewUserRequest{
`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`, UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{
Commands: []string{
`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`,
},
},
Password: "somesecurepassword",
Expiration: time.Now().Add(1 * time.Minute),
}, },
// This test statement doesn't generate creds. expectErr: false,
shouldTestCredsExist: false, credsAssertion: assertCredsDoNotExist,
},
"large block statements": {
req: newdbplugin.NewUserRequest{
UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{
Commands: newUserLargeBlockStatements,
},
Password: "somesecurepassword",
Expiration: time.Now().Add(1 * time.Minute),
},
expectErr: false,
credsAssertion: assertCredsExist,
}, },
} }
@@ -151,59 +225,55 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
usernameConfig := dbplugin.UsernameConfig{
DisplayName: "test",
RoleName: "test",
}
statements := dbplugin.Statements{
Creation: test.createStmts,
}
// Give a timeout just in case the test decides to be problematic // Give a timeout just in case the test decides to be problematic
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(time.Minute)) resp, err := db.NewUser(ctx, test.req)
if err != nil { if test.expectErr && err == nil {
t.Fatalf("err: %s", err) t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
} }
if !test.shouldTestCredsExist { test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password)
// We're done here.
return
}
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// Ensure that the role doesn't expire immediately // Ensure that the role doesn't expire immediately
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { test.credsAssertion(t, db.ConnectionURL, resp.Username, test.req.Password)
t.Fatalf("Could not connect with new credentials: %s", err)
}
}) })
} }
} }
func TestPostgreSQL_RenewUser(t *testing.T) { func TestUpdateUser_Password(t *testing.T) {
type testCase struct { type testCase struct {
renewalStmts []string statements []string
expectErr bool
credsAssertion credsAssertion
} }
tests := map[string]testCase{ tests := map[string]testCase{
"empty renewal statements": { "default statements": {
renewalStmts: nil, statements: nil,
expectErr: false,
credsAssertion: assertCredsExist,
}, },
"default renewal name": { "explicit default statements": {
renewalStmts: []string{defaultPostgresRenewSQL}, statements: []string{defaultChangePasswordStatement},
expectErr: false,
credsAssertion: assertCredsExist,
}, },
"default renewal username": { "name instead of username": {
renewalStmts: []string{` statements: []string{`ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`},
ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`, expectErr: false,
}, credsAssertion: assertCredsExist,
},
"bad statements": {
statements: []string{`asdofyas8uf77asoiajv`},
expectErr: true,
credsAssertion: assertCredsDoNotExist,
}, },
} }
@@ -213,128 +283,251 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
statements := dbplugin.Statements{ initialPass := "myreallysecurepassword"
Creation: []string{createAdminUser}, createReq := newdbplugin.NewUserRequest{
Renewal: test.renewalStmts, UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{
Commands: []string{createAdminUser},
},
Password: initialPass,
Expiration: time.Now().Add(2 * time.Second),
}
createResp := dbtesting.AssertNewUser(t, db, createReq)
assertCredsExist(t, db.ConnectionURL, createResp.Username, initialPass)
newPass := "somenewpassword"
updateReq := newdbplugin.UpdateUserRequest{
Username: createResp.Username,
Password: &newdbplugin.ChangePassword{
NewPassword: newPass,
Statements: newdbplugin.Statements{
Commands: test.statements,
},
},
} }
usernameConfig := dbplugin.UsernameConfig{ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
DisplayName: "test",
RoleName: "test",
}
// Give a timeout just in case the test decides to be problematic
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
_, err := db.UpdateUser(ctx, updateReq)
username, password, err := db.CreateUser(ctx, statements, usernameConfig, time.Now().Add(2*time.Second)) if test.expectErr && err == nil {
if err != nil { t.Fatalf("err expected, got nil")
t.Fatalf("err: %s", err) }
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
} }
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil { test.credsAssertion(t, db.ConnectionURL, createResp.Username, newPass)
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(ctx, statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Sleep longer than the initial expiration time
time.Sleep(2 * time.Second)
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
}) })
} }
t.Run("user does not exist", func(t *testing.T) {
newPass := "somenewpassword"
updateReq := newdbplugin.UpdateUserRequest{
Username: "missing-user",
Password: &newdbplugin.ChangePassword{
NewPassword: newPass,
Statements: newdbplugin.Statements{},
},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := db.UpdateUser(ctx, updateReq)
if err == nil {
t.Fatalf("err expected, got nil")
}
assertCredsDoNotExist(t, db.ConnectionURL, updateReq.Username, newPass)
})
} }
func TestPostgreSQL_RotateRootCredentials(t *testing.T) { func TestUpdateUser_Expiration(t *testing.T) {
type testCase struct { type testCase struct {
statements []string initialExpiration time.Time
newExpiration time.Time
expectedExpiration time.Time
statements []string
expectErr bool
} }
now := time.Now()
tests := map[string]testCase{ tests := map[string]testCase{
"empty statements": { "no statements": {
statements: nil, initialExpiration: now.Add(1 * time.Minute),
newExpiration: now.Add(5 * time.Minute),
expectedExpiration: now.Add(5 * time.Minute),
statements: nil,
expectErr: false,
}, },
"default name": { "default statements with name": {
statements: []string{` initialExpiration: now.Add(1 * time.Minute),
ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, newExpiration: now.Add(5 * time.Minute),
}, expectedExpiration: now.Add(5 * time.Minute),
statements: []string{defaultExpirationStatement},
expectErr: false,
}, },
"default username": { "default statements with username": {
statements: []string{defaultPostgresRotateRootCredentialsSQL}, initialExpiration: now.Add(1 * time.Minute),
newExpiration: now.Add(5 * time.Minute),
expectedExpiration: now.Add(5 * time.Minute),
statements: []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`},
expectErr: false,
},
"bad statements": {
initialExpiration: now.Add(1 * time.Minute),
newExpiration: now.Add(5 * time.Minute),
expectedExpiration: now.Add(1 * time.Minute),
statements: []string{"ladshfouay09sgj"},
expectErr: true,
}, },
} }
// Shared test container for speed - there should not be any overlap between the tests
db, cleanup := getPostgreSQL(t, nil)
defer cleanup()
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
cleanup, connURL := postgresql.PrepareTestContainer(t, "latest") password := "myreallysecurepassword"
defer cleanup() initialExpiration := test.initialExpiration.Truncate(time.Second)
connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1) createReq := newdbplugin.NewUserRequest{
UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{
Commands: []string{createAdminUser},
},
Password: password,
Expiration: initialExpiration,
}
createResp := dbtesting.AssertNewUser(t, db, createReq)
connectionDetails := map[string]interface{}{ assertCredsExist(t, db.ConnectionURL, createResp.Username, password)
"connection_url": connURL,
"max_open_connections": 5, actualExpiration := getExpiration(t, db, createResp.Username)
"username": "postgres", if actualExpiration.IsZero() {
"password": "secret", t.Fatalf("Initial expiration is zero but should be set")
}
if !actualExpiration.Equal(initialExpiration) {
t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, initialExpiration)
} }
db := new() newExpiration := test.newExpiration.Truncate(time.Second)
connProducer := db.SQLConnectionProducer updateReq := newdbplugin.UpdateUserRequest{
Username: createResp.Username,
Expiration: &newdbplugin.ChangeExpiration{
NewExpiration: newExpiration,
Statements: newdbplugin.Statements{
Commands: test.statements,
},
},
}
// Give a timeout just in case the test decides to be problematic ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
_, err := db.UpdateUser(ctx, updateReq)
_, err := db.Init(ctx, connectionDetails, true) if test.expectErr && err == nil {
if err != nil { t.Fatalf("err expected, got nil")
t.Fatalf("err: %s", err) }
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
} }
if !connProducer.Initialized { expectedExpiration := test.expectedExpiration.Truncate(time.Second)
t.Fatal("Database should be initialized") actualExpiration = getExpiration(t, db, createResp.Username)
} if !actualExpiration.Equal(expectedExpiration) {
t.Fatalf("Actual expiration: %s Expected expiration: %s", actualExpiration, expectedExpiration)
newConf, err := db.RotateRootCredentials(ctx, test.statements)
if err != nil {
t.Fatalf("err: %v", err)
}
if newConf["password"] == "secret" {
t.Fatal("password was not updated")
}
err = db.Close()
if err != nil {
t.Fatalf("failed to close: %s", err)
} }
}) })
} }
} }
func TestPostgreSQL_RevokeUser(t *testing.T) { func getExpiration(t testing.TB, db *PostgreSQL, username string) time.Time {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
query := fmt.Sprintf("select valuntil from pg_catalog.pg_user where usename = '%s'", username)
conn, err := db.getConnection(ctx)
if err != nil {
t.Fatalf("Failed to get connection to database: %s", err)
}
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
t.Fatalf("Failed to prepare statement: %s", err)
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
t.Fatalf("Failed to execute query to get expiration: %s", err)
}
if !rows.Next() {
return time.Time{} // No expiration
}
rawExp := ""
err = rows.Scan(&rawExp)
if err != nil {
t.Fatalf("Unable to get raw expiration: %s", err)
}
if rawExp == "" {
return time.Time{} // No expiration
}
exp, err := time.Parse(time.RFC3339, rawExp)
if err != nil {
t.Fatalf("Failed to parse expiration %q: %s", rawExp, err)
}
return exp
}
func TestDeleteUser(t *testing.T) {
type testCase struct { type testCase struct {
revokeStmts []string revokeStmts []string
expectErr bool
credsAssertion credsAssertion
} }
tests := map[string]testCase{ tests := map[string]testCase{
"empty statements": { "no statements": {
revokeStmts: nil, revokeStmts: nil,
expectErr: false,
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
}, },
"explicit default name": { "statements with name": {
revokeStmts: []string{defaultPostgresRevocationSQL}, revokeStmts: []string{`
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
DROP ROLE IF EXISTS "{{name}}";`},
expectErr: false,
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
}, },
"explicit default username": { "statements with username": {
revokeStmts: []string{` revokeStmts: []string{`
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}"; REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{username}}";
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}"; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{username}}";
REVOKE USAGE ON SCHEMA public FROM "{{username}}"; REVOKE USAGE ON SCHEMA public FROM "{{username}}";
DROP ROLE IF EXISTS "{{username}}";`, DROP ROLE IF EXISTS "{{username}}";`},
}, expectErr: false,
// Wait for a short time before failing because postgres takes a moment to finish deleting the user
credsAssertion: waitUntilCredsDoNotExist(2 * time.Second),
},
"bad statements": {
revokeStmts: []string{`8a9yhfoiasjff`},
expectErr: true,
// Wait for a short time before checking because postgres takes a moment to finish deleting the user
credsAssertion: assertCredsExistAfter(100 * time.Millisecond),
}, },
} }
@@ -344,155 +537,91 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
statements := dbplugin.Statements{ password := "myreallysecurepassword"
Creation: []string{createAdminUser}, createReq := newdbplugin.NewUserRequest{
Revocation: test.revokeStmts, UsernameConfig: newdbplugin.UsernameMetadata{
DisplayName: "test",
RoleName: "test",
},
Statements: newdbplugin.Statements{
Commands: []string{createAdminUser},
},
Password: password,
Expiration: time.Now().Add(2 * time.Second),
}
createResp := dbtesting.AssertNewUser(t, db, createReq)
assertCredsExist(t, db.ConnectionURL, createResp.Username, password)
deleteReq := newdbplugin.DeleteUserRequest{
Username: createResp.Username,
Statements: newdbplugin.Statements{
Commands: test.revokeStmts,
},
} }
usernameConfig := dbplugin.UsernameConfig{ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
DisplayName: "test", defer cancel()
RoleName: "test",
_, err := db.DeleteUser(ctx, deleteReq)
if test.expectErr && err == nil {
t.Fatalf("err expected, got nil")
}
if !test.expectErr && err != nil {
t.Fatalf("no error expected, got: %s", err)
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) test.credsAssertion(t, db.ConnectionURL, createResp.Username, password)
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, db.ConnectionURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// Test default revoke statements
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, db.ConnectionURL, username, password); err == nil {
t.Fatal("Credentials were not revoked")
}
}) })
} }
} }
func TestPostgreSQL_SetCredentials_missingArgs(t *testing.T) { type credsAssertion func(t testing.TB, connURL, username, password string)
type testCase struct {
statements dbplugin.Statements
userConfig dbplugin.StaticUserConfig
}
tests := map[string]testCase{ func assertCredsExist(t testing.TB, connURL, username, password string) {
"empty rotation statements": { t.Helper()
statements: dbplugin.Statements{ err := testCredsExist(t, connURL, username, password)
Rotation: nil, if err != nil {
}, t.Fatalf("user does not exist: %s", err)
userConfig: dbplugin.StaticUserConfig{
Username: "testuser",
Password: "password",
},
},
"empty username": {
statements: dbplugin.Statements{
Rotation: []string{`
ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`,
},
},
userConfig: dbplugin.StaticUserConfig{
Username: "",
Password: "password",
},
},
"empty password": {
statements: dbplugin.Statements{
Rotation: []string{`
ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`,
},
},
userConfig: dbplugin.StaticUserConfig{
Username: "testuser",
Password: "",
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
db := new()
username, password, err := db.SetCredentials(context.Background(), test.statements, test.userConfig)
if err == nil {
t.Fatalf("expected err, got nil")
}
if username != "" {
t.Fatalf("expected empty username, got [%s]", username)
}
if password != "" {
t.Fatalf("expected empty password, got [%s]", password)
}
})
} }
} }
func TestPostgresSQL_SetCredentials(t *testing.T) { func assertCredsDoNotExist(t testing.TB, connURL, username, password string) {
type testCase struct { t.Helper()
rotationStmts []string err := testCredsExist(t, connURL, username, password)
if err == nil {
t.Fatalf("user should not exist but does")
} }
}
tests := map[string]testCase{ func waitUntilCredsDoNotExist(timeout time.Duration) credsAssertion {
"name rotation": { return func(t testing.TB, connURL, username, password string) {
rotationStmts: []string{` t.Helper()
ALTER ROLE "{{name}}" WITH PASSWORD '{{password}}';`, ctx, cancel := context.WithTimeout(context.Background(), timeout)
}, defer cancel()
},
"username rotation": { ticker := time.NewTicker(10 * time.Millisecond)
rotationStmts: []string{` defer ticker.Stop()
ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';`, for {
}, select {
}, case <-ctx.Done():
t.Fatalf("Timed out waiting for user %s to be deleted", username)
case <-ticker.C:
err := testCredsExist(t, connURL, username, password)
if err != nil {
// Happy path
return
}
}
}
} }
}
for name, test := range tests { func assertCredsExistAfter(timeout time.Duration) credsAssertion {
t.Run(name, func(t *testing.T) { return func(t testing.TB, connURL, username, password string) {
db, cleanup := getPostgreSQL(t, nil) t.Helper()
defer cleanup() time.Sleep(timeout)
assertCredsExist(t, connURL, username, password)
// create the database user
dbUser := "vaultstatictest"
initPassword := "password"
createTestPGUser(t, db.ConnectionURL, dbUser, initPassword, testRoleStaticCreate)
statements := dbplugin.Statements{
Rotation: test.rotationStmts,
}
password, err := db.GenerateCredentials(context.Background())
if err != nil {
t.Fatal(err)
}
usernameConfig := dbplugin.StaticUserConfig{
Username: dbUser,
Password: password,
}
if err := testCredsExist(t, db.ConnectionURL, dbUser, initPassword); err != nil {
t.Fatalf("Could not connect with initial credentials: %s", err)
}
username, password, err := db.SetCredentials(context.Background(), statements, usernameConfig)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, db.ConnectionURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
if err := testCredsExist(t, db.ConnectionURL, username, initPassword); err == nil {
t.Fatalf("Should not be able to connect with initial credentials")
}
})
} }
} }
@@ -516,7 +645,7 @@ CREATE ROLE "{{name}}" WITH
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
` `
var testPostgresBlockStatementRoleSlice = []string{ var newUserLargeBlockStatements = []string{
` `
DO $$ DO $$
BEGIN BEGIN
@@ -539,59 +668,6 @@ $$
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
} }
const defaultPostgresRevocationSQL = `
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
DROP ROLE IF EXISTS "{{name}}";
`
const testRoleStaticCreate = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}';
`
// This is a copy of a test helper method also found in
// builtin/logical/database/rotation_test.go , and should be moved into a shared
// helper file in the future.
func createTestPGUser(t *testing.T, connURL string, username, password, query string) {
t.Helper()
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
db, err := sql.Open("postgres", conn)
defer db.Close()
if err != nil {
t.Fatal(err)
}
// Start a transaction
ctx := context.Background()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = tx.Rollback()
}()
m := map[string]string{
"name": username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
t.Fatal(err)
}
// Commit the transaction
if err := tx.Commit(); err != nil {
t.Fatal(err)
}
}
func TestContainsMultilineStatement(t *testing.T) { func TestContainsMultilineStatement(t *testing.T) {
type testCase struct { type testCase struct {
Input string Input string