mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
Migrate database plugin methods to sdk
This commit is contained in:
89
sdk/dbplugin/client.go
Normal file
89
sdk/dbplugin/client.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
)
|
||||
|
||||
// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close
|
||||
// method to also call Kill() on the plugin.Client.
|
||||
type DatabasePluginClient struct {
|
||||
client *plugin.Client
|
||||
sync.Mutex
|
||||
|
||||
Database
|
||||
}
|
||||
|
||||
// This wraps the Close call and ensures we both close the database connection
|
||||
// and kill the plugin.
|
||||
func (dc *DatabasePluginClient) Close() error {
|
||||
err := dc.Database.Close()
|
||||
dc.client.Kill()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// NewPluginClient returns a databaseRPCClient with a connection to a running
|
||||
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
|
||||
// plugin is killed on call of Close().
|
||||
func NewPluginClient(ctx context.Context, sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner, logger log.Logger, isMetadataMode bool) (Database, error) {
|
||||
|
||||
// pluginSets is the map of plugins we can dispense.
|
||||
pluginSets := map[int]plugin.PluginSet{
|
||||
// Version 3 used to supports both protocols. We want to keep it around
|
||||
// since it's possible old plugins built against this version will still
|
||||
// work with gRPC. There is currently no difference between version 3
|
||||
// and version 4.
|
||||
3: plugin.PluginSet{
|
||||
"database": new(GRPCDatabasePlugin),
|
||||
},
|
||||
// Version 4 only supports gRPC
|
||||
4: plugin.PluginSet{
|
||||
"database": new(GRPCDatabasePlugin),
|
||||
},
|
||||
}
|
||||
|
||||
var client *plugin.Client
|
||||
var err error
|
||||
if isMetadataMode {
|
||||
client, err = pluginRunner.RunMetadataMode(ctx, sys, pluginSets, handshakeConfig, []string{}, logger)
|
||||
} else {
|
||||
client, err = pluginRunner.Run(ctx, sys, pluginSets, handshakeConfig, []string{}, logger)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Connect via RPC
|
||||
rpcClient, err := client.Client()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Request the plugin
|
||||
raw, err := rpcClient.Dispense("database")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We should have a database type now. This feels like a normal interface
|
||||
// implementation but is in fact over an RPC connection.
|
||||
var db Database
|
||||
switch raw.(type) {
|
||||
case *gRPCClient:
|
||||
db = raw.(*gRPCClient)
|
||||
default:
|
||||
return nil, errors.New("unsupported client type")
|
||||
}
|
||||
|
||||
// Wrap RPC implementation in DatabasePluginClient
|
||||
return &DatabasePluginClient{
|
||||
client: client,
|
||||
Database: db,
|
||||
}, nil
|
||||
}
|
||||
1062
sdk/dbplugin/database.pb.go
Normal file
1062
sdk/dbplugin/database.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
93
sdk/dbplugin/database.proto
Normal file
93
sdk/dbplugin/database.proto
Normal file
@@ -0,0 +1,93 @@
|
||||
syntax = "proto3";
|
||||
|
||||
option go_package = "github.com/hashicorp/vault/sdk/dbplugin";
|
||||
|
||||
package dbplugin;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message InitializeRequest {
|
||||
option deprecated = true;
|
||||
bytes config = 1;
|
||||
bool verify_connection = 2;
|
||||
}
|
||||
|
||||
message InitRequest {
|
||||
bytes config = 1;
|
||||
bool verify_connection = 2;
|
||||
}
|
||||
|
||||
message CreateUserRequest {
|
||||
Statements statements = 1;
|
||||
UsernameConfig username_config = 2;
|
||||
google.protobuf.Timestamp expiration = 3;
|
||||
}
|
||||
|
||||
message RenewUserRequest {
|
||||
Statements statements = 1;
|
||||
string username = 2;
|
||||
google.protobuf.Timestamp expiration = 3;
|
||||
}
|
||||
|
||||
message RevokeUserRequest {
|
||||
Statements statements = 1;
|
||||
string username = 2;
|
||||
}
|
||||
|
||||
message RotateRootCredentialsRequest {
|
||||
repeated string statements = 1;
|
||||
}
|
||||
|
||||
message Statements {
|
||||
// DEPRECATED, will be removed in 0.12
|
||||
string creation_statements = 1 [deprecated=true];
|
||||
// DEPRECATED, will be removed in 0.12
|
||||
string revocation_statements = 2 [deprecated=true];
|
||||
// DEPRECATED, will be removed in 0.12
|
||||
string rollback_statements = 3 [deprecated=true];
|
||||
// DEPRECATED, will be removed in 0.12
|
||||
string renew_statements = 4 [deprecated=true];
|
||||
|
||||
repeated string creation = 5;
|
||||
repeated string revocation = 6;
|
||||
repeated string rollback = 7;
|
||||
repeated string renewal = 8;
|
||||
}
|
||||
|
||||
message UsernameConfig {
|
||||
string DisplayName = 1;
|
||||
string RoleName = 2;
|
||||
}
|
||||
|
||||
message InitResponse {
|
||||
bytes config = 1;
|
||||
}
|
||||
|
||||
message CreateUserResponse {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
message TypeResponse {
|
||||
string type = 1;
|
||||
}
|
||||
|
||||
message RotateRootCredentialsResponse {
|
||||
bytes config = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
service Database {
|
||||
rpc Type(Empty) returns (TypeResponse);
|
||||
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
|
||||
rpc RenewUser(RenewUserRequest) returns (Empty);
|
||||
rpc RevokeUser(RevokeUserRequest) returns (Empty);
|
||||
rpc RotateRootCredentials(RotateRootCredentialsRequest) returns (RotateRootCredentialsResponse);
|
||||
rpc Init(InitRequest) returns (InitResponse);
|
||||
rpc Close(Empty) returns (Empty);
|
||||
|
||||
rpc Initialize(InitializeRequest) returns (Empty) {
|
||||
option deprecated = true;
|
||||
};
|
||||
}
|
||||
275
sdk/dbplugin/databasemiddleware.go
Normal file
275
sdk/dbplugin/databasemiddleware.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
|
||||
metrics "github.com/armon/go-metrics"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
// ---- Tracing Middleware Domain ----
|
||||
|
||||
// databaseTracingMiddleware wraps a implementation of Database and executes
|
||||
// trace logging on function call.
|
||||
type databaseTracingMiddleware struct {
|
||||
next Database
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Type() (string, error) {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("create user", "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("create user", "status", "started")
|
||||
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("renew user", "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("renew user", "status", "started")
|
||||
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("revoke user", "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("revoke user", "status", "started")
|
||||
return mw.next.RevokeUser(ctx, statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("rotate root credentials", "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("rotate root credentials", "status", "started")
|
||||
return mw.next.RotateRootCredentials(ctx, statements)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
_, err := mw.Init(ctx, conf, verifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("initialize", "status", "finished", "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("initialize", "status", "started")
|
||||
return mw.next.Init(ctx, conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseTracingMiddleware) Close() (err error) {
|
||||
defer func(then time.Time) {
|
||||
mw.logger.Trace("close", "status", "finished", "err", err, "took", time.Since(then))
|
||||
}(time.Now())
|
||||
|
||||
mw.logger.Trace("close", "status", "started")
|
||||
return mw.next.Close()
|
||||
}
|
||||
|
||||
// ---- Metrics Middleware Domain ----
|
||||
|
||||
// databaseMetricsMiddleware wraps an implementation of Databases and on
|
||||
// function call logs metrics about this instance.
|
||||
type databaseMetricsMiddleware struct {
|
||||
next Database
|
||||
|
||||
typeStr string
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Type() (string, error) {
|
||||
return mw.next.Type()
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "CreateUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "CreateUser", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "CreateUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1)
|
||||
return mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RenewUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "RenewUser", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "RenewUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1)
|
||||
return mw.next.RenewUser(ctx, statements, username, expiration)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RevokeUser"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "RevokeUser", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "RevokeUser"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1)
|
||||
return mw.next.RevokeUser(ctx, statements, username)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "RotateRootCredentials"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "RotateRootCredentials"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "RotateRootCredentials", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "RotateRootCredentials"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials"}, 1)
|
||||
return mw.next.RotateRootCredentials(ctx, statements)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
_, err := mw.Init(ctx, conf, verifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
||||
return mw.next.Init(ctx, conf, verifyConnection)
|
||||
}
|
||||
|
||||
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
||||
defer func(now time.Time) {
|
||||
metrics.MeasureSince([]string{"database", "Close"}, now)
|
||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now)
|
||||
|
||||
if err != nil {
|
||||
metrics.IncrCounter([]string{"database", "Close", "error"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
metrics.IncrCounter([]string{"database", "Close"}, 1)
|
||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
|
||||
return mw.next.Close()
|
||||
}
|
||||
|
||||
// ---- Error Sanitizer Middleware Domain ----
|
||||
|
||||
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
|
||||
// sanitizes returned error messages
|
||||
type DatabaseErrorSanitizerMiddleware struct {
|
||||
l sync.RWMutex
|
||||
next Database
|
||||
secretsFn func() map[string]interface{}
|
||||
}
|
||||
|
||||
func NewDatabaseErrorSanitizerMiddleware(next Database, secretsFn func() map[string]interface{}) *DatabaseErrorSanitizerMiddleware {
|
||||
return &DatabaseErrorSanitizerMiddleware{
|
||||
next: next,
|
||||
secretsFn: secretsFn,
|
||||
}
|
||||
}
|
||||
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) Type() (string, error) {
|
||||
dbType, err := mw.next.Type()
|
||||
return dbType, mw.sanitize(err)
|
||||
}
|
||||
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
username, password, err = mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||
return username, password, mw.sanitize(err)
|
||||
}
|
||||
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||
return mw.sanitize(mw.next.RenewUser(ctx, statements, username, expiration))
|
||||
}
|
||||
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||
return mw.sanitize(mw.next.RevokeUser(ctx, statements, username))
|
||||
}
|
||||
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
|
||||
conf, err = mw.next.RotateRootCredentials(ctx, statements)
|
||||
return conf, mw.sanitize(err)
|
||||
}
|
||||
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
_, err := mw.Init(ctx, conf, verifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
|
||||
saveConf, err = mw.next.Init(ctx, conf, verifyConnection)
|
||||
return saveConf, mw.sanitize(err)
|
||||
}
|
||||
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) Close() (err error) {
|
||||
return mw.sanitize(mw.next.Close())
|
||||
}
|
||||
|
||||
// sanitize
|
||||
func (mw *DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errwrap.ContainsType(err, new(url.Error)) {
|
||||
return errors.New("unable to parse connection url")
|
||||
}
|
||||
if mw.secretsFn != nil {
|
||||
for k, v := range mw.secretsFn() {
|
||||
if k == "" {
|
||||
continue
|
||||
}
|
||||
err = errors.New(strings.Replace(err.Error(), k, v.(string), -1))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
285
sdk/dbplugin/grpc_transport.go
Normal file
285
sdk/dbplugin/grpc_transport.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPluginShutdown = errors.New("plugin shutdown")
|
||||
)
|
||||
|
||||
// ---- gRPC Server domain ----
|
||||
|
||||
type gRPCServer struct {
|
||||
impl Database
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Type(context.Context, *Empty) (*TypeResponse, error) {
|
||||
t, err := s.impl.Type()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TypeResponse{
|
||||
Type: t,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *gRPCServer) CreateUser(ctx context.Context, req *CreateUserRequest) (*CreateUserResponse, error) {
|
||||
e, err := ptypes.Timestamp(req.Expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, p, err := s.impl.CreateUser(ctx, *req.Statements, *req.UsernameConfig, e)
|
||||
|
||||
return &CreateUserResponse{
|
||||
Username: u,
|
||||
Password: p,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RenewUser(ctx context.Context, req *RenewUserRequest) (*Empty, error) {
|
||||
e, err := ptypes.Timestamp(req.Expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.impl.RenewUser(ctx, *req.Statements, req.Username, e)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*Empty, error) {
|
||||
err := s.impl.RevokeUser(ctx, *req.Statements, req.Username)
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) RotateRootCredentials(ctx context.Context, req *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error) {
|
||||
|
||||
resp, err := s.impl.RotateRootCredentials(ctx, req.Statements)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
respConfig, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RotateRootCredentialsResponse{
|
||||
Config: respConfig,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
|
||||
_, err := s.Init(ctx, &InitRequest{
|
||||
Config: req.Config,
|
||||
VerifyConnection: req.VerifyConnection,
|
||||
})
|
||||
return &Empty{}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Init(ctx context.Context, req *InitRequest) (*InitResponse, error) {
|
||||
config := map[string]interface{}{}
|
||||
err := json.Unmarshal(req.Config, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := s.impl.Init(ctx, config, req.VerifyConnection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
respConfig, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &InitResponse{
|
||||
Config: respConfig,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
|
||||
s.impl.Close()
|
||||
return &Empty{}, nil
|
||||
}
|
||||
|
||||
// ---- gRPC client domain ----
|
||||
|
||||
type gRPCClient struct {
|
||||
client DatabaseClient
|
||||
clientConn *grpc.ClientConn
|
||||
|
||||
doneCtx context.Context
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Type() (string, error) {
|
||||
resp, err := c.client.Type(c.doneCtx, &Empty{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resp.Type, err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||
t, err := ptypes.TimestampProto(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
resp, err := c.client.CreateUser(ctx, &CreateUserRequest{
|
||||
Statements: &statements,
|
||||
UsernameConfig: &usernameConfig,
|
||||
Expiration: t,
|
||||
})
|
||||
if err != nil {
|
||||
if c.doneCtx.Err() != nil {
|
||||
return "", "", ErrPluginShutdown
|
||||
}
|
||||
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return resp.Username, resp.Password, err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error {
|
||||
t, err := ptypes.TimestampProto(expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
_, err = c.client.RenewUser(ctx, &RenewUserRequest{
|
||||
Statements: &statements,
|
||||
Username: username,
|
||||
Expiration: t,
|
||||
})
|
||||
if err != nil {
|
||||
if c.doneCtx.Err() != nil {
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, username string) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.client.RevokeUser(ctx, &RevokeUserRequest{
|
||||
Statements: &statements,
|
||||
Username: username,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if c.doneCtx.Err() != nil {
|
||||
return ErrPluginShutdown
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gRPCClient) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
resp, err := c.client.RotateRootCredentials(ctx, &RotateRootCredentialsRequest{
|
||||
Statements: statements,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if c.doneCtx.Err() != nil {
|
||||
return nil, ErrPluginShutdown
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp.Config, &conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
_, err := c.Init(ctx, conf, verifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
|
||||
configRaw, err := json.Marshal(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
|
||||
defer close(quitCh)
|
||||
defer cancel()
|
||||
|
||||
resp, err := c.client.Init(ctx, &InitRequest{
|
||||
Config: configRaw,
|
||||
VerifyConnection: verifyConnection,
|
||||
})
|
||||
if err != nil {
|
||||
// Fall back to old call if not implemented
|
||||
grpcStatus, ok := status.FromError(err)
|
||||
if ok && grpcStatus.Code() == codes.Unimplemented {
|
||||
_, err = c.client.Initialize(ctx, &InitializeRequest{
|
||||
Config: configRaw,
|
||||
VerifyConnection: verifyConnection,
|
||||
})
|
||||
if err == nil {
|
||||
return conf, nil
|
||||
}
|
||||
}
|
||||
|
||||
if c.doneCtx.Err() != nil {
|
||||
return nil, ErrPluginShutdown
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp.Config, &conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (c *gRPCClient) Close() error {
|
||||
_, err := c.client.Close(c.doneCtx, &Empty{})
|
||||
return err
|
||||
}
|
||||
25
sdk/dbplugin/helper/connutil/connutil.go
Normal file
25
sdk/dbplugin/helper/connutil/connutil.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package connutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotInitialized = errors.New("connection has not been initialized")
|
||||
)
|
||||
|
||||
// ConnectionProducer can be used as an embedded interface in the Database
|
||||
// definition. It implements the methods dealing with individual database
|
||||
// connections and is used in all the builtin database types.
|
||||
type ConnectionProducer interface {
|
||||
Close() error
|
||||
Init(context.Context, map[string]interface{}, bool) (map[string]interface{}, error)
|
||||
Connection(context.Context) (interface{}, error)
|
||||
|
||||
sync.Locker
|
||||
|
||||
// DEPRECATED, will be removed in 0.12
|
||||
Initialize(context.Context, map[string]interface{}, bool) error
|
||||
}
|
||||
164
sdk/dbplugin/helper/connutil/sql.go
Normal file
164
sdk/dbplugin/helper/connutil/sql.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package connutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
"github.com/hashicorp/vault/sdk/dbplugin/helper/dbutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
var _ ConnectionProducer = &SQLConnectionProducer{}
|
||||
|
||||
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
||||
type SQLConnectionProducer struct {
|
||||
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
|
||||
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
|
||||
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
|
||||
Username string `json:"username" mapstructure:"username" structs:"username"`
|
||||
Password string `json:"password" mapstructure:"password" structs:"password"`
|
||||
|
||||
Type string
|
||||
RawConfig map[string]interface{}
|
||||
maxConnectionLifetime time.Duration
|
||||
Initialized bool
|
||||
db *sql.DB
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||
_, err := c.Init(ctx, conf, verifyConnection)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
c.RawConfig = conf
|
||||
|
||||
err := mapstructure.WeakDecode(conf, &c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(c.ConnectionURL) == 0 {
|
||||
return nil, fmt.Errorf("connection_url cannot be empty")
|
||||
}
|
||||
|
||||
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
|
||||
"username": c.Username,
|
||||
"password": c.Password,
|
||||
})
|
||||
|
||||
if c.MaxOpenConnections == 0 {
|
||||
c.MaxOpenConnections = 2
|
||||
}
|
||||
|
||||
if c.MaxIdleConnections == 0 {
|
||||
c.MaxIdleConnections = c.MaxOpenConnections
|
||||
}
|
||||
if c.MaxIdleConnections > c.MaxOpenConnections {
|
||||
c.MaxIdleConnections = c.MaxOpenConnections
|
||||
}
|
||||
if c.MaxConnectionLifetimeRaw == nil {
|
||||
c.MaxConnectionLifetimeRaw = "0s"
|
||||
}
|
||||
|
||||
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
|
||||
}
|
||||
|
||||
// Set initialized to true at this point since all fields are set,
|
||||
// and the connection can be established at a later time.
|
||||
c.Initialized = true
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(ctx); err != nil {
|
||||
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||
}
|
||||
|
||||
if err := c.db.PingContext(ctx); err != nil {
|
||||
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||
}
|
||||
}
|
||||
|
||||
return c.RawConfig, nil
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, ErrNotInitialized
|
||||
}
|
||||
|
||||
// If we already have a DB, test it and return
|
||||
if c.db != nil {
|
||||
if err := c.db.PingContext(ctx); err == nil {
|
||||
return c.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
// For mssql backend, switch to sqlserver instead
|
||||
dbType := c.Type
|
||||
if c.Type == "mssql" {
|
||||
dbType = "sqlserver"
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
conn := c.ConnectionURL
|
||||
|
||||
// Ensure timezone is set to UTC for all the connections
|
||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||
if strings.Contains(conn, "?") {
|
||||
conn += "&timezone=utc"
|
||||
} else {
|
||||
conn += "?timezone=utc"
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
c.db, err = sql.Open(dbType, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
c.db.SetMaxOpenConns(c.MaxOpenConnections)
|
||||
c.db.SetMaxIdleConns(c.MaxIdleConnections)
|
||||
c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
|
||||
|
||||
return c.db, nil
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) SecretValues() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
c.Password: "[password]",
|
||||
}
|
||||
}
|
||||
|
||||
// Close attempts to close the connection
|
||||
func (c *SQLConnectionProducer) Close() error {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if c.db != nil {
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
c.db = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
46
sdk/dbplugin/helper/credsutil/credsutil.go
Normal file
46
sdk/dbplugin/helper/credsutil/credsutil.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package credsutil
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/base62"
|
||||
)
|
||||
|
||||
// CredentialsProducer can be used as an embedded interface in the Database
|
||||
// definition. It implements the methods for generating user information for a
|
||||
// particular database type and is used in all the builtin database types.
|
||||
type CredentialsProducer interface {
|
||||
GenerateUsername(usernameConfig dbplugin.UsernameConfig) (string, error)
|
||||
GeneratePassword() (string, error)
|
||||
GenerateExpiration(ttl time.Time) (string, error)
|
||||
}
|
||||
|
||||
const (
|
||||
reqStr = `A1a-`
|
||||
minStrLen = 10
|
||||
)
|
||||
|
||||
// RandomAlphaNumeric returns a random string of characters [A-Za-z0-9-]
|
||||
// of the provided length. The string generated takes up to 4 characters
|
||||
// of space that are predefined and prepended to ensure password
|
||||
// character requirements. It also requires a min length of 10 characters.
|
||||
func RandomAlphaNumeric(length int, prependA1a bool) (string, error) {
|
||||
if length < minStrLen {
|
||||
return "", fmt.Errorf("minimum length of %d is required", minStrLen)
|
||||
}
|
||||
|
||||
var prefix string
|
||||
if prependA1a {
|
||||
prefix = reqStr
|
||||
}
|
||||
|
||||
randomStr, err := base62.Random(length - len(prefix))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return prefix + randomStr, nil
|
||||
}
|
||||
40
sdk/dbplugin/helper/credsutil/credsutil_test.go
Normal file
40
sdk/dbplugin/helper/credsutil/credsutil_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package credsutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRandomAlphaNumeric(t *testing.T) {
|
||||
s, err := RandomAlphaNumeric(10, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %s", err)
|
||||
}
|
||||
if len(s) != 10 {
|
||||
t.Fatalf("Unexpected length of string, expected 10, got string: %s", s)
|
||||
}
|
||||
|
||||
s, err = RandomAlphaNumeric(20, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %s", err)
|
||||
}
|
||||
if len(s) != 20 {
|
||||
t.Fatalf("Unexpected length of string, expected 20, got string: %s", s)
|
||||
}
|
||||
|
||||
if !strings.Contains(s, reqStr) {
|
||||
t.Fatalf("Expected %s to contain %s", s, reqStr)
|
||||
}
|
||||
|
||||
s, err = RandomAlphaNumeric(20, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %s", err)
|
||||
}
|
||||
if len(s) != 20 {
|
||||
t.Fatalf("Unexpected length of string, expected 20, got string: %s", s)
|
||||
}
|
||||
|
||||
if strings.Contains(s, reqStr) {
|
||||
t.Fatalf("Expected %s not to contain %s", s, reqStr)
|
||||
}
|
||||
}
|
||||
72
sdk/dbplugin/helper/credsutil/sql.go
Normal file
72
sdk/dbplugin/helper/credsutil/sql.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package credsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/dbplugin"
|
||||
)
|
||||
|
||||
const (
|
||||
NoneLength int = -1
|
||||
)
|
||||
|
||||
// SQLCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types.
|
||||
type SQLCredentialsProducer struct {
|
||||
DisplayNameLen int
|
||||
RoleNameLen int
|
||||
UsernameLen int
|
||||
Separator string
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GenerateUsername(config dbplugin.UsernameConfig) (string, error) {
|
||||
username := "v"
|
||||
|
||||
displayName := config.DisplayName
|
||||
if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen {
|
||||
displayName = displayName[:scp.DisplayNameLen]
|
||||
} else if scp.DisplayNameLen == NoneLength {
|
||||
displayName = ""
|
||||
}
|
||||
|
||||
if len(displayName) > 0 {
|
||||
username = fmt.Sprintf("%s%s%s", username, scp.Separator, displayName)
|
||||
}
|
||||
|
||||
roleName := config.RoleName
|
||||
if scp.RoleNameLen > 0 && len(roleName) > scp.RoleNameLen {
|
||||
roleName = roleName[:scp.RoleNameLen]
|
||||
} else if scp.RoleNameLen == NoneLength {
|
||||
roleName = ""
|
||||
}
|
||||
|
||||
if len(roleName) > 0 {
|
||||
username = fmt.Sprintf("%s%s%s", username, scp.Separator, roleName)
|
||||
}
|
||||
|
||||
userUUID, err := RandomAlphaNumeric(20, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID)
|
||||
username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().Unix()))
|
||||
if scp.UsernameLen > 0 && len(username) > scp.UsernameLen {
|
||||
username = username[:scp.UsernameLen]
|
||||
}
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := RandomAlphaNumeric(20, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
|
||||
return ttl.Format("2006-01-02 15:04:05-0700"), nil
|
||||
}
|
||||
52
sdk/dbplugin/helper/dbutil/dbutil.go
Normal file
52
sdk/dbplugin/helper/dbutil/dbutil.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/dbplugin"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyCreationStatement = errors.New("empty creation statements")
|
||||
)
|
||||
|
||||
// Query templates a query for us.
|
||||
func QueryHelper(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
||||
|
||||
// StatementCompatibilityHelper will populate the statements fields to support
|
||||
// compatibility
|
||||
func StatementCompatibilityHelper(statements dbplugin.Statements) dbplugin.Statements {
|
||||
switch {
|
||||
case len(statements.Creation) > 0 && len(statements.CreationStatements) == 0:
|
||||
statements.CreationStatements = strings.Join(statements.Creation, ";")
|
||||
case len(statements.CreationStatements) > 0:
|
||||
statements.Creation = []string{statements.CreationStatements}
|
||||
}
|
||||
switch {
|
||||
case len(statements.Revocation) > 0 && len(statements.RevocationStatements) == 0:
|
||||
statements.RevocationStatements = strings.Join(statements.Revocation, ";")
|
||||
case len(statements.RevocationStatements) > 0:
|
||||
statements.Revocation = []string{statements.RevocationStatements}
|
||||
}
|
||||
switch {
|
||||
case len(statements.Renewal) > 0 && len(statements.RenewStatements) == 0:
|
||||
statements.RenewStatements = strings.Join(statements.Renewal, ";")
|
||||
case len(statements.RenewStatements) > 0:
|
||||
statements.Renewal = []string{statements.RenewStatements}
|
||||
}
|
||||
switch {
|
||||
case len(statements.Rollback) > 0 && len(statements.RollbackStatements) == 0:
|
||||
statements.RollbackStatements = strings.Join(statements.Rollback, ";")
|
||||
case len(statements.RollbackStatements) > 0:
|
||||
statements.Rollback = []string{statements.RollbackStatements}
|
||||
}
|
||||
return statements
|
||||
}
|
||||
62
sdk/dbplugin/helper/dbutil/dbutil_test.go
Normal file
62
sdk/dbplugin/helper/dbutil/dbutil_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/dbplugin"
|
||||
)
|
||||
|
||||
func TestStatementCompatibilityHelper(t *testing.T) {
|
||||
const (
|
||||
creationStatement = "creation"
|
||||
renewStatement = "renew"
|
||||
revokeStatement = "revoke"
|
||||
rollbackStatement = "rollback"
|
||||
)
|
||||
|
||||
expectedStatements := dbplugin.Statements{
|
||||
Creation: []string{creationStatement},
|
||||
Rollback: []string{rollbackStatement},
|
||||
Revocation: []string{revokeStatement},
|
||||
Renewal: []string{renewStatement},
|
||||
CreationStatements: creationStatement,
|
||||
RenewStatements: renewStatement,
|
||||
RollbackStatements: rollbackStatement,
|
||||
RevocationStatements: revokeStatement,
|
||||
}
|
||||
|
||||
statements1 := dbplugin.Statements{
|
||||
CreationStatements: creationStatement,
|
||||
RenewStatements: renewStatement,
|
||||
RollbackStatements: rollbackStatement,
|
||||
RevocationStatements: revokeStatement,
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements1)) {
|
||||
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements1)
|
||||
}
|
||||
|
||||
statements2 := dbplugin.Statements{
|
||||
Creation: []string{creationStatement},
|
||||
Rollback: []string{rollbackStatement},
|
||||
Revocation: []string{revokeStatement},
|
||||
Renewal: []string{renewStatement},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements2)) {
|
||||
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements2)
|
||||
}
|
||||
|
||||
statements3 := dbplugin.Statements{
|
||||
CreationStatements: creationStatement,
|
||||
}
|
||||
expectedStatements3 := dbplugin.Statements{
|
||||
Creation: []string{creationStatement},
|
||||
CreationStatements: creationStatement,
|
||||
}
|
||||
if !reflect.DeepEqual(expectedStatements3, StatementCompatibilityHelper(statements3)) {
|
||||
t.Fatalf("mismatch: %#v, %#v", expectedStatements3, statements3)
|
||||
}
|
||||
|
||||
}
|
||||
137
sdk/dbplugin/plugin.go
Normal file
137
sdk/dbplugin/plugin.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
)
|
||||
|
||||
// Database is the interface that all database objects must implement.
|
||||
type Database interface {
|
||||
Type() (string, error)
|
||||
CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error)
|
||||
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
|
||||
RevokeUser(ctx context.Context, statements Statements, username string) error
|
||||
|
||||
RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error)
|
||||
|
||||
Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error)
|
||||
Close() error
|
||||
|
||||
// DEPRECATED, will be removed in a future plugin version bump.
|
||||
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error)
|
||||
}
|
||||
|
||||
// PluginFactory is used to build plugin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) {
|
||||
// Look for plugin in the plugin catalog
|
||||
pluginRunner, err := sys.LookupPlugin(ctx, pluginName, consts.PluginTypeDatabase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
namedLogger := logger.Named(pluginName)
|
||||
|
||||
var transport string
|
||||
var db Database
|
||||
if pluginRunner.Builtin {
|
||||
// Plugin is builtin so we can retrieve an instance of the interface
|
||||
// from the pluginRunner. Then cast it to a Database.
|
||||
dbRaw, err := pluginRunner.BuiltinFactory()
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error initializing plugin: {{err}}", err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
db, ok = dbRaw.(Database)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported database type: %q", pluginName)
|
||||
}
|
||||
|
||||
transport = "builtin"
|
||||
|
||||
} else {
|
||||
// create a DatabasePluginClient instance
|
||||
db, err = NewPluginClient(ctx, sys, pluginRunner, namedLogger, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Switch on the underlying database client type to get the transport
|
||||
// method.
|
||||
switch db.(*DatabasePluginClient).Database.(type) {
|
||||
case *gRPCClient:
|
||||
transport = "gRPC"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
typeStr, err := db.Type()
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
|
||||
}
|
||||
|
||||
// Wrap with metrics middleware
|
||||
db = &databaseMetricsMiddleware{
|
||||
next: db,
|
||||
typeStr: typeStr,
|
||||
}
|
||||
|
||||
// Wrap with tracing middleware
|
||||
if namedLogger.IsTrace() {
|
||||
db = &databaseTracingMiddleware{
|
||||
next: db,
|
||||
logger: namedLogger.With("transport", transport),
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// handshakeConfigs are used to just do a basic handshake between
|
||||
// a plugin and host. If the handshake fails, a user friendly error is shown.
|
||||
// This prevents users from executing bad plugins or executing a plugin
|
||||
// directory. It is a UX feature, not a security feature.
|
||||
var handshakeConfig = plugin.HandshakeConfig{
|
||||
ProtocolVersion: 4,
|
||||
MagicCookieKey: "VAULT_DATABASE_PLUGIN",
|
||||
MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb",
|
||||
}
|
||||
|
||||
var _ plugin.Plugin = &GRPCDatabasePlugin{}
|
||||
var _ plugin.GRPCPlugin = &GRPCDatabasePlugin{}
|
||||
|
||||
// GRPCDatabasePlugin is the plugin.Plugin implementation that only supports GRPC
|
||||
// transport
|
||||
type GRPCDatabasePlugin struct {
|
||||
Impl Database
|
||||
|
||||
// Embeding this will disable the netRPC protocol
|
||||
plugin.NetRPCUnsupportedPlugin
|
||||
}
|
||||
|
||||
func (d GRPCDatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
|
||||
impl := &DatabaseErrorSanitizerMiddleware{
|
||||
next: d.Impl,
|
||||
}
|
||||
|
||||
RegisterDatabaseServer(s, &gRPCServer{impl: impl})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (GRPCDatabasePlugin) GRPCClient(doneCtx context.Context, _ *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
|
||||
return &gRPCClient{
|
||||
client: NewDatabaseClient(c),
|
||||
clientConn: c,
|
||||
doneCtx: doneCtx,
|
||||
}, nil
|
||||
}
|
||||
43
sdk/dbplugin/server.go
Normal file
43
sdk/dbplugin/server.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package dbplugin
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
plugin "github.com/hashicorp/go-plugin"
|
||||
)
|
||||
|
||||
// Serve is called from within a plugin and wraps the provided
|
||||
// Database implementation in a databasePluginRPCServer object and starts a
|
||||
// RPC server.
|
||||
func Serve(db Database, tlsProvider func() (*tls.Config, error)) {
|
||||
plugin.Serve(ServeConfig(db, tlsProvider))
|
||||
}
|
||||
|
||||
func ServeConfig(db Database, tlsProvider func() (*tls.Config, error)) *plugin.ServeConfig {
|
||||
// pluginSets is the map of plugins we can dispense.
|
||||
pluginSets := map[int]plugin.PluginSet{
|
||||
// Version 3 used to supports both protocols. We want to keep it around
|
||||
// since it's possible old plugins built against this version will still
|
||||
// work with gRPC. There is currently no difference between version 3
|
||||
// and version 4.
|
||||
3: plugin.PluginSet{
|
||||
"database": &GRPCDatabasePlugin{
|
||||
Impl: db,
|
||||
},
|
||||
},
|
||||
4: plugin.PluginSet{
|
||||
"database": &GRPCDatabasePlugin{
|
||||
Impl: db,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
conf := &plugin.ServeConfig{
|
||||
HandshakeConfig: handshakeConfig,
|
||||
VersionedPlugins: pluginSets,
|
||||
TLSProvider: tlsProvider,
|
||||
GRPCServer: plugin.DefaultGRPCServer,
|
||||
}
|
||||
|
||||
return conf
|
||||
}
|
||||
@@ -3,6 +3,7 @@ module github.com/hashicorp/vault/sdk
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da
|
||||
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310
|
||||
github.com/fatih/structs v1.1.0
|
||||
github.com/go-ldap/ldap v3.0.2+incompatible
|
||||
@@ -24,10 +25,11 @@ require (
|
||||
github.com/mitchellh/copystructure v1.0.0
|
||||
github.com/mitchellh/go-testing-interface v1.0.0
|
||||
github.com/mitchellh/mapstructure v1.1.2
|
||||
github.com/pascaldekloe/goe v0.1.0 // indirect
|
||||
github.com/pierrec/lz4 v2.0.5+incompatible
|
||||
github.com/ryanuber/go-glob v1.0.0
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a // indirect
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e
|
||||
golang.org/x/text v0.3.1-0.20181227161524-e6919f6577db // indirect
|
||||
google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107 // indirect
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da h1:8GUt8eRujhVEGZFFEjBj46YV4rDjvGrNxb0KMWYkL2I=
|
||||
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
|
||||
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310 h1:BUAU3CGlLvorLI26FmByPp2eC2qla6E1Tw+scpcg/to=
|
||||
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
|
||||
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
|
||||
@@ -75,6 +77,8 @@ github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/I
|
||||
github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw=
|
||||
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
|
||||
github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY=
|
||||
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
|
||||
github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I=
|
||||
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
||||
41
sdk/helper/base62/base62.go
Normal file
41
sdk/helper/base62/base62.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// Package base62 provides utilities for working with base62 strings.
|
||||
// base62 strings will only contain characters: 0-9, a-z, A-Z
|
||||
package base62
|
||||
|
||||
import (
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
const csLen = byte(len(charset))
|
||||
|
||||
// Random generates a random string using base-62 characters.
|
||||
// Resulting entropy is ~5.95 bits/character.
|
||||
func Random(length int) (string, error) {
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
output := make([]byte, 0, length)
|
||||
|
||||
// Request a bit more than length to reduce the chance
|
||||
// of needing more than one batch of random bytes
|
||||
batchSize := length + length/4
|
||||
|
||||
for {
|
||||
buf, err := uuid.GenerateRandomBytes(batchSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, b := range buf {
|
||||
// Avoid bias by using a value range that's a multiple of 62
|
||||
if b < (csLen * 4) {
|
||||
output = append(output, charset[b%csLen])
|
||||
|
||||
if len(output) == length {
|
||||
return string(output), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
31
sdk/helper/base62/base62_test.go
Normal file
31
sdk/helper/base62/base62_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package base62
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRandom(t *testing.T) {
|
||||
strings := make(map[string]struct{})
|
||||
|
||||
for i := 0; i < 100000; i++ {
|
||||
c, err := Random(16)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := strings[c]; ok {
|
||||
t.Fatalf("Unexpected duplicate string: %s", c)
|
||||
}
|
||||
strings[c] = struct{}{}
|
||||
|
||||
}
|
||||
|
||||
for i := 0; i < 3000; i++ {
|
||||
c, err := Random(i)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(c) != i {
|
||||
t.Fatalf("Expected length %d, got: %d", i, len(c))
|
||||
}
|
||||
}
|
||||
}
|
||||
63
sdk/helper/dbtxn/dbtxn.go
Normal file
63
sdk/helper/dbtxn/dbtxn.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package dbtxn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExecuteDBQuery handles executing one single statement, while properly releasing its resources.
|
||||
// - ctx: Required
|
||||
// - db: Required
|
||||
// - config: Optional, may be nil
|
||||
// - query: Required
|
||||
func ExecuteDBQuery(ctx context.Context, db *sql.DB, params map[string]string, query string) error {
|
||||
|
||||
parsedQuery := parseQuery(params, query)
|
||||
|
||||
stmt, err := db.PrepareContext(ctx, parsedQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
return execute(ctx, stmt)
|
||||
}
|
||||
|
||||
// ExecuteTxQuery handles executing one single statement, while properly releasing its resources.
|
||||
// - ctx: Required
|
||||
// - tx: Required
|
||||
// - config: Optional, may be nil
|
||||
// - query: Required
|
||||
func ExecuteTxQuery(ctx context.Context, tx *sql.Tx, params map[string]string, query string) error {
|
||||
|
||||
parsedQuery := parseQuery(params, query)
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, parsedQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
return execute(ctx, stmt)
|
||||
}
|
||||
|
||||
func execute(ctx context.Context, stmt *sql.Stmt) error {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseQuery(m map[string]string, tpl string) string {
|
||||
|
||||
if m == nil || len(m) <= 0 {
|
||||
return tpl
|
||||
}
|
||||
|
||||
for k, v := range m {
|
||||
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
|
||||
}
|
||||
return tpl
|
||||
}
|
||||
Reference in New Issue
Block a user