Migrate database plugin methods to sdk

This commit is contained in:
Jeff Mitchell
2019-04-15 11:36:10 -04:00
parent 901060d479
commit 05bcacee74
19 changed files with 2587 additions and 1 deletions

89
sdk/dbplugin/client.go Normal file
View File

@@ -0,0 +1,89 @@
package dbplugin
import (
"context"
"errors"
"sync"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
)
// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close
// method to also call Kill() on the plugin.Client.
type DatabasePluginClient struct {
client *plugin.Client
sync.Mutex
Database
}
// This wraps the Close call and ensures we both close the database connection
// and kill the plugin.
func (dc *DatabasePluginClient) Close() error {
err := dc.Database.Close()
dc.client.Kill()
return err
}
// NewPluginClient returns a databaseRPCClient with a connection to a running
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
// plugin is killed on call of Close().
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (Database, error) {
// pluginSets is the map of plugins we can dispense.
pluginSets := map[int]plugin.PluginSet{
// Version 3 used to supports both protocols. We want to keep it around
// since it's possible old plugins built against this version will still
// work with gRPC. There is currently no difference between version 3
// and version 4.
3: plugin.PluginSet{
"database": new(GRPCDatabasePlugin),
},
// Version 4 only supports gRPC
4: plugin.PluginSet{
"database": new(GRPCDatabasePlugin),
},
}
var client *plugin.Client
var err error
if isMetadataMode {
client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSets, handshakeConfig, []string{}, logger)
} else {
client, err = pluginRunner.Run(ctx, sys, pluginSets, handshakeConfig, []string{}, logger)
}
if err != nil {
return nil, err
}
// Connect via RPC
rpcClient, err := client.Client()
if err != nil {
return nil, err
}
// Request the plugin
raw, err := rpcClient.Dispense("database")
if err != nil {
return nil, err
}
// We should have a database type now. This feels like a normal interface
// implementation but is in fact over an RPC connection.
var db Database
switch raw.(type) {
case *gRPCClient:
db = raw.(*gRPCClient)
default:
return nil, errors.New("unsupported client type")
}
// Wrap RPC implementation in DatabasePluginClient
return &DatabasePluginClient{
client: client,
Database: db,
}, nil
}

1062
sdk/dbplugin/database.pb.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,93 @@
syntax = "proto3";
option go_package = "github.com/hashicorp/vault/sdk/dbplugin";
package dbplugin;
import "google/protobuf/timestamp.proto";
message InitializeRequest {
option deprecated = true;
bytes config = 1;
bool verify_connection = 2;
}
message InitRequest {
bytes config = 1;
bool verify_connection = 2;
}
message CreateUserRequest {
Statements statements = 1;
UsernameConfig username_config = 2;
google.protobuf.Timestamp expiration = 3;
}
message RenewUserRequest {
Statements statements = 1;
string username = 2;
google.protobuf.Timestamp expiration = 3;
}
message RevokeUserRequest {
Statements statements = 1;
string username = 2;
}
message RotateRootCredentialsRequest {
repeated string statements = 1;
}
message Statements {
// DEPRECATED, will be removed in 0.12
string creation_statements = 1 [deprecated=true];
// DEPRECATED, will be removed in 0.12
string revocation_statements = 2 [deprecated=true];
// DEPRECATED, will be removed in 0.12
string rollback_statements = 3 [deprecated=true];
// DEPRECATED, will be removed in 0.12
string renew_statements = 4 [deprecated=true];
repeated string creation = 5;
repeated string revocation = 6;
repeated string rollback = 7;
repeated string renewal = 8;
}
message UsernameConfig {
string DisplayName = 1;
string RoleName = 2;
}
message InitResponse {
bytes config = 1;
}
message CreateUserResponse {
string username = 1;
string password = 2;
}
message TypeResponse {
string type = 1;
}
message RotateRootCredentialsResponse {
bytes config = 1;
}
message Empty {}
service Database {
rpc Type(Empty) returns (TypeResponse);
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
rpc RenewUser(RenewUserRequest) returns (Empty);
rpc RevokeUser(RevokeUserRequest) returns (Empty);
rpc RotateRootCredentials(RotateRootCredentialsRequest) returns (RotateRootCredentialsResponse);
rpc Init(InitRequest) returns (InitResponse);
rpc Close(Empty) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty) {
option deprecated = true;
};
}

