Begin work on database refactor

This commit is contained in:
Brian Kassouf
2016-12-19 10:15:58 -08:00
committed by Brian Kassouf
parent daf2dd6995
commit 3d77a9a6f4
11 changed files with 2031 additions and 0 deletions

View File

@@ -0,0 +1,104 @@
package database
import (
"strings"
"sync"
log "github.com/mgutz/logxi/v1"
"github.com/hashicorp/vault/builtin/logical/database/dbs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return Backend(conf).Setup(conf)
}
func Backend(conf *logical.BackendConfig) *databaseBackend {
var b databaseBackend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
Paths: []*framework.Path{
pathConfigConnection(&b),
pathConfigLease(&b),
pathListRoles(&b),
pathRoles(&b),
pathRoleCreate(&b),
},
Secrets: []*framework.Secret{
secretCreds(&b),
},
Clean: b.resetAllDBs,
}
b.logger = conf.Logger
b.connections = make(map[string]dbs.DatabaseType)
return &b
}
type databaseBackend struct {
connections map[string]dbs.DatabaseType
logger log.Logger
*framework.Backend
sync.RWMutex
}
// resetAllDBs closes all connections from all database types
func (b *databaseBackend) resetAllDBs() {
b.logger.Trace("postgres/resetdb: enter")
defer b.logger.Trace("postgres/resetdb: exit")
b.Lock()
defer b.Unlock()
for _, db := range b.connections {
db.Close()
}
}
// Lease returns the lease information
func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result configLease
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get("role/" + n)
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
var result roleEntry
if err := entry.DecodeJSON(&result); err != nil {
return nil, err
}
return &result, nil
}
const backendHelp = `
The PostgreSQL backend dynamically generates database users.
After mounting this backend, configure it using the endpoints within
the "config/" path.
`

View File

