diff --git a/builtin/logical/postgresql/path_role_create.go b/builtin/logical/postgresql/path_role_create.go index 5609bbc866..497b9a6c63 100644 --- a/builtin/logical/postgresql/path_role_create.go +++ b/builtin/logical/postgresql/path_role_create.go @@ -51,13 +51,7 @@ func (b *backend) pathRoleCreateRead( lease = &configLease{Lease: 1 * time.Hour} } - // Get our connection - db, err := b.DB(req.Storage) - if err != nil { - return nil, err - } - - // Generate our query + // Generate the username, password and expiration username := fmt.Sprintf( "vault-%s-%d-%d", req.DisplayName, time.Now().Unix(), rand.Int31n(10000)) @@ -65,19 +59,37 @@ func (b *backend) pathRoleCreateRead( expiration := time.Now().UTC(). Add(lease.Lease + time.Duration((float64(lease.Lease) * 0.1))). Format("2006-01-02 15:04:05") - query := Query(role.SQL, map[string]string{ - "name": username, - "password": password, - "expiration": expiration, - }) - // Prepare the statement and execute it - stmt, err := db.Prepare(query) + // Get our connection + db, err := b.DB(req.Storage) if err != nil { return nil, err } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + // Execute each query + for _, query := range SplitSQL(role.SQL) { + stmt, err := db.Prepare(Query(query, map[string]string{ + "name": username, + "password": password, + "expiration": expiration, + })) + if err != nil { + return nil, err + } + if _, err := stmt.Exec(); err != nil { + return nil, err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { return nil, err } diff --git a/builtin/logical/postgresql/path_roles.go b/builtin/logical/postgresql/path_roles.go index 23750b555b..dfc389e537 100644 --- a/builtin/logical/postgresql/path_roles.go +++ b/builtin/logical/postgresql/path_roles.go @@ -90,16 +90,18 @@ func (b *backend) pathRoleCreate( } // Test the query by trying to prepare it - stmt, err := db.Prepare(Query(sql, map[string]string{ - "name": "foo", - "password": "bar", - "expiration": "", - })) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error testing query: %s", err)), nil + for _, query := range SplitSQL(sql) { + stmt, err := db.Prepare(Query(query, map[string]string{ + "name": "foo", + "password": "bar", + "expiration": "", + })) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error testing query: %s", err)), nil + } + stmt.Close() } - stmt.Close() // Store it entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ @@ -127,7 +129,7 @@ const pathRoleHelpDesc = ` This path lets you manage the roles that can be created with this backend. The "sql" parameter customizes the SQL string used to create the role. -This can only be a single SQL query. Some substitution will be done to the +This can be a sequence of SQL queries. Some substitution will be done to the SQL string for certain keys. The names of the variables must be surrounded by "{{" and "}}" to be replaced. @@ -139,12 +141,12 @@ by "{{" and "}}" to be replaced. Example of a decent SQL query to use: - CREATE ROLE "{{name}}" WITH + CREATE ROLE '{{name}}' WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA db1 TO '{{name}}'; -Note the above user wouldn't be able to access anything. To give a user access -to resources, create roles manually in PostgreSQL, then use the "IN ROLE" -clause for CREATE ROLE to add the user to more roles. +Note the above user would be able to access everything. In schema dc1. +For more complex GRANT clauses, see the PostgreSQL manuel. ` diff --git a/builtin/logical/postgresql/util.go b/builtin/logical/postgresql/util.go new file mode 100644 index 0000000000..4286a13f4c --- /dev/null +++ b/builtin/logical/postgresql/util.go @@ -0,0 +1,16 @@ +package postgresql + +import "strings" + +// SplitSQL is used to split a series of SQL statements +func SplitSQL(sql string) []string { + parts := strings.Split(sql, ";") + out := make([]string, 0, len(parts)) + for _, p := range parts { + clean := strings.TrimSpace(p) + if len(clean) > 0 { + out = append(out, clean) + } + } + return out +}