diff --git a/builtin/logical/mysql/backend.go b/builtin/logical/mysql/backend.go index ea719e6e88..b80574f87e 100644 --- a/builtin/logical/mysql/backend.go +++ b/builtin/logical/mysql/backend.go @@ -74,9 +74,9 @@ func (b *backend) DB(s logical.Storage) (*sql.DB, error) { return nil, err } - conn := connConfig.ConnectionString + conn := connConfig.ConnectionURL if len(conn) == 0 { - conn = connConfig.ConnectionURL + conn = connConfig.ConnectionString } b.db, err = sql.Open("mysql", conn) diff --git a/builtin/logical/mysql/backend_test.go b/builtin/logical/mysql/backend_test.go index b302e29caa..39dcaa2584 100644 --- a/builtin/logical/mysql/backend_test.go +++ b/builtin/logical/mysql/backend_test.go @@ -33,6 +33,7 @@ func TestBackend_basic(t *testing.T) { }, }) } + func TestBackend_configConnection(t *testing.T) { b := Backend() d1 := map[string]interface{}{ @@ -58,6 +59,7 @@ func TestBackend_configConnection(t *testing.T) { }, }) } + func TestBackend_roleCrud(t *testing.T) { b := Backend() @@ -108,7 +110,6 @@ func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) Data: d, ErrorOk: true, Check: func(resp *logical.Response) error { - log.Printf("vishal: testAccStepConfig: resp: %#v\n", resp) if expectError { if resp.Data == nil { return fmt.Errorf("data is nil") diff --git a/builtin/logical/postgresql/backend.go b/builtin/logical/postgresql/backend.go index ef7e4db8ee..7aee50487e 100644 --- a/builtin/logical/postgresql/backend.go +++ b/builtin/logical/postgresql/backend.go @@ -8,7 +8,6 @@ import ( "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" - "github.com/lib/pq" ) func Factory(conf *logical.BackendConfig) (logical.Backend, error) { @@ -76,20 +75,21 @@ func (b *backend) DB(s logical.Storage) (*sql.DB, error) { return nil, err } - conn := connConfig.ConnectionString + conn := connConfig.ConnectionURL if len(conn) == 0 { - conn = connConfig.ConnectionURL + conn = connConfig.ConnectionString } // Ensure timezone is set to UTC for all the conenctions if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { - var err error - conn, err = pq.ParseURL(conn) - if err != nil { - return nil, err + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" } + } else { + conn += " timezone=utc" } - conn += " timezone=utc" b.db, err = sql.Open("postgres", conn) if err != nil { diff --git a/builtin/logical/postgresql/backend_test.go b/builtin/logical/postgresql/backend_test.go index a16cb2646a..9723e2aef9 100644 --- a/builtin/logical/postgresql/backend_test.go +++ b/builtin/logical/postgresql/backend_test.go @@ -16,11 +16,21 @@ import ( func TestBackend_basic(t *testing.T) { b, _ := Factory(logical.TestBackendConfig()) + d1 := map[string]interface{}{ + "connection_url": os.Getenv("PG_URL"), + } + d2 := map[string]interface{}{ + "value": os.Getenv("PG_URL"), + } + logicaltest.Test(t, logicaltest.TestCase{ PreCheck: func() { testAccPreCheck(t) }, Backend: b, Steps: []logicaltest.TestStep{ - testAccStepConfig(t), + testAccStepConfig(t, d1, false), + testAccStepRole(t), + testAccStepReadCreds(t, b, "web"), + testAccStepConfig(t, d2, false), testAccStepRole(t), testAccStepReadCreds(t, b, "web"), }, @@ -30,12 +40,15 @@ func TestBackend_basic(t *testing.T) { func TestBackend_roleCrud(t *testing.T) { b, _ := Factory(logical.TestBackendConfig()) + d := map[string]interface{}{ + "connection_url": os.Getenv("PG_URL"), + } logicaltest.Test(t, logicaltest.TestCase{ PreCheck: func() { testAccPreCheck(t) }, Backend: b, Steps: []logicaltest.TestStep{ - testAccStepConfig(t), + testAccStepConfig(t, d, false), testAccStepRole(t), testAccStepReadRole(t, "web", testRole), testAccStepDeleteRole(t, "web"), @@ -44,18 +57,63 @@ func TestBackend_roleCrud(t *testing.T) { }) } +func TestBackend_configConnection(t *testing.T) { + b := Backend() + d1 := map[string]interface{}{ + "value": os.Getenv("PG_URL"), + } + d2 := map[string]interface{}{ + "connection_url": os.Getenv("PG_URL"), + } + d3 := map[string]interface{}{ + "value": os.Getenv("PG_URL"), + "connection_url": os.Getenv("PG_URL"), + } + d4 := map[string]interface{}{} + + logicaltest.Test(t, logicaltest.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, d1, false), + testAccStepConfig(t, d2, false), + testAccStepConfig(t, d3, false), + testAccStepConfig(t, d4, true), + }, + }) +} + func testAccPreCheck(t *testing.T) { if v := os.Getenv("PG_URL"); v == "" { t.Fatal("PG_URL must be set for acceptance tests") } } -func testAccStepConfig(t *testing.T) logicaltest.TestStep { +func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.UpdateOperation, Path: "config/connection", - Data: map[string]interface{}{ - "value": os.Getenv("PG_URL"), + Data: d, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if expectError { + if resp.Data == nil { + return fmt.Errorf("data is nil") + } + var e struct { + Error string `mapstructure:"error"` + } + if err := mapstructure.Decode(resp.Data, &e); err != nil { + return err + } + if len(e.Error) == 0 { + return fmt.Errorf("expected error, but write succeeded.") + } + return nil + } else if resp != nil { + return fmt.Errorf("response should be nil") + } + return nil }, } } diff --git a/builtin/logical/postgresql/path_config_connection.go b/builtin/logical/postgresql/path_config_connection.go index ee5cd26342..c9dca1354a 100644 --- a/builtin/logical/postgresql/path_config_connection.go +++ b/builtin/logical/postgresql/path_config_connection.go @@ -22,6 +22,11 @@ func pathConfigConnection(b *backend) *framework.Path { Description: `DB connection string. Use 'connection_url' instead. This will be deprecated.`, }, + "verify_connection": &framework.FieldSchema{ + Type: framework.TypeBool, + Default: true, + Description: `If set, connection_url is verified by actually connecting to the database`, + }, "max_open_connections": &framework.FieldSchema{ Type: framework.TypeInt, Description: `Maximum number of open connections to the database; @@ -51,8 +56,15 @@ reduced to the same size.`, func (b *backend) pathConnectionWrite( req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connString := data.Get("value").(string) + connValue := data.Get("value").(string) connURL := data.Get("connection_url").(string) + if connURL == "" { + if connValue == "" { + return logical.ErrorResponse("connection_url parameter must be supplied"), nil + } else { + connURL = connValue + } + } maxOpenConns := data.Get("max_open_connections").(int) if maxOpenConns == 0 { @@ -67,21 +79,25 @@ func (b *backend) pathConnectionWrite( maxIdleConns = maxOpenConns } - // Verify the string - db, err := sql.Open("postgres", connString) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil - } - defer db.Close() - if err := db.Ping(); err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } } // Store it entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{ - ConnectionString: connString, + ConnectionString: connValue, ConnectionURL: connURL, MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, diff --git a/website/source/docs/secrets/mysql/index.html.md b/website/source/docs/secrets/mysql/index.html.md index 52689b89f2..6e0ea36096 100644 --- a/website/source/docs/secrets/mysql/index.html.md +++ b/website/source/docs/secrets/mysql/index.html.md @@ -137,7 +137,7 @@ allowed to read.