From 03e2bcbc7902128a1e5a0181fe2d5beb1ba5fe8f Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 16:41:06 -0700 Subject: [PATCH] Update Type() to return an error --- builtin/logical/database/backend.go | 2 +- builtin/logical/database/dbplugin/client.go | 10 +++++----- .../logical/database/dbplugin/databasemiddleware.go | 4 ++-- builtin/logical/database/dbplugin/plugin.go | 12 +++++++++--- builtin/logical/database/dbplugin/plugin_test.go | 12 ++++++------ builtin/logical/database/dbplugin/server.go | 5 +++-- 6 files changed, 26 insertions(+), 19 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 618ffac6f8..c8f9ad8541 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -162,5 +162,5 @@ as secret backends, including but not limited to: cassandra, msslq, mysql, postgres After mounting this backend, configure it using the endpoints within -the "database/dbs/" path. +the "database/config/" path. ` diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 5bdc3a01a0..93db86595a 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -52,10 +52,11 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn return nil, err } - // We should have a Greeter now! This feels like a normal interface + // We should have a database type now. This feels like a normal interface // implementation but is in fact over an RPC connection. databaseRPC := raw.(*databasePluginRPCClient) + // Wrap RPC implimentation in DatabasePluginClient return &DatabasePluginClient{ client: client, databasePluginRPCClient: databaseRPC, @@ -70,12 +71,11 @@ type databasePluginRPCClient struct { client *rpc.Client } -func (dr *databasePluginRPCClient) Type() string { +func (dr *databasePluginRPCClient) Type() (string, error) { var dbType string - //TODO: catch error - dr.client.Call("Plugin.Type", struct{}{}, &dbType) + err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) - return fmt.Sprintf("plugin-%s", dbType) + return fmt.Sprintf("plugin-%s", dbType), err } func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 2137cd9c38..e28a8741e4 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -18,7 +18,7 @@ type databaseTracingMiddleware struct { typeStr string } -func (mw *databaseTracingMiddleware) Type() string { +func (mw *databaseTracingMiddleware) Type() (string, error) { return mw.next.Type() } @@ -87,7 +87,7 @@ type databaseMetricsMiddleware struct { typeStr string } -func (mw *databaseMetricsMiddleware) Type() string { +func (mw *databaseMetricsMiddleware) Type() (string, error) { return mw.next.Type() } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index dadb6639ee..5e6ce939be 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -2,6 +2,7 @@ package dbplugin import ( "errors" + "fmt" "net/rpc" "time" @@ -16,7 +17,7 @@ var ( // DatabaseType is the interface that all database objects must implement. type DatabaseType interface { - Type() string + Type() (string, error) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) RenewUser(statements Statements, username string, expiration time.Time) error RevokeUser(statements Statements, username string) error @@ -52,16 +53,21 @@ func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Log return nil, err } + typeStr, err := db.Type() + if err != nil { + return nil, fmt.Errorf("error getting plugin type: %s", err) + } + // Wrap with metrics middleware db = &databaseMetricsMiddleware{ next: db, - typeStr: db.Type(), + typeStr: typeStr, } // Wrap with tracing middleware db = &databaseTracingMiddleware{ next: db, - typeStr: db.Type(), + typeStr: typeStr, logger: logger, } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 7909bbd4e5..1587ba24a5 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -19,7 +19,7 @@ type mockPlugin struct { users map[string][]string } -func (m *mockPlugin) Type() string { return "mock" } +func (m *mockPlugin) Type() (string, error) { return "mock", nil } func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { err = errors.New("err") if usernamePrefix == "" || expiration.IsZero() { @@ -59,7 +59,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) delete(m.users, username) return nil } -func (m *mockPlugin) Initialize(conf map[string]interface{}) error { +func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error { err := errors.New("err") if len(conf) != 1 { return err @@ -108,7 +108,7 @@ func TestPlugin_Initialize(t *testing.T) { "test": 1, } - err = dbRaw.Initialize(connectionDetails) + err = dbRaw.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -133,7 +133,7 @@ func TestPlugin_CreateUser(t *testing.T) { "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -167,7 +167,7 @@ func TestPlugin_RenewUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -196,7 +196,7 @@ func TestPlugin_RevokeUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 326e25103c..3a3e233946 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -42,8 +42,9 @@ type databasePluginRPCServer struct { } func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - *resp = ds.impl.Type() - return nil + var err error + *resp, err = ds.impl.Type() + return err } func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {