Database Root Credential Rotation (#3976)

* redoing connection handling

* a little more cleanup

* empty implementation of rotation

* updating rotate signature

* signature update

* updating interfaces again :(

* changing back to interface

* adding templated url support and rotation for postgres

* adding correct username

* return updates

* updating statements to be a list

* adding error sanitizing middleware

* fixing log sanitizier

* adding postgres rotate test

* removing conf from rotate

* adding rotate command

* adding mysql rotate

* finishing up the endpoint in the db backend for rotate

* no more structs, just store raw config

* fixing tests

* adding db instance lock

* adding support for statement list in cassandra

* wip redoing interface to support BC

* adding falllback for Initialize implementation

* adding backwards compat for statements

* fix tests

* fix more tests

* fixing up tests, switching to new fields in statements

* fixing more tests

* adding mssql and mysql

* wrapping all the things in middleware, implementing templating for mongodb

* wrapping all db servers with error santizer

* fixing test

* store the name with the db instance

* adding rotate to cassandra

* adding compatibility translation to both server and plugin

* reordering a few things

* store the name with the db instance

* reordering

* adding a few more tests

* switch secret values from slice to map

* addressing some feedback

* reinstate execute plugin after resetting connection

* set database connection to closed

* switching secret values func to map[string]interface for potential future uses

* addressing feedback
This commit is contained in:
Chris Hoffman
2018-03-21 15:05:56 -04:00
committed by GitHub
parent 1c443f22fe
commit 44aa151b78
33 changed files with 1974 additions and 777 deletions

View File

@@ -9,13 +9,37 @@ import (
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
"github.com/hashicorp/errwrap"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
) )
const databaseConfigPath = "database/config/" const databaseConfigPath = "database/config/"
type dbPluginInstance struct {
sync.RWMutex
dbplugin.Database
id string
name string
closed bool
}
func (dbi *dbPluginInstance) Close() error {
dbi.Lock()
defer dbi.Unlock()
if dbi.closed {
return nil
}
dbi.closed = true
return dbi.Database.Close()
}
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend(conf) b := Backend(conf)
if err := b.Setup(ctx, conf); err != nil { if err := b.Setup(ctx, conf); err != nil {
@@ -42,6 +66,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
pathRoles(&b), pathRoles(&b),
pathCredsCreate(&b), pathCredsCreate(&b),
pathResetConnection(&b), pathResetConnection(&b),
pathRotateCredentials(&b),
}, },
Secrets: []*framework.Secret{ Secrets: []*framework.Secret{
@@ -53,72 +78,22 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
} }
b.logger = conf.Logger b.logger = conf.Logger
b.connections = make(map[string]dbplugin.Database) b.connections = make(map[string]*dbPluginInstance)
return &b return &b
} }
type databaseBackend struct { type databaseBackend struct {
connections map[string]dbplugin.Database connections map[string]*dbPluginInstance
logger log.Logger logger log.Logger
*framework.Backend *framework.Backend
sync.RWMutex sync.RWMutex
} }
// closeAllDBs closes all connections from all database types
func (b *databaseBackend) closeAllDBs(ctx context.Context) {
b.Lock()
defer b.Unlock()
for _, db := range b.connections {
db.Close()
}
b.connections = make(map[string]dbplugin.Database)
}
// This function is used to retrieve a database object either from the cached
// connection map. The caller of this function needs to hold the backend's read
// lock.
func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
db, ok := b.connections[name]
return db, ok
}
// This function creates a new db object from the stored configuration and
// caches it in the connections map. The caller of this function needs to hold
// the backend's write lock
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
db, ok := b.connections[name]
if ok {
return db, nil
}
config, err := b.DatabaseConfig(ctx, s, name)
if err != nil {
return nil, err
}
db, err = dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
return nil, err
}
err = db.Initialize(ctx, config.ConnectionDetails, true)
if err != nil {
db.Close()
return nil, err
}
b.connections[name] = db
return db, nil
}
func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) { func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) {
entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name)) entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read connection configuration: %s", err) return nil, errwrap.Wrapf("failed to read connection configuration: {{err}}", err)
} }
if entry == nil { if entry == nil {
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
@@ -144,7 +119,7 @@ type upgradeStatements struct {
type upgradeCheck struct { type upgradeCheck struct {
// This json tag has a typo in it, the new version does not. This // This json tag has a typo in it, the new version does not. This
// necessitates this upgrade logic. // necessitates this upgrade logic.
Statements upgradeStatements `json:"statments"` Statements *upgradeStatements `json:"statments,omitempty"`
} }
func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) { func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) {
@@ -166,48 +141,140 @@ func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName
return nil, err return nil, err
} }
empty := upgradeCheck{} switch {
if upgradeCh != empty { case upgradeCh.Statements != nil:
result.Statements.CreationStatements = upgradeCh.Statements.CreationStatements var stmts dbplugin.Statements
result.Statements.RevocationStatements = upgradeCh.Statements.RevocationStatements if upgradeCh.Statements.CreationStatements != "" {
result.Statements.RollbackStatements = upgradeCh.Statements.RollbackStatements stmts.Creation = []string{upgradeCh.Statements.CreationStatements}
result.Statements.RenewStatements = upgradeCh.Statements.RenewStatements }
if upgradeCh.Statements.RevocationStatements != "" {
stmts.Revocation = []string{upgradeCh.Statements.RevocationStatements}
}
if upgradeCh.Statements.RollbackStatements != "" {
stmts.Rollback = []string{upgradeCh.Statements.RollbackStatements}
}
if upgradeCh.Statements.RenewStatements != "" {
stmts.Renewal = []string{upgradeCh.Statements.RenewStatements}
}
result.Statements = stmts
} }
// For backwards compatibility, copy the values back into the string form
// of the fields
result.Statements = dbutil.StatementCompatibilityHelper(result.Statements)
return &result, nil return &result, nil
} }
func (b *databaseBackend) invalidate(ctx context.Context, key string) { func (b *databaseBackend) invalidate(ctx context.Context, key string) {
b.Lock()
defer b.Unlock()
switch { switch {
case strings.HasPrefix(key, databaseConfigPath): case strings.HasPrefix(key, databaseConfigPath):
name := strings.TrimPrefix(key, databaseConfigPath) name := strings.TrimPrefix(key, databaseConfigPath)
b.clearConnection(name) b.ClearConnection(name)
} }
} }
// clearConnection closes the database connection and func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage, name string) (*dbPluginInstance, error) {
// removes it from the b.connections map. b.RLock()
func (b *databaseBackend) clearConnection(name string) { unlockFunc := b.RUnlock
defer func() { unlockFunc() }()
db, ok := b.connections[name] db, ok := b.connections[name]
if ok { if ok {
return db, nil
}
// Upgrade lock
b.RUnlock()
b.Lock()
unlockFunc = b.Unlock
db, ok = b.connections[name]
if ok {
return db, nil
}
config, err := b.DatabaseConfig(ctx, s, name)
if err != nil {
return nil, err
}
dbp, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
if err != nil {
return nil, err
}
_, err = dbp.Init(ctx, config.ConnectionDetails, true)
if err != nil {
dbp.Close()
return nil, err
}
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
db = &dbPluginInstance{
Database: dbp,
name: name,
id: id,
}
b.connections[name] = db
return db, nil
}
// ClearConnection closes the database connection and
// removes it from the b.connections map.
func (b *databaseBackend) ClearConnection(name string) error {
b.Lock()
defer b.Unlock()
return b.clearConnection(name)
}
func (b *databaseBackend) clearConnection(name string) error {
db, ok := b.connections[name]
if ok {
// Ignore error here since the database client is always killed
db.Close() db.Close()
delete(b.connections, name) delete(b.connections, name)
} }
return nil
} }
func (b *databaseBackend) closeIfShutdown(name string, err error) { func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) {
// Plugin has shutdown, close it so next call can reconnect. // Plugin has shutdown, close it so next call can reconnect.
switch err { switch err {
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown: case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
b.Lock() // Put this in a goroutine so that requests can run with the read or write lock
b.clearConnection(name) // and simply defer the unlock. Since we are attaching the instance and matching
b.Unlock() // the id in the conneciton map, we can safely do this.
go func() {
b.Lock()
defer b.Unlock()
db.Close()
// Ensure we are deleting the correct connection
mapDB, ok := b.connections[db.name]
if ok && db.id == mapDB.id {
delete(b.connections, db.name)
}
}()
} }
} }
// closeAllDBs closes all connections from all database types
func (b *databaseBackend) closeAllDBs(ctx context.Context) {
b.Lock()
defer b.Unlock()
for _, db := range b.connections {
db.Close()
}
b.connections = make(map[string]*dbPluginInstance)
}
const backendHelp = ` const backendHelp = `
The database backend supports using many different databases The database backend supports using many different databases
as secret backends, including but not limited to: as secret backends, including but not limited to:

View File

