mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
Begin work on database refactor
This commit is contained in:
committed by
Brian Kassouf
parent
daf2dd6995
commit
3d77a9a6f4
104
builtin/logical/database/backend.go
Normal file
104
builtin/logical/database/backend.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
return Backend(conf).Setup(conf)
|
||||
}
|
||||
|
||||
func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||
var b databaseBackend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigConnection(&b),
|
||||
pathConfigLease(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
pathRoleCreate(&b),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{
|
||||
secretCreds(&b),
|
||||
},
|
||||
|
||||
Clean: b.resetAllDBs,
|
||||
}
|
||||
|
||||
b.logger = conf.Logger
|
||||
b.connections = make(map[string]dbs.DatabaseType)
|
||||
return &b
|
||||
}
|
||||
|
||||
type databaseBackend struct {
|
||||
connections map[string]dbs.DatabaseType
|
||||
logger log.Logger
|
||||
|
||||
*framework.Backend
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// resetAllDBs closes all connections from all database types
|
||||
func (b *databaseBackend) resetAllDBs() {
|
||||
b.logger.Trace("postgres/resetdb: enter")
|
||||
defer b.logger.Trace("postgres/resetdb: exit")
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
for _, db := range b.connections {
|
||||
db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Lease returns the lease information
|
||||
func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) {
|
||||
entry, err := s.Get("config/lease")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result configLease
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get("role/" + n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result roleEntry
|
||||
if err := entry.DecodeJSON(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
The PostgreSQL backend dynamically generates database users.
|
||||
|
||||
After mounting this backend, configure it using the endpoints within
|
||||
the "config/" path.
|
||||
`
|
||||
620
builtin/logical/database/backend_test.go
Normal file
620
builtin/logical/database/backend_test.go
Normal file
@@ -0,0 +1,620 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
logicaltest "github.com/hashicorp/vault/logical/testing"
|
||||
"github.com/lib/pq"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/ory-am/dockertest"
|
||||
)
|
||||
|
||||
var (
|
||||
testImagePull sync.Once
|
||||
)
|
||||
|
||||
func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) {
|
||||
if os.Getenv("PG_URL") != "" {
|
||||
return "", os.Getenv("PG_URL")
|
||||
}
|
||||
|
||||
// Without this the checks for whether the container has started seem to
|
||||
// never actually pass. There's really no reason to expose the test
|
||||
// containers, so don't.
|
||||
dockertest.BindDockerToLocalhost = "yep"
|
||||
|
||||
testImagePull.Do(func() {
|
||||
dockertest.Pull("postgres")
|
||||
})
|
||||
|
||||
cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool {
|
||||
// This will cause a validation to run
|
||||
resp, err := b.HandleRequest(&logical.Request{
|
||||
Storage: s,
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
})
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
// It's likely not up and running yet, so return false and try again
|
||||
return false
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected warning")
|
||||
}
|
||||
|
||||
retURL = connURL
|
||||
return true
|
||||
})
|
||||
|
||||
if connErr != nil {
|
||||
t.Fatalf("could not connect to database: %v", connErr)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
|
||||
err := cid.KillRemove()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_config_connection(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
configData := map[string]interface{}{
|
||||
"connection_url": "sample_connection_url",
|
||||
"value": "",
|
||||
"max_open_connections": 9,
|
||||
"max_idle_connections": 7,
|
||||
"verify_connection": false,
|
||||
}
|
||||
|
||||
configReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Storage: config.StorageView,
|
||||
Data: configData,
|
||||
}
|
||||
resp, err = b.HandleRequest(configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
configReq.Operation = logical.ReadOperation
|
||||
resp, err = b.HandleRequest(configReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%s resp:%#v\n", err, resp)
|
||||
}
|
||||
|
||||
delete(configData, "verify_connection")
|
||||
if !reflect.DeepEqual(configData, resp.Data) {
|
||||
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cid, connURL := prepareTestContainer(t, config.StorageView, b)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepCreateRole(t, "web", testRole, false),
|
||||
testAccStepReadCreds(t, b, config.StorageView, "web", connURL),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleCrud(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cid, connURL := prepareTestContainer(t, config.StorageView, b)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepCreateRole(t, "web", testRole, false),
|
||||
testAccStepReadRole(t, "web", testRole),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
testAccStepReadRole(t, "web", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_BlockStatements(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cid, connURL := prepareTestContainer(t, config.StorageView, b)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
// This will also validate the query
|
||||
testAccStepCreateRole(t, "web-block", testBlockStatementRole, true),
|
||||
testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleReadOnly(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cid, connURL := prepareTestContainer(t, config.StorageView, b)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepCreateRole(t, "web", testRole, false),
|
||||
testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false),
|
||||
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
|
||||
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
|
||||
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
|
||||
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
|
||||
testAccStepDeleteRole(t, "web-readonly"),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
testAccStepReadRole(t, "web-readonly", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_roleReadOnly_revocationSQL(t *testing.T) {
|
||||
config := logical.TestBackendConfig()
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
b, err := Factory(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cid, connURL := prepareTestContainer(t, config.StorageView, b)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
connData := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
}
|
||||
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepConfig(t, connData, false),
|
||||
testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false),
|
||||
testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false),
|
||||
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
|
||||
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
|
||||
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
|
||||
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
|
||||
testAccStepDeleteRole(t, "web-readonly"),
|
||||
testAccStepDeleteRole(t, "web"),
|
||||
testAccStepReadRole(t, "web-readonly", ""),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/connection",
|
||||
Data: d,
|
||||
ErrorOk: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if expectError {
|
||||
if resp.Data == nil {
|
||||
return fmt.Errorf("data is nil")
|
||||
}
|
||||
var e struct {
|
||||
Error string `mapstructure:"error"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(e.Error) == 0 {
|
||||
return fmt.Errorf("expected error, but write succeeded.")
|
||||
}
|
||||
return nil
|
||||
} else if resp != nil && resp.IsError() {
|
||||
return fmt.Errorf("got an error response: %v", resp.Error())
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: path.Join("roles", name),
|
||||
Data: map[string]interface{}{
|
||||
"sql": sql,
|
||||
},
|
||||
ErrorOk: expectFail,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: path.Join("roles", name),
|
||||
Data: map[string]interface{}{
|
||||
"sql": sql,
|
||||
"revocation_sql": revocationSQL,
|
||||
},
|
||||
ErrorOk: expectFail,
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: path.Join("roles", name),
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: path.Join("creds", name),
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||
conn, err := pq.ParseURL(connURL)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn += " timezone=utc"
|
||||
|
||||
db, err := sql.Open("postgres", conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
returnedRows := func() int {
|
||||
stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');")
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(d.Username)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
i++
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// minNumPermissions is the minimum number of permissions that will always be present.
|
||||
const minNumPermissions = 2
|
||||
|
||||
userRows := returnedRows()
|
||||
if userRows < minNumPermissions {
|
||||
t.Fatalf("did not get expected number of rows, got %d", userRows)
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Operation: logical.RevokeOperation,
|
||||
Storage: s,
|
||||
Secret: &logical.Secret{
|
||||
InternalData: map[string]interface{}{
|
||||
"secret_type": "creds",
|
||||
"username": d.Username,
|
||||
"role": name,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("Error on resp: %#v", *resp)
|
||||
}
|
||||
}
|
||||
|
||||
userRows = returnedRows()
|
||||
// User shouldn't exist so returnedRows() should encounter an error and exit with -1
|
||||
if userRows != -1 {
|
||||
t.Fatalf("did not get expected number of rows, got %d", userRows)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: path.Join("creds", name),
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||
conn, err := pq.ParseURL(connURL)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn += " timezone=utc"
|
||||
|
||||
db, err := sql.Open("postgres", conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Operation: logical.RevokeOperation,
|
||||
Storage: s,
|
||||
Secret: &logical.Secret{
|
||||
InternalData: map[string]interface{}{
|
||||
"secret_type": "creds",
|
||||
"username": d.Username,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("Error on resp: %#v", *resp)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: path.Join("creds", name),
|
||||
Check: func(resp *logical.Response) error {
|
||||
var d struct {
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[TRACE] Generated credentials: %v", d)
|
||||
conn, err := pq.ParseURL(connURL)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn += " timezone=utc"
|
||||
|
||||
db, err := sql.Open("postgres", conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("DROP TABLE test;")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
resp, err = b.HandleRequest(&logical.Request{
|
||||
Operation: logical.RevokeOperation,
|
||||
Storage: s,
|
||||
Secret: &logical.Secret{
|
||||
InternalData: map[string]interface{}{
|
||||
"secret_type": "creds",
|
||||
"username": d.Username,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.IsError() {
|
||||
return fmt.Errorf("Error on resp: %#v", *resp)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep {
|
||||
return logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "roles/" + name,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp == nil {
|
||||
if sql == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
var d struct {
|
||||
SQL string `mapstructure:"sql"`
|
||||
}
|
||||
if err := mapstructure.Decode(resp.Data, &d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.SQL != sql {
|
||||
return fmt.Errorf("bad: %#v", resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const testRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testReadOnlyRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testBlockStatementRole = `
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
|
||||
CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';
|
||||
GRANT "foo-role" TO "{{name}}";
|
||||
ALTER ROLE "{{name}}" SET search_path = foo;
|
||||
GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";
|
||||
`
|
||||
|
||||
var testBlockStatementRoleSlice = []string{
|
||||
`
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
`,
|
||||
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
|
||||
`GRANT "foo-role" TO "{{name}}";`,
|
||||
`ALTER ROLE "{{name}}" SET search_path = foo;`,
|
||||
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
|
||||
}
|
||||
|
||||
const defaultRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
|
||||
REVOKE USAGE ON SCHEMA public FROM {{name}};
|
||||
|
||||
DROP ROLE IF EXISTS {{name}};
|
||||
`
|
||||
194
builtin/logical/database/dbs/cassandra.go
Normal file
194
builtin/logical/database/dbs/cassandra.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
)
|
||||
|
||||
type Cassandra struct {
|
||||
// Session is goroutine safe, however, since we reinitialize
|
||||
// it when connection info changes, we want to make sure we
|
||||
// can close it and use a new connection; hence the lock
|
||||
session *gocql.Session
|
||||
config ConnectionConfig
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *Cassandra) Type() string {
|
||||
return cassandraTypeName
|
||||
}
|
||||
|
||||
func (c *Cassandra) Connection() (*gocql.Session, error) {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if c.session != nil {
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
session, err := createSession(c.config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the session in backend for reuse
|
||||
c.session = session
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Close() {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if p.session != nil {
|
||||
p.session.Close()
|
||||
}
|
||||
|
||||
p.session = nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
p.config = config
|
||||
p.Unlock()
|
||||
|
||||
p.Close()
|
||||
return p.Connection()
|
||||
}
|
||||
|
||||
func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error {
|
||||
// Get the connection
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: This is racey
|
||||
// Grab a read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) RenewUser(username, expiration string) error {
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: This is Racey
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error {
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: this is Racey
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) DefaultRevokeUser(username string) error {
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
db, err := p.Connection()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSession(cfg *ConnectionConfig) (*gocql.Session, error) {
|
||||
clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
}
|
||||
|
||||
clusterConfig.ProtoVersion = cfg.ProtocolVersion
|
||||
if clusterConfig.ProtoVersion == 0 {
|
||||
clusterConfig.ProtoVersion = 2
|
||||
}
|
||||
|
||||
clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second
|
||||
|
||||
if cfg.TLS {
|
||||
var tlsConfig *tls.Config
|
||||
if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 {
|
||||
if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 {
|
||||
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
certBundle := &certutil.CertBundle{}
|
||||
if len(cfg.Certificate) > 0 {
|
||||
certBundle.Certificate = cfg.Certificate
|
||||
certBundle.PrivateKey = cfg.PrivateKey
|
||||
}
|
||||
if len(cfg.IssuingCA) > 0 {
|
||||
certBundle.IssuingCA = cfg.IssuingCA
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = cfg.InsecureTLS
|
||||
|
||||
if cfg.TLSMinVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
}
|
||||
|
||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||
Config: *tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error creating session: %s", err)
|
||||
}
|
||||
|
||||
// Verify the info
|
||||
err = session.Query(`LIST USERS`).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error validating connection info: %s", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
56
builtin/logical/database/dbs/db.go
Normal file
56
builtin/logical/database/dbs/db.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
postgreSQLTypeName = "postgres"
|
||||
cassandraTypeName = "cassandra"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
|
||||
)
|
||||
|
||||
func Factory(conf ConnectionConfig) (DatabaseType, error) {
|
||||
switch conf.ConnectionType {
|
||||
case postgreSQLTypeName:
|
||||
return &PostgreSQL{
|
||||
config: conf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, ErrUnsupportedDatabaseType
|
||||
}
|
||||
|
||||
type DatabaseType interface {
|
||||
Type() string
|
||||
Connection() (*sql.DB, error)
|
||||
Close()
|
||||
Reset(ConnectionConfig) (*sql.DB, error)
|
||||
CreateUser(createStmt, username, password, expiration string) error
|
||||
RenewUser(username, expiration string) error
|
||||
CustomRevokeUser(username, revocationSQL string) error
|
||||
DefaultRevokeUser(username string) error
|
||||
}
|
||||
|
||||
type ConnectionConfig struct {
|
||||
ConnectionType string `json:"type" structs:"type" mapstructure:"type"`
|
||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||
ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
}
|
||||
|
||||
// Query templates a query for us.
|
||||
func queryHelper(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
||||
336
builtin/logical/database/dbs/postgresql.go
Normal file
336
builtin/logical/database/dbs/postgresql.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type PostgreSQL struct {
|
||||
db *sql.DB
|
||||
config ConnectionConfig
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Type() string {
|
||||
return postgreSQLTypeName
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Connection() (*sql.DB, error) {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// If we already have a DB, we got it!
|
||||
if p.db != nil {
|
||||
if err := p.db.Ping(); err == nil {
|
||||
return p.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
p.db.Close()
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
conn := p.config.ConnectionURL
|
||||
|
||||
// Ensure timezone is set to UTC for all the conenctions
|
||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||
if strings.Contains(conn, "?") {
|
||||
conn += "&timezone=utc"
|
||||
} else {
|
||||
conn += "?timezone=utc"
|
||||
}
|
||||
} else {
|
||||
conn += " timezone=utc"
|
||||
}
|
||||
|
||||
var err error
|
||||
p.db, err = sql.Open("postgres", conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
p.db.SetMaxOpenConns(p.config.MaxOpenConnections)
|
||||
p.db.SetMaxIdleConns(p.config.MaxIdleConnections)
|
||||
|
||||
return p.db, nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Close() {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if p.db != nil {
|
||||
p.db.Close()
|
||||
}
|
||||
|
||||
p.db = nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) {
|
||||
// Grab the write lock
|
||||
p.Lock()
|
||||
p.config = config
|
||||
p.Unlock()
|
||||
|
||||
p.Close()
|
||||
return p.Connection()
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error {
|
||||
// Get the connection
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: This is racey
|
||||
// Grab a read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
// Start a transaction
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: starting transaction")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction")
|
||||
tx.Rollback()
|
||||
}()
|
||||
// Return the secret
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: preparing statement")
|
||||
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expiration,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: executing statement")
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
|
||||
// b.logger.Trace("postgres/pathRoleCreateRead: committing transaction")
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(username, expiration string) error {
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: This is Racey
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"ALTER ROLE %s VALID UNTIL '%s';",
|
||||
pq.QuoteIdentifier(username),
|
||||
expiration)
|
||||
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error {
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: this is Racey
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) DefaultRevokeUser(username string) error {
|
||||
// Grab the read lock
|
||||
p.RLock()
|
||||
defer p.RUnlock()
|
||||
|
||||
db, err := p.Connection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the role exists
|
||||
var exists bool
|
||||
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists == false {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query for permissions; we need to revoke permissions before we can drop
|
||||
// the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
const initialNumRevocations = 16
|
||||
revocationStmts := make([]string, 0, initialNumRevocations)
|
||||
for rows.Next() {
|
||||
var schema string
|
||||
err = rows.Scan(&schema)
|
||||
if err != nil {
|
||||
// keep going; remove as many permissions as possible right now
|
||||
continue
|
||||
}
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
|
||||
pq.QuoteIdentifier(schema),
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
|
||||
pq.QuoteIdentifier(schema),
|
||||
pq.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// for good measure, revoke all privileges and usage on schema public
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE USAGE ON SCHEMA public FROM %s;",
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
// get the current database name so we can issue a REVOKE CONNECT for
|
||||
// this username
|
||||
var dbname sql.NullString
|
||||
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if dbname.Valid {
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
|
||||
pq.QuoteIdentifier(dbname.String),
|
||||
pq.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// again, here, we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revocationStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all privileges are revoked
|
||||
if rows.Err() != nil {
|
||||
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this user
|
||||
stmt, err = db.Prepare(fmt.Sprintf(
|
||||
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
188
builtin/logical/database/path_config_connection.go
Normal file
188
builtin/logical/database/path_config_connection.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func pathConfigConnection(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of this DB type",
|
||||
},
|
||||
|
||||
"connection_type": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "DB type (e.g. postgres)",
|
||||
},
|
||||
|
||||
"connection_url": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "DB connection string",
|
||||
},
|
||||
|
||||
"connection_details": &framework.FieldSchema{
|
||||
Type: framework.TypeMap,
|
||||
Description: "Connection details for specified connection type.",
|
||||
},
|
||||
|
||||
"verify_connection": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: `If set, connection_url is verified by actually connecting to the database`,
|
||||
},
|
||||
|
||||
"max_open_connections": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
Description: `Maximum number of open connections to the database;
|
||||
a zero uses the default value of two and a
|
||||
negative value means unlimited`,
|
||||
},
|
||||
|
||||
"max_idle_connections": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
Description: `Maximum number of idle connections to the database;
|
||||
a zero uses the value of max_open_connections
|
||||
and a negative value disables idle connections.
|
||||
If larger than max_open_connections it will be
|
||||
reduced to the same size.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.pathConnectionWrite,
|
||||
logical.ReadOperation: b.pathConnectionRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigConnectionHelpSyn,
|
||||
HelpDescription: pathConfigConnectionHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
// pathConnectionRead reads out the connection configuration
|
||||
func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
|
||||
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration")
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var config dbs.ConnectionConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &logical.Response{
|
||||
Data: structs.New(config).Map(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
connURL := data.Get("connection_url").(string)
|
||||
connType := data.Get("connection_type").(string)
|
||||
|
||||
maxOpenConns := data.Get("max_open_connections").(int)
|
||||
if maxOpenConns == 0 {
|
||||
maxOpenConns = 2
|
||||
}
|
||||
|
||||
maxIdleConns := data.Get("max_idle_connections").(int)
|
||||
if maxIdleConns == 0 {
|
||||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
if maxIdleConns > maxOpenConns {
|
||||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
|
||||
config := dbs.ConnectionConfig{
|
||||
ConnectionType: connType,
|
||||
ConnectionURL: connURL,
|
||||
MaxOpenConnections: maxOpenConns,
|
||||
MaxIdleConnections: maxIdleConns,
|
||||
}
|
||||
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Grab the mutex lock
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
var err error
|
||||
var db dbs.DatabaseType
|
||||
if _, ok := b.connections[name]; ok {
|
||||
|
||||
// Don't allow the connection type to change
|
||||
if b.connections[name].Type() != connType {
|
||||
return logical.ErrorResponse("can not change type of existing connection"), nil
|
||||
}
|
||||
|
||||
db = b.connections[name]
|
||||
} else {
|
||||
db, err = dbs.Factory(config)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
// Don't check the connection_url if verification is disabled
|
||||
verifyConnection := data.Get("verify_connection").(bool)
|
||||
if verifyConnection {
|
||||
// Verify the string
|
||||
db, err := sql.Open("postgres", connURL)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error validating connection info: %s", err)), nil
|
||||
}
|
||||
defer db.Close()
|
||||
if err := db.Ping(); err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Error validating connection info: %s", err)), nil
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the DB connection
|
||||
db.Reset(config)
|
||||
b.connections[name] = db
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathConfigConnectionHelpSyn = `
|
||||
Configure the connection string to talk to PostgreSQL.
|
||||
`
|
||||
|
||||
const pathConfigConnectionHelpDesc = `
|
||||
This path configures the connection string used to connect to PostgreSQL.
|
||||
The value of the string can be a URL, or a PG style string in the
|
||||
format of "user=foo host=bar" etc.
|
||||
|
||||
The URL looks like:
|
||||
"postgresql://user:pass@host:port/dbname"
|
||||
|
||||
When configuring the connection string, the backend will verify its validity.
|
||||
`
|
||||
103
builtin/logical/database/path_config_lease.go
Normal file
103
builtin/logical/database/path_config_lease.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathConfigLease(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "config/lease",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"lease": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Default lease for roles.",
|
||||
},
|
||||
|
||||
"lease_max": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Maximum time a credential is valid for.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathLeaseRead,
|
||||
logical.UpdateOperation: b.pathLeaseWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigLeaseHelpSyn,
|
||||
HelpDescription: pathConfigLeaseHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathLeaseWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
leaseRaw := d.Get("lease").(string)
|
||||
leaseMaxRaw := d.Get("lease_max").(string)
|
||||
|
||||
lease, err := time.ParseDuration(leaseRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid lease: %s", err)), nil
|
||||
}
|
||||
leaseMax, err := time.ParseDuration(leaseMaxRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid lease: %s", err)), nil
|
||||
}
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("config/lease", &configLease{
|
||||
Lease: lease,
|
||||
LeaseMax: leaseMax,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathLeaseRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
lease, err := b.Lease(req.Storage)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"lease": lease.Lease.String(),
|
||||
"lease_max": lease.LeaseMax.String(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type configLease struct {
|
||||
Lease time.Duration
|
||||
LeaseMax time.Duration
|
||||
}
|
||||
|
||||
const pathConfigLeaseHelpSyn = `
|
||||
Configure the default lease information for generated credentials.
|
||||
`
|
||||
|
||||
const pathConfigLeaseHelpDesc = `
|
||||
This configures the default lease information used for credentials
|
||||
generated by this backend. The lease specifies the duration that a
|
||||
credential will be valid for, as well as the maximum session for
|
||||
a set of credentials.
|
||||
|
||||
The format for the lease is "1h" or integer and then unit. The longest
|
||||
unit is hour.
|
||||
`
|
||||
120
builtin/logical/database/path_role_create.go
Normal file
120
builtin/logical/database/path_role_create.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func pathRoleCreate(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "creds/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleCreateRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleCreateReadHelpSyn,
|
||||
HelpDescription: pathRoleCreateReadHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: enter")
|
||||
defer b.logger.Trace("postgres/pathRoleCreateRead: exit")
|
||||
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: getting role")
|
||||
role, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
// Determine if we have a lease
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: getting lease")
|
||||
lease, err := b.Lease(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Unlike some other backends we need a lease here (can't leave as 0 and
|
||||
// let core fill it in) because Postgres also expires users as a safety
|
||||
// measure, so cannot be zero
|
||||
if lease == nil {
|
||||
lease = &configLease{
|
||||
Lease: b.System().DefaultLeaseTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the username, password and expiration. PG limits user to 63 characters
|
||||
displayName := req.DisplayName
|
||||
if len(displayName) > 26 {
|
||||
displayName = displayName[:26]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
||||
if len(username) > 63 {
|
||||
username = username[:63]
|
||||
}
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiration := time.Now().
|
||||
Add(lease.Lease).
|
||||
Format("2006-01-02 15:04:05-0700")
|
||||
|
||||
// Get our handle
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
|
||||
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
db, ok := b.connections[role.DBName]
|
||||
if !ok {
|
||||
// TODO: return a resp error instead?
|
||||
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
|
||||
}
|
||||
|
||||
err = db.CreateUser(role.CreationStatement, username, password, expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: generating secret")
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}, map[string]interface{}{
|
||||
"username": username,
|
||||
"role": name,
|
||||
})
|
||||
resp.Secret.TTL = lease.Lease
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
const pathRoleCreateReadHelpSyn = `
|
||||
Request database credentials for a certain role.
|
||||
`
|
||||
|
||||
const pathRoleCreateReadHelpDesc = `
|
||||
This path reads database credentials for a certain role. The
|
||||
database credentials will be generated on demand and will be automatically
|
||||
revoked when the lease is up.
|
||||
`
|
||||
161
builtin/logical/database/path_roles.go
Normal file
161
builtin/logical/database/path_roles.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathListRoles(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/?$",
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ListOperation: b.pathRoleList,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func pathRoles(b *databaseBackend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "roles/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the role.",
|
||||
},
|
||||
|
||||
"db_name": {
|
||||
Type: framework.TypeString,
|
||||
Description: "Name of the database this role acts on.",
|
||||
},
|
||||
|
||||
"creation_statement": {
|
||||
Type: framework.TypeString,
|
||||
Description: "SQL string to create a user. See help for more info.",
|
||||
},
|
||||
|
||||
"revocation_statement": {
|
||||
Type: framework.TypeString,
|
||||
Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated
|
||||
string, a base64-encoded semicolon-separated string, a serialized JSON string
|
||||
array, or a base64-encoded serialized JSON string array. The '{{name}}' value
|
||||
will be substituted.`,
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathRoleRead,
|
||||
logical.UpdateOperation: b.pathRoleCreate,
|
||||
logical.DeleteOperation: b.pathRoleDelete,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRoleHelpSyn,
|
||||
HelpDescription: pathRoleHelpDesc,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
err := req.Storage.Delete("role/" + data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
role, err := b.Role(req.Storage, data.Get("name").(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"creation_statment": role.CreationStatement,
|
||||
"revocation_statement": role.RevocationStatement,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
entries, err := req.Storage.List("role/")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logical.ListResponse(entries), nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
dbName := data.Get("db_name").(string)
|
||||
creationStmt := data.Get("creation_statement").(string)
|
||||
revocationStmt := data.Get("revocation_statement").(string)
|
||||
|
||||
// TODO: Think about preparing the statments to test.
|
||||
|
||||
// Store it
|
||||
entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{
|
||||
DBName: dbName,
|
||||
CreationStatement: creationStmt,
|
||||
RevocationStatement: revocationStmt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := req.Storage.Put(entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type roleEntry struct {
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"`
|
||||
RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
Manage the roles that can be created with this backend.
|
||||
`
|
||||
|
||||
const pathRoleHelpDesc = `
|
||||
This path lets you manage the roles that can be created with this backend.
|
||||
|
||||
The "sql" parameter customizes the SQL string used to create the role.
|
||||
This can be a sequence of SQL queries. Some substitution will be done to the
|
||||
SQL string for certain keys. The names of the variables must be surrounded
|
||||
by "{{" and "}}" to be replaced.
|
||||
|
||||
* "name" - The random username generated for the DB user.
|
||||
|
||||
* "password" - The random password generated for the DB user.
|
||||
|
||||
* "expiration" - The timestamp when this user will expire.
|
||||
|
||||
Example of a decent SQL query to use:
|
||||
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
|
||||
Note the above user would be able to access everything in schema public.
|
||||
For more complex GRANT clauses, see the PostgreSQL manual.
|
||||
|
||||
The "revocation_sql" parameter customizes the SQL string used to revoke a user.
|
||||
Example of a decent revocation SQL query to use:
|
||||
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
|
||||
REVOKE USAGE ON SCHEMA public FROM {{name}};
|
||||
DROP ROLE IF EXISTS {{name}};
|
||||
`
|
||||
147
builtin/logical/database/secret_creds.go
Normal file
147
builtin/logical/database/secret_creds.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const SecretCredsType = "creds"
|
||||
|
||||
func secretCreds(b *databaseBackend) *framework.Secret {
|
||||
return &framework.Secret{
|
||||
Type: SecretCredsType,
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"username": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Username",
|
||||
},
|
||||
|
||||
"password": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "Password",
|
||||
},
|
||||
},
|
||||
|
||||
Renew: b.secretCredsRenew,
|
||||
Revoke: b.secretCredsRevoke,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
dbName := d.Get("name").(string)
|
||||
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
|
||||
// Get our connection
|
||||
db, ok := b.connections[dbName]
|
||||
if !ok {
|
||||
return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName))
|
||||
}
|
||||
|
||||
// Get the lease information
|
||||
lease, err := b.Lease(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease == nil {
|
||||
lease = &configLease{}
|
||||
}
|
||||
|
||||
f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System())
|
||||
resp, err := f(req, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() {
|
||||
expiration := expireTime.Format("2006-01-02 15:04:05-0700")
|
||||
|
||||
err := db.RenewUser(username, expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
// Get the username from the internal data
|
||||
usernameRaw, ok := req.Secret.InternalData["username"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("secret is missing username internal data")
|
||||
}
|
||||
username, ok := usernameRaw.(string)
|
||||
|
||||
var revocationSQL string
|
||||
var resp *logical.Response
|
||||
|
||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
/* TODO: think about how to handle this case.
|
||||
if !ok {
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
if resp == nil {
|
||||
resp = &logical.Response{}
|
||||
}
|
||||
resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string)))
|
||||
} else {
|
||||
revocationSQL = role.RevocationStatement
|
||||
}
|
||||
}*/
|
||||
|
||||
// Grab the read lock
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
|
||||
// Get our connection
|
||||
db, ok := b.connections[role.DBName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
|
||||
}
|
||||
|
||||
// TODO: Maybe move this down into db package?
|
||||
switch revocationSQL {
|
||||
|
||||
// This is the default revocation logic. If revocation SQL is provided it
|
||||
// is simply executed as-is.
|
||||
case "":
|
||||
err := db.DefaultRevokeUser(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We have revocation SQL, execute directly, within a transaction
|
||||
default:
|
||||
err := db.CustomRevokeUser(username, revocationSQL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/hashicorp/vault/builtin/logical/aws"
|
||||
"github.com/hashicorp/vault/builtin/logical/cassandra"
|
||||
"github.com/hashicorp/vault/builtin/logical/consul"
|
||||
"github.com/hashicorp/vault/builtin/logical/database"
|
||||
"github.com/hashicorp/vault/builtin/logical/mongodb"
|
||||
"github.com/hashicorp/vault/builtin/logical/mssql"
|
||||
"github.com/hashicorp/vault/builtin/logical/mysql"
|
||||
@@ -91,6 +92,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory {
|
||||
"mysql": mysql.Factory,
|
||||
"ssh": ssh.Factory,
|
||||
"rabbitmq": rabbitmq.Factory,
|
||||
"database": database.Factory,
|
||||
},
|
||||
ShutdownCh: command.MakeShutdownCh(),
|
||||
SighupCh: command.MakeSighupCh(),
|
||||
|
||||
Reference in New Issue
Block a user