@@ -0,0 +1,620 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"os"
"path"
"reflect"
"sync"
"testing"
"time"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/lib/pq"
"github.com/mitchellh/mapstructure"
"github.com/ory-am/dockertest"
)
var (
testImagePull sync.Once
)
func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) {
if os.Getenv("PG_URL") != "" {
return "", os.Getenv("PG_URL")
}
// Without this the checks for whether the container has started seem to
// never actually pass. There's really no reason to expose the test
// containers, so don't.
dockertest.BindDockerToLocalhost = "yep"
testImagePull.Do(func() {
dockertest.Pull("postgres")
})
cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool {
// This will cause a validation to run
resp, err := b.HandleRequest(&logical.Request{
Storage: s,
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: map[string]interface{}{
"connection_url": connURL,
},
})
if err != nil || (resp != nil && resp.IsError()) {
// It's likely not up and running yet, so return false and try again
return false
}
if resp == nil {
t.Fatal("expected warning")
}
retURL = connURL
return true
})
if connErr != nil {
t.Fatalf("could not connect to database: %v", connErr)
}
return
}
func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
err := cid.KillRemove()
if err != nil {
t.Fatal(err)
}
}
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
"value": "",
"max_open_connections": 9,
"max_idle_connections": 7,
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/connection",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(configData, "verify_connection")
if !reflect.DeepEqual(configData, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
}
}
func TestBackend_basic(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRole(t, "web", testRole, false),
testAccStepReadCreds(t, b, config.StorageView, "web", connURL),
},
})
}
func TestBackend_roleCrud(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRole(t, "web", testRole, false),
testAccStepReadRole(t, "web", testRole),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web", ""),
},
})
}
func TestBackend_BlockStatements(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
connData := map[string]interface{}{
"connection_url": connURL,
}
jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice)
if err != nil {
t.Fatal(err)
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
// This will also validate the query
testAccStepCreateRole(t, "web-block", testBlockStatementRole, true),
testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false),
},
})
}
func TestBackend_roleReadOnly(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRole(t, "web", testRole, false),
testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false),
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
testAccStepDeleteRole(t, "web-readonly"),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web-readonly", ""),
},
})
}
func TestBackend_roleReadOnly_revocationSQL(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false),
testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false),
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
testAccStepDeleteRole(t, "web-readonly"),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web-readonly", ""),
},
})
}
func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/connection",
Data: d,
ErrorOk: true,
Check: func(resp *logical.Response) error {
if expectError {
if resp.Data == nil {
return fmt.Errorf("data is nil")
}
var e struct {
Error string `mapstructure:"error"`
}
if err := mapstructure.Decode(resp.Data, &e); err != nil {
return err
}
if len(e.Error) == 0 {
return fmt.Errorf("expected error, but write succeeded.")
}
return nil
} else if resp != nil && resp.IsError() {
return fmt.Errorf("got an error response: %v", resp.Error())
}
return nil
},
}
}
func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: path.Join("roles", name),
Data: map[string]interface{}{
"sql": sql,
},
ErrorOk: expectFail,
}
}
func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: path.Join("roles", name),
Data: map[string]interface{}{
"sql": sql,
"revocation_sql": revocationSQL,
},
ErrorOk: expectFail,
}
}
func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: path.Join("roles", name),
}
}
func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
if err != nil {
t.Fatal(err)
}
returnedRows := func() int {
stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');")
if err != nil {
return -1
}
defer stmt.Close()
rows, err := stmt.Query(d.Username)
if err != nil {
return -1
}
defer rows.Close()
i := 0
for rows.Next() {
i++
}
return i
}
// minNumPermissions is the minimum number of permissions that will always be present.
const minNumPermissions = 2
userRows := returnedRows()
if userRows < minNumPermissions {
t.Fatalf("did not get expected number of rows, got %d", userRows)
}
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.RevokeOperation,
Storage: s,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
"role": name,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("Error on resp: %#v", *resp)
}
}
userRows = returnedRows()
// User shouldn't exist so returnedRows() should encounter an error and exit with -1
if userRows != -1 {
t.Fatalf("did not get expected number of rows, got %d", userRows)
}
return nil
},
}
}
func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);")
if err != nil {
t.Fatal(err)
}
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.RevokeOperation,
Storage: s,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("Error on resp: %#v", *resp)
}
}
return nil
},
}
}
func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec("DROP TABLE test;")
if err != nil {
t.Fatal(err)
}
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.RevokeOperation,
Storage: s,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("Error on resp: %#v", *resp)
}
}
return nil
},
}
}
func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "roles/" + name,
Check: func(resp *logical.Response) error {
if resp == nil {
if sql == "" {
return nil
}
return fmt.Errorf("bad: %#v", resp)
}
var d struct {
SQL string `mapstructure:"sql"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.SQL != sql {
return fmt.Errorf("bad: %#v", resp)
}
return nil
},
}
}
const testRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
`
const testReadOnlyRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
`
const testBlockStatementRole = `
DO $$
BEGIN
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
CREATE ROLE "foo-role";
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
ALTER ROLE "foo-role" SET search_path = foo;
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
END IF;
END
$$
CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';
GRANT "foo-role" TO "{{name}}";
ALTER ROLE "{{name}}" SET search_path = foo;
GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";
`
var testBlockStatementRoleSlice = []string{
`
DO $$
BEGIN
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
CREATE ROLE "foo-role";
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
ALTER ROLE "foo-role" SET search_path = foo;
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
END IF;
END
$$
`,
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
`GRANT "foo-role" TO "{{name}}";`,
`ALTER ROLE "{{name}}" SET search_path = foo;`,
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
}
const defaultRevocationSQL = `
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
REVOKE USAGE ON SCHEMA public FROM {{name}};
DROP ROLE IF EXISTS {{name}};
`

View File

@@ -0,0 +1,194 @@
package dbs
import (
"crypto/tls"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil"
)
type Cassandra struct {
// Session is goroutine safe, however, since we reinitialize
// it when connection info changes, we want to make sure we
// can close it and use a new connection; hence the lock
session *gocql.Session
config ConnectionConfig
sync.RWMutex
}
func (c *Cassandra) Type() string {
return cassandraTypeName
}
func (c *Cassandra) Connection() (*gocql.Session, error) {
// Grab the write lock
c.Lock()
defer c.Unlock()
// If we already have a DB, we got it!
if c.session != nil {
return c.session, nil
}
session, err := createSession(c.config)
if err != nil {
return nil, err
}
// Store the session in backend for reuse
c.session = session
return session, nil
}
func (p *Cassandra) Close() {
// Grab the write lock
p.Lock()
defer p.Unlock()
if p.session != nil {
p.session.Close()
}
p.session = nil
}
func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) {
// Grab the write lock
p.Lock()
p.config = config
p.Unlock()
p.Close()
return p.Connection()
}
func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error {
// Get the connection
db, err := p.Connection()
if err != nil {
return err
}
// TODO: This is racey
// Grab a read lock
p.RLock()
defer p.RUnlock()
return nil
}
func (p *Cassandra) RenewUser(username, expiration string) error {
db, err := p.Connection()
if err != nil {
return err
}
// TODO: This is Racey
// Grab the read lock
p.RLock()
defer p.RUnlock()
return nil
}
func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error {
db, err := p.Connection()
if err != nil {
return err
}
// TODO: this is Racey
p.RLock()
defer p.RUnlock()
return nil
}
func (p *Cassandra) DefaultRevokeUser(username string) error {
// Grab the read lock
p.RLock()
defer p.RUnlock()
db, err := p.Connection()
return nil
}
func createSession(cfg *ConnectionConfig) (*gocql.Session, error) {
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: cfg.Username,
Password: cfg.Password,
}
clusterConfig.ProtoVersion = cfg.ProtocolVersion
if clusterConfig.ProtoVersion == 0 {
clusterConfig.ProtoVersion = 2
}
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
if cfg.TLS {
var tlsConfig *tls.Config
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
}
certBundle := &certutil.CertBundle{}
if len(cfg.Certificate) > 0 {
certBundle.Certificate = cfg.Certificate
certBundle.PrivateKey = cfg.PrivateKey
}
if len(cfg.IssuingCA) > 0 {
certBundle.IssuingCA = cfg.IssuingCA
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
}
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
if cfg.TLSMinVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
}
clusterConfig.SslOpts = &gocql.SslOptions{
Config: *tlsConfig,
}
}
session, err := clusterConfig.CreateSession()
if err != nil {
return nil, fmt.Errorf("Error creating session: %s", err)
}
// Verify the info
err = session.Query(`LIST USERS`).Exec()
if err != nil {
return nil, fmt.Errorf("Error validating connection info: %s", err)
}
return session, nil
}

View File

@@ -0,0 +1,56 @@
package dbs
import (
"database/sql"
"errors"
"fmt"
"strings"
)
const (
postgreSQLTypeName = "postgres"
cassandraTypeName = "cassandra"
)
var (
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
)
func Factory(conf ConnectionConfig) (DatabaseType, error) {
switch conf.ConnectionType {
case postgreSQLTypeName:
return &PostgreSQL{
config: conf,
}, nil
}
return nil, ErrUnsupportedDatabaseType
}
type DatabaseType interface {
Type() string
Connection() (*sql.DB, error)
Close()
Reset(ConnectionConfig) (*sql.DB, error)
CreateUser(createStmt, username, password, expiration string) error
RenewUser(username, expiration string) error
CustomRevokeUser(username, revocationSQL string) error
DefaultRevokeUser(username string) error
}
type ConnectionConfig struct {
ConnectionType string `json:"type" structs:"type" mapstructure:"type"`
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
}
// Query templates a query for us.
func queryHelper(tpl string, data map[string]string) string {
for k, v := range data {
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
}
return tpl
}

View File

@@ -0,0 +1,336 @@
package dbs
import (
"database/sql"
"fmt"
"strings"
"sync"
"github.com/hashicorp/vault/helper/strutil"
"github.com/lib/pq"
)
type PostgreSQL struct {
db *sql.DB
config ConnectionConfig
sync.RWMutex
}
func (p *PostgreSQL) Type() string {
return postgreSQLTypeName
}
func (p *PostgreSQL) Connection() (*sql.DB, error) {
// Grab the write lock
p.Lock()
defer p.Unlock()
// If we already have a DB, we got it!
if p.db != nil {
if err := p.db.Ping(); err == nil {
return p.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
p.db.Close()
}
// Otherwise, attempt to make connection
conn := p.config.ConnectionURL
// Ensure timezone is set to UTC for all the conenctions
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
if strings.Contains(conn, "?") {
conn += "&timezone=utc"
} else {
conn += "?timezone=utc"
}
} else {
conn += " timezone=utc"
}
var err error
p.db, err = sql.Open("postgres", conn)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
p.db.SetMaxOpenConns(p.config.MaxOpenConnections)
p.db.SetMaxIdleConns(p.config.MaxIdleConnections)
return p.db, nil
}
func (p *PostgreSQL) Close() {
// Grab the write lock
p.Lock()
defer p.Unlock()
if p.db != nil {
p.db.Close()
}
p.db = nil
}
func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) {
// Grab the write lock
p.Lock()
p.config = config
p.Unlock()
p.Close()
return p.Connection()
}
func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error {
// Get the connection
db, err := p.Connection()
if err != nil {
return err
}
// TODO: This is racey
// Grab a read lock
p.RLock()
defer p.RUnlock()
// Start a transaction
// b.logger.Trace("postgres/pathRoleCreateRead: starting transaction")
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
// b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction")
tx.Rollback()
}()
// Return the secret
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
// b.logger.Trace("postgres/pathRoleCreateRead: preparing statement")
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expiration,
}))
if err != nil {
return err
}
defer stmt.Close()
// b.logger.Trace("postgres/pathRoleCreateRead: executing statement")
if _, err := stmt.Exec(); err != nil {
return err
}
}
// Commit the transaction
// b.logger.Trace("postgres/pathRoleCreateRead: committing transaction")
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (p *PostgreSQL) RenewUser(username, expiration string) error {
db, err := p.Connection()
if err != nil {
return err
}
// TODO: This is Racey
// Grab the read lock
p.RLock()
defer p.RUnlock()
query := fmt.Sprintf(
"ALTER ROLE %s VALID UNTIL '%s';",
pq.QuoteIdentifier(username),
expiration)
stmt, err := db.Prepare(query)
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
return nil
}
func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
db, err := p.Connection()
if err != nil {
return err
}
// TODO: this is Racey
p.RLock()
defer p.RUnlock()
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (p *PostgreSQL) DefaultRevokeUser(username string) error {
// Grab the read lock
p.RLock()
defer p.RUnlock()
db, err := p.Connection()
if err != nil {
return err
}
// Check if the role exists
var exists bool
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
return err
}
if exists == false {
return nil
}
// Query for permissions; we need to revoke permissions before we can drop
// the role
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil {
return err
}
defer stmt.Close()
rows, err := stmt.Query(username)
if err != nil {
return err
}
defer rows.Close()
const initialNumRevocations = 16
revocationStmts := make([]string, 0, initialNumRevocations)
for rows.Next() {
var schema string
err = rows.Scan(&schema)
if err != nil {
// keep going; remove as many permissions as possible right now
continue
}
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
pq.QuoteIdentifier(schema),
pq.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
pq.QuoteIdentifier(schema),
pq.QuoteIdentifier(username)))
}
// for good measure, revoke all privileges and usage on schema public
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
pq.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
pq.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE USAGE ON SCHEMA public FROM %s;",
pq.QuoteIdentifier(username)))
// get the current database name so we can issue a REVOKE CONNECT for
// this username
var dbname sql.NullString
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
return err
}
if dbname.Valid {
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
pq.QuoteIdentifier(dbname.String),
pq.QuoteIdentifier(username)))
}
// again, here, we do not stop on error, as we want to remove as
// many permissions as possible right now
var lastStmtError error
for _, query := range revocationStmts {
stmt, err := db.Prepare(query)
if err != nil {
lastStmtError = err
continue
}
defer stmt.Close()
_, err = stmt.Exec()
if err != nil {
lastStmtError = err
}
}
// can't drop if not all privileges are revoked
if rows.Err() != nil {
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
}
if lastStmtError != nil {
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
}
// Drop this user
stmt, err = db.Prepare(fmt.Sprintf(
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,188 @@
package database
import (
"fmt"
"github.com/fatih/structs"
"github.com/hashicorp/vault/builtin/logical/database/dbs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
_ "github.com/lib/pq"
)
func pathConfigConnection(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of this DB type",
},
"connection_type": &framework.FieldSchema{
Type: framework.TypeString,
Description: "DB type (e.g. postgres)",
},
"connection_url": &framework.FieldSchema{
Type: framework.TypeString,
Description: "DB connection string",
},
"connection_details": &framework.FieldSchema{
Type: framework.TypeMap,
Description: "Connection details for specified connection type.",
},
"verify_connection": &framework.FieldSchema{
Type: framework.TypeBool,
Default: true,
Description: `If set, connection_url is verified by actually connecting to the database`,
},
"max_open_connections": &framework.FieldSchema{
Type: framework.TypeInt,
Description: `Maximum number of open connections to the database;
a zero uses the default value of two and a
negative value means unlimited`,
},
"max_idle_connections": &framework.FieldSchema{
Type: framework.TypeInt,
Description: `Maximum number of idle connections to the database;
a zero uses the value of max_open_connections
and a negative value disables idle connections.
If larger than max_open_connections it will be
reduced to the same size.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
logical.ReadOperation: b.pathConnectionRead,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
HelpDescription: pathConfigConnectionHelpDesc,
}
}
// pathConnectionRead reads out the connection configuration
func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config dbs.ConnectionConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
return &logical.Response{
Data: structs.New(config).Map(),
}, nil
}
func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connURL := data.Get("connection_url").(string)
connType := data.Get("connection_type").(string)
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
maxOpenConns = 2
}
maxIdleConns := data.Get("max_idle_connections").(int)
if maxIdleConns == 0 {
maxIdleConns = maxOpenConns
}
if maxIdleConns > maxOpenConns {
maxIdleConns = maxOpenConns
}
config := dbs.ConnectionConfig{
ConnectionType: connType,
ConnectionURL: connURL,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
}
name := data.Get("name").(string)
// Grab the mutex lock
b.Lock()
defer b.Unlock()
var err error
var db dbs.DatabaseType
if _, ok := b.connections[name]; ok {
// Don't allow the connection type to change
if b.connections[name].Type() != connType {
return logical.ErrorResponse("can not change type of existing connection"), nil
}
db = b.connections[name]
} else {
db, err = dbs.Factory(config)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
}
/*
// Don't check the connection_url if verification is disabled
verifyConnection := data.Get("verify_connection").(bool)
if verifyConnection {
// Verify the string
db, err := sql.Open("postgres", connURL)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
defer db.Close()
if err := db.Ping(); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Error validating connection info: %s", err)), nil
}
}
*/
// Store it
entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
// Reset the DB connection
db.Reset(config)
b.connections[name] = db
resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
return resp, nil
}
const pathConfigConnectionHelpSyn = `
Configure the connection string to talk to PostgreSQL.
`
const pathConfigConnectionHelpDesc = `
This path configures the connection string used to connect to PostgreSQL.
The value of the string can be a URL, or a PG style string in the
format of "user=foo host=bar" etc.
The URL looks like:
"postgresql://user:pass@host:port/dbname"
When configuring the connection string, the backend will verify its validity.
`