@@ -7,6 +7,7 @@ import (
"log" "log"
"os" "os"
"reflect" "reflect"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -27,6 +28,7 @@ var (
) )
func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) { func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) {
t.Helper()
if os.Getenv("PG_URL") != "" { if os.Getenv("PG_URL") != "" {
return func() {}, os.Getenv("PG_URL") return func() {}, os.Getenv("PG_URL")
} }
@@ -64,7 +66,7 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac
}) })
if err != nil || (resp != nil && resp.IsError()) { if err != nil || (resp != nil && resp.IsError()) {
// It's likely not up and running yet, so return error and try again // It's likely not up and running yet, so return error and try again
return fmt.Errorf("err:%s resp:%#v\n", err, resp) return fmt.Errorf("err:%#v resp:%#v", err, resp)
} }
if resp == nil { if resp == nil {
t.Fatal("expected warning") t.Fatal("expected warning")
@@ -123,13 +125,18 @@ func TestBackend_RoleUpgrade(t *testing.T) {
storage := &logical.InmemStorage{} storage := &logical.InmemStorage{}
backend := &databaseBackend{} backend := &databaseBackend{}
roleEnt := &roleEntry{ roleExpected := &roleEntry{
Statements: dbplugin.Statements{ Statements: dbplugin.Statements{
CreationStatements: "test", CreationStatements: "test",
Creation: []string{"test"},
}, },
} }
entry, err := logical.StorageEntryJSON("role/test", roleEnt) entry, err := logical.StorageEntryJSON("role/test", &roleEntry{
Statements: dbplugin.Statements{
CreationStatements: "test",
},
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -142,8 +149,8 @@ func TestBackend_RoleUpgrade(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(role, roleEnt) { if !reflect.DeepEqual(role, roleExpected) {
t.Fatalf("bad role %#v", role) t.Fatalf("bad role %#v, %#v", role, roleExpected)
} }
// Upgrade case // Upgrade case
@@ -161,8 +168,8 @@ func TestBackend_RoleUpgrade(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if !reflect.DeepEqual(role, roleEnt) { if !reflect.DeepEqual(role, roleExpected) {
t.Fatalf("bad role %#v", role) t.Fatalf("bad role %#v, %#v", role, roleExpected)
} }
} }
@@ -206,7 +213,8 @@ func TestBackend_config_connection(t *testing.T) {
"connection_details": map[string]interface{}{ "connection_details": map[string]interface{}{
"connection_url": "sample_connection_url", "connection_url": "sample_connection_url",
}, },
"allowed_roles": []string{"*"}, "allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
} }
configReq.Operation = logical.ReadOperation configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(context.Background(), configReq) resp, err = b.HandleRequest(context.Background(), configReq)
@@ -233,6 +241,55 @@ func TestBackend_config_connection(t *testing.T) {
} }
} }
func TestBackend_BadConnectionString(t *testing.T) {
cluster, sys := getCluster(t)
defer cluster.Cleanup()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup(context.Background())
cleanup, _ := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()
respCheck := func(req *logical.Request) {
t.Helper()
resp, err := b.HandleRequest(context.Background(), req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil || !resp.IsError() {
t.Fatalf("expected error, resp:%#v", resp)
}
err = resp.Error()
if strings.Contains(err.Error(), "localhost") {
t.Fatalf("error should not contain connection info")
}
}
// Configure a connection
data := map[string]interface{}{
"connection_url": "postgresql://:pw@[localhost",
"plugin_name": "postgresql-database-plugin",
"allowed_roles": []string{"plugin-role-test"},
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: data,
}
respCheck(req)
time.Sleep(1 * time.Second)
}
func TestBackend_basic(t *testing.T) { func TestBackend_basic(t *testing.T) {
cluster, sys := getCluster(t) cluster, sys := getCluster(t)
defer cluster.Cleanup() defer cluster.Cleanup()
@@ -388,7 +445,6 @@ func TestBackend_basic(t *testing.T) {
if testCredsExist(t, credsResp, connURL) { if testCredsExist(t, credsResp, connURL) {
t.Fatalf("Creds should not exist") t.Fatalf("Creds should not exist")
} }
} }
func TestBackend_connectionCrud(t *testing.T) { func TestBackend_connectionCrud(t *testing.T) {
@@ -467,7 +523,8 @@ func TestBackend_connectionCrud(t *testing.T) {
"connection_details": map[string]interface{}{ "connection_details": map[string]interface{}{
"connection_url": connURL, "connection_url": connURL,
}, },
"allowed_roles": []string{"plugin-role-test"}, "allowed_roles": []string{"plugin-role-test"},
"root_credentials_rotate_statements": []string{},
} }
req.Operation = logical.ReadOperation req.Operation = logical.ReadOperation
resp, err = b.HandleRequest(context.Background(), req) resp, err = b.HandleRequest(context.Background(), req)
@@ -602,15 +659,15 @@ func TestBackend_roleCrud(t *testing.T) {
} }
expected := dbplugin.Statements{ expected := dbplugin.Statements{
CreationStatements: testRole, Creation: []string{strings.TrimSpace(testRole)},
RevocationStatements: defaultRevocationSQL, Revocation: []string{strings.TrimSpace(defaultRevocationSQL)},
} }
actual := dbplugin.Statements{ actual := dbplugin.Statements{
CreationStatements: resp.Data["creation_statements"].(string), Creation: resp.Data["creation_statements"].([]string),
RevocationStatements: resp.Data["revocation_statements"].(string), Revocation: resp.Data["revocation_statements"].([]string),
RollbackStatements: resp.Data["rollback_statements"].(string), Rollback: resp.Data["rollback_statements"].([]string),
RenewStatements: resp.Data["renew_statements"].(string), Renewal: resp.Data["renew_statements"].([]string),
} }
if !reflect.DeepEqual(expected, actual) { if !reflect.DeepEqual(expected, actual) {

View File

@@ -1,5 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go.
// source: builtin/logical/database/dbplugin/database.proto // source: builtin/logical/database/dbplugin/database.proto
// DO NOT EDIT!
/* /*
Package dbplugin is a generated protocol buffer package. Package dbplugin is a generated protocol buffer package.
@@ -9,13 +10,17 @@ It is generated from these files:
It has these top-level messages: It has these top-level messages:
InitializeRequest InitializeRequest
InitRequest
CreateUserRequest CreateUserRequest
RenewUserRequest RenewUserRequest
RevokeUserRequest RevokeUserRequest
RotateRootCredentialsRequest
Statements Statements
UsernameConfig UsernameConfig
InitResponse
CreateUserResponse CreateUserResponse
TypeResponse TypeResponse
RotateRootCredentialsResponse
Empty Empty
*/ */
package dbplugin package dbplugin
@@ -65,6 +70,30 @@ func (m *InitializeRequest) GetVerifyConnection() bool {
return false return false
} }
type InitRequest struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
}
func (m *InitRequest) Reset() { *m = InitRequest{} }
func (m *InitRequest) String() string { return proto.CompactTextString(m) }
func (*InitRequest) ProtoMessage() {}
func (*InitRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *InitRequest) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
func (m *InitRequest) GetVerifyConnection() bool {
if m != nil {
return m.VerifyConnection
}
return false
}
type CreateUserRequest struct { type CreateUserRequest struct {
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"` Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"` UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
@@ -74,7 +103,7 @@ type CreateUserRequest struct {
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} } func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) } func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
func (*CreateUserRequest) ProtoMessage() {} func (*CreateUserRequest) ProtoMessage() {}
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *CreateUserRequest) GetStatements() *Statements { func (m *CreateUserRequest) GetStatements() *Statements {
if m != nil { if m != nil {
@@ -106,7 +135,7 @@ type RenewUserRequest struct {
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} } func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) } func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
func (*RenewUserRequest) ProtoMessage() {} func (*RenewUserRequest) ProtoMessage() {}
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *RenewUserRequest) GetStatements() *Statements { func (m *RenewUserRequest) GetStatements() *Statements {
if m != nil { if m != nil {
@@ -137,7 +166,7 @@ type RevokeUserRequest struct {
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} } func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) } func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
func (*RevokeUserRequest) ProtoMessage() {} func (*RevokeUserRequest) ProtoMessage() {}
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *RevokeUserRequest) GetStatements() *Statements { func (m *RevokeUserRequest) GetStatements() *Statements {
if m != nil { if m != nil {
@@ -153,17 +182,41 @@ func (m *RevokeUserRequest) GetUsername() string {
return "" return ""
} }
type RotateRootCredentialsRequest struct {
Statements []string `protobuf:"bytes,1,rep,name=statements" json:"statements,omitempty"`
}
func (m *RotateRootCredentialsRequest) Reset() { *m = RotateRootCredentialsRequest{} }
func (m *RotateRootCredentialsRequest) String() string { return proto.CompactTextString(m) }
func (*RotateRootCredentialsRequest) ProtoMessage() {}
func (*RotateRootCredentialsRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (m *RotateRootCredentialsRequest) GetStatements() []string {
if m != nil {
return m.Statements
}
return nil
}
type Statements struct { type Statements struct {
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"` // DEPRECATED, will be removed in 0.12
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
// DEPRECATED, will be removed in 0.12
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"` RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"` // DEPRECATED, will be removed in 0.12
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"` RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
// DEPRECATED, will be removed in 0.12
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
Creation []string `protobuf:"bytes,5,rep,name=creation" json:"creation,omitempty"`
Revocation []string `protobuf:"bytes,6,rep,name=revocation" json:"revocation,omitempty"`
Rollback []string `protobuf:"bytes,7,rep,name=rollback" json:"rollback,omitempty"`
Renewal []string `protobuf:"bytes,8,rep,name=renewal" json:"renewal,omitempty"`
} }
func (m *Statements) Reset() { *m = Statements{} } func (m *Statements) Reset() { *m = Statements{} }
func (m *Statements) String() string { return proto.CompactTextString(m) } func (m *Statements) String() string { return proto.CompactTextString(m) }
func (*Statements) ProtoMessage() {} func (*Statements) ProtoMessage() {}
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *Statements) GetCreationStatements() string { func (m *Statements) GetCreationStatements() string {
if m != nil { if m != nil {
@@ -193,6 +246,34 @@ func (m *Statements) GetRenewStatements() string {
return "" return ""
} }
func (m *Statements) GetCreation() []string {
if m != nil {
return m.Creation
}
return nil
}
func (m *Statements) GetRevocation() []string {
if m != nil {
return m.Revocation
}
return nil
}
func (m *Statements) GetRollback() []string {
if m != nil {
return m.Rollback
}
return nil
}
func (m *Statements) GetRenewal() []string {
if m != nil {
return m.Renewal
}
return nil
}
type UsernameConfig struct { type UsernameConfig struct {
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"` DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"` RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
@@ -201,7 +282,7 @@ type UsernameConfig struct {
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} } func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) } func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
func (*UsernameConfig) ProtoMessage() {} func (*UsernameConfig) ProtoMessage() {}
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *UsernameConfig) GetDisplayName() string { func (m *UsernameConfig) GetDisplayName() string {
if m != nil { if m != nil {
@@ -217,6 +298,22 @@ func (m *UsernameConfig) GetRoleName() string {
return "" return ""
} }
type InitResponse struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
}
func (m *InitResponse) Reset() { *m = InitResponse{} }
func (m *InitResponse) String() string { return proto.CompactTextString(m) }
func (*InitResponse) ProtoMessage() {}
func (*InitResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
func (m *InitResponse) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
type CreateUserResponse struct { type CreateUserResponse struct {
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"` Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"` Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
@@ -225,7 +322,7 @@ type CreateUserResponse struct {
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} } func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) } func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
func (*CreateUserResponse) ProtoMessage() {} func (*CreateUserResponse) ProtoMessage() {}
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{9} }
func (m *CreateUserResponse) GetUsername() string { func (m *CreateUserResponse) GetUsername() string {
if m != nil { if m != nil {
@@ -248,7 +345,7 @@ type TypeResponse struct {
func (m *TypeResponse) Reset() { *m = TypeResponse{} } func (m *TypeResponse) Reset() { *m = TypeResponse{} }
func (m *TypeResponse) String() string { return proto.CompactTextString(m) } func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
func (*TypeResponse) ProtoMessage() {} func (*TypeResponse) ProtoMessage() {}
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{10} }
func (m *TypeResponse) GetType() string { func (m *TypeResponse) GetType() string {
if m != nil { if m != nil {
@@ -257,23 +354,43 @@ func (m *TypeResponse) GetType() string {
return "" return ""
} }
type RotateRootCredentialsResponse struct {
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
}
func (m *RotateRootCredentialsResponse) Reset() { *m = RotateRootCredentialsResponse{} }
func (m *RotateRootCredentialsResponse) String() string { return proto.CompactTextString(m) }
func (*RotateRootCredentialsResponse) ProtoMessage() {}
func (*RotateRootCredentialsResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{11} }
func (m *RotateRootCredentialsResponse) GetConfig() []byte {
if m != nil {
return m.Config
}
return nil
}
type Empty struct { type Empty struct {
} }
func (m *Empty) Reset() { *m = Empty{} } func (m *Empty) Reset() { *m = Empty{} }
func (m *Empty) String() string { return proto.CompactTextString(m) } func (m *Empty) String() string { return proto.CompactTextString(m) }
func (*Empty) ProtoMessage() {} func (*Empty) ProtoMessage() {}
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} } func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{12} }
func init() { func init() {
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest") proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
proto.RegisterType((*InitRequest)(nil), "dbplugin.InitRequest")
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest") proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest") proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest") proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
proto.RegisterType((*RotateRootCredentialsRequest)(nil), "dbplugin.RotateRootCredentialsRequest")
proto.RegisterType((*Statements)(nil), "dbplugin.Statements") proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig") proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
proto.RegisterType((*InitResponse)(nil), "dbplugin.InitResponse")
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse") proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse") proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
proto.RegisterType((*RotateRootCredentialsResponse)(nil), "dbplugin.RotateRootCredentialsResponse")
proto.RegisterType((*Empty)(nil), "dbplugin.Empty") proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
} }
@@ -292,8 +409,10 @@ type DatabaseClient interface {
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error) RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) RotateRootCredentials(ctx context.Context, in *RotateRootCredentialsRequest, opts ...grpc.CallOption) (*RotateRootCredentialsResponse, error)
Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error)
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
} }
type databaseClient struct { type databaseClient struct {
@@ -340,9 +459,18 @@ func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest,
return out, nil return out, nil
} }
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) { func (c *databaseClient) RotateRootCredentials(ctx context.Context, in *RotateRootCredentialsRequest, opts ...grpc.CallOption) (*RotateRootCredentialsResponse, error) {
out := new(Empty) out := new(RotateRootCredentialsResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...) err := grpc.Invoke(ctx, "/dbplugin.Database/RotateRootCredentials", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *databaseClient) Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error) {
out := new(InitResponse)
err := grpc.Invoke(ctx, "/dbplugin.Database/Init", in, out, c.cc, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -358,6 +486,15 @@ func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.Call
return out, nil return out, nil
} }
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for Database service // Server API for Database service
type DatabaseServer interface { type DatabaseServer interface {
@@ -365,8 +502,10 @@ type DatabaseServer interface {
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
RenewUser(context.Context, *RenewUserRequest) (*Empty, error) RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error) RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
Initialize(context.Context, *InitializeRequest) (*Empty, error) RotateRootCredentials(context.Context, *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error)
Init(context.Context, *InitRequest) (*InitResponse, error)
Close(context.Context, *Empty) (*Empty, error) Close(context.Context, *Empty) (*Empty, error)
Initialize(context.Context, *InitializeRequest) (*Empty, error)
} }
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) { func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
@@ -445,20 +584,38 @@ func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _Database_RotateRootCredentials_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitializeRequest) in := new(RotateRootCredentialsRequest)
if err := dec(in); err != nil { if err := dec(in); err != nil {
return nil, err return nil, err
} }
if interceptor == nil { if interceptor == nil {
return srv.(DatabaseServer).Initialize(ctx, in) return srv.(DatabaseServer).RotateRootCredentials(ctx, in)
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: "/dbplugin.Database/Initialize", FullMethod: "/dbplugin.Database/RotateRootCredentials",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest)) return srv.(DatabaseServer).RotateRootCredentials(ctx, req.(*RotateRootCredentialsRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Database_Init_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Init(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Init",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Init(ctx, req.(*InitRequest))
} }
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
@@ -481,6 +638,24 @@ func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(inte
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InitializeRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DatabaseServer).Initialize(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/dbplugin.Database/Initialize",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
}
return interceptor(ctx, in, info, handler)
}
var _Database_serviceDesc = grpc.ServiceDesc{ var _Database_serviceDesc = grpc.ServiceDesc{
ServiceName: "dbplugin.Database", ServiceName: "dbplugin.Database",
HandlerType: (*DatabaseServer)(nil), HandlerType: (*DatabaseServer)(nil),
@@ -502,13 +677,21 @@ var _Database_serviceDesc = grpc.ServiceDesc{
Handler: _Database_RevokeUser_Handler, Handler: _Database_RevokeUser_Handler,
}, },
{ {
MethodName: "Initialize", MethodName: "RotateRootCredentials",
Handler: _Database_Initialize_Handler, Handler: _Database_RotateRootCredentials_Handler,
},
{
MethodName: "Init",
Handler: _Database_Init_Handler,
}, },
{ {
MethodName: "Close", MethodName: "Close",
Handler: _Database_Close_Handler, Handler: _Database_Close_Handler,
}, },
{
MethodName: "Initialize",
Handler: _Database_Initialize_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "builtin/logical/database/dbplugin/database.proto", Metadata: "builtin/logical/database/dbplugin/database.proto",
@@ -517,40 +700,49 @@ var _Database_serviceDesc = grpc.ServiceDesc{
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) } func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{ var fileDescriptor0 = []byte{
// 548 bytes of a gzipped FileDescriptorProto // 690 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x55, 0x41, 0x4f, 0xdb, 0x4a,
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f, 0x10, 0x96, 0x93, 0x00, 0xc9, 0x80, 0x80, 0xec, 0x03, 0x64, 0xf9, 0xf1, 0xde, 0x43, 0x3e, 0xf0,
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e, 0x40, 0x95, 0xe2, 0x0a, 0x5a, 0xb5, 0xe2, 0xd0, 0xaa, 0x0a, 0x55, 0x55, 0xa9, 0xe2, 0xb0, 0xc0,
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2, 0xad, 0x12, 0xda, 0x38, 0x43, 0xba, 0xc2, 0xf1, 0xba, 0xde, 0x0d, 0x34, 0xfd, 0x03, 0xed, 0xcf,
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6, 0xe8, 0x4f, 0xe9, 0xb1, 0x3f, 0xab, 0xf2, 0xda, 0x6b, 0x6f, 0x62, 0x28, 0x07, 0xda, 0x9b, 0x67,
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c, 0xe6, 0xfb, 0x66, 0xbe, 0x9d, 0x9d, 0x59, 0xc3, 0xe3, 0xc1, 0x84, 0x47, 0x8a, 0xc7, 0x41, 0x24,
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e, 0x46, 0x3c, 0x64, 0x51, 0x30, 0x64, 0x8a, 0x0d, 0x98, 0xc4, 0x60, 0x38, 0x48, 0xa2, 0xc9, 0x88,
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3, 0xc7, 0xa5, 0xa7, 0x97, 0xa4, 0x42, 0x09, 0xd2, 0x36, 0x01, 0xef, 0xbf, 0x91, 0x10, 0xa3, 0x08,
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1, 0x03, 0xed, 0x1f, 0x4c, 0x2e, 0x03, 0xc5, 0xc7, 0x28, 0x15, 0x1b, 0x27, 0x39, 0xd4, 0x7f, 0x0f,
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f, 0xdd, 0xb7, 0x31, 0x57, 0x9c, 0x45, 0xfc, 0x33, 0x52, 0xfc, 0x38, 0x41, 0xa9, 0xc8, 0x16, 0x2c,
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49, 0x86, 0x22, 0xbe, 0xe4, 0x23, 0xd7, 0xd9, 0x71, 0xf6, 0x56, 0x68, 0x61, 0x91, 0x47, 0xd0, 0xbd,
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35, 0xc6, 0x94, 0x5f, 0x4e, 0x2f, 0x42, 0x11, 0xc7, 0x18, 0x2a, 0x2e, 0x62, 0xb7, 0xb1, 0xe3, 0xec,
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20, 0xb5, 0xe9, 0x7a, 0x1e, 0xe8, 0x97, 0xfe, 0xa3, 0x86, 0xeb, 0xf8, 0x14, 0x96, 0xb3, 0xec, 0xbf,
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6, 0x33, 0xaf, 0xff, 0xc3, 0x81, 0x6e, 0x3f, 0x45, 0xa6, 0xf0, 0x5c, 0x62, 0x6a, 0x52, 0x3f, 0x01,
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34, 0x90, 0x8a, 0x29, 0x1c, 0x63, 0xac, 0xa4, 0x4e, 0xbf, 0x7c, 0xb0, 0xd1, 0x33, 0x7d, 0xe8, 0x9d,
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0, 0x96, 0x31, 0x6a, 0xe1, 0xc8, 0x2b, 0x58, 0x9b, 0x48, 0x4c, 0x63, 0x36, 0xc6, 0x8b, 0x42, 0x59,
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce, 0x43, 0x53, 0xdd, 0x8a, 0x7a, 0x5e, 0x00, 0xfa, 0x3a, 0x4e, 0x57, 0x27, 0x33, 0x36, 0x39, 0x02,
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68, 0xc0, 0x4f, 0x09, 0x4f, 0x99, 0x16, 0xdd, 0xd4, 0x6c, 0xaf, 0x97, 0xb7, 0xbd, 0x67, 0xda, 0xde,
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0, 0x3b, 0x33, 0x6d, 0xa7, 0x16, 0xda, 0xff, 0xe6, 0xc0, 0x3a, 0xc5, 0x18, 0x6f, 0x1e, 0x7e, 0x12,
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8, 0x0f, 0xda, 0x46, 0x98, 0x3e, 0x42, 0x87, 0x96, 0xf6, 0x83, 0x24, 0x22, 0x74, 0x29, 0x5e, 0x8b,
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1, 0x2b, 0xfc, 0xa3, 0x12, 0xfd, 0x17, 0xb0, 0x4d, 0x45, 0x06, 0xa5, 0x42, 0xa8, 0x7e, 0x8a, 0x43,
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0, 0x8c, 0xb3, 0x99, 0x94, 0xa6, 0xe2, 0xbf, 0x73, 0x15, 0x9b, 0x7b, 0x1d, 0x3b, 0xb7, 0xff, 0xbd,
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a, 0x01, 0x50, 0x95, 0x25, 0x01, 0xfc, 0x15, 0x66, 0x23, 0xc2, 0x45, 0x7c, 0x31, 0xa7, 0xb4, 0x43,
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab, 0x89, 0x09, 0x59, 0x84, 0x43, 0xd8, 0x4c, 0xf1, 0x5a, 0x84, 0x35, 0x4a, 0x2e, 0x74, 0xa3, 0x0a,
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c, 0xce, 0x56, 0x49, 0x45, 0x14, 0x0d, 0x58, 0x78, 0x65, 0x53, 0x9a, 0x79, 0x15, 0x13, 0xb2, 0x08,
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4, 0xfb, 0xb0, 0x9e, 0x66, 0xd7, 0x6d, 0xa3, 0x5b, 0x1a, 0xbd, 0xa6, 0xfd, 0xa7, 0x33, 0xcd, 0x32,
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36, 0x32, 0xdd, 0x05, 0x7d, 0xdc, 0xd2, 0xce, 0x9a, 0x51, 0xe9, 0x71, 0x17, 0xf3, 0x66, 0x54, 0x9e,
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a, 0x8c, 0x6b, 0x8a, 0xbb, 0x4b, 0x39, 0xd7, 0xd8, 0xc4, 0x85, 0x25, 0x5d, 0x8a, 0x45, 0x6e, 0x5b,
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf, 0x87, 0x8c, 0xe9, 0x9f, 0xc0, 0xea, 0xec, 0xa8, 0x93, 0x1d, 0x58, 0x3e, 0xe6, 0x32, 0x89, 0xd8,
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67, 0xf4, 0x24, 0xbb, 0xb3, 0xbc, 0x7b, 0xb6, 0x2b, 0xab, 0x44, 0x45, 0x84, 0x27, 0xd6, 0x95, 0x1a,
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b, 0xdb, 0xdf, 0x85, 0x95, 0x7c, 0xf7, 0x65, 0x22, 0x62, 0x89, 0x77, 0x2d, 0xbf, 0xff, 0x0e, 0x88,
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d, 0xbd, 0xce, 0x05, 0xda, 0x1e, 0x16, 0x67, 0x6e, 0x9e, 0x3d, 0x68, 0x27, 0x4c, 0xca, 0x1b, 0x91,
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6, 0x0e, 0x4d, 0x55, 0x63, 0xfb, 0x3e, 0xac, 0x9c, 0x4d, 0x13, 0x2c, 0xf3, 0x10, 0x68, 0xa9, 0x69,
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56, 0x62, 0x72, 0xe8, 0x6f, 0xff, 0x19, 0xfc, 0x73, 0xc7, 0xb0, 0xdd, 0x23, 0x75, 0x09, 0x16, 0x5e,
0x94, 0x05, 0x00, 0x00, 0x8f, 0x13, 0x35, 0x3d, 0xf8, 0xd2, 0x82, 0xf6, 0x71, 0xf1, 0xe6, 0x92, 0x00, 0x5a, 0x59, 0x49,
0xb2, 0x56, 0x6d, 0x80, 0x46, 0x79, 0x5b, 0x95, 0x63, 0x46, 0xd3, 0x1b, 0x80, 0xea, 0xc4, 0xe4,
0xef, 0x0a, 0x55, 0x7b, 0xd6, 0xbc, 0xed, 0xdb, 0x83, 0x45, 0xa2, 0xe7, 0xd0, 0x29, 0x9f, 0x0f,
0xe2, 0x55, 0xd0, 0xf9, 0x37, 0xc5, 0x9b, 0x97, 0x96, 0x3d, 0x09, 0xd5, 0x5a, 0xdb, 0x12, 0x6a,
0xcb, 0x5e, 0xe7, 0x7e, 0x80, 0xcd, 0x5b, 0xdb, 0x47, 0x76, 0xad, 0x34, 0xbf, 0x58, 0x66, 0xef,
0xff, 0x7b, 0x71, 0xc5, 0xf9, 0x9e, 0x42, 0x2b, 0x1b, 0x21, 0xb2, 0x59, 0x11, 0xac, 0xdf, 0x89,
0xdd, 0xdf, 0x99, 0x49, 0xdb, 0x87, 0x85, 0x7e, 0x24, 0xe4, 0x2d, 0x37, 0x52, 0x3b, 0xcb, 0x4b,
0x80, 0xea, 0xf7, 0x67, 0xf7, 0xa1, 0xf6, 0x53, 0xac, 0x71, 0xfd, 0xe6, 0xd7, 0x86, 0x33, 0x58,
0xd4, 0xef, 0xe7, 0xe1, 0xcf, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa7, 0x13, 0xfe, 0x55, 0xa5, 0x07,
0x00, 0x00,
} }

View File

@@ -4,6 +4,12 @@ package dbplugin;
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
message InitializeRequest { message InitializeRequest {
option deprecated = true;
bytes config = 1;
bool verify_connection = 2;
}
message InitRequest {
bytes config = 1; bytes config = 1;
bool verify_connection = 2; bool verify_connection = 2;
} }
@@ -25,11 +31,24 @@ message RevokeUserRequest {
string username = 2; string username = 2;
} }
message RotateRootCredentialsRequest {
repeated string statements = 1;
}
message Statements { message Statements {
// DEPRECATED, will be removed in 0.12
string creation_statements = 1; string creation_statements = 1;
// DEPRECATED, will be removed in 0.12
string revocation_statements = 2; string revocation_statements = 2;
// DEPRECATED, will be removed in 0.12
string rollback_statements = 3; string rollback_statements = 3;
// DEPRECATED, will be removed in 0.12
string renew_statements = 4; string renew_statements = 4;
repeated string creation = 5;
repeated string revocation = 6;
repeated string rollback = 7;
repeated string renewal = 8;
} }
message UsernameConfig { message UsernameConfig {
@@ -37,22 +56,35 @@ message UsernameConfig {
string RoleName = 2; string RoleName = 2;
} }
message InitResponse {
bytes config = 1;
}
message CreateUserResponse { message CreateUserResponse {
string username = 1; string username = 1;
string password = 2; string password = 2;
} }
message TypeResponse { message TypeResponse {
string type = 1; string type = 1;
}
message RotateRootCredentialsResponse {
bytes config = 1;
} }
message Empty {} message Empty {}
service Database { service Database {
rpc Type(Empty) returns (TypeResponse); rpc Type(Empty) returns (TypeResponse);
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
rpc RenewUser(RenewUserRequest) returns (Empty); rpc RenewUser(RenewUserRequest) returns (Empty);
rpc RevokeUser(RevokeUserRequest) returns (Empty); rpc RevokeUser(RevokeUserRequest) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty); rpc RotateRootCredentials(RotateRootCredentialsRequest) returns (RotateRootCredentialsResponse);
rpc Close(Empty) returns (Empty); rpc Init(InitRequest) returns (InitResponse);
rpc Close(Empty) returns (Empty);
rpc Initialize(InitializeRequest) returns (Empty) {
option deprecated = true;
};
} }