View File

@@ -0,0 +1,275 @@
package dbplugin
import (
"context"
"errors"
"net/url"
"strings"
"sync"
"time"
"github.com/hashicorp/errwrap"
metrics "github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
)
// ---- Tracing Middleware Domain ----
// databaseTracingMiddleware wraps a implementation of Database and executes
// trace logging on function call.
type databaseTracingMiddleware struct {
next Database
logger log.Logger
}
func (mw *databaseTracingMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(then time.Time) {
mw.logger.Trace("create user", "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("create user", "status", "started")
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
}
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(then time.Time) {
mw.logger.Trace("renew user", "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("renew user", "status", "started")
return mw.next.RenewUser(ctx, statements, username, expiration)
}
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(then time.Time) {
mw.logger.Trace("revoke user", "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("revoke user", "status", "started")
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseTracingMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
defer func(then time.Time) {
mw.logger.Trace("rotate root credentials", "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("rotate root credentials", "status", "started")
return mw.next.RotateRootCredentials(ctx, statements)
}
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *databaseTracingMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
defer func(then time.Time) {
mw.logger.Trace("initialize", "status", "finished", "verify", verifyConnection, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("initialize", "status", "started")
return mw.next.Init(ctx, conf, verifyConnection)
}
func (mw *databaseTracingMiddleware) Close() (err error) {
defer func(then time.Time) {
mw.logger.Trace("close", "status", "finished", "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("close", "status", "started")
return mw.next.Close()
}
// ---- Metrics Middleware Domain ----
// databaseMetricsMiddleware wraps an implementation of Databases and on
// function call logs metrics about this instance.
type databaseMetricsMiddleware struct {
next Database
typeStr string
}
func (mw *databaseMetricsMiddleware) Type() (string, error) {
return mw.next.Type()
}
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "CreateUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
}
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "RenewUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
return mw.next.RenewUser(ctx, statements, username, expiration)
}
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "RevokeUser", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseMetricsMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RotateRootCredentials"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RotateRootCredentials"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "RotateRootCredentials", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "RotateRootCredentials"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials"}, 1)
return mw.next.RotateRootCredentials(ctx, statements)
}
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *databaseMetricsMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Init(ctx, conf, verifyConnection)
}
func (mw *databaseMetricsMiddleware) Close() (err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Close"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "Close", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "Close"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
return mw.next.Close()
}
// ---- Error Sanitizer Middleware Domain ----
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
// sanitizes returned error messages
type DatabaseErrorSanitizerMiddleware struct {
l sync.RWMutex
next Database
secretsFn func() map[string]interface{}
}
func NewDatabaseErrorSanitizerMiddleware(next Database, secretsFn func() map[string]interface{}) *DatabaseErrorSanitizerMiddleware {
return &DatabaseErrorSanitizerMiddleware{
next: next,
secretsFn: secretsFn,
}
}
func (mw *DatabaseErrorSanitizerMiddleware) Type() (string, error) {
dbType, err := mw.next.Type()
return dbType, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
username, password, err = mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
return username, password, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
return mw.sanitize(mw.next.RenewUser(ctx, statements, username, expiration))
}
func (mw *DatabaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
return mw.sanitize(mw.next.RevokeUser(ctx, statements, username))
}
func (mw *DatabaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
conf, err = mw.next.RotateRootCredentials(ctx, statements)
return conf, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *DatabaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
saveConf, err = mw.next.Init(ctx, conf, verifyConnection)
return saveConf, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) Close() (err error) {
return mw.sanitize(mw.next.Close())
}
// sanitize
func (mw *DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
if err == nil {
return nil
}
if errwrap.ContainsType(err, new(url.Error)) {
return errors.New("unable to parse connection url")
}
if mw.secretsFn != nil {
for k, v := range mw.secretsFn() {
if k == "" {
continue
}
err = errors.New(strings.Replace(err.Error(), k, v.(string), -1))
}
}
return err
}

View File

@@ -0,0 +1,285 @@
package dbplugin
import (
"context"
"encoding/json"
"errors"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
)
var (
ErrPluginShutdown = errors.New("plugin shutdown")
)
// ---- gRPC Server domain ----
type gRPCServer struct {
impl Database
}
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
t, err := s.impl.Type()
if err != nil {
return nil, err
}
return &TypeResponse{
Type: t,
}, nil
}
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
return &CreateUserResponse{
Username: u,
Password: p,
}, err
}
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
e, err := ptypes.Timestamp(req.Expiration)
if err != nil {
return nil, err
}
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
return &Empty{}, err
}
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
return &Empty{}, err
}
func (s *gRPCServer) RotateRootCredentials(ctx context.Context, req *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error) {
resp, err := s.impl.RotateRootCredentials(ctx, req.Statements)
if err != nil {
return nil, err
}
respConfig, err := json.Marshal(resp)
if err != nil {
return nil, err
}
return &RotateRootCredentialsResponse{
Config: respConfig,
}, err
}
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
_, err := s.Init(ctx, &InitRequest{
Config: req.Config,
VerifyConnection: req.VerifyConnection,
})
return &Empty{}, err
}
func (s *gRPCServer) Init(ctx context.Context, req *InitRequest) (*InitResponse, error) {
config := map[string]interface{}{}
err := json.Unmarshal(req.Config, &config)
if err != nil {
return nil, err
}
resp, err := s.impl.Init(ctx, config, req.VerifyConnection)
if err != nil {
return nil, err
}
respConfig, err := json.Marshal(resp)
if err != nil {
return nil, err
}
return &InitResponse{
Config: respConfig,
}, err
}
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
s.impl.Close()
return &Empty{}, nil
}
// ---- gRPC client domain ----
type gRPCClient struct {
client DatabaseClient
clientConn *grpc.ClientConn
doneCtx context.Context
}
func (c *gRPCClient) Type() (string, error) {
resp, err := c.client.Type(c.doneCtx, &Empty{})
if err != nil {
return "", err
}
return resp.Type, err
}
func (c *gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return "", "", err
}
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
defer close(quitCh)
defer cancel()
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
Statements: &statements,
UsernameConfig: &usernameConfig,
Expiration: t,
})
if err != nil {
if c.doneCtx.Err() != nil {
return "", "", ErrPluginShutdown
}
return "", "", err
}
return resp.Username, resp.Password, err
}
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return err
}
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
defer close(quitCh)
defer cancel()
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
Statements: &statements,
Username: username,
Expiration: t,
})
if err != nil {
if c.doneCtx.Err() != nil {
return ErrPluginShutdown
}
return err
}
return nil
}
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
defer close(quitCh)
defer cancel()
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
Statements: &statements,
Username: username,
})
if err != nil {
if c.doneCtx.Err() != nil {
return ErrPluginShutdown
}
return err
}
return nil
}
func (c *gRPCClient) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
defer close(quitCh)
defer cancel()
resp, err := c.client.RotateRootCredentials(ctx, &RotateRootCredentialsRequest{
Statements: statements,
})
if err != nil {
if c.doneCtx.Err() != nil {
return nil, ErrPluginShutdown
}
return nil, err
}
if err := json.Unmarshal(resp.Config, &conf); err != nil {
return nil, err
}
return conf, nil
}
func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *gRPCClient) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
configRaw, err := json.Marshal(conf)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
defer close(quitCh)
defer cancel()
resp, err := c.client.Init(ctx, &InitRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
if err != nil {
// Fall back to old call if not implemented
grpcStatus, ok := status.FromError(err)
if ok && grpcStatus.Code() == codes.Unimplemented {
_, err = c.client.Initialize(ctx, &InitializeRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
if err == nil {
return conf, nil
}
}
if c.doneCtx.Err() != nil {
return nil, ErrPluginShutdown
}
return nil, err
}
if err := json.Unmarshal(resp.Config, &conf); err != nil {
return nil, err
}
return conf, nil
}
func (c *gRPCClient) Close() error {
_, err := c.client.Close(c.doneCtx, &Empty{})
return err
}

View File

@@ -0,0 +1,25 @@
package connutil
import (
"context"
"errors"
"sync"
)
var (
ErrNotInitialized = errors.New("connection has not been initialized")
)
// ConnectionProducer can be used as an embedded interface in the Database
// definition. It implements the methods dealing with individual database
// connections and is used in all the builtin database types.
type ConnectionProducer interface {
Close() error
Init(context.Context, map[string]interface{}, bool) (map[string]interface{}, error)
Connection(context.Context) (interface{}, error)
sync.Locker
// DEPRECATED, will be removed in 0.12
Initialize(context.Context, map[string]interface{}, bool) error
}

View File

@@ -0,0 +1,164 @@
package connutil
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/sdk/dbplugin/helper/dbutil"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/mitchellh/mapstructure"
)
var _ ConnectionProducer = &SQLConnectionProducer{}
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type SQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
Username string `json:"username" mapstructure:"username" structs:"username"`
Password string `json:"password" mapstructure:"password" structs:"password"`
Type string
RawConfig map[string]interface{}
maxConnectionLifetime time.Duration
Initialized bool
db *sql.DB
sync.Mutex
}
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock()
defer c.Unlock()
c.RawConfig = conf
err := mapstructure.WeakDecode(conf, &c)
if err != nil {
return nil, err
}
if len(c.ConnectionURL) == 0 {
return nil, fmt.Errorf("connection_url cannot be empty")
}
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
"username": c.Username,
"password": c.Password,
})
if c.MaxOpenConnections == 0 {
c.MaxOpenConnections = 2
}
if c.MaxIdleConnections == 0 {
c.MaxIdleConnections = c.MaxOpenConnections
}
if c.MaxIdleConnections > c.MaxOpenConnections {
c.MaxIdleConnections = c.MaxOpenConnections
}
if c.MaxConnectionLifetimeRaw == nil {
c.MaxConnectionLifetimeRaw = "0s"
}
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
if err != nil {
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
}
// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true
if verifyConnection {
if _, err := c.Connection(ctx); err != nil {
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
if err := c.db.PingContext(ctx); err != nil {
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
}
return c.RawConfig, nil
}
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
if !c.Initialized {
return nil, ErrNotInitialized
}
// If we already have a DB, test it and return
if c.db != nil {
if err := c.db.PingContext(ctx); err == nil {
return c.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
c.db.Close()
}
// For mssql backend, switch to sqlserver instead
dbType := c.Type
if c.Type == "mssql" {
dbType = "sqlserver"
}
// Otherwise, attempt to make connection
conn := c.ConnectionURL
// Ensure timezone is set to UTC for all the connections
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
if strings.Contains(conn, "?") {
conn += "&timezone=utc"
} else {
conn += "?timezone=utc"
}
}
var err error
c.db, err = sql.Open(dbType, conn)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
c.db.SetMaxOpenConns(c.MaxOpenConnections)
c.db.SetMaxIdleConns(c.MaxIdleConnections)
c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
return c.db, nil
}
func (c *SQLConnectionProducer) SecretValues() map[string]interface{} {
return map[string]interface{}{
c.Password: "[password]",
}
}
// Close attempts to close the connection
func (c *SQLConnectionProducer) Close() error {
// Grab the write lock
c.Lock()
defer c.Unlock()
if c.db != nil {
c.db.Close()
}
c.db = nil
return nil
}

