Database Root Credential Rotation (#3976)

* redoing connection handling

* a little more cleanup

* empty implementation of rotation

* updating rotate signature

* signature update

* updating interfaces again :(

* changing back to interface

* adding templated url support and rotation for postgres

* adding correct username

* return updates

* updating statements to be a list

* adding error sanitizing middleware

* fixing log sanitizier

* adding postgres rotate test

* removing conf from rotate

* adding rotate command

* adding mysql rotate

* finishing up the endpoint in the db backend for rotate

* no more structs, just store raw config

* fixing tests

* adding db instance lock

* adding support for statement list in cassandra

* wip redoing interface to support BC

* adding falllback for Initialize implementation

* adding backwards compat for statements

* fix tests

* fix more tests

* fixing up tests, switching to new fields in statements

* fixing more tests

* adding mssql and mysql

* wrapping all the things in middleware, implementing templating for mongodb

* wrapping all db servers with error santizer

* fixing test

* store the name with the db instance

* adding rotate to cassandra

* adding compatibility translation to both server and plugin

* reordering a few things

* store the name with the db instance

* reordering

* adding a few more tests

* switch secret values from slice to map

* addressing some feedback

* reinstate execute plugin after resetting connection

* set database connection to closed

* switching secret values func to map[string]interface for potential future uses

* addressing feedback
This commit is contained in:
Chris Hoffman
2018-03-21 15:05:56 -04:00
committed by GitHub
parent 1c443f22fe
commit 44aa151b78
33 changed files with 1974 additions and 777 deletions

View File

@@ -9,13 +9,37 @@ import (
log "github.com/mgutz/logxi/v1"
"github.com/hashicorp/errwrap"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
)
const databaseConfigPath = "database/config/"
type dbPluginInstance struct {
sync.RWMutex
dbplugin.Database
id string
name string
closed bool
}
func (dbi *dbPluginInstance) Close() error {
dbi.Lock()
defer dbi.Unlock()
if dbi.closed {
return nil
}
dbi.closed = true
return dbi.Database.Close()
}
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend(conf)
if err := b.Setup(ctx, conf); err != nil {
@@ -42,6 +66,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
pathRoles(&b),
pathCredsCreate(&b),
pathResetConnection(&b),
pathRotateCredentials(&b),
},
Secrets: []*framework.Secret{
@@ -53,72 +78,22 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
}
b.logger = conf.Logger
b.connections = make(map[string]dbplugin.Database)
b.connections = make(map[string]*dbPluginInstance)
return &b
}
type databaseBackend struct {
connections map[string]dbplugin.Database
connections map[string]*dbPluginInstance
logger log.Logger
*framework.Backend
sync.RWMutex
}
// closeAllDBs closes all connections from all database types
func (b *databaseBackend) closeAllDBs(ctx context.Context) {
b.Lock()
defer b.Unlock()
for _, db := range b.connections {
db.Close()
}
b.connections = make(map[string]dbplugin.Database)
}
// This function is used to retrieve a database object either from the cached
// connection map. The caller of this function needs to hold the backend's read
// lock.
func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
db, ok := b.connections[name]
return db, ok
}
// This function creates a new db object from the stored configuration and
// caches it in the connections map. The caller of this function needs to hold
// the backend's write lock
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
db, ok := b.connections[name]
if ok {
return db, nil
}
config, err := b.DatabaseConfig(ctx, s, name)
if err != nil {
return nil, err
}
db, err = dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
return nil, err
}
err = db.Initialize(ctx, config.ConnectionDetails, true)
if err != nil {
db.Close()
return nil, err
}
b.connections[name] = db
return db, nil
}
func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) {
entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration: %s", err)
return nil, errwrap.Wrapf("failed to read connection configuration: {{err}}", err)
}
if entry == nil {
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
@@ -144,7 +119,7 @@ type upgradeStatements struct {
type upgradeCheck struct {
// This json tag has a typo in it, the new version does not. This
// necessitates this upgrade logic.
Statements upgradeStatements `json:"statments"`
Statements *upgradeStatements `json:"statments,omitempty"`
}
func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) {
@@ -166,46 +141,138 @@ func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName
return nil, err
}
empty := upgradeCheck{}
if upgradeCh != empty {
result.Statements.CreationStatements = upgradeCh.Statements.CreationStatements
result.Statements.RevocationStatements = upgradeCh.Statements.RevocationStatements
result.Statements.RollbackStatements = upgradeCh.Statements.RollbackStatements
result.Statements.RenewStatements = upgradeCh.Statements.RenewStatements
switch {
case upgradeCh.Statements != nil:
var stmts dbplugin.Statements
if upgradeCh.Statements.CreationStatements != "" {
stmts.Creation = []string{upgradeCh.Statements.CreationStatements}
}
if upgradeCh.Statements.RevocationStatements != "" {
stmts.Revocation = []string{upgradeCh.Statements.RevocationStatements}
}
if upgradeCh.Statements.RollbackStatements != "" {
stmts.Rollback = []string{upgradeCh.Statements.RollbackStatements}
}
if upgradeCh.Statements.RenewStatements != "" {
stmts.Renewal = []string{upgradeCh.Statements.RenewStatements}
}
result.Statements = stmts
}
// For backwards compatibility, copy the values back into the string form
// of the fields
result.Statements = dbutil.StatementCompatibilityHelper(result.Statements)
return &result, nil
}
func (b *databaseBackend) invalidate(ctx context.Context, key string) {
b.Lock()
defer b.Unlock()
switch {
case strings.HasPrefix(key, databaseConfigPath):
name := strings.TrimPrefix(key, databaseConfigPath)
b.clearConnection(name)
b.ClearConnection(name)
}
}
// clearConnection closes the database connection and
// removes it from the b.connections map.
func (b *databaseBackend) clearConnection(name string) {
func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage, name string) (*dbPluginInstance, error) {
b.RLock()
unlockFunc := b.RUnlock
defer func() { unlockFunc() }()
db, ok := b.connections[name]
if ok {
return db, nil
}
// Upgrade lock
b.RUnlock()
b.Lock()
unlockFunc = b.Unlock
db, ok = b.connections[name]
if ok {
return db, nil
}
config, err := b.DatabaseConfig(ctx, s, name)
if err != nil {
return nil, err
}
dbp, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
return nil, err
}
_, err = dbp.Init(ctx, config.ConnectionDetails, true)
if err != nil {
dbp.Close()
return nil, err
}
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
db = &dbPluginInstance{
Database: dbp,
name: name,
id: id,
}
b.connections[name] = db
return db, nil
}
// ClearConnection closes the database connection and
// removes it from the b.connections map.
func (b *databaseBackend) ClearConnection(name string) error {
b.Lock()
defer b.Unlock()
return b.clearConnection(name)
}
func (b *databaseBackend) clearConnection(name string) error {
db, ok := b.connections[name]
if ok {
// Ignore error here since the database client is always killed
db.Close()
delete(b.connections, name)
}
return nil
}
func (b *databaseBackend) closeIfShutdown(name string, err error) {
func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) {
// Plugin has shutdown, close it so next call can reconnect.
switch err {
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
// Put this in a goroutine so that requests can run with the read or write lock
// and simply defer the unlock. Since we are attaching the instance and matching
// the id in the conneciton map, we can safely do this.
go func() {
b.Lock()
b.clearConnection(name)
b.Unlock()
defer b.Unlock()
db.Close()
// Ensure we are deleting the correct connection
mapDB, ok := b.connections[db.name]
if ok && db.id == mapDB.id {
delete(b.connections, db.name)
}
}()
}
}
// closeAllDBs closes all connections from all database types
func (b *databaseBackend) closeAllDBs(ctx context.Context) {
b.Lock()
defer b.Unlock()
for _, db := range b.connections {
db.Close()
}
b.connections = make(map[string]*dbPluginInstance)
}
const backendHelp = `

View File

@@ -7,6 +7,7 @@ import (
"log"
"os"
"reflect"
"strings"
"sync"
"testing"
"time"
@@ -27,6 +28,7 @@ var (
)
func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) {
t.Helper()
if os.Getenv("PG_URL") != "" {
return func() {}, os.Getenv("PG_URL")
}
@@ -64,7 +66,7 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac
})
if err != nil || (resp != nil && resp.IsError()) {
// It's likely not up and running yet, so return error and try again
return fmt.Errorf("err:%s resp:%#v\n", err, resp)
return fmt.Errorf("err:%#v resp:%#v", err, resp)
}
if resp == nil {
t.Fatal("expected warning")
@@ -123,13 +125,18 @@ func TestBackend_RoleUpgrade(t *testing.T) {
storage := &logical.InmemStorage{}
backend := &databaseBackend{}
roleEnt := &roleEntry{
roleExpected := &roleEntry{
Statements: dbplugin.Statements{
CreationStatements: "test",
Creation: []string{"test"},
},
}
entry, err := logical.StorageEntryJSON("role/test", roleEnt)
entry, err := logical.StorageEntryJSON("role/test", &roleEntry{
Statements: dbplugin.Statements{
CreationStatements: "test",
},
})
if err != nil {
t.Fatal(err)
}
@@ -142,8 +149,8 @@ func TestBackend_RoleUpgrade(t *testing.T) {
t.Fatal(err)
}
if !reflect.DeepEqual(role, roleEnt) {
t.Fatalf("bad role %#v", role)
if !reflect.DeepEqual(role, roleExpected) {
t.Fatalf("bad role %#v, %#v", role, roleExpected)
}
// Upgrade case
@@ -161,8 +168,8 @@ func TestBackend_RoleUpgrade(t *testing.T) {
t.Fatal(err)
}
if !reflect.DeepEqual(role, roleEnt) {
t.Fatalf("bad role %#v", role)
if !reflect.DeepEqual(role, roleExpected) {
t.Fatalf("bad role %#v, %#v", role, roleExpected)
}
}
@@ -207,6 +214,7 @@ func TestBackend_config_connection(t *testing.T) {
"connection_url": "sample_connection_url",
},
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(context.Background(), configReq)
@@ -233,6 +241,55 @@ func TestBackend_config_connection(t *testing.T) {
}
}
func TestBackend_BadConnectionString(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup(context.Background())
cleanup, _ := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()
respCheck := func(req *logical.Request) {
t.Helper()
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected error, resp:%#v", resp)
}
err = resp.Error()
if strings.Contains(err.Error(), "localhost") {
t.Fatalf("error should not contain connection info")
}
}
// Configure a connection
data := map[string]interface{}{
"connection_url": "postgresql://:pw@[localhost",
"plugin_name": "postgresql-database-plugin",
"allowed_roles": []string{"plugin-role-test"},
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: data,
}
respCheck(req)
time.Sleep(1 * time.Second)
}
func TestBackend_basic(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
@@ -388,7 +445,6 @@ func TestBackend_basic(t *testing.T) {
if testCredsExist(t, credsResp, connURL) {
t.Fatalf("Creds should not exist")
}
}
func TestBackend_connectionCrud(t *testing.T) {
@@ -468,6 +524,7 @@ func TestBackend_connectionCrud(t *testing.T) {
"connection_url": connURL,
},
"allowed_roles": []string{"plugin-role-test"},
"root_credentials_rotate_statements": []string{},
}
req.Operation = logical.ReadOperation
resp, err = b.HandleRequest(context.Background(), req)
@@ -602,15 +659,15 @@ func TestBackend_roleCrud(t *testing.T) {
}
expected := dbplugin.Statements{
CreationStatements: testRole,
RevocationStatements: defaultRevocationSQL,
Creation: []string{strings.TrimSpace(testRole)},
Revocation: []string{strings.TrimSpace(defaultRevocationSQL)},
}
actual := dbplugin.Statements{
CreationStatements: resp.Data["creation_statements"].(string),
RevocationStatements: resp.Data["revocation_statements"].(string),
RollbackStatements: resp.Data["rollback_statements"].(string),
RenewStatements: resp.Data["renew_statements"].(string),
Creation: resp.Data["creation_statements"].([]string),
Revocation: resp.Data["revocation_statements"].([]string),
Rollback: resp.Data["rollback_statements"].([]string),
Renewal: resp.Data["renew_statements"].([]string),
}
if !reflect.DeepEqual(expected, actual) {

View File

@@ -1,5 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// Code generated by protoc-gen-go.
// source: builtin/logical/database/dbplugin/database.proto
// DO NOT EDIT!
/*
Package dbplugin is a generated protocol buffer package.
@@ -9,13 +10,17 @@ It is generated from these files:
It has these top-level messages:
InitializeRequest
InitRequest
CreateUserRequest
RenewUserRequest
RevokeUserRequest
RotateRootCredentialsRequest
Statements
UsernameConfig
InitResponse
CreateUserResponse
TypeResponse
RotateRootCredentialsResponse
Empty
*/
package dbplugin
@@ -65,6 +70,30 @@ func (m *InitializeRequest) GetVerifyConnection() bool {
return false
}
type InitRequest struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
}
func (m *InitRequest) Reset() { *m = InitRequest{} }
func (m *InitRequest) String() string { return proto.CompactTextString(m) }
func (*InitRequest) ProtoMessage() {}
func (*InitRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *InitRequest) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
func (m *InitRequest) GetVerifyConnection() bool {
if m != nil {
return m.VerifyConnection
}
return false
}
type CreateUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
@@ -74,7 +103,7 @@ type CreateUserRequest struct {
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
func (*CreateUserRequest) ProtoMessage() {}
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *CreateUserRequest) GetStatements() *Statements {
if m != nil {
@@ -106,7 +135,7 @@ type RenewUserRequest struct {
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
func (*RenewUserRequest) ProtoMessage() {}
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *RenewUserRequest) GetStatements() *Statements {
if m != nil {
@@ -137,7 +166,7 @@ type RevokeUserRequest struct {
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
func (*RevokeUserRequest) ProtoMessage() {}
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *RevokeUserRequest) GetStatements() *Statements {
if m != nil {
@@ -153,17 +182,41 @@ func (m *RevokeUserRequest) GetUsername() string {
return ""
}
type RotateRootCredentialsRequest struct {
Statements []string `protobuf:"bytes,1,rep,name=statements" json:"statements,omitempty"`
}
func (m *RotateRootCredentialsRequest) Reset() { *m = RotateRootCredentialsRequest{} }
func (m *RotateRootCredentialsRequest) String() string { return proto.CompactTextString(m) }
func (*RotateRootCredentialsRequest) ProtoMessage() {}
func (*RotateRootCredentialsRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (m *RotateRootCredentialsRequest) GetStatements() []string {
if m != nil {
return m.Statements
}
return nil
}
type Statements struct {
// DEPRECATED, will be removed in 0.12
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
// DEPRECATED, will be removed in 0.12
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
// DEPRECATED, will be removed in 0.12
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
// DEPRECATED, will be removed in 0.12
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
Creation []string `protobuf:"bytes,5,rep,name=creation" json:"creation,omitempty"`
Revocation []string `protobuf:"bytes,6,rep,name=revocation" json:"revocation,omitempty"`
Rollback []string `protobuf:"bytes,7,rep,name=rollback" json:"rollback,omitempty"`
Renewal []string `protobuf:"bytes,8,rep,name=renewal" json:"renewal,omitempty"`
}
func (m *Statements) Reset() { *m = Statements{} }
func (m *Statements) String() string { return proto.CompactTextString(m) }
func (*Statements) ProtoMessage() {}
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *Statements) GetCreationStatements() string {
if m != nil {
@@ -193,6 +246,34 @@ func (m *Statements) GetRenewStatements() string {
return ""
}
func (m *Statements) GetCreation() []string {
if m != nil {
return m.Creation
}
return nil
}
func (m *Statements) GetRevocation() []string {
if m != nil {
return m.Revocation
}
return nil
}
func (m *Statements) GetRollback() []string {
if m != nil {
return m.Rollback
}
return nil
}
func (m *Statements) GetRenewal() []string {
if m != nil {
return m.Renewal
}
return nil
}
type UsernameConfig struct {
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
@@ -201,7 +282,7 @@ type UsernameConfig struct {
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
func (*UsernameConfig) ProtoMessage() {}
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *UsernameConfig) GetDisplayName() string {
if m != nil {
@@ -217,6 +298,22 @@ func (m *UsernameConfig) GetRoleName() string {
return ""
}
type InitResponse struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
}
func (m *InitResponse) Reset() { *m = InitResponse{} }
func (m *InitResponse) String() string { return proto.CompactTextString(m) }
func (*InitResponse) ProtoMessage() {}
func (*InitResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
func (m *InitResponse) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
type CreateUserResponse struct {
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
@@ -225,7 +322,7 @@ type CreateUserResponse struct {
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
func (*CreateUserResponse) ProtoMessage() {}
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{9} }
func (m *CreateUserResponse) GetUsername() string {
if m != nil {
@@ -248,7 +345,7 @@ type TypeResponse struct {
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
func (*TypeResponse) ProtoMessage() {}
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{10} }
func (m *TypeResponse) GetType() string {
if m != nil {
@@ -257,23 +354,43 @@ func (m *TypeResponse) GetType() string {
return ""
}
type RotateRootCredentialsResponse struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
}
func (m *RotateRootCredentialsResponse) Reset() { *m = RotateRootCredentialsResponse{} }
func (m *RotateRootCredentialsResponse) String() string { return proto.CompactTextString(m) }
func (*RotateRootCredentialsResponse) ProtoMessage() {}
func (*RotateRootCredentialsResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{11} }
func (m *RotateRootCredentialsResponse) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
type Empty struct {
}
func (m *Empty) Reset() { *m = Empty{} }
func (m *Empty) String() string { return proto.CompactTextString(m) }
func (*Empty) ProtoMessage() {}
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{12} }
func init() {
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
proto.RegisterType((*InitRequest)(nil), "dbplugin.InitRequest")
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
proto.RegisterType((*RotateRootCredentialsRequest)(nil), "dbplugin.RotateRootCredentialsRequest")
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
proto.RegisterType((*InitResponse)(nil), "dbplugin.InitResponse")
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
proto.RegisterType((*RotateRootCredentialsResponse)(nil), "dbplugin.RotateRootCredentialsResponse")
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
}
@@ -292,8 +409,10 @@ type DatabaseClient interface {
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
RotateRootCredentials(ctx context.Context, in *RotateRootCredentialsRequest, opts ...grpc.CallOption) (*RotateRootCredentialsResponse, error)
Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error)
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
}
type databaseClient struct {
@@ -340,9 +459,18 @@ func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest,
return out, nil
}
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
func (c *databaseClient) RotateRootCredentials(ctx context.Context, in *RotateRootCredentialsRequest, opts ...grpc.CallOption) (*RotateRootCredentialsResponse, error) {
out := new(RotateRootCredentialsResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/RotateRootCredentials", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error) {
out := new(InitResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/Init", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
@@ -358,6 +486,15 @@ func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.Call
return out, nil
}
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Database service
type DatabaseServer interface {
@@ -365,8 +502,10 @@ type DatabaseServer interface {
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
Initialize(context.Context, *InitializeRequest) (*Empty, error)
RotateRootCredentials(context.Context, *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error)
Init(context.Context, *InitRequest) (*InitResponse, error)
Close(context.Context, *Empty) (*Empty, error)
Initialize(context.Context, *InitializeRequest) (*Empty, error)
}
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
@@ -445,20 +584,38 @@ func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func
return interceptor(ctx, in, info, handler)
}
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitializeRequest)
func _Database_RotateRootCredentials_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RotateRootCredentialsRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Initialize(ctx, in)
return srv.(DatabaseServer).RotateRootCredentials(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Initialize",
FullMethod: "/dbplugin.Database/RotateRootCredentials",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
return srv.(DatabaseServer).RotateRootCredentials(ctx, req.(*RotateRootCredentialsRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Init_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Init(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Init",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Init(ctx, req.(*InitRequest))
}
return interceptor(ctx, in, info, handler)
}
@@ -481,6 +638,24 @@ func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(inte
return interceptor(ctx, in, info, handler)
}
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitializeRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Initialize(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Initialize",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
}
return interceptor(ctx, in, info, handler)
}
var _Database_serviceDesc = grpc.ServiceDesc{
ServiceName: "dbplugin.Database",
HandlerType: (*DatabaseServer)(nil),
@@ -502,13 +677,21 @@ var _Database_serviceDesc = grpc.ServiceDesc{
Handler: _Database_RevokeUser_Handler,
},
{
MethodName: "Initialize",
Handler: _Database_Initialize_Handler,
MethodName: "RotateRootCredentials",
Handler: _Database_RotateRootCredentials_Handler,
},
{
MethodName: "Init",
Handler: _Database_Init_Handler,
},
{
MethodName: "Close",
Handler: _Database_Close_Handler,
},
{
MethodName: "Initialize",
Handler: _Database_Initialize_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "builtin/logical/database/dbplugin/database.proto",
@@ -517,40 +700,49 @@ var _Database_serviceDesc = grpc.ServiceDesc{
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 548 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
0x94, 0x05, 0x00, 0x00,
// 690 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x55, 0x41, 0x4f, 0xdb, 0x4a,
0x10, 0x96, 0x93, 0x00, 0xc9, 0x80, 0x80, 0xec, 0x03, 0x64, 0xf9, 0xf1, 0xde, 0x43, 0x3e, 0xf0,
0x40, 0x95, 0xe2, 0x0a, 0x5a, 0xb5, 0xe2, 0xd0, 0xaa, 0x0a, 0x55, 0x55, 0xa9, 0xe2, 0xb0, 0xc0,
0xad, 0x12, 0xda, 0x38, 0x43, 0xba, 0xc2, 0xf1, 0xba, 0xde, 0x0d, 0x34, 0xfd, 0x03, 0xed, 0xcf,
0xe8, 0x4f, 0xe9, 0xb1, 0x3f, 0xab, 0xf2, 0xda, 0x6b, 0x6f, 0x62, 0x28, 0x07, 0xda, 0x9b, 0x67,
0xe6, 0xfb, 0x66, 0xbe, 0x9d, 0x9d, 0x59, 0xc3, 0xe3, 0xc1, 0x84, 0x47, 0x8a, 0xc7, 0x41, 0x24,
0x46, 0x3c, 0x64, 0x51, 0x30, 0x64, 0x8a, 0x0d, 0x98, 0xc4, 0x60, 0x38, 0x48, 0xa2, 0xc9, 0x88,
0xc7, 0xa5, 0xa7, 0x97, 0xa4, 0x42, 0x09, 0xd2, 0x36, 0x01, 0xef, 0xbf, 0x91, 0x10, 0xa3, 0x08,
0x03, 0xed, 0x1f, 0x4c, 0x2e, 0x03, 0xc5, 0xc7, 0x28, 0x15, 0x1b, 0x27, 0x39, 0xd4, 0x7f, 0x0f,
0xdd, 0xb7, 0x31, 0x57, 0x9c, 0x45, 0xfc, 0x33, 0x52, 0xfc, 0x38, 0x41, 0xa9, 0xc8, 0x16, 0x2c,
0x86, 0x22, 0xbe, 0xe4, 0x23, 0xd7, 0xd9, 0x71, 0xf6, 0x56, 0x68, 0x61, 0x91, 0x47, 0xd0, 0xbd,
0xc6, 0x94, 0x5f, 0x4e, 0x2f, 0x42, 0x11, 0xc7, 0x18, 0x2a, 0x2e, 0x62, 0xb7, 0xb1, 0xe3, 0xec,
0xb5, 0xe9, 0x7a, 0x1e, 0xe8, 0x97, 0xfe, 0xa3, 0x86, 0xeb, 0xf8, 0x14, 0x96, 0xb3, 0xec, 0xbf,
0x33, 0xaf, 0xff, 0xc3, 0x81, 0x6e, 0x3f, 0x45, 0xa6, 0xf0, 0x5c, 0x62, 0x6a, 0x52, 0x3f, 0x01,
0x90, 0x8a, 0x29, 0x1c, 0x63, 0xac, 0xa4, 0x4e, 0xbf, 0x7c, 0xb0, 0xd1, 0x33, 0x7d, 0xe8, 0x9d,
0x96, 0x31, 0x6a, 0xe1, 0xc8, 0x2b, 0x58, 0x9b, 0x48, 0x4c, 0x63, 0x36, 0xc6, 0x8b, 0x42, 0x59,
0x43, 0x53, 0xdd, 0x8a, 0x7a, 0x5e, 0x00, 0xfa, 0x3a, 0x4e, 0x57, 0x27, 0x33, 0x36, 0x39, 0x02,
0xc0, 0x4f, 0x09, 0x4f, 0x99, 0x16, 0xdd, 0xd4, 0x6c, 0xaf, 0x97, 0xb7, 0xbd, 0x67, 0xda, 0xde,
0x3b, 0x33, 0x6d, 0xa7, 0x16, 0xda, 0xff, 0xe6, 0xc0, 0x3a, 0xc5, 0x18, 0x6f, 0x1e, 0x7e, 0x12,
0x0f, 0xda, 0x46, 0x98, 0x3e, 0x42, 0x87, 0x96, 0xf6, 0x83, 0x24, 0x22, 0x74, 0x29, 0x5e, 0x8b,
0x2b, 0xfc, 0xa3, 0x12, 0xfd, 0x17, 0xb0, 0x4d, 0x45, 0x06, 0xa5, 0x42, 0xa8, 0x7e, 0x8a, 0x43,
0x8c, 0xb3, 0x99, 0x94, 0xa6, 0xe2, 0xbf, 0x73, 0x15, 0x9b, 0x7b, 0x1d, 0x3b, 0xb7, 0xff, 0xbd,
0x01, 0x50, 0x95, 0x25, 0x01, 0xfc, 0x15, 0x66, 0x23, 0xc2, 0x45, 0x7c, 0x31, 0xa7, 0xb4, 0x43,
0x89, 0x09, 0x59, 0x84, 0x43, 0xd8, 0x4c, 0xf1, 0x5a, 0x84, 0x35, 0x4a, 0x2e, 0x74, 0xa3, 0x0a,
0xce, 0x56, 0x49, 0x45, 0x14, 0x0d, 0x58, 0x78, 0x65, 0x53, 0x9a, 0x79, 0x15, 0x13, 0xb2, 0x08,
0xfb, 0xb0, 0x9e, 0x66, 0xd7, 0x6d, 0xa3, 0x5b, 0x1a, 0xbd, 0xa6, 0xfd, 0xa7, 0x33, 0xcd, 0x32,
0x32, 0xdd, 0x05, 0x7d, 0xdc, 0xd2, 0xce, 0x9a, 0x51, 0xe9, 0x71, 0x17, 0xf3, 0x66, 0x54, 0x9e,
0x8c, 0x6b, 0x8a, 0xbb, 0x4b, 0x39, 0xd7, 0xd8, 0xc4, 0x85, 0x25, 0x5d, 0x8a, 0x45, 0x6e, 0x5b,
0x87, 0x8c, 0xe9, 0x9f, 0xc0, 0xea, 0xec, 0xa8, 0x93, 0x1d, 0x58, 0x3e, 0xe6, 0x32, 0x89, 0xd8,
0xf4, 0x24, 0xbb, 0xb3, 0xbc, 0x7b, 0xb6, 0x2b, 0xab, 0x44, 0x45, 0x84, 0x27, 0xd6, 0x95, 0x1a,
0xdb, 0xdf, 0x85, 0x95, 0x7c, 0xf7, 0x65, 0x22, 0x62, 0x89, 0x77, 0x2d, 0xbf, 0xff, 0x0e, 0x88,
0xbd, 0xce, 0x05, 0xda, 0x1e, 0x16, 0x67, 0x6e, 0x9e, 0x3d, 0x68, 0x27, 0x4c, 0xca, 0x1b, 0x91,
0x0e, 0x4d, 0x55, 0x63, 0xfb, 0x3e, 0xac, 0x9c, 0x4d, 0x13, 0x2c, 0xf3, 0x10, 0x68, 0xa9, 0x69,
0x62, 0x72, 0xe8, 0x6f, 0xff, 0x19, 0xfc, 0x73, 0xc7, 0xb0, 0xdd, 0x23, 0x75, 0x09, 0x16, 0x5e,
0x8f, 0x13, 0x35, 0x3d, 0xf8, 0xd2, 0x82, 0xf6, 0x71, 0xf1, 0xe6, 0x92, 0x00, 0x5a, 0x59, 0x49,
0xb2, 0x56, 0x6d, 0x80, 0x46, 0x79, 0x5b, 0x95, 0x63, 0x46, 0xd3, 0x1b, 0x80, 0xea, 0xc4, 0xe4,
0xef, 0x0a, 0x55, 0x7b, 0xd6, 0xbc, 0xed, 0xdb, 0x83, 0x45, 0xa2, 0xe7, 0xd0, 0x29, 0x9f, 0x0f,
0xe2, 0x55, 0xd0, 0xf9, 0x37, 0xc5, 0x9b, 0x97, 0x96, 0x3d, 0x09, 0xd5, 0x5a, 0xdb, 0x12, 0x6a,
0xcb, 0x5e, 0xe7, 0x7e, 0x80, 0xcd, 0x5b, 0xdb, 0x47, 0x76, 0xad, 0x34, 0xbf, 0x58, 0x66, 0xef,
0xff, 0x7b, 0x71, 0xc5, 0xf9, 0x9e, 0x42, 0x2b, 0x1b, 0x21, 0xb2, 0x59, 0x11, 0xac, 0xdf, 0x89,
0xdd, 0xdf, 0x99, 0x49, 0xdb, 0x87, 0x85, 0x7e, 0x24, 0xe4, 0x2d, 0x37, 0x52, 0x3b, 0xcb, 0x4b,
0x80, 0xea, 0xf7, 0x67, 0xf7, 0xa1, 0xf6, 0x53, 0xac, 0x71, 0xfd, 0xe6, 0xd7, 0x86, 0x33, 0x58,
0xd4, 0xef, 0xe7, 0xe1, 0xcf, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa7, 0x13, 0xfe, 0x55, 0xa5, 0x07,
0x00, 0x00,
}

View File

@@ -4,6 +4,12 @@ package dbplugin;
import "google/protobuf/timestamp.proto";
message InitializeRequest {
option deprecated = true;
bytes config = 1;
bool verify_connection = 2;
}
message InitRequest {
bytes config = 1;
bool verify_connection = 2;
}
@@ -25,11 +31,24 @@ message RevokeUserRequest {
string username = 2;
}
message RotateRootCredentialsRequest {
repeated string statements = 1;
}
message Statements {
// DEPRECATED, will be removed in 0.12
string creation_statements = 1;
// DEPRECATED, will be removed in 0.12
string revocation_statements = 2;
// DEPRECATED, will be removed in 0.12
string rollback_statements = 3;
// DEPRECATED, will be removed in 0.12
string renew_statements = 4;
repeated string creation = 5;
repeated string revocation = 6;
repeated string rollback = 7;
repeated string renewal = 8;
}
message UsernameConfig {
@@ -37,6 +56,10 @@ message UsernameConfig {
string RoleName = 2;
}
message InitResponse {
bytes config = 1;
}
message CreateUserResponse {
string username = 1;
string password = 2;
@@ -46,6 +69,10 @@ message TypeResponse {
string type = 1;
}
message RotateRootCredentialsResponse {
bytes config = 1;
}
message Empty {}
service Database {
@@ -53,6 +80,11 @@ service Database {
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
rpc RenewUser(RenewUserRequest) returns (Empty);
rpc RevokeUser(RevokeUserRequest) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty);
rpc RotateRootCredentials(RotateRootCredentialsRequest) returns (RotateRootCredentialsResponse);
rpc Init(InitRequest) returns (InitResponse);
rpc Close(Empty) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty) {
option deprecated = true;
};
}

View File

@@ -2,8 +2,14 @@ package dbplugin
import (
"context"
"errors"
"net/url"
"strings"
"sync"
"time"
"github.com/hashicorp/errwrap"
metrics "github.com/armon/go-metrics"
log "github.com/mgutz/logxi/v1"
)
@@ -51,13 +57,27 @@ func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
func (mw *databaseTracingMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RotateRootCredentials", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "RotateRootCredentials", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.RotateRootCredentials(ctx, statements)
}
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *databaseTracingMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Initialize(ctx, conf, verifyConnection)
return mw.next.Init(ctx, conf, verifyConnection)
}
func (mw *databaseTracingMiddleware) Close() (err error) {
@@ -131,7 +151,28 @@ func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements
return mw.next.RevokeUser(ctx, statements, username)
}
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
func (mw *databaseMetricsMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RotateRootCredentials"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RotateRootCredentials"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "RotateRootCredentials", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "RotateRootCredentials"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials"}, 1)
return mw.next.RotateRootCredentials(ctx, statements)
}
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *databaseMetricsMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
@@ -144,7 +185,7 @@ func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[st
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(ctx, conf, verifyConnection)
return mw.next.Init(ctx, conf, verifyConnection)
}
func (mw *databaseMetricsMiddleware) Close() (err error) {
@@ -162,3 +203,76 @@ func (mw *databaseMetricsMiddleware) Close() (err error) {
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
return mw.next.Close()
}
// ---- Error Sanitizer Middleware Domain ----
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
// sanitizes returned error messages
type DatabaseErrorSanitizerMiddleware struct {
l sync.RWMutex
next Database
secretsFn func() map[string]interface{}
}
func NewDatabaseErrorSanitizerMiddleware(next Database, secretsFn func() map[string]interface{}) *DatabaseErrorSanitizerMiddleware {
return &DatabaseErrorSanitizerMiddleware{
next: next,
secretsFn: secretsFn,
}
}
func (mw *DatabaseErrorSanitizerMiddleware) Type() (string, error) {
dbType, err := mw.next.Type()
return dbType, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
username, password, err = mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
return username, password, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
return mw.sanitize(mw.next.RenewUser(ctx, statements, username, expiration))
}
func (mw *DatabaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
return mw.sanitize(mw.next.RevokeUser(ctx, statements, username))
}
func (mw *DatabaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
conf, err = mw.next.RotateRootCredentials(ctx, statements)
return conf, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *DatabaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
saveConf, err = mw.next.Init(ctx, conf, verifyConnection)
return saveConf, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) Close() (err error) {
return mw.sanitize(mw.next.Close())
}
// sanitize
func (mw *DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
if err == nil {
return nil
}
if errwrap.ContainsType(err, new(url.Error)) {
return errors.New("unable to parse connection url")
}
if mw.secretsFn != nil {
for k, v := range mw.secretsFn() {
if k == "" {
continue
}
err = errors.New(strings.Replace(err.Error(), k, v.(string), -1))
}
}
return err
}

View File

@@ -7,6 +7,8 @@ import (
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/vault/helper/pluginutil"
@@ -61,16 +63,51 @@ func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*E
return &Empty{}, err
}
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
config := map[string]interface{}{}
func (s *gRPCServer) RotateRootCredentials(ctx context.Context, req *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error) {
resp, err := s.impl.RotateRootCredentials(ctx, req.Statements)
if err != nil {
return nil, err
}
respConfig, err := json.Marshal(resp)
if err != nil {
return nil, err
}
return &RotateRootCredentialsResponse{
Config: respConfig,
}, err
}
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
_, err := s.Init(ctx, &InitRequest{
Config: req.Config,
VerifyConnection: req.VerifyConnection,
})
return &Empty{}, err
}
func (s *gRPCServer) Init(ctx context.Context, req *InitRequest) (*InitResponse, error) {
config := map[string]interface{}{}
err := json.Unmarshal(req.Config, &config)
if err != nil {
return nil, err
}
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
return &Empty{}, err
resp, err := s.impl.Init(ctx, config, req.VerifyConnection)
if err != nil {
return nil, err
}
respConfig, err := json.Marshal(resp)
if err != nil {
return nil, err
}
return &InitResponse{
Config: respConfig,
}, err
}
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
@@ -87,7 +124,7 @@ type gRPCClient struct {
doneCtx context.Context
}
func (c gRPCClient) Type() (string, error) {
func (c *gRPCClient) Type() (string, error) {
resp, err := c.client.Type(c.doneCtx, &Empty{})
if err != nil {
return "", err
@@ -96,7 +133,7 @@ func (c gRPCClient) Type() (string, error) {
return resp.Type, err
}
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
func (c *gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
t, err := ptypes.TimestampProto(expiration)
if err != nil {
return "", "", err
@@ -172,30 +209,74 @@ func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, user
return nil
}
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
configRaw, err := json.Marshal(config)
func (c *gRPCClient) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
defer close(quitCh)
defer cancel()
resp, err := c.client.RotateRootCredentials(ctx, &RotateRootCredentialsRequest{
Statements: statements,
})
if err != nil {
if c.doneCtx.Err() != nil {
return nil, ErrPluginShutdown
}
return nil, err
}
if err := json.Unmarshal(resp.Config, &conf); err != nil {
return nil, err
}
return conf, nil
}
func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *gRPCClient) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
configRaw, err := json.Marshal(conf)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
defer close(quitCh)
defer cancel()
_, err = c.client.Initialize(ctx, &InitializeRequest{
resp, err := c.client.Init(ctx, &InitRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
if err != nil {
// Fall back to old call if not implemented
grpcStatus, ok := status.FromError(err)
if ok && grpcStatus.Code() == codes.Unimplemented {
_, err = c.client.Initialize(ctx, &InitializeRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
if err == nil {
return conf, nil
}
}
if c.doneCtx.Err() != nil {
return ErrPluginShutdown
return nil, ErrPluginShutdown
}
return nil, err
}
return err
if err := json.Unmarshal(resp.Config, &conf); err != nil {
return nil, err
}
return nil
return conf, nil
}
func (c *gRPCClient) Close() error {

View File

@@ -2,8 +2,10 @@ package dbplugin
import (
"context"
"encoding/json"
"fmt"
"net/rpc"
"strings"
"time"
)
@@ -37,8 +39,28 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *str
return err
}
func (ds *databasePluginRPCServer) RotateRootCredentials(args *RotateRootCredentialsRequestRPC, resp *RotateRootCredentialsResponse) error {
config, err := ds.impl.RotateRootCredentials(context.Background(), args.Statements)
if err != nil {
return err
}
resp.Config, err = json.Marshal(config)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
return ds.Init(&InitRequestRPC{
Config: args.Config,
VerifyConnection: args.VerifyConnection,
}, &InitResponse{})
}
func (ds *databasePluginRPCServer) Init(args *InitRequestRPC, resp *InitResponse) error {
config, err := ds.impl.Init(context.Background(), args.Config, args.VerifyConnection)
if err != nil {
return err
}
resp.Config, err = json.Marshal(config)
return err
}
@@ -81,9 +103,7 @@ func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements State
Expiration: expiration,
}
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
return dr.client.Call("Plugin.RenewUser", req, &struct{}{})
}
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
@@ -92,26 +112,55 @@ func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Stat
Username: username,
}
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
return dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
}
return err
func (dr *databasePluginRPCClient) RotateRootCredentials(_ context.Context, statements []string) (saveConf map[string]interface{}, err error) {
req := RotateRootCredentialsRequestRPC{
Statements: statements,
}
var resp RotateRootCredentialsResponse
err = dr.client.Call("Plugin.RotateRootCredentials", req, &resp)
err = json.Unmarshal(resp.Config, &saveConf)
return saveConf, err
}
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := dr.Init(nil, conf, verifyConnection)
return err
}
func (dr *databasePluginRPCClient) Init(_ context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
req := InitRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
var resp InitResponse
err = dr.client.Call("Plugin.Init", req, &resp)
if err != nil {
if strings.Contains(err.Error(), "can't find method Plugin.Init") {
req := InitializeRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
err = dr.client.Call("Plugin.Initialize", req, &struct{}{})
if err == nil {
return conf, nil
}
}
return nil, err
}
return err
err = json.Unmarshal(resp.Config, &saveConf)
return saveConf, err
}
func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
return dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
}
// ---- RPC Request Args Domain ----
@@ -121,6 +170,11 @@ type InitializeRequestRPC struct {
VerifyConnection bool
}
type InitRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}
type CreateUserRequestRPC struct {
Statements Statements
UsernameConfig UsernameConfig
@@ -137,3 +191,7 @@ type RevokeUserRequestRPC struct {
Statements Statements
Username string
}
type RotateRootCredentialsRequestRPC struct {
Statements []string
}

View File

@@ -8,6 +8,7 @@ import (
"google.golang.org/grpc"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
log "github.com/mgutz/logxi/v1"
@@ -20,8 +21,13 @@ type Database interface {
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
RevokeUser(ctx context.Context, statements Statements, username string) error
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error)
Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error)
Close() error
// DEPRECATED, will be removed in 0.12
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error)
}
// PluginFactory is used to build plugin database types. It wraps the database
@@ -40,7 +46,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
// from the pluginRunner. Then cast it to a Database.
dbRaw, err := pluginRunner.BuiltinFactory()
if err != nil {
return nil, fmt.Errorf("error getting plugin type: %s", err)
return nil, errwrap.Wrapf("error initializing plugin: {{err}}", err)
}
var ok bool
@@ -71,7 +77,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
typeStr, err := db.Type()
if err != nil {
return nil, fmt.Errorf("error getting plugin type: %s", err)
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
}
// Wrap with metrics middleware
@@ -113,7 +119,11 @@ type DatabasePlugin struct {
}
func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) {
return &databasePluginRPCServer{impl: d.impl}, nil
impl := &DatabaseErrorSanitizerMiddleware{
next: d.impl,
}
return &databasePluginRPCServer{impl: impl}, nil
}
func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
@@ -121,7 +131,11 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
}
func (d DatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
impl := &DatabaseErrorSanitizerMiddleware{
next: d.impl,
}
RegisterDatabaseServer(s, &gRPCServer{impl: impl})
return nil
}

View File

@@ -61,6 +61,17 @@ func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statement
delete(m.users, username)
return nil
}
func (m *mockPlugin) RotateRootCredentials(_ context.Context, statements []string) (map[string]interface{}, error) {
return nil, nil
}
func (m *mockPlugin) Init(_ context.Context, conf map[string]interface{}, _ bool) (map[string]interface{}, error) {
err := errors.New("err")
if len(conf) != 1 {
return nil, err
}
return conf, nil
}
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
err := errors.New("err")
if len(conf) != 1 {
@@ -132,7 +143,7 @@ func TestPlugin_NetRPC_Main(t *testing.T) {
plugin.Serve(serveConf)
}
func TestPlugin_Initialize(t *testing.T) {
func TestPlugin_Init(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
@@ -145,7 +156,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1,
}
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
_, err = dbRaw.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -170,7 +181,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -209,7 +220,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -243,7 +254,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -272,7 +283,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
}
// Test the code is still compatible with an old netRPC plugin
func TestPlugin_NetRPC_Initialize(t *testing.T) {
func TestPlugin_NetRPC_Init(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
@@ -285,7 +296,7 @@ func TestPlugin_NetRPC_Initialize(t *testing.T) {
"test": 1,
}
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
_, err = dbRaw.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -310,7 +321,7 @@ func TestPlugin_NetRPC_CreateUser(t *testing.T) {
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -349,7 +360,7 @@ func TestPlugin_NetRPC_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -383,7 +394,7 @@ func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{
"test": 1,
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"github.com/fatih/structs"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@@ -24,6 +25,8 @@ type DatabaseConfig struct {
// by each database type.
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"`
RootCredentialsRotateStatements []string `json:"root_credentials_rotate_statements" structs:"root_credentials_rotate_statements" mapstructure:"root_credentials_rotate_statements"`
}
// pathResetConnection configures a path to reset a plugin.
@@ -55,16 +58,13 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
return logical.ErrorResponse(respErrEmptyName), nil
}
// Grab the mutex lock
b.Lock()
defer b.Unlock()
// Close plugin and delete the entry in the connections cache.
b.clearConnection(name)
if err := b.ClearConnection(name); err != nil {
return nil, err
}
// Execute plugin again, we don't need the object so throw away.
_, err := b.createDBObj(ctx, req.Storage, name)
if err != nil {
if _, err := b.GetConnection(ctx, req.Storage, name); err != nil {
return nil, err
}
@@ -103,6 +103,14 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
allowed to get creds from this database connection. If empty no
roles are allowed. If "*" all roles are allowed.`,
},
"root_rotation_statements": &framework.FieldSchema{
Type: framework.TypeStringSlice,
Description: `Specifies the database statements to be executed
to rotate the root user's credentials. See the plugin's API
page for more information on support and formatting for this
parameter.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
@@ -179,18 +187,10 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
return nil, errors.New("failed to delete connection configuration")
}
b.Lock()
defer b.Unlock()
if _, ok := b.connections[name]; ok {
err = b.connections[name].Close()
if err != nil {
if err := b.ClearConnection(name); err != nil {
return nil, err
}
delete(b.connections, name)
}
return nil, nil
}
}
@@ -210,8 +210,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
}
verifyConnection := data.Get("verify_connection").(bool)
allowedRoles := data.Get("allowed_roles").([]string)
rootRotationStatements := data.Get("root_rotation_statements").([]string)
// Remove these entries from the data before we store it keyed under
// ConnectionDetails.
@@ -219,35 +219,45 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
delete(data.Raw, "plugin_name")
delete(data.Raw, "allowed_roles")
delete(data.Raw, "verify_connection")
delete(data.Raw, "root_rotation_statements")
config := &DatabaseConfig{
ConnectionDetails: data.Raw,
PluginName: pluginName,
AllowedRoles: allowedRoles,
}
db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
// Create a database plugin and initialize it. This instance is not
// going to be used and is initialized just to ensure all parameters
// are valid and the connection is verified, if requested.
db, err := dbplugin.PluginFactory(ctx, pluginName, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
}
err = db.Initialize(ctx, config.ConnectionDetails, verifyConnection)
connDetails, err := db.Init(ctx, data.Raw, verifyConnection)
if err != nil {
db.Close()
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
}
// Grab the mutex lock
b.Lock()
defer b.Unlock()
// Close and remove the old connection
b.clearConnection(name)
// Save the new connection
b.connections[name] = db
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
b.connections[name] = &dbPluginInstance{
Database: db,
name: name,
id: id,
}
// Store it
config := &DatabaseConfig{
ConnectionDetails: connDetails,
PluginName: pluginName,
AllowedRoles: allowedRoles,
RootCredentialsRotateStatements: rootRotationStatements,
}
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
if err != nil {
return nil, err

View File

@@ -54,26 +54,15 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
return nil, logical.ErrPermissionDenied
}
// Grab the read lock
b.RLock()
unlockFunc := b.RUnlock
// Get the Database object
db, ok := b.getDBObj(role.DBName)
if !ok {
// Upgrade lock
b.RUnlock()
b.Lock()
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
}
return nil, err
}
db.RLock()
defer db.RUnlock()
ttl := b.System().DefaultLeaseTTL()
if role.DefaultTTL != 0 {
ttl = role.DefaultTTL
@@ -96,8 +85,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
// Create the user
username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration)
if err != nil {
unlockFunc()
b.closeIfShutdown(role.DBName, err)
b.CloseIfShutdown(db, err)
return nil, err
}
@@ -109,8 +97,6 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
"role": name,
})
resp.Secret.TTL = ttl
unlockFunc()
return resp, nil
}
}