View File

@@ -2,8 +2,14 @@ package dbplugin
import ( import (
"context" "context"
"errors"
"net/url"
"strings"
"sync"
"time" "time"
"github.com/hashicorp/errwrap"
metrics "github.com/armon/go-metrics" metrics "github.com/armon/go-metrics"
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
) )
@@ -51,13 +57,27 @@ func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements
return mw.next.RevokeUser(ctx, statements, username) return mw.next.RevokeUser(ctx, statements, username)
} }
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { func (mw *databaseTracingMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
defer func(then time.Time) {
mw.logger.Trace("database", "operation", "RotateRootCredentials", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
}(time.Now())
mw.logger.Trace("database", "operation", "RotateRootCredentials", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.RotateRootCredentials(ctx, statements)
}
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *databaseTracingMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
defer func(then time.Time) { defer func(then time.Time) {
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then)) mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
}(time.Now()) }(time.Now())
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport) mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
return mw.next.Initialize(ctx, conf, verifyConnection) return mw.next.Init(ctx, conf, verifyConnection)
} }
func (mw *databaseTracingMiddleware) Close() (err error) { func (mw *databaseTracingMiddleware) Close() (err error) {
@@ -131,7 +151,28 @@ func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements
return mw.next.RevokeUser(ctx, statements, username) return mw.next.RevokeUser(ctx, statements, username)
} }
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) { func (mw *databaseMetricsMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "RotateRootCredentials"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "RotateRootCredentials"}, now)
if err != nil {
metrics.IncrCounter([]string{"database", "RotateRootCredentials", "error"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials", "error"}, 1)
}
}(time.Now())
metrics.IncrCounter([]string{"database", "RotateRootCredentials"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials"}, 1)
return mw.next.RotateRootCredentials(ctx, statements)
}
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *databaseMetricsMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
defer func(now time.Time) { defer func(now time.Time) {
metrics.MeasureSince([]string{"database", "Initialize"}, now) metrics.MeasureSince([]string{"database", "Initialize"}, now)
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
@@ -144,7 +185,7 @@ func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[st
metrics.IncrCounter([]string{"database", "Initialize"}, 1) metrics.IncrCounter([]string{"database", "Initialize"}, 1)
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
return mw.next.Initialize(ctx, conf, verifyConnection) return mw.next.Init(ctx, conf, verifyConnection)
} }
func (mw *databaseMetricsMiddleware) Close() (err error) { func (mw *databaseMetricsMiddleware) Close() (err error) {
@@ -162,3 +203,76 @@ func (mw *databaseMetricsMiddleware) Close() (err error) {
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
return mw.next.Close() return mw.next.Close()
} }
// ---- Error Sanitizer Middleware Domain ----
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
// sanitizes returned error messages
type DatabaseErrorSanitizerMiddleware struct {
l sync.RWMutex
next Database
secretsFn func() map[string]interface{}
}
func NewDatabaseErrorSanitizerMiddleware(next Database, secretsFn func() map[string]interface{}) *DatabaseErrorSanitizerMiddleware {
return &DatabaseErrorSanitizerMiddleware{
next: next,
secretsFn: secretsFn,
}
}
func (mw *DatabaseErrorSanitizerMiddleware) Type() (string, error) {
dbType, err := mw.next.Type()
return dbType, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
username, password, err = mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
return username, password, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
return mw.sanitize(mw.next.RenewUser(ctx, statements, username, expiration))
}
func (mw *DatabaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
return mw.sanitize(mw.next.RevokeUser(ctx, statements, username))
}
func (mw *DatabaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
conf, err = mw.next.RotateRootCredentials(ctx, statements)
return conf, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := mw.Init(ctx, conf, verifyConnection)
return err
}
func (mw *DatabaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
saveConf, err = mw.next.Init(ctx, conf, verifyConnection)
return saveConf, mw.sanitize(err)
}
func (mw *DatabaseErrorSanitizerMiddleware) Close() (err error) {
return mw.sanitize(mw.next.Close())
}
// sanitize
func (mw *DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
if err == nil {
return nil
}
if errwrap.ContainsType(err, new(url.Error)) {
return errors.New("unable to parse connection url")
}
if mw.secretsFn != nil {
for k, v := range mw.secretsFn() {
if k == "" {
continue
}
err = errors.New(strings.Replace(err.Error(), k, v.(string), -1))
}
}
return err
}

View File

@@ -7,6 +7,8 @@ import (
"time" "time"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
@@ -61,16 +63,51 @@ func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*E
return &Empty{}, err return &Empty{}, err
} }
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) { func (s *gRPCServer) RotateRootCredentials(ctx context.Context, req *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error) {
config := map[string]interface{}{}
resp, err := s.impl.RotateRootCredentials(ctx, req.Statements)
if err != nil {
return nil, err
}
respConfig, err := json.Marshal(resp)
if err != nil {
return nil, err
}
return &RotateRootCredentialsResponse{
Config: respConfig,
}, err
}
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
_, err := s.Init(ctx, &InitRequest{
Config: req.Config,
VerifyConnection: req.VerifyConnection,
})
return &Empty{}, err
}
func (s *gRPCServer) Init(ctx context.Context, req *InitRequest) (*InitResponse, error) {
config := map[string]interface{}{}
err := json.Unmarshal(req.Config, &config) err := json.Unmarshal(req.Config, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = s.impl.Initialize(ctx, config, req.VerifyConnection) resp, err := s.impl.Init(ctx, config, req.VerifyConnection)
return &Empty{}, err if err != nil {
return nil, err
}
respConfig, err := json.Marshal(resp)
if err != nil {
return nil, err
}
return &InitResponse{
Config: respConfig,
}, err
} }
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) { func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
@@ -87,7 +124,7 @@ type gRPCClient struct {
doneCtx context.Context doneCtx context.Context
} }
func (c gRPCClient) Type() (string, error) { func (c *gRPCClient) Type() (string, error) {
resp, err := c.client.Type(c.doneCtx, &Empty{}) resp, err := c.client.Type(c.doneCtx, &Empty{})
if err != nil { if err != nil {
return "", err return "", err
@@ -96,7 +133,7 @@ func (c gRPCClient) Type() (string, error) {
return resp.Type, err return resp.Type, err
} }
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) { func (c *gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
t, err := ptypes.TimestampProto(expiration) t, err := ptypes.TimestampProto(expiration)
if err != nil { if err != nil {
return "", "", err return "", "", err
@@ -172,10 +209,40 @@ func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, user
return nil return nil
} }
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error { func (c *gRPCClient) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
configRaw, err := json.Marshal(config) ctx, cancel := context.WithCancel(ctx)
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
defer close(quitCh)
defer cancel()
resp, err := c.client.RotateRootCredentials(ctx, &RotateRootCredentialsRequest{
Statements: statements,
})
if err != nil { if err != nil {
return err if c.doneCtx.Err() != nil {
return nil, ErrPluginShutdown
}
return nil, err
}
if err := json.Unmarshal(resp.Config, &conf); err != nil {
return nil, err
}
return conf, nil
}
func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *gRPCClient) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
configRaw, err := json.Marshal(conf)
if err != nil {
return nil, err
} }
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@@ -183,19 +250,33 @@ func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface
defer close(quitCh) defer close(quitCh)
defer cancel() defer cancel()
_, err = c.client.Initialize(ctx, &InitializeRequest{ resp, err := c.client.Init(ctx, &InitRequest{
Config: configRaw, Config: configRaw,
VerifyConnection: verifyConnection, VerifyConnection: verifyConnection,
}) })
if err != nil { if err != nil {
if c.doneCtx.Err() != nil { // Fall back to old call if not implemented
return ErrPluginShutdown grpcStatus, ok := status.FromError(err)
if ok && grpcStatus.Code() == codes.Unimplemented {
_, err = c.client.Initialize(ctx, &InitializeRequest{
Config: configRaw,
VerifyConnection: verifyConnection,
})
if err == nil {
return conf, nil
}
} }
return err if c.doneCtx.Err() != nil {
return nil, ErrPluginShutdown
}
return nil, err
} }
return nil if err := json.Unmarshal(resp.Config, &conf); err != nil {
return nil, err
}
return conf, nil
} }
func (c *gRPCClient) Close() error { func (c *gRPCClient) Close() error {

View File

@@ -2,8 +2,10 @@ package dbplugin
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/rpc" "net/rpc"
"strings"
"time" "time"
) )
@@ -37,8 +39,28 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *str
return err return err
} }
func (ds *databasePluginRPCServer) RotateRootCredentials(args *RotateRootCredentialsRequestRPC, resp *RotateRootCredentialsResponse) error {
config, err := ds.impl.RotateRootCredentials(context.Background(), args.Statements)
if err != nil {
return err
}
resp.Config, err = json.Marshal(config)
return err
}
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error { func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection) return ds.Init(&InitRequestRPC{
Config: args.Config,
VerifyConnection: args.VerifyConnection,
}, &InitResponse{})
}
func (ds *databasePluginRPCServer) Init(args *InitRequestRPC, resp *InitResponse) error {
config, err := ds.impl.Init(context.Background(), args.Config, args.VerifyConnection)
if err != nil {
return err
}
resp.Config, err = json.Marshal(config)
return err return err
} }
@@ -81,9 +103,7 @@ func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements State
Expiration: expiration, Expiration: expiration,
} }
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) return dr.client.Call("Plugin.RenewUser", req, &struct{}{})
return err
} }
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error { func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
@@ -92,26 +112,55 @@ func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Stat
Username: username, Username: username,
} }
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) return dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
}
return err func (dr *databasePluginRPCClient) RotateRootCredentials(_ context.Context, statements []string) (saveConf map[string]interface{}, err error) {
req := RotateRootCredentialsRequestRPC{
Statements: statements,
}
var resp RotateRootCredentialsResponse
err = dr.client.Call("Plugin.RotateRootCredentials", req, &resp)
err = json.Unmarshal(resp.Config, &saveConf)
return saveConf, err
} }
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error { func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
req := InitializeRequestRPC{ _, err := dr.Init(nil, conf, verifyConnection)
return err
}
func (dr *databasePluginRPCClient) Init(_ context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
req := InitRequestRPC{
Config: conf, Config: conf,
VerifyConnection: verifyConnection, VerifyConnection: verifyConnection,
} }
err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) var resp InitResponse
err = dr.client.Call("Plugin.Init", req, &resp)
if err != nil {
if strings.Contains(err.Error(), "can't find method Plugin.Init") {
req := InitializeRequestRPC{
Config: conf,
VerifyConnection: verifyConnection,
}
return err err = dr.client.Call("Plugin.Initialize", req, &struct{}{})
if err == nil {
return conf, nil
}
}
return nil, err
}
err = json.Unmarshal(resp.Config, &saveConf)
return saveConf, err
} }
func (dr *databasePluginRPCClient) Close() error { func (dr *databasePluginRPCClient) Close() error {
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) return dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
return err
} }
// ---- RPC Request Args Domain ---- // ---- RPC Request Args Domain ----
@@ -121,6 +170,11 @@ type InitializeRequestRPC struct {
VerifyConnection bool VerifyConnection bool
} }
type InitRequestRPC struct {
Config map[string]interface{}
VerifyConnection bool
}
type CreateUserRequestRPC struct { type CreateUserRequestRPC struct {
Statements Statements Statements Statements
UsernameConfig UsernameConfig UsernameConfig UsernameConfig
@@ -137,3 +191,7 @@ type RevokeUserRequestRPC struct {
Statements Statements Statements Statements
Username string Username string
} }
type RotateRootCredentialsRequestRPC struct {
Statements []string
}

View File

@@ -8,6 +8,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/pluginutil"
log "github.com/mgutz/logxi/v1" log "github.com/mgutz/logxi/v1"
@@ -20,8 +21,13 @@ type Database interface {
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
RevokeUser(ctx context.Context, statements Statements, username string) error RevokeUser(ctx context.Context, statements Statements, username string) error
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error)
Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error)
Close() error Close() error
// DEPRECATED, will be removed in 0.12
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error)
} }
// PluginFactory is used to build plugin database types. It wraps the database // PluginFactory is used to build plugin database types. It wraps the database
@@ -40,7 +46,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
// from the pluginRunner. Then cast it to a Database. // from the pluginRunner. Then cast it to a Database.
dbRaw, err := pluginRunner.BuiltinFactory() dbRaw, err := pluginRunner.BuiltinFactory()
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting plugin type: %s", err) return nil, errwrap.Wrapf("error initializing plugin: {{err}}", err)
} }
var ok bool var ok bool
@@ -71,7 +77,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
typeStr, err := db.Type() typeStr, err := db.Type()
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting plugin type: %s", err) return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
} }
// Wrap with metrics middleware // Wrap with metrics middleware
@@ -113,7 +119,11 @@ type DatabasePlugin struct {
} }
func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) {
return &databasePluginRPCServer{impl: d.impl}, nil impl := &DatabaseErrorSanitizerMiddleware{
next: d.impl,
}
return &databasePluginRPCServer{impl: impl}, nil
} }
func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
@@ -121,7 +131,11 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
} }
func (d DatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error { func (d DatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl}) impl := &DatabaseErrorSanitizerMiddleware{
next: d.impl,
}
RegisterDatabaseServer(s, &gRPCServer{impl: impl})
return nil return nil
} }

