Add functionaility to build db objects from disk so restarts work

This commit is contained in:
Brian Kassouf
2017-03-28 11:30:45 -07:00
parent d93378bb29
commit 6de5cfad5e
8 changed files with 92 additions and 74 deletions

View File

@@ -1,6 +1,7 @@
package database package database
import ( import (
"fmt"
"strings" "strings"
"sync" "sync"
@@ -52,14 +53,11 @@ type databaseBackend struct {
logger log.Logger logger log.Logger
*framework.Backend *framework.Backend
sync.RWMutex sync.Mutex
} }
// resetAllDBs closes all connections from all database types // resetAllDBs closes all connections from all database types
func (b *databaseBackend) closeAllDBs() { func (b *databaseBackend) closeAllDBs() {
b.logger.Trace("postgres/resetdb: enter")
defer b.logger.Trace("postgres/resetdb: exit")
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
@@ -68,6 +66,46 @@ func (b *databaseBackend) closeAllDBs() {
} }
} }
// This function is used to retrieve a database object either from the cached
// connection map or by using the database config in storage. The caller of this
// function needs to hold the backend's lock.
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) {
// if the object already is built and cached, return it
db, ok := b.connections[name]
if ok {
return db, nil
}
entry, err := s.Get(fmt.Sprintf("dbs/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration with name: %s", name)
}
if entry == nil {
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
}
var config dbs.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
factory := config.GetFactory()
db, err = factory(&config, b.System(), b.logger)
if err != nil {
return nil, err
}
err = db.Initialize(config.ConnectionDetails)
if err != nil {
return nil, err
}
b.connections[name] = db
return db, nil
}
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get("role/" + n) entry, err := s.Get("role/" + n)
if err != nil { if err != nil {

View File

@@ -20,7 +20,7 @@ import (
) )
var ( var (
errNotInitalized = errors.New("Connection has not been initalized") errNotInitalized = errors.New("connection has not been initalized")
) )
type ConnectionProducer interface { type ConnectionProducer interface {
@@ -142,7 +142,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) er
c.initalized = true c.initalized = true
if _, err := c.connection(); err != nil { if _, err := c.connection(); err != nil {
return fmt.Errorf("Error Initalizing Connection: %s", err) return fmt.Errorf("error Initalizing Connection: %s", err)
} }
return nil return nil
@@ -244,7 +244,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
session, err := clusterConfig.CreateSession() session, err := clusterConfig.CreateSession()
if err != nil { if err != nil {
return nil, fmt.Errorf("Error creating session: %s", err) return nil, fmt.Errorf("error creating session: %s", err)
} }
// Set consistency // Set consistency
@@ -260,7 +260,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
// Verify the info // Verify the info
err = session.Query(`LIST USERS`).Exec() err = session.Query(`LIST USERS`).Exec()
if err != nil { if err != nil {
return nil, fmt.Errorf("Error validating connection info: %s", err) return nil, fmt.Errorf("error validating connection info: %s", err)
} }
return session, nil return session, nil

View File

@@ -18,10 +18,11 @@ const (
) )
var ( var (
ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ErrUnsupportedDatabaseType = errors.New("unsupported database type")
ErrEmptyCreationStatement = errors.New("Empty creation statements") ErrEmptyCreationStatement = errors.New("empty creation statements")
) )
// Factory function for
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {

View File

@@ -1,7 +1,6 @@
package database package database
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -34,47 +33,24 @@ func pathResetConnection(b *databaseBackend) *framework.Path {
func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string) name := data.Get("name").(string)
if name == "" { if name == "" {
return nil, errors.New("No database name set") return logical.ErrorResponse("Empty name attribute given"), nil
} }
// Grab the mutex lock // Grab the mutex lock
b.Lock() b.Lock()
defer b.Unlock() defer b.Unlock()
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) db, ok := b.connections[name]
if err != nil { if ok {
return nil, fmt.Errorf("failed to read connection configuration") db.Close()
} delete(b.connections, name)
if entry == nil {
return nil, nil
} }
var config dbs.DatabaseConfig db, err := b.getOrCreateDBObj(req.Storage, name)
if err := entry.DecodeJSON(&config); err != nil { if err != nil {
return nil, err return nil, err
} }
db, ok := b.connections[name]
if !ok {
return logical.ErrorResponse("Can not change type of existing connection."), nil
}
db.Close()
factory := config.GetFactory()
db, err = factory(&config, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
err = db.Initialize(config.ConnectionDetails)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
b.connections[name] = db
return nil, nil return nil, nil
} }
@@ -306,7 +282,6 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
return nil, err return nil, err
} }
// Reset the DB connection
resp := &logical.Response{} resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")

