diff --git a/builtin/logical/postgresql/backend.go b/builtin/logical/postgresql/backend.go index e86e2a6368..65cd2230c3 100644 --- a/builtin/logical/postgresql/backend.go +++ b/builtin/logical/postgresql/backend.go @@ -1,7 +1,10 @@ package postgresql import ( + "database/sql" + "fmt" "strings" + "sync" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -23,7 +26,8 @@ func Backend() *framework.Backend { }, Paths: []*framework.Path{ - pathConfigConnection(), + pathConfigConnection(&b), + pathRoles(&b), }, } @@ -32,6 +36,58 @@ func Backend() *framework.Backend { type backend struct { *framework.Backend + + db *sql.DB + lock sync.Mutex +} + +// DB returns the database connection. +func (b *backend) DB(s logical.Storage) (*sql.DB, error) { + b.lock.Lock() + defer b.lock.Unlock() + + // If we already have a DB, we got it! + if b.db != nil { + return b.db, nil + } + + // Otherwise, attempt to make connection + entry, err := s.Get("config/connection") + if err != nil { + return nil, err + } + if entry == nil { + return nil, + fmt.Errorf("configure the DB connection with config/connection first") + } + + var conn string + if err := entry.DecodeJSON(&conn); err != nil { + return nil, err + } + + b.db, err = sql.Open("postgres", conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + b.db.SetMaxOpenConns(2) + + return b.db, nil +} + +// ResetDB forces a connection next time DB() is called. +func (b *backend) ResetDB() { + b.lock.Lock() + defer b.lock.Unlock() + + if b.db != nil { + b.db.Close() + } + + b.db = nil } const backendHelp = ` diff --git a/builtin/logical/postgresql/backend_test.go b/builtin/logical/postgresql/backend_test.go index 541e560d70..909279715b 100644 --- a/builtin/logical/postgresql/backend_test.go +++ b/builtin/logical/postgresql/backend_test.go @@ -14,6 +14,7 @@ func TestBackend_basic(t *testing.T) { Backend: Backend(), Steps: []logicaltest.TestStep{ testAccStepConfig(t), + testAccStepRole(t), }, }) } @@ -33,3 +34,20 @@ func testAccStepConfig(t *testing.T) logicaltest.TestStep { }, } } + +func testAccStepRole(t *testing.T) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "roles/web", + Data: map[string]interface{}{ + "sql": testRole, + }, + } +} + +const testRole = ` +CREATE ROLE {{name}} WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +` diff --git a/builtin/logical/postgresql/path_config.go b/builtin/logical/postgresql/path_config.go index 90132b9fa6..3f4e87b0e1 100644 --- a/builtin/logical/postgresql/path_config.go +++ b/builtin/logical/postgresql/path_config.go @@ -9,7 +9,7 @@ import ( _ "github.com/lib/pq" ) -func pathConfigConnection() *framework.Path { +func pathConfigConnection(b *backend) *framework.Path { return &framework.Path{ Pattern: "config/connection", Fields: map[string]*framework.FieldSchema{ @@ -20,7 +20,7 @@ func pathConfigConnection() *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.WriteOperation: pathConnectionWrite, + logical.WriteOperation: b.pathConnectionWrite, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -28,7 +28,7 @@ func pathConfigConnection() *framework.Path { } } -func pathConnectionWrite( +func (b *backend) pathConnectionWrite( req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connString := data.Get("value").(string) @@ -53,6 +53,9 @@ func pathConnectionWrite( return nil, err } + // Reset the DB connection + b.ResetDB() + return nil, nil } diff --git a/builtin/logical/postgresql/path_roles.go b/builtin/logical/postgresql/path_roles.go new file mode 100644 index 0000000000..b32cbc115b --- /dev/null +++ b/builtin/logical/postgresql/path_roles.go @@ -0,0 +1,100 @@ +package postgresql + +import ( + "fmt" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + _ "github.com/lib/pq" +) + +func pathRoles(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "roles/(?P\\w+)", + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the role.", + }, + + "sql": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "SQL string to create a user. See help for more info.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: b.pathRoleCreate, + }, + + HelpSynopsis: pathRoleHelpSyn, + HelpDescription: pathRoleHelpDesc, + } +} + +func (b *backend) pathRoleCreate( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + sql := data.Get("sql").(string) + + // Get our connection + db, err := b.DB(req.Storage) + if err != nil { + return nil, err + } + + // Test the query by trying to prepare it + stmt, err := db.Prepare(Query(sql, map[string]string{ + "name": "foo", + "password": "bar", + "expiration": "", + })) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error testing query: %s", err)), nil + } + stmt.Close() + + // Store it + entry, err := logical.StorageEntryJSON("role/"+name, map[string]interface{}{ + "sql": sql, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +const pathRoleHelpSyn = ` +Manage the roles that can be created with this backend. +` + +const pathRoleHelpDesc = ` +This path lets you manage the roles that can be created with this backend. + +The "sql" parameter customizes the SQL string used to create the role. +This can only be a single SQL query. Some substitution will be done to the +SQL string for certain keys. The names of the variables must be surrounded +by "{{" and "}}" to be replaced. + + * "name" - The random username generated for the DB user. + + * "password" - The random password generated for the DB user. + + * "expiration" - The timestamp when this user will expire. + +Example of a decent SQL query to use: + + CREATE ROLE {{name}} WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + +Note the above user wouldn't be able to access anything. To give a user access +to resources, create roles manually in PostgreSQL, then use the "IN ROLE" +clause for CREATE ROLE to add the user to more roles. +` diff --git a/builtin/logical/postgresql/query.go b/builtin/logical/postgresql/query.go new file mode 100644 index 0000000000..e4f7f59ddf --- /dev/null +++ b/builtin/logical/postgresql/query.go @@ -0,0 +1,15 @@ +package postgresql + +import ( + "fmt" + "strings" +) + +// Query templates a query for us. +func Query(tpl string, data map[string]string) string { + for k, v := range data { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + + return tpl +}