View File

@@ -61,6 +61,17 @@ func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statement
delete(m.users, username) delete(m.users, username)
return nil return nil
} }
func (m *mockPlugin) RotateRootCredentials(_ context.Context, statements []string) (map[string]interface{}, error) {
return nil, nil
}
func (m *mockPlugin) Init(_ context.Context, conf map[string]interface{}, _ bool) (map[string]interface{}, error) {
err := errors.New("err")
if len(conf) != 1 {
return nil, err
}
return conf, nil
}
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error { func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
err := errors.New("err") err := errors.New("err")
if len(conf) != 1 { if len(conf) != 1 {
@@ -132,7 +143,7 @@ func TestPlugin_NetRPC_Main(t *testing.T) {
plugin.Serve(serveConf) plugin.Serve(serveConf)
} }
func TestPlugin_Initialize(t *testing.T) { func TestPlugin_Init(t *testing.T) {
cluster, sys := getCluster(t) cluster, sys := getCluster(t)
defer cluster.Cleanup() defer cluster.Cleanup()
@@ -145,7 +156,7 @@ func TestPlugin_Initialize(t *testing.T) {
"test": 1, "test": 1,
} }
err = dbRaw.Initialize(context.Background(), connectionDetails, true) _, err = dbRaw.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -170,7 +181,7 @@ func TestPlugin_CreateUser(t *testing.T) {
"test": 1, "test": 1,
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -209,7 +220,7 @@ func TestPlugin_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"test": 1, "test": 1,
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -243,7 +254,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"test": 1, "test": 1,
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -272,7 +283,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
} }
// Test the code is still compatible with an old netRPC plugin // Test the code is still compatible with an old netRPC plugin
func TestPlugin_NetRPC_Initialize(t *testing.T) { func TestPlugin_NetRPC_Init(t *testing.T) {
cluster, sys := getCluster(t) cluster, sys := getCluster(t)
defer cluster.Cleanup() defer cluster.Cleanup()
@@ -285,7 +296,7 @@ func TestPlugin_NetRPC_Initialize(t *testing.T) {
"test": 1, "test": 1,
} }
err = dbRaw.Initialize(context.Background(), connectionDetails, true) _, err = dbRaw.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -310,7 +321,7 @@ func TestPlugin_NetRPC_CreateUser(t *testing.T) {
"test": 1, "test": 1,
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -349,7 +360,7 @@ func TestPlugin_NetRPC_RenewUser(t *testing.T) {
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"test": 1, "test": 1,
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -383,7 +394,7 @@ func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"test": 1, "test": 1,
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/fatih/structs" "github.com/fatih/structs"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/logical/framework"
@@ -24,6 +25,8 @@ type DatabaseConfig struct {
// by each database type. // by each database type.
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"` AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"`
RootCredentialsRotateStatements []string `json:"root_credentials_rotate_statements" structs:"root_credentials_rotate_statements" mapstructure:"root_credentials_rotate_statements"`
} }
// pathResetConnection configures a path to reset a plugin. // pathResetConnection configures a path to reset a plugin.
@@ -55,16 +58,13 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
return logical.ErrorResponse(respErrEmptyName), nil return logical.ErrorResponse(respErrEmptyName), nil
} }
// Grab the mutex lock
b.Lock()
defer b.Unlock()
// Close plugin and delete the entry in the connections cache. // Close plugin and delete the entry in the connections cache.
b.clearConnection(name) if err := b.ClearConnection(name); err != nil {
return nil, err
}
// Execute plugin again, we don't need the object so throw away. // Execute plugin again, we don't need the object so throw away.
_, err := b.createDBObj(ctx, req.Storage, name) if _, err := b.GetConnection(ctx, req.Storage, name); err != nil {
if err != nil {
return nil, err return nil, err
} }
@@ -103,6 +103,14 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
allowed to get creds from this database connection. If empty no allowed to get creds from this database connection. If empty no
roles are allowed. If "*" all roles are allowed.`, roles are allowed. If "*" all roles are allowed.`,
}, },
"root_rotation_statements": &framework.FieldSchema{
Type: framework.TypeStringSlice,
Description: `Specifies the database statements to be executed
to rotate the root user's credentials. See the plugin's API
page for more information on support and formatting for this
parameter.`,
},
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
@@ -179,16 +187,8 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
return nil, errors.New("failed to delete connection configuration") return nil, errors.New("failed to delete connection configuration")
} }
b.Lock() if err := b.ClearConnection(name); err != nil {
defer b.Unlock() return nil, err
if _, ok := b.connections[name]; ok {
err = b.connections[name].Close()
if err != nil {
return nil, err
}
delete(b.connections, name)
} }
return nil, nil return nil, nil
@@ -210,8 +210,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
} }
verifyConnection := data.Get("verify_connection").(bool) verifyConnection := data.Get("verify_connection").(bool)
allowedRoles := data.Get("allowed_roles").([]string) allowedRoles := data.Get("allowed_roles").([]string)
rootRotationStatements := data.Get("root_rotation_statements").([]string)
// Remove these entries from the data before we store it keyed under // Remove these entries from the data before we store it keyed under
// ConnectionDetails. // ConnectionDetails.
@@ -219,35 +219,45 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
delete(data.Raw, "plugin_name") delete(data.Raw, "plugin_name")
delete(data.Raw, "allowed_roles") delete(data.Raw, "allowed_roles")
delete(data.Raw, "verify_connection") delete(data.Raw, "verify_connection")
delete(data.Raw, "root_rotation_statements")
config := &DatabaseConfig{ // Create a database plugin and initialize it. This instance is not
ConnectionDetails: data.Raw, // going to be used and is initialized just to ensure all parameters
PluginName: pluginName, // are valid and the connection is verified, if requested.
AllowedRoles: allowedRoles, db, err := dbplugin.PluginFactory(ctx, pluginName, b.System(), b.logger)
}
db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
if err != nil { if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
} }
connDetails, err := db.Init(ctx, data.Raw, verifyConnection)
err = db.Initialize(ctx, config.ConnectionDetails, verifyConnection)
if err != nil { if err != nil {
db.Close() db.Close()
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
} }
// Grab the mutex lock
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
// Close and remove the old connection // Close and remove the old connection
b.clearConnection(name) b.clearConnection(name)
// Save the new connection id, err := uuid.GenerateUUID()
b.connections[name] = db if err != nil {
return nil, err
}
b.connections[name] = &dbPluginInstance{
Database: db,
name: name,
id: id,
}
// Store it // Store it
config := &DatabaseConfig{
ConnectionDetails: connDetails,
PluginName: pluginName,
AllowedRoles: allowedRoles,
RootCredentialsRotateStatements: rootRotationStatements,
}
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config) entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -54,26 +54,15 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
return nil, logical.ErrPermissionDenied return nil, logical.ErrPermissionDenied
} }
// Grab the read lock
b.RLock()
unlockFunc := b.RUnlock
// Get the Database object // Get the Database object
db, ok := b.getDBObj(role.DBName) db, err := b.GetConnection(ctx, req.Storage, role.DBName)
if !ok { if err != nil {
// Upgrade lock return nil, err
b.RUnlock()
b.Lock()
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
}
} }
db.RLock()
defer db.RUnlock()
ttl := b.System().DefaultLeaseTTL() ttl := b.System().DefaultLeaseTTL()
if role.DefaultTTL != 0 { if role.DefaultTTL != 0 {
ttl = role.DefaultTTL ttl = role.DefaultTTL
@@ -96,8 +85,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
// Create the user // Create the user
username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration) username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration)
if err != nil { if err != nil {
unlockFunc() b.CloseIfShutdown(db, err)
b.closeIfShutdown(role.DBName, err)
return nil, err return nil, err
} }
@@ -109,8 +97,6 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
"role": name, "role": name,
}) })
resp.Secret.TTL = ttl resp.Secret.TTL = ttl
unlockFunc()
return resp, nil return resp, nil
} }
} }

View File

@@ -36,26 +36,26 @@ func pathRoles(b *databaseBackend) *framework.Path {
Description: "Name of the database this role acts on.", Description: "Name of the database this role acts on.",
}, },
"creation_statements": { "creation_statements": {
Type: framework.TypeString, Type: framework.TypeStringSlice,
Description: `Specifies the database statements executed to Description: `Specifies the database statements executed to
create and configure a user. See the plugin's API page for more create and configure a user. See the plugin's API page for more
information on support and formatting for this parameter.`, information on support and formatting for this parameter.`,
}, },
"revocation_statements": { "revocation_statements": {
Type: framework.TypeString, Type: framework.TypeStringSlice,
Description: `Specifies the database statements to be executed Description: `Specifies the database statements to be executed
to revoke a user. See the plugin's API page for more information to revoke a user. See the plugin's API page for more information
on support and formatting for this parameter.`, on support and formatting for this parameter.`,
}, },
"renew_statements": { "renew_statements": {
Type: framework.TypeString, Type: framework.TypeStringSlice,
Description: `Specifies the database statements to be executed Description: `Specifies the database statements to be executed
to renew a user. Not every plugin type will support this to renew a user. Not every plugin type will support this
functionality. See the plugin's API page for more information on functionality. See the plugin's API page for more information on
support and formatting for this parameter. `, support and formatting for this parameter. `,
}, },
"rollback_statements": { "rollback_statements": {
Type: framework.TypeString, Type: framework.TypeStringSlice,
Description: `Specifies the database statements to be executed Description: `Specifies the database statements to be executed
rollback a create operation in the event of an error. Not every rollback a create operation in the event of an error. Not every
plugin type will support this functionality. See the plugin's plugin type will support this functionality. See the plugin's
@@ -109,10 +109,10 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
return &logical.Response{ return &logical.Response{
Data: map[string]interface{}{ Data: map[string]interface{}{
"db_name": role.DBName, "db_name": role.DBName,
"creation_statements": role.Statements.CreationStatements, "creation_statements": role.Statements.Creation,
"revocation_statements": role.Statements.RevocationStatements, "revocation_statements": role.Statements.Revocation,
"rollback_statements": role.Statements.RollbackStatements, "rollback_statements": role.Statements.Rollback,
"renew_statements": role.Statements.RenewStatements, "renew_statements": role.Statements.Renewal,
"default_ttl": role.DefaultTTL.Seconds(), "default_ttl": role.DefaultTTL.Seconds(),
"max_ttl": role.MaxTTL.Seconds(), "max_ttl": role.MaxTTL.Seconds(),
}, },
@@ -144,10 +144,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
} }
// Get statements // Get statements
creationStmts := data.Get("creation_statements").(string) creationStmts := data.Get("creation_statements").([]string)
revocationStmts := data.Get("revocation_statements").(string) revocationStmts := data.Get("revocation_statements").([]string)
rollbackStmts := data.Get("rollback_statements").(string) rollbackStmts := data.Get("rollback_statements").([]string)
renewStmts := data.Get("renew_statements").(string) renewStmts := data.Get("renew_statements").([]string)
// Get TTLs // Get TTLs
defaultTTLRaw := data.Get("default_ttl").(int) defaultTTLRaw := data.Get("default_ttl").(int)
@@ -156,10 +156,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
maxTTL := time.Duration(maxTTLRaw) * time.Second maxTTL := time.Duration(maxTTLRaw) * time.Second
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: creationStmts, Creation: creationStmts,
RevocationStatements: revocationStmts, Revocation: revocationStmts,
RollbackStatements: rollbackStmts, Rollback: rollbackStmts,
RenewStatements: renewStmts, Renewal: renewStmts,
} }
// Store it // Store it
@@ -181,10 +181,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
} }
type roleEntry struct { type roleEntry struct {
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` DBName string `json:"db_name"`
Statements dbplugin.Statements `json:"statements" mapstructure:"statements" structs:"statements"` Statements dbplugin.Statements `json:"statements"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` DefaultTTL time.Duration `json:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` MaxTTL time.Duration `json:"max_ttl"`
} }
const pathRoleHelpSyn = ` const pathRoleHelpSyn = `

View File

@@ -0,0 +1,80 @@
package database
import (
"context"
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathRotateCredentials(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: "rotate-root/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of this database connection",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRotateCredentialsUpdate(),
},
HelpSynopsis: pathCredsCreateReadHelpSyn,
HelpDescription: pathCredsCreateReadHelpDesc,
}
}
func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse(respErrEmptyName), nil
}
config, err := b.DatabaseConfig(ctx, req.Storage, name)
if err != nil {
return nil, err
}
db, err := b.GetConnection(ctx, req.Storage, name)
if err != nil {
return nil, err
}
// Take the write lock instead of read since we are updating the
// connection
db.Lock()
defer db.Unlock()
connectionDetails, err := db.RotateRootCredentials(ctx, config.RootCredentialsRotateStatements)
if err != nil {
return nil, err
}
config.ConnectionDetails = connectionDetails
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
if err := b.ClearConnection(name); err != nil {
return nil, err
}
return nil, nil
}
}
const pathRotateCredentialsUpdateHelpSyn = `
Request to rotate the root credentials for a certain database connection.
`
const pathRotateCredentialsUpdateHelpDesc = `
This path attempts to rotate the root credentials for the given database.
`