View File

@@ -36,26 +36,26 @@ func pathRoles(b *databaseBackend) *framework.Path {
Description: "Name of the database this role acts on.",
},
"creation_statements": {
Type: framework.TypeString,
Type: framework.TypeStringSlice,
Description: `Specifies the database statements executed to
create and configure a user. See the plugin's API page for more
information on support and formatting for this parameter.`,
},
"revocation_statements": {
Type: framework.TypeString,
Type: framework.TypeStringSlice,
Description: `Specifies the database statements to be executed
to revoke a user. See the plugin's API page for more information
on support and formatting for this parameter.`,
},
"renew_statements": {
Type: framework.TypeString,
Type: framework.TypeStringSlice,
Description: `Specifies the database statements to be executed
to renew a user. Not every plugin type will support this
functionality. See the plugin's API page for more information on
support and formatting for this parameter. `,
},
"rollback_statements": {
Type: framework.TypeString,
Type: framework.TypeStringSlice,
Description: `Specifies the database statements to be executed
rollback a create operation in the event of an error. Not every
plugin type will support this functionality. See the plugin's
@@ -109,10 +109,10 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
return &logical.Response{
Data: map[string]interface{}{
"db_name": role.DBName,
"creation_statements": role.Statements.CreationStatements,
"revocation_statements": role.Statements.RevocationStatements,
"rollback_statements": role.Statements.RollbackStatements,
"renew_statements": role.Statements.RenewStatements,
"creation_statements": role.Statements.Creation,
"revocation_statements": role.Statements.Revocation,
"rollback_statements": role.Statements.Rollback,
"renew_statements": role.Statements.Renewal,
"default_ttl": role.DefaultTTL.Seconds(),
"max_ttl": role.MaxTTL.Seconds(),
},
@@ -144,10 +144,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
}
// Get statements
creationStmts := data.Get("creation_statements").(string)
revocationStmts := data.Get("revocation_statements").(string)
rollbackStmts := data.Get("rollback_statements").(string)
renewStmts := data.Get("renew_statements").(string)
creationStmts := data.Get("creation_statements").([]string)
revocationStmts := data.Get("revocation_statements").([]string)
rollbackStmts := data.Get("rollback_statements").([]string)
renewStmts := data.Get("renew_statements").([]string)
// Get TTLs
defaultTTLRaw := data.Get("default_ttl").(int)
@@ -156,10 +156,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
maxTTL := time.Duration(maxTTLRaw) * time.Second
statements := dbplugin.Statements{
CreationStatements: creationStmts,
RevocationStatements: revocationStmts,
RollbackStatements: rollbackStmts,
RenewStatements: renewStmts,
Creation: creationStmts,
Revocation: revocationStmts,
Rollback: rollbackStmts,
Renewal: renewStmts,
}
// Store it
@@ -181,10 +181,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
}
type roleEntry struct {
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
Statements dbplugin.Statements `json:"statements" mapstructure:"statements" structs:"statements"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
DBName string `json:"db_name"`
Statements dbplugin.Statements `json:"statements"`
DefaultTTL time.Duration `json:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl"`
}
const pathRoleHelpSyn = `

View File

@@ -0,0 +1,80 @@
package database
import (
"context"
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathRotateCredentials(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: "rotate-root/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of this database connection",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRotateCredentialsUpdate(),
},
HelpSynopsis: pathCredsCreateReadHelpSyn,
HelpDescription: pathCredsCreateReadHelpDesc,
}
}
func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse(respErrEmptyName), nil
}
config, err := b.DatabaseConfig(ctx, req.Storage, name)
if err != nil {
return nil, err
}
db, err := b.GetConnection(ctx, req.Storage, name)
if err != nil {
return nil, err
}
// Take the write lock instead of read since we are updating the
// connection
db.Lock()
defer db.Unlock()
connectionDetails, err := db.RotateRootCredentials(ctx, config.RootCredentialsRotateStatements)
if err != nil {
return nil, err
}
config.ConnectionDetails = connectionDetails
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
if err := b.ClearConnection(name); err != nil {
return nil, err
}
return nil, nil
}
}
const pathRotateCredentialsUpdateHelpSyn = `
Request to rotate the root credentials for a certain database connection.
`
const pathRotateCredentialsUpdateHelpDesc = `
This path attempts to rotate the root credentials for the given database.
`

View File

@@ -48,37 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
return nil, err
}
// Grab the read lock
b.RLock()
unlockFunc := b.RUnlock
// Get the Database object
db, ok := b.getDBObj(role.DBName)
if !ok {
// Upgrade lock
b.RUnlock()
b.Lock()
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
}
return nil, err
}
db.RLock()
defer db.RUnlock()
// Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
err := db.RenewUser(ctx, role.Statements, username, expireTime)
if err != nil {
unlockFunc()
b.closeIfShutdown(role.DBName, err)
b.CloseIfShutdown(db, err)
return nil, err
}
}
unlockFunc()
return resp, nil
}
}
@@ -107,33 +93,19 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"])
}
// Grab the read lock
b.RLock()
unlockFunc := b.RUnlock
// Get our connection
db, ok := b.getDBObj(role.DBName)
if !ok {
// Upgrade lock
b.RUnlock()
b.Lock()
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
}
}
if err := db.RevokeUser(ctx, role.Statements, username); err != nil {
unlockFunc()
b.closeIfShutdown(role.DBName, err)
return nil, err
}
unlockFunc()
db.RLock()
defer db.RUnlock()
if err := db.RevokeUser(ctx, role.Statements, username); err != nil {
b.CloseIfShutdown(db, err)
return nil, err
}
return resp, nil
}
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/plugins"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
)
@@ -19,6 +18,7 @@ import (
const (
defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
defaultUserDeletionCQL = `DROP USER '{{username}}';`
defaultRootCredentialRotationCQL = `ALTER USER {{username}} WITH PASSWORD '{{password}}';`
cassandraTypeName = "cassandra"
)
@@ -26,12 +26,19 @@ var _ dbplugin.Database = &Cassandra{}
// Cassandra is an implementation of Database interface
type Cassandra struct {
connutil.ConnectionProducer
*cassandraConnectionProducer
credsutil.CredentialsProducer
}
// New returns a new Cassandra instance
func New() (interface{}, error) {
db := new()
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}
func new() *Cassandra {
connProducer := &cassandraConnectionProducer{}
connProducer.Type = cassandraTypeName
@@ -42,12 +49,10 @@ func New() (interface{}, error) {
Separator: "_",
}
dbType := &Cassandra{
ConnectionProducer: connProducer,
return &Cassandra{
cassandraConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
return dbType, nil
}
// Run instantiates a Cassandra object, and runs the RPC server for the plugin
@@ -57,7 +62,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}
plugins.Serve(dbType.(*Cassandra), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil
}
@@ -83,19 +88,22 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
c.Lock()
defer c.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection
session, err := c.getConnection(ctx)
if err != nil {
return "", "", err
}
creationCQL := statements.CreationStatements
if creationCQL == "" {
creationCQL = defaultUserCreationCQL
creationCQL := statements.Creation
if len(creationCQL) == 0 {
creationCQL = []string{defaultUserCreationCQL}
}
rollbackCQL := statements.RollbackStatements
if rollbackCQL == "" {
rollbackCQL = defaultUserDeletionCQL
rollbackCQL := statements.Rollback
if len(rollbackCQL) == 0 {
rollbackCQL = []string{defaultUserDeletionCQL}
}
username, err = c.GenerateUsername(usernameConfig)
@@ -112,7 +120,8 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
}
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") {
for _, stmt := range creationCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -123,7 +132,8 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
"password": password,
})).Exec()
if err != nil {
for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") {
for _, stmt := range rollbackCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -133,9 +143,11 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
"username": username,
})).Exec()
}
}
return "", "", err
}
}
}
return username, password, nil
}
@@ -152,18 +164,21 @@ func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statemen
c.Lock()
defer c.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
session, err := c.getConnection(ctx)
if err != nil {
return err
}
revocationCQL := statements.RevocationStatements
if revocationCQL == "" {
revocationCQL = defaultUserDeletionCQL
revocationCQL := statements.Revocation
if len(revocationCQL) == 0 {
revocationCQL = []string{defaultUserDeletionCQL}
}
var result *multierror.Error
for _, query := range strutil.ParseArbitraryStringSlice(revocationCQL, ";") {
for _, stmt := range revocationCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -175,6 +190,53 @@ func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statemen
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}
func (c *Cassandra) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
// Grab the lock
c.Lock()
defer c.Unlock()
session, err := c.getConnection(ctx)
if err != nil {
return nil, err
}
rotateCQL := statements
if len(rotateCQL) == 0 {
rotateCQL = []string{defaultRootCredentialRotationCQL}
}
password, err := c.GeneratePassword()
if err != nil {
return nil, err
}
var result *multierror.Error
for _, stmt := range rotateCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
err := session.Query(dbutil.QueryHelper(query, map[string]string{
"username": c.Username,
"password": password,
})).Exec()
result = multierror.Append(result, err)
}
}
err = result.ErrorOrNil()
if err != nil {
return nil, err
}
c.rawConfig["password"] = password
return c.rawConfig, nil
}

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"github.com/gocql/gocql"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
@@ -60,7 +61,7 @@ func prepareCassandraTestContainer(t *testing.T) (func(), string, int) {
session, err := clusterConfig.CreateSession()
if err != nil {
return fmt.Errorf("error creating session: %s", err)
return errwrap.Wrapf("error creating session: {{err}}", err)
}
defer session.Close()
return nil
@@ -86,16 +87,13 @@ func TestCassandra_Initialize(t *testing.T) {
"protocol_version": 4,
}
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initalized")
}
@@ -113,7 +111,7 @@ func TestCassandra_Initialize(t *testing.T) {
"protocol_version": "4",
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -134,15 +132,14 @@ func TestCassandra_CreateUser(t *testing.T) {
"protocol_version": 4,
}
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testCassandraRole,
Creation: []string{testCassandraRole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -175,15 +172,14 @@ func TestMyCassandra_RenewUser(t *testing.T) {
"protocol_version": 4,
}
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testCassandraRole,
Creation: []string{testCassandraRole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -221,15 +217,14 @@ func TestCassandra_RevokeUser(t *testing.T) {
"protocol_version": 4,
}
dbRaw, _ := New()
db := dbRaw.(*Cassandra)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testCassandraRole,
Creation: []string{testCassandraRole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -268,7 +263,7 @@ func testCredsExist(t testing.TB, address string, port int, username, password s
session, err := clusterConfig.CreateSession()
if err != nil {
return fmt.Errorf("error creating session: %s", err)
return errwrap.Wrapf("error creating session: {{err}}", err)
}
defer session.Close()
return nil

View File

@@ -11,6 +11,7 @@ import (
"github.com/mitchellh/mapstructure"
"github.com/gocql/gocql"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/tlsutil"
@@ -37,6 +38,7 @@ type cassandraConnectionProducer struct {
certificate string
privateKey string
issuingCA string
rawConfig map[string]interface{}
Initialized bool
Type string
@@ -45,12 +47,19 @@ type cassandraConnectionProducer struct {
}
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock()
defer c.Unlock()
c.rawConfig = conf
err := mapstructure.WeakDecode(conf, c)
if err != nil {
return err
return nil, err
}
if c.ConnectTimeoutRaw == nil {
@@ -58,16 +67,16 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
}
c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
if err != nil {
return fmt.Errorf("invalid connect_timeout: %s", err)
return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
}
switch {
case len(c.Hosts) == 0:
return fmt.Errorf("hosts cannot be empty")
return nil, fmt.Errorf("hosts cannot be empty")
case len(c.Username) == 0:
return fmt.Errorf("username cannot be empty")
return nil, fmt.Errorf("username cannot be empty")
case len(c.Password) == 0:
return fmt.Errorf("password cannot be empty")
return nil, fmt.Errorf("password cannot be empty")
}
var certBundle *certutil.CertBundle
@@ -76,11 +85,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
case len(c.PemJSON) != 0:
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
if err != nil {
return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err)
return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return fmt.Errorf("Error marshaling PEM information: %s", err)
return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
}
c.certificate = certBundle.Certificate
c.privateKey = certBundle.PrivateKey
@@ -90,11 +99,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
case len(c.PemBundle) != 0:
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
if err != nil {
return fmt.Errorf("Error parsing the given PEM information: %s", err)
return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
}
certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil {
return fmt.Errorf("Error marshaling PEM information: %s", err)
return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
}
c.certificate = certBundle.Certificate
c.privateKey = certBundle.PrivateKey
@@ -108,11 +117,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
if verifyConnection {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
}
return nil
return conf, nil
}
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
@@ -186,12 +195,12 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
}
tlsConfig.InsecureSkipVerify = c.InsecureTLS
@@ -215,7 +224,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
session, err := clusterConfig.CreateSession()
if err != nil {
return nil, fmt.Errorf("error creating session: %s", err)
return nil, errwrap.Wrapf("error creating session: {{err}}", err)
}
// Set consistency
@@ -231,8 +240,16 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
// Verify the info
err = session.Query(`LIST ALL`).Exec()
if err != nil {
return nil, fmt.Errorf("error validating connection info: %s", err)
return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
}
return session, nil
}
func (c *cassandraConnectionProducer) secretValues() map[string]interface{} {
return map[string]interface{}{
c.Password: "[password]",
c.PemBundle: "[pem_bundle]",
c.PemJSON: "[pem_json]",
}
}

View File

@@ -572,7 +572,7 @@ ssl_storage_port: 7001
#
# Setting listen_address to 0.0.0.0 is always wrong.
#
listen_address: 172.17.0.5
listen_address: 172.17.0.2
# Set listen_address OR listen_interface, not both. Interfaces must correspond
# to a single address, IP aliasing is not supported.

View File

@@ -3,6 +3,7 @@ package hana
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
@@ -23,7 +24,7 @@ const (
// HANA is an implementation of Database interface
type HANA struct {
connutil.ConnectionProducer
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}
@@ -31,6 +32,14 @@ var _ dbplugin.Database = &HANA{}
// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
return dbType, nil
}
func new() *HANA {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = hanaTypeName
@@ -41,12 +50,10 @@ func New() (interface{}, error) {
Separator: "_",
}
dbType := &HANA{
ConnectionProducer: connProducer,
return &HANA{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
return dbType, nil
}
// Run instantiates a HANA object, and runs the RPC server for the plugin
@@ -56,7 +63,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}
plugins.Serve(dbType.(*HANA), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil
}
@@ -82,13 +89,15 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
h.Lock()
defer h.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection
db, err := h.getConnection(ctx)
if err != nil {
return "", "", err
}
if statements.CreationStatements == "" {
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
}
@@ -127,7 +136,8 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
for _, stmt := range statements.Creation {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -146,6 +156,7 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
return "", "", err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
@@ -157,6 +168,8 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
// Renewing hana user just means altering user's valid until property
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
statements = dbutil.StatementCompatibilityHelper(statements)
// Get connection
db, err := h.getConnection(ctx)
if err != nil {
@@ -197,8 +210,10 @@ func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, us
// Revoking hana user will deactivate user and try to perform a soft drop
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
statements = dbutil.StatementCompatibilityHelper(statements)
// default revoke will be a soft drop on user
if statements.RevocationStatements == "" {
if len(statements.Revocation) == 0 {
return h.revokeUserDefault(ctx, username)
}
@@ -216,7 +231,8 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") {
for _, stmt := range statements.Revocation {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -233,13 +249,9 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u
return err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
return tx.Commit()
}
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
@@ -284,3 +296,8 @@ func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
return nil
}
// RotateRootCredentials is not currently supported on HANA
func (h *HANA) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
return nil, errors.New("root credentaion rotation is not currently implemented in this database secrets engine")
}

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
)
func TestHANA_Initialize(t *testing.T) {
@@ -23,16 +22,13 @@ func TestHANA_Initialize(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initialized")
}
@@ -53,10 +49,8 @@ func TestHANA_CreateUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -73,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
}
statements := dbplugin.Statements{
CreationStatements: testHANARole,
Creation: []string{testHANARole},
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
@@ -96,16 +90,14 @@ func TestHANA_RevokeUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*HANA)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testHANARole,
Creation: []string{testHANARole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -139,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
statements.RevocationStatements = testHANADrop
statements.Revocation = []string{testHANADrop}
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)

View File

@@ -14,7 +14,9 @@ import (
"sync"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
"github.com/mitchellh/mapstructure"
"gopkg.in/mgo.v2"
@@ -25,28 +27,43 @@ import (
type mongoDBConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
Initialized bool
RawConfig map[string]interface{}
Type string
session *mgo.Session
safe *mgo.Safe
sync.Mutex
}
// Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
err := mapstructure.WeakDecode(conf, c)
if err != nil {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty")
// Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock()
defer c.Unlock()
c.RawConfig = conf
err := mapstructure.WeakDecode(conf, c)
if err != nil {
return nil, err
}
if len(c.ConnectionURL) == 0 {
return nil, fmt.Errorf("connection_url cannot be empty")
}
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
"username": c.Username,
"password": c.Password,
})
if c.WriteConcern != "" {
input := c.WriteConcern
@@ -60,13 +77,13 @@ func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[str
concern := &mgo.Safe{}
err = json.Unmarshal([]byte(input), concern)
if err != nil {
return fmt.Errorf("error mashalling write_concern: %s", err)
return nil, errwrap.Wrapf("error mashalling write_concern: {{err}}", err)
}
// Guard against empty, non-nil mgo.Safe object; we don't want to pass that
// into mgo.SetSafe in Connection().
if (mgo.Safe{} == *concern) {
return fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields")
return nil, fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields")
}
c.safe = concern
}
@@ -77,15 +94,15 @@ func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[str
if verifyConnection {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
if err := c.session.Ping(); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
}
return nil
return conf, nil
}
// Connection creates or returns an existing a database connection. If the session fails
@@ -203,3 +220,9 @@ func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
return &info, nil
}
func (c *mongoDBConnectionProducer) secretValues() map[string]interface{} {
return map[string]interface{}{
c.Password: "[password]",
}
}

View File

@@ -2,6 +2,7 @@ package mongodb
import (
"context"
"errors"
"io"
"strings"
"time"
@@ -14,7 +15,6 @@ import (
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
"gopkg.in/mgo.v2"
@@ -24,7 +24,7 @@ const mongoDBTypeName = "mongodb"
// MongoDB is an implementation of Database interface
type MongoDB struct {
connutil.ConnectionProducer
*mongoDBConnectionProducer
credsutil.CredentialsProducer
}
@@ -32,6 +32,12 @@ var _ dbplugin.Database = &MongoDB{}
// New returns a new MongoDB instance
func New() (interface{}, error) {
db := new()
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}
func new() *MongoDB {
connProducer := &mongoDBConnectionProducer{}
connProducer.Type = mongoDBTypeName
@@ -42,11 +48,10 @@ func New() (interface{}, error) {
Separator: "-",
}
dbType := &MongoDB{
ConnectionProducer: connProducer,
return &MongoDB{
mongoDBConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
return dbType, nil
}
// Run instantiates a MongoDB object, and runs the RPC server for the plugin
@@ -88,7 +93,9 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements
m.Lock()
defer m.Unlock()
if statements.CreationStatements == "" {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
}
@@ -109,7 +116,7 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements
// Unmarshal statements.CreationStatements into mongodbRoles
var mongoCS mongoDBStatement
err = json.Unmarshal([]byte(statements.CreationStatements), &mongoCS)
err = json.Unmarshal([]byte(statements.Creation[0]), &mongoCS)
if err != nil {
return "", "", err
}
@@ -158,15 +165,22 @@ func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements,
// RevokeUser drops the specified user from the authentication database. If none is provided
// in the revocation statement, the default "admin" authentication database will be assumed.
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
statements = dbutil.StatementCompatibilityHelper(statements)
session, err := m.getConnection(ctx)
if err != nil {
return err
}
// If no revocation statements provided, pass in empty JSON
revocationStatement := statements.RevocationStatements
if revocationStatement == "" {
var revocationStatement string
switch len(statements.Revocation) {
case 0:
revocationStatement = `{}`
case 1:
revocationStatement = statements.Revocation[0]
default:
return fmt.Errorf("expected 0 or 1 revocation statements, got %d", len(statements.Revocation))
}
// Unmarshal revocation statements into mongodbRoles
@@ -186,7 +200,7 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements
switch {
case err == nil, err == mgo.ErrNotFound:
case err == io.EOF, strings.Contains(err.Error(), "EOF"):
if err := m.ConnectionProducer.Close(); err != nil {
if err := m.Close(); err != nil {
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
}
session, err := m.getConnection(ctx)
@@ -203,3 +217,8 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements
return nil
}
// RotateRootCredentials is not currently supported on MongoDB
func (m *MongoDB) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
return nil, errors.New("root credentaion rotation is not currently implemented in this database secrets engine")
}

View File

@@ -73,19 +73,13 @@ func TestMongoDB_Initialize(t *testing.T) {
"connection_url": connURL,
}
dbRaw, err := New()
if err != nil {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
err = db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initialized")
}
@@ -103,18 +97,14 @@ func TestMongoDB_CreateUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, err := New()
if err != nil {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testMongoDBRole,
Creation: []string{testMongoDBRole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -141,18 +131,14 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
"write_concern": testMongoDBWriteConcern,
}
dbRaw, err := New()
if err != nil {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testMongoDBRole,
Creation: []string{testMongoDBRole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -178,18 +164,14 @@ func TestMongoDB_RevokeUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, err := New()
if err != nil {
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testMongoDBRole,
Creation: []string{testMongoDBRole},
}
usernameConfig := dbplugin.UsernameConfig{

View File

@@ -3,11 +3,13 @@ package mssql
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
_ "github.com/denisenkom/go-mssqldb"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil"
@@ -23,11 +25,19 @@ var _ dbplugin.Database = &MSSQL{}
// MSSQL is an implementation of Database interface
type MSSQL struct {
connutil.ConnectionProducer
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}
func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
return dbType, nil
}
func new() *MSSQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = msSQLTypeName
@@ -38,12 +48,10 @@ func New() (interface{}, error) {
Separator: "-",
}
dbType := &MSSQL{
ConnectionProducer: connProducer,
return &MSSQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
return dbType, nil
}
// Run instantiates a MSSQL object, and runs the RPC server for the plugin
@@ -53,7 +61,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}
plugins.Serve(dbType.(*MSSQL), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil
}
@@ -79,13 +87,15 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
m.Lock()
defer m.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
if statements.CreationStatements == "" {
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
}
@@ -112,7 +122,8 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
for _, stmt := range statements.Creation {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -131,6 +142,7 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
return "", "", err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
@@ -150,7 +162,9 @@ func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, u
// then kill pending connections from that user, and finally drop the user and login from the
// database instance.
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
if statements.RevocationStatements == "" {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Revocation) == 0 {
return m.revokeUserDefault(ctx, username)
}
@@ -168,7 +182,8 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") {
for _, stmt := range statements.Revocation {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -185,6 +200,7 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
return err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
@@ -283,10 +299,10 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
// can't drop if not all database users are dropped
if rows.Err() != nil {
return fmt.Errorf("could not generate sql statements for all rows: %s", rows.Err())
return errwrap.Wrapf("could not generate sql statements for all rows: {{err}}", rows.Err())
}
if lastStmtError != nil {
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
return errwrap.Wrapf("could not perform all sql statements: {{err}}", lastStmtError)
}
// Drop this login
@@ -302,6 +318,70 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
return nil
}
func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
m.Lock()
defer m.Unlock()
if len(m.Username) == 0 || len(m.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{rotateRootCredentialsSQL}
}
db, err := m.getConnection(ctx)
if err != nil {
return nil, err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := m.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"username": m.Username,
"password": password,
}))
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
if err := db.Close(); err != nil {
return nil, err
}
m.RawConfig["password"] = password
return m.RawConfig, nil
}
const dropUserSQL = `
USE [%s]
IF EXISTS
@@ -322,3 +402,7 @@ BEGIN
DROP LOGIN [%s]
END
`
const rotateRootCredentialsSQL = `
ALTER LOGIN [%s] WITH PASSWORD = '%s'
`

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
)
var (
@@ -28,16 +27,13 @@ func TestMSSQL_Initialize(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initalized")
}
@@ -52,7 +48,7 @@ func TestMSSQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -68,9 +64,8 @@ func TestMSSQL_CreateUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -87,7 +82,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
}
statements := dbplugin.Statements{
CreationStatements: testMSSQLRole,
Creation: []string{testMSSQLRole},
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
@@ -110,15 +105,14 @@ func TestMSSQL_RevokeUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*MSSQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testMSSQLRole,
Creation: []string{testMSSQLRole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -155,7 +149,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
}
// Test custom revoke statement
statements.RevocationStatements = testMSSQLDrop
statements.Revocation = []string{testMSSQLDrop}
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)

View File

@@ -3,6 +3,7 @@ package mysql
import (
"context"
"database/sql"
"errors"
"strings"
"time"
@@ -21,6 +22,11 @@ const (
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
DROP USER '{{name}}'@'%'
`
defaultMySQLRotateRootCredentialsSQL = `
ALTER USER '{{username}}'@'%' IDENTIFIED BY '{{password}}';
`
mySQLTypeName = "mysql"
)
@@ -34,13 +40,22 @@ var (
var _ dbplugin.Database = &MySQL{}
type MySQL struct {
connutil.ConnectionProducer
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}
// New implements builtinplugins.BuiltinFactory
func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, error) {
return func() (interface{}, error) {
db := new(displayNameLen, roleNameLen, usernameLen)
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
return dbType, nil
}
}
func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = mySQLTypeName
@@ -51,13 +66,10 @@ func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, erro
Separator: "-",
}
dbType := &MySQL{
ConnectionProducer: connProducer,
return &MySQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
return dbType, nil
}
}
// Run instantiates a MySQL object, and runs the RPC server for the plugin
@@ -82,7 +94,7 @@ func runCommon(legacy bool, apiTLSConfig *api.TLSConfig) error {
return err
}
plugins.Serve(dbType.(*MySQL), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil
}
@@ -105,13 +117,15 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
m.Lock()
defer m.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection
db, err := m.getConnection(ctx)
if err != nil {
return "", "", err
}
if statements.CreationStatements == "" {
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
}
@@ -138,7 +152,8 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
for _, stmt := range statements.Creation {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -172,6 +187,7 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
return "", "", err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
@@ -191,16 +207,18 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
m.Lock()
defer m.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection
db, err := m.getConnection(ctx)
if err != nil {
return err
}
revocationStmts := statements.RevocationStatements
revocationStmts := statements.Revocation
// Use a default SQL statement for revocation if one cannot be fetched from the role
if revocationStmts == "" {
revocationStmts = defaultMysqlRevocationStmts
if len(revocationStmts) == 0 {
revocationStmts = []string{defaultMysqlRevocationStmts}
}
// Start a transaction
@@ -210,7 +228,8 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
}
defer tx.Rollback()
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
for _, stmt := range revocationStmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -224,7 +243,7 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
if err != nil {
return err
}
}
}
// Commit the transaction
@@ -234,3 +253,67 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
return nil
}
func (m *MySQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
m.Lock()
defer m.Unlock()
if len(m.Username) == 0 || len(m.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{defaultMySQLRotateRootCredentialsSQL}
}
db, err := m.getConnection(ctx)
if err != nil {
return nil, err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := m.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"username": m.Username,
"password": password,
}))
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
if err := db.Close(); err != nil {
return nil, err
}
m.RawConfig["password"] = password
return m.RawConfig, nil
}

