mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Support custom renewal statements in Postgres (#2788)
* Support custom renewal statements in Postgres * Refactored out default/custom renew methods
This commit is contained in:
committed by
Brian Kassouf
parent
ed9ff085c4
commit
d004ad75db
@@ -16,7 +16,12 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const postgreSQLTypeName string = "postgres"
|
||||
const (
|
||||
postgreSQLTypeName string = "postgres"
|
||||
defaultPostgresRenewSQL = `
|
||||
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
|
||||
`
|
||||
)
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New() (interface{}, error) {
|
||||
@@ -141,33 +146,54 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix s
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
renewStmts := statements.RenewStatements
|
||||
if renewStmts == "" {
|
||||
renewStmts = defaultPostgresRenewSQL
|
||||
}
|
||||
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
expirationStr, err := p.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"ALTER ROLE %s VALID UNTIL '%s';",
|
||||
pq.QuoteIdentifier(username),
|
||||
expirationStr)
|
||||
|
||||
stmt, err := db.Prepare(query)
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(renewStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"expiration": expirationStr,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -170,6 +170,28 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
statements.RenewStatements = defaultPostgresRenewSQL
|
||||
username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Sleep longer than the inital expiration time
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user