mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	Database gRPC plugins (#3666)
* Start work on context aware backends * Start work on moving the database plugins to gRPC in order to pass context * Add context to builtin database plugins * use byte slice instead of string * Context all the things * Move proto messages to the dbplugin package * Add a grpc mechanism for running backend plugins * Serve the GRPC plugin * Add backwards compatibility to the database plugins * Remove backend plugin changes * Remove backend plugin changes * Cleanup the transport implementations * If grpc connection is in an unexpected state restart the plugin * Fix tests * Fix tests * Remove context from the request object, replace it with context.TODO * Add a test to verify netRPC plugins still work * Remove unused mapstructure call * Code review fixes * Code review fixes * Code review fixes
This commit is contained in:
		| @@ -1,6 +1,7 @@ | ||||
| package postgresql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| @@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; | ||||
| ` | ||||
| ) | ||||
|  | ||||
| var _ dbplugin.Database = &PostgreSQL{} | ||||
|  | ||||
| // New implements builtinplugins.BuiltinFactory | ||||
| func New() (interface{}, error) { | ||||
| 	connProducer := &connutil.SQLConnectionProducer{} | ||||
| @@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) { | ||||
| 	return postgreSQLTypeName, nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) getConnection() (*sql.DB, error) { | ||||
| 	db, err := p.Connection() | ||||
| func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { | ||||
| 	db, err := p.Connection(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { | ||||
| 	return db.(*sql.DB), nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	if statements.CreationStatements == "" { | ||||
| 		return "", "", dbutil.ErrEmptyCreationStatement | ||||
| 	} | ||||
| @@ -99,7 +102,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d | ||||
| 	} | ||||
|  | ||||
| 	// Get the connection | ||||
| 	db, err := p.getConnection() | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
|  | ||||
| @@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d | ||||
| 	return username, password, nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| @@ -157,7 +160,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, | ||||
| 		renewStmts = defaultPostgresRenewSQL | ||||
| 	} | ||||
|  | ||||
| 	db, err := p.getConnection() | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -201,20 +204,20 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error { | ||||
| func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	// Grab the lock | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	if statements.RevocationStatements == "" { | ||||
| 		return p.defaultRevokeUser(username) | ||||
| 		return p.defaultRevokeUser(ctx, username) | ||||
| 	} | ||||
|  | ||||
| 	return p.customRevokeUser(username, statements.RevocationStatements) | ||||
| 	return p.customRevokeUser(ctx, username, statements.RevocationStatements) | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { | ||||
| 	db, err := p.getConnection() | ||||
| func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error { | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -253,8 +256,8 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) defaultRevokeUser(username string) error { | ||||
| 	db, err := p.getConnection() | ||||
| func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package postgresql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| @@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { | ||||
|  | ||||
| 	connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { | ||||
| 		"max_open_connections": "5", | ||||
| 	} | ||||
|  | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*PostgreSQL) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test with no configured Creation Statememt | ||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("Expected error when no creation statement is provided") | ||||
| 	} | ||||
| @@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 		CreationStatements: testPostgresRole, | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	statements.CreationStatements = testPostgresReadOnlyRole | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*PostgreSQL) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| 		t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) | ||||
| 	err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| 		t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 	} | ||||
| 	statements.RenewStatements = defaultPostgresRenewSQL | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| 		t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) | ||||
| 	err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*PostgreSQL) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test default revoke statememts | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
| 		t.Fatal("Credentials were not revoked") | ||||
| 	} | ||||
|  | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	// Test custom revoke statements | ||||
| 	statements.RevocationStatements = defaultPostgresRevocationSQL | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Brian Kassouf
					Brian Kassouf