View File

@@ -48,37 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
return nil, err return nil, err
} }
// Grab the read lock
b.RLock()
unlockFunc := b.RUnlock
// Get the Database object // Get the Database object
db, ok := b.getDBObj(role.DBName) db, err := b.GetConnection(ctx, req.Storage, role.DBName)
if !ok { if err != nil {
// Upgrade lock return nil, err
b.RUnlock()
b.Lock()
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
}
} }
db.RLock()
defer db.RUnlock()
// Make sure we increase the VALID UNTIL endpoint for this user. // Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
err := db.RenewUser(ctx, role.Statements, username, expireTime) err := db.RenewUser(ctx, role.Statements, username, expireTime)
if err != nil { if err != nil {
unlockFunc() b.CloseIfShutdown(db, err)
b.closeIfShutdown(role.DBName, err)
return nil, err return nil, err
} }
} }
unlockFunc()
return resp, nil return resp, nil
} }
} }
@@ -107,33 +93,19 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"]) return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"])
} }
// Grab the read lock
b.RLock()
unlockFunc := b.RUnlock
// Get our connection // Get our connection
db, ok := b.getDBObj(role.DBName) db, err := b.GetConnection(ctx, req.Storage, role.DBName)
if !ok { if err != nil {
// Upgrade lock
b.RUnlock()
b.Lock()
unlockFunc = b.Unlock
// Create a new DB object
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
if err != nil {
unlockFunc()
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
}
}
if err := db.RevokeUser(ctx, role.Statements, username); err != nil {
unlockFunc()
b.closeIfShutdown(role.DBName, err)
return nil, err return nil, err
} }
unlockFunc() db.RLock()
defer db.RUnlock()
if err := db.RevokeUser(ctx, role.Statements, username); err != nil {
b.CloseIfShutdown(db, err)
return nil, err
}
return resp, nil return resp, nil
} }
} }

View File

@@ -11,27 +11,34 @@ import (
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil"
) )
const ( const (
defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
defaultUserDeletionCQL = `DROP USER '{{username}}';` defaultUserDeletionCQL = `DROP USER '{{username}}';`
cassandraTypeName = "cassandra" defaultRootCredentialRotationCQL = `ALTER USER {{username}} WITH PASSWORD '{{password}}';`
cassandraTypeName = "cassandra"
) )
var _ dbplugin.Database = &Cassandra{} var _ dbplugin.Database = &Cassandra{}
// Cassandra is an implementation of Database interface // Cassandra is an implementation of Database interface
type Cassandra struct { type Cassandra struct {
connutil.ConnectionProducer *cassandraConnectionProducer
credsutil.CredentialsProducer credsutil.CredentialsProducer
} }
// New returns a new Cassandra instance // New returns a new Cassandra instance
func New() (interface{}, error) { func New() (interface{}, error) {
db := new()
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}
func new() *Cassandra {
connProducer := &cassandraConnectionProducer{} connProducer := &cassandraConnectionProducer{}
connProducer.Type = cassandraTypeName connProducer.Type = cassandraTypeName
@@ -42,12 +49,10 @@ func New() (interface{}, error) {
Separator: "_", Separator: "_",
} }
dbType := &Cassandra{ return &Cassandra{
ConnectionProducer: connProducer, cassandraConnectionProducer: connProducer,
CredentialsProducer: credsProducer, CredentialsProducer: credsProducer,
} }
return dbType, nil
} }
// Run instantiates a Cassandra object, and runs the RPC server for the plugin // Run instantiates a Cassandra object, and runs the RPC server for the plugin
@@ -57,7 +62,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err return err
} }
plugins.Serve(dbType.(*Cassandra), apiTLSConfig) plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil return nil
} }
@@ -83,19 +88,22 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection // Get the connection
session, err := c.getConnection(ctx) session, err := c.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
creationCQL := statements.CreationStatements creationCQL := statements.Creation
if creationCQL == "" { if len(creationCQL) == 0 {
creationCQL = defaultUserCreationCQL creationCQL = []string{defaultUserCreationCQL}
} }
rollbackCQL := statements.RollbackStatements
if rollbackCQL == "" { rollbackCQL := statements.Rollback
rollbackCQL = defaultUserDeletionCQL if len(rollbackCQL) == 0 {
rollbackCQL = []string{defaultUserDeletionCQL}
} }
username, err = c.GenerateUsername(usernameConfig) username, err = c.GenerateUsername(usernameConfig)
@@ -112,28 +120,32 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
} }
// Execute each query // Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { for _, stmt := range creationCQL {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
}
err = session.Query(dbutil.QueryHelper(query, map[string]string{
"username": username, err = session.Query(dbutil.QueryHelper(query, map[string]string{
"password": password, "username": username,
})).Exec() "password": password,
if err != nil { })).Exec()
for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { if err != nil {
query = strings.TrimSpace(query) for _, stmt := range rollbackCQL {
if len(query) == 0 { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
continue query = strings.TrimSpace(query)
} if len(query) == 0 {
continue
session.Query(dbutil.QueryHelper(query, map[string]string{ }
"username": username,
})).Exec() session.Query(dbutil.QueryHelper(query, map[string]string{
"username": username,
})).Exec()
}
}
return "", "", err
} }
return "", "", err
} }
} }
@@ -152,29 +164,79 @@ func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statemen
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
session, err := c.getConnection(ctx) session, err := c.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
revocationCQL := statements.RevocationStatements revocationCQL := statements.Revocation
if revocationCQL == "" { if len(revocationCQL) == 0 {
revocationCQL = defaultUserDeletionCQL revocationCQL = []string{defaultUserDeletionCQL}
} }
var result *multierror.Error var result *multierror.Error
for _, query := range strutil.ParseArbitraryStringSlice(revocationCQL, ";") { for _, stmt := range revocationCQL {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
continue
}
err := session.Query(dbutil.QueryHelper(query, map[string]string{
"username": username,
})).Exec()
result = multierror.Append(result, err)
} }
err := session.Query(dbutil.QueryHelper(query, map[string]string{
"username": username,
})).Exec()
result = multierror.Append(result, err)
} }
return result.ErrorOrNil() return result.ErrorOrNil()
} }
func (c *Cassandra) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
// Grab the lock
c.Lock()
defer c.Unlock()
session, err := c.getConnection(ctx)
if err != nil {
return nil, err
}
rotateCQL := statements
if len(rotateCQL) == 0 {
rotateCQL = []string{defaultRootCredentialRotationCQL}
}
password, err := c.GeneratePassword()
if err != nil {
return nil, err
}
var result *multierror.Error
for _, stmt := range rotateCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
err := session.Query(dbutil.QueryHelper(query, map[string]string{
"username": c.Username,
"password": password,
})).Exec()
result = multierror.Append(result, err)
}
}
err = result.ErrorOrNil()
if err != nil {
return nil, err
}
c.rawConfig["password"] = password
return c.rawConfig, nil
}

View File