View File

@@ -0,0 +1,46 @@
package credsutil
import (
"time"
"fmt"
"github.com/hashicorp/vault/sdk/dbplugin"
"github.com/hashicorp/vault/sdk/helper/base62"
)
// CredentialsProducer can be used as an embedded interface in the Database
// definition. It implements the methods for generating user information for a
// particular database type and is used in all the builtin database types.
type CredentialsProducer interface {
GenerateUsername(usernameConfig dbplugin.UsernameConfig) (string, error)
GeneratePassword() (string, error)
GenerateExpiration(ttl time.Time) (string, error)
}
const (
reqStr = `A1a-`
minStrLen = 10
)
// RandomAlphaNumeric returns a random string of characters [A-Za-z0-9-]
// of the provided length. The string generated takes up to 4 characters
// of space that are predefined and prepended to ensure password
// character requirements. It also requires a min length of 10 characters.
func RandomAlphaNumeric(length int, prependA1a bool) (string, error) {
if length < minStrLen {
return "", fmt.Errorf("minimum length of %d is required", minStrLen)
}
var prefix string
if prependA1a {
prefix = reqStr
}
randomStr, err := base62.Random(length - len(prefix))
if err != nil {
return "", err
}
return prefix + randomStr, nil
}

View File

@@ -0,0 +1,40 @@
package credsutil
import (
"strings"
"testing"
)
func TestRandomAlphaNumeric(t *testing.T) {
s, err := RandomAlphaNumeric(10, true)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if len(s) != 10 {
t.Fatalf("Unexpected length of string, expected 10, got string: %s", s)
}
s, err = RandomAlphaNumeric(20, true)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if len(s) != 20 {
t.Fatalf("Unexpected length of string, expected 20, got string: %s", s)
}
if !strings.Contains(s, reqStr) {
t.Fatalf("Expected %s to contain %s", s, reqStr)
}
s, err = RandomAlphaNumeric(20, false)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if len(s) != 20 {
t.Fatalf("Unexpected length of string, expected 20, got string: %s", s)
}
if strings.Contains(s, reqStr) {
t.Fatalf("Expected %s not to contain %s", s, reqStr)
}
}

