mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	Update Type() to return an error
This commit is contained in:
		@@ -162,5 +162,5 @@ as secret backends, including but not limited to:
 | 
			
		||||
cassandra, msslq, mysql, postgres
 | 
			
		||||
 | 
			
		||||
After mounting this backend, configure it using the endpoints within
 | 
			
		||||
the "database/dbs/" path.
 | 
			
		||||
the "database/config/" path.
 | 
			
		||||
`
 | 
			
		||||
 
 | 
			
		||||
@@ -52,10 +52,11 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// We should have a Greeter now! This feels like a normal interface
 | 
			
		||||
	// We should have a database type now. This feels like a normal interface
 | 
			
		||||
	// implementation but is in fact over an RPC connection.
 | 
			
		||||
	databaseRPC := raw.(*databasePluginRPCClient)
 | 
			
		||||
 | 
			
		||||
	// Wrap RPC implimentation in DatabasePluginClient
 | 
			
		||||
	return &DatabasePluginClient{
 | 
			
		||||
		client:                  client,
 | 
			
		||||
		databasePluginRPCClient: databaseRPC,
 | 
			
		||||
@@ -70,12 +71,11 @@ type databasePluginRPCClient struct {
 | 
			
		||||
	client *rpc.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dr *databasePluginRPCClient) Type() string {
 | 
			
		||||
func (dr *databasePluginRPCClient) Type() (string, error) {
 | 
			
		||||
	var dbType string
 | 
			
		||||
	//TODO: catch error
 | 
			
		||||
	dr.client.Call("Plugin.Type", struct{}{}, &dbType)
 | 
			
		||||
	err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("plugin-%s", dbType)
 | 
			
		||||
	return fmt.Sprintf("plugin-%s", dbType), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -18,7 +18,7 @@ type databaseTracingMiddleware struct {
 | 
			
		||||
	typeStr string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mw *databaseTracingMiddleware) Type() string {
 | 
			
		||||
func (mw *databaseTracingMiddleware) Type() (string, error) {
 | 
			
		||||
	return mw.next.Type()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -87,7 +87,7 @@ type databaseMetricsMiddleware struct {
 | 
			
		||||
	typeStr string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mw *databaseMetricsMiddleware) Type() string {
 | 
			
		||||
func (mw *databaseMetricsMiddleware) Type() (string, error) {
 | 
			
		||||
	return mw.next.Type()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package dbplugin
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/rpc"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
@@ -16,7 +17,7 @@ var (
 | 
			
		||||
 | 
			
		||||
// DatabaseType is the interface that all database objects must implement.
 | 
			
		||||
type DatabaseType interface {
 | 
			
		||||
	Type() string
 | 
			
		||||
	Type() (string, error)
 | 
			
		||||
	CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error)
 | 
			
		||||
	RenewUser(statements Statements, username string, expiration time.Time) error
 | 
			
		||||
	RevokeUser(statements Statements, username string) error
 | 
			
		||||
@@ -52,16 +53,21 @@ func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Log
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	typeStr, err := db.Type()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error getting plugin type: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Wrap with metrics middleware
 | 
			
		||||
	db = &databaseMetricsMiddleware{
 | 
			
		||||
		next:    db,
 | 
			
		||||
		typeStr: db.Type(),
 | 
			
		||||
		typeStr: typeStr,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Wrap with tracing middleware
 | 
			
		||||
	db = &databaseTracingMiddleware{
 | 
			
		||||
		next:    db,
 | 
			
		||||
		typeStr: db.Type(),
 | 
			
		||||
		typeStr: typeStr,
 | 
			
		||||
		logger:  logger,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,7 @@ type mockPlugin struct {
 | 
			
		||||
	users map[string][]string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mockPlugin) Type() string { return "mock" }
 | 
			
		||||
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
 | 
			
		||||
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
 | 
			
		||||
	err = errors.New("err")
 | 
			
		||||
	if usernamePrefix == "" || expiration.IsZero() {
 | 
			
		||||
@@ -59,7 +59,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
 | 
			
		||||
	delete(m.users, username)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
func (m *mockPlugin) Initialize(conf map[string]interface{}) error {
 | 
			
		||||
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
 | 
			
		||||
	err := errors.New("err")
 | 
			
		||||
	if len(conf) != 1 {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -108,7 +108,7 @@ func TestPlugin_Initialize(t *testing.T) {
 | 
			
		||||
		"test": 1,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = dbRaw.Initialize(connectionDetails)
 | 
			
		||||
	err = dbRaw.Initialize(connectionDetails, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("err: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -133,7 +133,7 @@ func TestPlugin_CreateUser(t *testing.T) {
 | 
			
		||||
		"test": 1,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = db.Initialize(connectionDetails)
 | 
			
		||||
	err = db.Initialize(connectionDetails, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("err: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -167,7 +167,7 @@ func TestPlugin_RenewUser(t *testing.T) {
 | 
			
		||||
	connectionDetails := map[string]interface{}{
 | 
			
		||||
		"test": 1,
 | 
			
		||||
	}
 | 
			
		||||
	err = db.Initialize(connectionDetails)
 | 
			
		||||
	err = db.Initialize(connectionDetails, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("err: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -196,7 +196,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
 | 
			
		||||
	connectionDetails := map[string]interface{}{
 | 
			
		||||
		"test": 1,
 | 
			
		||||
	}
 | 
			
		||||
	err = db.Initialize(connectionDetails)
 | 
			
		||||
	err = db.Initialize(connectionDetails, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("err: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -42,8 +42,9 @@ type databasePluginRPCServer struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
 | 
			
		||||
	*resp = ds.impl.Type()
 | 
			
		||||
	return nil
 | 
			
		||||
	var err error
 | 
			
		||||
	*resp, err = ds.impl.Type()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user