Support custom renewal statements in Postgres (#2788)

* Support custom renewal statements in Postgres

* Refactored out default/custom renew methods
This commit is contained in:
Andrew Paulin
2017-06-01 16:18:16 -04:00
committed by Brian Kassouf
parent ed9ff085c4
commit d004ad75db
2 changed files with 59 additions and 11 deletions

View File

@@ -16,7 +16,12 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
const postgreSQLTypeName string = "postgres" const (
postgreSQLTypeName string = "postgres"
defaultPostgresRenewSQL = `
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
`
)
// New implements builtinplugins.BuiltinFactory // New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) { func New() (interface{}, error) {
@@ -141,31 +146,52 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix s
} }
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
// Grab the lock
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
renewStmts := statements.RenewStatements
if renewStmts == "" {
renewStmts = defaultPostgresRenewSQL
}
db, err := p.getConnection() db, err := p.getConnection()
if err != nil { if err != nil {
return err return err
} }
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
expirationStr, err := p.GenerateExpiration(expiration) expirationStr, err := p.GenerateExpiration(expiration)
if err != nil { if err != nil {
return err return err
} }
query := fmt.Sprintf( for _, query := range strutil.ParseArbitraryStringSlice(renewStmts, ";") {
"ALTER ROLE %s VALID UNTIL '%s';", query = strings.TrimSpace(query)
pq.QuoteIdentifier(username), if len(query) == 0 {
expirationStr) continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
"name": username,
"expiration": expirationStr,
}))
if err != nil {
return err
}
stmt, err := db.Prepare(query) defer stmt.Close()
if err != nil { if _, err := stmt.Exec(); err != nil {
return err return err
}
} }
defer stmt.Close()
if _, err := stmt.Exec(); err != nil { if err := tx.Commit(); err != nil {
return err return err
} }

View File

@@ -170,6 +170,28 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
if err = testCredsExist(t, connURL, username, password); err != nil { if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
statements.RenewStatements = defaultPostgresRenewSQL
username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Sleep longer than the inital expiration time
time.Sleep(2 * time.Second)
if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
} }
func TestPostgreSQL_RevokeUser(t *testing.T) { func TestPostgreSQL_RevokeUser(t *testing.T) {