mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-14 17:05:11 +00:00
Add functionaility to build db objects from disk so restarts work
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -52,14 +53,11 @@ type databaseBackend struct {
|
|||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
*framework.Backend
|
*framework.Backend
|
||||||
sync.RWMutex
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// resetAllDBs closes all connections from all database types
|
// resetAllDBs closes all connections from all database types
|
||||||
func (b *databaseBackend) closeAllDBs() {
|
func (b *databaseBackend) closeAllDBs() {
|
||||||
b.logger.Trace("postgres/resetdb: enter")
|
|
||||||
defer b.logger.Trace("postgres/resetdb: exit")
|
|
||||||
|
|
||||||
b.Lock()
|
b.Lock()
|
||||||
defer b.Unlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
@@ -68,6 +66,46 @@ func (b *databaseBackend) closeAllDBs() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This function is used to retrieve a database object either from the cached
|
||||||
|
// connection map or by using the database config in storage. The caller of this
|
||||||
|
// function needs to hold the backend's lock.
|
||||||
|
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) {
|
||||||
|
// if the object already is built and cached, return it
|
||||||
|
db, ok := b.connections[name]
|
||||||
|
if ok {
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entry, err := s.Get(fmt.Sprintf("dbs/%s", name))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read connection configuration with name: %s", name)
|
||||||
|
}
|
||||||
|
if entry == nil {
|
||||||
|
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
var config dbs.DatabaseConfig
|
||||||
|
if err := entry.DecodeJSON(&config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
factory := config.GetFactory()
|
||||||
|
|
||||||
|
db, err = factory(&config, b.System(), b.logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Initialize(config.ConnectionDetails)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.connections[name] = db
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
|
func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) {
|
||||||
entry, err := s.Get("role/" + n)
|
entry, err := s.Get("role/" + n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errNotInitalized = errors.New("Connection has not been initalized")
|
errNotInitalized = errors.New("connection has not been initalized")
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnectionProducer interface {
|
type ConnectionProducer interface {
|
||||||
@@ -142,7 +142,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) er
|
|||||||
c.initalized = true
|
c.initalized = true
|
||||||
|
|
||||||
if _, err := c.connection(); err != nil {
|
if _, err := c.connection(); err != nil {
|
||||||
return fmt.Errorf("Error Initalizing Connection: %s", err)
|
return fmt.Errorf("error Initalizing Connection: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -244,7 +244,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
|||||||
|
|
||||||
session, err := clusterConfig.CreateSession()
|
session, err := clusterConfig.CreateSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error creating session: %s", err)
|
return nil, fmt.Errorf("error creating session: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set consistency
|
// Set consistency
|
||||||
@@ -260,7 +260,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
|||||||
// Verify the info
|
// Verify the info
|
||||||
err = session.Query(`LIST USERS`).Exec()
|
err = session.Query(`LIST USERS`).Exec()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error validating connection info: %s", err)
|
return nil, fmt.Errorf("error validating connection info: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return session, nil
|
return session, nil
|
||||||
|
|||||||
@@ -18,10 +18,11 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUnsupportedDatabaseType = errors.New("Unsupported database type")
|
ErrUnsupportedDatabaseType = errors.New("unsupported database type")
|
||||||
ErrEmptyCreationStatement = errors.New("Empty creation statements")
|
ErrEmptyCreationStatement = errors.New("empty creation statements")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Factory function for
|
||||||
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
|
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
|
||||||
|
|
||||||
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
|
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -34,47 +33,24 @@ func pathResetConnection(b *databaseBackend) *framework.Path {
|
|||||||
func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||||
name := data.Get("name").(string)
|
name := data.Get("name").(string)
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, errors.New("No database name set")
|
return logical.ErrorResponse("Empty name attribute given"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grab the mutex lock
|
// Grab the mutex lock
|
||||||
b.Lock()
|
b.Lock()
|
||||||
defer b.Unlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name))
|
db, ok := b.connections[name]
|
||||||
if err != nil {
|
if ok {
|
||||||
return nil, fmt.Errorf("failed to read connection configuration")
|
db.Close()
|
||||||
}
|
delete(b.connections, name)
|
||||||
if entry == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var config dbs.DatabaseConfig
|
db, err := b.getOrCreateDBObj(req.Storage, name)
|
||||||
if err := entry.DecodeJSON(&config); err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
db, ok := b.connections[name]
|
|
||||||
if !ok {
|
|
||||||
return logical.ErrorResponse("Can not change type of existing connection."), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
db.Close()
|
|
||||||
|
|
||||||
factory := config.GetFactory()
|
|
||||||
|
|
||||||
db, err = factory(&config, b.System(), b.logger)
|
|
||||||
if err != nil {
|
|
||||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = db.Initialize(config.ConnectionDetails)
|
|
||||||
if err != nil {
|
|
||||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
b.connections[name] = db
|
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,7 +282,6 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset the DB connection
|
|
||||||
resp := &logical.Response{}
|
resp := &logical.Response{}
|
||||||
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
|
resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.")
|
||||||
|
|
||||||
|
|||||||
@@ -27,34 +27,28 @@ func pathRoleCreate(b *databaseBackend) *framework.Path {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||||
b.logger.Trace("postgres/pathRoleCreateRead: enter")
|
|
||||||
defer b.logger.Trace("postgres/pathRoleCreateRead: exit")
|
|
||||||
|
|
||||||
name := data.Get("name").(string)
|
name := data.Get("name").(string)
|
||||||
|
|
||||||
// Get the role
|
// Get the role
|
||||||
b.logger.Trace("postgres/pathRoleCreateRead: getting role")
|
|
||||||
role, err := b.Role(req.Storage, name)
|
role, err := b.Role(req.Storage, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if role == nil {
|
if role == nil {
|
||||||
return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil
|
return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Lock()
|
||||||
|
defer b.Unlock()
|
||||||
|
|
||||||
|
// Get the Database object
|
||||||
|
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: return a resp error instead?
|
||||||
|
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the username, password and expiration
|
// Generate the username, password and expiration
|
||||||
|
|
||||||
// Get our handle
|
|
||||||
b.logger.Trace("postgres/pathRoleCreateRead: getting database handle")
|
|
||||||
|
|
||||||
b.RLock()
|
|
||||||
defer b.RUnlock()
|
|
||||||
db, ok := b.connections[role.DBName]
|
|
||||||
if !ok {
|
|
||||||
// TODO: return a resp error instead?
|
|
||||||
return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName)
|
|
||||||
}
|
|
||||||
|
|
||||||
username, err := db.GenerateUsername(req.DisplayName)
|
username, err := db.GenerateUsername(req.DisplayName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -70,12 +64,12 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create the user
|
||||||
err = db.CreateUser(role.Statements, username, password, expiration)
|
err = db.CreateUser(role.Statements, username, password, expiration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.logger.Trace("postgres/pathRoleCreateRead: generating secret")
|
|
||||||
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
resp := b.Secret(SecretCredsType).Response(map[string]interface{}{
|
||||||
"username": username,
|
"username": username,
|
||||||
"password": password,
|
"password": password,
|
||||||
|
|||||||
@@ -126,7 +126,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD
|
|||||||
|
|
||||||
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||||
name := data.Get("name").(string)
|
name := data.Get("name").(string)
|
||||||
|
if name == "" {
|
||||||
|
return logical.ErrorResponse("Empty role name attribute given"), nil
|
||||||
|
}
|
||||||
|
|
||||||
dbName := data.Get("db_name").(string)
|
dbName := data.Get("db_name").(string)
|
||||||
|
if dbName == "" {
|
||||||
|
return logical.ErrorResponse("Empty database name attribute given"), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get statements
|
// Get statements
|
||||||
creationStmts := data.Get("creation_statements").(string)
|
creationStmts := data.Get("creation_statements").(string)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
|
|||||||
|
|
||||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||||
@@ -37,7 +37,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if role == nil {
|
if role == nil {
|
||||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
|
f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System())
|
||||||
@@ -47,13 +47,13 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Grab the read lock
|
// Grab the read lock
|
||||||
b.RLock()
|
b.Lock()
|
||||||
defer b.RUnlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
// Get our connection
|
// Get our connection
|
||||||
db, ok := b.connections[role.DBName]
|
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not find connection with name %s", role.DBName)
|
return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||||
@@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
|
|||||||
|
|
||||||
roleNameRaw, ok := req.Secret.InternalData["role"]
|
roleNameRaw, ok := req.Secret.InternalData["role"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
role, err := b.Role(req.Storage, roleNameRaw.(string))
|
||||||
@@ -89,7 +89,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if role == nil {
|
if role == nil {
|
||||||
return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"])
|
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
|
||||||
}
|
}
|
||||||
|
|
||||||
/* TODO: think about how to handle this case.
|
/* TODO: think about how to handle this case.
|
||||||
@@ -109,13 +109,13 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
|
|||||||
}*/
|
}*/
|
||||||
|
|
||||||
// Grab the read lock
|
// Grab the read lock
|
||||||
b.RLock()
|
b.Lock()
|
||||||
defer b.RUnlock()
|
defer b.Unlock()
|
||||||
|
|
||||||
// Get our connection
|
// Get our connection
|
||||||
db, ok := b.connections[role.DBName]
|
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not find database with name: %s", role.DBName)
|
return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.RevokeUser(role.Statements, username)
|
err = db.RevokeUser(role.Statements, username)
|
||||||
|
|||||||
@@ -217,6 +217,9 @@ func VaultPluginTLSProvider() (*tls.Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
|
return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err)
|
||||||
}
|
}
|
||||||
|
if secret == nil {
|
||||||
|
return nil, errors.New("error during token unwrap request secret is nil")
|
||||||
|
}
|
||||||
|
|
||||||
// Retrieve and parse the CA Certificate
|
// Retrieve and parse the CA Certificate
|
||||||
CABytesRaw, ok := secret.Data["CACert"].(string)
|
CABytesRaw, ok := secret.Data["CACert"].(string)
|
||||||
|
|||||||
Reference in New Issue
Block a user