View File

@@ -0,0 +1,103 @@
package database
import (
"fmt"
"time"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathConfigLease(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: "config/lease",
Fields: map[string]*framework.FieldSchema{
"lease": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Default lease for roles.",
},
"lease_max": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Maximum time a credential is valid for.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathLeaseRead,
logical.UpdateOperation: b.pathLeaseWrite,
},
HelpSynopsis: pathConfigLeaseHelpSyn,
HelpDescription: pathConfigLeaseHelpDesc,
}
}
func (b *databaseBackend) pathLeaseWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
leaseRaw := d.Get("lease").(string)
leaseMaxRaw := d.Get("lease_max").(string)
lease, err := time.ParseDuration(leaseRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid lease: %s", err)), nil
}
leaseMax, err := time.ParseDuration(leaseMaxRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid lease: %s", err)), nil
}
// Store it
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
Lease: lease,
LeaseMax: leaseMax,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
func (b *databaseBackend) pathLeaseRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"lease": lease.Lease.String(),
"lease_max": lease.LeaseMax.String(),
},
}, nil
}
type configLease struct {
Lease time.Duration
LeaseMax time.Duration
}
const pathConfigLeaseHelpSyn = `
Configure the default lease information for generated credentials.
`
const pathConfigLeaseHelpDesc = `
This configures the default lease information used for credentials
generated by this backend. The lease specifies the duration that a
credential will be valid for, as well as the maximum session for
a set of credentials.
The format for the lease is "1h" or integer and then unit. The longest
unit is hour.
`

