mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-12-23 22:07:06 +00:00
Database gRPC plugins (#3666)
* Start work on context aware backends * Start work on moving the database plugins to gRPC in order to pass context * Add context to builtin database plugins * use byte slice instead of string * Context all the things * Move proto messages to the dbplugin package * Add a grpc mechanism for running backend plugins * Serve the GRPC plugin * Add backwards compatibility to the database plugins * Remove backend plugin changes * Remove backend plugin changes * Cleanup the transport implementations * If grpc connection is in an unexpected state restart the plugin * Fix tests * Fix tests * Remove context from the request object, replace it with context.TODO * Add a test to verify netRPC plugins still work * Remove unused mapstructure call * Code review fixes * Code review fixes * Code review fixes
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
|
||||
`
|
||||
)
|
||||
|
||||
var _ dbplugin.Database = &PostgreSQL{}
|
||||
|
||||
// New implements builtinplugins.BuiltinFactory
|
||||
func New() (interface{}, error) {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
@@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) {
|
||||
return postgreSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := p.Connection()
|
||||
func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
||||
db, err := p.Connection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
@@ -99,7 +102,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
||||
}
|
||||
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
@@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
@@ -157,7 +160,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
||||
renewStmts = defaultPostgresRenewSQL
|
||||
}
|
||||
|
||||
db, err := p.getConnection()
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -201,20 +204,20 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if statements.RevocationStatements == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
return p.defaultRevokeUser(ctx, username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, statements.RevocationStatements)
|
||||
return p.customRevokeUser(ctx, username, statements.RevocationStatements)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
db, err := p.getConnection()
|
||||
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -253,8 +256,8 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||
db, err := p.getConnection()
|
||||
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
|
||||
db, err := p.getConnection(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
||||
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
||||
"max_open_connections": "5",
|
||||
}
|
||||
|
||||
err = db.Initialize(connectionDetails, true)
|
||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
@@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
}
|
||||
|
||||
statements.CreationStatements = testPostgresReadOnlyRole
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
statements.RenewStatements = defaultPostgresRenewSQL
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
|
||||
dbRaw, _ := New()
|
||||
db := dbRaw.(*PostgreSQL)
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
RoleName: "test",
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
@@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = defaultPostgresRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
err = db.RevokeUser(context.Background(), statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user