@@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"github.com/gocql/gocql" "github.com/gocql/gocql"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
dockertest "gopkg.in/ory-am/dockertest.v3" dockertest "gopkg.in/ory-am/dockertest.v3"
) )
@@ -60,7 +61,7 @@ func prepareCassandraTestContainer(t *testing.T) (func(), string, int) {
session, err := clusterConfig.CreateSession() session, err := clusterConfig.CreateSession()
if err != nil { if err != nil {
return fmt.Errorf("error creating session: %s", err) return errwrap.Wrapf("error creating session: {{err}}", err)
} }
defer session.Close() defer session.Close()
return nil return nil
@@ -86,16 +87,13 @@ func TestCassandra_Initialize(t *testing.T) {
"protocol_version": 4, "protocol_version": 4,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*Cassandra) _, err := db.Init(context.Background(), connectionDetails, true)
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
if !connProducer.Initialized { if !db.Initialized {
t.Fatal("Database should be initalized") t.Fatal("Database should be initalized")
} }
@@ -113,7 +111,7 @@ func TestCassandra_Initialize(t *testing.T) {
"protocol_version": "4", "protocol_version": "4",
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -134,15 +132,14 @@ func TestCassandra_CreateUser(t *testing.T) {
"protocol_version": 4, "protocol_version": 4,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*Cassandra) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testCassandraRole, Creation: []string{testCassandraRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -175,15 +172,14 @@ func TestMyCassandra_RenewUser(t *testing.T) {
"protocol_version": 4, "protocol_version": 4,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*Cassandra) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testCassandraRole, Creation: []string{testCassandraRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -221,15 +217,14 @@ func TestCassandra_RevokeUser(t *testing.T) {
"protocol_version": 4, "protocol_version": 4,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*Cassandra) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testCassandraRole, Creation: []string{testCassandraRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -268,7 +263,7 @@ func testCredsExist(t testing.TB, address string, port int, username, password s
session, err := clusterConfig.CreateSession() session, err := clusterConfig.CreateSession()
if err != nil { if err != nil {
return fmt.Errorf("error creating session: %s", err) return errwrap.Wrapf("error creating session: {{err}}", err)
} }
defer session.Close() defer session.Close()
return nil return nil

View File

@@ -11,6 +11,7 @@ import (
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/gocql/gocql" "github.com/gocql/gocql"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/tlsutil" "github.com/hashicorp/vault/helper/tlsutil"
@@ -37,6 +38,7 @@ type cassandraConnectionProducer struct {
certificate string certificate string
privateKey string privateKey string
issuingCA string issuingCA string
rawConfig map[string]interface{}
Initialized bool Initialized bool
Type string Type string
@@ -45,12 +47,19 @@ type cassandraConnectionProducer struct {
} }
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
c.rawConfig = conf
err := mapstructure.WeakDecode(conf, c) err := mapstructure.WeakDecode(conf, c)
if err != nil { if err != nil {
return err return nil, err
} }
if c.ConnectTimeoutRaw == nil { if c.ConnectTimeoutRaw == nil {
@@ -58,16 +67,16 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
} }
c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
if err != nil { if err != nil {
return fmt.Errorf("invalid connect_timeout: %s", err) return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
} }
switch { switch {
case len(c.Hosts) == 0: case len(c.Hosts) == 0:
return fmt.Errorf("hosts cannot be empty") return nil, fmt.Errorf("hosts cannot be empty")
case len(c.Username) == 0: case len(c.Username) == 0:
return fmt.Errorf("username cannot be empty") return nil, fmt.Errorf("username cannot be empty")
case len(c.Password) == 0: case len(c.Password) == 0:
return fmt.Errorf("password cannot be empty") return nil, fmt.Errorf("password cannot be empty")
} }
var certBundle *certutil.CertBundle var certBundle *certutil.CertBundle
@@ -76,11 +85,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
case len(c.PemJSON) != 0: case len(c.PemJSON) != 0:
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON)) parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
if err != nil { if err != nil {
return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err) return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
} }
certBundle, err = parsedCertBundle.ToCertBundle() certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil { if err != nil {
return fmt.Errorf("Error marshaling PEM information: %s", err) return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
} }
c.certificate = certBundle.Certificate c.certificate = certBundle.Certificate
c.privateKey = certBundle.PrivateKey c.privateKey = certBundle.PrivateKey
@@ -90,11 +99,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
case len(c.PemBundle) != 0: case len(c.PemBundle) != 0:
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle) parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
if err != nil { if err != nil {
return fmt.Errorf("Error parsing the given PEM information: %s", err) return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
} }
certBundle, err = parsedCertBundle.ToCertBundle() certBundle, err = parsedCertBundle.ToCertBundle()
if err != nil { if err != nil {
return fmt.Errorf("Error marshaling PEM information: %s", err) return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
} }
c.certificate = certBundle.Certificate c.certificate = certBundle.Certificate
c.privateKey = certBundle.PrivateKey c.privateKey = certBundle.PrivateKey
@@ -108,11 +117,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
if verifyConnection { if verifyConnection {
if _, err := c.Connection(ctx); err != nil { if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
} }
} }
return nil return conf, nil
} }
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) { func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
@@ -186,12 +195,12 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
parsedCertBundle, err := certBundle.ToParsedCertBundle() parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
} }
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil { if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
} }
tlsConfig.InsecureSkipVerify = c.InsecureTLS tlsConfig.InsecureSkipVerify = c.InsecureTLS
@@ -215,7 +224,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
session, err := clusterConfig.CreateSession() session, err := clusterConfig.CreateSession()
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating session: %s", err) return nil, errwrap.Wrapf("error creating session: {{err}}", err)
} }
// Set consistency // Set consistency
@@ -231,8 +240,16 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
// Verify the info // Verify the info
err = session.Query(`LIST ALL`).Exec() err = session.Query(`LIST ALL`).Exec()
if err != nil { if err != nil {
return nil, fmt.Errorf("error validating connection info: %s", err) return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
} }
return session, nil return session, nil
} }
func (c *cassandraConnectionProducer) secretValues() map[string]interface{} {
return map[string]interface{}{
c.Password: "[password]",
c.PemBundle: "[pem_bundle]",
c.PemJSON: "[pem_json]",
}
}

View File

@@ -572,7 +572,7 @@ ssl_storage_port: 7001
# #
# Setting listen_address to 0.0.0.0 is always wrong. # Setting listen_address to 0.0.0.0 is always wrong.
# #
listen_address: 172.17.0.5 listen_address: 172.17.0.2
# Set listen_address OR listen_interface, not both. Interfaces must correspond # Set listen_address OR listen_interface, not both. Interfaces must correspond
# to a single address, IP aliasing is not supported. # to a single address, IP aliasing is not supported.

View File

@@ -3,6 +3,7 @@ package hana
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -23,7 +24,7 @@ const (
// HANA is an implementation of Database interface // HANA is an implementation of Database interface
type HANA struct { type HANA struct {
connutil.ConnectionProducer *connutil.SQLConnectionProducer
credsutil.CredentialsProducer credsutil.CredentialsProducer
} }
@@ -31,6 +32,14 @@ var _ dbplugin.Database = &HANA{}
// New implements builtinplugins.BuiltinFactory // New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) { func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
return dbType, nil
}
func new() *HANA {
connProducer := &connutil.SQLConnectionProducer{} connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = hanaTypeName connProducer.Type = hanaTypeName
@@ -41,12 +50,10 @@ func New() (interface{}, error) {
Separator: "_", Separator: "_",
} }
dbType := &HANA{ return &HANA{
ConnectionProducer: connProducer, SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer, CredentialsProducer: credsProducer,
} }
return dbType, nil
} }
// Run instantiates a HANA object, and runs the RPC server for the plugin // Run instantiates a HANA object, and runs the RPC server for the plugin
@@ -56,7 +63,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err return err
} }
plugins.Serve(dbType.(*HANA), apiTLSConfig) plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil return nil
} }
@@ -82,13 +89,15 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection // Get the connection
db, err := h.getConnection(ctx) db, err := h.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
if statements.CreationStatements == "" { if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement return "", "", dbutil.ErrEmptyCreationStatement
} }
@@ -127,23 +136,25 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
defer tx.Rollback() defer tx.Rollback()
// Execute each query // Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { for _, stmt := range statements.Creation {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
})) }))
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
}
} }
} }
@@ -157,6 +168,8 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
// Renewing hana user just means altering user's valid until property // Renewing hana user just means altering user's valid until property
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
statements = dbutil.StatementCompatibilityHelper(statements)
// Get connection // Get connection
db, err := h.getConnection(ctx) db, err := h.getConnection(ctx)
if err != nil { if err != nil {
@@ -197,8 +210,10 @@ func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, us
// Revoking hana user will deactivate user and try to perform a soft drop // Revoking hana user will deactivate user and try to perform a soft drop
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
statements = dbutil.StatementCompatibilityHelper(statements)
// default revoke will be a soft drop on user // default revoke will be a soft drop on user
if statements.RevocationStatements == "" { if len(statements.Revocation) == 0 {
return h.revokeUserDefault(ctx, username) return h.revokeUserDefault(ctx, username)
} }
@@ -216,30 +231,27 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u
defer tx.Rollback() defer tx.Rollback()
// Execute each query // Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") { for _, stmt := range statements.Revocation {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
}
} }
} }
// Commit the transaction return tx.Commit()
if err := tx.Commit(); err != nil {
return err
}
return nil
} }
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error { func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
@@ -284,3 +296,8 @@ func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
return nil return nil
} }
// RotateRootCredentials is not currently supported on HANA
func (h *HANA) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
return nil, errors.New("root credentaion rotation is not currently implemented in this database secrets engine")
}

View File

@@ -10,7 +10,6 @@ import (
"time" "time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
) )
func TestHANA_Initialize(t *testing.T) { func TestHANA_Initialize(t *testing.T) {
@@ -23,16 +22,13 @@ func TestHANA_Initialize(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*HANA) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) if !db.Initialized {
if !connProducer.Initialized {
t.Fatal("Database should be initialized") t.Fatal("Database should be initialized")
} }
@@ -53,10 +49,8 @@ func TestHANA_CreateUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*HANA) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -73,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testHANARole, Creation: []string{testHANARole},
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
@@ -96,16 +90,14 @@ func TestHANA_RevokeUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*HANA) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testHANARole, Creation: []string{testHANARole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -139,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
statements.RevocationStatements = testHANADrop statements.Revocation = []string{testHANADrop}
err = db.RevokeUser(context.Background(), statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)

View File

@@ -14,7 +14,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"gopkg.in/mgo.v2" "gopkg.in/mgo.v2"
@@ -25,28 +27,43 @@ import (
type mongoDBConnectionProducer struct { type mongoDBConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"` WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
Initialized bool Initialized bool
RawConfig map[string]interface{}
Type string Type string
session *mgo.Session session *mgo.Session
safe *mgo.Safe safe *mgo.Safe
sync.Mutex sync.Mutex
} }
// Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
// Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
c.RawConfig = conf
err := mapstructure.WeakDecode(conf, c) err := mapstructure.WeakDecode(conf, c)
if err != nil { if err != nil {
return err return nil, err
} }
if len(c.ConnectionURL) == 0 { if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty") return nil, fmt.Errorf("connection_url cannot be empty")
} }
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
"username": c.Username,
"password": c.Password,
})
if c.WriteConcern != "" { if c.WriteConcern != "" {
input := c.WriteConcern input := c.WriteConcern
@@ -60,13 +77,13 @@ func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[str
concern := &mgo.Safe{} concern := &mgo.Safe{}
err = json.Unmarshal([]byte(input), concern) err = json.Unmarshal([]byte(input), concern)
if err != nil { if err != nil {
return fmt.Errorf("error mashalling write_concern: %s", err) return nil, errwrap.Wrapf("error mashalling write_concern: {{err}}", err)
} }
// Guard against empty, non-nil mgo.Safe object; we don't want to pass that // Guard against empty, non-nil mgo.Safe object; we don't want to pass that
// into mgo.SetSafe in Connection(). // into mgo.SetSafe in Connection().
if (mgo.Safe{} == *concern) { if (mgo.Safe{} == *concern) {
return fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields") return nil, fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields")
} }
c.safe = concern c.safe = concern
} }
@@ -77,15 +94,15 @@ func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[str
if verifyConnection { if verifyConnection {
if _, err := c.Connection(ctx); err != nil { if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
} }
if err := c.session.Ping(); err != nil { if err := c.session.Ping(); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
} }
} }
return nil return conf, nil
} }
// Connection creates or returns an existing a database connection. If the session fails // Connection creates or returns an existing a database connection. If the session fails
@@ -203,3 +220,9 @@ func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
return &info, nil return &info, nil
} }
func (c *mongoDBConnectionProducer) secretValues() map[string]interface{} {
return map[string]interface{}{
c.Password: "[password]",
}
}

View File

@@ -2,6 +2,7 @@ package mongodb
import ( import (
"context" "context"
"errors"
"io" "io"
"strings" "strings"
"time" "time"
@@ -14,7 +15,6 @@ import (
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil"
"gopkg.in/mgo.v2" "gopkg.in/mgo.v2"
@@ -24,7 +24,7 @@ const mongoDBTypeName = "mongodb"
// MongoDB is an implementation of Database interface // MongoDB is an implementation of Database interface
type MongoDB struct { type MongoDB struct {
connutil.ConnectionProducer *mongoDBConnectionProducer
credsutil.CredentialsProducer credsutil.CredentialsProducer
} }
@@ -32,6 +32,12 @@ var _ dbplugin.Database = &MongoDB{}
// New returns a new MongoDB instance // New returns a new MongoDB instance
func New() (interface{}, error) { func New() (interface{}, error) {
db := new()
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}
func new() *MongoDB {
connProducer := &mongoDBConnectionProducer{} connProducer := &mongoDBConnectionProducer{}
connProducer.Type = mongoDBTypeName connProducer.Type = mongoDBTypeName
@@ -42,11 +48,10 @@ func New() (interface{}, error) {
Separator: "-", Separator: "-",
} }
dbType := &MongoDB{ return &MongoDB{
ConnectionProducer: connProducer, mongoDBConnectionProducer: connProducer,
CredentialsProducer: credsProducer, CredentialsProducer: credsProducer,
} }
return dbType, nil
} }
// Run instantiates a MongoDB object, and runs the RPC server for the plugin // Run instantiates a MongoDB object, and runs the RPC server for the plugin
@@ -88,7 +93,9 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
if statements.CreationStatements == "" { statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement return "", "", dbutil.ErrEmptyCreationStatement
} }
@@ -109,7 +116,7 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements
// Unmarshal statements.CreationStatements into mongodbRoles // Unmarshal statements.CreationStatements into mongodbRoles
var mongoCS mongoDBStatement var mongoCS mongoDBStatement
err = json.Unmarshal([]byte(statements.CreationStatements), &mongoCS) err = json.Unmarshal([]byte(statements.Creation[0]), &mongoCS)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@@ -158,15 +165,22 @@ func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements,
// RevokeUser drops the specified user from the authentication database. If none is provided // RevokeUser drops the specified user from the authentication database. If none is provided
// in the revocation statement, the default "admin" authentication database will be assumed. // in the revocation statement, the default "admin" authentication database will be assumed.
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
statements = dbutil.StatementCompatibilityHelper(statements)
session, err := m.getConnection(ctx) session, err := m.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
// If no revocation statements provided, pass in empty JSON // If no revocation statements provided, pass in empty JSON
revocationStatement := statements.RevocationStatements var revocationStatement string
if revocationStatement == "" { switch len(statements.Revocation) {
case 0:
revocationStatement = `{}` revocationStatement = `{}`
case 1:
revocationStatement = statements.Revocation[0]
default:
return fmt.Errorf("expected 0 or 1 revocation statements, got %d", len(statements.Revocation))
} }
// Unmarshal revocation statements into mongodbRoles // Unmarshal revocation statements into mongodbRoles
@@ -186,7 +200,7 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements
switch { switch {
case err == nil, err == mgo.ErrNotFound: case err == nil, err == mgo.ErrNotFound:
case err == io.EOF, strings.Contains(err.Error(), "EOF"): case err == io.EOF, strings.Contains(err.Error(), "EOF"):
if err := m.ConnectionProducer.Close(); err != nil { if err := m.Close(); err != nil {
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err) return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
} }
session, err := m.getConnection(ctx) session, err := m.getConnection(ctx)
@@ -203,3 +217,8 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements
return nil return nil
} }
// RotateRootCredentials is not currently supported on MongoDB
func (m *MongoDB) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
return nil, errors.New("root credentaion rotation is not currently implemented in this database secrets engine")
}