View File

@@ -0,0 +1,72 @@
package credsutil
import (
"fmt"
"time"
"github.com/hashicorp/vault/sdk/dbplugin"
)
const (
NoneLength int = -1
)
// SQLCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types.
type SQLCredentialsProducer struct {
DisplayNameLen int
RoleNameLen int
UsernameLen int
Separator string
}
func (scp *SQLCredentialsProducer) GenerateUsername(config dbplugin.UsernameConfig) (string, error) {
username := "v"
displayName := config.DisplayName
if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen {
displayName = displayName[:scp.DisplayNameLen]
} else if scp.DisplayNameLen == NoneLength {
displayName = ""
}
if len(displayName) > 0 {
username = fmt.Sprintf("%s%s%s", username, scp.Separator, displayName)
}
roleName := config.RoleName
if scp.RoleNameLen > 0 && len(roleName) > scp.RoleNameLen {
roleName = roleName[:scp.RoleNameLen]
} else if scp.RoleNameLen == NoneLength {
roleName = ""
}
if len(roleName) > 0 {
username = fmt.Sprintf("%s%s%s", username, scp.Separator, roleName)
}
userUUID, err := RandomAlphaNumeric(20, false)
if err != nil {
return "", err
}
username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID)
username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().Unix()))
if scp.UsernameLen > 0 && len(username) > scp.UsernameLen {
username = username[:scp.UsernameLen]
}
return username, nil
}
func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) {
password, err := RandomAlphaNumeric(20, true)
if err != nil {
return "", err
}
return password, nil
}
func (scp *SQLCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
return ttl.Format("2006-01-02 15:04:05-0700"), nil
}

