Update Type() to return an error

This commit is contained in:
Brian Kassouf
2017-04-12 16:41:06 -07:00
parent f2401c0128
commit 03e2bcbc79
6 changed files with 26 additions and 19 deletions

View File

@@ -162,5 +162,5 @@ as secret backends, including but not limited to:
cassandra, msslq, mysql, postgres
After mounting this backend, configure it using the endpoints within
the "database/dbs/" path.
the "database/config/" path.
`

View File

@@ -52,10 +52,11 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn
return nil, err
}
// We should have a Greeter now! This feels like a normal interface
// We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
databaseRPC := raw.(*databasePluginRPCClient)
// Wrap RPC implimentation in DatabasePluginClient
return &DatabasePluginClient{
client: client,
databasePluginRPCClient: databaseRPC,
@@ -70,12 +71,11 @@ type databasePluginRPCClient struct {
client *rpc.Client
}
func (dr *databasePluginRPCClient) Type() string {
func (dr *databasePluginRPCClient) Type() (string, error) {
var dbType string
//TODO: catch error
dr.client.Call("Plugin.Type", struct{}{}, &dbType)
err := dr.client.Call("Plugin.Type", struct{}{}, &dbType)
return fmt.Sprintf("plugin-%s", dbType)
return fmt.Sprintf("plugin-%s", dbType), err
}
func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {

View File

@@ -18,7 +18,7 @@ type databaseTracingMiddleware struct {
typeStr string
}
func (mw *databaseTracingMiddleware) Type() string {
func (mw *databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type()
}
@@ -87,7 +87,7 @@ type databaseMetricsMiddleware struct {
typeStr string
}
func (mw *databaseMetricsMiddleware) Type() string {
func (mw *databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type()
}

View File

@@ -2,6 +2,7 @@ package dbplugin
import (
"errors"
"fmt"
"net/rpc"
"time"
@@ -16,7 +17,7 @@ var (
// DatabaseType is the interface that all database objects must implement.
type DatabaseType interface {
Type() string
Type() (string, error)
CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error)
RenewUser(statements Statements, username string, expiration time.Time) error
RevokeUser(statements Statements, username string) error
@@ -52,16 +53,21 @@ func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Log
return nil, err
}
typeStr, err := db.Type()
if err != nil {
return nil, fmt.Errorf("error getting plugin type: %s", err)
}
// Wrap with metrics middleware
db = &databaseMetricsMiddleware{
next: db,
typeStr: db.Type(),
typeStr: typeStr,
}
// Wrap with tracing middleware
db = &databaseTracingMiddleware{
next: db,
typeStr: db.Type(),
typeStr: typeStr,
logger: logger,
}

View File

@@ -19,7 +19,7 @@ type mockPlugin struct {
users map[string][]string
}
func (m *mockPlugin) Type() string { return "mock" }
func (m *mockPlugin) Type() (string, error) { return "mock", nil }
func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
err = errors.New("err")
if usernamePrefix == "" || expiration.IsZero() {
@@ -59,7 +59,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string)
delete(m.users, username)
return nil
}
func (m *mockPlugin) Initialize(conf map[string]interface{}) error {
func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error {
err := errors.New("err")
if len(conf) != 1 {
return err
@@ -108,7 +108,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1,
}
err = dbRaw.Initialize(connectionDetails)
err = dbRaw.Initialize(connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -133,7 +133,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1,
}
err = db.Initialize(connectionDetails)
err = db.Initialize(connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -167,7 +167,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(connectionDetails)
err = db.Initialize(connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -196,7 +196,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(connectionDetails)
err = db.Initialize(connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@@ -42,8 +42,9 @@ type databasePluginRPCServer struct {
}
func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error {
*resp = ds.impl.Type()
return nil
var err error
*resp, err = ds.impl.Type()
return err
}
func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error {