mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	Database gRPC plugins (#3666)
* Start work on context aware backends * Start work on moving the database plugins to gRPC in order to pass context * Add context to builtin database plugins * use byte slice instead of string * Context all the things * Move proto messages to the dbplugin package * Add a grpc mechanism for running backend plugins * Serve the GRPC plugin * Add backwards compatibility to the database plugins * Remove backend plugin changes * Remove backend plugin changes * Cleanup the transport implementations * If grpc connection is in an unexpected state restart the plugin * Fix tests * Fix tests * Remove context from the request object, replace it with context.TODO * Add a test to verify netRPC plugins still work * Remove unused mapstructure call * Code review fixes * Code review fixes * Code review fixes
This commit is contained in:
		
							
								
								
									
										1
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								Makefile
									
									
									
									
									
								
							| @@ -84,6 +84,7 @@ proto: | ||||
| 	protoc -I helper/forwarding -I vault -I ../../.. helper/forwarding/types.proto --go_out=plugins=grpc:helper/forwarding | ||||
| 	protoc -I physical physical/types.proto --go_out=plugins=grpc:physical | ||||
| 	protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity | ||||
| 	protoc  builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:. | ||||
| 	sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go | ||||
| 	sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package database | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/rpc" | ||||
| 	"strings" | ||||
| @@ -87,7 +88,7 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { | ||||
| // This function creates a new db object from the stored configuration and | ||||
| // caches it in the connections map. The caller of this function needs to hold | ||||
| // the backend's write lock | ||||
| func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { | ||||
| func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) { | ||||
| 	db, ok := b.connections[name] | ||||
| 	if ok { | ||||
| 		return db, nil | ||||
| @@ -103,7 +104,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin. | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = db.Initialize(config.ConnectionDetails, true) | ||||
| 	err = db.Initialize(ctx, config.ConnectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -170,7 +171,8 @@ func (b *databaseBackend) clearConnection(name string) { | ||||
|  | ||||
| func (b *databaseBackend) closeIfShutdown(name string, err error) { | ||||
| 	// Plugin has shutdown, close it so next call can reconnect. | ||||
| 	if err == rpc.ErrShutdown { | ||||
| 	switch err { | ||||
| 	case rpc.ErrShutdown, dbplugin.ErrPluginShutdown: | ||||
| 		b.Lock() | ||||
| 		b.clearConnection(name) | ||||
| 		b.Unlock() | ||||
|   | ||||
| @@ -488,9 +488,11 @@ func TestBackend_roleCrud(t *testing.T) { | ||||
| 		RevocationStatements: defaultRevocationSQL, | ||||
| 	} | ||||
|  | ||||
| 	var actual dbplugin.Statements | ||||
| 	if err := mapstructure.Decode(resp.Data, &actual); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	actual := dbplugin.Statements{ | ||||
| 		CreationStatements:   resp.Data["creation_statements"].(string), | ||||
| 		RevocationStatements: resp.Data["revocation_statements"].(string), | ||||
| 		RollbackStatements:   resp.Data["rollback_statements"].(string), | ||||
| 		RenewStatements:      resp.Data["renew_statements"].(string), | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(expected, actual) { | ||||
|   | ||||
| @@ -1,10 +1,8 @@ | ||||
| package dbplugin | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/rpc" | ||||
| 	"errors" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/go-plugin" | ||||
| 	"github.com/hashicorp/vault/helper/pluginutil" | ||||
| @@ -17,11 +15,11 @@ type DatabasePluginClient struct { | ||||
| 	client *plugin.Client | ||||
| 	sync.Mutex | ||||
|  | ||||
| 	*databasePluginRPCClient | ||||
| 	Database | ||||
| } | ||||
|  | ||||
| func (dc *DatabasePluginClient) Close() error { | ||||
| 	err := dc.databasePluginRPCClient.Close() | ||||
| 	err := dc.Database.Close() | ||||
| 	dc.client.Kill() | ||||
|  | ||||
| 	return err | ||||
| @@ -55,79 +53,20 @@ func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginR | ||||
|  | ||||
| 	// We should have a database type now. This feels like a normal interface | ||||
| 	// implementation but is in fact over an RPC connection. | ||||
| 	databaseRPC := raw.(*databasePluginRPCClient) | ||||
| 	var db Database | ||||
| 	switch raw.(type) { | ||||
| 	case *gRPCClient: | ||||
| 		db = raw.(*gRPCClient) | ||||
| 	case *databasePluginRPCClient: | ||||
| 		logger.Warn("database: plugin is using deprecated net RPC transport, recompile plugin to upgrade to gRPC", "plugin", pluginRunner.Name) | ||||
| 		db = raw.(*databasePluginRPCClient) | ||||
| 	default: | ||||
| 		return nil, errors.New("unsupported client type") | ||||
| 	} | ||||
|  | ||||
| 	// Wrap RPC implimentation in DatabasePluginClient | ||||
| 	return &DatabasePluginClient{ | ||||
| 		client:   client, | ||||
| 		databasePluginRPCClient: databaseRPC, | ||||
| 		Database: db, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // ---- RPC client domain ---- | ||||
|  | ||||
| // databasePluginRPCClient implements Database and is used on the client to | ||||
| // make RPC calls to a plugin. | ||||
| type databasePluginRPCClient struct { | ||||
| 	client *rpc.Client | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) Type() (string, error) { | ||||
| 	var dbType string | ||||
| 	err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) | ||||
|  | ||||
| 	return fmt.Sprintf("plugin-%s", dbType), err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	req := CreateUserRequest{ | ||||
| 		Statements:     statements, | ||||
| 		UsernameConfig: usernameConfig, | ||||
| 		Expiration:     expiration, | ||||
| 	} | ||||
|  | ||||
| 	var resp CreateUserResponse | ||||
| 	err = dr.client.Call("Plugin.CreateUser", req, &resp) | ||||
|  | ||||
| 	return resp.Username, resp.Password, err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error { | ||||
| 	req := RenewUserRequest{ | ||||
| 		Statements: statements, | ||||
| 		Username:   username, | ||||
| 		Expiration: expiration, | ||||
| 	} | ||||
|  | ||||
| 	err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { | ||||
| 	req := RevokeUserRequest{ | ||||
| 		Statements: statements, | ||||
| 		Username:   username, | ||||
| 	} | ||||
|  | ||||
| 	err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error { | ||||
| 	req := InitializeRequest{ | ||||
| 		Config:           conf, | ||||
| 		VerifyConnection: verifyConnection, | ||||
| 	} | ||||
|  | ||||
| 	err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) Close() error { | ||||
| 	err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|   | ||||
							
								
								
									
										556
									
								
								builtin/logical/database/dbplugin/database.pb.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										556
									
								
								builtin/logical/database/dbplugin/database.pb.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,556 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // source: builtin/logical/database/dbplugin/database.proto | ||||
|  | ||||
| /* | ||||
| Package dbplugin is a generated protocol buffer package. | ||||
|  | ||||
| It is generated from these files: | ||||
| 	builtin/logical/database/dbplugin/database.proto | ||||
|  | ||||
| It has these top-level messages: | ||||
| 	InitializeRequest | ||||
| 	CreateUserRequest | ||||
| 	RenewUserRequest | ||||
| 	RevokeUserRequest | ||||
| 	Statements | ||||
| 	UsernameConfig | ||||
| 	CreateUserResponse | ||||
| 	TypeResponse | ||||
| 	Empty | ||||
| */ | ||||
| package dbplugin | ||||
|  | ||||
| import proto "github.com/golang/protobuf/proto" | ||||
| import fmt "fmt" | ||||
| import math "math" | ||||
| import google_protobuf "github.com/golang/protobuf/ptypes/timestamp" | ||||
|  | ||||
| import ( | ||||
| 	context "golang.org/x/net/context" | ||||
| 	grpc "google.golang.org/grpc" | ||||
| ) | ||||
|  | ||||
| // Reference imports to suppress errors if they are not otherwise used. | ||||
| var _ = proto.Marshal | ||||
| var _ = fmt.Errorf | ||||
| var _ = math.Inf | ||||
|  | ||||
| // This is a compile-time assertion to ensure that this generated file | ||||
| // is compatible with the proto package it is being compiled against. | ||||
| // A compilation error at this line likely means your copy of the | ||||
| // proto package needs to be updated. | ||||
| const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package | ||||
|  | ||||
| type InitializeRequest struct { | ||||
| 	Config           []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` | ||||
| 	VerifyConnection bool   `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m *InitializeRequest) Reset()                    { *m = InitializeRequest{} } | ||||
| func (m *InitializeRequest) String() string            { return proto.CompactTextString(m) } | ||||
| func (*InitializeRequest) ProtoMessage()               {} | ||||
| func (*InitializeRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } | ||||
|  | ||||
| func (m *InitializeRequest) GetConfig() []byte { | ||||
| 	if m != nil { | ||||
| 		return m.Config | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (m *InitializeRequest) GetVerifyConnection() bool { | ||||
| 	if m != nil { | ||||
| 		return m.VerifyConnection | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| type CreateUserRequest struct { | ||||
| 	Statements     *Statements                `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` | ||||
| 	UsernameConfig *UsernameConfig            `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"` | ||||
| 	Expiration     *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m *CreateUserRequest) Reset()                    { *m = CreateUserRequest{} } | ||||
| func (m *CreateUserRequest) String() string            { return proto.CompactTextString(m) } | ||||
| func (*CreateUserRequest) ProtoMessage()               {} | ||||
| func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } | ||||
|  | ||||
| func (m *CreateUserRequest) GetStatements() *Statements { | ||||
| 	if m != nil { | ||||
| 		return m.Statements | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (m *CreateUserRequest) GetUsernameConfig() *UsernameConfig { | ||||
| 	if m != nil { | ||||
| 		return m.UsernameConfig | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (m *CreateUserRequest) GetExpiration() *google_protobuf.Timestamp { | ||||
| 	if m != nil { | ||||
| 		return m.Expiration | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type RenewUserRequest struct { | ||||
| 	Statements *Statements                `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` | ||||
| 	Username   string                     `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` | ||||
| 	Expiration *google_protobuf.Timestamp `protobuf:"bytes,3,opt,name=expiration" json:"expiration,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m *RenewUserRequest) Reset()                    { *m = RenewUserRequest{} } | ||||
| func (m *RenewUserRequest) String() string            { return proto.CompactTextString(m) } | ||||
| func (*RenewUserRequest) ProtoMessage()               {} | ||||
| func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } | ||||
|  | ||||
| func (m *RenewUserRequest) GetStatements() *Statements { | ||||
| 	if m != nil { | ||||
| 		return m.Statements | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (m *RenewUserRequest) GetUsername() string { | ||||
| 	if m != nil { | ||||
| 		return m.Username | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (m *RenewUserRequest) GetExpiration() *google_protobuf.Timestamp { | ||||
| 	if m != nil { | ||||
| 		return m.Expiration | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type RevokeUserRequest struct { | ||||
| 	Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` | ||||
| 	Username   string      `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m *RevokeUserRequest) Reset()                    { *m = RevokeUserRequest{} } | ||||
| func (m *RevokeUserRequest) String() string            { return proto.CompactTextString(m) } | ||||
| func (*RevokeUserRequest) ProtoMessage()               {} | ||||
| func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } | ||||
|  | ||||
| func (m *RevokeUserRequest) GetStatements() *Statements { | ||||
| 	if m != nil { | ||||
| 		return m.Statements | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (m *RevokeUserRequest) GetUsername() string { | ||||
| 	if m != nil { | ||||
| 		return m.Username | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| type Statements struct { | ||||
| 	CreationStatements   string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"` | ||||
| 	RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"` | ||||
| 	RollbackStatements   string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"` | ||||
| 	RenewStatements      string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m *Statements) Reset()                    { *m = Statements{} } | ||||
| func (m *Statements) String() string            { return proto.CompactTextString(m) } | ||||
| func (*Statements) ProtoMessage()               {} | ||||
| func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } | ||||
|  | ||||
| func (m *Statements) GetCreationStatements() string { | ||||
| 	if m != nil { | ||||
| 		return m.CreationStatements | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (m *Statements) GetRevocationStatements() string { | ||||
| 	if m != nil { | ||||
| 		return m.RevocationStatements | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (m *Statements) GetRollbackStatements() string { | ||||
| 	if m != nil { | ||||
| 		return m.RollbackStatements | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (m *Statements) GetRenewStatements() string { | ||||
| 	if m != nil { | ||||
| 		return m.RenewStatements | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| type UsernameConfig struct { | ||||
| 	DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"` | ||||
| 	RoleName    string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m *UsernameConfig) Reset()                    { *m = UsernameConfig{} } | ||||
| func (m *UsernameConfig) String() string            { return proto.CompactTextString(m) } | ||||
| func (*UsernameConfig) ProtoMessage()               {} | ||||
| func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } | ||||
|  | ||||
| func (m *UsernameConfig) GetDisplayName() string { | ||||
| 	if m != nil { | ||||
| 		return m.DisplayName | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (m *UsernameConfig) GetRoleName() string { | ||||
| 	if m != nil { | ||||
| 		return m.RoleName | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| type CreateUserResponse struct { | ||||
| 	Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"` | ||||
| 	Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m *CreateUserResponse) Reset()                    { *m = CreateUserResponse{} } | ||||
| func (m *CreateUserResponse) String() string            { return proto.CompactTextString(m) } | ||||
| func (*CreateUserResponse) ProtoMessage()               {} | ||||
| func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } | ||||
|  | ||||
| func (m *CreateUserResponse) GetUsername() string { | ||||
| 	if m != nil { | ||||
| 		return m.Username | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (m *CreateUserResponse) GetPassword() string { | ||||
| 	if m != nil { | ||||
| 		return m.Password | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| type TypeResponse struct { | ||||
| 	Type string `protobuf:"bytes,1,opt,name=type" json:"type,omitempty"` | ||||
| } | ||||
|  | ||||
| func (m *TypeResponse) Reset()                    { *m = TypeResponse{} } | ||||
| func (m *TypeResponse) String() string            { return proto.CompactTextString(m) } | ||||
| func (*TypeResponse) ProtoMessage()               {} | ||||
| func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } | ||||
|  | ||||
| func (m *TypeResponse) GetType() string { | ||||
| 	if m != nil { | ||||
| 		return m.Type | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| type Empty struct { | ||||
| } | ||||
|  | ||||
| func (m *Empty) Reset()                    { *m = Empty{} } | ||||
| func (m *Empty) String() string            { return proto.CompactTextString(m) } | ||||
| func (*Empty) ProtoMessage()               {} | ||||
| func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} } | ||||
|  | ||||
| func init() { | ||||
| 	proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest") | ||||
| 	proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest") | ||||
| 	proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest") | ||||
| 	proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest") | ||||
| 	proto.RegisterType((*Statements)(nil), "dbplugin.Statements") | ||||
| 	proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig") | ||||
| 	proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse") | ||||
| 	proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse") | ||||
| 	proto.RegisterType((*Empty)(nil), "dbplugin.Empty") | ||||
| } | ||||
|  | ||||
| // Reference imports to suppress errors if they are not otherwise used. | ||||
| var _ context.Context | ||||
| var _ grpc.ClientConn | ||||
|  | ||||
| // This is a compile-time assertion to ensure that this generated file | ||||
| // is compatible with the grpc package it is being compiled against. | ||||
| const _ = grpc.SupportPackageIsVersion4 | ||||
|  | ||||
| // Client API for Database service | ||||
|  | ||||
| type DatabaseClient interface { | ||||
| 	Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) | ||||
| 	CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) | ||||
| 	RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) | ||||
| 	RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) | ||||
| 	Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) | ||||
| 	Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) | ||||
| } | ||||
|  | ||||
| type databaseClient struct { | ||||
| 	cc *grpc.ClientConn | ||||
| } | ||||
|  | ||||
| func NewDatabaseClient(cc *grpc.ClientConn) DatabaseClient { | ||||
| 	return &databaseClient{cc} | ||||
| } | ||||
|  | ||||
| func (c *databaseClient) Type(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*TypeResponse, error) { | ||||
| 	out := new(TypeResponse) | ||||
| 	err := grpc.Invoke(ctx, "/dbplugin.Database/Type", in, out, c.cc, opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| func (c *databaseClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) { | ||||
| 	out := new(CreateUserResponse) | ||||
| 	err := grpc.Invoke(ctx, "/dbplugin.Database/CreateUser", in, out, c.cc, opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| func (c *databaseClient) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) { | ||||
| 	out := new(Empty) | ||||
| 	err := grpc.Invoke(ctx, "/dbplugin.Database/RenewUser", in, out, c.cc, opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) { | ||||
| 	out := new(Empty) | ||||
| 	err := grpc.Invoke(ctx, "/dbplugin.Database/RevokeUser", in, out, c.cc, opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) { | ||||
| 	out := new(Empty) | ||||
| 	err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) { | ||||
| 	out := new(Empty) | ||||
| 	err := grpc.Invoke(ctx, "/dbplugin.Database/Close", in, out, c.cc, opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return out, nil | ||||
| } | ||||
|  | ||||
| // Server API for Database service | ||||
|  | ||||
| type DatabaseServer interface { | ||||
| 	Type(context.Context, *Empty) (*TypeResponse, error) | ||||
| 	CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) | ||||
| 	RenewUser(context.Context, *RenewUserRequest) (*Empty, error) | ||||
| 	RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error) | ||||
| 	Initialize(context.Context, *InitializeRequest) (*Empty, error) | ||||
| 	Close(context.Context, *Empty) (*Empty, error) | ||||
| } | ||||
|  | ||||
| func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) { | ||||
| 	s.RegisterService(&_Database_serviceDesc, srv) | ||||
| } | ||||
|  | ||||
| func _Database_Type_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||
| 	in := new(Empty) | ||||
| 	if err := dec(in); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if interceptor == nil { | ||||
| 		return srv.(DatabaseServer).Type(ctx, in) | ||||
| 	} | ||||
| 	info := &grpc.UnaryServerInfo{ | ||||
| 		Server:     srv, | ||||
| 		FullMethod: "/dbplugin.Database/Type", | ||||
| 	} | ||||
| 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||
| 		return srv.(DatabaseServer).Type(ctx, req.(*Empty)) | ||||
| 	} | ||||
| 	return interceptor(ctx, in, info, handler) | ||||
| } | ||||
|  | ||||
| func _Database_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||
| 	in := new(CreateUserRequest) | ||||
| 	if err := dec(in); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if interceptor == nil { | ||||
| 		return srv.(DatabaseServer).CreateUser(ctx, in) | ||||
| 	} | ||||
| 	info := &grpc.UnaryServerInfo{ | ||||
| 		Server:     srv, | ||||
| 		FullMethod: "/dbplugin.Database/CreateUser", | ||||
| 	} | ||||
| 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||
| 		return srv.(DatabaseServer).CreateUser(ctx, req.(*CreateUserRequest)) | ||||
| 	} | ||||
| 	return interceptor(ctx, in, info, handler) | ||||
| } | ||||
|  | ||||
| func _Database_RenewUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||
| 	in := new(RenewUserRequest) | ||||
| 	if err := dec(in); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if interceptor == nil { | ||||
| 		return srv.(DatabaseServer).RenewUser(ctx, in) | ||||
| 	} | ||||
| 	info := &grpc.UnaryServerInfo{ | ||||
| 		Server:     srv, | ||||
| 		FullMethod: "/dbplugin.Database/RenewUser", | ||||
| 	} | ||||
| 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||
| 		return srv.(DatabaseServer).RenewUser(ctx, req.(*RenewUserRequest)) | ||||
| 	} | ||||
| 	return interceptor(ctx, in, info, handler) | ||||
| } | ||||
|  | ||||
| func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||
| 	in := new(RevokeUserRequest) | ||||
| 	if err := dec(in); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if interceptor == nil { | ||||
| 		return srv.(DatabaseServer).RevokeUser(ctx, in) | ||||
| 	} | ||||
| 	info := &grpc.UnaryServerInfo{ | ||||
| 		Server:     srv, | ||||
| 		FullMethod: "/dbplugin.Database/RevokeUser", | ||||
| 	} | ||||
| 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||
| 		return srv.(DatabaseServer).RevokeUser(ctx, req.(*RevokeUserRequest)) | ||||
| 	} | ||||
| 	return interceptor(ctx, in, info, handler) | ||||
| } | ||||
|  | ||||
| func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||
| 	in := new(InitializeRequest) | ||||
| 	if err := dec(in); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if interceptor == nil { | ||||
| 		return srv.(DatabaseServer).Initialize(ctx, in) | ||||
| 	} | ||||
| 	info := &grpc.UnaryServerInfo{ | ||||
| 		Server:     srv, | ||||
| 		FullMethod: "/dbplugin.Database/Initialize", | ||||
| 	} | ||||
| 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||
| 		return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest)) | ||||
| 	} | ||||
| 	return interceptor(ctx, in, info, handler) | ||||
| } | ||||
|  | ||||
| func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { | ||||
| 	in := new(Empty) | ||||
| 	if err := dec(in); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if interceptor == nil { | ||||
| 		return srv.(DatabaseServer).Close(ctx, in) | ||||
| 	} | ||||
| 	info := &grpc.UnaryServerInfo{ | ||||
| 		Server:     srv, | ||||
| 		FullMethod: "/dbplugin.Database/Close", | ||||
| 	} | ||||
| 	handler := func(ctx context.Context, req interface{}) (interface{}, error) { | ||||
| 		return srv.(DatabaseServer).Close(ctx, req.(*Empty)) | ||||
| 	} | ||||
| 	return interceptor(ctx, in, info, handler) | ||||
| } | ||||
|  | ||||
| var _Database_serviceDesc = grpc.ServiceDesc{ | ||||
| 	ServiceName: "dbplugin.Database", | ||||
| 	HandlerType: (*DatabaseServer)(nil), | ||||
| 	Methods: []grpc.MethodDesc{ | ||||
| 		{ | ||||
| 			MethodName: "Type", | ||||
| 			Handler:    _Database_Type_Handler, | ||||
| 		}, | ||||
| 		{ | ||||
| 			MethodName: "CreateUser", | ||||
| 			Handler:    _Database_CreateUser_Handler, | ||||
| 		}, | ||||
| 		{ | ||||
| 			MethodName: "RenewUser", | ||||
| 			Handler:    _Database_RenewUser_Handler, | ||||
| 		}, | ||||
| 		{ | ||||
| 			MethodName: "RevokeUser", | ||||
| 			Handler:    _Database_RevokeUser_Handler, | ||||
| 		}, | ||||
| 		{ | ||||
| 			MethodName: "Initialize", | ||||
| 			Handler:    _Database_Initialize_Handler, | ||||
| 		}, | ||||
| 		{ | ||||
| 			MethodName: "Close", | ||||
| 			Handler:    _Database_Close_Handler, | ||||
| 		}, | ||||
| 	}, | ||||
| 	Streams:  []grpc.StreamDesc{}, | ||||
| 	Metadata: "builtin/logical/database/dbplugin/database.proto", | ||||
| } | ||||
|  | ||||
| func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) } | ||||
|  | ||||
| var fileDescriptor0 = []byte{ | ||||
| 	// 548 bytes of a gzipped FileDescriptorProto | ||||
| 	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e, | ||||
| 	0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f, | ||||
| 	0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e, | ||||
| 	0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2, | ||||
| 	0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6, | ||||
| 	0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c, | ||||
| 	0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e, | ||||
| 	0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3, | ||||
| 	0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1, | ||||
| 	0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f, | ||||
| 	0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49, | ||||
| 	0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35, | ||||
| 	0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20, | ||||
| 	0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6, | ||||
| 	0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34, | ||||
| 	0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0, | ||||
| 	0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce, | ||||
| 	0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68, | ||||
| 	0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0, | ||||
| 	0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8, | ||||
| 	0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1, | ||||
| 	0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0, | ||||
| 	0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a, | ||||
| 	0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab, | ||||
| 	0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c, | ||||
| 	0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4, | ||||
| 	0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36, | ||||
| 	0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a, | ||||
| 	0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf, | ||||
| 	0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67, | ||||
| 	0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b, | ||||
| 	0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d, | ||||
| 	0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6, | ||||
| 	0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56, | ||||
| 	0x94, 0x05, 0x00, 0x00, | ||||
| } | ||||
							
								
								
									
										58
									
								
								builtin/logical/database/dbplugin/database.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								builtin/logical/database/dbplugin/database.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| syntax = "proto3"; | ||||
| package dbplugin; | ||||
|  | ||||
| import "google/protobuf/timestamp.proto"; | ||||
|  | ||||
| message InitializeRequest { | ||||
| 	bytes config = 1; | ||||
| 	bool verify_connection = 2; | ||||
| } | ||||
|  | ||||
| message CreateUserRequest { | ||||
| 	Statements     statements = 1; | ||||
| 	UsernameConfig username_config = 2; | ||||
| 	google.protobuf.Timestamp expiration = 3; | ||||
| } | ||||
|  | ||||
| message RenewUserRequest { | ||||
| 	Statements statements = 1; | ||||
| 	string username = 2; | ||||
| 	google.protobuf.Timestamp expiration = 3; | ||||
| } | ||||
|  | ||||
| message RevokeUserRequest { | ||||
| 	Statements statements = 1; | ||||
| 	string username = 2; | ||||
| } | ||||
|  | ||||
| message Statements { | ||||
| 	string creation_statements = 1; | ||||
| 	string revocation_statements = 2; | ||||
| 	string rollback_statements  = 3; | ||||
| 	string renew_statements = 4; | ||||
| } | ||||
|  | ||||
| message UsernameConfig { | ||||
| 	string DisplayName = 1; | ||||
| 	string RoleName = 2; | ||||
| } | ||||
|  | ||||
| message CreateUserResponse { | ||||
| 	string username = 1; | ||||
| 	string password = 2; | ||||
| } | ||||
|  | ||||
| message TypeResponse { | ||||
|     string type = 1; | ||||
| } | ||||
|  | ||||
| message Empty {} | ||||
|  | ||||
| service Database { | ||||
|     rpc Type(Empty) returns (TypeResponse); | ||||
|     rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); | ||||
|     rpc RenewUser(RenewUserRequest) returns (Empty); | ||||
|     rpc RevokeUser(RevokeUserRequest) returns (Empty); | ||||
|     rpc Initialize(InitializeRequest) returns (Empty); | ||||
|     rpc Close(Empty) returns (Empty); | ||||
| } | ||||
| @@ -1,6 +1,7 @@ | ||||
| package dbplugin | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"time" | ||||
|  | ||||
| 	metrics "github.com/armon/go-metrics" | ||||
| @@ -16,54 +17,55 @@ type databaseTracingMiddleware struct { | ||||
| 	logger log.Logger | ||||
|  | ||||
| 	typeStr   string | ||||
| 	transport string | ||||
| } | ||||
|  | ||||
| func (mw *databaseTracingMiddleware) Type() (string, error) { | ||||
| 	return mw.next.Type() | ||||
| } | ||||
|  | ||||
| func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	defer func(then time.Time) { | ||||
| 		mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) | ||||
| 		mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) | ||||
| 	}(time.Now()) | ||||
|  | ||||
| 	mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) | ||||
| 	return mw.next.CreateUser(statements, usernameConfig, expiration) | ||||
| 	mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr, "transport", mw.transport) | ||||
| 	return mw.next.CreateUser(ctx, statements, usernameConfig, expiration) | ||||
| } | ||||
|  | ||||
| func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { | ||||
| func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { | ||||
| 	defer func(then time.Time) { | ||||
| 		mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) | ||||
| 		mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) | ||||
| 	}(time.Now()) | ||||
|  | ||||
| 	mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) | ||||
| 	return mw.next.RenewUser(statements, username, expiration) | ||||
| 	mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr, "transport", mw.transport) | ||||
| 	return mw.next.RenewUser(ctx, statements, username, expiration) | ||||
| } | ||||
|  | ||||
| func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { | ||||
| func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { | ||||
| 	defer func(then time.Time) { | ||||
| 		mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) | ||||
| 		mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) | ||||
| 	}(time.Now()) | ||||
|  | ||||
| 	mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) | ||||
| 	return mw.next.RevokeUser(statements, username) | ||||
| 	mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr, "transport", mw.transport) | ||||
| 	return mw.next.RevokeUser(ctx, statements, username) | ||||
| } | ||||
|  | ||||
| func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { | ||||
| func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { | ||||
| 	defer func(then time.Time) { | ||||
| 		mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) | ||||
| 		mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then)) | ||||
| 	}(time.Now()) | ||||
|  | ||||
| 	mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) | ||||
| 	return mw.next.Initialize(conf, verifyConnection) | ||||
| 	mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport) | ||||
| 	return mw.next.Initialize(ctx, conf, verifyConnection) | ||||
| } | ||||
|  | ||||
| func (mw *databaseTracingMiddleware) Close() (err error) { | ||||
| 	defer func(then time.Time) { | ||||
| 		mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) | ||||
| 		mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then)) | ||||
| 	}(time.Now()) | ||||
|  | ||||
| 	mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) | ||||
| 	mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr, "transport", mw.transport) | ||||
| 	return mw.next.Close() | ||||
| } | ||||
|  | ||||
| @@ -81,7 +83,7 @@ func (mw *databaseMetricsMiddleware) Type() (string, error) { | ||||
| 	return mw.next.Type() | ||||
| } | ||||
|  | ||||
| func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	defer func(now time.Time) { | ||||
| 		metrics.MeasureSince([]string{"database", "CreateUser"}, now) | ||||
| 		metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) | ||||
| @@ -94,10 +96,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernameC | ||||
|  | ||||
| 	metrics.IncrCounter([]string{"database", "CreateUser"}, 1) | ||||
| 	metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) | ||||
| 	return mw.next.CreateUser(statements, usernameConfig, expiration) | ||||
| 	return mw.next.CreateUser(ctx, statements, usernameConfig, expiration) | ||||
| } | ||||
|  | ||||
| func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { | ||||
| func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) { | ||||
| 	defer func(now time.Time) { | ||||
| 		metrics.MeasureSince([]string{"database", "RenewUser"}, now) | ||||
| 		metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) | ||||
| @@ -110,10 +112,10 @@ func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username s | ||||
|  | ||||
| 	metrics.IncrCounter([]string{"database", "RenewUser"}, 1) | ||||
| 	metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1) | ||||
| 	return mw.next.RenewUser(statements, username, expiration) | ||||
| 	return mw.next.RenewUser(ctx, statements, username, expiration) | ||||
| } | ||||
|  | ||||
| func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) { | ||||
| func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) { | ||||
| 	defer func(now time.Time) { | ||||
| 		metrics.MeasureSince([]string{"database", "RevokeUser"}, now) | ||||
| 		metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now) | ||||
| @@ -126,10 +128,10 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username | ||||
|  | ||||
| 	metrics.IncrCounter([]string{"database", "RevokeUser"}, 1) | ||||
| 	metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1) | ||||
| 	return mw.next.RevokeUser(statements, username) | ||||
| 	return mw.next.RevokeUser(ctx, statements, username) | ||||
| } | ||||
|  | ||||
| func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { | ||||
| func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { | ||||
| 	defer func(now time.Time) { | ||||
| 		metrics.MeasureSince([]string{"database", "Initialize"}, now) | ||||
| 		metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) | ||||
| @@ -142,7 +144,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, ver | ||||
|  | ||||
| 	metrics.IncrCounter([]string{"database", "Initialize"}, 1) | ||||
| 	metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) | ||||
| 	return mw.next.Initialize(conf, verifyConnection) | ||||
| 	return mw.next.Initialize(ctx, conf, verifyConnection) | ||||
| } | ||||
|  | ||||
| func (mw *databaseMetricsMiddleware) Close() (err error) { | ||||
|   | ||||
							
								
								
									
										198
									
								
								builtin/logical/database/dbplugin/grpc_transport.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										198
									
								
								builtin/logical/database/dbplugin/grpc_transport.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,198 @@ | ||||
| package dbplugin | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"time" | ||||
|  | ||||
| 	"google.golang.org/grpc" | ||||
| 	"google.golang.org/grpc/connectivity" | ||||
|  | ||||
| 	"github.com/golang/protobuf/ptypes" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	ErrPluginShutdown = errors.New("plugin shutdown") | ||||
| ) | ||||
|  | ||||
| // ---- gRPC Server domain ---- | ||||
|  | ||||
| type gRPCServer struct { | ||||
| 	impl Database | ||||
| } | ||||
|  | ||||
| func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) { | ||||
| 	t, err := s.impl.Type() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &TypeResponse{ | ||||
| 		Type: t, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) { | ||||
| 	e, err := ptypes.Timestamp(req.Expiration) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e) | ||||
|  | ||||
| 	return &CreateUserResponse{ | ||||
| 		Username: u, | ||||
| 		Password: p, | ||||
| 	}, err | ||||
| } | ||||
|  | ||||
| func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) { | ||||
| 	e, err := ptypes.Timestamp(req.Expiration) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e) | ||||
| 	return &Empty{}, err | ||||
| } | ||||
|  | ||||
| func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) { | ||||
| 	err := s.impl.RevokeUser(ctx, *req.Statements, req.Username) | ||||
| 	return &Empty{}, err | ||||
| } | ||||
|  | ||||
| func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) { | ||||
| 	config := map[string]interface{}{} | ||||
|  | ||||
| 	err := json.Unmarshal(req.Config, &config) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = s.impl.Initialize(ctx, config, req.VerifyConnection) | ||||
| 	return &Empty{}, err | ||||
| } | ||||
|  | ||||
| func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) { | ||||
| 	s.impl.Close() | ||||
| 	return &Empty{}, nil | ||||
| } | ||||
|  | ||||
| // ---- gRPC client domain ---- | ||||
|  | ||||
| type gRPCClient struct { | ||||
| 	client     DatabaseClient | ||||
| 	clientConn *grpc.ClientConn | ||||
| } | ||||
|  | ||||
| func (c gRPCClient) Type() (string, error) { | ||||
| 	// If the plugin has already shutdown, this will hang forever so we give it | ||||
| 	// a one second timeout. | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	switch c.clientConn.GetState() { | ||||
| 	case connectivity.Ready, connectivity.Idle: | ||||
| 	default: | ||||
| 		return "", ErrPluginShutdown | ||||
| 	} | ||||
| 	resp, err := c.client.Type(ctx, &Empty{}) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	return resp.Type, err | ||||
| } | ||||
|  | ||||
| func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	t, err := ptypes.TimestampProto(expiration) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	switch c.clientConn.GetState() { | ||||
| 	case connectivity.Ready, connectivity.Idle: | ||||
| 	default: | ||||
| 		return "", "", ErrPluginShutdown | ||||
| 	} | ||||
|  | ||||
| 	resp, err := c.client.CreateUser(ctx, &CreateUserRequest{ | ||||
| 		Statements:     &statements, | ||||
| 		UsernameConfig: &usernameConfig, | ||||
| 		Expiration:     t, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	return resp.Username, resp.Password, err | ||||
| } | ||||
|  | ||||
| func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error { | ||||
| 	t, err := ptypes.TimestampProto(expiration) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	switch c.clientConn.GetState() { | ||||
| 	case connectivity.Ready, connectivity.Idle: | ||||
| 	default: | ||||
| 		return ErrPluginShutdown | ||||
| 	} | ||||
|  | ||||
| 	_, err = c.client.RenewUser(ctx, &RenewUserRequest{ | ||||
| 		Statements: &statements, | ||||
| 		Username:   username, | ||||
| 		Expiration: t, | ||||
| 	}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error { | ||||
| 	switch c.clientConn.GetState() { | ||||
| 	case connectivity.Ready, connectivity.Idle: | ||||
| 	default: | ||||
| 		return ErrPluginShutdown | ||||
| 	} | ||||
| 	_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{ | ||||
| 		Statements: &statements, | ||||
| 		Username:   username, | ||||
| 	}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error { | ||||
| 	configRaw, err := json.Marshal(config) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	switch c.clientConn.GetState() { | ||||
| 	case connectivity.Ready, connectivity.Idle: | ||||
| 	default: | ||||
| 		return ErrPluginShutdown | ||||
| 	} | ||||
|  | ||||
| 	_, err = c.client.Initialize(ctx, &InitializeRequest{ | ||||
| 		Config:           configRaw, | ||||
| 		VerifyConnection: verifyConnection, | ||||
| 	}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (c *gRPCClient) Close() error { | ||||
| 	// If the plugin has already shutdown, this will hang forever so we give it | ||||
| 	// a one second timeout. | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), time.Second) | ||||
| 	defer cancel() | ||||
| 	switch c.clientConn.GetState() { | ||||
| 	case connectivity.Ready, connectivity.Idle: | ||||
| 		_, err := c.client.Close(ctx, &Empty{}) | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										139
									
								
								builtin/logical/database/dbplugin/netrpc_transport.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								builtin/logical/database/dbplugin/netrpc_transport.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| package dbplugin | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/rpc" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // ---- RPC server domain ---- | ||||
|  | ||||
| // databasePluginRPCServer implements an RPC version of Database and is run | ||||
| // inside a plugin. It wraps an underlying implementation of Database. | ||||
| type databasePluginRPCServer struct { | ||||
| 	impl Database | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { | ||||
| 	var err error | ||||
| 	*resp, err = ds.impl.Type() | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequestRPC, resp *CreateUserResponse) error { | ||||
| 	var err error | ||||
| 	resp.Username, resp.Password, err = ds.impl.CreateUser(context.Background(), args.Statements, args.UsernameConfig, args.Expiration) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequestRPC, _ *struct{}) error { | ||||
| 	err := ds.impl.RenewUser(context.Background(), args.Statements, args.Username, args.Expiration) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *struct{}) error { | ||||
| 	err := ds.impl.RevokeUser(context.Background(), args.Statements, args.Username) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error { | ||||
| 	err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { | ||||
| 	ds.impl.Close() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ---- RPC client domain ---- | ||||
| // databasePluginRPCClient implements Database and is used on the client to | ||||
| // make RPC calls to a plugin. | ||||
| type databasePluginRPCClient struct { | ||||
| 	client *rpc.Client | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) Type() (string, error) { | ||||
| 	var dbType string | ||||
| 	err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) | ||||
|  | ||||
| 	return fmt.Sprintf("plugin-%s", dbType), err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) CreateUser(_ context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	req := CreateUserRequestRPC{ | ||||
| 		Statements:     statements, | ||||
| 		UsernameConfig: usernameConfig, | ||||
| 		Expiration:     expiration, | ||||
| 	} | ||||
|  | ||||
| 	var resp CreateUserResponse | ||||
| 	err = dr.client.Call("Plugin.CreateUser", req, &resp) | ||||
|  | ||||
| 	return resp.Username, resp.Password, err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements Statements, username string, expiration time.Time) error { | ||||
| 	req := RenewUserRequestRPC{ | ||||
| 		Statements: statements, | ||||
| 		Username:   username, | ||||
| 		Expiration: expiration, | ||||
| 	} | ||||
|  | ||||
| 	err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error { | ||||
| 	req := RevokeUserRequestRPC{ | ||||
| 		Statements: statements, | ||||
| 		Username:   username, | ||||
| 	} | ||||
|  | ||||
| 	err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error { | ||||
| 	req := InitializeRequestRPC{ | ||||
| 		Config:           conf, | ||||
| 		VerifyConnection: verifyConnection, | ||||
| 	} | ||||
|  | ||||
| 	err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (dr *databasePluginRPCClient) Close() error { | ||||
| 	err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // ---- RPC Request Args Domain ---- | ||||
|  | ||||
| type InitializeRequestRPC struct { | ||||
| 	Config           map[string]interface{} | ||||
| 	VerifyConnection bool | ||||
| } | ||||
|  | ||||
| type CreateUserRequestRPC struct { | ||||
| 	Statements     Statements | ||||
| 	UsernameConfig UsernameConfig | ||||
| 	Expiration     time.Time | ||||
| } | ||||
|  | ||||
| type RenewUserRequestRPC struct { | ||||
| 	Statements Statements | ||||
| 	Username   string | ||||
| 	Expiration time.Time | ||||
| } | ||||
|  | ||||
| type RevokeUserRequestRPC struct { | ||||
| 	Statements Statements | ||||
| 	Username   string | ||||
| } | ||||
| @@ -1,10 +1,13 @@ | ||||
| package dbplugin | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/rpc" | ||||
| 	"time" | ||||
|  | ||||
| 	"google.golang.org/grpc" | ||||
|  | ||||
| 	"github.com/hashicorp/go-plugin" | ||||
| 	"github.com/hashicorp/vault/helper/pluginutil" | ||||
| 	log "github.com/mgutz/logxi/v1" | ||||
| @@ -13,29 +16,14 @@ import ( | ||||
| // Database is the interface that all database objects must implement. | ||||
| type Database interface { | ||||
| 	Type() (string, error) | ||||
| 	CreateUser(statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) | ||||
| 	RenewUser(statements Statements, username string, expiration time.Time) error | ||||
| 	RevokeUser(statements Statements, username string) error | ||||
| 	CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) | ||||
| 	RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error | ||||
| 	RevokeUser(ctx context.Context, statements Statements, username string) error | ||||
|  | ||||
| 	Initialize(config map[string]interface{}, verifyConnection bool) error | ||||
| 	Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error | ||||
| 	Close() error | ||||
| } | ||||
|  | ||||
| // Statements set in role creation and passed into the database type's functions. | ||||
| type Statements struct { | ||||
| 	CreationStatements   string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` | ||||
| 	RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` | ||||
| 	RollbackStatements   string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` | ||||
| 	RenewStatements      string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` | ||||
| } | ||||
|  | ||||
| // UsernameConfig is used to configure prefixes for the username to be | ||||
| // generated. | ||||
| type UsernameConfig struct { | ||||
| 	DisplayName string | ||||
| 	RoleName    string | ||||
| } | ||||
|  | ||||
| // PluginFactory is used to build plugin database types. It wraps the database | ||||
| // object in a logging and metrics middleware. | ||||
| func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { | ||||
| @@ -45,6 +33,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var transport string | ||||
| 	var db Database | ||||
| 	if pluginRunner.Builtin { | ||||
| 		// Plugin is builtin so we can retrieve an instance of the interface | ||||
| @@ -60,12 +49,24 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. | ||||
| 			return nil, fmt.Errorf("unsuported database type: %s", pluginName) | ||||
| 		} | ||||
|  | ||||
| 		transport = "builtin" | ||||
|  | ||||
| 	} else { | ||||
| 		// create a DatabasePluginClient instance | ||||
| 		db, err = newPluginClient(sys, pluginRunner, logger) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		// Switch on the underlying database client type to get the transport | ||||
| 		// method. | ||||
| 		switch db.(*DatabasePluginClient).Database.(type) { | ||||
| 		case *gRPCClient: | ||||
| 			transport = "gRPC" | ||||
| 		case *databasePluginRPCClient: | ||||
| 			transport = "netRPC" | ||||
| 		} | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	typeStr, err := db.Type() | ||||
| @@ -82,6 +83,7 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. | ||||
| 	// Wrap with tracing middleware | ||||
| 	if logger.IsTrace() { | ||||
| 		db = &databaseTracingMiddleware{ | ||||
| 			transport: transport, | ||||
| 			next:      db, | ||||
| 			typeStr:   typeStr, | ||||
| 			logger:    logger, | ||||
| @@ -115,33 +117,14 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e | ||||
| 	return &databasePluginRPCClient{client: c}, nil | ||||
| } | ||||
|  | ||||
| // ---- RPC Request Args Domain ---- | ||||
|  | ||||
| type InitializeRequest struct { | ||||
| 	Config           map[string]interface{} | ||||
| 	VerifyConnection bool | ||||
| func (d DatabasePlugin) GRPCServer(s *grpc.Server) error { | ||||
| 	RegisterDatabaseServer(s, &gRPCServer{impl: d.impl}) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type CreateUserRequest struct { | ||||
| 	Statements     Statements | ||||
| 	UsernameConfig UsernameConfig | ||||
| 	Expiration     time.Time | ||||
| } | ||||
|  | ||||
| type RenewUserRequest struct { | ||||
| 	Statements Statements | ||||
| 	Username   string | ||||
| 	Expiration time.Time | ||||
| } | ||||
|  | ||||
| type RevokeUserRequest struct { | ||||
| 	Statements Statements | ||||
| 	Username   string | ||||
| } | ||||
|  | ||||
| // ---- RPC Response Args Domain ---- | ||||
|  | ||||
| type CreateUserResponse struct { | ||||
| 	Username string | ||||
| 	Password string | ||||
| func (DatabasePlugin) GRPCClient(c *grpc.ClientConn) (interface{}, error) { | ||||
| 	return &gRPCClient{ | ||||
| 		client:     NewDatabaseClient(c), | ||||
| 		clientConn: c, | ||||
| 	}, nil | ||||
| } | ||||
|   | ||||
| @@ -1,11 +1,13 @@ | ||||
| package dbplugin_test | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	plugin "github.com/hashicorp/go-plugin" | ||||
| 	"github.com/hashicorp/vault/builtin/logical/database/dbplugin" | ||||
| 	"github.com/hashicorp/vault/helper/pluginutil" | ||||
| 	vaulthttp "github.com/hashicorp/vault/http" | ||||
| @@ -20,7 +22,7 @@ type mockPlugin struct { | ||||
| } | ||||
|  | ||||
| func (m *mockPlugin) Type() (string, error) { return "mock", nil } | ||||
| func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (m *mockPlugin) CreateUser(_ context.Context, statements dbplugin.Statements, usernameConf dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	err = errors.New("err") | ||||
| 	if usernameConf.DisplayName == "" || expiration.IsZero() { | ||||
| 		return "", "", err | ||||
| @@ -34,7 +36,7 @@ func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernameConf dbp | ||||
|  | ||||
| 	return usernameConf.DisplayName, "test", nil | ||||
| } | ||||
| func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| func (m *mockPlugin) RenewUser(_ context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	err := errors.New("err") | ||||
| 	if username == "" || expiration.IsZero() { | ||||
| 		return err | ||||
| @@ -46,7 +48,7 @@ func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error { | ||||
| func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	err := errors.New("err") | ||||
| 	if username == "" { | ||||
| 		return err | ||||
| @@ -59,7 +61,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) | ||||
| 	delete(m.users, username) | ||||
| 	return nil | ||||
| } | ||||
| func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error { | ||||
| func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error { | ||||
| 	err := errors.New("err") | ||||
| 	if len(conf) != 1 { | ||||
| 		return err | ||||
| @@ -80,14 +82,15 @@ func getCluster(t *testing.T) (*vault.TestCluster, logical.SystemView) { | ||||
| 	cores := cluster.Cores | ||||
|  | ||||
| 	sys := vault.TestDynamicSystemView(cores[0].Core) | ||||
| 	vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_Main") | ||||
| 	vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin", "TestPlugin_GRPC_Main") | ||||
| 	vault.TestAddTestPlugin(t, cores[0].Core, "test-plugin-netRPC", "TestPlugin_NetRPC_Main") | ||||
|  | ||||
| 	return cluster, sys | ||||
| } | ||||
|  | ||||
| // This is not an actual test case, it's a helper function that will be executed | ||||
| // by the go-plugin client via an exec call. | ||||
| func TestPlugin_Main(t *testing.T) { | ||||
| func TestPlugin_GRPC_Main(t *testing.T) { | ||||
| 	if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { | ||||
| 		return | ||||
| 	} | ||||
| @@ -105,6 +108,30 @@ func TestPlugin_Main(t *testing.T) { | ||||
| 	plugins.Serve(plugin, apiClientMeta.GetTLSConfig()) | ||||
| } | ||||
|  | ||||
| // This is not an actual test case, it's a helper function that will be executed | ||||
| // by the go-plugin client via an exec call. | ||||
| func TestPlugin_NetRPC_Main(t *testing.T) { | ||||
| 	if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	p := &mockPlugin{ | ||||
| 		users: make(map[string][]string), | ||||
| 	} | ||||
|  | ||||
| 	args := []string{"--tls-skip-verify=true"} | ||||
|  | ||||
| 	apiClientMeta := &pluginutil.APIClientMeta{} | ||||
| 	flags := apiClientMeta.FlagSet() | ||||
| 	flags.Parse(args) | ||||
|  | ||||
| 	tlsProvider := pluginutil.VaultPluginTLSProvider(apiClientMeta.GetTLSConfig()) | ||||
| 	serveConf := dbplugin.ServeConfig(p, tlsProvider) | ||||
| 	serveConf.GRPCServer = nil | ||||
|  | ||||
| 	plugin.Serve(serveConf) | ||||
| } | ||||
|  | ||||
| func TestPlugin_Initialize(t *testing.T) { | ||||
| 	cluster, sys := getCluster(t) | ||||
| 	defer cluster.Cleanup() | ||||
| @@ -118,7 +145,7 @@ func TestPlugin_Initialize(t *testing.T) { | ||||
| 		"test": 1, | ||||
| 	} | ||||
|  | ||||
| 	err = dbRaw.Initialize(connectionDetails, true) | ||||
| 	err = dbRaw.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -143,7 +170,7 @@ func TestPlugin_CreateUser(t *testing.T) { | ||||
| 		"test": 1, | ||||
| 	} | ||||
|  | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -153,7 +180,7 @@ func TestPlugin_CreateUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	us, pw, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -163,7 +190,7 @@ func TestPlugin_CreateUser(t *testing.T) { | ||||
|  | ||||
| 	// try and save the same user again to verify it saved the first time, this | ||||
| 	// should return an error | ||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("expected an error, user wasn't created correctly") | ||||
| 	} | ||||
| @@ -182,7 +209,7 @@ func TestPlugin_RenewUser(t *testing.T) { | ||||
| 	connectionDetails := map[string]interface{}{ | ||||
| 		"test": 1, | ||||
| 	} | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -192,12 +219,12 @@ func TestPlugin_RenewUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute)) | ||||
| 	err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -216,7 +243,7 @@ func TestPlugin_RevokeUser(t *testing.T) { | ||||
| 	connectionDetails := map[string]interface{}{ | ||||
| 		"test": 1, | ||||
| 	} | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -226,19 +253,159 @@ func TestPlugin_RevokeUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	us, _, err := db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	// Test default revoke statememts | ||||
| 	err = db.RevokeUser(dbplugin.Statements{}, us) | ||||
| 	err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	// Try adding the same username back so we can verify it was removed | ||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Test the code is still compatible with an old netRPC plugin | ||||
| func TestPlugin_NetRPC_Initialize(t *testing.T) { | ||||
| 	cluster, sys := getCluster(t) | ||||
| 	defer cluster.Cleanup() | ||||
|  | ||||
| 	dbRaw, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	connectionDetails := map[string]interface{}{ | ||||
| 		"test": 1, | ||||
| 	} | ||||
|  | ||||
| 	err = dbRaw.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = dbRaw.Close() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPlugin_NetRPC_CreateUser(t *testing.T) { | ||||
| 	cluster, sys := getCluster(t) | ||||
| 	defer cluster.Cleanup() | ||||
|  | ||||
| 	db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	connectionDetails := map[string]interface{}{ | ||||
| 		"test": 1, | ||||
| 	} | ||||
|  | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	usernameConf := dbplugin.UsernameConfig{ | ||||
| 		DisplayName: "test", | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	us, pw, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| 	if us != "test" || pw != "test" { | ||||
| 		t.Fatal("expected username and password to be 'test'") | ||||
| 	} | ||||
|  | ||||
| 	// try and save the same user again to verify it saved the first time, this | ||||
| 	// should return an error | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("expected an error, user wasn't created correctly") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPlugin_NetRPC_RenewUser(t *testing.T) { | ||||
| 	cluster, sys := getCluster(t) | ||||
| 	defer cluster.Cleanup() | ||||
|  | ||||
| 	db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	connectionDetails := map[string]interface{}{ | ||||
| 		"test": 1, | ||||
| 	} | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	usernameConf := dbplugin.UsernameConfig{ | ||||
| 		DisplayName: "test", | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.RenewUser(context.Background(), dbplugin.Statements{}, us, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPlugin_NetRPC_RevokeUser(t *testing.T) { | ||||
| 	cluster, sys := getCluster(t) | ||||
| 	defer cluster.Cleanup() | ||||
|  | ||||
| 	db, err := dbplugin.PluginFactory("test-plugin-netRPC", sys, &log.NullLogger{}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	connectionDetails := map[string]interface{}{ | ||||
| 		"test": 1, | ||||
| 	} | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	usernameConf := dbplugin.UsernameConfig{ | ||||
| 		DisplayName: "test", | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	us, _, err := db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	// Test default revoke statememts | ||||
| 	err = db.RevokeUser(context.Background(), dbplugin.Statements{}, us) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	// Try adding the same username back so we can verify it was removed | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConf, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
| @@ -10,6 +10,10 @@ import ( | ||||
| // Database implementation in a databasePluginRPCServer object and starts a | ||||
| // RPC server. | ||||
| func Serve(db Database, tlsProvider func() (*tls.Config, error)) { | ||||
| 	plugin.Serve(ServeConfig(db, tlsProvider)) | ||||
| } | ||||
|  | ||||
| func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig { | ||||
| 	dbPlugin := &DatabasePlugin{ | ||||
| 		impl: db, | ||||
| 	} | ||||
| @@ -19,53 +23,10 @@ func Serve(db Database, tlsProvider func() (*tls.Config, error)) { | ||||
| 		"database": dbPlugin, | ||||
| 	} | ||||
|  | ||||
| 	plugin.Serve(&plugin.ServeConfig{ | ||||
| 	return &plugin.ServeConfig{ | ||||
| 		HandshakeConfig: handshakeConfig, | ||||
| 		Plugins:         pluginMap, | ||||
| 		TLSProvider:     tlsProvider, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // ---- RPC server domain ---- | ||||
|  | ||||
| // databasePluginRPCServer implements an RPC version of Database and is run | ||||
| // inside a plugin. It wraps an underlying implementation of Database. | ||||
| type databasePluginRPCServer struct { | ||||
| 	impl Database | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { | ||||
| 	var err error | ||||
| 	*resp, err = ds.impl.Type() | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { | ||||
| 	var err error | ||||
| 	resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernameConfig, args.Expiration) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { | ||||
| 	err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { | ||||
| 	err := ds.impl.RevokeUser(args.Statements, args.Username) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error { | ||||
| 	err := ds.impl.Initialize(args.Config, args.VerifyConnection) | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { | ||||
| 	ds.impl.Close() | ||||
| 	return nil | ||||
| 		GRPCServer:      plugin.DefaultGRPCServer, | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package database | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
|  | ||||
| @@ -62,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { | ||||
| 		b.clearConnection(name) | ||||
|  | ||||
| 		// Execute plugin again, we don't need the object so throw away. | ||||
| 		_, err := b.createDBObj(req.Storage, name) | ||||
| 		_, err := b.createDBObj(context.TODO(), req.Storage, name) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -230,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { | ||||
| 			return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil | ||||
| 		} | ||||
|  | ||||
| 		err = db.Initialize(config.ConnectionDetails, verifyConnection) | ||||
| 		err = db.Initialize(context.TODO(), config.ConnectionDetails, verifyConnection) | ||||
| 		if err != nil { | ||||
| 			db.Close() | ||||
| 			return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package database | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| @@ -66,7 +67,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { | ||||
| 			unlockFunc = b.Unlock | ||||
|  | ||||
| 			// Create a new DB object | ||||
| 			db, err = b.createDBObj(req.Storage, role.DBName) | ||||
| 			db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) | ||||
| 			if err != nil { | ||||
| 				unlockFunc() | ||||
| 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | ||||
| @@ -81,7 +82,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { | ||||
| 		} | ||||
|  | ||||
| 		// Create the user | ||||
| 		username, password, err := db.CreateUser(role.Statements, usernameConfig, expiration) | ||||
| 		username, password, err := db.CreateUser(context.TODO(), role.Statements, usernameConfig, expiration) | ||||
| 		// Unlock | ||||
| 		unlockFunc() | ||||
| 		if err != nil { | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package database | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/logical" | ||||
| @@ -60,7 +61,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { | ||||
| 			unlockFunc = b.Unlock | ||||
|  | ||||
| 			// Create a new DB object | ||||
| 			db, err = b.createDBObj(req.Storage, role.DBName) | ||||
| 			db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) | ||||
| 			if err != nil { | ||||
| 				unlockFunc() | ||||
| 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | ||||
| @@ -69,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { | ||||
|  | ||||
| 		// Make sure we increase the VALID UNTIL endpoint for this user. | ||||
| 		if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { | ||||
| 			err := db.RenewUser(role.Statements, username, expireTime) | ||||
| 			err := db.RenewUser(context.TODO(), role.Statements, username, expireTime) | ||||
| 			// Unlock | ||||
| 			unlockFunc() | ||||
| 			if err != nil { | ||||
| @@ -119,14 +120,14 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { | ||||
| 			unlockFunc = b.Unlock | ||||
|  | ||||
| 			// Create a new DB object | ||||
| 			db, err = b.createDBObj(req.Storage, role.DBName) | ||||
| 			db, err = b.createDBObj(context.TODO(), req.Storage, role.DBName) | ||||
| 			if err != nil { | ||||
| 				unlockFunc() | ||||
| 				return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		err = db.RevokeUser(role.Statements, username) | ||||
| 		err = db.RevokeUser(context.TODO(), role.Statements, username) | ||||
| 		// Unlock | ||||
| 		unlockFunc() | ||||
| 		if err != nil { | ||||
|   | ||||
| @@ -119,6 +119,10 @@ func (r *PluginRunner) runCommon(wrapper RunnerUtil, pluginMap map[string]plugin | ||||
| 		SecureConfig:    secureConfig, | ||||
| 		TLSConfig:       clientTLSConfig, | ||||
| 		Logger:          namedLogger, | ||||
| 		AllowedProtocols: []plugin.Protocol{ | ||||
| 			plugin.ProtocolNetRPC, | ||||
| 			plugin.ProtocolGRPC, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	client := plugin.NewClient(clientConfig) | ||||
|   | ||||
| @@ -192,8 +192,7 @@ func TestBackendHandleRequest_helpRoot(t *testing.T) { | ||||
| func TestBackendHandleRequest_renewAuth(t *testing.T) { | ||||
| 	b := &Backend{} | ||||
|  | ||||
| 	resp, err := b.HandleRequest(logical.RenewAuthRequest( | ||||
| 		"/foo", &logical.Auth{}, nil)) | ||||
| 	resp, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -213,8 +212,7 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) { | ||||
| 		AuthRenew: callback, | ||||
| 	} | ||||
|  | ||||
| 	_, err := b.HandleRequest(logical.RenewAuthRequest( | ||||
| 		"/foo", &logical.Auth{}, nil)) | ||||
| 	_, err := b.HandleRequest(logical.RenewAuthRequest("/foo", &logical.Auth{}, nil)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -237,8 +235,7 @@ func TestBackendHandleRequest_renew(t *testing.T) { | ||||
| 		Secrets: []*Secret{secret}, | ||||
| 	} | ||||
|  | ||||
| 	_, err := b.HandleRequest(logical.RenewRequest( | ||||
| 		"/foo", secret.Response(nil, nil).Secret, nil)) | ||||
| 	_, err := b.HandleRequest(logical.RenewRequest("/foo", secret.Response(nil, nil).Secret, nil)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -293,8 +290,7 @@ func TestBackendHandleRequest_revoke(t *testing.T) { | ||||
| 		Secrets: []*Secret{secret}, | ||||
| 	} | ||||
|  | ||||
| 	_, err := b.HandleRequest(logical.RevokeRequest( | ||||
| 		"/foo", secret.Response(nil, nil).Secret, nil)) | ||||
| 	_, err := b.HandleRequest(logical.RevokeRequest("/foo", secret.Response(nil, nil).Secret, nil)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
| @@ -200,8 +200,7 @@ func (r *Request) SetLastRemoteWAL(last uint64) { | ||||
| } | ||||
|  | ||||
| // RenewRequest creates the structure of the renew request. | ||||
| func RenewRequest( | ||||
| 	path string, secret *Secret, data map[string]interface{}) *Request { | ||||
| func RenewRequest(path string, secret *Secret, data map[string]interface{}) *Request { | ||||
| 	return &Request{ | ||||
| 		Operation: RenewOperation, | ||||
| 		Path:      path, | ||||
| @@ -211,8 +210,7 @@ func RenewRequest( | ||||
| } | ||||
|  | ||||
| // RenewAuthRequest creates the structure of the renew request for an auth. | ||||
| func RenewAuthRequest( | ||||
| 	path string, auth *Auth, data map[string]interface{}) *Request { | ||||
| func RenewAuthRequest(path string, auth *Auth, data map[string]interface{}) *Request { | ||||
| 	return &Request{ | ||||
| 		Operation: RenewOperation, | ||||
| 		Path:      path, | ||||
| @@ -222,8 +220,7 @@ func RenewAuthRequest( | ||||
| } | ||||
|  | ||||
| // RevokeRequest creates the structure of the revoke request. | ||||
| func RevokeRequest( | ||||
| 	path string, secret *Secret, data map[string]interface{}) *Request { | ||||
| func RevokeRequest(path string, secret *Secret, data map[string]interface{}) *Request { | ||||
| 	return &Request{ | ||||
| 		Operation: RevokeOperation, | ||||
| 		Path:      path, | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package cassandra | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| @@ -21,6 +22,8 @@ const ( | ||||
| 	cassandraTypeName      = "cassandra" | ||||
| ) | ||||
|  | ||||
| var _ dbplugin.Database = &Cassandra{} | ||||
|  | ||||
| // Cassandra is an implementation of Database interface | ||||
| type Cassandra struct { | ||||
| 	connutil.ConnectionProducer | ||||
| @@ -64,8 +67,8 @@ func (c *Cassandra) Type() (string, error) { | ||||
| 	return cassandraTypeName, nil | ||||
| } | ||||
|  | ||||
| func (c *Cassandra) getConnection() (*gocql.Session, error) { | ||||
| 	session, err := c.Connection() | ||||
| func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) { | ||||
| 	session, err := c.Connection(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -75,13 +78,13 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { | ||||
|  | ||||
| // CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by | ||||
| // the CreationStatement provided. | ||||
| func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	// Grab the lock | ||||
| 	c.Lock() | ||||
| 	defer c.Unlock() | ||||
|  | ||||
| 	// Get the connection | ||||
| 	session, err := c.getConnection() | ||||
| 	session, err := c.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
| @@ -138,18 +141,18 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernameConfig db | ||||
| } | ||||
|  | ||||
| // RenewUser is not supported on Cassandra, so this is a no-op. | ||||
| func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| func (c *Cassandra) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	// NOOP | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // RevokeUser attempts to drop the specified user. | ||||
| func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { | ||||
| func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	// Grab the lock | ||||
| 	c.Lock() | ||||
| 	defer c.Unlock() | ||||
|  | ||||
| 	session, err := c.getConnection() | ||||
| 	session, err := c.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package cassandra | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"os" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| @@ -89,7 +90,7 @@ func TestCassandra_Initialize(t *testing.T) { | ||||
| 	db := dbRaw.(*Cassandra) | ||||
| 	connProducer := db.ConnectionProducer.(*cassandraConnectionProducer) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -112,7 +113,7 @@ func TestCassandra_Initialize(t *testing.T) { | ||||
| 		"protocol_version": "4", | ||||
| 	} | ||||
|  | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -135,7 +136,7 @@ func TestCassandra_CreateUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*Cassandra) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -149,7 +150,7 @@ func TestCassandra_CreateUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -176,7 +177,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*Cassandra) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -190,7 +191,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -199,7 +200,7 @@ func TestMyCassandra_RenewUser(t *testing.T) { | ||||
| 		t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) | ||||
| 	err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -222,7 +223,7 @@ func TestCassandra_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*Cassandra) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -236,7 +237,7 @@ func TestCassandra_RevokeUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -246,7 +247,7 @@ func TestCassandra_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test default revoke statememts | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package cassandra | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| @@ -43,7 +44,7 @@ type cassandraConnectionProducer struct { | ||||
| 	sync.Mutex | ||||
| } | ||||
|  | ||||
| func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { | ||||
| func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { | ||||
| 	c.Lock() | ||||
| 	defer c.Unlock() | ||||
|  | ||||
| @@ -106,7 +107,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve | ||||
| 	c.Initialized = true | ||||
|  | ||||
| 	if verifyConnection { | ||||
| 		if _, err := c.Connection(); err != nil { | ||||
| 		if _, err := c.Connection(ctx); err != nil { | ||||
| 			return fmt.Errorf("error verifying connection: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| @@ -114,7 +115,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, ve | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *cassandraConnectionProducer) Connection() (interface{}, error) { | ||||
| func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) { | ||||
| 	if !c.Initialized { | ||||
| 		return nil, connutil.ErrNotInitialized | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package hana | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| @@ -26,6 +27,8 @@ type HANA struct { | ||||
| 	credsutil.CredentialsProducer | ||||
| } | ||||
|  | ||||
| var _ dbplugin.Database = &HANA{} | ||||
|  | ||||
| // New implements builtinplugins.BuiltinFactory | ||||
| func New() (interface{}, error) { | ||||
| 	connProducer := &connutil.SQLConnectionProducer{} | ||||
| @@ -63,8 +66,8 @@ func (h *HANA) Type() (string, error) { | ||||
| 	return hanaTypeName, nil | ||||
| } | ||||
|  | ||||
| func (h *HANA) getConnection() (*sql.DB, error) { | ||||
| 	db, err := h.Connection() | ||||
| func (h *HANA) getConnection(ctx context.Context) (*sql.DB, error) { | ||||
| 	db, err := h.Connection(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -74,13 +77,13 @@ func (h *HANA) getConnection() (*sql.DB, error) { | ||||
|  | ||||
| // CreateUser generates the username/password on the underlying HANA secret backend | ||||
| // as instructed by the CreationStatement provided. | ||||
| func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	// Grab the lock | ||||
| 	h.Lock() | ||||
| 	defer h.Unlock() | ||||
|  | ||||
| 	// Get the connection | ||||
| 	db, err := h.getConnection() | ||||
| 	db, err := h.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
| @@ -153,9 +156,9 @@ func (h *HANA) CreateUser(statements dbplugin.Statements, usernameConfig dbplugi | ||||
| } | ||||
|  | ||||
| // Renewing hana user just means altering user's valid until property | ||||
| func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	// Get connection | ||||
| 	db, err := h.getConnection() | ||||
| 	db, err := h.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -193,14 +196,14 @@ func (h *HANA) RenewUser(statements dbplugin.Statements, username string, expira | ||||
| } | ||||
|  | ||||
| // Revoking hana user will deactivate user and try to perform a soft drop | ||||
| func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error { | ||||
| func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	// default revoke will be a soft drop on user | ||||
| 	if statements.RevocationStatements == "" { | ||||
| 		return h.revokeUserDefault(username) | ||||
| 		return h.revokeUserDefault(ctx, username) | ||||
| 	} | ||||
|  | ||||
| 	// Get connection | ||||
| 	db, err := h.getConnection() | ||||
| 	db, err := h.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -239,9 +242,9 @@ func (h *HANA) RevokeUser(statements dbplugin.Statements, username string) error | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (h *HANA) revokeUserDefault(username string) error { | ||||
| func (h *HANA) revokeUserDefault(ctx context.Context, username string) error { | ||||
| 	// Get connection | ||||
| 	db, err := h.getConnection() | ||||
| 	db, err := h.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package hana | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| @@ -25,7 +26,7 @@ func TestHANA_Initialize(t *testing.T) { | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*HANA) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -55,7 +56,7 @@ func TestHANA_CreateUser(t *testing.T) { | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*HANA) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -66,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test with no configured Creation Statememt | ||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour)) | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Hour)) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("Expected error when no creation statement is provided") | ||||
| 	} | ||||
| @@ -75,7 +76,7 @@ func TestHANA_CreateUser(t *testing.T) { | ||||
| 		CreationStatements: testHANARole, | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -98,7 +99,7 @@ func TestHANA_RevokeUser(t *testing.T) { | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*HANA) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -113,7 +114,7 @@ func TestHANA_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test default revoke statememts | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -121,7 +122,7 @@ func TestHANA_RevokeUser(t *testing.T) { | ||||
| 		t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -130,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test custom revoke statememt | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Hour)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -139,7 +140,7 @@ func TestHANA_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	statements.RevocationStatements = testHANADrop | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package mongodb | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| @@ -33,7 +34,7 @@ type mongoDBConnectionProducer struct { | ||||
| } | ||||
|  | ||||
| // Initialize parses connection configuration. | ||||
| func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { | ||||
| func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { | ||||
| 	c.Lock() | ||||
| 	defer c.Unlock() | ||||
|  | ||||
| @@ -75,7 +76,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri | ||||
| 	c.Initialized = true | ||||
|  | ||||
| 	if verifyConnection { | ||||
| 		if _, err := c.Connection(); err != nil { | ||||
| 		if _, err := c.Connection(ctx); err != nil { | ||||
| 			return fmt.Errorf("error verifying connection: %s", err) | ||||
| 		} | ||||
|  | ||||
| @@ -88,7 +89,7 @@ func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, veri | ||||
| } | ||||
|  | ||||
| // Connection creates a database connection. | ||||
| func (c *mongoDBConnectionProducer) Connection() (interface{}, error) { | ||||
| func (c *mongoDBConnectionProducer) Connection(_ context.Context) (interface{}, error) { | ||||
| 	if !c.Initialized { | ||||
| 		return nil, connutil.ErrNotInitialized | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package mongodb | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -27,6 +28,8 @@ type MongoDB struct { | ||||
| 	credsutil.CredentialsProducer | ||||
| } | ||||
|  | ||||
| var _ dbplugin.Database = &MongoDB{} | ||||
|  | ||||
| // New returns a new MongoDB instance | ||||
| func New() (interface{}, error) { | ||||
| 	connProducer := &mongoDBConnectionProducer{} | ||||
| @@ -63,8 +66,8 @@ func (m *MongoDB) Type() (string, error) { | ||||
| 	return mongoDBTypeName, nil | ||||
| } | ||||
|  | ||||
| func (m *MongoDB) getConnection() (*mgo.Session, error) { | ||||
| 	session, err := m.Connection() | ||||
| func (m *MongoDB) getConnection(ctx context.Context) (*mgo.Session, error) { | ||||
| 	session, err := m.Connection(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -80,7 +83,7 @@ func (m *MongoDB) getConnection() (*mgo.Session, error) { | ||||
| // | ||||
| // JSON Example: | ||||
| //  { "db": "admin", "roles": [{ "role": "readWrite" }, {"role": "read", "db": "foo"}] } | ||||
| func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	// Grab the lock | ||||
| 	m.Lock() | ||||
| 	defer m.Unlock() | ||||
| @@ -89,7 +92,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl | ||||
| 		return "", "", dbutil.ErrEmptyCreationStatement | ||||
| 	} | ||||
|  | ||||
| 	session, err := m.getConnection() | ||||
| 	session, err := m.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
| @@ -133,7 +136,7 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl | ||||
| 		if err := m.ConnectionProducer.Close(); err != nil { | ||||
| 			return "", "", errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) | ||||
| 		} | ||||
| 		session, err := m.getConnection() | ||||
| 		session, err := m.getConnection(ctx) | ||||
| 		if err != nil { | ||||
| 			return "", "", err | ||||
| 		} | ||||
| @@ -149,15 +152,15 @@ func (m *MongoDB) CreateUser(statements dbplugin.Statements, usernameConfig dbpl | ||||
| } | ||||
|  | ||||
| // RenewUser is not supported on MongoDB, so this is a no-op. | ||||
| func (m *MongoDB) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	// NOOP | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // RevokeUser drops the specified user from the authentication databse. If none is provided | ||||
| // in the revocation statement, the default "admin" authentication database will be assumed. | ||||
| func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) error { | ||||
| 	session, err := m.getConnection() | ||||
| func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	session, err := m.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -188,7 +191,7 @@ func (m *MongoDB) RevokeUser(statements dbplugin.Statements, username string) er | ||||
| 		if err := m.ConnectionProducer.Close(); err != nil { | ||||
| 			return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) | ||||
| 		} | ||||
| 		session, err := m.getConnection() | ||||
| 		session, err := m.getConnection(ctx) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package mongodb | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| @@ -79,7 +80,7 @@ func TestMongoDB_Initialize(t *testing.T) { | ||||
| 	db := dbRaw.(*MongoDB) | ||||
| 	connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer) | ||||
|  | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -107,7 +108,7 @@ func TestMongoDB_CreateUser(t *testing.T) { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| 	db := dbRaw.(*MongoDB) | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -121,7 +122,7 @@ func TestMongoDB_CreateUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -145,7 +146,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| 	db := dbRaw.(*MongoDB) | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -159,7 +160,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -182,7 +183,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| 	db := dbRaw.(*MongoDB) | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -196,7 +197,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -206,7 +207,7 @@ func TestMongoDB_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test default revocation statememt | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package mssql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| @@ -18,6 +19,8 @@ import ( | ||||
|  | ||||
| const msSQLTypeName = "mssql" | ||||
|  | ||||
| var _ dbplugin.Database = &MSSQL{} | ||||
|  | ||||
| // MSSQL is an implementation of Database interface | ||||
| type MSSQL struct { | ||||
| 	connutil.ConnectionProducer | ||||
| @@ -60,8 +63,8 @@ func (m *MSSQL) Type() (string, error) { | ||||
| 	return msSQLTypeName, nil | ||||
| } | ||||
|  | ||||
| func (m *MSSQL) getConnection() (*sql.DB, error) { | ||||
| 	db, err := m.Connection() | ||||
| func (m *MSSQL) getConnection(ctx context.Context) (*sql.DB, error) { | ||||
| 	db, err := m.Connection(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -71,13 +74,13 @@ func (m *MSSQL) getConnection() (*sql.DB, error) { | ||||
|  | ||||
| // CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by | ||||
| // the CreationStatement provided. | ||||
| func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	// Grab the lock | ||||
| 	m.Lock() | ||||
| 	defer m.Unlock() | ||||
|  | ||||
| 	// Get the connection | ||||
| 	db, err := m.getConnection() | ||||
| 	db, err := m.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
| @@ -138,7 +141,7 @@ func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug | ||||
| } | ||||
|  | ||||
| // RenewUser is not supported on MSSQL, so this is a no-op. | ||||
| func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	// NOOP | ||||
| 	return nil | ||||
| } | ||||
| @@ -146,13 +149,13 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir | ||||
| // RevokeUser attempts to drop the specified user. It will first attempt to disable login, | ||||
| // then kill pending connections from that user, and finally drop the user and login from the | ||||
| // database instance. | ||||
| func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { | ||||
| func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	if statements.RevocationStatements == "" { | ||||
| 		return m.revokeUserDefault(username) | ||||
| 		return m.revokeUserDefault(ctx, username) | ||||
| 	} | ||||
|  | ||||
| 	// Get connection | ||||
| 	db, err := m.getConnection() | ||||
| 	db, err := m.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -191,9 +194,9 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (m *MSSQL) revokeUserDefault(username string) error { | ||||
| func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error { | ||||
| 	// Get connection | ||||
| 	db, err := m.getConnection() | ||||
| 	db, err := m.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package mssql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| @@ -30,7 +31,7 @@ func TestMSSQL_Initialize(t *testing.T) { | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*MSSQL) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -51,7 +52,7 @@ func TestMSSQL_Initialize(t *testing.T) { | ||||
| 		"max_open_connections": "5", | ||||
| 	} | ||||
|  | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -69,7 +70,7 @@ func TestMSSQL_CreateUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*MSSQL) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -80,7 +81,7 @@ func TestMSSQL_CreateUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test with no configured Creation Statememt | ||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("Expected error when no creation statement is provided") | ||||
| 	} | ||||
| @@ -89,7 +90,7 @@ func TestMSSQL_CreateUser(t *testing.T) { | ||||
| 		CreationStatements: testMSSQLRole, | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -111,7 +112,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*MSSQL) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -125,7 +126,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -135,7 +136,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test default revoke statememts | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -144,7 +145,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { | ||||
| 		t.Fatal("Credentials were not revoked") | ||||
| 	} | ||||
|  | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -155,7 +156,7 @@ func TestMSSQL_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	// Test custom revoke statememt | ||||
| 	statements.RevocationStatements = testMSSQLDrop | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package mysql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -30,6 +31,8 @@ var ( | ||||
| 	LegacyUsernameLen int = 16 | ||||
| ) | ||||
|  | ||||
| var _ dbplugin.Database = &MySQL{} | ||||
|  | ||||
| type MySQL struct { | ||||
| 	connutil.ConnectionProducer | ||||
| 	credsutil.CredentialsProducer | ||||
| @@ -88,8 +91,8 @@ func (m *MySQL) Type() (string, error) { | ||||
| 	return mySQLTypeName, nil | ||||
| } | ||||
|  | ||||
| func (m *MySQL) getConnection() (*sql.DB, error) { | ||||
| 	db, err := m.Connection() | ||||
| func (m *MySQL) getConnection(ctx context.Context) (*sql.DB, error) { | ||||
| 	db, err := m.Connection(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -97,13 +100,13 @@ func (m *MySQL) getConnection() (*sql.DB, error) { | ||||
| 	return db.(*sql.DB), nil | ||||
| } | ||||
|  | ||||
| func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	// Grab the lock | ||||
| 	m.Lock() | ||||
| 	defer m.Unlock() | ||||
|  | ||||
| 	// Get the connection | ||||
| 	db, err := m.getConnection() | ||||
| 	db, err := m.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
| @@ -128,7 +131,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug | ||||
| 	} | ||||
|  | ||||
| 	// Start a transaction | ||||
| 	tx, err := db.Begin() | ||||
| 	tx, err := db.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
| @@ -146,7 +149,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug | ||||
| 			"expiration": expirationStr, | ||||
| 		}) | ||||
|  | ||||
| 		stmt, err := tx.Prepare(query) | ||||
| 		stmt, err := tx.PrepareContext(ctx, query) | ||||
| 		if err != nil { | ||||
| 			// If the error code we get back is Error 1295: This command is not | ||||
| 			// supported in the prepared statement protocol yet, we will execute | ||||
| @@ -155,7 +158,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug | ||||
| 			// prepare supported commands. If there is no error when running we | ||||
| 			// will continue to the next statement. | ||||
| 			if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 { | ||||
| 				_, err = tx.Exec(query) | ||||
| 				_, err = tx.ExecContext(ctx, query) | ||||
| 				if err != nil { | ||||
| 					return "", "", err | ||||
| 				} | ||||
| @@ -165,7 +168,7 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug | ||||
| 			return "", "", err | ||||
| 		} | ||||
| 		defer stmt.Close() | ||||
| 		if _, err := stmt.Exec(); err != nil { | ||||
| 		if _, err := stmt.ExecContext(ctx); err != nil { | ||||
| 			return "", "", err | ||||
| 		} | ||||
| 	} | ||||
| @@ -179,17 +182,17 @@ func (m *MySQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplug | ||||
| } | ||||
|  | ||||
| // NOOP | ||||
| func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| func (m *MySQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error { | ||||
| func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	// Grab the read lock | ||||
| 	m.Lock() | ||||
| 	defer m.Unlock() | ||||
|  | ||||
| 	// Get the connection | ||||
| 	db, err := m.getConnection() | ||||
| 	db, err := m.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -201,7 +204,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro | ||||
| 	} | ||||
|  | ||||
| 	// Start a transaction | ||||
| 	tx, err := db.Begin() | ||||
| 	tx, err := db.BeginTx(ctx, nil) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -217,7 +220,7 @@ func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) erro | ||||
| 		// 1295: This command is not supported in the prepared statement protocol yet | ||||
| 		// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ | ||||
| 		query = strings.Replace(query, "{{name}}", username, -1) | ||||
| 		_, err = tx.Exec(query) | ||||
| 		_, err = tx.ExecContext(ctx, query) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package mysql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| @@ -108,7 +109,7 @@ func TestMySQL_Initialize(t *testing.T) { | ||||
| 	db := dbRaw.(*MySQL) | ||||
| 	connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -128,7 +129,7 @@ func TestMySQL_Initialize(t *testing.T) { | ||||
| 		"max_open_connections": "5", | ||||
| 	} | ||||
|  | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -146,7 +147,7 @@ func TestMySQL_CreateUser(t *testing.T) { | ||||
| 	dbRaw, _ := f() | ||||
| 	db := dbRaw.(*MySQL) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -157,7 +158,7 @@ func TestMySQL_CreateUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test with no configured Creation Statememt | ||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("Expected error when no creation statement is provided") | ||||
| 	} | ||||
| @@ -166,7 +167,7 @@ func TestMySQL_CreateUser(t *testing.T) { | ||||
| 		CreationStatements: testMySQLRoleWildCard, | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -176,7 +177,7 @@ func TestMySQL_CreateUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test a second time to make sure usernames don't collide | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -188,7 +189,7 @@ func TestMySQL_CreateUser(t *testing.T) { | ||||
| 	// Test with a manualy prepare statement | ||||
| 	statements.CreationStatements = testMySQLRolePreparedStmt | ||||
|  | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -211,7 +212,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { | ||||
| 	dbRaw, _ := f() | ||||
| 	db := dbRaw.(*MySQL) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -222,7 +223,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test with no configured Creation Statememt | ||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("Expected error when no creation statement is provided") | ||||
| 	} | ||||
| @@ -231,7 +232,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { | ||||
| 		CreationStatements: testMySQLRoleWildCard, | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -241,7 +242,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test a second time to make sure usernames don't collide | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -263,7 +264,7 @@ func TestMySQL_RevokeUser(t *testing.T) { | ||||
| 	dbRaw, _ := f() | ||||
| 	db := dbRaw.(*MySQL) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -277,7 +278,7 @@ func TestMySQL_RevokeUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -287,7 +288,7 @@ func TestMySQL_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test default revoke statememts | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -297,7 +298,7 @@ func TestMySQL_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	statements.CreationStatements = testMySQLRoleWildCard | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -308,7 +309,7 @@ func TestMySQL_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	// Test custom revoke statements | ||||
| 	statements.RevocationStatements = testMySQLRevocationSQL | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package postgresql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| @@ -24,6 +25,8 @@ ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; | ||||
| ` | ||||
| ) | ||||
|  | ||||
| var _ dbplugin.Database = &PostgreSQL{} | ||||
|  | ||||
| // New implements builtinplugins.BuiltinFactory | ||||
| func New() (interface{}, error) { | ||||
| 	connProducer := &connutil.SQLConnectionProducer{} | ||||
| @@ -65,8 +68,8 @@ func (p *PostgreSQL) Type() (string, error) { | ||||
| 	return postgreSQLTypeName, nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) getConnection() (*sql.DB, error) { | ||||
| 	db, err := p.Connection() | ||||
| func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { | ||||
| 	db, err := p.Connection(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -74,7 +77,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { | ||||
| 	return db.(*sql.DB), nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | ||||
| 	if statements.CreationStatements == "" { | ||||
| 		return "", "", dbutil.ErrEmptyCreationStatement | ||||
| 	} | ||||
| @@ -99,7 +102,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d | ||||
| 	} | ||||
|  | ||||
| 	// Get the connection | ||||
| 	db, err := p.getConnection() | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
|  | ||||
| @@ -148,7 +151,7 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernameConfig d | ||||
| 	return username, password, nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| @@ -157,7 +160,7 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, | ||||
| 		renewStmts = defaultPostgresRenewSQL | ||||
| 	} | ||||
|  | ||||
| 	db, err := p.getConnection() | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -201,20 +204,20 @@ func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error { | ||||
| func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | ||||
| 	// Grab the lock | ||||
| 	p.Lock() | ||||
| 	defer p.Unlock() | ||||
|  | ||||
| 	if statements.RevocationStatements == "" { | ||||
| 		return p.defaultRevokeUser(username) | ||||
| 		return p.defaultRevokeUser(ctx, username) | ||||
| 	} | ||||
|  | ||||
| 	return p.customRevokeUser(username, statements.RevocationStatements) | ||||
| 	return p.customRevokeUser(ctx, username, statements.RevocationStatements) | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { | ||||
| 	db, err := p.getConnection() | ||||
| func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error { | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -253,8 +256,8 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (p *PostgreSQL) defaultRevokeUser(username string) error { | ||||
| 	db, err := p.getConnection() | ||||
| func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { | ||||
| 	db, err := p.getConnection(ctx) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package postgresql | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| @@ -72,7 +73,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { | ||||
|  | ||||
| 	connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) | ||||
|  | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -92,7 +93,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { | ||||
| 		"max_open_connections": "5", | ||||
| 	} | ||||
|  | ||||
| 	err = db.Initialize(connectionDetails, true) | ||||
| 	err = db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -109,7 +110,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*PostgreSQL) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -120,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test with no configured Creation Statememt | ||||
| 	_, _, err = db.CreateUser(dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("Expected error when no creation statement is provided") | ||||
| 	} | ||||
| @@ -129,7 +130,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 		CreationStatements: testPostgresRole, | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -139,7 +140,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	statements.CreationStatements = testPostgresReadOnlyRole | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -162,7 +163,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*PostgreSQL) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -176,7 +177,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -185,7 +186,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| 		t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) | ||||
| 	err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -197,7 +198,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| 		t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 	} | ||||
| 	statements.RenewStatements = defaultPostgresRenewSQL | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -206,7 +207,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { | ||||
| 		t.Fatalf("Could not connect with new credentials: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) | ||||
| 	err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -230,7 +231,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	dbRaw, _ := New() | ||||
| 	db := dbRaw.(*PostgreSQL) | ||||
| 	err := db.Initialize(connectionDetails, true) | ||||
| 	err := db.Initialize(context.Background(), connectionDetails, true) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -244,7 +245,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
| 		RoleName:    "test", | ||||
| 	} | ||||
|  | ||||
| 	username, password, err := db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -254,7 +255,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// Test default revoke statememts | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -263,7 +264,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
| 		t.Fatal("Credentials were not revoked") | ||||
| 	} | ||||
|  | ||||
| 	username, password, err = db.CreateUser(statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| @@ -274,7 +275,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { | ||||
|  | ||||
| 	// Test custom revoke statements | ||||
| 	statements.RevocationStatements = defaultPostgresRevocationSQL | ||||
| 	err = db.RevokeUser(statements, username) | ||||
| 	err = db.RevokeUser(context.Background(), statements, username) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package connutil | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"sync" | ||||
| ) | ||||
| @@ -14,8 +15,8 @@ var ( | ||||
| // connections and is used in all the builtin database types. | ||||
| type ConnectionProducer interface { | ||||
| 	Close() error | ||||
| 	Initialize(map[string]interface{}, bool) error | ||||
| 	Connection() (interface{}, error) | ||||
| 	Initialize(context.Context, map[string]interface{}, bool) error | ||||
| 	Connection(context.Context) (interface{}, error) | ||||
|  | ||||
| 	sync.Locker | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package connutil | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| @@ -25,7 +26,7 @@ type SQLConnectionProducer struct { | ||||
| 	sync.Mutex | ||||
| } | ||||
|  | ||||
| func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { | ||||
| func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { | ||||
| 	c.Lock() | ||||
| 	defer c.Unlock() | ||||
|  | ||||
| @@ -62,11 +63,11 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo | ||||
| 	c.Initialized = true | ||||
|  | ||||
| 	if verifyConnection { | ||||
| 		if _, err := c.Connection(); err != nil { | ||||
| 		if _, err := c.Connection(ctx); err != nil { | ||||
| 			return fmt.Errorf("error verifying connection: %s", err) | ||||
| 		} | ||||
|  | ||||
| 		if err := c.db.Ping(); err != nil { | ||||
| 		if err := c.db.PingContext(ctx); err != nil { | ||||
| 			return fmt.Errorf("error verifying connection: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| @@ -74,14 +75,14 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *SQLConnectionProducer) Connection() (interface{}, error) { | ||||
| func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) { | ||||
| 	if !c.Initialized { | ||||
| 		return nil, ErrNotInitialized | ||||
| 	} | ||||
|  | ||||
| 	// If we already have a DB, test it and return | ||||
| 	if c.db != nil { | ||||
| 		if err := c.db.Ping(); err == nil { | ||||
| 		if err := c.db.PingContext(ctx); err == nil { | ||||
| 			return c.db, nil | ||||
| 		} | ||||
| 		// If the ping was unsuccessful, close it and ignore errors as we'll be | ||||
|   | ||||
| @@ -1017,8 +1017,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error { | ||||
| 	} | ||||
|  | ||||
| 	// Handle standard revocation via backends | ||||
| 	resp, err := m.router.Route(logical.RevokeRequest( | ||||
| 		le.Path, le.Secret, le.Data)) | ||||
| 	resp, err := m.router.Route(logical.RevokeRequest(le.Path, le.Secret, le.Data)) | ||||
| 	if err != nil || (resp != nil && resp.IsError()) { | ||||
| 		return fmt.Errorf("failed to revoke entry: resp:%#v err:%s", resp, err) | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Brian Kassouf
					Brian Kassouf