View File

@@ -0,0 +1,52 @@
package dbutil
import (
"errors"
"fmt"
"strings"
"github.com/hashicorp/vault/sdk/dbplugin"
)
var (
ErrEmptyCreationStatement = errors.New("empty creation statements")
)
// Query templates a query for us.
func QueryHelper(tpl string, data map[string]string) string {
for k, v := range data {
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
}
return tpl
}
// StatementCompatibilityHelper will populate the statements fields to support
// compatibility
func StatementCompatibilityHelper(statements dbplugin.Statements) dbplugin.Statements {
switch {
case len(statements.Creation) > 0 && len(statements.CreationStatements) == 0:
statements.CreationStatements = strings.Join(statements.Creation, ";")
case len(statements.CreationStatements) > 0:
statements.Creation = []string{statements.CreationStatements}
}
switch {
case len(statements.Revocation) > 0 && len(statements.RevocationStatements) == 0:
statements.RevocationStatements = strings.Join(statements.Revocation, ";")
case len(statements.RevocationStatements) > 0:
statements.Revocation = []string{statements.RevocationStatements}
}
switch {
case len(statements.Renewal) > 0 && len(statements.RenewStatements) == 0:
statements.RenewStatements = strings.Join(statements.Renewal, ";")
case len(statements.RenewStatements) > 0:
statements.Renewal = []string{statements.RenewStatements}
}
switch {
case len(statements.Rollback) > 0 && len(statements.RollbackStatements) == 0:
statements.RollbackStatements = strings.Join(statements.Rollback, ";")
case len(statements.RollbackStatements) > 0:
statements.Rollback = []string{statements.RollbackStatements}
}
return statements
}