View File

@@ -73,19 +73,13 @@ func TestMongoDB_Initialize(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, err := New() db := new()
if err != nil { _, err := db.Init(context.Background(), connectionDetails, true)
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
if !connProducer.Initialized { if !db.Initialized {
t.Fatal("Database should be initialized") t.Fatal("Database should be initialized")
} }
@@ -103,18 +97,14 @@ func TestMongoDB_CreateUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, err := New() db := new()
if err != nil { _, err := db.Init(context.Background(), connectionDetails, true)
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testMongoDBRole, Creation: []string{testMongoDBRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -141,18 +131,14 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
"write_concern": testMongoDBWriteConcern, "write_concern": testMongoDBWriteConcern,
} }
dbRaw, err := New() db := new()
if err != nil { _, err := db.Init(context.Background(), connectionDetails, true)
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testMongoDBRole, Creation: []string{testMongoDBRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -178,18 +164,14 @@ func TestMongoDB_RevokeUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, err := New() db := new()
if err != nil { _, err := db.Init(context.Background(), connectionDetails, true)
t.Fatalf("err: %s", err)
}
db := dbRaw.(*MongoDB)
err = db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testMongoDBRole, Creation: []string{testMongoDBRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{

View File

@@ -3,11 +3,13 @@ package mssql
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
_ "github.com/denisenkom/go-mssqldb" _ "github.com/denisenkom/go-mssqldb"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/helper/strutil"
@@ -23,11 +25,19 @@ var _ dbplugin.Database = &MSSQL{}
// MSSQL is an implementation of Database interface // MSSQL is an implementation of Database interface
type MSSQL struct { type MSSQL struct {
connutil.ConnectionProducer *connutil.SQLConnectionProducer
credsutil.CredentialsProducer credsutil.CredentialsProducer
} }
func New() (interface{}, error) { func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
return dbType, nil
}
func new() *MSSQL {
connProducer := &connutil.SQLConnectionProducer{} connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = msSQLTypeName connProducer.Type = msSQLTypeName
@@ -38,12 +48,10 @@ func New() (interface{}, error) {
Separator: "-", Separator: "-",
} }
dbType := &MSSQL{ return &MSSQL{
ConnectionProducer: connProducer, SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer, CredentialsProducer: credsProducer,
} }
return dbType, nil
} }
// Run instantiates a MSSQL object, and runs the RPC server for the plugin // Run instantiates a MSSQL object, and runs the RPC server for the plugin
@@ -53,7 +61,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err return err
} }
plugins.Serve(dbType.(*MSSQL), apiTLSConfig) plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil return nil
} }
@@ -79,13 +87,15 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection // Get the connection
db, err := m.getConnection(ctx) db, err := m.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
if statements.CreationStatements == "" { if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement return "", "", dbutil.ErrEmptyCreationStatement
} }
@@ -112,23 +122,25 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
defer tx.Rollback() defer tx.Rollback()
// Execute each query // Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { for _, stmt := range statements.Creation {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
"password": password, "password": password,
"expiration": expirationStr, "expiration": expirationStr,
})) }))
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err return "", "", err
}
} }
} }
@@ -150,7 +162,9 @@ func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, u
// then kill pending connections from that user, and finally drop the user and login from the // then kill pending connections from that user, and finally drop the user and login from the
// database instance. // database instance.
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
if statements.RevocationStatements == "" { statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Revocation) == 0 {
return m.revokeUserDefault(ctx, username) return m.revokeUserDefault(ctx, username)
} }
@@ -168,21 +182,23 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
defer tx.Rollback() defer tx.Rollback()
// Execute each query // Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") { for _, stmt := range statements.Revocation {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
}
} }
} }
@@ -283,10 +299,10 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
// can't drop if not all database users are dropped // can't drop if not all database users are dropped
if rows.Err() != nil { if rows.Err() != nil {
return fmt.Errorf("could not generate sql statements for all rows: %s", rows.Err()) return errwrap.Wrapf("could not generate sql statements for all rows: {{err}}", rows.Err())
} }
if lastStmtError != nil { if lastStmtError != nil {
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) return errwrap.Wrapf("could not perform all sql statements: {{err}}", lastStmtError)
} }
// Drop this login // Drop this login
@@ -302,6 +318,70 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
return nil return nil
} }
func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
m.Lock()
defer m.Unlock()
if len(m.Username) == 0 || len(m.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{rotateRootCredentialsSQL}
}
db, err := m.getConnection(ctx)
if err != nil {
return nil, err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := m.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"username": m.Username,
"password": password,
}))
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
if err := db.Close(); err != nil {
return nil, err
}
m.RawConfig["password"] = password
return m.RawConfig, nil
}
const dropUserSQL = ` const dropUserSQL = `
USE [%s] USE [%s]
IF EXISTS IF EXISTS
@@ -322,3 +402,7 @@ BEGIN
DROP LOGIN [%s] DROP LOGIN [%s]
END END
` `
const rotateRootCredentialsSQL = `
ALTER LOGIN [%s] WITH PASSWORD = '%s'
`

View File

@@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
) )
var ( var (
@@ -28,16 +27,13 @@ func TestMSSQL_Initialize(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*MSSQL) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) if !db.Initialized {
if !connProducer.Initialized {
t.Fatal("Database should be initalized") t.Fatal("Database should be initalized")
} }
@@ -52,7 +48,7 @@ func TestMSSQL_Initialize(t *testing.T) {
"max_open_connections": "5", "max_open_connections": "5",
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -68,9 +64,8 @@ func TestMSSQL_CreateUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*MSSQL) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -87,7 +82,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testMSSQLRole, Creation: []string{testMSSQLRole},
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
@@ -110,15 +105,14 @@ func TestMSSQL_RevokeUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*MSSQL) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testMSSQLRole, Creation: []string{testMSSQLRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -155,7 +149,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
} }
// Test custom revoke statement // Test custom revoke statement
statements.RevocationStatements = testMSSQLDrop statements.Revocation = []string{testMSSQLDrop}
err = db.RevokeUser(context.Background(), statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)

View File

@@ -3,6 +3,7 @@ package mysql
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"strings" "strings"
"time" "time"
@@ -21,6 +22,11 @@ const (
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
DROP USER '{{name}}'@'%' DROP USER '{{name}}'@'%'
` `
defaultMySQLRotateRootCredentialsSQL = `
ALTER USER '{{username}}'@'%' IDENTIFIED BY '{{password}}';
`
mySQLTypeName = "mysql" mySQLTypeName = "mysql"
) )
@@ -34,32 +40,38 @@ var (
var _ dbplugin.Database = &MySQL{} var _ dbplugin.Database = &MySQL{}
type MySQL struct { type MySQL struct {
connutil.ConnectionProducer *connutil.SQLConnectionProducer
credsutil.CredentialsProducer credsutil.CredentialsProducer
} }
// New implements builtinplugins.BuiltinFactory // New implements builtinplugins.BuiltinFactory
func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, error) { func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, error) {
return func() (interface{}, error) { return func() (interface{}, error) {
connProducer := &connutil.SQLConnectionProducer{} db := new(displayNameLen, roleNameLen, usernameLen)
connProducer.Type = mySQLTypeName // Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: displayNameLen,
RoleNameLen: roleNameLen,
UsernameLen: usernameLen,
Separator: "-",
}
dbType := &MySQL{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
return dbType, nil return dbType, nil
} }
} }
func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = mySQLTypeName
credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: displayNameLen,
RoleNameLen: roleNameLen,
UsernameLen: usernameLen,
Separator: "-",
}
return &MySQL{
SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
}
// Run instantiates a MySQL object, and runs the RPC server for the plugin // Run instantiates a MySQL object, and runs the RPC server for the plugin
func Run(apiTLSConfig *api.TLSConfig) error { func Run(apiTLSConfig *api.TLSConfig) error {
return runCommon(false, apiTLSConfig) return runCommon(false, apiTLSConfig)
@@ -82,7 +94,7 @@ func runCommon(legacy bool, apiTLSConfig *api.TLSConfig) error {
return err return err
} }
plugins.Serve(dbType.(*MySQL), apiTLSConfig) plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil return nil
} }
@@ -105,13 +117,15 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection // Get the connection
db, err := m.getConnection(ctx) db, err := m.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
if statements.CreationStatements == "" { if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement return "", "", dbutil.ErrEmptyCreationStatement
} }
@@ -138,38 +152,40 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
defer tx.Rollback() defer tx.Rollback()
// Execute each query // Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { for _, stmt := range statements.Creation {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
}
query = dbutil.QueryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expirationStr,
})
stmt, err := tx.PrepareContext(ctx, query)
if err != nil {
// If the error code we get back is Error 1295: This command is not
// supported in the prepared statement protocol yet, we will execute
// the statement without preparing it. This allows the caller to
// manually prepare statements, as well as run other not yet
// prepare supported commands. If there is no error when running we
// will continue to the next statement.
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
_, err = tx.ExecContext(ctx, query)
if err != nil {
return "", "", err
}
continue continue
} }
query = dbutil.QueryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expirationStr,
})
return "", "", err stmt, err := tx.PrepareContext(ctx, query)
} if err != nil {
defer stmt.Close() // If the error code we get back is Error 1295: This command is not
if _, err := stmt.ExecContext(ctx); err != nil { // supported in the prepared statement protocol yet, we will execute
return "", "", err // the statement without preparing it. This allows the caller to
// manually prepare statements, as well as run other not yet
// prepare supported commands. If there is no error when running we
// will continue to the next statement.
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
_, err = tx.ExecContext(ctx, query)
if err != nil {
return "", "", err
}
continue
}
return "", "", err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
}
} }
} }
@@ -191,16 +207,18 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
// Get the connection // Get the connection
db, err := m.getConnection(ctx) db, err := m.getConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
revocationStmts := statements.RevocationStatements revocationStmts := statements.Revocation
// Use a default SQL statement for revocation if one cannot be fetched from the role // Use a default SQL statement for revocation if one cannot be fetched from the role
if revocationStmts == "" { if len(revocationStmts) == 0 {
revocationStmts = defaultMysqlRevocationStmts revocationStmts = []string{defaultMysqlRevocationStmts}
} }
// Start a transaction // Start a transaction
@@ -210,21 +228,22 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
} }
defer tx.Rollback() defer tx.Rollback()
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { for _, stmt := range revocationStmts {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
}
// This is not a prepared statement because not all commands are supported // This is not a prepared statement because not all commands are supported
// 1295: This command is not supported in the prepared statement protocol yet // 1295: This command is not supported in the prepared statement protocol yet
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
query = strings.Replace(query, "{{name}}", username, -1) query = strings.Replace(query, "{{name}}", username, -1)
_, err = tx.ExecContext(ctx, query) _, err = tx.ExecContext(ctx, query)
if err != nil { if err != nil {
return err return err
}
} }
} }
// Commit the transaction // Commit the transaction
@@ -234,3 +253,67 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
return nil return nil
} }
func (m *MySQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
m.Lock()
defer m.Unlock()
if len(m.Username) == 0 || len(m.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{defaultMySQLRotateRootCredentialsSQL}
}
db, err := m.getConnection(ctx)
if err != nil {
return nil, err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := m.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"username": m.Username,
"password": password,
}))
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
if err := db.Close(); err != nil {
return nil, err
}
m.RawConfig["password"] = password
return m.RawConfig, nil
}

View File

@@ -9,9 +9,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
dockertest "gopkg.in/ory-am/dockertest.v3" dockertest "gopkg.in/ory-am/dockertest.v3"
) )
@@ -104,17 +104,13 @@ func TestMySQL_Initialize(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
f := New(MetadataLen, MetadataLen, UsernameLen) db := new(MetadataLen, MetadataLen, UsernameLen)
dbRaw, _ := f() _, err := db.Init(context.Background(), connectionDetails, true)
db := dbRaw.(*MySQL)
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
if !connProducer.Initialized { if !db.Initialized {
t.Fatal("Database should be initalized") t.Fatal("Database should be initalized")
} }
@@ -129,7 +125,7 @@ func TestMySQL_Initialize(t *testing.T) {
"max_open_connections": "5", "max_open_connections": "5",
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -143,11 +139,8 @@ func TestMySQL_CreateUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
f := New(MetadataLen, MetadataLen, UsernameLen) db := new(MetadataLen, MetadataLen, UsernameLen)
dbRaw, _ := f() _, err := db.Init(context.Background(), connectionDetails, true)
db := dbRaw.(*MySQL)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -164,7 +157,7 @@ func TestMySQL_CreateUser(t *testing.T) {
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testMySQLRoleWildCard, Creation: []string{testMySQLRoleWildCard},
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
@@ -187,7 +180,7 @@ func TestMySQL_CreateUser(t *testing.T) {
} }
// Test with a manually prepare statement // Test with a manually prepare statement
statements.CreationStatements = testMySQLRolePreparedStmt statements.Creation = []string{testMySQLRolePreparedStmt}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
@@ -208,11 +201,8 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
f := New(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen) db := new(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen)
dbRaw, _ := f() _, err := db.Init(context.Background(), connectionDetails, true)
db := dbRaw.(*MySQL)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -229,7 +219,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testMySQLRoleWildCard, Creation: []string{testMySQLRoleWildCard},
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
@@ -252,6 +242,42 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
} }
} }
func TestMySQL_RotateRootCredentials(t *testing.T) {
cleanup, connURL := prepareMySQLTestContainer(t)
defer cleanup()
connURL = strings.Replace(connURL, "root:secret", `{{username}}:{{password}}`, -1)
connectionDetails := map[string]interface{}{
"connection_url": connURL,
"username": "root",
"password": "secret",
}
db := new(MetadataLen, MetadataLen, UsernameLen)
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !db.Initialized {
t.Fatal("Database should be initalized")
}
newConf, err := db.RotateRootCredentials(context.Background(), nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if newConf["password"] == "secret" {
t.Fatal("password was not updated")
}
err = db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestMySQL_RevokeUser(t *testing.T) { func TestMySQL_RevokeUser(t *testing.T) {
cleanup, connURL := prepareMySQLTestContainer(t) cleanup, connURL := prepareMySQLTestContainer(t)
defer cleanup() defer cleanup()
@@ -260,17 +286,14 @@ func TestMySQL_RevokeUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
f := New(MetadataLen, MetadataLen, UsernameLen) db := new(MetadataLen, MetadataLen, UsernameLen)
dbRaw, _ := f() _, err := db.Init(context.Background(), connectionDetails, true)
db := dbRaw.(*MySQL)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testMySQLRoleWildCard, Creation: []string{testMySQLRoleWildCard},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -297,7 +320,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
t.Fatal("Credentials were not revoked") t.Fatal("Credentials were not revoked")
} }
statements.CreationStatements = testMySQLRoleWildCard statements.Creation = []string{testMySQLRoleWildCard}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
@@ -308,7 +331,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
} }
// Test custom revoke statements // Test custom revoke statements
statements.RevocationStatements = testMySQLRevocationSQL statements.Revocation = []string{testMySQLRevocationSQL}
err = db.RevokeUser(context.Background(), statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)

