mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 19:47:54 +00:00
Add functionaility to build db objects from disk so restarts work
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -52,14 +53,11 @@ type databaseBackend struct {
|
||||
logger log.Logger
|
||||
|
||||
*framework.Backend
|
||||
sync.RWMutex
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// resetAllDBs closes all connections from all database types
|
||||
func (b *databaseBackend) closeAllDBs() {
|
||||
b.logger.Trace("postgres/resetdb: enter")
|
||||
defer b.logger.Trace("postgres/resetdb: exit")
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
@@ -68,6 +66,46 @@ func (b *databaseBackend) closeAllDBs() {
|
||||
}
|
||||
}
|
||||
|
||||
// This function is used to retrieve a database object either from the cached
|
||||
// connection map or by using the database config in storage. The caller of this
|
||||
// function needs to hold the backend's lock.
|
||||
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) {
|
||||
// if the object already is built and cached, return it
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
entry, err := s.Get(fmt.Sprintf("dbs/%s", name))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration with name: %s", name)
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
|
||||
}
|
||||
|
||||
var config dbs.DatabaseConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
factory := config.GetFactory()
|
||||
|
||||
db, err = factory(&config, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.connections[name] = db
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
|
||||
entry, err := s.Get("role/" + n)
|
||||
if err != nil {
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errNotInitalized = errors.New("Connection has not been initalized")
|
||||
errNotInitalized = errors.New("connection has not been initalized")
|
||||
)
|
||||
|
||||
type ConnectionProducer interface {
|
||||
@@ -142,7 +142,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) er
|
||||
c.initalized = true
|
||||
|
||||
if _, err := c.connection(); err != nil {
|
||||
return fmt.Errorf("Error Initalizing Connection: %s", err)
|
||||
return fmt.Errorf("error Initalizing Connection: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -244,7 +244,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error creating session: %s", err)
|
||||
return nil, fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
|
||||
// Set consistency
|
||||
@@ -260,7 +260,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
||||
// Verify the info
|
||||
err = session.Query(`LIST USERS`).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error validating connection info: %s", err)
|
||||
return nil, fmt.Errorf("error validating connection info: %s", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
|
||||
@@ -18,10 +18,11 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
|
||||
ErrEmptyCreationStatement = errors.New("Empty creation statements")
|
||||
ErrUnsupportedDatabaseType = errors.New("unsupported database type")
|
||||
ErrEmptyCreationStatement = errors.New("empty creation statements")
|
||||
)
|
||||
|
||||
// Factory function for
|
||||
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
|
||||
|
||||
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -34,47 +33,24 @@ func pathResetConnection(b *databaseBackend) *framework.Path {
|
||||
func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return nil, errors.New("No database name set")
|
||||
return logical.ErrorResponse("Empty name attribute given"), nil
|
||||
}
|
||||
|
||||
// Grab the mutex lock
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read connection configuration")
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
db.Close()
|
||||
delete(b.connections, name)
|
||||
}
|
||||
|
||||
var config dbs.DatabaseConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
db, err := b.getOrCreateDBObj(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, ok := b.connections[name]
|
||||
if !ok {
|
||||
return logical.ErrorResponse("Can not change type of existing connection."), nil
|
||||
}
|
||||
|
||||
db.Close()
|
||||
|
||||
factory := config.GetFactory()
|
||||
|
||||
db, err = factory(&config, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
err = db.Initialize(config.ConnectionDetails)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
b.connections[name] = db
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -306,7 +282,6 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reset the DB connection
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
|
||||
|
||||
|
||||
@@ -27,34 +27,28 @@ func pathRoleCreate(b *databaseBackend) *framework.Path {
|
||||
}
|
||||
|
||||
func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: enter")
|
||||
defer b.logger.Trace("postgres/pathRoleCreateRead: exit")
|
||||
|
||||
name := data.Get("name").(string)
|
||||
|
||||
// Get the role
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: getting role")
|
||||
role, err := b.Role(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
||||
return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil
|
||||
}
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Get the Database object
|
||||
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
// TODO: return a resp error instead?
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
|
||||
// Generate the username, password and expiration
|
||||
|
||||
// Get our handle
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
|
||||
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
db, ok := b.connections[role.DBName]
|
||||
if !ok {
|
||||
// TODO: return a resp error instead?
|
||||
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername(req.DisplayName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -70,12 +64,12 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the user
|
||||
err = db.CreateUser(role.Statements, username, password, expiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.logger.Trace("postgres/pathRoleCreateRead: generating secret")
|
||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
|
||||
@@ -126,7 +126,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD
|
||||
|
||||
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
name := data.Get("name").(string)
|
||||
if name == "" {
|
||||
return logical.ErrorResponse("Empty role name attribute given"), nil
|
||||
}
|
||||
|
||||
dbName := data.Get("db_name").(string)
|
||||
if dbName == "" {
|
||||
return logical.ErrorResponse("Empty database name attribute given"), nil
|
||||
}
|
||||
|
||||
// Get statements
|
||||
creationStmts := data.Get("creation_statements").(string)
|
||||
|
||||
@@ -29,7 +29,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
|
||||
|
||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
@@ -37,7 +37,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
|
||||
@@ -47,13 +47,13 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
|
||||
}
|
||||
|
||||
// Grab the read lock
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Get our connection
|
||||
db, ok := b.connections[role.DBName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not find connection with name %s", role.DBName)
|
||||
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err)
|
||||
}
|
||||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
@@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
|
||||
|
||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||
@@ -89,7 +89,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
|
||||
return nil, err
|
||||
}
|
||||
if role == nil {
|
||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||
}
|
||||
|
||||
/* TODO: think about how to handle this case.
|
||||
@@ -109,13 +109,13 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
|
||||
}*/
|
||||
|
||||
// Grab the read lock
|
||||
b.RLock()
|
||||
defer b.RUnlock()
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Get our connection
|
||||
db, ok := b.connections[role.DBName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
|
||||
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
|
||||
err = db.RevokeUser(role.Statements, username)
|
||||
|
||||
@@ -217,6 +217,9 @@ func VaultPluginTLSProvider() (*tls.Config, error) {
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
|
||||
}
|
||||
if secret == nil {
|
||||
return nil, errors.New("error during token unwrap request secret is nil")
|
||||
}
|
||||
|
||||
// Retrieve and parse the CA Certificate
|
||||
CABytesRaw, ok := secret.Data["CACert"].(string)
|
||||
|
||||
Reference in New Issue
Block a user