View File

@@ -9,9 +9,9 @@ import (
"testing"
"time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
@@ -104,17 +104,13 @@ func TestMySQL_Initialize(t *testing.T) {
"connection_url": connURL,
}
f := New(MetadataLen, MetadataLen, UsernameLen)
dbRaw, _ := f()
db := dbRaw.(*MySQL)
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new(MetadataLen, MetadataLen, UsernameLen)
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initalized")
}
@@ -129,7 +125,7 @@ func TestMySQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -143,11 +139,8 @@ func TestMySQL_CreateUser(t *testing.T) {
"connection_url": connURL,
}
f := New(MetadataLen, MetadataLen, UsernameLen)
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new(MetadataLen, MetadataLen, UsernameLen)
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -164,7 +157,7 @@ func TestMySQL_CreateUser(t *testing.T) {
}
statements := dbplugin.Statements{
CreationStatements: testMySQLRoleWildCard,
Creation: []string{testMySQLRoleWildCard},
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
@@ -187,7 +180,7 @@ func TestMySQL_CreateUser(t *testing.T) {
}
// Test with a manually prepare statement
statements.CreationStatements = testMySQLRolePreparedStmt
statements.Creation = []string{testMySQLRolePreparedStmt}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
@@ -208,11 +201,8 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
"connection_url": connURL,
}
f := New(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen)
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen)
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -229,7 +219,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
}
statements := dbplugin.Statements{
CreationStatements: testMySQLRoleWildCard,
Creation: []string{testMySQLRoleWildCard},
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
@@ -252,6 +242,42 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
}
}
func TestMySQL_RotateRootCredentials(t *testing.T) {
cleanup, connURL := prepareMySQLTestContainer(t)
defer cleanup()
connURL = strings.Replace(connURL, "root:secret", `{{username}}:{{password}}`, -1)
connectionDetails := map[string]interface{}{
"connection_url": connURL,
"username": "root",
"password": "secret",
}
db := new(MetadataLen, MetadataLen, UsernameLen)
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !db.Initialized {
t.Fatal("Database should be initalized")
}
newConf, err := db.RotateRootCredentials(context.Background(), nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if newConf["password"] == "secret" {
t.Fatal("password was not updated")
}
err = db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestMySQL_RevokeUser(t *testing.T) {
cleanup, connURL := prepareMySQLTestContainer(t)
defer cleanup()
@@ -260,17 +286,14 @@ func TestMySQL_RevokeUser(t *testing.T) {
"connection_url": connURL,
}
f := New(MetadataLen, MetadataLen, UsernameLen)
dbRaw, _ := f()
db := dbRaw.(*MySQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new(MetadataLen, MetadataLen, UsernameLen)
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testMySQLRoleWildCard,
Creation: []string{testMySQLRoleWildCard},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -297,7 +320,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
t.Fatal("Credentials were not revoked")
}
statements.CreationStatements = testMySQLRoleWildCard
statements.Creation = []string{testMySQLRoleWildCard}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
@@ -308,7 +331,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
}
// Test custom revoke statements
statements.RevocationStatements = testMySQLRevocationSQL
statements.Revocation = []string{testMySQLRevocationSQL}
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)

View File

@@ -3,10 +3,12 @@ package postgresql
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil"
@@ -15,13 +17,15 @@ import (
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
"github.com/lib/pq"
_ "github.com/lib/pq"
)
const (
postgreSQLTypeName string = "postgres"
postgreSQLTypeName = "postgres"
defaultPostgresRenewSQL = `
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
`
defaultPostgresRotateRootCredentialsSQL = `
ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
`
)
@@ -29,6 +33,13 @@ var _ dbplugin.Database = &PostgreSQL{}
// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
return dbType, nil
}
func new() *PostgreSQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = postgreSQLTypeName
@@ -39,12 +50,12 @@ func New() (interface{}, error) {
Separator: "-",
}
dbType := &PostgreSQL{
ConnectionProducer: connProducer,
db := &PostgreSQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
return dbType, nil
return db
}
// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin
@@ -54,13 +65,13 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err
}
plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig)
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil
}
type PostgreSQL struct {
connutil.ConnectionProducer
*connutil.SQLConnectionProducer
credsutil.CredentialsProducer
}
@@ -78,7 +89,9 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
}
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
if statements.CreationStatements == "" {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
}
@@ -105,7 +118,6 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
db, err := p.getConnection(ctx)
if err != nil {
return "", "", err
}
// Start a transaction
@@ -120,7 +132,8 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
// Return the secret
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
for _, stmt := range statements.Creation {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -133,12 +146,11 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
}))
if err != nil {
return "", "", err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
}
}
}
@@ -155,9 +167,11 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
p.Lock()
defer p.Unlock()
renewStmts := statements.RenewStatements
if renewStmts == "" {
renewStmts = defaultPostgresRenewSQL
statements = dbutil.StatementCompatibilityHelper(statements)
renewStmts := statements.Renewal
if len(renewStmts) == 0 {
renewStmts = []string{defaultPostgresRenewSQL}
}
db, err := p.getConnection(ctx)
@@ -178,7 +192,8 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
return err
}
for _, query := range strutil.ParseArbitraryStringSlice(renewStmts, ";") {
for _, stmt := range renewStmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -196,12 +211,9 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
return tx.Commit()
}
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
@@ -209,14 +221,16 @@ func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Stateme
p.Lock()
defer p.Unlock()
if statements.RevocationStatements == "" {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Revocation) == 0 {
return p.defaultRevokeUser(ctx, username)
}
return p.customRevokeUser(ctx, username, statements.RevocationStatements)
return p.customRevokeUser(ctx, username, statements.Revocation)
}
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error {
db, err := p.getConnection(ctx)
if err != nil {
return err
@@ -230,7 +244,8 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationS
tx.Rollback()
}()
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
for _, stmt := range revocationStmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
@@ -248,12 +263,9 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationS
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
return tx.Commit()
}
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
@@ -354,10 +366,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
// can't drop if not all privileges are revoked
if rows.Err() != nil {
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err())
}
if lastStmtError != nil {
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError)
}
// Drop this user
@@ -373,3 +385,68 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
return nil
}
func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
p.Lock()
defer p.Unlock()
if len(p.Username) == 0 || len(p.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{defaultPostgresRotateRootCredentialsSQL}
}
db, err := p.getConnection(ctx)
if err != nil {
return nil, err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := p.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"username": p.Username,
"password": password,
}))
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
// Close the database connection to ensure no new connections come in
if err := db.Close(); err != nil {
return nil, err
}
p.RawConfig["password"] = password
return p.RawConfig, nil
}

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
@@ -68,17 +67,13 @@ func TestPostgreSQL_Initialize(t *testing.T) {
"max_open_connections": 5,
}
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.Initialized {
if !db.Initialized {
t.Fatal("Database should be initalized")
}
@@ -93,7 +88,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
"max_open_connections": "5",
}
err = db.Initialize(context.Background(), connectionDetails, true)
_, err = db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -108,9 +103,8 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
@@ -127,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
}
statements := dbplugin.Statements{
CreationStatements: testPostgresRole,
Creation: []string{testPostgresRole},
}
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
@@ -139,7 +133,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err)
}
statements.CreationStatements = testPostgresReadOnlyRole
statements.Creation = []string{testPostgresReadOnlyRole}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
@@ -161,15 +155,14 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testPostgresRole,
Creation: []string{testPostgresRole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -197,7 +190,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
statements.RenewStatements = defaultPostgresRenewSQL
statements.Renewal = []string{defaultPostgresRenewSQL}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
@@ -221,6 +214,46 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
}
func TestPostgreSQL_RotateRootCredentials(t *testing.T) {
cleanup, connURL := preparePostgresTestContainer(t)
defer cleanup()
connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1)
connectionDetails := map[string]interface{}{
"connection_url": connURL,
"max_open_connections": 5,
"username": "postgres",
"password": "secret",
}
db := new()
connProducer := db.SQLConnectionProducer
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.Initialized {
t.Fatal("Database should be initalized")
}
newConf, err := db.RotateRootCredentials(context.Background(), nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if newConf["password"] == "secret" {
t.Fatal("password was not updated")
}
err = db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPostgreSQL_RevokeUser(t *testing.T) {
cleanup, connURL := preparePostgresTestContainer(t)
defer cleanup()
@@ -229,15 +262,14 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
"connection_url": connURL,
}
dbRaw, _ := New()
db := dbRaw.(*PostgreSQL)
err := db.Initialize(context.Background(), connectionDetails, true)
db := new()
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
CreationStatements: testPostgresRole,
Creation: []string{testPostgresRole},
}
usernameConfig := dbplugin.UsernameConfig{
@@ -274,7 +306,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
}
// Test custom revoke statements
statements.RevocationStatements = defaultPostgresRevocationSQL
statements.Revocation = []string{defaultPostgresRevocationSQL}
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
@@ -286,6 +318,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
}
func testCredsExist(t testing.TB, connURL, username, password string) error {
t.Helper()
// Log in with the new creds
connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1)
db, err := sql.Open("postgres", connURL)

