mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
Move plugins into main vault repo
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
package builtinplugins
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/vault-plugins/database/mysql"
|
||||
"github.com/hashicorp/vault-plugins/database/postgresql"
|
||||
"github.com/hashicorp/vault/plugins/database/mysql"
|
||||
"github.com/hashicorp/vault/plugins/database/postgresql"
|
||||
)
|
||||
|
||||
var BuiltinPlugins *builtinPlugins = &builtinPlugins{
|
||||
|
||||
16
plugins/database/mssql/mssql-database-plugin/main.go
Normal file
16
plugins/database/mssql/mssql-database-plugin/main.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/plugins/database/mssql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := mssql.Run()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
268
plugins/database/mssql/mssql.go
Normal file
268
plugins/database/mssql/mssql.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
)
|
||||
|
||||
const msSQLTypeName = "mssql"
|
||||
|
||||
// MSSQL is an implementation of DatabaseType interface
|
||||
type MSSQL struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
func New() *MSSQL {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
connProducer.Type = msSQLTypeName
|
||||
|
||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||
DisplayNameLen: 4,
|
||||
UsernameLen: 16,
|
||||
}
|
||||
|
||||
dbType := &MSSQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
return dbType
|
||||
}
|
||||
|
||||
// Run instantiates a MSSQL object, and runs the RPC server for the plugin
|
||||
func Run() error {
|
||||
dbType := New()
|
||||
|
||||
dbplugin.NewPluginServer(dbType)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Type returns the TypeName for this backend
|
||||
func (m *MSSQL) Type() (string, error) {
|
||||
return msSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MSSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
username, err = m.GenerateUsername(usernamePrefix)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
password, err = m.GeneratePassword()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
expirationStr, err := m.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
}))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
// RenewUser is not supported on MSSQL, so this is a no-op.
|
||||
func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
|
||||
// then kill pending connections from that user, and finally drop the user and login from the
|
||||
// database instance.
|
||||
func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// First disable server login
|
||||
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer disableStmt.Close()
|
||||
if _, err := disableStmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Query for sessions for the login so that we can kill any outstanding
|
||||
// sessions. There cannot be any active sessions before we drop the logins
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
sessionStmt, err := db.Prepare(fmt.Sprintf(
|
||||
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionStmt.Close()
|
||||
|
||||
sessionRows, err := sessionStmt.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionRows.Close()
|
||||
|
||||
var revokeStmts []string
|
||||
for sessionRows.Next() {
|
||||
var sessionID int
|
||||
err = sessionRows.Scan(&sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
|
||||
}
|
||||
|
||||
// Query for database users using undocumented stored procedure for now since
|
||||
// it is the easiest way to get this information;
|
||||
// we need to drop the database users before we can drop the login and the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var loginName, dbName, qUsername string
|
||||
var aliasName sql.NullString
|
||||
err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username))
|
||||
}
|
||||
|
||||
// we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revokeStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all database users are dropped
|
||||
if rows.Err() != nil {
|
||||
return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this login
|
||||
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const dropUserSQL = `
|
||||
USE [%s]
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM sys.database_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP USER [%s]
|
||||
END
|
||||
`
|
||||
|
||||
const dropLoginSQL = `
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM master.sys.server_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP LOGIN [%s]
|
||||
END
|
||||
`
|
||||
173
plugins/database/mssql/mssql_test.go
Normal file
173
plugins/database/mssql/mssql_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testMSQLImagePull sync.Once
|
||||
)
|
||||
|
||||
func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) {
|
||||
if os.Getenv("MSSQL_URL") != "" {
|
||||
return func() {}, os.Getenv("MSSQL_URL")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local MSSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp"))
|
||||
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
var err error
|
||||
var db *sql.DB
|
||||
db, err = sql.Open("mssql", retURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Ping()
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to MSSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestMSSQL_Initialize(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
if !connProducer.Initialized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSQL_CreateUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSQL_RevokeUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||
// Log in with the new creds
|
||||
connURL = strings.Replace(connURL, "sa:yourStrong(!)Password", fmt.Sprintf("%s:%s", username, password), 1)
|
||||
db, err := sql.Open("mssql", connURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
const testMSSQLRole = `
|
||||
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
|
||||
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
|
||||
GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`
|
||||
16
plugins/database/mysql/mysql-database-plugin/main.go
Normal file
16
plugins/database/mysql/mysql-database-plugin/main.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/plugins/database/mysql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := mysql.Run()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
183
plugins/database/mysql/mysql.go
Normal file
183
plugins/database/mysql/mysql.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
)
|
||||
|
||||
const defaultMysqlRevocationStmts = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||
DROP USER '{{name}}'@'%'
|
||||
`
|
||||
const mySQLTypeName = "mysql"
|
||||
|
||||
type MySQL struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
func New() *MySQL {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
connProducer.Type = mySQLTypeName
|
||||
|
||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||
DisplayNameLen: 4,
|
||||
UsernameLen: 16,
|
||||
}
|
||||
|
||||
dbType := &MySQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
return dbType
|
||||
}
|
||||
|
||||
// Run instantiates a MySQL object, and runs the RPC server for the plugin
|
||||
func Run() error {
|
||||
dbType := New()
|
||||
|
||||
dbplugin.NewPluginServer(dbType)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) Type() (string, error) {
|
||||
return mySQLTypeName, nil
|
||||
}
|
||||
|
||||
func (m *MySQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (m *MySQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
username, err = m.GenerateUsername(usernamePrefix)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
password, err = m.GeneratePassword()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
expirationStr, err := m.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
}))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
// NOOP
|
||||
func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
// Grab the read lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
revocationStmts := statements.RevocationStatements
|
||||
// Use a default SQL statement for revocation if one cannot be fetched from the role
|
||||
if revocationStmts == "" {
|
||||
revocationStmts = defaultMysqlRevocationStmts
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// This is not a prepared statement because not all commands are supported
|
||||
// 1295: This command is not supported in the prepared statement protocol yet
|
||||
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
||||
query = strings.Replace(query, "{{name}}", username, -1)
|
||||
_, err = tx.Exec(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
200
plugins/database/mysql/mysql_test.go
Normal file
200
plugins/database/mysql/mysql_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testMySQLImagePull sync.Once
|
||||
)
|
||||
|
||||
func prepareMySQLTestContainer(t *testing.T) (cleanup func(), retURL string) {
|
||||
if os.Getenv("MYSQL_URL") != "" {
|
||||
return func() {}, os.Getenv("MYSQL_URL")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("mysql", "latest", []string{"MYSQL_ROOT_PASSWORD=secret"})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local MySQL docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("root:secret@(localhost:%s)/mysql?parseTime=true", resource.GetPort("3306/tcp"))
|
||||
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
var err error
|
||||
var db *sql.DB
|
||||
db, err = sql.Open("mysql", retURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Ping()
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to MySQL docker container: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestMySQL_Initialize(t *testing.T) {
|
||||
cleanup, connURL := prepareMySQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.Initialized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQL_CreateUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMySQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQL_RevokeUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMySQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
statements.CreationStatements = testMySQLRoleWildCard
|
||||
username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = testMySQLRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||
// Log in with the new creds
|
||||
connURL = strings.Replace(connURL, "root:secret", fmt.Sprintf("%s:%s", username, password), 1)
|
||||
db, err := sql.Open("mysql", connURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
const testMySQLRoleWildCard = `
|
||||
CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';
|
||||
GRANT SELECT ON *.* TO '{{name}}'@'%';
|
||||
`
|
||||
const testMySQLRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||
DROP USER '{{name}}'@'%';
|
||||
`
|
||||
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/vault/plugins/database/postgresql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
err := postgresql.Run()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
337
plugins/database/postgresql/postgresql.go
Normal file
337
plugins/database/postgresql/postgresql.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/credsutil"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/dbutil"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const postgreSQLTypeName string = "postgres"
|
||||
|
||||
func New() *PostgreSQL {
|
||||
connProducer := &connutil.SQLConnectionProducer{}
|
||||
connProducer.Type = postgreSQLTypeName
|
||||
|
||||
credsProducer := &credsutil.SQLCredentialsProducer{
|
||||
DisplayNameLen: 4,
|
||||
UsernameLen: 16,
|
||||
}
|
||||
|
||||
dbType := &PostgreSQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
return dbType
|
||||
}
|
||||
|
||||
// Run instatiates a PostgreSQL object, and runs the RPC server for the plugin
|
||||
func Run() error {
|
||||
dbType := New()
|
||||
|
||||
dbplugin.NewPluginServer(dbType)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type PostgreSQL struct {
|
||||
connutil.ConnectionProducer
|
||||
credsutil.CredentialsProducer
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Type() (string, error) {
|
||||
return postgreSQLTypeName, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) {
|
||||
if statements.CreationStatements == "" {
|
||||
return "", "", dbutil.ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
username, err = p.GenerateUsername(usernamePrefix)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
password, err = p.GeneratePassword()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
expirationStr, err := p.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
// Return the secret
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expirationStr,
|
||||
}))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", err
|
||||
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expirationStr, err := p.GenerateExpiration(expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"ALTER ROLE %s VALID UNTIL '%s';",
|
||||
pq.QuoteIdentifier(username),
|
||||
expirationStr)
|
||||
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if statements.RevocationStatements == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, statements.RevocationStatements)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the role exists
|
||||
var exists bool
|
||||
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists == false {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query for permissions; we need to revoke permissions before we can drop
|
||||
// the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
const initialNumRevocations = 16
|
||||
revocationStmts := make([]string, 0, initialNumRevocations)
|
||||
for rows.Next() {
|
||||
var schema string
|
||||
err = rows.Scan(&schema)
|
||||
if err != nil {
|
||||
// keep going; remove as many permissions as possible right now
|
||||
continue
|
||||
}
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
|
||||
pq.QuoteIdentifier(schema),
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
|
||||
pq.QuoteIdentifier(schema),
|
||||
pq.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// for good measure, revoke all privileges and usage on schema public
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE USAGE ON SCHEMA public FROM %s;",
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
// get the current database name so we can issue a REVOKE CONNECT for
|
||||
// this username
|
||||
var dbname sql.NullString
|
||||
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if dbname.Valid {
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
|
||||
pq.QuoteIdentifier(dbname.String),
|
||||
pq.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// again, here, we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revocationStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all privileges are revoked
|
||||
if rows.Err() != nil {
|
||||
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this user
|
||||
stmt, err = db.Prepare(fmt.Sprintf(
|
||||
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
308
plugins/database/postgresql/postgresql_test.go
Normal file
308
plugins/database/postgresql/postgresql_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
|
||||
"github.com/hashicorp/vault/plugins/helper/database/connutil"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testPostgresImagePull sync.Once
|
||||
)
|
||||
|
||||
func preparePostgresTestContainer(t *testing.T) (cleanup func(), retURL string) {
|
||||
if os.Getenv("PG_URL") != "" {
|
||||
return func() {}, os.Getenv("PG_URL")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local PostgreSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp"))
|
||||
|
||||
// exponential backoff-retry
|
||||
if err = pool.Retry(func() error {
|
||||
var err error
|
||||
var db *sql.DB
|
||||
db, err = sql.Open("postgres", retURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Ping()
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to PostgreSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestPostgreSQL_Initialize(t *testing.T) {
|
||||
cleanup, connURL := preparePostgresTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer)
|
||||
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.Initialized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
cleanup, connURL := preparePostgresTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
_, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute))
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
statements.CreationStatements = testPostgresReadOnlyRole
|
||||
username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
cleanup, connURL := preparePostgresTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Sleep longer than the inital expiration time
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
cleanup, connURL := preparePostgresTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
db := New()
|
||||
err := db.Initialize(connectionDetails, true)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := dbplugin.Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
|
||||
username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err = testCredsExist(t, connURL, username, password); err != nil {
|
||||
t.Fatalf("Could not connect with new credentials: %s", err)
|
||||
}
|
||||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = defaultPostgresRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if err := testCredsExist(t, connURL, username, password); err == nil {
|
||||
t.Fatal("Credentials were not revoked")
|
||||
}
|
||||
}
|
||||
|
||||
func testCredsExist(t testing.TB, connURL, username, password string) error {
|
||||
// Log in with the new creds
|
||||
connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1)
|
||||
db, err := sql.Open("postgres", connURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
const testPostgresRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testPostgresReadOnlyRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testPostgresBlockStatementRole = `
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
|
||||
CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';
|
||||
GRANT "foo-role" TO "{{name}}";
|
||||
ALTER ROLE "{{name}}" SET search_path = foo;
|
||||
GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";
|
||||
`
|
||||
|
||||
var testPostgresBlockStatementRoleSlice = []string{
|
||||
`
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
`,
|
||||
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
|
||||
`GRANT "foo-role" TO "{{name}}";`,
|
||||
`ALTER ROLE "{{name}}" SET search_path = foo;`,
|
||||
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
|
||||
}
|
||||
|
||||
const defaultPostgresRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
|
||||
|
||||
DROP ROLE IF EXISTS "{{name}}";
|
||||
`
|
||||
172
plugins/helper/database/connutil/cassandra.go
Normal file
172
plugins/helper/database/connutil/cassandra.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package connutil
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
)
|
||||
|
||||
// CassandraConnectionProducer implements ConnectionProducer and provides an
|
||||
// interface for cassandra databases to make connections.
|
||||
type CassandraConnectionProducer struct {
|
||||
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
|
||||
Username string `json:"username" structs:"username" mapstructure:"username"`
|
||||
Password string `json:"password" structs:"password" mapstructure:"password"`
|
||||
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
|
||||
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
|
||||
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
|
||||
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
|
||||
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
|
||||
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
|
||||
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
|
||||
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
|
||||
|
||||
Initialized bool
|
||||
session *gocql.Session
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
err := mapstructure.Decode(conf, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Initialized = true
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.connection(); err != nil {
|
||||
return fmt.Errorf("error Initalizing Connection: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CassandraConnectionProducer) connection() (interface{}, error) {
|
||||
if !c.Initialized {
|
||||
return nil, errNotInitialized
|
||||
}
|
||||
|
||||
// If we already have a DB, return it
|
||||
if c.session != nil {
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
session, err := c.createSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the session in backend for reuse
|
||||
c.session = session
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *CassandraConnectionProducer) Close() error {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if c.session != nil {
|
||||
c.session.Close()
|
||||
}
|
||||
|
||||
c.session = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
||||
clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: c.Username,
|
||||
Password: c.Password,
|
||||
}
|
||||
|
||||
clusterConfig.ProtoVersion = c.ProtocolVersion
|
||||
if clusterConfig.ProtoVersion == 0 {
|
||||
clusterConfig.ProtoVersion = 2
|
||||
}
|
||||
|
||||
clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second
|
||||
|
||||
if c.TLS {
|
||||
var tlsConfig *tls.Config
|
||||
if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 {
|
||||
if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 {
|
||||
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
certBundle := &certutil.CertBundle{}
|
||||
if len(c.Certificate) > 0 {
|
||||
certBundle.Certificate = c.Certificate
|
||||
certBundle.PrivateKey = c.PrivateKey
|
||||
}
|
||||
if len(c.IssuingCA) > 0 {
|
||||
certBundle.IssuingCA = c.IssuingCA
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = c.InsecureTLS
|
||||
|
||||
if c.TLSMinVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
}
|
||||
|
||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||
Config: *tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
|
||||
// Set consistency
|
||||
if c.Consistency != "" {
|
||||
consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.SetConsistency(consistencyValue)
|
||||
}
|
||||
|
||||
// Verify the info
|
||||
err = session.Query(`LIST USERS`).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error validating connection info: %s", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
21
plugins/helper/database/connutil/connutil.go
Normal file
21
plugins/helper/database/connutil/connutil.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package connutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
errNotInitialized = errors.New("connection has not been initalized")
|
||||
)
|
||||
|
||||
// ConnectionProducer can be used as an embeded interface in the DatabaseType
|
||||
// definition. It implements the methods dealing with individual database
|
||||
// connections and is used in all the builtin database types.
|
||||
type ConnectionProducer interface {
|
||||
Close() error
|
||||
Initialize(map[string]interface{}, bool) error
|
||||
Connection() (interface{}, error)
|
||||
|
||||
sync.Locker
|
||||
}
|
||||
131
plugins/helper/database/connutil/sql.go
Normal file
131
plugins/helper/database/connutil/sql.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package connutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
// Import sql drivers
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
||||
type SQLConnectionProducer struct {
|
||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
MaxConnectionLifetimeRaw string `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
|
||||
|
||||
Type string
|
||||
MaxConnectionLifetime time.Duration
|
||||
Initialized bool
|
||||
db *sql.DB
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
err := mapstructure.Decode(conf, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.MaxOpenConnections == 0 {
|
||||
c.MaxOpenConnections = 2
|
||||
}
|
||||
|
||||
if c.MaxIdleConnections == 0 {
|
||||
c.MaxIdleConnections = c.MaxOpenConnections
|
||||
}
|
||||
if c.MaxIdleConnections > c.MaxOpenConnections {
|
||||
c.MaxIdleConnections = c.MaxOpenConnections
|
||||
}
|
||||
if c.MaxConnectionLifetimeRaw == "" {
|
||||
c.MaxConnectionLifetimeRaw = "0s"
|
||||
}
|
||||
|
||||
c.MaxConnectionLifetime, err = time.ParseDuration(c.MaxConnectionLifetimeRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid max_connection_lifetime: %s", err)
|
||||
}
|
||||
|
||||
if verifyConnection {
|
||||
if _, err := c.Connection(); err != nil {
|
||||
return fmt.Errorf("error initalizing connection: %s", err)
|
||||
}
|
||||
|
||||
if err := c.db.Ping(); err != nil {
|
||||
return fmt.Errorf("error initalizing connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.Initialized = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) Connection() (interface{}, error) {
|
||||
// If we already have a DB, test it and return
|
||||
if c.db != nil {
|
||||
if err := c.db.Ping(); err == nil {
|
||||
return c.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
// For mssql backend, switch to sqlserver instead
|
||||
dbType := c.Type
|
||||
if c.Type == "mssql" {
|
||||
dbType = "sqlserver"
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
conn := c.ConnectionURL
|
||||
|
||||
// Ensure timezone is set to UTC for all the conenctions
|
||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||
if strings.Contains(conn, "?") {
|
||||
conn += "&timezone=utc"
|
||||
} else {
|
||||
conn += "?timezone=utc"
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
c.db, err = sql.Open(dbType, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
c.db.SetMaxOpenConns(c.MaxOpenConnections)
|
||||
c.db.SetMaxIdleConns(c.MaxIdleConnections)
|
||||
c.db.SetConnMaxLifetime(c.MaxConnectionLifetime)
|
||||
|
||||
return c.db, nil
|
||||
}
|
||||
|
||||
// Close attempts to close the connection
|
||||
func (c *SQLConnectionProducer) Close() error {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if c.db != nil {
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
c.db = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
37
plugins/helper/database/credsutil/cassandra.go
Normal file
37
plugins/helper/database/credsutil/cassandra.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package credsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
// CassandraCredentialsProducer implements CredentialsProducer and provides an
|
||||
// interface for cassandra databases to generate user information.
|
||||
type CassandraCredentialsProducer struct{}
|
||||
|
||||
func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix())
|
||||
username = strings.Replace(username, "-", "_", -1)
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (ccp *CassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
12
plugins/helper/database/credsutil/credsutil.go
Normal file
12
plugins/helper/database/credsutil/credsutil.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package credsutil
|
||||
|
||||
import "time"
|
||||
|
||||
// CredentialsProducer can be used as an embeded interface in the DatabaseType
|
||||
// definition. It implements the methods for generating user information for a
|
||||
// particular database type and is used in all the builtin database types.
|
||||
type CredentialsProducer interface {
|
||||
GenerateUsername(displayName string) (string, error)
|
||||
GeneratePassword() (string, error)
|
||||
GenerateExpiration(ttl time.Time) (string, error)
|
||||
}
|
||||
43
plugins/helper/database/credsutil/sql.go
Normal file
43
plugins/helper/database/credsutil/sql.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package credsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
// SQLCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types.
|
||||
type SQLCredentialsProducer struct {
|
||||
DisplayNameLen int
|
||||
UsernameLen int
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||
if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen {
|
||||
displayName = displayName[:scp.DisplayNameLen]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
||||
if scp.UsernameLen > 0 && len(username) > scp.UsernameLen {
|
||||
username = username[:scp.UsernameLen]
|
||||
}
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (scp *SQLCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
|
||||
return ttl.Format("2006-01-02 15:04:05-0700"), nil
|
||||
}
|
||||
20
plugins/helper/database/dbutil/dbutil.go
Normal file
20
plugins/helper/database/dbutil/dbutil.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyCreationStatement = errors.New("empty creation statements")
|
||||
)
|
||||
|
||||
// Query templates a query for us.
|
||||
func QueryHelper(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
||||
Reference in New Issue
Block a user