diff --git a/builtin/logical/mysql/backend_test.go b/builtin/logical/mysql/backend_test.go new file mode 100644 index 0000000000..937b054132 --- /dev/null +++ b/builtin/logical/mysql/backend_test.go @@ -0,0 +1,128 @@ +package mysql + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/hashicorp/vault/logical" + logicaltest "github.com/hashicorp/vault/logical/testing" + "github.com/mitchellh/mapstructure" +) + +func TestBackend_basic(t *testing.T) { + b := Backend() + + logicaltest.Test(t, logicaltest.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t), + testAccStepRole(t), + testAccStepReadCreds(t, "web"), + }, + }) +} + +func TestBackend_roleCrud(t *testing.T) { + b := Backend() + + logicaltest.Test(t, logicaltest.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t), + testAccStepRole(t), + testAccStepReadRole(t, "web", testRole), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web", ""), + }, + }) +} + +func testAccPreCheck(t *testing.T) { + if v := os.Getenv("MYSQL_DSN"); v == "" { + t.Fatal("MYSQL_DSN must be set for acceptance tests") + } +} + +func testAccStepConfig(t *testing.T) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "config/connection", + Data: map[string]interface{}{ + "value": os.Getenv("MYSQL_DSN"), + }, + } +} + +func testAccStepRole(t *testing.T) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "roles/web", + Data: map[string]interface{}{ + "sql": testRole, + }, + } +} + +func testAccStepDeleteRole(t *testing.T, n string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.DeleteOperation, + Path: "roles/" + n, + } +} + +func testAccStepReadCreds(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "creds/" + name, + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[WARN] Generated credentials: %v", d) + + return nil + }, + } +} + +func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "roles/" + name, + Check: func(resp *logical.Response) error { + if resp == nil { + if sql == "" { + return nil + } + + return fmt.Errorf("bad: %#v", resp) + } + + var d struct { + SQL string `mapstructure:"sql"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + if d.SQL != sql { + return fmt.Errorf("bad: %#v", resp) + } + + return nil + }, + } +} + +const testRole = ` +CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; +GRANT SELECT ON *.* TO '{{name}}'@'%'; +`