View File

@@ -0,0 +1,62 @@
package dbutil
import (
"reflect"
"testing"
"github.com/hashicorp/vault/sdk/dbplugin"
)
func TestStatementCompatibilityHelper(t *testing.T) {
const (
creationStatement = "creation"
renewStatement = "renew"
revokeStatement = "revoke"
rollbackStatement = "rollback"
)
expectedStatements := dbplugin.Statements{
Creation: []string{creationStatement},
Rollback: []string{rollbackStatement},
Revocation: []string{revokeStatement},
Renewal: []string{renewStatement},
CreationStatements: creationStatement,
RenewStatements: renewStatement,
RollbackStatements: rollbackStatement,
RevocationStatements: revokeStatement,
}
statements1 := dbplugin.Statements{
CreationStatements: creationStatement,
RenewStatements: renewStatement,
RollbackStatements: rollbackStatement,
RevocationStatements: revokeStatement,
}
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements1)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements1)
}
statements2 := dbplugin.Statements{
Creation: []string{creationStatement},
Rollback: []string{rollbackStatement},
Revocation: []string{revokeStatement},
Renewal: []string{renewStatement},
}
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements2)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements2)
}
statements3 := dbplugin.Statements{
CreationStatements: creationStatement,
}
expectedStatements3 := dbplugin.Statements{
Creation: []string{creationStatement},
CreationStatements: creationStatement,
}
if !reflect.DeepEqual(expectedStatements3, StatementCompatibilityHelper(statements3)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements3, statements3)
}
}

137
sdk/dbplugin/plugin.go Normal file
View File

