diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go new file mode 100644 index 0000000000..5cb84476d4 --- /dev/null +++ b/builtin/logical/database/backend_test.go @@ -0,0 +1,567 @@ +package database + +import ( + "database/sql" + "errors" + "fmt" + "log" + "net" + "os" + "reflect" + "strings" + "sync" + "testing" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" + "github.com/lib/pq" + "github.com/mitchellh/mapstructure" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testImagePull sync.Once +) + +func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) { + if os.Getenv("PG_URL") != "" { + return func() {}, os.Getenv("PG_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"}) + if err != nil { + t.Fatalf("Could not start local PostgreSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + // This will cause a validation to run + resp, err := b.HandleRequest(&logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: "config/postgresql", + Data: map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_url": retURL, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + // It's likely not up and running yet, so return error and try again + return fmt.Errorf("err:%s resp:%#v\n", err, resp) + } + if resp == nil { + t.Fatal("expected warning") + } + + return nil + }); err != nil { + t.Fatalf("Could not connect to PostgreSQL docker container: %s", err) + } + + return +} + +func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, string) { + core, _, token, ln := vault.TestCoreUnsealedWithListener(t) + http.TestServerWithListener(t, ln, "", core) + sys := vault.TestDynamicSystemView(core) + vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", fmt.Sprintf("%s -test.run=TestBackend_PluginMain", os.Args[0])) + + return core, ln, sys, token +} + +func TestBackend_PluginMain(t *testing.T) { + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + return + } + + f, _ := builtinplugins.BuiltinPlugins.Get("postgresql-database-plugin") + f() +} + +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "plugin_name": "postgresql-database-plugin", + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": configData, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } +} + +func TestBackend_basic(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if testCredsByCount(t, credsResp, connURL) != 2 { + t.Fatalf("Got wrong number of creds") + } + + // Revoke creds + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: config.StorageView, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": credsResp.Data["username"], + "role": "plugin-role-test", + }, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if testCredsByCount(t, credsResp, connURL) != -1 { + t.Fatalf("Got wrong number of creds") + } + +} + +func TestBackend_roleCrud(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "revocation_statements": defaultRevocationSQL, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + expected := dbplugin.Statements{ + CreationStatements: testRole, + RevocationStatements: defaultRevocationSQL, + } + + var actual dbplugin.Statements + if err := mapstructure.Decode(resp.Data, &actual); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Statements did not match, exepected %#v, got %#v", expected, actual) + } + + // Delete the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.DeleteOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Should be empty + if resp != nil { + t.Fatal("Expected response to be nil") + } +} + +func TestBackend_roleReadOnly(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a readonly role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testReadOnlyRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-readonly-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if i := testCredsByCount(t, credsResp, connURL); i != 2 { + t.Fatalf("Got wrong number of creds got %d, expected 2", i) + } + + // Get readonly creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-readonly-role-test", + Storage: config.StorageView, + Data: data, + } + readOnlyCredsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, readOnlyCredsResp) + } + + if i := testCredsByCount(t, readOnlyCredsResp, connURL); i != 2 { + t.Fatalf("Got wrong number of creds got %d, expected 2", i) + } + + if err := testCreateTable(t, readOnlyCredsResp, connURL); err == nil { + t.Fatal("Read only creds should return error on table creation") + } + + if err := testCreateTable(t, credsResp, connURL); err != nil { + t.Fatalf("Error on table creation: %s", err) + } +} + +func testCredsByCount(t *testing.T, resp *logical.Response, connURL string) int { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + t.Fatal(err) + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + returnedRows := func() int { + stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');") + if err != nil { + return -1 + } + defer stmt.Close() + + rows, err := stmt.Query(d.Username) + if err != nil { + return -1 + } + defer rows.Close() + + i := 0 + for rows.Next() { + i++ + } + return i + } + + return returnedRows() +} + +func testCreateTable(t *testing.T, resp *logical.Response, connURL string) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + t.Fatal(err) + } + + connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", d.Username, d.Password), 1) + + fmt.Println(connURL) + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + r, err := db.Exec("CREATE TABLE test1 (id SERIAL PRIMARY KEY);") + if err != nil { + return err + } + + if i, _ := r.RowsAffected(); i != 1 { + return errors.New("Did not create db") + } + + return nil +} + +const testRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +REVOKE ALL ON SCHEMA public FROM "{{name}}"; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const defaultRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; +REVOKE USAGE ON SCHEMA public FROM {{name}}; + +DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4af6e70a05..1b8a658315 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -172,10 +172,12 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { err = db.Initialize(config.ConnectionDetails) if err != nil { if !strings.Contains(err.Error(), "Error Initializing Connection") { + db.Close() return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } if verifyConnection { + db.Close() return logical.ErrorResponse("Could not verify connection"), nil } } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index b3ef6f753d..a6989df248 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -105,7 +105,7 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie return &logical.Response{ Data: map[string]interface{}{ - "creation_statments": role.Statements.CreationStatements, + "creation_statements": role.Statements.CreationStatements, "revocation_statements": role.Statements.RevocationStatements, "rollback_statements": role.Statements.RollbackStatements, "renew_statements": role.Statements.RenewStatements, diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 2b63ea1f89..353541c0cc 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("no role name was provided") } role, err := b.Role(req.Storage, roleNameRaw.(string)) diff --git a/command/plugin-exec.go b/command/plugin-exec.go index 70bc8ae1d4..575be14b7d 100644 --- a/command/plugin-exec.go +++ b/command/plugin-exec.go @@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int { pluginName := args[0] - runner, ok := builtinplugins.BuiltinPlugins[pluginName] + runner, ok := builtinplugins.BuiltinPlugins.Get(pluginName) if !ok { c.Ui.Error(fmt.Sprintf( "No plugin with the name %s found", pluginName)) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index ceaf10edf9..ba3769c900 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -5,7 +5,20 @@ import ( "github.com/hashicorp/vault-plugins/database/postgresql" ) -var BuiltinPlugins = map[string]func() error{ - "mysql-database-plugin": mysql.Run, - "postgresql-database-plugin": postgresql.Run, +var BuiltinPlugins *builtinPlugins = &builtinPlugins{ + plugins: map[string]func() error{ + "mysql-database-plugin": mysql.Run, + "postgresql-database-plugin": postgresql.Run, + }, +} + +// The list of builtin plugins should not be changed by any other package, so we +// store them in an unexported variable in this unexported struct. +type builtinPlugins struct { + plugins map[string]func() error +} + +func (b *builtinPlugins) Get(name string) (func() error, bool) { + f, ok := b.plugins[name] + return f, ok } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index b9c15db22a..a42f85ec11 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -54,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if _, ok := builtinplugins.BuiltinPlugins[name]; !ok { + if _, ok := builtinplugins.BuiltinPlugins.Get(name); !ok { return nil, fmt.Errorf("no plugin found with name: %s", name) } diff --git a/vault/testing.go b/vault/testing.go index 7b914bbdbb..fdf55b4e59 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -8,9 +8,13 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "io" "net" "net/http" + "os" "os/exec" + "path/filepath" + "strings" "testing" "time" @@ -306,7 +310,46 @@ func TestKeyCopy(key []byte) []byte { } func TestDynamicSystemView(c *Core) *dynamicSystemView { - return &dynamicSystemView{c, nil} + me := &MountEntry{ + Config: MountConfig{ + DefaultLeaseTTL: 24 * time.Hour, + MaxLeaseTTL: 2 * 24 * time.Hour, + }, + } + + return &dynamicSystemView{c, me} +} + +func TestAddTestPlugin(t testing.TB, c *Core, name, command string) { + parts := strings.Split(command, " ") + + file, err := os.Open(parts[0]) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + hash := sha256.New() + + _, err = io.Copy(hash, file) + if err != nil { + t.Fatal(err) + } + + sum := hash.Sum(nil) + c.pluginCatalog.directory, err = filepath.EvalSymlinks(parts[0]) + if err != nil { + t.Fatal(err) + } + c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory) + + parts[0] = filepath.Base(parts[0]) + command = strings.Join(parts, " ") + + err = c.pluginCatalog.Set(name, command, sum) + if err != nil { + t.Fatal(err) + } } var testLogicalBackends = map[string]logical.Factory{}