View File

@@ -0,0 +1,120 @@
package database
import (
"fmt"
"time"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
_ "github.com/lib/pq"
)
func pathRoleCreate(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: "creds/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the role.",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleCreateRead,
},
HelpSynopsis: pathRoleCreateReadHelpSyn,
HelpDescription: pathRoleCreateReadHelpDesc,
}
}
func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.logger.Trace("postgres/pathRoleCreateRead: enter")
defer b.logger.Trace("postgres/pathRoleCreateRead: exit")
name := data.Get("name").(string)
// Get the role
b.logger.Trace("postgres/pathRoleCreateRead: getting role")
role, err := b.Role(req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
}
// Determine if we have a lease
b.logger.Trace("postgres/pathRoleCreateRead: getting lease")
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
// Unlike some other backends we need a lease here (can't leave as 0 and
// let core fill it in) because Postgres also expires users as a safety
// measure, so cannot be zero
if lease == nil {
lease = &configLease{
Lease: b.System().DefaultLeaseTTL(),
}
}
// Generate the username, password and expiration. PG limits user to 63 characters
displayName := req.DisplayName
if len(displayName) > 26 {
displayName = displayName[:26]
}
userUUID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
username := fmt.Sprintf("%s-%s", displayName, userUUID)
if len(username) > 63 {
username = username[:63]
}
password, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
expiration := time.Now().
Add(lease.Lease).
Format("2006-01-02 15:04:05-0700")
// Get our handle
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
b.RLock()
defer b.RUnlock()
db, ok := b.connections[role.DBName]
if !ok {
// TODO: return a resp error instead?
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
}
err = db.CreateUser(role.CreationStatement, username, password, expiration)
if err != nil {
return nil, err
}
b.logger.Trace("postgres/pathRoleCreateRead: generating secret")
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username,
"password": password,
}, map[string]interface{}{
"username": username,
"role": name,
})
resp.Secret.TTL = lease.Lease
return resp, nil
}
const pathRoleCreateReadHelpSyn = `
Request database credentials for a certain role.
`
const pathRoleCreateReadHelpDesc = `
This path reads database credentials for a certain role. The
database credentials will be generated on demand and will be automatically
revoked when the lease is up.
`

