From 6de5cfad5e5796c160833a27f11d809fae2cf96c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 11:30:45 -0700 Subject: [PATCH] Add functionaility to build db objects from disk so restarts work --- builtin/logical/database/backend.go | 46 +++++++++++++++++-- .../database/dbs/connectionproducer.go | 8 ++-- builtin/logical/database/dbs/db.go | 5 +- .../database/path_config_connection.go | 39 +++------------- builtin/logical/database/path_role_create.go | 30 +++++------- builtin/logical/database/path_roles.go | 7 +++ builtin/logical/database/secret_creds.go | 28 +++++------ helper/pluginutil/tls.go | 3 ++ 8 files changed, 92 insertions(+), 74 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 6108652532..f8bcc60f1d 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -1,6 +1,7 @@ package database import ( + "fmt" "strings" "sync" @@ -52,14 +53,11 @@ type databaseBackend struct { logger log.Logger *framework.Backend - sync.RWMutex + sync.Mutex } // resetAllDBs closes all connections from all database types func (b *databaseBackend) closeAllDBs() { - b.logger.Trace("postgres/resetdb: enter") - defer b.logger.Trace("postgres/resetdb: exit") - b.Lock() defer b.Unlock() @@ -68,6 +66,46 @@ func (b *databaseBackend) closeAllDBs() { } } +// This function is used to retrieve a database object either from the cached +// connection map or by using the database config in storage. The caller of this +// function needs to hold the backend's lock. +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) { + // if the object already is built and cached, return it + db, ok := b.connections[name] + if ok { + return db, nil + } + + entry, err := s.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) + } + if entry == nil { + return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) + } + + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + factory := config.GetFactory() + + db, err = factory(&config, b.System(), b.logger) + if err != nil { + return nil, err + } + + err = db.Initialize(config.ConnectionDetails) + if err != nil { + return nil, err + } + + b.connections[name] = db + + return db, nil +} + func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 1e944c7b96..dae8d9400e 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -20,7 +20,7 @@ import ( ) var ( - errNotInitalized = errors.New("Connection has not been initalized") + errNotInitalized = errors.New("connection has not been initalized") ) type ConnectionProducer interface { @@ -142,7 +142,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) er c.initalized = true if _, err := c.connection(); err != nil { - return fmt.Errorf("Error Initalizing Connection: %s", err) + return fmt.Errorf("error Initalizing Connection: %s", err) } return nil @@ -244,7 +244,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { session, err := clusterConfig.CreateSession() if err != nil { - return nil, fmt.Errorf("Error creating session: %s", err) + return nil, fmt.Errorf("error creating session: %s", err) } // Set consistency @@ -260,7 +260,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { // Verify the info err = session.Query(`LIST USERS`).Exec() if err != nil { - return nil, fmt.Errorf("Error validating connection info: %s", err) + return nil, fmt.Errorf("error validating connection info: %s", err) } return session, nil diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 54581e465d..74f5a26057 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -18,10 +18,11 @@ const ( ) var ( - ErrUnsupportedDatabaseType = errors.New("Unsupported database type") - ErrEmptyCreationStatement = errors.New("Empty creation statements") + ErrUnsupportedDatabaseType = errors.New("unsupported database type") + ErrEmptyCreationStatement = errors.New("empty creation statements") ) +// Factory function for type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index ff633e7456..b4c699750d 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,7 +1,6 @@ package database import ( - "errors" "fmt" "strings" "time" @@ -34,47 +33,24 @@ func pathResetConnection(b *databaseBackend) *framework.Path { func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return nil, errors.New("No database name set") + return logical.ErrorResponse("Empty name attribute given"), nil } // Grab the mutex lock b.Lock() defer b.Unlock() - entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) - if err != nil { - return nil, fmt.Errorf("failed to read connection configuration") - } - if entry == nil { - return nil, nil + db, ok := b.connections[name] + if ok { + db.Close() + delete(b.connections, name) } - var config dbs.DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { + db, err := b.getOrCreateDBObj(req.Storage, name) + if err != nil { return nil, err } - db, ok := b.connections[name] - if !ok { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - - db.Close() - - factory := config.GetFactory() - - db, err = factory(&config, b.System(), b.logger) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - } - - err = db.Initialize(config.ConnectionDetails) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - } - - b.connections[name] = db - return nil, nil } @@ -306,7 +282,6 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. return nil, err } - // Reset the DB connection resp := &logical.Response{} resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 14b65cbb31..d379ef2673 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -27,34 +27,28 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { } func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - b.logger.Trace("postgres/pathRoleCreateRead: enter") - defer b.logger.Trace("postgres/pathRoleCreateRead: exit") - name := data.Get("name").(string) // Get the role - b.logger.Trace("postgres/pathRoleCreateRead: getting role") role, err := b.Role(req.Storage, name) if err != nil { return nil, err } if role == nil { - return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil + return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil + } + + b.Lock() + defer b.Unlock() + + // Get the Database object + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + // TODO: return a resp error instead? + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } // Generate the username, password and expiration - - // Get our handle - b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") - - b.RLock() - defer b.RUnlock() - db, ok := b.connections[role.DBName] - if !ok { - // TODO: return a resp error instead? - return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) - } - username, err := db.GenerateUsername(req.DisplayName) if err != nil { return nil, err @@ -70,12 +64,12 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, err } + // Create the user err = db.CreateUser(role.Statements, username, password, expiration) if err != nil { return nil, err } - b.logger.Trace("postgres/pathRoleCreateRead: generating secret") resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ "username": username, "password": password, diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 9a5bb9324d..6f62c79d98 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -126,7 +126,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty role name attribute given"), nil + } + dbName := data.Get("db_name").(string) + if dbName == "" { + return logical.ErrorResponse("Empty database name attribute given"), nil + } // Get statements creationStmts := data.Get("creation_statements").(string) diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index e39525a18c..2b63ea1f89 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -29,7 +29,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } role, err := b.Role(req.Storage, roleNameRaw.(string)) @@ -37,7 +37,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi return nil, err } if role == nil { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) @@ -47,13 +47,13 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi } // Grab the read lock - b.RLock() - defer b.RUnlock() + b.Lock() + defer b.Unlock() // Get our connection - db, ok := b.connections[role.DBName] - if !ok { - return nil, fmt.Errorf("Could not find connection with name %s", role.DBName) + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) } // Make sure we increase the VALID UNTIL endpoint for this user. @@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } role, err := b.Role(req.Storage, roleNameRaw.(string)) @@ -89,7 +89,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F return nil, err } if role == nil { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } /* TODO: think about how to handle this case. @@ -109,13 +109,13 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F }*/ // Grab the read lock - b.RLock() - defer b.RUnlock() + b.Lock() + defer b.Unlock() // Get our connection - db, ok := b.connections[role.DBName] - if !ok { - return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) } err = db.RevokeUser(role.Statements, username) diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 08f24985d2..63ae2932f1 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -217,6 +217,9 @@ func VaultPluginTLSProvider() (*tls.Config, error) { if err != nil { return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) } + if secret == nil { + return nil, errors.New("error during token unwrap request secret is nil") + } // Retrieve and parse the CA Certificate CABytesRaw, ok := secret.Data["CACert"].(string)