View File

@@ -27,34 +27,28 @@ func pathRoleCreate(b *databaseBackend) *framework.Path {
} }
func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.logger.Trace("postgres/pathRoleCreateRead: enter")
defer b.logger.Trace("postgres/pathRoleCreateRead: exit")
name := data.Get("name").(string) name := data.Get("name").(string)
// Get the role // Get the role
b.logger.Trace("postgres/pathRoleCreateRead: getting role")
role, err := b.Role(req.Storage, name) role, err := b.Role(req.Storage, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if role == nil { if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil
}
b.Lock()
defer b.Unlock()
// Get the Database object
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if err != nil {
// TODO: return a resp error instead?
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
} }
// Generate the username, password and expiration // Generate the username, password and expiration
// Get our handle
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
b.RLock()
defer b.RUnlock()
db, ok := b.connections[role.DBName]
if !ok {
// TODO: return a resp error instead?
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
}
username, err := db.GenerateUsername(req.DisplayName) username, err := db.GenerateUsername(req.DisplayName)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -70,12 +64,12 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
return nil, err return nil, err
} }
// Create the user
err = db.CreateUser(role.Statements, username, password, expiration) err = db.CreateUser(role.Statements, username, password, expiration)
if err != nil { if err != nil {
return nil, err return nil, err
} }
b.logger.Trace("postgres/pathRoleCreateRead: generating secret")
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username, "username": username,
"password": password, "password": password,

View File

@@ -126,7 +126,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string) name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse("Empty role name attribute given"), nil
}
dbName := data.Get("db_name").(string) dbName := data.Get("db_name").(string)
if dbName == "" {
return logical.ErrorResponse("Empty database name attribute given"), nil
}
// Get statements // Get statements
creationStmts := data.Get("creation_statements").(string) creationStmts := data.Get("creation_statements").(string)

View File

@@ -29,7 +29,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
roleNameRaw, ok := req.Secret.InternalData["role"] roleNameRaw, ok := req.Secret.InternalData["role"]
if !ok { if !ok {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
} }
role, err := b.Role(req.Storage, roleNameRaw.(string)) role, err := b.Role(req.Storage, roleNameRaw.(string))
@@ -37,7 +37,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
return nil, err return nil, err
} }
if role == nil { if role == nil {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
} }
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
@@ -47,13 +47,13 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
} }
// Grab the read lock // Grab the read lock
b.RLock() b.Lock()
defer b.RUnlock() defer b.Unlock()
// Get our connection // Get our connection
db, ok := b.connections[role.DBName] db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if !ok { if err != nil {
return nil, fmt.Errorf("Could not find connection with name %s", role.DBName) return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err)
} }
// Make sure we increase the VALID UNTIL endpoint for this user. // Make sure we increase the VALID UNTIL endpoint for this user.
@@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
roleNameRaw, ok := req.Secret.InternalData["role"] roleNameRaw, ok := req.Secret.InternalData["role"]
if !ok { if !ok {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
} }
role, err := b.Role(req.Storage, roleNameRaw.(string)) role, err := b.Role(req.Storage, roleNameRaw.(string))
@@ -89,7 +89,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
return nil, err return nil, err
} }
if role == nil { if role == nil {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
} }
/* TODO: think about how to handle this case. /* TODO: think about how to handle this case.
@@ -109,13 +109,13 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
}*/ }*/
// Grab the read lock // Grab the read lock
b.RLock() b.Lock()
defer b.RUnlock() defer b.Unlock()
// Get our connection // Get our connection
db, ok := b.connections[role.DBName] db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if !ok { if err != nil {
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err)
} }
err = db.RevokeUser(role.Statements, username) err = db.RevokeUser(role.Statements, username)

View File

@@ -217,6 +217,9 @@ func VaultPluginTLSProvider() (*tls.Config, error) {
if err != nil { if err != nil {
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
} }
if secret == nil {
return nil, errors.New("error during token unwrap request secret is nil")
}
// Retrieve and parse the CA Certificate // Retrieve and parse the CA Certificate
CABytesRaw, ok := secret.Data["CACert"].(string) CABytesRaw, ok := secret.Data["CACert"].(string)