Add functionaility to build db objects from disk so restarts work

This commit is contained in:
Brian Kassouf
2017-03-28 11:30:45 -07:00
parent d93378bb29
commit 6de5cfad5e
8 changed files with 92 additions and 74 deletions

View File

@@ -1,6 +1,7 @@
package database
import (
"fmt"
"strings"
"sync"
@@ -52,14 +53,11 @@ type databaseBackend struct {
logger log.Logger
*framework.Backend
sync.RWMutex
sync.Mutex
}
// resetAllDBs closes all connections from all database types
func (b *databaseBackend) closeAllDBs() {
b.logger.Trace("postgres/resetdb: enter")
defer b.logger.Trace("postgres/resetdb: exit")
b.Lock()
defer b.Unlock()
@@ -68,6 +66,46 @@ func (b *databaseBackend) closeAllDBs() {
}
}
// This function is used to retrieve a database object either from the cached
// connection map or by using the database config in storage. The caller of this
// function needs to hold the backend's lock.
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) {
// if the object already is built and cached, return it
db, ok := b.connections[name]
if ok {
return db, nil
}
entry, err := s.Get(fmt.Sprintf("dbs/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration with name: %s", name)
}
if entry == nil {
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
}
var config dbs.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
factory := config.GetFactory()
db, err = factory(&config, b.System(), b.logger)
if err != nil {
return nil, err
}
err = db.Initialize(config.ConnectionDetails)
if err != nil {
return nil, err
}
b.connections[name] = db
return db, nil
}
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
entry, err := s.Get("role/" + n)
if err != nil {

View File

@@ -20,7 +20,7 @@ import (
)
var (
errNotInitalized = errors.New("Connection has not been initalized")
errNotInitalized = errors.New("connection has not been initalized")
)
type ConnectionProducer interface {
@@ -142,7 +142,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) er
c.initalized = true
if _, err := c.connection(); err != nil {
return fmt.Errorf("Error Initalizing Connection: %s", err)
return fmt.Errorf("error Initalizing Connection: %s", err)
}
return nil
@@ -244,7 +244,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
session, err := clusterConfig.CreateSession()
if err != nil {
return nil, fmt.Errorf("Error creating session: %s", err)
return nil, fmt.Errorf("error creating session: %s", err)
}
// Set consistency
@@ -260,7 +260,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
// Verify the info
err = session.Query(`LIST USERS`).Exec()
if err != nil {
return nil, fmt.Errorf("Error validating connection info: %s", err)
return nil, fmt.Errorf("error validating connection info: %s", err)
}
return session, nil

View File

@@ -18,10 +18,11 @@ const (
)
var (
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
ErrEmptyCreationStatement = errors.New("Empty creation statements")
ErrUnsupportedDatabaseType = errors.New("unsupported database type")
ErrEmptyCreationStatement = errors.New("empty creation statements")
)
// Factory function for
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {

View File

@@ -1,7 +1,6 @@
package database
import (
"errors"
"fmt"
"strings"
"time"
@@ -34,47 +33,24 @@ func pathResetConnection(b *databaseBackend) *framework.Path {
func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return nil, errors.New("No database name set")
return logical.ErrorResponse("Empty name attribute given"), nil
}
// Grab the mutex lock
b.Lock()
defer b.Unlock()
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
db, ok := b.connections[name]
if ok {
db.Close()
delete(b.connections, name)
}
var config dbs.DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
db, err := b.getOrCreateDBObj(req.Storage, name)
if err != nil {
return nil, err
}
db, ok := b.connections[name]
if !ok {
return logical.ErrorResponse("Can not change type of existing connection."), nil
}
db.Close()
factory := config.GetFactory()
db, err = factory(&config, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
err = db.Initialize(config.ConnectionDetails)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
b.connections[name] = db
return nil, nil
}
@@ -306,7 +282,6 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
return nil, err
}
// Reset the DB connection
resp := &logical.Response{}
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")

View File

@@ -27,34 +27,28 @@ func pathRoleCreate(b *databaseBackend) *framework.Path {
}
func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.logger.Trace("postgres/pathRoleCreateRead: enter")
defer b.logger.Trace("postgres/pathRoleCreateRead: exit")
name := data.Get("name").(string)
// Get the role
b.logger.Trace("postgres/pathRoleCreateRead: getting role")
role, err := b.Role(req.Storage, name)
if err != nil {
return nil, err
}
if role == nil {
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil
}
b.Lock()
defer b.Unlock()
// Get the Database object
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if err != nil {
// TODO: return a resp error instead?
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
}
// Generate the username, password and expiration
// Get our handle
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
b.RLock()
defer b.RUnlock()
db, ok := b.connections[role.DBName]
if !ok {
// TODO: return a resp error instead?
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
}
username, err := db.GenerateUsername(req.DisplayName)
if err != nil {
return nil, err
@@ -70,12 +64,12 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
return nil, err
}
// Create the user
err = db.CreateUser(role.Statements, username, password, expiration)
if err != nil {
return nil, err
}
b.logger.Trace("postgres/pathRoleCreateRead: generating secret")
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
"username": username,
"password": password,

View File

@@ -126,7 +126,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse("Empty role name attribute given"), nil
}
dbName := data.Get("db_name").(string)
if dbName == "" {
return logical.ErrorResponse("Empty database name attribute given"), nil
}
// Get statements
creationStmts := data.Get("creation_statements").(string)

View File

@@ -29,7 +29,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
roleNameRaw, ok := req.Secret.InternalData["role"]
if !ok {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
}
role, err := b.Role(req.Storage, roleNameRaw.(string))
@@ -37,7 +37,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
return nil, err
}
if role == nil {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
}
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
@@ -47,13 +47,13 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
}
// Grab the read lock
b.RLock()
defer b.RUnlock()
b.Lock()
defer b.Unlock()
// Get our connection
db, ok := b.connections[role.DBName]
if !ok {
return nil, fmt.Errorf("Could not find connection with name %s", role.DBName)
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if err != nil {
return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err)
}
// Make sure we increase the VALID UNTIL endpoint for this user.
@@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
roleNameRaw, ok := req.Secret.InternalData["role"]
if !ok {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
}
role, err := b.Role(req.Storage, roleNameRaw.(string))
@@ -89,7 +89,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
return nil, err
}
if role == nil {
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
}
/* TODO: think about how to handle this case.
@@ -109,13 +109,13 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
}*/
// Grab the read lock
b.RLock()
defer b.RUnlock()
b.Lock()
defer b.Unlock()
// Get our connection
db, ok := b.connections[role.DBName]
if !ok {
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if err != nil {
return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err)
}
err = db.RevokeUser(role.Statements, username)

View File

@@ -217,6 +217,9 @@ func VaultPluginTLSProvider() (*tls.Config, error) {
if err != nil {
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
}
if secret == nil {
return nil, errors.New("error during token unwrap request secret is nil")
}
// Retrieve and parse the CA Certificate
CABytesRaw, ok := secret.Data["CACert"].(string)