Add backend test

This commit is contained in:
Brian Kassouf
2017-04-07 15:50:03 -07:00
parent 8e77bd98d8
commit 9ae5a2aede
8 changed files with 633 additions and 8 deletions

View File

@@ -0,0 +1,567 @@
package database
import (
"database/sql"
"errors"
"fmt"
"log"
"net"
"os"
"reflect"
"strings"
"sync"
"testing"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/helper/builtinplugins"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
"github.com/lib/pq"
"github.com/mitchellh/mapstructure"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
var (
testImagePull sync.Once
)
func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) {
if os.Getenv("PG_URL") != "" {
return func() {}, os.Getenv("PG_URL")
}
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"})
if err != nil {
t.Fatalf("Could not start local PostgreSQL docker container: %s", err)
}
cleanup = func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local container: %s", err)
}
}
retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp"))
// exponential backoff-retry
if err = pool.Retry(func() error {
// This will cause a validation to run
resp, err := b.HandleRequest(&logical.Request{
Storage: s,
Operation: logical.UpdateOperation,
Path: "config/postgresql",
Data: map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_url": retURL,
},
})
if err != nil || (resp != nil && resp.IsError()) {
// It's likely not up and running yet, so return error and try again
return fmt.Errorf("err:%s resp:%#v\n", err, resp)
}
if resp == nil {
t.Fatal("expected warning")
}
return nil
}); err != nil {
t.Fatalf("Could not connect to PostgreSQL docker container: %s", err)
}
return
}
func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, string) {
core, _, token, ln := vault.TestCoreUnsealedWithListener(t)
http.TestServerWithListener(t, ln, "", core)
sys := vault.TestDynamicSystemView(core)
vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", fmt.Sprintf("%s -test.run=TestBackend_PluginMain", os.Args[0]))
return core, ln, sys, token
}
func TestBackend_PluginMain(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
f, _ := builtinplugins.BuiltinPlugins.Get("postgresql-database-plugin")
f()
}
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
_, ln, sys, _ := getCore(t)
defer ln.Close()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
"plugin_name": "postgresql-database-plugin",
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": configData,
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
}
func TestBackend_basic(t *testing.T) {
_, ln, sys, _ := getCore(t)
defer ln.Close()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()
// Configure a connection
data := map[string]interface{}{
"connection_url": connURL,
"plugin_name": "postgresql-database-plugin",
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: data,
}
resp, err := b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Create a role
data = map[string]interface{}{
"db_name": "plugin-test",
"creation_statements": testRole,
"default_ttl": "5m",
"max_ttl": "10m",
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Get creds
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "creds/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
credsResp, err := b.HandleRequest(req)
if err != nil || (credsResp != nil && credsResp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
}
if testCredsByCount(t, credsResp, connURL) != 2 {
t.Fatalf("Got wrong number of creds")
}
// Revoke creds
resp, err = b.HandleRequest(&logical.Request{
Operation: logical.RevokeOperation,
Storage: config.StorageView,
Secret: &logical.Secret{
InternalData: map[string]interface{}{
"secret_type": "creds",
"username": credsResp.Data["username"],
"role": "plugin-role-test",
},
},
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if testCredsByCount(t, credsResp, connURL) != -1 {
t.Fatalf("Got wrong number of creds")
}
}
func TestBackend_roleCrud(t *testing.T) {
_, ln, sys, _ := getCore(t)
defer ln.Close()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()
// Configure a connection
data := map[string]interface{}{
"connection_url": connURL,
"plugin_name": "postgresql-database-plugin",
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: data,
}
resp, err := b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Create a role
data = map[string]interface{}{
"db_name": "plugin-test",
"creation_statements": testRole,
"revocation_statements": defaultRevocationSQL,
"default_ttl": "5m",
"max_ttl": "10m",
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Read the role
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
expected := dbplugin.Statements{
CreationStatements: testRole,
RevocationStatements: defaultRevocationSQL,
}
var actual dbplugin.Statements
if err := mapstructure.Decode(resp.Data, &actual); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("Statements did not match, exepected %#v, got %#v", expected, actual)
}
// Delete the role
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.DeleteOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Read the role
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Should be empty
if resp != nil {
t.Fatal("Expected response to be nil")
}
}
func TestBackend_roleReadOnly(t *testing.T) {
_, ln, sys, _ := getCore(t)
defer ln.Close()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
defer b.Cleanup()
cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b)
defer cleanup()
// Configure a connection
data := map[string]interface{}{
"connection_url": connURL,
"plugin_name": "postgresql-database-plugin",
}
req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: data,
}
resp, err := b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Create a role
data = map[string]interface{}{
"db_name": "plugin-test",
"creation_statements": testRole,
"default_ttl": "5m",
"max_ttl": "10m",
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Create a readonly role
data = map[string]interface{}{
"db_name": "plugin-test",
"creation_statements": testReadOnlyRole,
"default_ttl": "5m",
"max_ttl": "10m",
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "roles/plugin-readonly-role-test",
Storage: config.StorageView,
Data: data,
}
resp, err = b.HandleRequest(req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Get creds
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "creds/plugin-role-test",
Storage: config.StorageView,
Data: data,
}
credsResp, err := b.HandleRequest(req)
if err != nil || (credsResp != nil && credsResp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
}
if i := testCredsByCount(t, credsResp, connURL); i != 2 {
t.Fatalf("Got wrong number of creds got %d, expected 2", i)
}
// Get readonly creds
data = map[string]interface{}{}
req = &logical.Request{
Operation: logical.ReadOperation,
Path: "creds/plugin-readonly-role-test",
Storage: config.StorageView,
Data: data,
}
readOnlyCredsResp, err := b.HandleRequest(req)
if err != nil || (credsResp != nil && credsResp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, readOnlyCredsResp)
}
if i := testCredsByCount(t, readOnlyCredsResp, connURL); i != 2 {
t.Fatalf("Got wrong number of creds got %d, expected 2", i)
}
if err := testCreateTable(t, readOnlyCredsResp, connURL); err == nil {
t.Fatal("Read only creds should return error on table creation")
}
if err := testCreateTable(t, credsResp, connURL); err != nil {
t.Fatalf("Error on table creation: %s", err)
}
}
func testCredsByCount(t *testing.T, resp *logical.Response, connURL string) int {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
t.Fatal(err)
}
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
if err != nil {
t.Fatal(err)
}
returnedRows := func() int {
stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');")
if err != nil {
return -1
}
defer stmt.Close()
rows, err := stmt.Query(d.Username)
if err != nil {
return -1
}
defer rows.Close()
i := 0
for rows.Next() {
i++
}
return i
}
return returnedRows()
}
func testCreateTable(t *testing.T, resp *logical.Response, connURL string) error {
var d struct {
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
t.Fatal(err)
}
connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", d.Username, d.Password), 1)
fmt.Println(connURL)
log.Printf("[TRACE] Generated credentials: %v", d)
conn, err := pq.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
conn += " timezone=utc"
db, err := sql.Open("postgres", conn)
if err != nil {
t.Fatal(err)
}
r, err := db.Exec("CREATE TABLE test1 (id SERIAL PRIMARY KEY);")
if err != nil {
return err
}
if i, _ := r.RowsAffected(); i != 1 {
return errors.New("Did not create db")
}
return nil
}
const testRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
`
const testReadOnlyRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
REVOKE ALL ON SCHEMA public FROM "{{name}}";
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
`
const defaultRevocationSQL = `
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}};
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}};
REVOKE USAGE ON SCHEMA public FROM {{name}};
DROP ROLE IF EXISTS {{name}};
`