View File

@@ -3,10 +3,12 @@ package postgresql
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/helper/strutil"
@@ -15,13 +17,15 @@ import (
"github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil"
"github.com/lib/pq" "github.com/lib/pq"
_ "github.com/lib/pq"
) )
const ( const (
postgreSQLTypeName string = "postgres" postgreSQLTypeName = "postgres"
defaultPostgresRenewSQL = ` defaultPostgresRenewSQL = `
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
`
defaultPostgresRotateRootCredentialsSQL = `
ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
` `
) )
@@ -29,6 +33,13 @@ var _ dbplugin.Database = &PostgreSQL{}
// New implements builtinplugins.BuiltinFactory // New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) { func New() (interface{}, error) {
db := new()
// Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
return dbType, nil
}
func new() *PostgreSQL {
connProducer := &connutil.SQLConnectionProducer{} connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = postgreSQLTypeName connProducer.Type = postgreSQLTypeName
@@ -39,12 +50,12 @@ func New() (interface{}, error) {
Separator: "-", Separator: "-",
} }
dbType := &PostgreSQL{ db := &PostgreSQL{
ConnectionProducer: connProducer, SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer, CredentialsProducer: credsProducer,
} }
return dbType, nil return db
} }
// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin // Run instantiates a PostgreSQL object, and runs the RPC server for the plugin
@@ -54,13 +65,13 @@ func Run(apiTLSConfig *api.TLSConfig) error {
return err return err
} }
plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig) plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
return nil return nil
} }
type PostgreSQL struct { type PostgreSQL struct {
connutil.ConnectionProducer *connutil.SQLConnectionProducer
credsutil.CredentialsProducer credsutil.CredentialsProducer
} }
@@ -78,7 +89,9 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
} }
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
if statements.CreationStatements == "" { statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement return "", "", dbutil.ErrEmptyCreationStatement
} }
@@ -105,7 +118,6 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
db, err := p.getConnection(ctx) db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
// Start a transaction // Start a transaction
@@ -120,25 +132,25 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
// Return the secret // Return the secret
// Execute each query // Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { for _, stmt := range statements.Creation {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expirationStr,
}))
if err != nil {
return "", "", err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expirationStr,
}))
if err != nil {
return "", "", err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return "", "", err
}
} }
} }
@@ -155,9 +167,11 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
renewStmts := statements.RenewStatements statements = dbutil.StatementCompatibilityHelper(statements)
if renewStmts == "" {
renewStmts = defaultPostgresRenewSQL renewStmts := statements.Renewal
if len(renewStmts) == 0 {
renewStmts = []string{defaultPostgresRenewSQL}
} }
db, err := p.getConnection(ctx) db, err := p.getConnection(ctx)
@@ -178,30 +192,28 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
return err return err
} }
for _, query := range strutil.ParseArbitraryStringSlice(renewStmts, ";") { for _, stmt := range renewStmts {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ }
"name": username, stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"expiration": expirationStr, "name": username,
})) "expiration": expirationStr,
if err != nil { }))
return err if err != nil {
} return err
}
defer stmt.Close() defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
}
} }
} }
if err := tx.Commit(); err != nil { return tx.Commit()
return err
}
return nil
} }
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
@@ -209,14 +221,16 @@ func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Stateme
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if statements.RevocationStatements == "" { statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Revocation) == 0 {
return p.defaultRevokeUser(ctx, username) return p.defaultRevokeUser(ctx, username)
} }
return p.customRevokeUser(ctx, username, statements.RevocationStatements) return p.customRevokeUser(ctx, username, statements.Revocation)
} }
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error { func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error {
db, err := p.getConnection(ctx) db, err := p.getConnection(ctx)
if err != nil { if err != nil {
return err return err
@@ -230,30 +244,28 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationS
tx.Rollback() tx.Rollback()
}() }()
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { for _, stmt := range revocationStmts {
query = strings.TrimSpace(query) for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
if len(query) == 0 { query = strings.TrimSpace(query)
continue if len(query) == 0 {
} continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{ stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"name": username, "name": username,
})) }))
if err != nil { if err != nil {
return err return err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return err
}
} }
} }
if err := tx.Commit(); err != nil { return tx.Commit()
return err
}
return nil
} }
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
@@ -354,10 +366,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
// can't drop if not all privileges are revoked // can't drop if not all privileges are revoked
if rows.Err() != nil { if rows.Err() != nil {
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err())
} }
if lastStmtError != nil { if lastStmtError != nil {
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError)
} }
// Drop this user // Drop this user
@@ -373,3 +385,68 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
return nil return nil
} }
func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
p.Lock()
defer p.Unlock()
if len(p.Username) == 0 || len(p.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatents := statements
if len(rotateStatents) == 0 {
rotateStatents = []string{defaultPostgresRotateRootCredentialsSQL}
}
db, err := p.getConnection(ctx)
if err != nil {
return nil, err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := p.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatents {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
"username": p.Username,
"password": password,
}))
if err != nil {
return nil, err
}
defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
// Close the database connection to ensure no new connections come in
if err := db.Close(); err != nil {
return nil, err
}
p.RawConfig["password"] = password
return p.RawConfig, nil
}

View File

@@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
dockertest "gopkg.in/ory-am/dockertest.v3" dockertest "gopkg.in/ory-am/dockertest.v3"
) )
@@ -68,17 +67,13 @@ func TestPostgreSQL_Initialize(t *testing.T) {
"max_open_connections": 5, "max_open_connections": 5,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*PostgreSQL) _, err := db.Init(context.Background(), connectionDetails, true)
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
if !connProducer.Initialized { if !db.Initialized {
t.Fatal("Database should be initalized") t.Fatal("Database should be initalized")
} }
@@ -93,7 +88,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
"max_open_connections": "5", "max_open_connections": "5",
} }
err = db.Initialize(context.Background(), connectionDetails, true) _, err = db.Init(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -108,9 +103,8 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*PostgreSQL) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
@@ -127,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testPostgresRole, Creation: []string{testPostgresRole},
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
@@ -139,7 +133,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
statements.CreationStatements = testPostgresReadOnlyRole statements.Creation = []string{testPostgresReadOnlyRole}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
@@ -161,15 +155,14 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*PostgreSQL) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testPostgresRole, Creation: []string{testPostgresRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -197,7 +190,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
if err = testCredsExist(t, connURL, username, password); err != nil { if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
statements.RenewStatements = defaultPostgresRenewSQL statements.Renewal = []string{defaultPostgresRenewSQL}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
@@ -221,6 +214,46 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
} }
func TestPostgreSQL_RotateRootCredentials(t *testing.T) {
cleanup, connURL := preparePostgresTestContainer(t)
defer cleanup()
connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1)
connectionDetails := map[string]interface{}{
"connection_url": connURL,
"max_open_connections": 5,
"username": "postgres",
"password": "secret",
}
db := new()
connProducer := db.SQLConnectionProducer
_, err := db.Init(context.Background(), connectionDetails, true)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.Initialized {
t.Fatal("Database should be initalized")
}
newConf, err := db.RotateRootCredentials(context.Background(), nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if newConf["password"] == "secret" {
t.Fatal("password was not updated")
}
err = db.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPostgreSQL_RevokeUser(t *testing.T) { func TestPostgreSQL_RevokeUser(t *testing.T) {
cleanup, connURL := preparePostgresTestContainer(t) cleanup, connURL := preparePostgresTestContainer(t)
defer cleanup() defer cleanup()
@@ -229,15 +262,14 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
"connection_url": connURL, "connection_url": connURL,
} }
dbRaw, _ := New() db := new()
db := dbRaw.(*PostgreSQL) _, err := db.Init(context.Background(), connectionDetails, true)
err := db.Initialize(context.Background(), connectionDetails, true)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
statements := dbplugin.Statements{ statements := dbplugin.Statements{
CreationStatements: testPostgresRole, Creation: []string{testPostgresRole},
} }
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameConfig{
@@ -274,7 +306,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
} }
// Test custom revoke statements // Test custom revoke statements
statements.RevocationStatements = defaultPostgresRevocationSQL statements.Revocation = []string{defaultPostgresRevocationSQL}
err = db.RevokeUser(context.Background(), statements, username) err = db.RevokeUser(context.Background(), statements, username)
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
@@ -286,6 +318,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
} }
func testCredsExist(t testing.TB, connURL, username, password string) error { func testCredsExist(t testing.TB, connURL, username, password string) error {
t.Helper()
// Log in with the new creds // Log in with the new creds
connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1) connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1)
db, err := sql.Open("postgres", connURL) db, err := sql.Open("postgres", connURL)

View File

@@ -15,8 +15,11 @@ var (
// connections and is used in all the builtin database types. // connections and is used in all the builtin database types.
type ConnectionProducer interface { type ConnectionProducer interface {
Close() error Close() error
Initialize(context.Context, map[string]interface{}, bool) error Init(context.Context, map[string]interface{}, bool) (map[string]interface{}, error)
Connection(context.Context) (interface{}, error) Connection(context.Context) (interface{}, error)
sync.Locker sync.Locker
// DEPRECATED, will be removed in 0.12
Initialize(context.Context, map[string]interface{}, bool) error
} }

View File

@@ -8,18 +8,25 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
) )
var _ ConnectionProducer = &SQLConnectionProducer{}
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases // SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type SQLConnectionProducer struct { type SQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
Username string `json:"username" mapstructure:"username" structs:"username"`
Password string `json:"password" mapstructure:"password" structs:"password"`
Type string Type string
RawConfig map[string]interface{}
maxConnectionLifetime time.Duration maxConnectionLifetime time.Duration
Initialized bool Initialized bool
db *sql.DB db *sql.DB
@@ -27,18 +34,30 @@ type SQLConnectionProducer struct {
} }
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
_, err := c.Init(ctx, conf, verifyConnection)
return err
}
func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
err := mapstructure.WeakDecode(conf, c) c.RawConfig = conf
err := mapstructure.WeakDecode(conf, &c)
if err != nil { if err != nil {
return err return nil, err
} }
if len(c.ConnectionURL) == 0 { if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty") return nil, fmt.Errorf("connection_url cannot be empty")
} }
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
"username": c.Username,
"password": c.Password,
})
if c.MaxOpenConnections == 0 { if c.MaxOpenConnections == 0 {
c.MaxOpenConnections = 2 c.MaxOpenConnections = 2
} }
@@ -55,7 +74,7 @@ func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw) c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
if err != nil { if err != nil {
return fmt.Errorf("invalid max_connection_lifetime: %s", err) return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
} }
// Set initialized to true at this point since all fields are set, // Set initialized to true at this point since all fields are set,
@@ -64,15 +83,15 @@ func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]
if verifyConnection { if verifyConnection {
if _, err := c.Connection(ctx); err != nil { if _, err := c.Connection(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
} }
if err := c.db.PingContext(ctx); err != nil { if err := c.db.PingContext(ctx); err != nil {
return fmt.Errorf("error verifying connection: %s", err) return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
} }
} }
return nil return c.RawConfig, nil
} }
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) { func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
@@ -123,6 +142,12 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
return c.db, nil return c.db, nil
} }
func (c *SQLConnectionProducer) SecretValues() map[string]interface{} {
return map[string]interface{}{
c.Password: "[password]",
}
}
// Close attempts to close the connection // Close attempts to close the connection
func (c *SQLConnectionProducer) Close() error { func (c *SQLConnectionProducer) Close() error {
// Grab the write lock // Grab the write lock

View File

@@ -4,6 +4,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
) )
var ( var (
@@ -18,3 +20,33 @@ func QueryHelper(tpl string, data map[string]string) string {
return tpl return tpl
} }
// StatementCompatibilityHelper will populate the statements fields to support
// compatibility
func StatementCompatibilityHelper(statements dbplugin.Statements) dbplugin.Statements {
switch {
case len(statements.Creation) > 0 && len(statements.CreationStatements) == 0:
statements.CreationStatements = strings.Join(statements.Creation, ";")
case len(statements.CreationStatements) > 0:
statements.Creation = []string{statements.CreationStatements}
}
switch {
case len(statements.Revocation) > 0 && len(statements.RevocationStatements) == 0:
statements.RevocationStatements = strings.Join(statements.Revocation, ";")
case len(statements.RevocationStatements) > 0:
statements.Revocation = []string{statements.RevocationStatements}
}
switch {
case len(statements.Renewal) > 0 && len(statements.RenewStatements) == 0:
statements.RenewStatements = strings.Join(statements.Renewal, ";")
case len(statements.RenewStatements) > 0:
statements.Renewal = []string{statements.RenewStatements}
}
switch {
case len(statements.Rollback) > 0 && len(statements.RollbackStatements) == 0:
statements.RollbackStatements = strings.Join(statements.Rollback, ";")
case len(statements.RollbackStatements) > 0:
statements.Rollback = []string{statements.RollbackStatements}
}
return statements
}

View File

@@ -0,0 +1,62 @@
package dbutil
import (
"reflect"
"testing"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
)
func TestStatementCompatibilityHelper(t *testing.T) {
const (
creationStatement = "creation"
renewStatement = "renew"
revokeStatement = "revoke"
rollbackStatement = "rollback"
)
expectedStatements := dbplugin.Statements{
Creation: []string{creationStatement},
Rollback: []string{rollbackStatement},
Revocation: []string{revokeStatement},
Renewal: []string{renewStatement},
CreationStatements: creationStatement,
RenewStatements: renewStatement,
RollbackStatements: rollbackStatement,
RevocationStatements: revokeStatement,
}
statements1 := dbplugin.Statements{
CreationStatements: creationStatement,
RenewStatements: renewStatement,
RollbackStatements: rollbackStatement,
RevocationStatements: revokeStatement,
}
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements1)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements1)
}
statements2 := dbplugin.Statements{
Creation: []string{creationStatement},
Rollback: []string{rollbackStatement},
Revocation: []string{revokeStatement},
Renewal: []string{renewStatement},
}
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements2)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements2)
}
statements3 := dbplugin.Statements{
CreationStatements: creationStatement,
}
expectedStatements3 := dbplugin.Statements{
Creation: []string{creationStatement},
CreationStatements: creationStatement,
}
if !reflect.DeepEqual(expectedStatements3, StatementCompatibilityHelper(statements3)) {
t.Fatalf("mismatch: %#v, %#v", expectedStatements3, statements3)
}
}