@@ -0,0 +1,137 @@
package dbplugin
import (
"context"
"fmt"
"time"
"google.golang.org/grpc"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
)
// Database is the interface that all database objects must implement.
type Database interface {
Type() (string, error)
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
RevokeUser(ctx context.Context, statements Statements, username string) error
RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error)
Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error)
Close() error
// DEPRECATED, will be removed in a future plugin version bump.
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error)
}
// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
// Look for plugin in the plugin catalog
pluginRunner, err := sys.LookupPlugin(ctx, pluginName, consts.PluginTypeDatabase)
if err != nil {
return nil, err
}
namedLogger := logger.Named(pluginName)
var transport string
var db Database
if pluginRunner.Builtin {
// Plugin is builtin so we can retrieve an instance of the interface
// from the pluginRunner. Then cast it to a Database.
dbRaw, err := pluginRunner.BuiltinFactory()
if err != nil {
return nil, errwrap.Wrapf("error initializing plugin: {{err}}", err)
}
var ok bool
db, ok = dbRaw.(Database)
if !ok {
return nil, fmt.Errorf("unsupported database type: %q", pluginName)
}
transport = "builtin"
} else {
// create a DatabasePluginClient instance
db, err = NewPluginClient(ctx, sys, pluginRunner, namedLogger, false)
if err != nil {
return nil, err
}
// Switch on the underlying database client type to get the transport
// method.
switch db.(*DatabasePluginClient).Database.(type) {
case *gRPCClient:
transport = "gRPC"
}
}
typeStr, err := db.Type()
if err != nil {
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
}
// Wrap with metrics middleware
db = &databaseMetricsMiddleware{
next: db,
typeStr: typeStr,
}
// Wrap with tracing middleware
if namedLogger.IsTrace() {
db = &databaseTracingMiddleware{
next: db,
logger: namedLogger.With("transport", transport),
}
}
return db, nil
}
// handshakeConfigs are used to just do a basic handshake between
// a plugin and host. If the handshake fails, a user friendly error is shown.
// This prevents users from executing bad plugins or executing a plugin
// directory. It is a UX feature, not a security feature.
var handshakeConfig = plugin.HandshakeConfig{
ProtocolVersion: 4,
MagicCookieKey: "VAULT_DATABASE_PLUGIN",
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
}
var _ plugin.Plugin = &GRPCDatabasePlugin{}
var _ plugin.GRPCPlugin = &GRPCDatabasePlugin{}
// GRPCDatabasePlugin is the plugin.Plugin implementation that only supports GRPC
// transport
type GRPCDatabasePlugin struct {
Impl Database
// Embeding this will disable the netRPC protocol
plugin.NetRPCUnsupportedPlugin
}
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
impl := &DatabaseErrorSanitizerMiddleware{
next: d.Impl,
}
RegisterDatabaseServer(s, &gRPCServer{impl: impl})
return nil
}
func (GRPCDatabasePlugin) GRPCClient(doneCtx context.Context, _ *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &gRPCClient{
client: NewDatabaseClient(c),
clientConn: c,
doneCtx: doneCtx,
}, nil
}

43
sdk/dbplugin/server.go Normal file
View File

@@ -0,0 +1,43 @@
package dbplugin
import (
"crypto/tls"
plugin "github.com/hashicorp/go-plugin"
)
// Serve is called from within a plugin and wraps the provided
// Database implementation in a databasePluginRPCServer object and starts a
// RPC server.
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
plugin.Serve(ServeConfig(db, tlsProvider))
}
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
// pluginSets is the map of plugins we can dispense.
pluginSets := map[int]plugin.PluginSet{
// Version 3 used to supports both protocols. We want to keep it around
// since it's possible old plugins built against this version will still
// work with gRPC. There is currently no difference between version 3
// and version 4.
3: plugin.PluginSet{
"database": &GRPCDatabasePlugin{
Impl: db,
},
},
4: plugin.PluginSet{
"database": &GRPCDatabasePlugin{
Impl: db,
},
},
}
conf := &plugin.ServeConfig{
HandshakeConfig: handshakeConfig,
VersionedPlugins: pluginSets,
TLSProvider: tlsProvider,
GRPCServer: plugin.DefaultGRPCServer,
}
return conf
}

View File