View File

@@ -0,0 +1,161 @@
package database
import (
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func pathListRoles(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: "roles/?$",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ListOperation: b.pathRoleList,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func pathRoles(b *databaseBackend) *framework.Path {
return &framework.Path{
Pattern: "roles/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the role.",
},
"db_name": {
Type: framework.TypeString,
Description: "Name of the database this role acts on.",
},
"creation_statement": {
Type: framework.TypeString,
Description: "SQL string to create a user. See help for more info.",
},
"revocation_statement": {
Type: framework.TypeString,
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
string, a base64-encoded semicolon-separated string, a serialized JSON string
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
will be substituted.`,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRoleRead,
logical.UpdateOperation: b.pathRoleCreate,
logical.DeleteOperation: b.pathRoleDelete,
},
HelpSynopsis: pathRoleHelpSyn,
HelpDescription: pathRoleHelpDesc,
}
}
func (b *databaseBackend) pathRoleDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("role/" + data.Get("name").(string))
if err != nil {
return nil, err
}
return nil, nil
}
func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, err := b.Role(req.Storage, data.Get("name").(string))
if err != nil {
return nil, err
}
if role == nil {
return nil, nil
}
return &logical.Response{
Data: map[string]interface{}{
"creation_statment": role.CreationStatement,
"revocation_statement": role.RevocationStatement,
},
}, nil
}
func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("role/")
if err != nil {
return nil, err
}
return logical.ListResponse(entries), nil
}
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
dbName := data.Get("db_name").(string)
creationStmt := data.Get("creation_statement").(string)
revocationStmt := data.Get("revocation_statement").(string)
// TODO: Think about preparing the statments to test.
// Store it
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
DBName: dbName,
CreationStatement: creationStmt,
RevocationStatement: revocationStmt,
})
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
return nil, err
}
return nil, nil
}
type roleEntry struct {
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
}
const pathRoleHelpSyn = `
Manage the roles that can be created with this backend.
`
const pathRoleHelpDesc = `
This path lets you manage the roles that can be created with this backend.
The "sql" parameter customizes the SQL string used to create the role.
This can be a sequence of SQL queries. Some substitution will be done to the
SQL string for certain keys. The names of the variables must be surrounded
by "{{" and "}}" to be replaced.
* "name" - The random username generated for the DB user.
* "password" - The random password generated for the DB user.
* "expiration" - The timestamp when this user will expire.
Example of a decent SQL query to use:
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
Note the above user would be able to access everything in schema public.
For more complex GRANT clauses, see the PostgreSQL manual.
The "revocation_sql" parameter customizes the SQL string used to revoke a user.
Example of a decent revocation SQL query to use:
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
REVOKE USAGE ON SCHEMA public FROM {{name}};
DROP ROLE IF EXISTS {{name}};
`

View File

@@ -0,0 +1,147 @@
package database
import (
"errors"
"fmt"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
const SecretCredsType = "creds"
func secretCreds(b *databaseBackend) *framework.Secret {
return &framework.Secret{
Type: SecretCredsType,
Fields: map[string]*framework.FieldSchema{
"username": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Username",
},
"password": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Password",
},
},
Renew: b.secretCredsRenew,
Revoke: b.secretCredsRevoke,
}
}
func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
dbName := d.Get("name").(string)
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
// Get our connection
db, ok := b.connections[dbName]
if !ok {
return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName))
}
// Get the lease information
lease, err := b.Lease(req.Storage)
if err != nil {
return nil, err
}
if lease == nil {
lease = &configLease{}
}
f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System())
resp, err := f(req, d)
if err != nil {
return nil, err
}
// Make sure we increase the VALID UNTIL endpoint for this user.
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
err := db.RenewUser(username, expiration)
if err != nil {
return nil, err
}
}
return resp, nil
}
func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the username from the internal data
usernameRaw, ok := req.Secret.InternalData["username"]
if !ok {
return nil, fmt.Errorf("secret is missing username internal data")
}
username, ok := usernameRaw.(string)
var revocationSQL string
var resp *logical.Response
roleNameRaw, ok := req.Secret.InternalData["role"]
if !ok {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
}
role, err := b.Role(req.Storage, roleNameRaw.(string))
if err != nil {
return nil, err
}
if role == nil {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
}
/* TODO: think about how to handle this case.
if !ok {
role, err := b.Role(req.Storage, roleNameRaw.(string))
if err != nil {
return nil, err
}
if role == nil {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string)))
} else {
revocationSQL = role.RevocationStatement
}
}*/
// Grab the read lock
b.RLock()
defer b.RUnlock()
// Get our connection
db, ok := b.connections[role.DBName]
if !ok {
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
}
// TODO: Maybe move this down into db package?
switch revocationSQL {
// This is the default revocation logic. If revocation SQL is provided it
// is simply executed as-is.
case "":
err := db.DefaultRevokeUser(username)
if err != nil {
return nil, err
}
// We have revocation SQL, execute directly, within a transaction
default:
err := db.CustomRevokeUser(username, revocationSQL)
if err != nil {
return nil, err
}
}
return resp, nil
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/hashicorp/vault/builtin/logical/aws"
"github.com/hashicorp/vault/builtin/logical/cassandra"
"github.com/hashicorp/vault/builtin/logical/consul"
"github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/builtin/logical/mongodb"
"github.com/hashicorp/vault/builtin/logical/mssql"
"github.com/hashicorp/vault/builtin/logical/mysql"
@@ -91,6 +92,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
"mysql": mysql.Factory,
"ssh": ssh.Factory,
"rabbitmq": rabbitmq.Factory,
"database": database.Factory,
},
ShutdownCh: command.MakeShutdownCh(),
SighupCh: command.MakeSighupCh(),