View File

@@ -15,8 +15,11 @@ var (
// connections and is used in all the builtin database types.
type ConnectionProducer interface {
Close() error
Initialize(context.Context, map[string]interface{}, bool) error
Init(context.Context, map[string]interface{}, bool) (map[string]interface{}, error)
Connection(context.Context) (interface{}, error)
sync.Locker
// DEPRECATED, will be removed in 0.12
Initialize(context.Context, map[string]interface{}, bool) error
}

View File

@@ -8,18 +8,25 @@ import (
"sync"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
"github.com/mitchellh/mapstructure"
)
var _ ConnectionProducer = &SQLConnectionProducer{}
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type SQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
Username string `json:"username" mapstructure:"username" structs:"username"`
Password string `json:"password" mapstructure:"password" structs:"password"`
Type string
RawConfig map[string]interface{}
maxConnectionLifetime time.Duration
Initialized bool
db *sql.DB
@@ -27,18 +34,30 @@ type SQLConnectionProducer struct {
}
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
err := mapstructure.WeakDecode(conf, c)
if err != nil {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty")
func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock()
defer c.Unlock()
c.RawConfig = conf
err := mapstructure.WeakDecode(conf, &c)
if err != nil {
return nil, err
}
if len(c.ConnectionURL) == 0 {
return nil, fmt.Errorf("connection_url cannot be empty")
}
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
"username": c.Username,
"password": c.Password,
})
if c.MaxOpenConnections == 0 {
c.MaxOpenConnections = 2
}
@@ -55,7 +74,7 @@ func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
if err != nil {
return fmt.Errorf("invalid max_connection_lifetime: %s", err)
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
}
// Set initialized to true at this point since all fields are set,
@@ -64,15 +83,15 @@ func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]
if verifyConnection {
if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
if err := c.db.PingContext(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
}
}
return nil
return c.RawConfig, nil
}
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
@@ -123,6 +142,12 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
return c.db, nil
}
func (c *SQLConnectionProducer) SecretValues() map[string]interface{} {
return map[string]interface{}{
c.Password: "[password]",
}
}
// Close attempts to close the connection
func (c *SQLConnectionProducer) Close() error {
// Grab the write lock

View File

@@ -4,6 +4,8 @@ import (
"errors"
"fmt"
"strings"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
)
var (
@@ -18,3 +20,33 @@ func QueryHelper(tpl string, data map[string]string) string {
return tpl
}
// StatementCompatibilityHelper will populate the statements fields to support
// compatibility
func StatementCompatibilityHelper(statements dbplugin.Statements) dbplugin.Statements {
switch {
case len(statements.Creation) > 0 && len(statements.CreationStatements) == 0:
statements.CreationStatements = strings.Join(statements.Creation, ";")
case len(statements.CreationStatements) > 0:
statements.Creation = []string{statements.CreationStatements}
}
switch {
case len(statements.Revocation) > 0 && len(statements.RevocationStatements) == 0:
statements.RevocationStatements = strings.Join(statements.Revocation, ";")
case len(statements.RevocationStatements) > 0:
statements.Revocation = []string{statements.RevocationStatements}
}
switch {
case len(statements.Renewal) > 0 && len(statements.RenewStatements) == 0:
statements.RenewStatements = strings.Join(statements.Renewal, ";")
case len(statements.RenewStatements) > 0:
statements.Renewal = []string{statements.RenewStatements}
}
switch {
case len(statements.Rollback) > 0 && len(statements.RollbackStatements) == 0:
statements.RollbackStatements = strings.Join(statements.Rollback, ";")
case len(statements.RollbackStatements) > 0:
statements.Rollback = []string{statements.RollbackStatements}
}
return statements
}

View File

@@ -0,0 +1,62 @@
package dbutil
import (
"reflect"
"testing"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
)
func TestStatementCompatibilityHelper(t *testing.T) {
const (
creationStatement = "creation"
renewStatement = "renew"
revokeStatement = "revoke"
rollbackStatement = "rollback"
)
expectedStatements := dbplugin.Statements{
Creation: []string{creationStatement},
Rollback: []string{rollbackStatement},
Revocation: []string{revokeStatement},
Renewal: []string{renewStatement},
CreationStatements: creationStatement,
RenewStatements: renewStatement,
RollbackStatements: rollbackStatement,
RevocationStatements: revokeStatement,
}
statements1 := dbplugin.Statements{
CreationStatements: creationStatement,
RenewStatements: renewStatement,
RollbackStatements: rollbackStatement,
RevocationStatements: revokeStatement,
}
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements1)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements1)
}
statements2 := dbplugin.Statements{
Creation: []string{creationStatement},
Rollback: []string{rollbackStatement},
Revocation: []string{revokeStatement},
Renewal: []string{renewStatement},
}
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements2)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements2)
}
statements3 := dbplugin.Statements{
CreationStatements: creationStatement,
}
expectedStatements3 := dbplugin.Statements{
Creation: []string{creationStatement},
CreationStatements: creationStatement,
}
if !reflect.DeepEqual(expectedStatements3, StatementCompatibilityHelper(statements3)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements3, statements3)
}
}