View File

@@ -172,10 +172,12 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
err = db.Initialize(config.ConnectionDetails)
if err != nil {
if !strings.Contains(err.Error(), "Error Initializing Connection") {
db.Close()
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}
if verifyConnection {
db.Close()
return logical.ErrorResponse("Could not verify connection"), nil
}
}

View File

@@ -105,7 +105,7 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie
return &logical.Response{
Data: map[string]interface{}{
"creation_statments": role.Statements.CreationStatements,
"creation_statements": role.Statements.CreationStatements,
"revocation_statements": role.Statements.RevocationStatements,
"rollback_statements": role.Statements.RollbackStatements,
"renew_statements": role.Statements.RenewStatements,

View File

@@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F
roleNameRaw, ok := req.Secret.InternalData["role"]
if !ok {
return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"])
return nil, fmt.Errorf("no role name was provided")
}
role, err := b.Role(req.Storage, roleNameRaw.(string))

View File

@@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int {
pluginName := args[0]
runner, ok := builtinplugins.BuiltinPlugins[pluginName]
runner, ok := builtinplugins.BuiltinPlugins.Get(pluginName)
if !ok {
c.Ui.Error(fmt.Sprintf(
"No plugin with the name %s found", pluginName))

View File

@@ -5,7 +5,20 @@ import (
"github.com/hashicorp/vault-plugins/database/postgresql"
)
var BuiltinPlugins = map[string]func() error{
var BuiltinPlugins *builtinPlugins = &builtinPlugins{
plugins: map[string]func() error{
"mysql-database-plugin": mysql.Run,
"postgresql-database-plugin": postgresql.Run,
},
}
// The list of builtin plugins should not be changed by any other package, so we
// store them in an unexported variable in this unexported struct.
type builtinPlugins struct {
plugins map[string]func() error
}
func (b *builtinPlugins) Get(name string) (func() error, bool) {
f, ok := b.plugins[name]
return f, ok
}

View File

@@ -54,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) {
}
// Look for builtin plugins
if _, ok := builtinplugins.BuiltinPlugins[name]; !ok {
if _, ok := builtinplugins.BuiltinPlugins.Get(name); !ok {
return nil, fmt.Errorf("no plugin found with name: %s", name)
}

View File

@@ -8,9 +8,13 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
@@ -306,7 +310,46 @@ func TestKeyCopy(key []byte) []byte {
}
func TestDynamicSystemView(c *Core) *dynamicSystemView {
return &dynamicSystemView{c, nil}
me := &MountEntry{
Config: MountConfig{
DefaultLeaseTTL: 24 * time.Hour,
MaxLeaseTTL: 2 * 24 * time.Hour,
},
}
return &dynamicSystemView{c, me}
}
func TestAddTestPlugin(t testing.TB, c *Core, name, command string) {
parts := strings.Split(command, " ")
file, err := os.Open(parts[0])
if err != nil {
t.Fatal(err)
}
defer file.Close()
hash := sha256.New()
_, err = io.Copy(hash, file)
if err != nil {
t.Fatal(err)
}
sum := hash.Sum(nil)
c.pluginCatalog.directory, err = filepath.EvalSymlinks(parts[0])
if err != nil {
t.Fatal(err)
}
c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory)
parts[0] = filepath.Base(parts[0])
command = strings.Join(parts, " ")
err = c.pluginCatalog.Set(name, command, sum)
if err != nil {
t.Fatal(err)
}
}
var testLogicalBackends = map[string]logical.Factory{}