mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 02:57:59 +00:00
Database Root Credential Rotation (#3976)
* redoing connection handling * a little more cleanup * empty implementation of rotation * updating rotate signature * signature update * updating interfaces again :( * changing back to interface * adding templated url support and rotation for postgres * adding correct username * return updates * updating statements to be a list * adding error sanitizing middleware * fixing log sanitizier * adding postgres rotate test * removing conf from rotate * adding rotate command * adding mysql rotate * finishing up the endpoint in the db backend for rotate * no more structs, just store raw config * fixing tests * adding db instance lock * adding support for statement list in cassandra * wip redoing interface to support BC * adding falllback for Initialize implementation * adding backwards compat for statements * fix tests * fix more tests * fixing up tests, switching to new fields in statements * fixing more tests * adding mssql and mysql * wrapping all the things in middleware, implementing templating for mongodb * wrapping all db servers with error santizer * fixing test * store the name with the db instance * adding rotate to cassandra * adding compatibility translation to both server and plugin * reordering a few things * store the name with the db instance * reordering * adding a few more tests * switch secret values from slice to map * addressing some feedback * reinstate execute plugin after resetting connection * set database connection to closed * switching secret values func to map[string]interface for potential future uses * addressing feedback
This commit is contained in:
@@ -9,13 +9,37 @@ import (
|
|||||||
|
|
||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
|
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
|
uuid "github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
|
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const databaseConfigPath = "database/config/"
|
const databaseConfigPath = "database/config/"
|
||||||
|
|
||||||
|
type dbPluginInstance struct {
|
||||||
|
sync.RWMutex
|
||||||
|
dbplugin.Database
|
||||||
|
|
||||||
|
id string
|
||||||
|
name string
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dbi *dbPluginInstance) Close() error {
|
||||||
|
dbi.Lock()
|
||||||
|
defer dbi.Unlock()
|
||||||
|
|
||||||
|
if dbi.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dbi.closed = true
|
||||||
|
|
||||||
|
return dbi.Database.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||||
b := Backend(conf)
|
b := Backend(conf)
|
||||||
if err := b.Setup(ctx, conf); err != nil {
|
if err := b.Setup(ctx, conf); err != nil {
|
||||||
@@ -42,6 +66,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
|||||||
pathRoles(&b),
|
pathRoles(&b),
|
||||||
pathCredsCreate(&b),
|
pathCredsCreate(&b),
|
||||||
pathResetConnection(&b),
|
pathResetConnection(&b),
|
||||||
|
pathRotateCredentials(&b),
|
||||||
},
|
},
|
||||||
|
|
||||||
Secrets: []*framework.Secret{
|
Secrets: []*framework.Secret{
|
||||||
@@ -53,72 +78,22 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.logger = conf.Logger
|
b.logger = conf.Logger
|
||||||
b.connections = make(map[string]dbplugin.Database)
|
b.connections = make(map[string]*dbPluginInstance)
|
||||||
return &b
|
return &b
|
||||||
}
|
}
|
||||||
|
|
||||||
type databaseBackend struct {
|
type databaseBackend struct {
|
||||||
connections map[string]dbplugin.Database
|
connections map[string]*dbPluginInstance
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
*framework.Backend
|
*framework.Backend
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeAllDBs closes all connections from all database types
|
|
||||||
func (b *databaseBackend) closeAllDBs(ctx context.Context) {
|
|
||||||
b.Lock()
|
|
||||||
defer b.Unlock()
|
|
||||||
|
|
||||||
for _, db := range b.connections {
|
|
||||||
db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
b.connections = make(map[string]dbplugin.Database)
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function is used to retrieve a database object either from the cached
|
|
||||||
// connection map. The caller of this function needs to hold the backend's read
|
|
||||||
// lock.
|
|
||||||
func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
|
|
||||||
db, ok := b.connections[name]
|
|
||||||
return db, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function creates a new db object from the stored configuration and
|
|
||||||
// caches it in the connections map. The caller of this function needs to hold
|
|
||||||
// the backend's write lock
|
|
||||||
func (b *databaseBackend) createDBObj(ctx context.Context, s logical.Storage, name string) (dbplugin.Database, error) {
|
|
||||||
db, ok := b.connections[name]
|
|
||||||
if ok {
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
config, err := b.DatabaseConfig(ctx, s, name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err = dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = db.Initialize(ctx, config.ConnectionDetails, true)
|
|
||||||
if err != nil {
|
|
||||||
db.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b.connections[name] = db
|
|
||||||
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) {
|
func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) {
|
||||||
entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name))
|
entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read connection configuration: %s", err)
|
return nil, errwrap.Wrapf("failed to read connection configuration: {{err}}", err)
|
||||||
}
|
}
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
|
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
|
||||||
@@ -144,7 +119,7 @@ type upgradeStatements struct {
|
|||||||
type upgradeCheck struct {
|
type upgradeCheck struct {
|
||||||
// This json tag has a typo in it, the new version does not. This
|
// This json tag has a typo in it, the new version does not. This
|
||||||
// necessitates this upgrade logic.
|
// necessitates this upgrade logic.
|
||||||
Statements upgradeStatements `json:"statments"`
|
Statements *upgradeStatements `json:"statments,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) {
|
func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName string) (*roleEntry, error) {
|
||||||
@@ -166,48 +141,140 @@ func (b *databaseBackend) Role(ctx context.Context, s logical.Storage, roleName
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
empty := upgradeCheck{}
|
switch {
|
||||||
if upgradeCh != empty {
|
case upgradeCh.Statements != nil:
|
||||||
result.Statements.CreationStatements = upgradeCh.Statements.CreationStatements
|
var stmts dbplugin.Statements
|
||||||
result.Statements.RevocationStatements = upgradeCh.Statements.RevocationStatements
|
if upgradeCh.Statements.CreationStatements != "" {
|
||||||
result.Statements.RollbackStatements = upgradeCh.Statements.RollbackStatements
|
stmts.Creation = []string{upgradeCh.Statements.CreationStatements}
|
||||||
result.Statements.RenewStatements = upgradeCh.Statements.RenewStatements
|
}
|
||||||
|
if upgradeCh.Statements.RevocationStatements != "" {
|
||||||
|
stmts.Revocation = []string{upgradeCh.Statements.RevocationStatements}
|
||||||
|
}
|
||||||
|
if upgradeCh.Statements.RollbackStatements != "" {
|
||||||
|
stmts.Rollback = []string{upgradeCh.Statements.RollbackStatements}
|
||||||
|
}
|
||||||
|
if upgradeCh.Statements.RenewStatements != "" {
|
||||||
|
stmts.Renewal = []string{upgradeCh.Statements.RenewStatements}
|
||||||
|
}
|
||||||
|
result.Statements = stmts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For backwards compatibility, copy the values back into the string form
|
||||||
|
// of the fields
|
||||||
|
result.Statements = dbutil.StatementCompatibilityHelper(result.Statements)
|
||||||
|
|
||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *databaseBackend) invalidate(ctx context.Context, key string) {
|
func (b *databaseBackend) invalidate(ctx context.Context, key string) {
|
||||||
b.Lock()
|
|
||||||
defer b.Unlock()
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(key, databaseConfigPath):
|
case strings.HasPrefix(key, databaseConfigPath):
|
||||||
name := strings.TrimPrefix(key, databaseConfigPath)
|
name := strings.TrimPrefix(key, databaseConfigPath)
|
||||||
b.clearConnection(name)
|
b.ClearConnection(name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// clearConnection closes the database connection and
|
func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage, name string) (*dbPluginInstance, error) {
|
||||||
// removes it from the b.connections map.
|
b.RLock()
|
||||||
func (b *databaseBackend) clearConnection(name string) {
|
unlockFunc := b.RUnlock
|
||||||
|
defer func() { unlockFunc() }()
|
||||||
|
|
||||||
db, ok := b.connections[name]
|
db, ok := b.connections[name]
|
||||||
if ok {
|
if ok {
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade lock
|
||||||
|
b.RUnlock()
|
||||||
|
b.Lock()
|
||||||
|
unlockFunc = b.Unlock
|
||||||
|
|
||||||
|
db, ok = b.connections[name]
|
||||||
|
if ok {
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := b.DatabaseConfig(ctx, s, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dbp, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dbp.Init(ctx, config.ConnectionDetails, true)
|
||||||
|
if err != nil {
|
||||||
|
dbp.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := uuid.GenerateUUID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
db = &dbPluginInstance{
|
||||||
|
Database: dbp,
|
||||||
|
name: name,
|
||||||
|
id: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
b.connections[name] = db
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearConnection closes the database connection and
|
||||||
|
// removes it from the b.connections map.
|
||||||
|
func (b *databaseBackend) ClearConnection(name string) error {
|
||||||
|
b.Lock()
|
||||||
|
defer b.Unlock()
|
||||||
|
return b.clearConnection(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *databaseBackend) clearConnection(name string) error {
|
||||||
|
db, ok := b.connections[name]
|
||||||
|
if ok {
|
||||||
|
// Ignore error here since the database client is always killed
|
||||||
db.Close()
|
db.Close()
|
||||||
delete(b.connections, name)
|
delete(b.connections, name)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *databaseBackend) closeIfShutdown(name string, err error) {
|
func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) {
|
||||||
// Plugin has shutdown, close it so next call can reconnect.
|
// Plugin has shutdown, close it so next call can reconnect.
|
||||||
switch err {
|
switch err {
|
||||||
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
|
case rpc.ErrShutdown, dbplugin.ErrPluginShutdown:
|
||||||
b.Lock()
|
// Put this in a goroutine so that requests can run with the read or write lock
|
||||||
b.clearConnection(name)
|
// and simply defer the unlock. Since we are attaching the instance and matching
|
||||||
b.Unlock()
|
// the id in the conneciton map, we can safely do this.
|
||||||
|
go func() {
|
||||||
|
b.Lock()
|
||||||
|
defer b.Unlock()
|
||||||
|
db.Close()
|
||||||
|
|
||||||
|
// Ensure we are deleting the correct connection
|
||||||
|
mapDB, ok := b.connections[db.name]
|
||||||
|
if ok && db.id == mapDB.id {
|
||||||
|
delete(b.connections, db.name)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// closeAllDBs closes all connections from all database types
|
||||||
|
func (b *databaseBackend) closeAllDBs(ctx context.Context) {
|
||||||
|
b.Lock()
|
||||||
|
defer b.Unlock()
|
||||||
|
|
||||||
|
for _, db := range b.connections {
|
||||||
|
db.Close()
|
||||||
|
}
|
||||||
|
b.connections = make(map[string]*dbPluginInstance)
|
||||||
|
}
|
||||||
|
|
||||||
const backendHelp = `
|
const backendHelp = `
|
||||||
The database backend supports using many different databases
|
The database backend supports using many different databases
|
||||||
as secret backends, including but not limited to:
|
as secret backends, including but not limited to:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -27,6 +28,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) {
|
func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) {
|
||||||
|
t.Helper()
|
||||||
if os.Getenv("PG_URL") != "" {
|
if os.Getenv("PG_URL") != "" {
|
||||||
return func() {}, os.Getenv("PG_URL")
|
return func() {}, os.Getenv("PG_URL")
|
||||||
}
|
}
|
||||||
@@ -64,7 +66,7 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac
|
|||||||
})
|
})
|
||||||
if err != nil || (resp != nil && resp.IsError()) {
|
if err != nil || (resp != nil && resp.IsError()) {
|
||||||
// It's likely not up and running yet, so return error and try again
|
// It's likely not up and running yet, so return error and try again
|
||||||
return fmt.Errorf("err:%s resp:%#v\n", err, resp)
|
return fmt.Errorf("err:%#v resp:%#v", err, resp)
|
||||||
}
|
}
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
t.Fatal("expected warning")
|
t.Fatal("expected warning")
|
||||||
@@ -123,13 +125,18 @@ func TestBackend_RoleUpgrade(t *testing.T) {
|
|||||||
storage := &logical.InmemStorage{}
|
storage := &logical.InmemStorage{}
|
||||||
backend := &databaseBackend{}
|
backend := &databaseBackend{}
|
||||||
|
|
||||||
roleEnt := &roleEntry{
|
roleExpected := &roleEntry{
|
||||||
Statements: dbplugin.Statements{
|
Statements: dbplugin.Statements{
|
||||||
CreationStatements: "test",
|
CreationStatements: "test",
|
||||||
|
Creation: []string{"test"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, err := logical.StorageEntryJSON("role/test", roleEnt)
|
entry, err := logical.StorageEntryJSON("role/test", &roleEntry{
|
||||||
|
Statements: dbplugin.Statements{
|
||||||
|
CreationStatements: "test",
|
||||||
|
},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -142,8 +149,8 @@ func TestBackend_RoleUpgrade(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(role, roleEnt) {
|
if !reflect.DeepEqual(role, roleExpected) {
|
||||||
t.Fatalf("bad role %#v", role)
|
t.Fatalf("bad role %#v, %#v", role, roleExpected)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upgrade case
|
// Upgrade case
|
||||||
@@ -161,8 +168,8 @@ func TestBackend_RoleUpgrade(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(role, roleEnt) {
|
if !reflect.DeepEqual(role, roleExpected) {
|
||||||
t.Fatalf("bad role %#v", role)
|
t.Fatalf("bad role %#v, %#v", role, roleExpected)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -206,7 +213,8 @@ func TestBackend_config_connection(t *testing.T) {
|
|||||||
"connection_details": map[string]interface{}{
|
"connection_details": map[string]interface{}{
|
||||||
"connection_url": "sample_connection_url",
|
"connection_url": "sample_connection_url",
|
||||||
},
|
},
|
||||||
"allowed_roles": []string{"*"},
|
"allowed_roles": []string{"*"},
|
||||||
|
"root_credentials_rotate_statements": []string{},
|
||||||
}
|
}
|
||||||
configReq.Operation = logical.ReadOperation
|
configReq.Operation = logical.ReadOperation
|
||||||
resp, err = b.HandleRequest(context.Background(), configReq)
|
resp, err = b.HandleRequest(context.Background(), configReq)
|
||||||
@@ -233,6 +241,55 @@ func TestBackend_config_connection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBackend_BadConnectionString(t *testing.T) {
|
||||||
|
cluster, sys := getCluster(t)
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
config := logical.TestBackendConfig()
|
||||||
|
config.StorageView = &logical.InmemStorage{}
|
||||||
|
config.System = sys
|
||||||
|
|
||||||
|
b, err := Factory(context.Background(), config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer b.Cleanup(context.Background())
|
||||||
|
|
||||||
|
cleanup, _ := preparePostgresTestContainer(t, config.StorageView, b)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
respCheck := func(req *logical.Request) {
|
||||||
|
t.Helper()
|
||||||
|
resp, err := b.HandleRequest(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if resp == nil || !resp.IsError() {
|
||||||
|
t.Fatalf("expected error, resp:%#v", resp)
|
||||||
|
}
|
||||||
|
err = resp.Error()
|
||||||
|
if strings.Contains(err.Error(), "localhost") {
|
||||||
|
t.Fatalf("error should not contain connection info")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure a connection
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"connection_url": "postgresql://:pw@[localhost",
|
||||||
|
"plugin_name": "postgresql-database-plugin",
|
||||||
|
"allowed_roles": []string{"plugin-role-test"},
|
||||||
|
}
|
||||||
|
req := &logical.Request{
|
||||||
|
Operation: logical.UpdateOperation,
|
||||||
|
Path: "config/plugin-test",
|
||||||
|
Storage: config.StorageView,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
respCheck(req)
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
func TestBackend_basic(t *testing.T) {
|
func TestBackend_basic(t *testing.T) {
|
||||||
cluster, sys := getCluster(t)
|
cluster, sys := getCluster(t)
|
||||||
defer cluster.Cleanup()
|
defer cluster.Cleanup()
|
||||||
@@ -388,7 +445,6 @@ func TestBackend_basic(t *testing.T) {
|
|||||||
if testCredsExist(t, credsResp, connURL) {
|
if testCredsExist(t, credsResp, connURL) {
|
||||||
t.Fatalf("Creds should not exist")
|
t.Fatalf("Creds should not exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBackend_connectionCrud(t *testing.T) {
|
func TestBackend_connectionCrud(t *testing.T) {
|
||||||
@@ -467,7 +523,8 @@ func TestBackend_connectionCrud(t *testing.T) {
|
|||||||
"connection_details": map[string]interface{}{
|
"connection_details": map[string]interface{}{
|
||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
},
|
},
|
||||||
"allowed_roles": []string{"plugin-role-test"},
|
"allowed_roles": []string{"plugin-role-test"},
|
||||||
|
"root_credentials_rotate_statements": []string{},
|
||||||
}
|
}
|
||||||
req.Operation = logical.ReadOperation
|
req.Operation = logical.ReadOperation
|
||||||
resp, err = b.HandleRequest(context.Background(), req)
|
resp, err = b.HandleRequest(context.Background(), req)
|
||||||
@@ -602,15 +659,15 @@ func TestBackend_roleCrud(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
expected := dbplugin.Statements{
|
expected := dbplugin.Statements{
|
||||||
CreationStatements: testRole,
|
Creation: []string{strings.TrimSpace(testRole)},
|
||||||
RevocationStatements: defaultRevocationSQL,
|
Revocation: []string{strings.TrimSpace(defaultRevocationSQL)},
|
||||||
}
|
}
|
||||||
|
|
||||||
actual := dbplugin.Statements{
|
actual := dbplugin.Statements{
|
||||||
CreationStatements: resp.Data["creation_statements"].(string),
|
Creation: resp.Data["creation_statements"].([]string),
|
||||||
RevocationStatements: resp.Data["revocation_statements"].(string),
|
Revocation: resp.Data["revocation_statements"].([]string),
|
||||||
RollbackStatements: resp.Data["rollback_statements"].(string),
|
Rollback: resp.Data["rollback_statements"].([]string),
|
||||||
RenewStatements: resp.Data["renew_statements"].(string),
|
Renewal: resp.Data["renew_statements"].([]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, actual) {
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go.
|
||||||
// source: builtin/logical/database/dbplugin/database.proto
|
// source: builtin/logical/database/dbplugin/database.proto
|
||||||
|
// DO NOT EDIT!
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Package dbplugin is a generated protocol buffer package.
|
Package dbplugin is a generated protocol buffer package.
|
||||||
@@ -9,13 +10,17 @@ It is generated from these files:
|
|||||||
|
|
||||||
It has these top-level messages:
|
It has these top-level messages:
|
||||||
InitializeRequest
|
InitializeRequest
|
||||||
|
InitRequest
|
||||||
CreateUserRequest
|
CreateUserRequest
|
||||||
RenewUserRequest
|
RenewUserRequest
|
||||||
RevokeUserRequest
|
RevokeUserRequest
|
||||||
|
RotateRootCredentialsRequest
|
||||||
Statements
|
Statements
|
||||||
UsernameConfig
|
UsernameConfig
|
||||||
|
InitResponse
|
||||||
CreateUserResponse
|
CreateUserResponse
|
||||||
TypeResponse
|
TypeResponse
|
||||||
|
RotateRootCredentialsResponse
|
||||||
Empty
|
Empty
|
||||||
*/
|
*/
|
||||||
package dbplugin
|
package dbplugin
|
||||||
@@ -65,6 +70,30 @@ func (m *InitializeRequest) GetVerifyConnection() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InitRequest struct {
|
||||||
|
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
|
||||||
|
VerifyConnection bool `protobuf:"varint,2,opt,name=verify_connection,json=verifyConnection" json:"verify_connection,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InitRequest) Reset() { *m = InitRequest{} }
|
||||||
|
func (m *InitRequest) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*InitRequest) ProtoMessage() {}
|
||||||
|
func (*InitRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
|
||||||
|
|
||||||
|
func (m *InitRequest) GetConfig() []byte {
|
||||||
|
if m != nil {
|
||||||
|
return m.Config
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InitRequest) GetVerifyConnection() bool {
|
||||||
|
if m != nil {
|
||||||
|
return m.VerifyConnection
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type CreateUserRequest struct {
|
type CreateUserRequest struct {
|
||||||
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
Statements *Statements `protobuf:"bytes,1,opt,name=statements" json:"statements,omitempty"`
|
||||||
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
|
UsernameConfig *UsernameConfig `protobuf:"bytes,2,opt,name=username_config,json=usernameConfig" json:"username_config,omitempty"`
|
||||||
@@ -74,7 +103,7 @@ type CreateUserRequest struct {
|
|||||||
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
|
func (m *CreateUserRequest) Reset() { *m = CreateUserRequest{} }
|
||||||
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
|
func (m *CreateUserRequest) String() string { return proto.CompactTextString(m) }
|
||||||
func (*CreateUserRequest) ProtoMessage() {}
|
func (*CreateUserRequest) ProtoMessage() {}
|
||||||
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
|
func (*CreateUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
||||||
|
|
||||||
func (m *CreateUserRequest) GetStatements() *Statements {
|
func (m *CreateUserRequest) GetStatements() *Statements {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
@@ -106,7 +135,7 @@ type RenewUserRequest struct {
|
|||||||
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
|
func (m *RenewUserRequest) Reset() { *m = RenewUserRequest{} }
|
||||||
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
|
func (m *RenewUserRequest) String() string { return proto.CompactTextString(m) }
|
||||||
func (*RenewUserRequest) ProtoMessage() {}
|
func (*RenewUserRequest) ProtoMessage() {}
|
||||||
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
func (*RenewUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
|
||||||
|
|
||||||
func (m *RenewUserRequest) GetStatements() *Statements {
|
func (m *RenewUserRequest) GetStatements() *Statements {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
@@ -137,7 +166,7 @@ type RevokeUserRequest struct {
|
|||||||
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
|
func (m *RevokeUserRequest) Reset() { *m = RevokeUserRequest{} }
|
||||||
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
|
func (m *RevokeUserRequest) String() string { return proto.CompactTextString(m) }
|
||||||
func (*RevokeUserRequest) ProtoMessage() {}
|
func (*RevokeUserRequest) ProtoMessage() {}
|
||||||
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
|
func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
|
||||||
|
|
||||||
func (m *RevokeUserRequest) GetStatements() *Statements {
|
func (m *RevokeUserRequest) GetStatements() *Statements {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
@@ -153,17 +182,41 @@ func (m *RevokeUserRequest) GetUsername() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RotateRootCredentialsRequest struct {
|
||||||
|
Statements []string `protobuf:"bytes,1,rep,name=statements" json:"statements,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RotateRootCredentialsRequest) Reset() { *m = RotateRootCredentialsRequest{} }
|
||||||
|
func (m *RotateRootCredentialsRequest) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*RotateRootCredentialsRequest) ProtoMessage() {}
|
||||||
|
func (*RotateRootCredentialsRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
|
||||||
|
|
||||||
|
func (m *RotateRootCredentialsRequest) GetStatements() []string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Statements
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Statements struct {
|
type Statements struct {
|
||||||
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
|
// DEPRECATED, will be removed in 0.12
|
||||||
|
CreationStatements string `protobuf:"bytes,1,opt,name=creation_statements,json=creationStatements" json:"creation_statements,omitempty"`
|
||||||
|
// DEPRECATED, will be removed in 0.12
|
||||||
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
|
RevocationStatements string `protobuf:"bytes,2,opt,name=revocation_statements,json=revocationStatements" json:"revocation_statements,omitempty"`
|
||||||
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
|
// DEPRECATED, will be removed in 0.12
|
||||||
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
|
RollbackStatements string `protobuf:"bytes,3,opt,name=rollback_statements,json=rollbackStatements" json:"rollback_statements,omitempty"`
|
||||||
|
// DEPRECATED, will be removed in 0.12
|
||||||
|
RenewStatements string `protobuf:"bytes,4,opt,name=renew_statements,json=renewStatements" json:"renew_statements,omitempty"`
|
||||||
|
Creation []string `protobuf:"bytes,5,rep,name=creation" json:"creation,omitempty"`
|
||||||
|
Revocation []string `protobuf:"bytes,6,rep,name=revocation" json:"revocation,omitempty"`
|
||||||
|
Rollback []string `protobuf:"bytes,7,rep,name=rollback" json:"rollback,omitempty"`
|
||||||
|
Renewal []string `protobuf:"bytes,8,rep,name=renewal" json:"renewal,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Statements) Reset() { *m = Statements{} }
|
func (m *Statements) Reset() { *m = Statements{} }
|
||||||
func (m *Statements) String() string { return proto.CompactTextString(m) }
|
func (m *Statements) String() string { return proto.CompactTextString(m) }
|
||||||
func (*Statements) ProtoMessage() {}
|
func (*Statements) ProtoMessage() {}
|
||||||
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
|
func (*Statements) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
|
||||||
|
|
||||||
func (m *Statements) GetCreationStatements() string {
|
func (m *Statements) GetCreationStatements() string {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
@@ -193,6 +246,34 @@ func (m *Statements) GetRenewStatements() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Statements) GetCreation() []string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Creation
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Statements) GetRevocation() []string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Revocation
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Statements) GetRollback() []string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Rollback
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Statements) GetRenewal() []string {
|
||||||
|
if m != nil {
|
||||||
|
return m.Renewal
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type UsernameConfig struct {
|
type UsernameConfig struct {
|
||||||
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
|
DisplayName string `protobuf:"bytes,1,opt,name=DisplayName" json:"DisplayName,omitempty"`
|
||||||
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
|
RoleName string `protobuf:"bytes,2,opt,name=RoleName" json:"RoleName,omitempty"`
|
||||||
@@ -201,7 +282,7 @@ type UsernameConfig struct {
|
|||||||
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
|
func (m *UsernameConfig) Reset() { *m = UsernameConfig{} }
|
||||||
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
|
func (m *UsernameConfig) String() string { return proto.CompactTextString(m) }
|
||||||
func (*UsernameConfig) ProtoMessage() {}
|
func (*UsernameConfig) ProtoMessage() {}
|
||||||
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
|
func (*UsernameConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
|
||||||
|
|
||||||
func (m *UsernameConfig) GetDisplayName() string {
|
func (m *UsernameConfig) GetDisplayName() string {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
@@ -217,6 +298,22 @@ func (m *UsernameConfig) GetRoleName() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InitResponse struct {
|
||||||
|
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InitResponse) Reset() { *m = InitResponse{} }
|
||||||
|
func (m *InitResponse) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*InitResponse) ProtoMessage() {}
|
||||||
|
func (*InitResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
|
||||||
|
|
||||||
|
func (m *InitResponse) GetConfig() []byte {
|
||||||
|
if m != nil {
|
||||||
|
return m.Config
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type CreateUserResponse struct {
|
type CreateUserResponse struct {
|
||||||
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
|
Username string `protobuf:"bytes,1,opt,name=username" json:"username,omitempty"`
|
||||||
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
|
Password string `protobuf:"bytes,2,opt,name=password" json:"password,omitempty"`
|
||||||
@@ -225,7 +322,7 @@ type CreateUserResponse struct {
|
|||||||
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
|
func (m *CreateUserResponse) Reset() { *m = CreateUserResponse{} }
|
||||||
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
|
func (m *CreateUserResponse) String() string { return proto.CompactTextString(m) }
|
||||||
func (*CreateUserResponse) ProtoMessage() {}
|
func (*CreateUserResponse) ProtoMessage() {}
|
||||||
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
|
func (*CreateUserResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{9} }
|
||||||
|
|
||||||
func (m *CreateUserResponse) GetUsername() string {
|
func (m *CreateUserResponse) GetUsername() string {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
@@ -248,7 +345,7 @@ type TypeResponse struct {
|
|||||||
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
|
func (m *TypeResponse) Reset() { *m = TypeResponse{} }
|
||||||
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
|
func (m *TypeResponse) String() string { return proto.CompactTextString(m) }
|
||||||
func (*TypeResponse) ProtoMessage() {}
|
func (*TypeResponse) ProtoMessage() {}
|
||||||
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
|
func (*TypeResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{10} }
|
||||||
|
|
||||||
func (m *TypeResponse) GetType() string {
|
func (m *TypeResponse) GetType() string {
|
||||||
if m != nil {
|
if m != nil {
|
||||||
@@ -257,23 +354,43 @@ func (m *TypeResponse) GetType() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RotateRootCredentialsResponse struct {
|
||||||
|
Config []byte `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RotateRootCredentialsResponse) Reset() { *m = RotateRootCredentialsResponse{} }
|
||||||
|
func (m *RotateRootCredentialsResponse) String() string { return proto.CompactTextString(m) }
|
||||||
|
func (*RotateRootCredentialsResponse) ProtoMessage() {}
|
||||||
|
func (*RotateRootCredentialsResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{11} }
|
||||||
|
|
||||||
|
func (m *RotateRootCredentialsResponse) GetConfig() []byte {
|
||||||
|
if m != nil {
|
||||||
|
return m.Config
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Empty struct {
|
type Empty struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Empty) Reset() { *m = Empty{} }
|
func (m *Empty) Reset() { *m = Empty{} }
|
||||||
func (m *Empty) String() string { return proto.CompactTextString(m) }
|
func (m *Empty) String() string { return proto.CompactTextString(m) }
|
||||||
func (*Empty) ProtoMessage() {}
|
func (*Empty) ProtoMessage() {}
|
||||||
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{8} }
|
func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{12} }
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
|
proto.RegisterType((*InitializeRequest)(nil), "dbplugin.InitializeRequest")
|
||||||
|
proto.RegisterType((*InitRequest)(nil), "dbplugin.InitRequest")
|
||||||
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
|
proto.RegisterType((*CreateUserRequest)(nil), "dbplugin.CreateUserRequest")
|
||||||
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
|
proto.RegisterType((*RenewUserRequest)(nil), "dbplugin.RenewUserRequest")
|
||||||
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
|
proto.RegisterType((*RevokeUserRequest)(nil), "dbplugin.RevokeUserRequest")
|
||||||
|
proto.RegisterType((*RotateRootCredentialsRequest)(nil), "dbplugin.RotateRootCredentialsRequest")
|
||||||
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
|
proto.RegisterType((*Statements)(nil), "dbplugin.Statements")
|
||||||
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
|
proto.RegisterType((*UsernameConfig)(nil), "dbplugin.UsernameConfig")
|
||||||
|
proto.RegisterType((*InitResponse)(nil), "dbplugin.InitResponse")
|
||||||
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
|
proto.RegisterType((*CreateUserResponse)(nil), "dbplugin.CreateUserResponse")
|
||||||
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
|
proto.RegisterType((*TypeResponse)(nil), "dbplugin.TypeResponse")
|
||||||
|
proto.RegisterType((*RotateRootCredentialsResponse)(nil), "dbplugin.RotateRootCredentialsResponse")
|
||||||
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
|
proto.RegisterType((*Empty)(nil), "dbplugin.Empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,8 +409,10 @@ type DatabaseClient interface {
|
|||||||
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
|
CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error)
|
||||||
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
RenewUser(ctx context.Context, in *RenewUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||||
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||||
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
|
RotateRootCredentials(ctx context.Context, in *RotateRootCredentialsRequest, opts ...grpc.CallOption) (*RotateRootCredentialsResponse, error)
|
||||||
|
Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error)
|
||||||
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
|
Close(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
|
||||||
|
Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type databaseClient struct {
|
type databaseClient struct {
|
||||||
@@ -340,9 +459,18 @@ func (c *databaseClient) RevokeUser(ctx context.Context, in *RevokeUserRequest,
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
|
func (c *databaseClient) RotateRootCredentials(ctx context.Context, in *RotateRootCredentialsRequest, opts ...grpc.CallOption) (*RotateRootCredentialsResponse, error) {
|
||||||
out := new(Empty)
|
out := new(RotateRootCredentialsResponse)
|
||||||
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
|
err := grpc.Invoke(ctx, "/dbplugin.Database/RotateRootCredentials", in, out, c.cc, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *databaseClient) Init(ctx context.Context, in *InitRequest, opts ...grpc.CallOption) (*InitResponse, error) {
|
||||||
|
out := new(InitResponse)
|
||||||
|
err := grpc.Invoke(ctx, "/dbplugin.Database/Init", in, out, c.cc, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -358,6 +486,15 @@ func (c *databaseClient) Close(ctx context.Context, in *Empty, opts ...grpc.Call
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *databaseClient) Initialize(ctx context.Context, in *InitializeRequest, opts ...grpc.CallOption) (*Empty, error) {
|
||||||
|
out := new(Empty)
|
||||||
|
err := grpc.Invoke(ctx, "/dbplugin.Database/Initialize", in, out, c.cc, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Server API for Database service
|
// Server API for Database service
|
||||||
|
|
||||||
type DatabaseServer interface {
|
type DatabaseServer interface {
|
||||||
@@ -365,8 +502,10 @@ type DatabaseServer interface {
|
|||||||
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
|
CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error)
|
||||||
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
|
RenewUser(context.Context, *RenewUserRequest) (*Empty, error)
|
||||||
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
|
RevokeUser(context.Context, *RevokeUserRequest) (*Empty, error)
|
||||||
Initialize(context.Context, *InitializeRequest) (*Empty, error)
|
RotateRootCredentials(context.Context, *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error)
|
||||||
|
Init(context.Context, *InitRequest) (*InitResponse, error)
|
||||||
Close(context.Context, *Empty) (*Empty, error)
|
Close(context.Context, *Empty) (*Empty, error)
|
||||||
|
Initialize(context.Context, *InitializeRequest) (*Empty, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
|
func RegisterDatabaseServer(s *grpc.Server, srv DatabaseServer) {
|
||||||
@@ -445,20 +584,38 @@ func _Database_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _Database_RotateRootCredentials_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(InitializeRequest)
|
in := new(RotateRootCredentialsRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if interceptor == nil {
|
if interceptor == nil {
|
||||||
return srv.(DatabaseServer).Initialize(ctx, in)
|
return srv.(DatabaseServer).RotateRootCredentials(ctx, in)
|
||||||
}
|
}
|
||||||
info := &grpc.UnaryServerInfo{
|
info := &grpc.UnaryServerInfo{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
FullMethod: "/dbplugin.Database/Initialize",
|
FullMethod: "/dbplugin.Database/RotateRootCredentials",
|
||||||
}
|
}
|
||||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
|
return srv.(DatabaseServer).RotateRootCredentials(ctx, req.(*RotateRootCredentialsRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _Database_Init_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(InitRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DatabaseServer).Init(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/dbplugin.Database/Init",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DatabaseServer).Init(ctx, req.(*InitRequest))
|
||||||
}
|
}
|
||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
@@ -481,6 +638,24 @@ func _Database_Close_Handler(srv interface{}, ctx context.Context, dec func(inte
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _Database_Initialize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(InitializeRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DatabaseServer).Initialize(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/dbplugin.Database/Initialize",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DatabaseServer).Initialize(ctx, req.(*InitializeRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
var _Database_serviceDesc = grpc.ServiceDesc{
|
var _Database_serviceDesc = grpc.ServiceDesc{
|
||||||
ServiceName: "dbplugin.Database",
|
ServiceName: "dbplugin.Database",
|
||||||
HandlerType: (*DatabaseServer)(nil),
|
HandlerType: (*DatabaseServer)(nil),
|
||||||
@@ -502,13 +677,21 @@ var _Database_serviceDesc = grpc.ServiceDesc{
|
|||||||
Handler: _Database_RevokeUser_Handler,
|
Handler: _Database_RevokeUser_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "Initialize",
|
MethodName: "RotateRootCredentials",
|
||||||
Handler: _Database_Initialize_Handler,
|
Handler: _Database_RotateRootCredentials_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "Init",
|
||||||
|
Handler: _Database_Init_Handler,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
MethodName: "Close",
|
MethodName: "Close",
|
||||||
Handler: _Database_Close_Handler,
|
Handler: _Database_Close_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "Initialize",
|
||||||
|
Handler: _Database_Initialize_Handler,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{},
|
Streams: []grpc.StreamDesc{},
|
||||||
Metadata: "builtin/logical/database/dbplugin/database.proto",
|
Metadata: "builtin/logical/database/dbplugin/database.proto",
|
||||||
@@ -517,40 +700,49 @@ var _Database_serviceDesc = grpc.ServiceDesc{
|
|||||||
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
|
func init() { proto.RegisterFile("builtin/logical/database/dbplugin/database.proto", fileDescriptor0) }
|
||||||
|
|
||||||
var fileDescriptor0 = []byte{
|
var fileDescriptor0 = []byte{
|
||||||
// 548 bytes of a gzipped FileDescriptorProto
|
// 690 bytes of a gzipped FileDescriptorProto
|
||||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x54, 0xcf, 0x6e, 0xd3, 0x4e,
|
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0x55, 0x41, 0x4f, 0xdb, 0x4a,
|
||||||
0x10, 0x96, 0xdb, 0xb4, 0xbf, 0x64, 0x5a, 0x35, 0xc9, 0xfe, 0x4a, 0x15, 0x19, 0x24, 0x22, 0x9f,
|
0x10, 0x96, 0x93, 0x00, 0xc9, 0x80, 0x80, 0xec, 0x03, 0x64, 0xf9, 0xf1, 0xde, 0x43, 0x3e, 0xf0,
|
||||||
0x5a, 0x21, 0xd9, 0xa8, 0xe5, 0x80, 0xb8, 0xa1, 0x14, 0x21, 0x24, 0x94, 0x83, 0x69, 0x25, 0x6e,
|
0x40, 0x95, 0xe2, 0x0a, 0x5a, 0xb5, 0xe2, 0xd0, 0xaa, 0x0a, 0x55, 0x55, 0xa9, 0xe2, 0xb0, 0xc0,
|
||||||
0xd1, 0xda, 0x99, 0x44, 0xab, 0x3a, 0xbb, 0xc6, 0xbb, 0x4e, 0x09, 0x4f, 0xc3, 0xe3, 0x70, 0xe2,
|
0xad, 0x12, 0xda, 0x38, 0x43, 0xba, 0xc2, 0xf1, 0xba, 0xde, 0x0d, 0x34, 0xfd, 0x03, 0xed, 0xcf,
|
||||||
0x1d, 0x78, 0x13, 0xe4, 0x75, 0xd6, 0xbb, 0xf9, 0x73, 0xab, 0xb8, 0x79, 0x66, 0xbe, 0x6f, 0xf6,
|
0xe8, 0x4f, 0xe9, 0xb1, 0x3f, 0xab, 0xf2, 0xda, 0x6b, 0x6f, 0x62, 0x28, 0x07, 0xda, 0x9b, 0x67,
|
||||||
0xf3, 0xb7, 0x33, 0x0b, 0xaf, 0x93, 0x92, 0x65, 0x8a, 0xf1, 0x28, 0x13, 0x73, 0x96, 0xd2, 0x2c,
|
0xe6, 0xfb, 0x66, 0xbe, 0x9d, 0x9d, 0x59, 0xc3, 0xe3, 0xc1, 0x84, 0x47, 0x8a, 0xc7, 0x41, 0x24,
|
||||||
0x9a, 0x52, 0x45, 0x13, 0x2a, 0x31, 0x9a, 0x26, 0x79, 0x56, 0xce, 0x19, 0x6f, 0x32, 0x61, 0x5e,
|
0x46, 0x3c, 0x64, 0x51, 0x30, 0x64, 0x8a, 0x0d, 0x98, 0xc4, 0x60, 0x38, 0x48, 0xa2, 0xc9, 0x88,
|
||||||
0x08, 0x25, 0x48, 0xdb, 0x14, 0xfc, 0x97, 0x73, 0x21, 0xe6, 0x19, 0x46, 0x3a, 0x9f, 0x94, 0xb3,
|
0xc7, 0xa5, 0xa7, 0x97, 0xa4, 0x42, 0x09, 0xd2, 0x36, 0x01, 0xef, 0xbf, 0x91, 0x10, 0xa3, 0x08,
|
||||||
0x48, 0xb1, 0x05, 0x4a, 0x45, 0x17, 0x79, 0x0d, 0x0d, 0xbe, 0x42, 0xff, 0x13, 0x67, 0x8a, 0xd1,
|
0x03, 0xed, 0x1f, 0x4c, 0x2e, 0x03, 0xc5, 0xc7, 0x28, 0x15, 0x1b, 0x27, 0x39, 0xd4, 0x7f, 0x0f,
|
||||||
0x8c, 0xfd, 0xc0, 0x18, 0xbf, 0x95, 0x28, 0x15, 0xb9, 0x80, 0xe3, 0x54, 0xf0, 0x19, 0x9b, 0x0f,
|
0xdd, 0xb7, 0x31, 0x57, 0x9c, 0x45, 0xfc, 0x33, 0x52, 0xfc, 0x38, 0x41, 0xa9, 0xc8, 0x16, 0x2c,
|
||||||
0xbc, 0xa1, 0x77, 0x79, 0x1a, 0xaf, 0x23, 0xf2, 0x0a, 0xfa, 0x4b, 0x2c, 0xd8, 0x6c, 0x35, 0x49,
|
0x86, 0x22, 0xbe, 0xe4, 0x23, 0xd7, 0xd9, 0x71, 0xf6, 0x56, 0x68, 0x61, 0x91, 0x47, 0xd0, 0xbd,
|
||||||
0x05, 0xe7, 0x98, 0x2a, 0x26, 0xf8, 0xe0, 0x60, 0xe8, 0x5d, 0xb6, 0xe3, 0x5e, 0x5d, 0x18, 0x35,
|
0xc6, 0x94, 0x5f, 0x4e, 0x2f, 0x42, 0x11, 0xc7, 0x18, 0x2a, 0x2e, 0x62, 0xb7, 0xb1, 0xe3, 0xec,
|
||||||
0xf9, 0xe0, 0x97, 0x07, 0xfd, 0x51, 0x81, 0x54, 0xe1, 0xbd, 0xc4, 0xc2, 0xb4, 0x7e, 0x03, 0x20,
|
0xb5, 0xe9, 0x7a, 0x1e, 0xe8, 0x97, 0xfe, 0xa3, 0x86, 0xeb, 0xf8, 0x14, 0x96, 0xb3, 0xec, 0xbf,
|
||||||
0x15, 0x55, 0xb8, 0x40, 0xae, 0xa4, 0x6e, 0x7f, 0x72, 0x7d, 0x1e, 0x1a, 0xbd, 0xe1, 0x97, 0xa6,
|
0x33, 0xaf, 0xff, 0xc3, 0x81, 0x6e, 0x3f, 0x45, 0xa6, 0xf0, 0x5c, 0x62, 0x6a, 0x52, 0x3f, 0x01,
|
||||||
0x16, 0x3b, 0x38, 0xf2, 0x1e, 0xba, 0xa5, 0xc4, 0x82, 0xd3, 0x05, 0x4e, 0xd6, 0xca, 0x0e, 0x34,
|
0x90, 0x8a, 0x29, 0x1c, 0x63, 0xac, 0xa4, 0x4e, 0xbf, 0x7c, 0xb0, 0xd1, 0x33, 0x7d, 0xe8, 0x9d,
|
||||||
0x75, 0x60, 0xa9, 0xf7, 0x6b, 0xc0, 0x48, 0xd7, 0xe3, 0xb3, 0x72, 0x23, 0x26, 0xef, 0x00, 0xf0,
|
0x96, 0x31, 0x6a, 0xe1, 0xc8, 0x2b, 0x58, 0x9b, 0x48, 0x4c, 0x63, 0x36, 0xc6, 0x8b, 0x42, 0x59,
|
||||||
0x7b, 0xce, 0x0a, 0xaa, 0x45, 0x1f, 0x6a, 0xb6, 0x1f, 0xd6, 0xf6, 0x84, 0xc6, 0x9e, 0xf0, 0xce,
|
0x43, 0x53, 0xdd, 0x8a, 0x7a, 0x5e, 0x00, 0xfa, 0x3a, 0x4e, 0x57, 0x27, 0x33, 0x36, 0x39, 0x02,
|
||||||
0xd8, 0x13, 0x3b, 0xe8, 0xe0, 0xa7, 0x07, 0xbd, 0x18, 0x39, 0x3e, 0x3e, 0xfd, 0x4f, 0x7c, 0x68,
|
0xc0, 0x4f, 0x09, 0x4f, 0x99, 0x16, 0xdd, 0xd4, 0x6c, 0xaf, 0x97, 0xb7, 0xbd, 0x67, 0xda, 0xde,
|
||||||
0x1b, 0x61, 0xfa, 0x17, 0x3a, 0x71, 0x13, 0x3f, 0x49, 0x22, 0x42, 0x3f, 0xc6, 0xa5, 0x78, 0xc0,
|
0x3b, 0x33, 0x6d, 0xa7, 0x16, 0xda, 0xff, 0xe6, 0xc0, 0x3a, 0xc5, 0x18, 0x6f, 0x1e, 0x7e, 0x12,
|
||||||
0x7f, 0x2a, 0x31, 0xf8, 0xed, 0x01, 0x58, 0x1a, 0x89, 0xe0, 0xff, 0xb4, 0xba, 0x62, 0x26, 0xf8,
|
0x0f, 0xda, 0x46, 0x98, 0x3e, 0x42, 0x87, 0x96, 0xf6, 0x83, 0x24, 0x22, 0x74, 0x29, 0x5e, 0x8b,
|
||||||
0x64, 0xeb, 0xa4, 0x4e, 0x4c, 0x4c, 0xc9, 0x21, 0xdc, 0xc0, 0xb3, 0x02, 0x97, 0x22, 0xdd, 0xa1,
|
0x2b, 0xfc, 0xa3, 0x12, 0xfd, 0x17, 0xb0, 0x4d, 0x45, 0x06, 0xa5, 0x42, 0xa8, 0x7e, 0x8a, 0x43,
|
||||||
0xd4, 0x07, 0x9d, 0xdb, 0xe2, 0xe6, 0x29, 0x85, 0xc8, 0xb2, 0x84, 0xa6, 0x0f, 0x2e, 0xe5, 0xb0,
|
0x8c, 0xb3, 0x99, 0x94, 0xa6, 0xe2, 0xbf, 0x73, 0x15, 0x9b, 0x7b, 0x1d, 0x3b, 0xb7, 0xff, 0xbd,
|
||||||
0x3e, 0xc5, 0x94, 0x1c, 0xc2, 0x15, 0xf4, 0x8a, 0xea, 0xba, 0x5c, 0x74, 0x4b, 0xa3, 0xbb, 0x3a,
|
0x01, 0x50, 0x95, 0x25, 0x01, 0xfc, 0x15, 0x66, 0x23, 0xc2, 0x45, 0x7c, 0x31, 0xa7, 0xb4, 0x43,
|
||||||
0x6f, 0xa1, 0xc1, 0x18, 0xce, 0x36, 0x07, 0x87, 0x0c, 0xe1, 0xe4, 0x96, 0xc9, 0x3c, 0xa3, 0xab,
|
0x89, 0x09, 0x59, 0x84, 0x43, 0xd8, 0x4c, 0xf1, 0x5a, 0x84, 0x35, 0x4a, 0x2e, 0x74, 0xa3, 0x0a,
|
||||||
0x71, 0xe5, 0x40, 0xfd, 0x2f, 0x6e, 0xaa, 0x32, 0x28, 0x16, 0x19, 0x8e, 0x1d, 0x83, 0x4c, 0x1c,
|
0xce, 0x56, 0x49, 0x45, 0x14, 0x0d, 0x58, 0x78, 0x65, 0x53, 0x9a, 0x79, 0x15, 0x13, 0xb2, 0x08,
|
||||||
0x7c, 0x06, 0xe2, 0x0e, 0xbd, 0xcc, 0x05, 0x97, 0xb8, 0x61, 0xa9, 0xb7, 0x75, 0xeb, 0x3e, 0xb4,
|
0xfb, 0xb0, 0x9e, 0x66, 0xd7, 0x6d, 0xa3, 0x5b, 0x1a, 0xbd, 0xa6, 0xfd, 0xa7, 0x33, 0xcd, 0x32,
|
||||||
0x73, 0x2a, 0xe5, 0xa3, 0x28, 0xa6, 0xa6, 0x9b, 0x89, 0x83, 0x00, 0x4e, 0xef, 0x56, 0x39, 0x36,
|
0x32, 0xdd, 0x05, 0x7d, 0xdc, 0xd2, 0xce, 0x9a, 0x51, 0xe9, 0x71, 0x17, 0xf3, 0x66, 0x54, 0x9e,
|
||||||
0x7d, 0x08, 0xb4, 0xd4, 0x2a, 0x37, 0x3d, 0xf4, 0x77, 0xf0, 0x1f, 0x1c, 0x7d, 0x58, 0xe4, 0x6a,
|
0x8c, 0x6b, 0x8a, 0xbb, 0x4b, 0x39, 0xd7, 0xd8, 0xc4, 0x85, 0x25, 0x5d, 0x8a, 0x45, 0x6e, 0x5b,
|
||||||
0x75, 0xfd, 0xe7, 0x00, 0xda, 0xb7, 0xeb, 0x87, 0x80, 0x44, 0xd0, 0xaa, 0x98, 0xa4, 0x6b, 0xaf,
|
0x87, 0x8c, 0xe9, 0x9f, 0xc0, 0xea, 0xec, 0xa8, 0x93, 0x1d, 0x58, 0x3e, 0xe6, 0x32, 0x89, 0xd8,
|
||||||
0x5b, 0xa3, 0xfc, 0x0b, 0x9b, 0xd8, 0x68, 0xfd, 0x11, 0xc0, 0x0a, 0x27, 0xcf, 0x2d, 0x6a, 0x67,
|
0xf4, 0x24, 0xbb, 0xb3, 0xbc, 0x7b, 0xb6, 0x2b, 0xab, 0x44, 0x45, 0x84, 0x27, 0xd6, 0x95, 0x1a,
|
||||||
0x87, 0xfd, 0x17, 0xfb, 0x8b, 0xeb, 0x46, 0x6f, 0xa1, 0xd3, 0xec, 0x0a, 0xf1, 0x2d, 0x74, 0x7b,
|
0xdb, 0xdf, 0x85, 0x95, 0x7c, 0xf7, 0x65, 0x22, 0x62, 0x89, 0x77, 0x2d, 0xbf, 0xff, 0x0e, 0x88,
|
||||||
0x81, 0xfc, 0x6d, 0x69, 0xd5, 0xfc, 0xdb, 0x19, 0x76, 0x25, 0xec, 0x4c, 0xf6, 0x5e, 0xae, 0x7d,
|
0xbd, 0xce, 0x05, 0xda, 0x1e, 0x16, 0x67, 0x6e, 0x9e, 0x3d, 0x68, 0x27, 0x4c, 0xca, 0x1b, 0x91,
|
||||||
0xc7, 0x5c, 0xee, 0xce, 0xeb, 0xb6, 0xcb, 0xbd, 0x82, 0xa3, 0x51, 0x26, 0xe4, 0x1e, 0xb3, 0xb6,
|
0x0e, 0x4d, 0x55, 0x63, 0xfb, 0x3e, 0xac, 0x9c, 0x4d, 0x13, 0x2c, 0xf3, 0x10, 0x68, 0xa9, 0x69,
|
||||||
0x13, 0xc9, 0xb1, 0x5e, 0xc3, 0x9b, 0xbf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x55, 0x84, 0x56,
|
0x62, 0x72, 0xe8, 0x6f, 0xff, 0x19, 0xfc, 0x73, 0xc7, 0xb0, 0xdd, 0x23, 0x75, 0x09, 0x16, 0x5e,
|
||||||
0x94, 0x05, 0x00, 0x00,
|
0x8f, 0x13, 0x35, 0x3d, 0xf8, 0xd2, 0x82, 0xf6, 0x71, 0xf1, 0xe6, 0x92, 0x00, 0x5a, 0x59, 0x49,
|
||||||
|
0xb2, 0x56, 0x6d, 0x80, 0x46, 0x79, 0x5b, 0x95, 0x63, 0x46, 0xd3, 0x1b, 0x80, 0xea, 0xc4, 0xe4,
|
||||||
|
0xef, 0x0a, 0x55, 0x7b, 0xd6, 0xbc, 0xed, 0xdb, 0x83, 0x45, 0xa2, 0xe7, 0xd0, 0x29, 0x9f, 0x0f,
|
||||||
|
0xe2, 0x55, 0xd0, 0xf9, 0x37, 0xc5, 0x9b, 0x97, 0x96, 0x3d, 0x09, 0xd5, 0x5a, 0xdb, 0x12, 0x6a,
|
||||||
|
0xcb, 0x5e, 0xe7, 0x7e, 0x80, 0xcd, 0x5b, 0xdb, 0x47, 0x76, 0xad, 0x34, 0xbf, 0x58, 0x66, 0xef,
|
||||||
|
0xff, 0x7b, 0x71, 0xc5, 0xf9, 0x9e, 0x42, 0x2b, 0x1b, 0x21, 0xb2, 0x59, 0x11, 0xac, 0xdf, 0x89,
|
||||||
|
0xdd, 0xdf, 0x99, 0x49, 0xdb, 0x87, 0x85, 0x7e, 0x24, 0xe4, 0x2d, 0x37, 0x52, 0x3b, 0xcb, 0x4b,
|
||||||
|
0x80, 0xea, 0xf7, 0x67, 0xf7, 0xa1, 0xf6, 0x53, 0xac, 0x71, 0xfd, 0xe6, 0xd7, 0x86, 0x33, 0x58,
|
||||||
|
0xd4, 0xef, 0xe7, 0xe1, 0xcf, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa7, 0x13, 0xfe, 0x55, 0xa5, 0x07,
|
||||||
|
0x00, 0x00,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,12 @@ package dbplugin;
|
|||||||
import "google/protobuf/timestamp.proto";
|
import "google/protobuf/timestamp.proto";
|
||||||
|
|
||||||
message InitializeRequest {
|
message InitializeRequest {
|
||||||
|
option deprecated = true;
|
||||||
|
bytes config = 1;
|
||||||
|
bool verify_connection = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message InitRequest {
|
||||||
bytes config = 1;
|
bytes config = 1;
|
||||||
bool verify_connection = 2;
|
bool verify_connection = 2;
|
||||||
}
|
}
|
||||||
@@ -25,11 +31,24 @@ message RevokeUserRequest {
|
|||||||
string username = 2;
|
string username = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message RotateRootCredentialsRequest {
|
||||||
|
repeated string statements = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message Statements {
|
message Statements {
|
||||||
|
// DEPRECATED, will be removed in 0.12
|
||||||
string creation_statements = 1;
|
string creation_statements = 1;
|
||||||
|
// DEPRECATED, will be removed in 0.12
|
||||||
string revocation_statements = 2;
|
string revocation_statements = 2;
|
||||||
|
// DEPRECATED, will be removed in 0.12
|
||||||
string rollback_statements = 3;
|
string rollback_statements = 3;
|
||||||
|
// DEPRECATED, will be removed in 0.12
|
||||||
string renew_statements = 4;
|
string renew_statements = 4;
|
||||||
|
|
||||||
|
repeated string creation = 5;
|
||||||
|
repeated string revocation = 6;
|
||||||
|
repeated string rollback = 7;
|
||||||
|
repeated string renewal = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message UsernameConfig {
|
message UsernameConfig {
|
||||||
@@ -37,22 +56,35 @@ message UsernameConfig {
|
|||||||
string RoleName = 2;
|
string RoleName = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message InitResponse {
|
||||||
|
bytes config = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message CreateUserResponse {
|
message CreateUserResponse {
|
||||||
string username = 1;
|
string username = 1;
|
||||||
string password = 2;
|
string password = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TypeResponse {
|
message TypeResponse {
|
||||||
string type = 1;
|
string type = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RotateRootCredentialsResponse {
|
||||||
|
bytes config = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Empty {}
|
message Empty {}
|
||||||
|
|
||||||
service Database {
|
service Database {
|
||||||
rpc Type(Empty) returns (TypeResponse);
|
rpc Type(Empty) returns (TypeResponse);
|
||||||
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
|
rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
|
||||||
rpc RenewUser(RenewUserRequest) returns (Empty);
|
rpc RenewUser(RenewUserRequest) returns (Empty);
|
||||||
rpc RevokeUser(RevokeUserRequest) returns (Empty);
|
rpc RevokeUser(RevokeUserRequest) returns (Empty);
|
||||||
rpc Initialize(InitializeRequest) returns (Empty);
|
rpc RotateRootCredentials(RotateRootCredentialsRequest) returns (RotateRootCredentialsResponse);
|
||||||
rpc Close(Empty) returns (Empty);
|
rpc Init(InitRequest) returns (InitResponse);
|
||||||
|
rpc Close(Empty) returns (Empty);
|
||||||
|
|
||||||
|
rpc Initialize(InitializeRequest) returns (Empty) {
|
||||||
|
option deprecated = true;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,14 @@ package dbplugin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
|
|
||||||
metrics "github.com/armon/go-metrics"
|
metrics "github.com/armon/go-metrics"
|
||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
)
|
)
|
||||||
@@ -51,13 +57,27 @@ func (mw *databaseTracingMiddleware) RevokeUser(ctx context.Context, statements
|
|||||||
return mw.next.RevokeUser(ctx, statements, username)
|
return mw.next.RevokeUser(ctx, statements, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
func (mw *databaseTracingMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
|
||||||
|
defer func(then time.Time) {
|
||||||
|
mw.logger.Trace("database", "operation", "RotateRootCredentials", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "err", err, "took", time.Since(then))
|
||||||
|
}(time.Now())
|
||||||
|
|
||||||
|
mw.logger.Trace("database", "operation", "RotateRootCredentials", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||||
|
return mw.next.RotateRootCredentials(ctx, statements)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *databaseTracingMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
_, err := mw.Init(ctx, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *databaseTracingMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
|
||||||
defer func(then time.Time) {
|
defer func(then time.Time) {
|
||||||
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "transport", mw.transport, "verify", verifyConnection, "err", err, "took", time.Since(then))
|
||||||
}(time.Now())
|
}(time.Now())
|
||||||
|
|
||||||
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr, "transport", mw.transport)
|
||||||
return mw.next.Initialize(ctx, conf, verifyConnection)
|
return mw.next.Init(ctx, conf, verifyConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseTracingMiddleware) Close() (err error) {
|
func (mw *databaseTracingMiddleware) Close() (err error) {
|
||||||
@@ -131,7 +151,28 @@ func (mw *databaseMetricsMiddleware) RevokeUser(ctx context.Context, statements
|
|||||||
return mw.next.RevokeUser(ctx, statements, username)
|
return mw.next.RevokeUser(ctx, statements, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (err error) {
|
func (mw *databaseMetricsMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
|
||||||
|
defer func(now time.Time) {
|
||||||
|
metrics.MeasureSince([]string{"database", "RotateRootCredentials"}, now)
|
||||||
|
metrics.MeasureSince([]string{"database", mw.typeStr, "RotateRootCredentials"}, now)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
metrics.IncrCounter([]string{"database", "RotateRootCredentials", "error"}, 1)
|
||||||
|
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials", "error"}, 1)
|
||||||
|
}
|
||||||
|
}(time.Now())
|
||||||
|
|
||||||
|
metrics.IncrCounter([]string{"database", "RotateRootCredentials"}, 1)
|
||||||
|
metrics.IncrCounter([]string{"database", mw.typeStr, "RotateRootCredentials"}, 1)
|
||||||
|
return mw.next.RotateRootCredentials(ctx, statements)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
_, err := mw.Init(ctx, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *databaseMetricsMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
|
||||||
defer func(now time.Time) {
|
defer func(now time.Time) {
|
||||||
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
metrics.MeasureSince([]string{"database", "Initialize"}, now)
|
||||||
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now)
|
||||||
@@ -144,7 +185,7 @@ func (mw *databaseMetricsMiddleware) Initialize(ctx context.Context, conf map[st
|
|||||||
|
|
||||||
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
metrics.IncrCounter([]string{"database", "Initialize"}, 1)
|
||||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1)
|
||||||
return mw.next.Initialize(ctx, conf, verifyConnection)
|
return mw.next.Init(ctx, conf, verifyConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
func (mw *databaseMetricsMiddleware) Close() (err error) {
|
||||||
@@ -162,3 +203,76 @@ func (mw *databaseMetricsMiddleware) Close() (err error) {
|
|||||||
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
|
metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1)
|
||||||
return mw.next.Close()
|
return mw.next.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---- Error Sanitizer Middleware Domain ----
|
||||||
|
|
||||||
|
// DatabaseErrorSanitizerMiddleware wraps an implementation of Databases and
|
||||||
|
// sanitizes returned error messages
|
||||||
|
type DatabaseErrorSanitizerMiddleware struct {
|
||||||
|
l sync.RWMutex
|
||||||
|
next Database
|
||||||
|
secretsFn func() map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDatabaseErrorSanitizerMiddleware(next Database, secretsFn func() map[string]interface{}) *DatabaseErrorSanitizerMiddleware {
|
||||||
|
return &DatabaseErrorSanitizerMiddleware{
|
||||||
|
next: next,
|
||||||
|
secretsFn: secretsFn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) Type() (string, error) {
|
||||||
|
dbType, err := mw.next.Type()
|
||||||
|
return dbType, mw.sanitize(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
|
username, password, err = mw.next.CreateUser(ctx, statements, usernameConfig, expiration)
|
||||||
|
return username, password, mw.sanitize(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) (err error) {
|
||||||
|
return mw.sanitize(mw.next.RenewUser(ctx, statements, username, expiration))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) RevokeUser(ctx context.Context, statements Statements, username string) (err error) {
|
||||||
|
return mw.sanitize(mw.next.RevokeUser(ctx, statements, username))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
|
||||||
|
conf, err = mw.next.RotateRootCredentials(ctx, statements)
|
||||||
|
return conf, mw.sanitize(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
_, err := mw.Init(ctx, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
|
||||||
|
saveConf, err = mw.next.Init(ctx, conf, verifyConnection)
|
||||||
|
return saveConf, mw.sanitize(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) Close() (err error) {
|
||||||
|
return mw.sanitize(mw.next.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitize
|
||||||
|
func (mw *DatabaseErrorSanitizerMiddleware) sanitize(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errwrap.ContainsType(err, new(url.Error)) {
|
||||||
|
return errors.New("unable to parse connection url")
|
||||||
|
}
|
||||||
|
if mw.secretsFn != nil {
|
||||||
|
for k, v := range mw.secretsFn() {
|
||||||
|
if k == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errors.New(strings.Replace(err.Error(), k, v.(string), -1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/golang/protobuf/ptypes"
|
"github.com/golang/protobuf/ptypes"
|
||||||
"github.com/hashicorp/vault/helper/pluginutil"
|
"github.com/hashicorp/vault/helper/pluginutil"
|
||||||
@@ -61,16 +63,51 @@ func (s *gRPCServer) RevokeUser(ctx context.Context, req *RevokeUserRequest) (*E
|
|||||||
return &Empty{}, err
|
return &Empty{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
|
func (s *gRPCServer) RotateRootCredentials(ctx context.Context, req *RotateRootCredentialsRequest) (*RotateRootCredentialsResponse, error) {
|
||||||
config := map[string]interface{}{}
|
|
||||||
|
|
||||||
|
resp, err := s.impl.RotateRootCredentials(ctx, req.Statements)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
respConfig, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RotateRootCredentialsResponse{
|
||||||
|
Config: respConfig,
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *gRPCServer) Initialize(ctx context.Context, req *InitializeRequest) (*Empty, error) {
|
||||||
|
_, err := s.Init(ctx, &InitRequest{
|
||||||
|
Config: req.Config,
|
||||||
|
VerifyConnection: req.VerifyConnection,
|
||||||
|
})
|
||||||
|
return &Empty{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *gRPCServer) Init(ctx context.Context, req *InitRequest) (*InitResponse, error) {
|
||||||
|
config := map[string]interface{}{}
|
||||||
err := json.Unmarshal(req.Config, &config)
|
err := json.Unmarshal(req.Config, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.impl.Initialize(ctx, config, req.VerifyConnection)
|
resp, err := s.impl.Init(ctx, config, req.VerifyConnection)
|
||||||
return &Empty{}, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
respConfig, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &InitResponse{
|
||||||
|
Config: respConfig,
|
||||||
|
}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
|
func (s *gRPCServer) Close(_ context.Context, _ *Empty) (*Empty, error) {
|
||||||
@@ -87,7 +124,7 @@ type gRPCClient struct {
|
|||||||
doneCtx context.Context
|
doneCtx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c gRPCClient) Type() (string, error) {
|
func (c *gRPCClient) Type() (string, error) {
|
||||||
resp, err := c.client.Type(c.doneCtx, &Empty{})
|
resp, err := c.client.Type(c.doneCtx, &Empty{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -96,7 +133,7 @@ func (c gRPCClient) Type() (string, error) {
|
|||||||
return resp.Type, err
|
return resp.Type, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (c *gRPCClient) CreateUser(ctx context.Context, statements Statements, usernameConfig UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
t, err := ptypes.TimestampProto(expiration)
|
t, err := ptypes.TimestampProto(expiration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
@@ -172,10 +209,40 @@ func (c *gRPCClient) RevokeUser(ctx context.Context, statements Statements, user
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error {
|
func (c *gRPCClient) RotateRootCredentials(ctx context.Context, statements []string) (conf map[string]interface{}, err error) {
|
||||||
configRaw, err := json.Marshal(config)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
quitCh := pluginutil.CtxCancelIfCanceled(cancel, c.doneCtx)
|
||||||
|
defer close(quitCh)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := c.client.RotateRootCredentials(ctx, &RotateRootCredentialsRequest{
|
||||||
|
Statements: statements,
|
||||||
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if c.doneCtx.Err() != nil {
|
||||||
|
return nil, ErrPluginShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(resp.Config, &conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gRPCClient) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
_, err := c.Init(ctx, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gRPCClient) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
|
||||||
|
configRaw, err := json.Marshal(conf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
@@ -183,19 +250,33 @@ func (c *gRPCClient) Initialize(ctx context.Context, config map[string]interface
|
|||||||
defer close(quitCh)
|
defer close(quitCh)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
_, err = c.client.Initialize(ctx, &InitializeRequest{
|
resp, err := c.client.Init(ctx, &InitRequest{
|
||||||
Config: configRaw,
|
Config: configRaw,
|
||||||
VerifyConnection: verifyConnection,
|
VerifyConnection: verifyConnection,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if c.doneCtx.Err() != nil {
|
// Fall back to old call if not implemented
|
||||||
return ErrPluginShutdown
|
grpcStatus, ok := status.FromError(err)
|
||||||
|
if ok && grpcStatus.Code() == codes.Unimplemented {
|
||||||
|
_, err = c.client.Initialize(ctx, &InitializeRequest{
|
||||||
|
Config: configRaw,
|
||||||
|
VerifyConnection: verifyConnection,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
return conf, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
if c.doneCtx.Err() != nil {
|
||||||
|
return nil, ErrPluginShutdown
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
if err := json.Unmarshal(resp.Config, &conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *gRPCClient) Close() error {
|
func (c *gRPCClient) Close() error {
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package dbplugin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/rpc"
|
"net/rpc"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,8 +39,28 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequestRPC, _ *str
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) RotateRootCredentials(args *RotateRootCredentialsRequestRPC, resp *RotateRootCredentialsResponse) error {
|
||||||
|
config, err := ds.impl.RotateRootCredentials(context.Background(), args.Statements)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp.Config, err = json.Marshal(config)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
|
func (ds *databasePluginRPCServer) Initialize(args *InitializeRequestRPC, _ *struct{}) error {
|
||||||
err := ds.impl.Initialize(context.Background(), args.Config, args.VerifyConnection)
|
return ds.Init(&InitRequestRPC{
|
||||||
|
Config: args.Config,
|
||||||
|
VerifyConnection: args.VerifyConnection,
|
||||||
|
}, &InitResponse{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ds *databasePluginRPCServer) Init(args *InitRequestRPC, resp *InitResponse) error {
|
||||||
|
config, err := ds.impl.Init(context.Background(), args.Config, args.VerifyConnection)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp.Config, err = json.Marshal(config)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,9 +103,7 @@ func (dr *databasePluginRPCClient) RenewUser(_ context.Context, statements State
|
|||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
return dr.client.Call("Plugin.RenewUser", req, &struct{}{})
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
|
func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Statements, username string) error {
|
||||||
@@ -92,26 +112,55 @@ func (dr *databasePluginRPCClient) RevokeUser(_ context.Context, statements Stat
|
|||||||
Username: username,
|
Username: username,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
return dr.client.Call("Plugin.RevokeUser", req, &struct{}{})
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
func (dr *databasePluginRPCClient) RotateRootCredentials(_ context.Context, statements []string) (saveConf map[string]interface{}, err error) {
|
||||||
|
req := RotateRootCredentialsRequestRPC{
|
||||||
|
Statements: statements,
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp RotateRootCredentialsResponse
|
||||||
|
err = dr.client.Call("Plugin.RotateRootCredentials", req, &resp)
|
||||||
|
|
||||||
|
err = json.Unmarshal(resp.Config, &saveConf)
|
||||||
|
return saveConf, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
func (dr *databasePluginRPCClient) Initialize(_ context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
req := InitializeRequestRPC{
|
_, err := dr.Init(nil, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *databasePluginRPCClient) Init(_ context.Context, conf map[string]interface{}, verifyConnection bool) (saveConf map[string]interface{}, err error) {
|
||||||
|
req := InitRequestRPC{
|
||||||
Config: conf,
|
Config: conf,
|
||||||
VerifyConnection: verifyConnection,
|
VerifyConnection: verifyConnection,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
var resp InitResponse
|
||||||
|
err = dr.client.Call("Plugin.Init", req, &resp)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "can't find method Plugin.Init") {
|
||||||
|
req := InitializeRequestRPC{
|
||||||
|
Config: conf,
|
||||||
|
VerifyConnection: verifyConnection,
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
err = dr.client.Call("Plugin.Initialize", req, &struct{}{})
|
||||||
|
if err == nil {
|
||||||
|
return conf, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(resp.Config, &saveConf)
|
||||||
|
return saveConf, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dr *databasePluginRPCClient) Close() error {
|
func (dr *databasePluginRPCClient) Close() error {
|
||||||
err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
return dr.client.Call("Plugin.Close", struct{}{}, &struct{}{})
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---- RPC Request Args Domain ----
|
// ---- RPC Request Args Domain ----
|
||||||
@@ -121,6 +170,11 @@ type InitializeRequestRPC struct {
|
|||||||
VerifyConnection bool
|
VerifyConnection bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InitRequestRPC struct {
|
||||||
|
Config map[string]interface{}
|
||||||
|
VerifyConnection bool
|
||||||
|
}
|
||||||
|
|
||||||
type CreateUserRequestRPC struct {
|
type CreateUserRequestRPC struct {
|
||||||
Statements Statements
|
Statements Statements
|
||||||
UsernameConfig UsernameConfig
|
UsernameConfig UsernameConfig
|
||||||
@@ -137,3 +191,7 @@ type RevokeUserRequestRPC struct {
|
|||||||
Statements Statements
|
Statements Statements
|
||||||
Username string
|
Username string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RotateRootCredentialsRequestRPC struct {
|
||||||
|
Statements []string
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/go-plugin"
|
"github.com/hashicorp/go-plugin"
|
||||||
"github.com/hashicorp/vault/helper/pluginutil"
|
"github.com/hashicorp/vault/helper/pluginutil"
|
||||||
log "github.com/mgutz/logxi/v1"
|
log "github.com/mgutz/logxi/v1"
|
||||||
@@ -20,8 +21,13 @@ type Database interface {
|
|||||||
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
|
RenewUser(ctx context.Context, statements Statements, username string, expiration time.Time) error
|
||||||
RevokeUser(ctx context.Context, statements Statements, username string) error
|
RevokeUser(ctx context.Context, statements Statements, username string) error
|
||||||
|
|
||||||
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) error
|
RotateRootCredentials(ctx context.Context, statements []string) (config map[string]interface{}, err error)
|
||||||
|
|
||||||
|
Init(ctx context.Context, config map[string]interface{}, verifyConnection bool) (saveConfig map[string]interface{}, err error)
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
|
// DEPRECATED, will be removed in 0.12
|
||||||
|
Initialize(ctx context.Context, config map[string]interface{}, verifyConnection bool) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PluginFactory is used to build plugin database types. It wraps the database
|
// PluginFactory is used to build plugin database types. It wraps the database
|
||||||
@@ -40,7 +46,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
|
|||||||
// from the pluginRunner. Then cast it to a Database.
|
// from the pluginRunner. Then cast it to a Database.
|
||||||
dbRaw, err := pluginRunner.BuiltinFactory()
|
dbRaw, err := pluginRunner.BuiltinFactory()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting plugin type: %s", err)
|
return nil, errwrap.Wrapf("error initializing plugin: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ok bool
|
var ok bool
|
||||||
@@ -71,7 +77,7 @@ func PluginFactory(ctx context.Context, pluginName string, sys pluginutil.LookRu
|
|||||||
|
|
||||||
typeStr, err := db.Type()
|
typeStr, err := db.Type()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting plugin type: %s", err)
|
return nil, errwrap.Wrapf("error getting plugin type: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap with metrics middleware
|
// Wrap with metrics middleware
|
||||||
@@ -113,7 +119,11 @@ type DatabasePlugin struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) {
|
func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) {
|
||||||
return &databasePluginRPCServer{impl: d.impl}, nil
|
impl := &DatabaseErrorSanitizerMiddleware{
|
||||||
|
next: d.impl,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &databasePluginRPCServer{impl: impl}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
|
func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
|
||||||
@@ -121,7 +131,11 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d DatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
|
func (d DatabasePlugin) GRPCServer(_ *plugin.GRPCBroker, s *grpc.Server) error {
|
||||||
RegisterDatabaseServer(s, &gRPCServer{impl: d.impl})
|
impl := &DatabaseErrorSanitizerMiddleware{
|
||||||
|
next: d.impl,
|
||||||
|
}
|
||||||
|
|
||||||
|
RegisterDatabaseServer(s, &gRPCServer{impl: impl})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,6 +61,17 @@ func (m *mockPlugin) RevokeUser(_ context.Context, statements dbplugin.Statement
|
|||||||
delete(m.users, username)
|
delete(m.users, username)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (m *mockPlugin) RotateRootCredentials(_ context.Context, statements []string) (map[string]interface{}, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (m *mockPlugin) Init(_ context.Context, conf map[string]interface{}, _ bool) (map[string]interface{}, error) {
|
||||||
|
err := errors.New("err")
|
||||||
|
if len(conf) != 1 {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conf, nil
|
||||||
|
}
|
||||||
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
|
func (m *mockPlugin) Initialize(_ context.Context, conf map[string]interface{}, _ bool) error {
|
||||||
err := errors.New("err")
|
err := errors.New("err")
|
||||||
if len(conf) != 1 {
|
if len(conf) != 1 {
|
||||||
@@ -132,7 +143,7 @@ func TestPlugin_NetRPC_Main(t *testing.T) {
|
|||||||
plugin.Serve(serveConf)
|
plugin.Serve(serveConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPlugin_Initialize(t *testing.T) {
|
func TestPlugin_Init(t *testing.T) {
|
||||||
cluster, sys := getCluster(t)
|
cluster, sys := getCluster(t)
|
||||||
defer cluster.Cleanup()
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
@@ -145,7 +156,7 @@ func TestPlugin_Initialize(t *testing.T) {
|
|||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
_, err = dbRaw.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -170,7 +181,7 @@ func TestPlugin_CreateUser(t *testing.T) {
|
|||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -209,7 +220,7 @@ func TestPlugin_RenewUser(t *testing.T) {
|
|||||||
connectionDetails := map[string]interface{}{
|
connectionDetails := map[string]interface{}{
|
||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -243,7 +254,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
|||||||
connectionDetails := map[string]interface{}{
|
connectionDetails := map[string]interface{}{
|
||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -272,7 +283,7 @@ func TestPlugin_RevokeUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test the code is still compatible with an old netRPC plugin
|
// Test the code is still compatible with an old netRPC plugin
|
||||||
func TestPlugin_NetRPC_Initialize(t *testing.T) {
|
func TestPlugin_NetRPC_Init(t *testing.T) {
|
||||||
cluster, sys := getCluster(t)
|
cluster, sys := getCluster(t)
|
||||||
defer cluster.Cleanup()
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
@@ -285,7 +296,7 @@ func TestPlugin_NetRPC_Initialize(t *testing.T) {
|
|||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dbRaw.Initialize(context.Background(), connectionDetails, true)
|
_, err = dbRaw.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -310,7 +321,7 @@ func TestPlugin_NetRPC_CreateUser(t *testing.T) {
|
|||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -349,7 +360,7 @@ func TestPlugin_NetRPC_RenewUser(t *testing.T) {
|
|||||||
connectionDetails := map[string]interface{}{
|
connectionDetails := map[string]interface{}{
|
||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -383,7 +394,7 @@ func TestPlugin_NetRPC_RevokeUser(t *testing.T) {
|
|||||||
connectionDetails := map[string]interface{}{
|
connectionDetails := map[string]interface{}{
|
||||||
"test": 1,
|
"test": 1,
|
||||||
}
|
}
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/fatih/structs"
|
"github.com/fatih/structs"
|
||||||
|
uuid "github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/logical"
|
"github.com/hashicorp/vault/logical"
|
||||||
"github.com/hashicorp/vault/logical/framework"
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
@@ -24,6 +25,8 @@ type DatabaseConfig struct {
|
|||||||
// by each database type.
|
// by each database type.
|
||||||
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
||||||
AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"`
|
AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"`
|
||||||
|
|
||||||
|
RootCredentialsRotateStatements []string `json:"root_credentials_rotate_statements" structs:"root_credentials_rotate_statements" mapstructure:"root_credentials_rotate_statements"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// pathResetConnection configures a path to reset a plugin.
|
// pathResetConnection configures a path to reset a plugin.
|
||||||
@@ -55,16 +58,13 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
|
|||||||
return logical.ErrorResponse(respErrEmptyName), nil
|
return logical.ErrorResponse(respErrEmptyName), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grab the mutex lock
|
|
||||||
b.Lock()
|
|
||||||
defer b.Unlock()
|
|
||||||
|
|
||||||
// Close plugin and delete the entry in the connections cache.
|
// Close plugin and delete the entry in the connections cache.
|
||||||
b.clearConnection(name)
|
if err := b.ClearConnection(name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Execute plugin again, we don't need the object so throw away.
|
// Execute plugin again, we don't need the object so throw away.
|
||||||
_, err := b.createDBObj(ctx, req.Storage, name)
|
if _, err := b.GetConnection(ctx, req.Storage, name); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,6 +103,14 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
|
|||||||
allowed to get creds from this database connection. If empty no
|
allowed to get creds from this database connection. If empty no
|
||||||
roles are allowed. If "*" all roles are allowed.`,
|
roles are allowed. If "*" all roles are allowed.`,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
"root_rotation_statements": &framework.FieldSchema{
|
||||||
|
Type: framework.TypeStringSlice,
|
||||||
|
Description: `Specifies the database statements to be executed
|
||||||
|
to rotate the root user's credentials. See the plugin's API
|
||||||
|
page for more information on support and formatting for this
|
||||||
|
parameter.`,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||||
@@ -179,16 +187,8 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
|
|||||||
return nil, errors.New("failed to delete connection configuration")
|
return nil, errors.New("failed to delete connection configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Lock()
|
if err := b.ClearConnection(name); err != nil {
|
||||||
defer b.Unlock()
|
return nil, err
|
||||||
|
|
||||||
if _, ok := b.connections[name]; ok {
|
|
||||||
err = b.connections[name].Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(b.connections, name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -210,8 +210,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
verifyConnection := data.Get("verify_connection").(bool)
|
verifyConnection := data.Get("verify_connection").(bool)
|
||||||
|
|
||||||
allowedRoles := data.Get("allowed_roles").([]string)
|
allowedRoles := data.Get("allowed_roles").([]string)
|
||||||
|
rootRotationStatements := data.Get("root_rotation_statements").([]string)
|
||||||
|
|
||||||
// Remove these entries from the data before we store it keyed under
|
// Remove these entries from the data before we store it keyed under
|
||||||
// ConnectionDetails.
|
// ConnectionDetails.
|
||||||
@@ -219,35 +219,45 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
|||||||
delete(data.Raw, "plugin_name")
|
delete(data.Raw, "plugin_name")
|
||||||
delete(data.Raw, "allowed_roles")
|
delete(data.Raw, "allowed_roles")
|
||||||
delete(data.Raw, "verify_connection")
|
delete(data.Raw, "verify_connection")
|
||||||
|
delete(data.Raw, "root_rotation_statements")
|
||||||
|
|
||||||
config := &DatabaseConfig{
|
// Create a database plugin and initialize it. This instance is not
|
||||||
ConnectionDetails: data.Raw,
|
// going to be used and is initialized just to ensure all parameters
|
||||||
PluginName: pluginName,
|
// are valid and the connection is verified, if requested.
|
||||||
AllowedRoles: allowedRoles,
|
db, err := dbplugin.PluginFactory(ctx, pluginName, b.System(), b.logger)
|
||||||
}
|
|
||||||
|
|
||||||
db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||||
}
|
}
|
||||||
|
connDetails, err := db.Init(ctx, data.Raw, verifyConnection)
|
||||||
err = db.Initialize(ctx, config.ConnectionDetails, verifyConnection)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
db.Close()
|
db.Close()
|
||||||
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grab the mutex lock
|
|
||||||
b.Lock()
|
b.Lock()
|
||||||
defer b.Unlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
// Close and remove the old connection
|
// Close and remove the old connection
|
||||||
b.clearConnection(name)
|
b.clearConnection(name)
|
||||||
|
|
||||||
// Save the new connection
|
id, err := uuid.GenerateUUID()
|
||||||
b.connections[name] = db
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.connections[name] = &dbPluginInstance{
|
||||||
|
Database: db,
|
||||||
|
name: name,
|
||||||
|
id: id,
|
||||||
|
}
|
||||||
|
|
||||||
// Store it
|
// Store it
|
||||||
|
config := &DatabaseConfig{
|
||||||
|
ConnectionDetails: connDetails,
|
||||||
|
PluginName: pluginName,
|
||||||
|
AllowedRoles: allowedRoles,
|
||||||
|
RootCredentialsRotateStatements: rootRotationStatements,
|
||||||
|
}
|
||||||
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
|
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -54,26 +54,15 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
|||||||
return nil, logical.ErrPermissionDenied
|
return nil, logical.ErrPermissionDenied
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grab the read lock
|
|
||||||
b.RLock()
|
|
||||||
unlockFunc := b.RUnlock
|
|
||||||
|
|
||||||
// Get the Database object
|
// Get the Database object
|
||||||
db, ok := b.getDBObj(role.DBName)
|
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
|
||||||
if !ok {
|
if err != nil {
|
||||||
// Upgrade lock
|
return nil, err
|
||||||
b.RUnlock()
|
|
||||||
b.Lock()
|
|
||||||
unlockFunc = b.Unlock
|
|
||||||
|
|
||||||
// Create a new DB object
|
|
||||||
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
|
|
||||||
if err != nil {
|
|
||||||
unlockFunc()
|
|
||||||
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.RLock()
|
||||||
|
defer db.RUnlock()
|
||||||
|
|
||||||
ttl := b.System().DefaultLeaseTTL()
|
ttl := b.System().DefaultLeaseTTL()
|
||||||
if role.DefaultTTL != 0 {
|
if role.DefaultTTL != 0 {
|
||||||
ttl = role.DefaultTTL
|
ttl = role.DefaultTTL
|
||||||
@@ -96,8 +85,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
|||||||
// Create the user
|
// Create the user
|
||||||
username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration)
|
username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unlockFunc()
|
b.CloseIfShutdown(db, err)
|
||||||
b.closeIfShutdown(role.DBName, err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,8 +97,6 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
|||||||
"role": name,
|
"role": name,
|
||||||
})
|
})
|
||||||
resp.Secret.TTL = ttl
|
resp.Secret.TTL = ttl
|
||||||
|
|
||||||
unlockFunc()
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,26 +36,26 @@ func pathRoles(b *databaseBackend) *framework.Path {
|
|||||||
Description: "Name of the database this role acts on.",
|
Description: "Name of the database this role acts on.",
|
||||||
},
|
},
|
||||||
"creation_statements": {
|
"creation_statements": {
|
||||||
Type: framework.TypeString,
|
Type: framework.TypeStringSlice,
|
||||||
Description: `Specifies the database statements executed to
|
Description: `Specifies the database statements executed to
|
||||||
create and configure a user. See the plugin's API page for more
|
create and configure a user. See the plugin's API page for more
|
||||||
information on support and formatting for this parameter.`,
|
information on support and formatting for this parameter.`,
|
||||||
},
|
},
|
||||||
"revocation_statements": {
|
"revocation_statements": {
|
||||||
Type: framework.TypeString,
|
Type: framework.TypeStringSlice,
|
||||||
Description: `Specifies the database statements to be executed
|
Description: `Specifies the database statements to be executed
|
||||||
to revoke a user. See the plugin's API page for more information
|
to revoke a user. See the plugin's API page for more information
|
||||||
on support and formatting for this parameter.`,
|
on support and formatting for this parameter.`,
|
||||||
},
|
},
|
||||||
"renew_statements": {
|
"renew_statements": {
|
||||||
Type: framework.TypeString,
|
Type: framework.TypeStringSlice,
|
||||||
Description: `Specifies the database statements to be executed
|
Description: `Specifies the database statements to be executed
|
||||||
to renew a user. Not every plugin type will support this
|
to renew a user. Not every plugin type will support this
|
||||||
functionality. See the plugin's API page for more information on
|
functionality. See the plugin's API page for more information on
|
||||||
support and formatting for this parameter. `,
|
support and formatting for this parameter. `,
|
||||||
},
|
},
|
||||||
"rollback_statements": {
|
"rollback_statements": {
|
||||||
Type: framework.TypeString,
|
Type: framework.TypeStringSlice,
|
||||||
Description: `Specifies the database statements to be executed
|
Description: `Specifies the database statements to be executed
|
||||||
rollback a create operation in the event of an error. Not every
|
rollback a create operation in the event of an error. Not every
|
||||||
plugin type will support this functionality. See the plugin's
|
plugin type will support this functionality. See the plugin's
|
||||||
@@ -109,10 +109,10 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc {
|
|||||||
return &logical.Response{
|
return &logical.Response{
|
||||||
Data: map[string]interface{}{
|
Data: map[string]interface{}{
|
||||||
"db_name": role.DBName,
|
"db_name": role.DBName,
|
||||||
"creation_statements": role.Statements.CreationStatements,
|
"creation_statements": role.Statements.Creation,
|
||||||
"revocation_statements": role.Statements.RevocationStatements,
|
"revocation_statements": role.Statements.Revocation,
|
||||||
"rollback_statements": role.Statements.RollbackStatements,
|
"rollback_statements": role.Statements.Rollback,
|
||||||
"renew_statements": role.Statements.RenewStatements,
|
"renew_statements": role.Statements.Renewal,
|
||||||
"default_ttl": role.DefaultTTL.Seconds(),
|
"default_ttl": role.DefaultTTL.Seconds(),
|
||||||
"max_ttl": role.MaxTTL.Seconds(),
|
"max_ttl": role.MaxTTL.Seconds(),
|
||||||
},
|
},
|
||||||
@@ -144,10 +144,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get statements
|
// Get statements
|
||||||
creationStmts := data.Get("creation_statements").(string)
|
creationStmts := data.Get("creation_statements").([]string)
|
||||||
revocationStmts := data.Get("revocation_statements").(string)
|
revocationStmts := data.Get("revocation_statements").([]string)
|
||||||
rollbackStmts := data.Get("rollback_statements").(string)
|
rollbackStmts := data.Get("rollback_statements").([]string)
|
||||||
renewStmts := data.Get("renew_statements").(string)
|
renewStmts := data.Get("renew_statements").([]string)
|
||||||
|
|
||||||
// Get TTLs
|
// Get TTLs
|
||||||
defaultTTLRaw := data.Get("default_ttl").(int)
|
defaultTTLRaw := data.Get("default_ttl").(int)
|
||||||
@@ -156,10 +156,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
|
|||||||
maxTTL := time.Duration(maxTTLRaw) * time.Second
|
maxTTL := time.Duration(maxTTLRaw) * time.Second
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: creationStmts,
|
Creation: creationStmts,
|
||||||
RevocationStatements: revocationStmts,
|
Revocation: revocationStmts,
|
||||||
RollbackStatements: rollbackStmts,
|
Rollback: rollbackStmts,
|
||||||
RenewStatements: renewStmts,
|
Renewal: renewStmts,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store it
|
// Store it
|
||||||
@@ -181,10 +181,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type roleEntry struct {
|
type roleEntry struct {
|
||||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
DBName string `json:"db_name"`
|
||||||
Statements dbplugin.Statements `json:"statements" mapstructure:"statements" structs:"statements"`
|
Statements dbplugin.Statements `json:"statements"`
|
||||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
DefaultTTL time.Duration `json:"default_ttl"`
|
||||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
MaxTTL time.Duration `json:"max_ttl"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const pathRoleHelpSyn = `
|
const pathRoleHelpSyn = `
|
||||||
|
|||||||
80
builtin/logical/database/path_rotate_credentials.go
Normal file
80
builtin/logical/database/path_rotate_credentials.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/logical"
|
||||||
|
"github.com/hashicorp/vault/logical/framework"
|
||||||
|
)
|
||||||
|
|
||||||
|
func pathRotateCredentials(b *databaseBackend) *framework.Path {
|
||||||
|
return &framework.Path{
|
||||||
|
Pattern: "rotate-root/" + framework.GenericNameRegex("name"),
|
||||||
|
Fields: map[string]*framework.FieldSchema{
|
||||||
|
"name": &framework.FieldSchema{
|
||||||
|
Type: framework.TypeString,
|
||||||
|
Description: "Name of this database connection",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||||
|
logical.ReadOperation: b.pathRotateCredentialsUpdate(),
|
||||||
|
},
|
||||||
|
|
||||||
|
HelpSynopsis: pathCredsCreateReadHelpSyn,
|
||||||
|
HelpDescription: pathCredsCreateReadHelpDesc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *databaseBackend) pathRotateCredentialsUpdate() framework.OperationFunc {
|
||||||
|
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||||
|
name := data.Get("name").(string)
|
||||||
|
if name == "" {
|
||||||
|
return logical.ErrorResponse(respErrEmptyName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := b.DatabaseConfig(ctx, req.Storage, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := b.GetConnection(ctx, req.Storage, name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take the write lock instead of read since we are updating the
|
||||||
|
// connection
|
||||||
|
db.Lock()
|
||||||
|
defer db.Unlock()
|
||||||
|
|
||||||
|
connectionDetails, err := db.RotateRootCredentials(ctx, config.RootCredentialsRotateStatements)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.ConnectionDetails = connectionDetails
|
||||||
|
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := req.Storage.Put(ctx, entry); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := b.ClearConnection(name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const pathRotateCredentialsUpdateHelpSyn = `
|
||||||
|
Request to rotate the root credentials for a certain database connection.
|
||||||
|
`
|
||||||
|
|
||||||
|
const pathRotateCredentialsUpdateHelpDesc = `
|
||||||
|
This path attempts to rotate the root credentials for the given database.
|
||||||
|
`
|
||||||
@@ -48,37 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grab the read lock
|
|
||||||
b.RLock()
|
|
||||||
unlockFunc := b.RUnlock
|
|
||||||
|
|
||||||
// Get the Database object
|
// Get the Database object
|
||||||
db, ok := b.getDBObj(role.DBName)
|
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
|
||||||
if !ok {
|
if err != nil {
|
||||||
// Upgrade lock
|
return nil, err
|
||||||
b.RUnlock()
|
|
||||||
b.Lock()
|
|
||||||
unlockFunc = b.Unlock
|
|
||||||
|
|
||||||
// Create a new DB object
|
|
||||||
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
|
|
||||||
if err != nil {
|
|
||||||
unlockFunc()
|
|
||||||
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.RLock()
|
||||||
|
defer db.RUnlock()
|
||||||
|
|
||||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||||
err := db.RenewUser(ctx, role.Statements, username, expireTime)
|
err := db.RenewUser(ctx, role.Statements, username, expireTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
unlockFunc()
|
b.CloseIfShutdown(db, err)
|
||||||
b.closeIfShutdown(role.DBName, err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unlockFunc()
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -107,33 +93,19 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
|||||||
return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"])
|
return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grab the read lock
|
|
||||||
b.RLock()
|
|
||||||
unlockFunc := b.RUnlock
|
|
||||||
|
|
||||||
// Get our connection
|
// Get our connection
|
||||||
db, ok := b.getDBObj(role.DBName)
|
db, err := b.GetConnection(ctx, req.Storage, role.DBName)
|
||||||
if !ok {
|
if err != nil {
|
||||||
// Upgrade lock
|
|
||||||
b.RUnlock()
|
|
||||||
b.Lock()
|
|
||||||
unlockFunc = b.Unlock
|
|
||||||
|
|
||||||
// Create a new DB object
|
|
||||||
db, err = b.createDBObj(ctx, req.Storage, role.DBName)
|
|
||||||
if err != nil {
|
|
||||||
unlockFunc()
|
|
||||||
return nil, fmt.Errorf("could not retrieve db with name: %s, got error: %s", role.DBName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.RevokeUser(ctx, role.Statements, username); err != nil {
|
|
||||||
unlockFunc()
|
|
||||||
b.closeIfShutdown(role.DBName, err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
unlockFunc()
|
db.RLock()
|
||||||
|
defer db.RUnlock()
|
||||||
|
|
||||||
|
if err := db.RevokeUser(ctx, role.Statements, username); err != nil {
|
||||||
|
b.CloseIfShutdown(db, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,27 +11,34 @@ import (
|
|||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/helper/strutil"
|
"github.com/hashicorp/vault/helper/strutil"
|
||||||
"github.com/hashicorp/vault/plugins"
|
"github.com/hashicorp/vault/plugins"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
|
defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
|
||||||
defaultUserDeletionCQL = `DROP USER '{{username}}';`
|
defaultUserDeletionCQL = `DROP USER '{{username}}';`
|
||||||
cassandraTypeName = "cassandra"
|
defaultRootCredentialRotationCQL = `ALTER USER {{username}} WITH PASSWORD '{{password}}';`
|
||||||
|
cassandraTypeName = "cassandra"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ dbplugin.Database = &Cassandra{}
|
var _ dbplugin.Database = &Cassandra{}
|
||||||
|
|
||||||
// Cassandra is an implementation of Database interface
|
// Cassandra is an implementation of Database interface
|
||||||
type Cassandra struct {
|
type Cassandra struct {
|
||||||
connutil.ConnectionProducer
|
*cassandraConnectionProducer
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new Cassandra instance
|
// New returns a new Cassandra instance
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
|
db := new()
|
||||||
|
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
|
||||||
|
|
||||||
|
return dbType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func new() *Cassandra {
|
||||||
connProducer := &cassandraConnectionProducer{}
|
connProducer := &cassandraConnectionProducer{}
|
||||||
connProducer.Type = cassandraTypeName
|
connProducer.Type = cassandraTypeName
|
||||||
|
|
||||||
@@ -42,12 +49,10 @@ func New() (interface{}, error) {
|
|||||||
Separator: "_",
|
Separator: "_",
|
||||||
}
|
}
|
||||||
|
|
||||||
dbType := &Cassandra{
|
return &Cassandra{
|
||||||
ConnectionProducer: connProducer,
|
cassandraConnectionProducer: connProducer,
|
||||||
CredentialsProducer: credsProducer,
|
CredentialsProducer: credsProducer,
|
||||||
}
|
}
|
||||||
|
|
||||||
return dbType, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run instantiates a Cassandra object, and runs the RPC server for the plugin
|
// Run instantiates a Cassandra object, and runs the RPC server for the plugin
|
||||||
@@ -57,7 +62,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
plugins.Serve(dbType.(*Cassandra), apiTLSConfig)
|
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -83,19 +88,22 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
|
|||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
session, err := c.getConnection(ctx)
|
session, err := c.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
creationCQL := statements.CreationStatements
|
creationCQL := statements.Creation
|
||||||
if creationCQL == "" {
|
if len(creationCQL) == 0 {
|
||||||
creationCQL = defaultUserCreationCQL
|
creationCQL = []string{defaultUserCreationCQL}
|
||||||
}
|
}
|
||||||
rollbackCQL := statements.RollbackStatements
|
|
||||||
if rollbackCQL == "" {
|
rollbackCQL := statements.Rollback
|
||||||
rollbackCQL = defaultUserDeletionCQL
|
if len(rollbackCQL) == 0 {
|
||||||
|
rollbackCQL = []string{defaultUserDeletionCQL}
|
||||||
}
|
}
|
||||||
|
|
||||||
username, err = c.GenerateUsername(usernameConfig)
|
username, err = c.GenerateUsername(usernameConfig)
|
||||||
@@ -112,28 +120,32 @@ func (c *Cassandra) CreateUser(ctx context.Context, statements dbplugin.Statemen
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Execute each query
|
// Execute each query
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") {
|
for _, stmt := range creationCQL {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
err = session.Query(dbutil.QueryHelper(query, map[string]string{
|
|
||||||
"username": username,
|
err = session.Query(dbutil.QueryHelper(query, map[string]string{
|
||||||
"password": password,
|
"username": username,
|
||||||
})).Exec()
|
"password": password,
|
||||||
if err != nil {
|
})).Exec()
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") {
|
if err != nil {
|
||||||
query = strings.TrimSpace(query)
|
for _, stmt := range rollbackCQL {
|
||||||
if len(query) == 0 {
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
continue
|
query = strings.TrimSpace(query)
|
||||||
}
|
if len(query) == 0 {
|
||||||
|
continue
|
||||||
session.Query(dbutil.QueryHelper(query, map[string]string{
|
}
|
||||||
"username": username,
|
|
||||||
})).Exec()
|
session.Query(dbutil.QueryHelper(query, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
})).Exec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", "", err
|
||||||
}
|
}
|
||||||
return "", "", err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,29 +164,79 @@ func (c *Cassandra) RevokeUser(ctx context.Context, statements dbplugin.Statemen
|
|||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
session, err := c.getConnection(ctx)
|
session, err := c.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
revocationCQL := statements.RevocationStatements
|
revocationCQL := statements.Revocation
|
||||||
if revocationCQL == "" {
|
if len(revocationCQL) == 0 {
|
||||||
revocationCQL = defaultUserDeletionCQL
|
revocationCQL = []string{defaultUserDeletionCQL}
|
||||||
}
|
}
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationCQL, ";") {
|
for _, stmt := range revocationCQL {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := session.Query(dbutil.QueryHelper(query, map[string]string{
|
||||||
|
"username": username,
|
||||||
|
})).Exec()
|
||||||
|
|
||||||
|
result = multierror.Append(result, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := session.Query(dbutil.QueryHelper(query, map[string]string{
|
|
||||||
"username": username,
|
|
||||||
})).Exec()
|
|
||||||
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Cassandra) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
|
||||||
|
// Grab the lock
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
|
session, err := c.getConnection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rotateCQL := statements
|
||||||
|
if len(rotateCQL) == 0 {
|
||||||
|
rotateCQL = []string{defaultRootCredentialRotationCQL}
|
||||||
|
}
|
||||||
|
|
||||||
|
password, err := c.GeneratePassword()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *multierror.Error
|
||||||
|
for _, stmt := range rotateCQL {
|
||||||
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if len(query) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := session.Query(dbutil.QueryHelper(query, map[string]string{
|
||||||
|
"username": c.Username,
|
||||||
|
"password": password,
|
||||||
|
})).Exec()
|
||||||
|
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = result.ErrorOrNil()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.rawConfig["password"] = password
|
||||||
|
return c.rawConfig, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
"github.com/gocql/gocql"
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||||
)
|
)
|
||||||
@@ -60,7 +61,7 @@ func prepareCassandraTestContainer(t *testing.T) (func(), string, int) {
|
|||||||
|
|
||||||
session, err := clusterConfig.CreateSession()
|
session, err := clusterConfig.CreateSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating session: %s", err)
|
return errwrap.Wrapf("error creating session: {{err}}", err)
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
return nil
|
return nil
|
||||||
@@ -86,16 +87,13 @@ func TestCassandra_Initialize(t *testing.T) {
|
|||||||
"protocol_version": 4,
|
"protocol_version": 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*Cassandra)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)
|
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !connProducer.Initialized {
|
if !db.Initialized {
|
||||||
t.Fatal("Database should be initalized")
|
t.Fatal("Database should be initalized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,7 +111,7 @@ func TestCassandra_Initialize(t *testing.T) {
|
|||||||
"protocol_version": "4",
|
"protocol_version": "4",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -134,15 +132,14 @@ func TestCassandra_CreateUser(t *testing.T) {
|
|||||||
"protocol_version": 4,
|
"protocol_version": 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*Cassandra)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testCassandraRole,
|
Creation: []string{testCassandraRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -175,15 +172,14 @@ func TestMyCassandra_RenewUser(t *testing.T) {
|
|||||||
"protocol_version": 4,
|
"protocol_version": 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*Cassandra)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testCassandraRole,
|
Creation: []string{testCassandraRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -221,15 +217,14 @@ func TestCassandra_RevokeUser(t *testing.T) {
|
|||||||
"protocol_version": 4,
|
"protocol_version": 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*Cassandra)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testCassandraRole,
|
Creation: []string{testCassandraRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -268,7 +263,7 @@ func testCredsExist(t testing.TB, address string, port int, username, password s
|
|||||||
|
|
||||||
session, err := clusterConfig.CreateSession()
|
session, err := clusterConfig.CreateSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating session: %s", err)
|
return errwrap.Wrapf("error creating session: {{err}}", err)
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
"github.com/gocql/gocql"
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/vault/helper/certutil"
|
"github.com/hashicorp/vault/helper/certutil"
|
||||||
"github.com/hashicorp/vault/helper/parseutil"
|
"github.com/hashicorp/vault/helper/parseutil"
|
||||||
"github.com/hashicorp/vault/helper/tlsutil"
|
"github.com/hashicorp/vault/helper/tlsutil"
|
||||||
@@ -37,6 +38,7 @@ type cassandraConnectionProducer struct {
|
|||||||
certificate string
|
certificate string
|
||||||
privateKey string
|
privateKey string
|
||||||
issuingCA string
|
issuingCA string
|
||||||
|
rawConfig map[string]interface{}
|
||||||
|
|
||||||
Initialized bool
|
Initialized bool
|
||||||
Type string
|
Type string
|
||||||
@@ -45,12 +47,19 @@ type cassandraConnectionProducer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
_, err := c.Init(ctx, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cassandraConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
|
c.rawConfig = conf
|
||||||
|
|
||||||
err := mapstructure.WeakDecode(conf, c)
|
err := mapstructure.WeakDecode(conf, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.ConnectTimeoutRaw == nil {
|
if c.ConnectTimeoutRaw == nil {
|
||||||
@@ -58,16 +67,16 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
|
|||||||
}
|
}
|
||||||
c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
|
c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid connect_timeout: %s", err)
|
return nil, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case len(c.Hosts) == 0:
|
case len(c.Hosts) == 0:
|
||||||
return fmt.Errorf("hosts cannot be empty")
|
return nil, fmt.Errorf("hosts cannot be empty")
|
||||||
case len(c.Username) == 0:
|
case len(c.Username) == 0:
|
||||||
return fmt.Errorf("username cannot be empty")
|
return nil, fmt.Errorf("username cannot be empty")
|
||||||
case len(c.Password) == 0:
|
case len(c.Password) == 0:
|
||||||
return fmt.Errorf("password cannot be empty")
|
return nil, fmt.Errorf("password cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
var certBundle *certutil.CertBundle
|
var certBundle *certutil.CertBundle
|
||||||
@@ -76,11 +85,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
|
|||||||
case len(c.PemJSON) != 0:
|
case len(c.PemJSON) != 0:
|
||||||
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
|
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err)
|
return nil, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
|
||||||
}
|
}
|
||||||
certBundle, err = parsedCertBundle.ToCertBundle()
|
certBundle, err = parsedCertBundle.ToCertBundle()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error marshaling PEM information: %s", err)
|
return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
|
||||||
}
|
}
|
||||||
c.certificate = certBundle.Certificate
|
c.certificate = certBundle.Certificate
|
||||||
c.privateKey = certBundle.PrivateKey
|
c.privateKey = certBundle.PrivateKey
|
||||||
@@ -90,11 +99,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
|
|||||||
case len(c.PemBundle) != 0:
|
case len(c.PemBundle) != 0:
|
||||||
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
|
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error parsing the given PEM information: %s", err)
|
return nil, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
|
||||||
}
|
}
|
||||||
certBundle, err = parsedCertBundle.ToCertBundle()
|
certBundle, err = parsedCertBundle.ToCertBundle()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error marshaling PEM information: %s", err)
|
return nil, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
|
||||||
}
|
}
|
||||||
c.certificate = certBundle.Certificate
|
c.certificate = certBundle.Certificate
|
||||||
c.privateKey = certBundle.PrivateKey
|
c.privateKey = certBundle.PrivateKey
|
||||||
@@ -108,11 +117,11 @@ func (c *cassandraConnectionProducer) Initialize(ctx context.Context, conf map[s
|
|||||||
|
|
||||||
if verifyConnection {
|
if verifyConnection {
|
||||||
if _, err := c.Connection(ctx); err != nil {
|
if _, err := c.Connection(ctx); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
func (c *cassandraConnectionProducer) Connection(_ context.Context) (interface{}, error) {
|
||||||
@@ -186,12 +195,12 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
|||||||
|
|
||||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||||
if err != nil || tlsConfig == nil {
|
if err != nil || tlsConfig == nil {
|
||||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
|
||||||
}
|
}
|
||||||
tlsConfig.InsecureSkipVerify = c.InsecureTLS
|
tlsConfig.InsecureSkipVerify = c.InsecureTLS
|
||||||
|
|
||||||
@@ -215,7 +224,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
|||||||
|
|
||||||
session, err := clusterConfig.CreateSession()
|
session, err := clusterConfig.CreateSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating session: %s", err)
|
return nil, errwrap.Wrapf("error creating session: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set consistency
|
// Set consistency
|
||||||
@@ -231,8 +240,16 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
|||||||
// Verify the info
|
// Verify the info
|
||||||
err = session.Query(`LIST ALL`).Exec()
|
err = session.Query(`LIST ALL`).Exec()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error validating connection info: %s", err)
|
return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *cassandraConnectionProducer) secretValues() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
c.Password: "[password]",
|
||||||
|
c.PemBundle: "[pem_bundle]",
|
||||||
|
c.PemJSON: "[pem_json]",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -572,7 +572,7 @@ ssl_storage_port: 7001
|
|||||||
#
|
#
|
||||||
# Setting listen_address to 0.0.0.0 is always wrong.
|
# Setting listen_address to 0.0.0.0 is always wrong.
|
||||||
#
|
#
|
||||||
listen_address: 172.17.0.5
|
listen_address: 172.17.0.2
|
||||||
|
|
||||||
# Set listen_address OR listen_interface, not both. Interfaces must correspond
|
# Set listen_address OR listen_interface, not both. Interfaces must correspond
|
||||||
# to a single address, IP aliasing is not supported.
|
# to a single address, IP aliasing is not supported.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package hana
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,7 +24,7 @@ const (
|
|||||||
|
|
||||||
// HANA is an implementation of Database interface
|
// HANA is an implementation of Database interface
|
||||||
type HANA struct {
|
type HANA struct {
|
||||||
connutil.ConnectionProducer
|
*connutil.SQLConnectionProducer
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,6 +32,14 @@ var _ dbplugin.Database = &HANA{}
|
|||||||
|
|
||||||
// New implements builtinplugins.BuiltinFactory
|
// New implements builtinplugins.BuiltinFactory
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
|
db := new()
|
||||||
|
// Wrap the plugin with middleware to sanitize errors
|
||||||
|
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
|
||||||
|
|
||||||
|
return dbType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func new() *HANA {
|
||||||
connProducer := &connutil.SQLConnectionProducer{}
|
connProducer := &connutil.SQLConnectionProducer{}
|
||||||
connProducer.Type = hanaTypeName
|
connProducer.Type = hanaTypeName
|
||||||
|
|
||||||
@@ -41,12 +50,10 @@ func New() (interface{}, error) {
|
|||||||
Separator: "_",
|
Separator: "_",
|
||||||
}
|
}
|
||||||
|
|
||||||
dbType := &HANA{
|
return &HANA{
|
||||||
ConnectionProducer: connProducer,
|
SQLConnectionProducer: connProducer,
|
||||||
CredentialsProducer: credsProducer,
|
CredentialsProducer: credsProducer,
|
||||||
}
|
}
|
||||||
|
|
||||||
return dbType, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run instantiates a HANA object, and runs the RPC server for the plugin
|
// Run instantiates a HANA object, and runs the RPC server for the plugin
|
||||||
@@ -56,7 +63,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
plugins.Serve(dbType.(*HANA), apiTLSConfig)
|
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -82,13 +89,15 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
|
|||||||
h.Lock()
|
h.Lock()
|
||||||
defer h.Unlock()
|
defer h.Unlock()
|
||||||
|
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := h.getConnection(ctx)
|
db, err := h.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if statements.CreationStatements == "" {
|
if len(statements.Creation) == 0 {
|
||||||
return "", "", dbutil.ErrEmptyCreationStatement
|
return "", "", dbutil.ErrEmptyCreationStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,23 +136,25 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
|
|||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
// Execute each query
|
// Execute each query
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
for _, stmt := range statements.Creation {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
"password": password,
|
"password": password,
|
||||||
"expiration": expirationStr,
|
"expiration": expirationStr,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,6 +168,8 @@ func (h *HANA) CreateUser(ctx context.Context, statements dbplugin.Statements, u
|
|||||||
|
|
||||||
// Renewing hana user just means altering user's valid until property
|
// Renewing hana user just means altering user's valid until property
|
||||||
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
// Get connection
|
// Get connection
|
||||||
db, err := h.getConnection(ctx)
|
db, err := h.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -197,8 +210,10 @@ func (h *HANA) RenewUser(ctx context.Context, statements dbplugin.Statements, us
|
|||||||
|
|
||||||
// Revoking hana user will deactivate user and try to perform a soft drop
|
// Revoking hana user will deactivate user and try to perform a soft drop
|
||||||
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
// default revoke will be a soft drop on user
|
// default revoke will be a soft drop on user
|
||||||
if statements.RevocationStatements == "" {
|
if len(statements.Revocation) == 0 {
|
||||||
return h.revokeUserDefault(ctx, username)
|
return h.revokeUserDefault(ctx, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,30 +231,27 @@ func (h *HANA) RevokeUser(ctx context.Context, statements dbplugin.Statements, u
|
|||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
// Execute each query
|
// Execute each query
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") {
|
for _, stmt := range statements.Revocation {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit the transaction
|
return tx.Commit()
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
|
func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
|
||||||
@@ -284,3 +296,8 @@ func (h *HANA) revokeUserDefault(ctx context.Context, username string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RotateRootCredentials is not currently supported on HANA
|
||||||
|
func (h *HANA) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
|
||||||
|
return nil, errors.New("root credentaion rotation is not currently implemented in this database secrets engine")
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHANA_Initialize(t *testing.T) {
|
func TestHANA_Initialize(t *testing.T) {
|
||||||
@@ -23,16 +22,13 @@ func TestHANA_Initialize(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*HANA)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
if !db.Initialized {
|
||||||
if !connProducer.Initialized {
|
|
||||||
t.Fatal("Database should be initialized")
|
t.Fatal("Database should be initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,10 +49,8 @@ func TestHANA_CreateUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*HANA)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -73,7 +67,7 @@ func TestHANA_CreateUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testHANARole,
|
Creation: []string{testHANARole},
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Hour))
|
||||||
@@ -96,16 +90,14 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*HANA)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testHANARole,
|
Creation: []string{testHANARole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -139,7 +131,7 @@ func TestHANA_RevokeUser(t *testing.T) {
|
|||||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements.RevocationStatements = testHANADrop
|
statements.Revocation = []string{testHANADrop}
|
||||||
err = db.RevokeUser(context.Background(), statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||||
|
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
|
|
||||||
"gopkg.in/mgo.v2"
|
"gopkg.in/mgo.v2"
|
||||||
@@ -25,28 +27,43 @@ import (
|
|||||||
type mongoDBConnectionProducer struct {
|
type mongoDBConnectionProducer struct {
|
||||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||||
WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"`
|
WriteConcern string `json:"write_concern" structs:"write_concern" mapstructure:"write_concern"`
|
||||||
|
Username string `json:"username" structs:"username" mapstructure:"username"`
|
||||||
|
Password string `json:"password" structs:"password" mapstructure:"password"`
|
||||||
|
|
||||||
Initialized bool
|
Initialized bool
|
||||||
|
RawConfig map[string]interface{}
|
||||||
Type string
|
Type string
|
||||||
session *mgo.Session
|
session *mgo.Session
|
||||||
safe *mgo.Safe
|
safe *mgo.Safe
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize parses connection configuration.
|
|
||||||
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
_, err := c.Init(ctx, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize parses connection configuration.
|
||||||
|
func (c *mongoDBConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
|
c.RawConfig = conf
|
||||||
|
|
||||||
err := mapstructure.WeakDecode(conf, c)
|
err := mapstructure.WeakDecode(conf, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.ConnectionURL) == 0 {
|
if len(c.ConnectionURL) == 0 {
|
||||||
return fmt.Errorf("connection_url cannot be empty")
|
return nil, fmt.Errorf("connection_url cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
|
||||||
|
"username": c.Username,
|
||||||
|
"password": c.Password,
|
||||||
|
})
|
||||||
|
|
||||||
if c.WriteConcern != "" {
|
if c.WriteConcern != "" {
|
||||||
input := c.WriteConcern
|
input := c.WriteConcern
|
||||||
|
|
||||||
@@ -60,13 +77,13 @@ func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[str
|
|||||||
concern := &mgo.Safe{}
|
concern := &mgo.Safe{}
|
||||||
err = json.Unmarshal([]byte(input), concern)
|
err = json.Unmarshal([]byte(input), concern)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error mashalling write_concern: %s", err)
|
return nil, errwrap.Wrapf("error mashalling write_concern: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Guard against empty, non-nil mgo.Safe object; we don't want to pass that
|
// Guard against empty, non-nil mgo.Safe object; we don't want to pass that
|
||||||
// into mgo.SetSafe in Connection().
|
// into mgo.SetSafe in Connection().
|
||||||
if (mgo.Safe{} == *concern) {
|
if (mgo.Safe{} == *concern) {
|
||||||
return fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields")
|
return nil, fmt.Errorf("provided write_concern values did not map to any mgo.Safe fields")
|
||||||
}
|
}
|
||||||
c.safe = concern
|
c.safe = concern
|
||||||
}
|
}
|
||||||
@@ -77,15 +94,15 @@ func (c *mongoDBConnectionProducer) Initialize(ctx context.Context, conf map[str
|
|||||||
|
|
||||||
if verifyConnection {
|
if verifyConnection {
|
||||||
if _, err := c.Connection(ctx); err != nil {
|
if _, err := c.Connection(ctx); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.session.Ping(); err != nil {
|
if err := c.session.Ping(); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection creates or returns an existing a database connection. If the session fails
|
// Connection creates or returns an existing a database connection. If the session fails
|
||||||
@@ -203,3 +220,9 @@ func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
|
|||||||
|
|
||||||
return &info, nil
|
return &info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *mongoDBConnectionProducer) secretValues() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
c.Password: "[password]",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package mongodb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -14,7 +15,6 @@ import (
|
|||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/plugins"
|
"github.com/hashicorp/vault/plugins"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||||
"gopkg.in/mgo.v2"
|
"gopkg.in/mgo.v2"
|
||||||
@@ -24,7 +24,7 @@ const mongoDBTypeName = "mongodb"
|
|||||||
|
|
||||||
// MongoDB is an implementation of Database interface
|
// MongoDB is an implementation of Database interface
|
||||||
type MongoDB struct {
|
type MongoDB struct {
|
||||||
connutil.ConnectionProducer
|
*mongoDBConnectionProducer
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,6 +32,12 @@ var _ dbplugin.Database = &MongoDB{}
|
|||||||
|
|
||||||
// New returns a new MongoDB instance
|
// New returns a new MongoDB instance
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
|
db := new()
|
||||||
|
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
|
||||||
|
return dbType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func new() *MongoDB {
|
||||||
connProducer := &mongoDBConnectionProducer{}
|
connProducer := &mongoDBConnectionProducer{}
|
||||||
connProducer.Type = mongoDBTypeName
|
connProducer.Type = mongoDBTypeName
|
||||||
|
|
||||||
@@ -42,11 +48,10 @@ func New() (interface{}, error) {
|
|||||||
Separator: "-",
|
Separator: "-",
|
||||||
}
|
}
|
||||||
|
|
||||||
dbType := &MongoDB{
|
return &MongoDB{
|
||||||
ConnectionProducer: connProducer,
|
mongoDBConnectionProducer: connProducer,
|
||||||
CredentialsProducer: credsProducer,
|
CredentialsProducer: credsProducer,
|
||||||
}
|
}
|
||||||
return dbType, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run instantiates a MongoDB object, and runs the RPC server for the plugin
|
// Run instantiates a MongoDB object, and runs the RPC server for the plugin
|
||||||
@@ -88,7 +93,9 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements
|
|||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
if statements.CreationStatements == "" {
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
|
if len(statements.Creation) == 0 {
|
||||||
return "", "", dbutil.ErrEmptyCreationStatement
|
return "", "", dbutil.ErrEmptyCreationStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,7 +116,7 @@ func (m *MongoDB) CreateUser(ctx context.Context, statements dbplugin.Statements
|
|||||||
|
|
||||||
// Unmarshal statements.CreationStatements into mongodbRoles
|
// Unmarshal statements.CreationStatements into mongodbRoles
|
||||||
var mongoCS mongoDBStatement
|
var mongoCS mongoDBStatement
|
||||||
err = json.Unmarshal([]byte(statements.CreationStatements), &mongoCS)
|
err = json.Unmarshal([]byte(statements.Creation[0]), &mongoCS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
@@ -158,15 +165,22 @@ func (m *MongoDB) RenewUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
// RevokeUser drops the specified user from the authentication database. If none is provided
|
// RevokeUser drops the specified user from the authentication database. If none is provided
|
||||||
// in the revocation statement, the default "admin" authentication database will be assumed.
|
// in the revocation statement, the default "admin" authentication database will be assumed.
|
||||||
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
session, err := m.getConnection(ctx)
|
session, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no revocation statements provided, pass in empty JSON
|
// If no revocation statements provided, pass in empty JSON
|
||||||
revocationStatement := statements.RevocationStatements
|
var revocationStatement string
|
||||||
if revocationStatement == "" {
|
switch len(statements.Revocation) {
|
||||||
|
case 0:
|
||||||
revocationStatement = `{}`
|
revocationStatement = `{}`
|
||||||
|
case 1:
|
||||||
|
revocationStatement = statements.Revocation[0]
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("expected 0 or 1 revocation statements, got %d", len(statements.Revocation))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unmarshal revocation statements into mongodbRoles
|
// Unmarshal revocation statements into mongodbRoles
|
||||||
@@ -186,7 +200,7 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements
|
|||||||
switch {
|
switch {
|
||||||
case err == nil, err == mgo.ErrNotFound:
|
case err == nil, err == mgo.ErrNotFound:
|
||||||
case err == io.EOF, strings.Contains(err.Error(), "EOF"):
|
case err == io.EOF, strings.Contains(err.Error(), "EOF"):
|
||||||
if err := m.ConnectionProducer.Close(); err != nil {
|
if err := m.Close(); err != nil {
|
||||||
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
return errwrap.Wrapf("error closing EOF'd mongo connection: {{err}}", err)
|
||||||
}
|
}
|
||||||
session, err := m.getConnection(ctx)
|
session, err := m.getConnection(ctx)
|
||||||
@@ -203,3 +217,8 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RotateRootCredentials is not currently supported on MongoDB
|
||||||
|
func (m *MongoDB) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
|
||||||
|
return nil, errors.New("root credentaion rotation is not currently implemented in this database secrets engine")
|
||||||
|
}
|
||||||
|
|||||||
@@ -73,19 +73,13 @@ func TestMongoDB_Initialize(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, err := New()
|
db := new()
|
||||||
if err != nil {
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
db := dbRaw.(*MongoDB)
|
|
||||||
connProducer := db.ConnectionProducer.(*mongoDBConnectionProducer)
|
|
||||||
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !connProducer.Initialized {
|
if !db.Initialized {
|
||||||
t.Fatal("Database should be initialized")
|
t.Fatal("Database should be initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,18 +97,14 @@ func TestMongoDB_CreateUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, err := New()
|
db := new()
|
||||||
if err != nil {
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
db := dbRaw.(*MongoDB)
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testMongoDBRole,
|
Creation: []string{testMongoDBRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -141,18 +131,14 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
|
|||||||
"write_concern": testMongoDBWriteConcern,
|
"write_concern": testMongoDBWriteConcern,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, err := New()
|
db := new()
|
||||||
if err != nil {
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
db := dbRaw.(*MongoDB)
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testMongoDBRole,
|
Creation: []string{testMongoDBRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -178,18 +164,14 @@ func TestMongoDB_RevokeUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, err := New()
|
db := new()
|
||||||
if err != nil {
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
t.Fatalf("err: %s", err)
|
|
||||||
}
|
|
||||||
db := dbRaw.(*MongoDB)
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testMongoDBRole,
|
Creation: []string{testMongoDBRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ package mssql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/denisenkom/go-mssqldb"
|
_ "github.com/denisenkom/go-mssqldb"
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/helper/strutil"
|
"github.com/hashicorp/vault/helper/strutil"
|
||||||
@@ -23,11 +25,19 @@ var _ dbplugin.Database = &MSSQL{}
|
|||||||
|
|
||||||
// MSSQL is an implementation of Database interface
|
// MSSQL is an implementation of Database interface
|
||||||
type MSSQL struct {
|
type MSSQL struct {
|
||||||
connutil.ConnectionProducer
|
*connutil.SQLConnectionProducer
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
|
db := new()
|
||||||
|
// Wrap the plugin with middleware to sanitize errors
|
||||||
|
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
|
||||||
|
|
||||||
|
return dbType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func new() *MSSQL {
|
||||||
connProducer := &connutil.SQLConnectionProducer{}
|
connProducer := &connutil.SQLConnectionProducer{}
|
||||||
connProducer.Type = msSQLTypeName
|
connProducer.Type = msSQLTypeName
|
||||||
|
|
||||||
@@ -38,12 +48,10 @@ func New() (interface{}, error) {
|
|||||||
Separator: "-",
|
Separator: "-",
|
||||||
}
|
}
|
||||||
|
|
||||||
dbType := &MSSQL{
|
return &MSSQL{
|
||||||
ConnectionProducer: connProducer,
|
SQLConnectionProducer: connProducer,
|
||||||
CredentialsProducer: credsProducer,
|
CredentialsProducer: credsProducer,
|
||||||
}
|
}
|
||||||
|
|
||||||
return dbType, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run instantiates a MSSQL object, and runs the RPC server for the plugin
|
// Run instantiates a MSSQL object, and runs the RPC server for the plugin
|
||||||
@@ -53,7 +61,7 @@ func Run(apiTLSConfig *api.TLSConfig) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
plugins.Serve(dbType.(*MSSQL), apiTLSConfig)
|
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -79,13 +87,15 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := m.getConnection(ctx)
|
db, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if statements.CreationStatements == "" {
|
if len(statements.Creation) == 0 {
|
||||||
return "", "", dbutil.ErrEmptyCreationStatement
|
return "", "", dbutil.ErrEmptyCreationStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,23 +122,25 @@ func (m *MSSQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
// Execute each query
|
// Execute each query
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
for _, stmt := range statements.Creation {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
"password": password,
|
"password": password,
|
||||||
"expiration": expirationStr,
|
"expiration": expirationStr,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,7 +162,9 @@ func (m *MSSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, u
|
|||||||
// then kill pending connections from that user, and finally drop the user and login from the
|
// then kill pending connections from that user, and finally drop the user and login from the
|
||||||
// database instance.
|
// database instance.
|
||||||
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
if statements.RevocationStatements == "" {
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
|
if len(statements.Revocation) == 0 {
|
||||||
return m.revokeUserDefault(ctx, username)
|
return m.revokeUserDefault(ctx, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,21 +182,23 @@ func (m *MSSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
// Execute each query
|
// Execute each query
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") {
|
for _, stmt := range statements.Revocation {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,10 +299,10 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
|
|||||||
|
|
||||||
// can't drop if not all database users are dropped
|
// can't drop if not all database users are dropped
|
||||||
if rows.Err() != nil {
|
if rows.Err() != nil {
|
||||||
return fmt.Errorf("could not generate sql statements for all rows: %s", rows.Err())
|
return errwrap.Wrapf("could not generate sql statements for all rows: {{err}}", rows.Err())
|
||||||
}
|
}
|
||||||
if lastStmtError != nil {
|
if lastStmtError != nil {
|
||||||
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
|
return errwrap.Wrapf("could not perform all sql statements: {{err}}", lastStmtError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop this login
|
// Drop this login
|
||||||
@@ -302,6 +318,70 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MSSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
|
||||||
|
if len(m.Username) == 0 || len(m.Password) == 0 {
|
||||||
|
return nil, errors.New("username and password are required to rotate")
|
||||||
|
}
|
||||||
|
|
||||||
|
rotateStatents := statements
|
||||||
|
if len(rotateStatents) == 0 {
|
||||||
|
rotateStatents = []string{rotateRootCredentialsSQL}
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := m.getConnection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
|
password, err := m.GeneratePassword()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stmt := range rotateStatents {
|
||||||
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if len(query) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
|
"username": m.Username,
|
||||||
|
"password": password,
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer stmt.Close()
|
||||||
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RawConfig["password"] = password
|
||||||
|
return m.RawConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
const dropUserSQL = `
|
const dropUserSQL = `
|
||||||
USE [%s]
|
USE [%s]
|
||||||
IF EXISTS
|
IF EXISTS
|
||||||
@@ -322,3 +402,7 @@ BEGIN
|
|||||||
DROP LOGIN [%s]
|
DROP LOGIN [%s]
|
||||||
END
|
END
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const rotateRootCredentialsSQL = `
|
||||||
|
ALTER LOGIN [%s] WITH PASSWORD = '%s'
|
||||||
|
`
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -28,16 +27,13 @@ func TestMSSQL_Initialize(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*MSSQL)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
if !db.Initialized {
|
||||||
if !connProducer.Initialized {
|
|
||||||
t.Fatal("Database should be initalized")
|
t.Fatal("Database should be initalized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +48,7 @@ func TestMSSQL_Initialize(t *testing.T) {
|
|||||||
"max_open_connections": "5",
|
"max_open_connections": "5",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -68,9 +64,8 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*MSSQL)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -87,7 +82,7 @@ func TestMSSQL_CreateUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testMSSQLRole,
|
Creation: []string{testMSSQLRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
@@ -110,15 +105,14 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*MSSQL)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testMSSQLRole,
|
Creation: []string{testMSSQLRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -155,7 +149,7 @@ func TestMSSQL_RevokeUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test custom revoke statement
|
// Test custom revoke statement
|
||||||
statements.RevocationStatements = testMSSQLDrop
|
statements.Revocation = []string{testMSSQLDrop}
|
||||||
err = db.RevokeUser(context.Background(), statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package mysql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,6 +22,11 @@ const (
|
|||||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||||
DROP USER '{{name}}'@'%'
|
DROP USER '{{name}}'@'%'
|
||||||
`
|
`
|
||||||
|
|
||||||
|
defaultMySQLRotateRootCredentialsSQL = `
|
||||||
|
ALTER USER '{{username}}'@'%' IDENTIFIED BY '{{password}}';
|
||||||
|
`
|
||||||
|
|
||||||
mySQLTypeName = "mysql"
|
mySQLTypeName = "mysql"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,32 +40,38 @@ var (
|
|||||||
var _ dbplugin.Database = &MySQL{}
|
var _ dbplugin.Database = &MySQL{}
|
||||||
|
|
||||||
type MySQL struct {
|
type MySQL struct {
|
||||||
connutil.ConnectionProducer
|
*connutil.SQLConnectionProducer
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
// New implements builtinplugins.BuiltinFactory
|
// New implements builtinplugins.BuiltinFactory
|
||||||
func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, error) {
|
func New(displayNameLen, roleNameLen, usernameLen int) func() (interface{}, error) {
|
||||||
return func() (interface{}, error) {
|
return func() (interface{}, error) {
|
||||||
connProducer := &connutil.SQLConnectionProducer{}
|
db := new(displayNameLen, roleNameLen, usernameLen)
|
||||||
connProducer.Type = mySQLTypeName
|
// Wrap the plugin with middleware to sanitize errors
|
||||||
|
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
|
||||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
|
||||||
DisplayNameLen: displayNameLen,
|
|
||||||
RoleNameLen: roleNameLen,
|
|
||||||
UsernameLen: usernameLen,
|
|
||||||
Separator: "-",
|
|
||||||
}
|
|
||||||
|
|
||||||
dbType := &MySQL{
|
|
||||||
ConnectionProducer: connProducer,
|
|
||||||
CredentialsProducer: credsProducer,
|
|
||||||
}
|
|
||||||
|
|
||||||
return dbType, nil
|
return dbType, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func new(displayNameLen, roleNameLen, usernameLen int) *MySQL {
|
||||||
|
connProducer := &connutil.SQLConnectionProducer{}
|
||||||
|
connProducer.Type = mySQLTypeName
|
||||||
|
|
||||||
|
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||||
|
DisplayNameLen: displayNameLen,
|
||||||
|
RoleNameLen: roleNameLen,
|
||||||
|
UsernameLen: usernameLen,
|
||||||
|
Separator: "-",
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MySQL{
|
||||||
|
SQLConnectionProducer: connProducer,
|
||||||
|
CredentialsProducer: credsProducer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Run instantiates a MySQL object, and runs the RPC server for the plugin
|
// Run instantiates a MySQL object, and runs the RPC server for the plugin
|
||||||
func Run(apiTLSConfig *api.TLSConfig) error {
|
func Run(apiTLSConfig *api.TLSConfig) error {
|
||||||
return runCommon(false, apiTLSConfig)
|
return runCommon(false, apiTLSConfig)
|
||||||
@@ -82,7 +94,7 @@ func runCommon(legacy bool, apiTLSConfig *api.TLSConfig) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
plugins.Serve(dbType.(*MySQL), apiTLSConfig)
|
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -105,13 +117,15 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := m.getConnection(ctx)
|
db, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if statements.CreationStatements == "" {
|
if len(statements.Creation) == 0 {
|
||||||
return "", "", dbutil.ErrEmptyCreationStatement
|
return "", "", dbutil.ErrEmptyCreationStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,38 +152,40 @@ func (m *MySQL) CreateUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
// Execute each query
|
// Execute each query
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
for _, stmt := range statements.Creation {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
|
||||||
query = dbutil.QueryHelper(query, map[string]string{
|
|
||||||
"name": username,
|
|
||||||
"password": password,
|
|
||||||
"expiration": expirationStr,
|
|
||||||
})
|
|
||||||
|
|
||||||
stmt, err := tx.PrepareContext(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
// If the error code we get back is Error 1295: This command is not
|
|
||||||
// supported in the prepared statement protocol yet, we will execute
|
|
||||||
// the statement without preparing it. This allows the caller to
|
|
||||||
// manually prepare statements, as well as run other not yet
|
|
||||||
// prepare supported commands. If there is no error when running we
|
|
||||||
// will continue to the next statement.
|
|
||||||
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
|
|
||||||
_, err = tx.ExecContext(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
query = dbutil.QueryHelper(query, map[string]string{
|
||||||
|
"name": username,
|
||||||
|
"password": password,
|
||||||
|
"expiration": expirationStr,
|
||||||
|
})
|
||||||
|
|
||||||
return "", "", err
|
stmt, err := tx.PrepareContext(ctx, query)
|
||||||
}
|
if err != nil {
|
||||||
defer stmt.Close()
|
// If the error code we get back is Error 1295: This command is not
|
||||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
// supported in the prepared statement protocol yet, we will execute
|
||||||
return "", "", err
|
// the statement without preparing it. This allows the caller to
|
||||||
|
// manually prepare statements, as well as run other not yet
|
||||||
|
// prepare supported commands. If there is no error when running we
|
||||||
|
// will continue to the next statement.
|
||||||
|
if e, ok := err.(*stdmysql.MySQLError); ok && e.Number == 1295 {
|
||||||
|
_, err = tx.ExecContext(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,16 +207,18 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
// Get the connection
|
// Get the connection
|
||||||
db, err := m.getConnection(ctx)
|
db, err := m.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
revocationStmts := statements.RevocationStatements
|
revocationStmts := statements.Revocation
|
||||||
// Use a default SQL statement for revocation if one cannot be fetched from the role
|
// Use a default SQL statement for revocation if one cannot be fetched from the role
|
||||||
if revocationStmts == "" {
|
if len(revocationStmts) == 0 {
|
||||||
revocationStmts = defaultMysqlRevocationStmts
|
revocationStmts = []string{defaultMysqlRevocationStmts}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
@@ -210,21 +228,22 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
for _, stmt := range revocationStmts {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// This is not a prepared statement because not all commands are supported
|
// This is not a prepared statement because not all commands are supported
|
||||||
// 1295: This command is not supported in the prepared statement protocol yet
|
// 1295: This command is not supported in the prepared statement protocol yet
|
||||||
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
||||||
query = strings.Replace(query, "{{name}}", username, -1)
|
query = strings.Replace(query, "{{name}}", username, -1)
|
||||||
_, err = tx.ExecContext(ctx, query)
|
_, err = tx.ExecContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit the transaction
|
// Commit the transaction
|
||||||
@@ -234,3 +253,67 @@ func (m *MySQL) RevokeUser(ctx context.Context, statements dbplugin.Statements,
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MySQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
|
||||||
|
if len(m.Username) == 0 || len(m.Password) == 0 {
|
||||||
|
return nil, errors.New("username and password are required to rotate")
|
||||||
|
}
|
||||||
|
|
||||||
|
rotateStatents := statements
|
||||||
|
if len(rotateStatents) == 0 {
|
||||||
|
rotateStatents = []string{defaultMySQLRotateRootCredentialsSQL}
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := m.getConnection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
|
password, err := m.GeneratePassword()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stmt := range rotateStatents {
|
||||||
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if len(query) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
|
"username": m.Username,
|
||||||
|
"password": password,
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer stmt.Close()
|
||||||
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RawConfig["password"] = password
|
||||||
|
return m.RawConfig, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -104,17 +104,13 @@ func TestMySQL_Initialize(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
f := New(MetadataLen, MetadataLen, UsernameLen)
|
db := new(MetadataLen, MetadataLen, UsernameLen)
|
||||||
dbRaw, _ := f()
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
db := dbRaw.(*MySQL)
|
|
||||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !connProducer.Initialized {
|
if !db.Initialized {
|
||||||
t.Fatal("Database should be initalized")
|
t.Fatal("Database should be initalized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +125,7 @@ func TestMySQL_Initialize(t *testing.T) {
|
|||||||
"max_open_connections": "5",
|
"max_open_connections": "5",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -143,11 +139,8 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
f := New(MetadataLen, MetadataLen, UsernameLen)
|
db := new(MetadataLen, MetadataLen, UsernameLen)
|
||||||
dbRaw, _ := f()
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
db := dbRaw.(*MySQL)
|
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -164,7 +157,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testMySQLRoleWildCard,
|
Creation: []string{testMySQLRoleWildCard},
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
@@ -187,7 +180,7 @@ func TestMySQL_CreateUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test with a manually prepare statement
|
// Test with a manually prepare statement
|
||||||
statements.CreationStatements = testMySQLRolePreparedStmt
|
statements.Creation = []string{testMySQLRolePreparedStmt}
|
||||||
|
|
||||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -208,11 +201,8 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
f := New(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen)
|
db := new(credsutil.NoneLength, LegacyMetadataLen, LegacyUsernameLen)
|
||||||
dbRaw, _ := f()
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
db := dbRaw.(*MySQL)
|
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -229,7 +219,7 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testMySQLRoleWildCard,
|
Creation: []string{testMySQLRoleWildCard},
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
@@ -252,6 +242,42 @@ func TestMySQL_CreateUser_Legacy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMySQL_RotateRootCredentials(t *testing.T) {
|
||||||
|
cleanup, connURL := prepareMySQLTestContainer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
connURL = strings.Replace(connURL, "root:secret", `{{username}}:{{password}}`, -1)
|
||||||
|
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"connection_url": connURL,
|
||||||
|
"username": "root",
|
||||||
|
"password": "secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
db := new(MetadataLen, MetadataLen, UsernameLen)
|
||||||
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !db.Initialized {
|
||||||
|
t.Fatal("Database should be initalized")
|
||||||
|
}
|
||||||
|
|
||||||
|
newConf, err := db.RotateRootCredentials(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if newConf["password"] == "secret" {
|
||||||
|
t.Fatal("password was not updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMySQL_RevokeUser(t *testing.T) {
|
func TestMySQL_RevokeUser(t *testing.T) {
|
||||||
cleanup, connURL := prepareMySQLTestContainer(t)
|
cleanup, connURL := prepareMySQLTestContainer(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -260,17 +286,14 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
f := New(MetadataLen, MetadataLen, UsernameLen)
|
db := new(MetadataLen, MetadataLen, UsernameLen)
|
||||||
dbRaw, _ := f()
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
db := dbRaw.(*MySQL)
|
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testMySQLRoleWildCard,
|
Creation: []string{testMySQLRoleWildCard},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -297,7 +320,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||||||
t.Fatal("Credentials were not revoked")
|
t.Fatal("Credentials were not revoked")
|
||||||
}
|
}
|
||||||
|
|
||||||
statements.CreationStatements = testMySQLRoleWildCard
|
statements.Creation = []string{testMySQLRoleWildCard}
|
||||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
@@ -308,7 +331,7 @@ func TestMySQL_RevokeUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test custom revoke statements
|
// Test custom revoke statements
|
||||||
statements.RevocationStatements = testMySQLRevocationSQL
|
statements.Revocation = []string{testMySQLRevocationSQL}
|
||||||
err = db.RevokeUser(context.Background(), statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ package postgresql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/helper/strutil"
|
"github.com/hashicorp/vault/helper/strutil"
|
||||||
@@ -15,13 +17,15 @@ import (
|
|||||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
_ "github.com/lib/pq"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
postgreSQLTypeName string = "postgres"
|
postgreSQLTypeName = "postgres"
|
||||||
defaultPostgresRenewSQL = `
|
defaultPostgresRenewSQL = `
|
||||||
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
|
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
|
||||||
|
`
|
||||||
|
defaultPostgresRotateRootCredentialsSQL = `
|
||||||
|
ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,6 +33,13 @@ var _ dbplugin.Database = &PostgreSQL{}
|
|||||||
|
|
||||||
// New implements builtinplugins.BuiltinFactory
|
// New implements builtinplugins.BuiltinFactory
|
||||||
func New() (interface{}, error) {
|
func New() (interface{}, error) {
|
||||||
|
db := new()
|
||||||
|
// Wrap the plugin with middleware to sanitize errors
|
||||||
|
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues)
|
||||||
|
return dbType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func new() *PostgreSQL {
|
||||||
connProducer := &connutil.SQLConnectionProducer{}
|
connProducer := &connutil.SQLConnectionProducer{}
|
||||||
connProducer.Type = postgreSQLTypeName
|
connProducer.Type = postgreSQLTypeName
|
||||||
|
|
||||||
@@ -39,12 +50,12 @@ func New() (interface{}, error) {
|
|||||||
Separator: "-",
|
Separator: "-",
|
||||||
}
|
}
|
||||||
|
|
||||||
dbType := &PostgreSQL{
|
db := &PostgreSQL{
|
||||||
ConnectionProducer: connProducer,
|
SQLConnectionProducer: connProducer,
|
||||||
CredentialsProducer: credsProducer,
|
CredentialsProducer: credsProducer,
|
||||||
}
|
}
|
||||||
|
|
||||||
return dbType, nil
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin
|
// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin
|
||||||
@@ -54,13 +65,13 @@ func Run(apiTLSConfig *api.TLSConfig) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig)
|
plugins.Serve(dbType.(dbplugin.Database), apiTLSConfig)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type PostgreSQL struct {
|
type PostgreSQL struct {
|
||||||
connutil.ConnectionProducer
|
*connutil.SQLConnectionProducer
|
||||||
credsutil.CredentialsProducer
|
credsutil.CredentialsProducer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,7 +89,9 @@ func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
|
||||||
if statements.CreationStatements == "" {
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
|
if len(statements.Creation) == 0 {
|
||||||
return "", "", dbutil.ErrEmptyCreationStatement
|
return "", "", dbutil.ErrEmptyCreationStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +118,6 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
|
|||||||
db, err := p.getConnection(ctx)
|
db, err := p.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a transaction
|
// Start a transaction
|
||||||
@@ -120,25 +132,25 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme
|
|||||||
// Return the secret
|
// Return the secret
|
||||||
|
|
||||||
// Execute each query
|
// Execute each query
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
for _, stmt := range statements.Creation {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
|
||||||
"name": username,
|
|
||||||
"password": password,
|
|
||||||
"expiration": expirationStr,
|
|
||||||
}))
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
|
||||||
return "", "", err
|
|
||||||
|
|
||||||
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
|
"name": username,
|
||||||
|
"password": password,
|
||||||
|
"expiration": expirationStr,
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,9 +167,11 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
|
|||||||
p.Lock()
|
p.Lock()
|
||||||
defer p.Unlock()
|
defer p.Unlock()
|
||||||
|
|
||||||
renewStmts := statements.RenewStatements
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
if renewStmts == "" {
|
|
||||||
renewStmts = defaultPostgresRenewSQL
|
renewStmts := statements.Renewal
|
||||||
|
if len(renewStmts) == 0 {
|
||||||
|
renewStmts = []string{defaultPostgresRenewSQL}
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := p.getConnection(ctx)
|
db, err := p.getConnection(ctx)
|
||||||
@@ -178,30 +192,28 @@ func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statemen
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(renewStmts, ";") {
|
for _, stmt := range renewStmts {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
}
|
||||||
"name": username,
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"expiration": expirationStr,
|
"name": username,
|
||||||
}))
|
"expiration": expirationStr,
|
||||||
if err != nil {
|
}))
|
||||||
return err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
return tx.Commit()
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
|
||||||
@@ -209,14 +221,16 @@ func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Stateme
|
|||||||
p.Lock()
|
p.Lock()
|
||||||
defer p.Unlock()
|
defer p.Unlock()
|
||||||
|
|
||||||
if statements.RevocationStatements == "" {
|
statements = dbutil.StatementCompatibilityHelper(statements)
|
||||||
|
|
||||||
|
if len(statements.Revocation) == 0 {
|
||||||
return p.defaultRevokeUser(ctx, username)
|
return p.defaultRevokeUser(ctx, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.customRevokeUser(ctx, username, statements.RevocationStatements)
|
return p.customRevokeUser(ctx, username, statements.Revocation)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationStmts string) error {
|
func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error {
|
||||||
db, err := p.getConnection(ctx)
|
db, err := p.getConnection(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -230,30 +244,28 @@ func (p *PostgreSQL) customRevokeUser(ctx context.Context, username, revocationS
|
|||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
for _, stmt := range revocationStmts {
|
||||||
query = strings.TrimSpace(query)
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
if len(query) == 0 {
|
query = strings.TrimSpace(query)
|
||||||
continue
|
if len(query) == 0 {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
"name": username,
|
"name": username,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
|
|
||||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
return tx.Commit()
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
|
func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error {
|
||||||
@@ -354,10 +366,10 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
|
|||||||
|
|
||||||
// can't drop if not all privileges are revoked
|
// can't drop if not all privileges are revoked
|
||||||
if rows.Err() != nil {
|
if rows.Err() != nil {
|
||||||
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
|
return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err())
|
||||||
}
|
}
|
||||||
if lastStmtError != nil {
|
if lastStmtError != nil {
|
||||||
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
|
return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop this user
|
// Drop this user
|
||||||
@@ -373,3 +385,68 @@ func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) err
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
|
||||||
|
p.Lock()
|
||||||
|
defer p.Unlock()
|
||||||
|
|
||||||
|
if len(p.Username) == 0 || len(p.Password) == 0 {
|
||||||
|
return nil, errors.New("username and password are required to rotate")
|
||||||
|
}
|
||||||
|
|
||||||
|
rotateStatents := statements
|
||||||
|
if len(rotateStatents) == 0 {
|
||||||
|
rotateStatents = []string{defaultPostgresRotateRootCredentialsSQL}
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := p.getConnection(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
|
password, err := p.GeneratePassword()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stmt := range rotateStatents {
|
||||||
|
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if len(query) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stmt, err := tx.PrepareContext(ctx, dbutil.QueryHelper(query, map[string]string{
|
||||||
|
"username": p.Username,
|
||||||
|
"password": password,
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer stmt.Close()
|
||||||
|
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the database connection to ensure no new connections come in
|
||||||
|
if err := db.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.RawConfig["password"] = password
|
||||||
|
return p.RawConfig, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
|
||||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -68,17 +67,13 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
|||||||
"max_open_connections": 5,
|
"max_open_connections": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*PostgreSQL)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
|
|
||||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
|
||||||
|
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !connProducer.Initialized {
|
if !db.Initialized {
|
||||||
t.Fatal("Database should be initalized")
|
t.Fatal("Database should be initalized")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,7 +88,7 @@ func TestPostgreSQL_Initialize(t *testing.T) {
|
|||||||
"max_open_connections": "5",
|
"max_open_connections": "5",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.Initialize(context.Background(), connectionDetails, true)
|
_, err = db.Init(context.Background(), connectionDetails, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -108,9 +103,8 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*PostgreSQL)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
@@ -127,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testPostgresRole,
|
Creation: []string{testPostgresRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
@@ -139,7 +133,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
|
|||||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements.CreationStatements = testPostgresReadOnlyRole
|
statements.Creation = []string{testPostgresReadOnlyRole}
|
||||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
@@ -161,15 +155,14 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*PostgreSQL)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testPostgresRole,
|
Creation: []string{testPostgresRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -197,7 +190,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||||
}
|
}
|
||||||
statements.RenewStatements = defaultPostgresRenewSQL
|
statements.Renewal = []string{defaultPostgresRenewSQL}
|
||||||
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
@@ -221,6 +214,46 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPostgreSQL_RotateRootCredentials(t *testing.T) {
|
||||||
|
cleanup, connURL := preparePostgresTestContainer(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
connURL = strings.Replace(connURL, "postgres:secret", `{{username}}:{{password}}`, -1)
|
||||||
|
|
||||||
|
connectionDetails := map[string]interface{}{
|
||||||
|
"connection_url": connURL,
|
||||||
|
"max_open_connections": 5,
|
||||||
|
"username": "postgres",
|
||||||
|
"password": "secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
db := new()
|
||||||
|
|
||||||
|
connProducer := db.SQLConnectionProducer
|
||||||
|
|
||||||
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !connProducer.Initialized {
|
||||||
|
t.Fatal("Database should be initalized")
|
||||||
|
}
|
||||||
|
|
||||||
|
newConf, err := db.RotateRootCredentials(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if newConf["password"] == "secret" {
|
||||||
|
t.Fatal("password was not updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPostgreSQL_RevokeUser(t *testing.T) {
|
func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||||
cleanup, connURL := preparePostgresTestContainer(t)
|
cleanup, connURL := preparePostgresTestContainer(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -229,15 +262,14 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||||||
"connection_url": connURL,
|
"connection_url": connURL,
|
||||||
}
|
}
|
||||||
|
|
||||||
dbRaw, _ := New()
|
db := new()
|
||||||
db := dbRaw.(*PostgreSQL)
|
_, err := db.Init(context.Background(), connectionDetails, true)
|
||||||
err := db.Initialize(context.Background(), connectionDetails, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statements := dbplugin.Statements{
|
statements := dbplugin.Statements{
|
||||||
CreationStatements: testPostgresRole,
|
Creation: []string{testPostgresRole},
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameConfig := dbplugin.UsernameConfig{
|
usernameConfig := dbplugin.UsernameConfig{
|
||||||
@@ -274,7 +306,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test custom revoke statements
|
// Test custom revoke statements
|
||||||
statements.RevocationStatements = defaultPostgresRevocationSQL
|
statements.Revocation = []string{defaultPostgresRevocationSQL}
|
||||||
err = db.RevokeUser(context.Background(), statements, username)
|
err = db.RevokeUser(context.Background(), statements, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("err: %s", err)
|
t.Fatalf("err: %s", err)
|
||||||
@@ -286,6 +318,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||||
|
t.Helper()
|
||||||
// Log in with the new creds
|
// Log in with the new creds
|
||||||
connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1)
|
connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1)
|
||||||
db, err := sql.Open("postgres", connURL)
|
db, err := sql.Open("postgres", connURL)
|
||||||
|
|||||||
@@ -15,8 +15,11 @@ var (
|
|||||||
// connections and is used in all the builtin database types.
|
// connections and is used in all the builtin database types.
|
||||||
type ConnectionProducer interface {
|
type ConnectionProducer interface {
|
||||||
Close() error
|
Close() error
|
||||||
Initialize(context.Context, map[string]interface{}, bool) error
|
Init(context.Context, map[string]interface{}, bool) (map[string]interface{}, error)
|
||||||
Connection(context.Context) (interface{}, error)
|
Connection(context.Context) (interface{}, error)
|
||||||
|
|
||||||
sync.Locker
|
sync.Locker
|
||||||
|
|
||||||
|
// DEPRECATED, will be removed in 0.12
|
||||||
|
Initialize(context.Context, map[string]interface{}, bool) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,18 +8,25 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/errwrap"
|
||||||
"github.com/hashicorp/vault/helper/parseutil"
|
"github.com/hashicorp/vault/helper/parseutil"
|
||||||
|
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ ConnectionProducer = &SQLConnectionProducer{}
|
||||||
|
|
||||||
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
||||||
type SQLConnectionProducer struct {
|
type SQLConnectionProducer struct {
|
||||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
|
||||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
|
||||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
|
||||||
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
|
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
|
||||||
|
Username string `json:"username" mapstructure:"username" structs:"username"`
|
||||||
|
Password string `json:"password" mapstructure:"password" structs:"password"`
|
||||||
|
|
||||||
Type string
|
Type string
|
||||||
|
RawConfig map[string]interface{}
|
||||||
maxConnectionLifetime time.Duration
|
maxConnectionLifetime time.Duration
|
||||||
Initialized bool
|
Initialized bool
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
@@ -27,18 +34,30 @@ type SQLConnectionProducer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
||||||
|
_, err := c.Init(ctx, conf, verifyConnection)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
|
||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
err := mapstructure.WeakDecode(conf, c)
|
c.RawConfig = conf
|
||||||
|
|
||||||
|
err := mapstructure.WeakDecode(conf, &c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.ConnectionURL) == 0 {
|
if len(c.ConnectionURL) == 0 {
|
||||||
return fmt.Errorf("connection_url cannot be empty")
|
return nil, fmt.Errorf("connection_url cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
|
||||||
|
"username": c.Username,
|
||||||
|
"password": c.Password,
|
||||||
|
})
|
||||||
|
|
||||||
if c.MaxOpenConnections == 0 {
|
if c.MaxOpenConnections == 0 {
|
||||||
c.MaxOpenConnections = 2
|
c.MaxOpenConnections = 2
|
||||||
}
|
}
|
||||||
@@ -55,7 +74,7 @@ func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]
|
|||||||
|
|
||||||
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
|
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid max_connection_lifetime: %s", err)
|
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set initialized to true at this point since all fields are set,
|
// Set initialized to true at this point since all fields are set,
|
||||||
@@ -64,15 +83,15 @@ func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]
|
|||||||
|
|
||||||
if verifyConnection {
|
if verifyConnection {
|
||||||
if _, err := c.Connection(ctx); err != nil {
|
if _, err := c.Connection(ctx); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.db.PingContext(ctx); err != nil {
|
if err := c.db.PingContext(ctx); err != nil {
|
||||||
return fmt.Errorf("error verifying connection: %s", err)
|
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return c.RawConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
||||||
@@ -123,6 +142,12 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
|
|||||||
return c.db, nil
|
return c.db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *SQLConnectionProducer) SecretValues() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
c.Password: "[password]",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Close attempts to close the connection
|
// Close attempts to close the connection
|
||||||
func (c *SQLConnectionProducer) Close() error {
|
func (c *SQLConnectionProducer) Close() error {
|
||||||
// Grab the write lock
|
// Grab the write lock
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -18,3 +20,33 @@ func QueryHelper(tpl string, data map[string]string) string {
|
|||||||
|
|
||||||
return tpl
|
return tpl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StatementCompatibilityHelper will populate the statements fields to support
|
||||||
|
// compatibility
|
||||||
|
func StatementCompatibilityHelper(statements dbplugin.Statements) dbplugin.Statements {
|
||||||
|
switch {
|
||||||
|
case len(statements.Creation) > 0 && len(statements.CreationStatements) == 0:
|
||||||
|
statements.CreationStatements = strings.Join(statements.Creation, ";")
|
||||||
|
case len(statements.CreationStatements) > 0:
|
||||||
|
statements.Creation = []string{statements.CreationStatements}
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case len(statements.Revocation) > 0 && len(statements.RevocationStatements) == 0:
|
||||||
|
statements.RevocationStatements = strings.Join(statements.Revocation, ";")
|
||||||
|
case len(statements.RevocationStatements) > 0:
|
||||||
|
statements.Revocation = []string{statements.RevocationStatements}
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case len(statements.Renewal) > 0 && len(statements.RenewStatements) == 0:
|
||||||
|
statements.RenewStatements = strings.Join(statements.Renewal, ";")
|
||||||
|
case len(statements.RenewStatements) > 0:
|
||||||
|
statements.Renewal = []string{statements.RenewStatements}
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case len(statements.Rollback) > 0 && len(statements.RollbackStatements) == 0:
|
||||||
|
statements.RollbackStatements = strings.Join(statements.Rollback, ";")
|
||||||
|
case len(statements.RollbackStatements) > 0:
|
||||||
|
statements.Rollback = []string{statements.RollbackStatements}
|
||||||
|
}
|
||||||
|
return statements
|
||||||
|
}
|
||||||
|
|||||||
62
plugins/helper/database/dbutil/dbutil_test.go
Normal file
62
plugins/helper/database/dbutil/dbutil_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package dbutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStatementCompatibilityHelper(t *testing.T) {
|
||||||
|
const (
|
||||||
|
creationStatement = "creation"
|
||||||
|
renewStatement = "renew"
|
||||||
|
revokeStatement = "revoke"
|
||||||
|
rollbackStatement = "rollback"
|
||||||
|
)
|
||||||
|
|
||||||
|
expectedStatements := dbplugin.Statements{
|
||||||
|
Creation: []string{creationStatement},
|
||||||
|
Rollback: []string{rollbackStatement},
|
||||||
|
Revocation: []string{revokeStatement},
|
||||||
|
Renewal: []string{renewStatement},
|
||||||
|
CreationStatements: creationStatement,
|
||||||
|
RenewStatements: renewStatement,
|
||||||
|
RollbackStatements: rollbackStatement,
|
||||||
|
RevocationStatements: revokeStatement,
|
||||||
|
}
|
||||||
|
|
||||||
|
statements1 := dbplugin.Statements{
|
||||||
|
CreationStatements: creationStatement,
|
||||||
|
RenewStatements: renewStatement,
|
||||||
|
RollbackStatements: rollbackStatement,
|
||||||
|
RevocationStatements: revokeStatement,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements1)) {
|
||||||
|
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements1)
|
||||||
|
}
|
||||||
|
|
||||||
|
statements2 := dbplugin.Statements{
|
||||||
|
Creation: []string{creationStatement},
|
||||||
|
Rollback: []string{rollbackStatement},
|
||||||
|
Revocation: []string{revokeStatement},
|
||||||
|
Renewal: []string{renewStatement},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expectedStatements, StatementCompatibilityHelper(statements2)) {
|
||||||
|
t.Fatalf("mismatch: %#v, %#v", expectedStatements, statements2)
|
||||||
|
}
|
||||||
|
|
||||||
|
statements3 := dbplugin.Statements{
|
||||||
|
CreationStatements: creationStatement,
|
||||||
|
}
|
||||||
|
expectedStatements3 := dbplugin.Statements{
|
||||||
|
Creation: []string{creationStatement},
|
||||||
|
CreationStatements: creationStatement,
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(expectedStatements3, StatementCompatibilityHelper(statements3)) {
|
||||||
|
t.Fatalf("mismatch: %#v, %#v", expectedStatements3, statements3)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user