diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go new file mode 100644 index 0000000000..8b7fa36700 --- /dev/null +++ b/builtin/logical/database/backend.go @@ -0,0 +1,104 @@ +package database + +import ( + "strings" + "sync" + + log "github.com/mgutz/logxi/v1" + + "github.com/hashicorp/vault/builtin/logical/database/dbs" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func Factory(conf *logical.BackendConfig) (logical.Backend, error) { + return Backend(conf).Setup(conf) +} + +func Backend(conf *logical.BackendConfig) *databaseBackend { + var b databaseBackend + b.Backend = &framework.Backend{ + Help: strings.TrimSpace(backendHelp), + + Paths: []*framework.Path{ + pathConfigConnection(&b), + pathConfigLease(&b), + pathListRoles(&b), + pathRoles(&b), + pathRoleCreate(&b), + }, + + Secrets: []*framework.Secret{ + secretCreds(&b), + }, + + Clean: b.resetAllDBs, + } + + b.logger = conf.Logger + b.connections = make(map[string]dbs.DatabaseType) + return &b +} + +type databaseBackend struct { + connections map[string]dbs.DatabaseType + logger log.Logger + + *framework.Backend + sync.RWMutex +} + +// resetAllDBs closes all connections from all database types +func (b *databaseBackend) resetAllDBs() { + b.logger.Trace("postgres/resetdb: enter") + defer b.logger.Trace("postgres/resetdb: exit") + + b.Lock() + defer b.Unlock() + + for _, db := range b.connections { + db.Close() + } +} + +// Lease returns the lease information +func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) { + entry, err := s.Get("config/lease") + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result configLease + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { + entry, err := s.Get("role/" + n) + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result roleEntry + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +const backendHelp = ` +The PostgreSQL backend dynamically generates database users. + +After mounting this backend, configure it using the endpoints within +the "config/" path. +` diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go new file mode 100644 index 0000000000..a203c9b191 --- /dev/null +++ b/builtin/logical/database/backend_test.go @@ -0,0 +1,620 @@ +package database + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "os" + "path" + "reflect" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/logical" + logicaltest "github.com/hashicorp/vault/logical/testing" + "github.com/lib/pq" + "github.com/mitchellh/mapstructure" + "github.com/ory-am/dockertest" +) + +var ( + testImagePull sync.Once +) + +func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) { + if os.Getenv("PG_URL") != "" { + return "", os.Getenv("PG_URL") + } + + // Without this the checks for whether the container has started seem to + // never actually pass. There's really no reason to expose the test + // containers, so don't. + dockertest.BindDockerToLocalhost = "yep" + + testImagePull.Do(func() { + dockertest.Pull("postgres") + }) + + cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { + // This will cause a validation to run + resp, err := b.HandleRequest(&logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: "config/connection", + Data: map[string]interface{}{ + "connection_url": connURL, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + // It's likely not up and running yet, so return false and try again + return false + } + if resp == nil { + t.Fatal("expected warning") + } + + retURL = connURL + return true + }) + + if connErr != nil { + t.Fatalf("could not connect to database: %v", connErr) + } + + return +} + +func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { + err := cid.KillRemove() + if err != nil { + t.Fatal(err) + } +} + +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "value": "", + "max_open_connections": 9, + "max_idle_connections": 7, + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(configData, "verify_connection") + if !reflect.DeepEqual(configData, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data) + } +} + +func TestBackend_basic(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepReadCreds(t, b, config.StorageView, "web", connURL), + }, + }) +} + +func TestBackend_roleCrud(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepReadRole(t, "web", testRole), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web", ""), + }, + }) +} + +func TestBackend_BlockStatements(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice) + if err != nil { + t.Fatal(err) + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + // This will also validate the query + testAccStepCreateRole(t, "web-block", testBlockStatementRole, true), + testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false), + }, + }) +} + +func TestBackend_roleReadOnly(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false), + testAccStepReadRole(t, "web-readonly", testReadOnlyRole), + testAccStepCreateTable(t, b, config.StorageView, "web", connURL), + testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), + testAccStepDropTable(t, b, config.StorageView, "web", connURL), + testAccStepDeleteRole(t, "web-readonly"), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web-readonly", ""), + }, + }) +} + +func TestBackend_roleReadOnly_revocationSQL(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false), + testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false), + testAccStepReadRole(t, "web-readonly", testReadOnlyRole), + testAccStepCreateTable(t, b, config.StorageView, "web", connURL), + testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), + testAccStepDropTable(t, b, config.StorageView, "web", connURL), + testAccStepDeleteRole(t, "web-readonly"), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web-readonly", ""), + }, + }) +} + +func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Data: d, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if expectError { + if resp.Data == nil { + return fmt.Errorf("data is nil") + } + var e struct { + Error string `mapstructure:"error"` + } + if err := mapstructure.Decode(resp.Data, &e); err != nil { + return err + } + if len(e.Error) == 0 { + return fmt.Errorf("expected error, but write succeeded.") + } + return nil + } else if resp != nil && resp.IsError() { + return fmt.Errorf("got an error response: %v", resp.Error()) + } + return nil + }, + } +} + +func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: path.Join("roles", name), + Data: map[string]interface{}{ + "sql": sql, + }, + ErrorOk: expectFail, + } +} + +func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: path.Join("roles", name), + Data: map[string]interface{}{ + "sql": sql, + "revocation_sql": revocationSQL, + }, + ErrorOk: expectFail, + } +} + +func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.DeleteOperation, + Path: path.Join("roles", name), + } +} + +func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + returnedRows := func() int { + stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');") + if err != nil { + return -1 + } + defer stmt.Close() + + rows, err := stmt.Query(d.Username) + if err != nil { + return -1 + } + defer rows.Close() + + i := 0 + for rows.Next() { + i++ + } + return i + } + + // minNumPermissions is the minimum number of permissions that will always be present. + const minNumPermissions = 2 + + userRows := returnedRows() + if userRows < minNumPermissions { + t.Fatalf("did not get expected number of rows, got %d", userRows) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + "role": name, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + userRows = returnedRows() + // User shouldn't exist so returnedRows() should encounter an error and exit with -1 + if userRows != -1 { + t.Fatalf("did not get expected number of rows, got %d", userRows) + } + + return nil + }, + } +} + +func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);") + if err != nil { + t.Fatal(err) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + return nil + }, + } +} + +func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("DROP TABLE test;") + if err != nil { + t.Fatal(err) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + return nil + }, + } +} + +func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "roles/" + name, + Check: func(resp *logical.Response) error { + if resp == nil { + if sql == "" { + return nil + } + + return fmt.Errorf("bad: %#v", resp) + } + + var d struct { + SQL string `mapstructure:"sql"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + if d.SQL != sql { + return fmt.Errorf("bad: %#v", resp) + } + + return nil + }, + } +} + +const testRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const testBlockStatementRole = ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ + +CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; +GRANT "foo-role" TO "{{name}}"; +ALTER ROLE "{{name}}" SET search_path = foo; +GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; +` + +var testBlockStatementRoleSlice = []string{ + ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ +`, + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, + `GRANT "foo-role" TO "{{name}}";`, + `ALTER ROLE "{{name}}" SET search_path = foo;`, + `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, +} + +const defaultRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; +REVOKE USAGE ON SCHEMA public FROM {{name}}; + +DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go new file mode 100644 index 0000000000..8c7a068bec --- /dev/null +++ b/builtin/logical/database/dbs/cassandra.go @@ -0,0 +1,194 @@ +package dbs + +import ( + "crypto/tls" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" +) + +type Cassandra struct { + // Session is goroutine safe, however, since we reinitialize + // it when connection info changes, we want to make sure we + // can close it and use a new connection; hence the lock + session *gocql.Session + config ConnectionConfig + + sync.RWMutex +} + +func (c *Cassandra) Type() string { + return cassandraTypeName +} + +func (c *Cassandra) Connection() (*gocql.Session, error) { + // Grab the write lock + c.Lock() + defer c.Unlock() + + // If we already have a DB, we got it! + if c.session != nil { + return c.session, nil + } + + session, err := createSession(c.config) + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + c.session = session + + return session, nil +} + +func (p *Cassandra) Close() { + // Grab the write lock + p.Lock() + defer p.Unlock() + + if p.session != nil { + p.session.Close() + } + + p.session = nil +} + +func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) { + // Grab the write lock + p.Lock() + p.config = config + p.Unlock() + + p.Close() + return p.Connection() +} + +func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error { + // Get the connection + db, err := p.Connection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) RenewUser(username, expiration string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: This is Racey + // Grab the read lock + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: this is Racey + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) DefaultRevokeUser(username string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + db, err := p.Connection() + + return nil +} + +func createSession(cfg *ConnectionConfig) (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: cfg.Username, + Password: cfg.Password, + } + + clusterConfig.ProtoVersion = cfg.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + + if cfg.TLS { + var tlsConfig *tls.Config + if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { + if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(cfg.Certificate) > 0 { + certBundle.Certificate = cfg.Certificate + certBundle.PrivateKey = cfg.PrivateKey + } + if len(cfg.IssuingCA) > 0 { + certBundle.IssuingCA = cfg.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + + if cfg.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("Error creating session: %s", err) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("Error validating connection info: %s", err) + } + + return session, nil +} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go new file mode 100644 index 0000000000..ee7b15b64d --- /dev/null +++ b/builtin/logical/database/dbs/db.go @@ -0,0 +1,56 @@ +package dbs + +import ( + "database/sql" + "errors" + "fmt" + "strings" +) + +const ( + postgreSQLTypeName = "postgres" + cassandraTypeName = "cassandra" +) + +var ( + ErrUnsupportedDatabaseType = errors.New("Unsupported database type") +) + +func Factory(conf ConnectionConfig) (DatabaseType, error) { + switch conf.ConnectionType { + case postgreSQLTypeName: + return &PostgreSQL{ + config: conf, + }, nil + } + + return nil, ErrUnsupportedDatabaseType +} + +type DatabaseType interface { + Type() string + Connection() (*sql.DB, error) + Close() + Reset(ConnectionConfig) (*sql.DB, error) + CreateUser(createStmt, username, password, expiration string) error + RenewUser(username, expiration string) error + CustomRevokeUser(username, revocationSQL string) error + DefaultRevokeUser(username string) error +} + +type ConnectionConfig struct { + ConnectionType string `json:"type" structs:"type" mapstructure:"type"` + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` +} + +// Query templates a query for us. +func queryHelper(tpl string, data map[string]string) string { + for k, v := range data { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + + return tpl +} diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go new file mode 100644 index 0000000000..ea7d08f8ac --- /dev/null +++ b/builtin/logical/database/dbs/postgresql.go @@ -0,0 +1,336 @@ +package dbs + +import ( + "database/sql" + "fmt" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/strutil" + "github.com/lib/pq" +) + +type PostgreSQL struct { + db *sql.DB + config ConnectionConfig + + sync.RWMutex +} + +func (p *PostgreSQL) Type() string { + return postgreSQLTypeName +} + +func (p *PostgreSQL) Connection() (*sql.DB, error) { + // Grab the write lock + p.Lock() + defer p.Unlock() + + // If we already have a DB, we got it! + if p.db != nil { + if err := p.db.Ping(); err == nil { + return p.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + p.db.Close() + } + + // Otherwise, attempt to make connection + conn := p.config.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } else { + conn += " timezone=utc" + } + + var err error + p.db, err = sql.Open("postgres", conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + p.db.SetMaxOpenConns(p.config.MaxOpenConnections) + p.db.SetMaxIdleConns(p.config.MaxIdleConnections) + + return p.db, nil +} + +func (p *PostgreSQL) Close() { + // Grab the write lock + p.Lock() + defer p.Unlock() + + if p.db != nil { + p.db.Close() + } + + p.db = nil +} + +func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) { + // Grab the write lock + p.Lock() + p.config = config + p.Unlock() + + p.Close() + return p.Connection() +} + +func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error { + // Get the connection + db, err := p.Connection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + // Start a transaction + // b.logger.Trace("postgres/pathRoleCreateRead: starting transaction") + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + // b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction") + tx.Rollback() + }() + // Return the secret + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // b.logger.Trace("postgres/pathRoleCreateRead: preparing statement") + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expiration, + })) + if err != nil { + return err + } + defer stmt.Close() + // b.logger.Trace("postgres/pathRoleCreateRead: executing statement") + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + + // b.logger.Trace("postgres/pathRoleCreateRead: committing transaction") + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) RenewUser(username, expiration string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: This is Racey + // Grab the read lock + p.RLock() + defer p.RUnlock() + + query := fmt.Sprintf( + "ALTER ROLE %s VALID UNTIL '%s';", + pq.QuoteIdentifier(username), + expiration) + + stmt, err := db.Prepare(query) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: this is Racey + p.RLock() + defer p.RUnlock() + + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + tx.Rollback() + }() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + + if _, err := stmt.Exec(); err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) DefaultRevokeUser(username string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + db, err := p.Connection() + if err != nil { + return err + } + + // Check if the role exists + var exists bool + err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return err + } + + if exists == false { + return nil + } + + // Query for permissions; we need to revoke permissions before we can drop + // the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query(username) + if err != nil { + return err + } + defer rows.Close() + + const initialNumRevocations = 16 + revocationStmts := make([]string, 0, initialNumRevocations) + for rows.Next() { + var schema string + err = rows.Scan(&schema) + if err != nil { + // keep going; remove as many permissions as possible right now + continue + } + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE USAGE ON SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + } + + // for good measure, revoke all privileges and usage on schema public + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE USAGE ON SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + // get the current database name so we can issue a REVOKE CONNECT for + // this username + var dbname sql.NullString + if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { + return err + } + + if dbname.Valid { + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE CONNECT ON DATABASE %s FROM %s;`, + pq.QuoteIdentifier(dbname.String), + pq.QuoteIdentifier(username))) + } + + // again, here, we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revocationStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all privileges are revoked + if rows.Err() != nil { + return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) + } + + // Drop this user + stmt, err = db.Prepare(fmt.Sprintf( + `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go new file mode 100644 index 0000000000..be017ea35c --- /dev/null +++ b/builtin/logical/database/path_config_connection.go @@ -0,0 +1,188 @@ +package database + +import ( + "fmt" + + "github.com/fatih/structs" + "github.com/hashicorp/vault/builtin/logical/database/dbs" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + _ "github.com/lib/pq" +) + +func pathConfigConnection(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of this DB type", + }, + + "connection_type": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "DB type (e.g. postgres)", + }, + + "connection_url": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "DB connection string", + }, + + "connection_details": &framework.FieldSchema{ + Type: framework.TypeMap, + Description: "Connection details for specified connection type.", + }, + + "verify_connection": &framework.FieldSchema{ + Type: framework.TypeBool, + Default: true, + Description: `If set, connection_url is verified by actually connecting to the database`, + }, + + "max_open_connections": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Maximum number of open connections to the database; +a zero uses the default value of two and a +negative value means unlimited`, + }, + + "max_idle_connections": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Maximum number of idle connections to the database; +a zero uses the value of max_open_connections +and a negative value disables idle connections. +If larger than max_open_connections it will be +reduced to the same size.`, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.pathConnectionWrite, + logical.ReadOperation: b.pathConnectionRead, + }, + + HelpSynopsis: pathConfigConnectionHelpSyn, + HelpDescription: pathConfigConnectionHelpDesc, + } +} + +// pathConnectionRead reads out the connection configuration +func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config dbs.ConnectionConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil +} + +func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + connURL := data.Get("connection_url").(string) + connType := data.Get("connection_type").(string) + + maxOpenConns := data.Get("max_open_connections").(int) + if maxOpenConns == 0 { + maxOpenConns = 2 + } + + maxIdleConns := data.Get("max_idle_connections").(int) + if maxIdleConns == 0 { + maxIdleConns = maxOpenConns + } + if maxIdleConns > maxOpenConns { + maxIdleConns = maxOpenConns + } + + config := dbs.ConnectionConfig{ + ConnectionType: connType, + ConnectionURL: connURL, + MaxOpenConnections: maxOpenConns, + MaxIdleConnections: maxIdleConns, + } + + name := data.Get("name").(string) + + // Grab the mutex lock + b.Lock() + defer b.Unlock() + + var err error + var db dbs.DatabaseType + if _, ok := b.connections[name]; ok { + + // Don't allow the connection type to change + if b.connections[name].Type() != connType { + return logical.ErrorResponse("can not change type of existing connection"), nil + } + + db = b.connections[name] + } else { + db, err = dbs.Factory(config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + } + + /* + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + } + */ + + // Store it + entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + // Reset the DB connection + db.Reset(config) + b.connections[name] = db + + resp := &logical.Response{} + resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + + return resp, nil +} + +const pathConfigConnectionHelpSyn = ` +Configure the connection string to talk to PostgreSQL. +` + +const pathConfigConnectionHelpDesc = ` +This path configures the connection string used to connect to PostgreSQL. +The value of the string can be a URL, or a PG style string in the +format of "user=foo host=bar" etc. + +The URL looks like: +"postgresql://user:pass@host:port/dbname" + +When configuring the connection string, the backend will verify its validity. +` diff --git a/builtin/logical/database/path_config_lease.go b/builtin/logical/database/path_config_lease.go new file mode 100644 index 0000000000..5cc40a056e --- /dev/null +++ b/builtin/logical/database/path_config_lease.go @@ -0,0 +1,103 @@ +package database + +import ( + "fmt" + "time" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathConfigLease(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "config/lease", + Fields: map[string]*framework.FieldSchema{ + "lease": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Default lease for roles.", + }, + + "lease_max": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Maximum time a credential is valid for.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathLeaseRead, + logical.UpdateOperation: b.pathLeaseWrite, + }, + + HelpSynopsis: pathConfigLeaseHelpSyn, + HelpDescription: pathConfigLeaseHelpDesc, + } +} + +func (b *databaseBackend) pathLeaseWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + leaseRaw := d.Get("lease").(string) + leaseMaxRaw := d.Get("lease_max").(string) + + lease, err := time.ParseDuration(leaseRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid lease: %s", err)), nil + } + leaseMax, err := time.ParseDuration(leaseMaxRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid lease: %s", err)), nil + } + + // Store it + entry, err := logical.StorageEntryJSON("config/lease", &configLease{ + Lease: lease, + LeaseMax: leaseMax, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +func (b *databaseBackend) pathLeaseRead( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + lease, err := b.Lease(req.Storage) + + if err != nil { + return nil, err + } + if lease == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "lease": lease.Lease.String(), + "lease_max": lease.LeaseMax.String(), + }, + }, nil +} + +type configLease struct { + Lease time.Duration + LeaseMax time.Duration +} + +const pathConfigLeaseHelpSyn = ` +Configure the default lease information for generated credentials. +` + +const pathConfigLeaseHelpDesc = ` +This configures the default lease information used for credentials +generated by this backend. The lease specifies the duration that a +credential will be valid for, as well as the maximum session for +a set of credentials. + +The format for the lease is "1h" or integer and then unit. The longest +unit is hour. +` diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go new file mode 100644 index 0000000000..2a2386d012 --- /dev/null +++ b/builtin/logical/database/path_role_create.go @@ -0,0 +1,120 @@ +package database + +import ( + "fmt" + "time" + + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + _ "github.com/lib/pq" +) + +func pathRoleCreate(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "creds/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the role.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRoleCreateRead, + }, + + HelpSynopsis: pathRoleCreateReadHelpSyn, + HelpDescription: pathRoleCreateReadHelpDesc, + } +} + +func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.logger.Trace("postgres/pathRoleCreateRead: enter") + defer b.logger.Trace("postgres/pathRoleCreateRead: exit") + + name := data.Get("name").(string) + + // Get the role + b.logger.Trace("postgres/pathRoleCreateRead: getting role") + role, err := b.Role(req.Storage, name) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil + } + + // Determine if we have a lease + b.logger.Trace("postgres/pathRoleCreateRead: getting lease") + lease, err := b.Lease(req.Storage) + if err != nil { + return nil, err + } + // Unlike some other backends we need a lease here (can't leave as 0 and + // let core fill it in) because Postgres also expires users as a safety + // measure, so cannot be zero + if lease == nil { + lease = &configLease{ + Lease: b.System().DefaultLeaseTTL(), + } + } + + // Generate the username, password and expiration. PG limits user to 63 characters + displayName := req.DisplayName + if len(displayName) > 26 { + displayName = displayName[:26] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if len(username) > 63 { + username = username[:63] + } + password, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + expiration := time.Now(). + Add(lease.Lease). + Format("2006-01-02 15:04:05-0700") + + // Get our handle + b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") + + b.RLock() + defer b.RUnlock() + db, ok := b.connections[role.DBName] + if !ok { + // TODO: return a resp error instead? + return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) + } + + err = db.CreateUser(role.CreationStatement, username, password, expiration) + if err != nil { + return nil, err + } + + b.logger.Trace("postgres/pathRoleCreateRead: generating secret") + resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ + "username": username, + "password": password, + }, map[string]interface{}{ + "username": username, + "role": name, + }) + resp.Secret.TTL = lease.Lease + return resp, nil +} + +const pathRoleCreateReadHelpSyn = ` +Request database credentials for a certain role. +` + +const pathRoleCreateReadHelpDesc = ` +This path reads database credentials for a certain role. The +database credentials will be generated on demand and will be automatically +revoked when the lease is up. +` diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go new file mode 100644 index 0000000000..e06518b289 --- /dev/null +++ b/builtin/logical/database/path_roles.go @@ -0,0 +1,161 @@ +package database + +import ( + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathListRoles(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "roles/?$", + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: b.pathRoleList, + }, + + HelpSynopsis: pathRoleHelpSyn, + HelpDescription: pathRoleHelpDesc, + } +} + +func pathRoles(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "roles/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": { + Type: framework.TypeString, + Description: "Name of the role.", + }, + + "db_name": { + Type: framework.TypeString, + Description: "Name of the database this role acts on.", + }, + + "creation_statement": { + Type: framework.TypeString, + Description: "SQL string to create a user. See help for more info.", + }, + + "revocation_statement": { + Type: framework.TypeString, + Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted.`, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRoleRead, + logical.UpdateOperation: b.pathRoleCreate, + logical.DeleteOperation: b.pathRoleDelete, + }, + + HelpSynopsis: pathRoleHelpSyn, + HelpDescription: pathRoleHelpDesc, + } +} + +func (b *databaseBackend) pathRoleDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + err := req.Storage.Delete("role/" + data.Get("name").(string)) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + role, err := b.Role(req.Storage, data.Get("name").(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "creation_statment": role.CreationStatement, + "revocation_statement": role.RevocationStatement, + }, + }, nil +} + +func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + entries, err := req.Storage.List("role/") + if err != nil { + return nil, err + } + + return logical.ListResponse(entries), nil +} + +func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + dbName := data.Get("db_name").(string) + creationStmt := data.Get("creation_statement").(string) + revocationStmt := data.Get("revocation_statement").(string) + + // TODO: Think about preparing the statments to test. + + // Store it + entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ + DBName: dbName, + CreationStatement: creationStmt, + RevocationStatement: revocationStmt, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +type roleEntry struct { + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` +} + +const pathRoleHelpSyn = ` +Manage the roles that can be created with this backend. +` + +const pathRoleHelpDesc = ` +This path lets you manage the roles that can be created with this backend. + +The "sql" parameter customizes the SQL string used to create the role. +This can be a sequence of SQL queries. Some substitution will be done to the +SQL string for certain keys. The names of the variables must be surrounded +by "{{" and "}}" to be replaced. + + * "name" - The random username generated for the DB user. + + * "password" - The random password generated for the DB user. + + * "expiration" - The timestamp when this user will expire. + +Example of a decent SQL query to use: + + CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; + +Note the above user would be able to access everything in schema public. +For more complex GRANT clauses, see the PostgreSQL manual. + +The "revocation_sql" parameter customizes the SQL string used to revoke a user. +Example of a decent revocation SQL query to use: + + REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; + REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; + REVOKE USAGE ON SCHEMA public FROM {{name}}; + DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go new file mode 100644 index 0000000000..30c4a6430f --- /dev/null +++ b/builtin/logical/database/secret_creds.go @@ -0,0 +1,147 @@ +package database + +import ( + "errors" + "fmt" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +const SecretCredsType = "creds" + +func secretCreds(b *databaseBackend) *framework.Secret { + return &framework.Secret{ + Type: SecretCredsType, + Fields: map[string]*framework.FieldSchema{ + "username": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Username", + }, + + "password": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Password", + }, + }, + + Renew: b.secretCredsRenew, + Revoke: b.secretCredsRevoke, + } +} + +func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + dbName := d.Get("name").(string) + + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) + + // Get our connection + db, ok := b.connections[dbName] + if !ok { + return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName)) + } + + // Get the lease information + lease, err := b.Lease(req.Storage) + if err != nil { + return nil, err + } + if lease == nil { + lease = &configLease{} + } + + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System()) + resp, err := f(req, d) + if err != nil { + return nil, err + } + + // Make sure we increase the VALID UNTIL endpoint for this user. + if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { + expiration := expireTime.Format("2006-01-02 15:04:05-0700") + + err := db.RenewUser(username, expiration) + if err != nil { + return nil, err + } + } + + return resp, nil +} + +func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) + + var revocationSQL string + var resp *logical.Response + + roleNameRaw, ok := req.Secret.InternalData["role"] + if !ok { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + } + + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + } + + /* TODO: think about how to handle this case. + if !ok { + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + if resp == nil { + resp = &logical.Response{} + } + resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) + } else { + revocationSQL = role.RevocationStatement + } + }*/ + + // Grab the read lock + b.RLock() + defer b.RUnlock() + + // Get our connection + db, ok := b.connections[role.DBName] + if !ok { + return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) + } + + // TODO: Maybe move this down into db package? + switch revocationSQL { + + // This is the default revocation logic. If revocation SQL is provided it + // is simply executed as-is. + case "": + err := db.DefaultRevokeUser(username) + if err != nil { + return nil, err + } + + // We have revocation SQL, execute directly, within a transaction + default: + err := db.CustomRevokeUser(username, revocationSQL) + if err != nil { + return nil, err + } + } + + return resp, nil +} diff --git a/cli/commands.go b/cli/commands.go index 1901111779..13f7c8b25a 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/vault/builtin/logical/aws" "github.com/hashicorp/vault/builtin/logical/cassandra" "github.com/hashicorp/vault/builtin/logical/consul" + "github.com/hashicorp/vault/builtin/logical/database" "github.com/hashicorp/vault/builtin/logical/mongodb" "github.com/hashicorp/vault/builtin/logical/mssql" "github.com/hashicorp/vault/builtin/logical/mysql" @@ -91,6 +92,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { "mysql": mysql.Factory, "ssh": ssh.Factory, "rabbitmq": rabbitmq.Factory, + "database": database.Factory, }, ShutdownCh: command.MakeShutdownCh(), SighupCh: command.MakeSighupCh(),