@@ -3,6 +3,7 @@ module github.com/hashicorp/vault/sdk
go 1.12
require (
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310
github.com/fatih/structs v1.1.0
github.com/go-ldap/ldap v3.0.2+incompatible
@@ -24,10 +25,11 @@ require (
github.com/mitchellh/copystructure v1.0.0
github.com/mitchellh/go-testing-interface v1.0.0
github.com/mitchellh/mapstructure v1.1.2
github.com/pascaldekloe/goe v0.1.0 // indirect
github.com/pierrec/lz4 v2.0.5+incompatible
github.com/ryanuber/go-glob v1.0.0
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
golang.org/x/net v0.0.0-20190311183353-d8887717615a
golang.org/x/net v0.0.0-20190311183353-d8887717615a // indirect
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e
golang.org/x/text v0.3.1-0.20181227161524-e6919f6577db // indirect
google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107 // indirect

View File

@@ -1,5 +1,7 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da h1:8GUt8eRujhVEGZFFEjBj46YV4rDjvGrNxb0KMWYkL2I=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310 h1:BUAU3CGlLvorLI26FmByPp2eC2qla6E1Tw+scpcg/to=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
@@ -75,6 +77,8 @@ github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/I
github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw=
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@@ -0,0 +1,41 @@
// Package base62 provides utilities for working with base62 strings.
// base62 strings will only contain characters: 0-9, a-z, A-Z
package base62
import (
uuid "github.com/hashicorp/go-uuid"
)
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
const csLen = byte(len(charset))
// Random generates a random string using base-62 characters.
// Resulting entropy is ~5.95 bits/character.
func Random(length int) (string, error) {
if length == 0 {
return "", nil
}
output := make([]byte, 0, length)
// Request a bit more than length to reduce the chance
// of needing more than one batch of random bytes
batchSize := length + length/4
for {
buf, err := uuid.GenerateRandomBytes(batchSize)
if err != nil {
return "", err
}
for _, b := range buf {
// Avoid bias by using a value range that's a multiple of 62
if b < (csLen * 4) {
output = append(output, charset[b%csLen])
if len(output) == length {
return string(output), nil
}
}
}
}
}

View File

@@ -0,0 +1,31 @@
package base62
import (
"testing"
)
func TestRandom(t *testing.T) {
strings := make(map[string]struct{})
for i := 0; i < 100000; i++ {
c, err := Random(16)
if err != nil {
t.Fatal(err)
}
if _, ok := strings[c]; ok {
t.Fatalf("Unexpected duplicate string: %s", c)
}
strings[c] = struct{}{}
}
for i := 0; i < 3000; i++ {
c, err := Random(i)
if err != nil {
t.Fatal(err)
}
if len(c) != i {
t.Fatalf("Expected length %d, got: %d", i, len(c))
}
}
}

63
sdk/helper/dbtxn/dbtxn.go Normal file
View File

@@ -0,0 +1,63 @@
package dbtxn
import (
"context"
"database/sql"
"fmt"
"strings"
)
// ExecuteDBQuery handles executing one single statement, while properly releasing its resources.
// - ctx: Required
// - db: Required
// - config: Optional, may be nil
// - query: Required
func ExecuteDBQuery(ctx context.Context, db *sql.DB, params map[string]string, query string) error {
parsedQuery := parseQuery(params, query)
stmt, err := db.PrepareContext(ctx, parsedQuery)
if err != nil {
return err
}
defer stmt.Close()
return execute(ctx, stmt)
}
// ExecuteTxQuery handles executing one single statement, while properly releasing its resources.
// - ctx: Required
// - tx: Required
// - config: Optional, may be nil
// - query: Required
func ExecuteTxQuery(ctx context.Context, tx *sql.Tx, params map[string]string, query string) error {
parsedQuery := parseQuery(params, query)
stmt, err := tx.PrepareContext(ctx, parsedQuery)
if err != nil {
return err
}
defer stmt.Close()
return execute(ctx, stmt)
}
func execute(ctx context.Context, stmt *sql.Stmt) error {
if _, err := stmt.ExecContext(ctx); err != nil {
return err
}
return nil
}
func parseQuery(m map[string]string, tpl string) string {
if m == nil || len(m) <= 0 {
return tpl
}
for k, v := range m {
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
}
return tpl
}