Merge pull request #1573 from mickhansen/logical-postgresql-revoke-sequences

handle revocations for roles that have privileges on sequences
This commit is contained in:
Jeff Mitchell
2016-07-18 13:30:42 -04:00
committed by GitHub
2 changed files with 174 additions and 12 deletions

View File

@@ -9,6 +9,7 @@ import (
"sync"
"testing"
"time"
"path"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
@@ -131,7 +132,7 @@ func TestBackend_basic(t *testing.T) {
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepRole(t),
testAccStepCreateRole(t, "web", testRole),
testAccStepReadCreds(t, b, config.StorageView, "web", connURL),
},
})
@@ -157,7 +158,7 @@ func TestBackend_roleCrud(t *testing.T) {
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepRole(t),
testAccStepCreateRole(t, "web", testRole),
testAccStepReadRole(t, "web", testRole),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web", ""),
@@ -165,6 +166,39 @@ func TestBackend_roleCrud(t *testing.T) {
})
}
func TestBackend_roleReadOnly(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
cid, connURL := prepareTestContainer(t, config.StorageView, b)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
connData := map[string]interface{}{
"connection_url": connURL,
}
logicaltest.Test(t, logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{
testAccStepConfig(t, connData, false),
testAccStepCreateRole(t, "web", testRole),
testAccStepCreateRole(t, "web-readonly", testReadOnlyRole),
testAccStepReadRole(t, "web-readonly", testReadOnlyRole),
testAccStepCreateTable(t, b, config.StorageView, "web", connURL),
testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL),
testAccStepDropTable(t, b, config.StorageView, "web", connURL),
testAccStepDeleteRole(t, "web-readonly"),
testAccStepDeleteRole(t, "web"),
testAccStepReadRole(t, "web-readonly", ""),
},
})
}
func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
@@ -194,27 +228,27 @@ func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool)
}
}
func testAccStepRole(t *testing.T) logicaltest.TestStep {
func testAccStepCreateRole(t *testing.T, name string, sql string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "roles/web",
Path: path.Join("roles", name),
Data: map[string]interface{}{
"sql": testRole,
"sql": sql,
},
}
}
func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep {
func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "roles/" + n,
Path: path.Join("roles", name),
}
}
func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "creds/" + name,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
@@ -223,9 +257,9 @@ func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, na
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[WARN] Generated credentials: %v", d)
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
@@ -257,8 +291,11 @@ func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, na
return i
}
// minNumPermissions is the minimum number of permissions that will always be present.
const minNumPermissions = 2
userRows := returnedRows()
if userRows != 2 {
if userRows < minNumPermissions {
t.Fatalf("did not get expected number of rows, got %d", userRows)
}
@@ -292,6 +329,117 @@ func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, na
}
}
func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);")
if err != nil {
t.Fatal(err)
}
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.RevokeOperation,
Storage: s,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("Error on resp: %#v", *resp)
}
}
return nil
},
}
}
func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: path.Join("creds", name),
Check: func(resp *logical.Response) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec("DROP TABLE test;")
if err != nil {
t.Fatal(err)
}
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.RevokeOperation,
Storage: s,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": d.Username,
},
},
})
if err != nil {
return err
}
if resp != nil {
if resp.IsError() {
return fmt.Errorf("Error on resp: %#v", *resp)
}
}
return nil
},
}
}
func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
@@ -328,3 +476,12 @@ CREATE ROLE "{{name}}" WITH
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
`
const testReadOnlyRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
`

View File

@@ -91,6 +91,7 @@ func (b *backend) secretCredsRevoke(
}
username, ok := usernameRaw.(string)
// Get our connection
db, err := b.DB(req.Storage)
if err != nil {
@@ -150,7 +151,11 @@ func (b *backend) secretCredsRevoke(
pq.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE USAGE ON SCHEMA public FROM %s;`,
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
pq.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE USAGE ON SCHEMA public FROM %s;",
pq.QuoteIdentifier(username)))
// get the current database name so we can issue a REVOKE CONNECT for