diff --git a/builtin/logical/mssql/backend_test.go b/builtin/logical/mssql/backend_test.go index dd9a575eaf..76f7bc2033 100644 --- a/builtin/logical/mssql/backend_test.go +++ b/builtin/logical/mssql/backend_test.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "reflect" "testing" "github.com/hashicorp/vault/logical" @@ -11,6 +12,47 @@ import ( "github.com/mitchellh/mapstructure" ) +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + configData := map[string]interface{}{ + "connection_string": "sample_connection_string", + "max_open_connections": 7, + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + if resp != nil { + t.Fatalf("expected a nil response") + } + + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if !reflect.DeepEqual(configData, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data) + } +} + func TestBackend_basic(t *testing.T) { b, _ := Factory(logical.TestBackendConfig()) diff --git a/builtin/logical/mssql/path_config_connection.go b/builtin/logical/mssql/path_config_connection.go index e12ede0503..9125a4be83 100644 --- a/builtin/logical/mssql/path_config_connection.go +++ b/builtin/logical/mssql/path_config_connection.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" + "github.com/fatih/structs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -29,6 +30,7 @@ func pathConfigConnection(b *backend) *framework.Path { Callbacks: map[logical.Operation]framework.OperationFunc{ logical.UpdateOperation: b.pathConnectionWrite, + logical.ReadOperation: b.pathConnectionRead, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -36,8 +38,27 @@ func pathConfigConnection(b *backend) *framework.Path { } } -func (b *backend) pathConnectionWrite( - req *logical.Request, data *framework.FieldData) (*logical.Response, error) { +// pathConnectionWrite reads out the connection configuration +func (b *backend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + entry, err := req.Storage.Get("config/connection") + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config connectionConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil +} + +// pathConnectionWrite stores the connection configuration +func (b *backend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connString := data.Get("connection_string").(string) maxOpenConns := data.Get("max_open_connections").(int) @@ -66,6 +87,7 @@ func (b *backend) pathConnectionWrite( entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{ ConnectionString: connString, MaxOpenConnections: maxOpenConns, + VerifyConnection: verifyConnection, }) if err != nil { return nil, err @@ -80,8 +102,9 @@ func (b *backend) pathConnectionWrite( } type connectionConfig struct { - ConnectionString string `json:"connection_string"` - MaxOpenConnections int `json:"max_open_connections"` + ConnectionString string `json:"connection_string" structs:"connection_string" mapstructure:"connection_string"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + VerifyConnection bool `json:"verify_connection" structs:"verify_connection" mapstructure:"verify_connection"` } const pathConfigConnectionHelpSyn = ` diff --git a/builtin/logical/mysql/backend.go b/builtin/logical/mysql/backend.go index 70e6e33992..3b4761dda3 100644 --- a/builtin/logical/mysql/backend.go +++ b/builtin/logical/mysql/backend.go @@ -68,12 +68,7 @@ func (b *backend) DB(s logical.Storage) (*sql.DB, error) { return nil, err } - conn := connConfig.ConnectionURL - if len(conn) == 0 { - conn = connConfig.ConnectionString - } - - b.db, err = sql.Open("mysql", conn) + b.db, err = sql.Open("mysql", connConfig.ConnectionURL) if err != nil { return nil, err } diff --git a/builtin/logical/mysql/backend_test.go b/builtin/logical/mysql/backend_test.go index 804588a893..0c1b3a034d 100644 --- a/builtin/logical/mysql/backend_test.go +++ b/builtin/logical/mysql/backend_test.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "reflect" "testing" "github.com/hashicorp/vault/logical" @@ -11,6 +12,47 @@ import ( "github.com/mitchellh/mapstructure" ) +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "max_open_connections": 7, + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + if resp != nil { + t.Fatalf("expected a nil response") + } + + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if !reflect.DeepEqual(configData, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data) + } +} + func TestBackend_basic(t *testing.T) { b, _ := Factory(logical.TestBackendConfig()) diff --git a/builtin/logical/mysql/path_config_connection.go b/builtin/logical/mysql/path_config_connection.go index 34d693917f..81e38319c6 100644 --- a/builtin/logical/mysql/path_config_connection.go +++ b/builtin/logical/mysql/path_config_connection.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" + "github.com/fatih/structs" _ "github.com/go-sql-driver/mysql" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -17,12 +18,6 @@ func pathConfigConnection(b *backend) *framework.Path { Type: framework.TypeString, Description: "DB connection string", }, - "value": &framework.FieldSchema{ - Type: framework.TypeString, - Description: ` - DB connection string. Use 'connection_url' instead. -This name is deprecated.`, - }, "max_open_connections": &framework.FieldSchema{ Type: framework.TypeInt, Description: "Maximum number of open connections to database", @@ -36,6 +31,7 @@ This name is deprecated.`, Callbacks: map[logical.Operation]framework.OperationFunc{ logical.UpdateOperation: b.pathConnectionWrite, + logical.ReadOperation: b.pathConnectionRead, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -43,16 +39,30 @@ This name is deprecated.`, } } +// pathConnectionRead reads out the connection configuration +func (b *backend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + entry, err := req.Storage.Get("config/connection") + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config connectionConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil +} + func (b *backend) pathConnectionWrite( req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connValue := data.Get("value").(string) connURL := data.Get("connection_url").(string) if connURL == "" { - if connValue == "" { - return logical.ErrorResponse("the connection_url parameter must be supplied"), nil - } else { - connURL = connValue - } + return logical.ErrorResponse("the connection_url parameter must be supplied"), nil } maxOpenConns := data.Get("max_open_connections").(int) @@ -81,6 +91,7 @@ func (b *backend) pathConnectionWrite( entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{ ConnectionURL: connURL, MaxOpenConnections: maxOpenConns, + VerifyConnection: verifyConnection, }) if err != nil { return nil, err @@ -95,10 +106,9 @@ func (b *backend) pathConnectionWrite( } type connectionConfig struct { - ConnectionURL string `json:"connection_url"` - // Deprecate "value" in coming releases - ConnectionString string `json:"value"` - MaxOpenConnections int `json:"max_open_connections"` + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + VerifyConnection bool `json:"verify_connection" structs:"verify_connection" mapstructure:"verify_connection"` } const pathConfigConnectionHelpSyn = ` diff --git a/builtin/logical/postgresql/backend.go b/builtin/logical/postgresql/backend.go index 85cf2d1fff..6aad0d68d5 100644 --- a/builtin/logical/postgresql/backend.go +++ b/builtin/logical/postgresql/backend.go @@ -70,9 +70,6 @@ func (b *backend) DB(s logical.Storage) (*sql.DB, error) { } conn := connConfig.ConnectionURL - if len(conn) == 0 { - conn = connConfig.ConnectionString - } // Ensure timezone is set to UTC for all the conenctions if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { diff --git a/builtin/logical/postgresql/backend_test.go b/builtin/logical/postgresql/backend_test.go index ac2b39fbd3..7c2f73212f 100644 --- a/builtin/logical/postgresql/backend_test.go +++ b/builtin/logical/postgresql/backend_test.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "os" + "reflect" "testing" "github.com/hashicorp/vault/logical" @@ -13,6 +14,48 @@ import ( "github.com/mitchellh/mapstructure" ) +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "max_open_connections": 9, + "max_idle_connections": 7, + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + if resp != nil { + t.Fatalf("expected a nil response") + } + + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if !reflect.DeepEqual(configData, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data) + } +} + func TestBackend_basic(t *testing.T) { b, _ := Factory(logical.TestBackendConfig()) diff --git a/builtin/logical/postgresql/path_config_connection.go b/builtin/logical/postgresql/path_config_connection.go index c9dca1354a..8476bdc513 100644 --- a/builtin/logical/postgresql/path_config_connection.go +++ b/builtin/logical/postgresql/path_config_connection.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" + "github.com/fatih/structs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" _ "github.com/lib/pq" @@ -47,6 +48,7 @@ reduced to the same size.`, Callbacks: map[logical.Operation]framework.OperationFunc{ logical.UpdateOperation: b.pathConnectionWrite, + logical.ReadOperation: b.pathConnectionRead, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -54,6 +56,25 @@ reduced to the same size.`, } } +// pathConnectionRead reads out the connection configuration +func (b *backend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + entry, err := req.Storage.Get("config/connection") + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config connectionConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil +} + func (b *backend) pathConnectionWrite( req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connValue := data.Get("value").(string) @@ -97,10 +118,10 @@ func (b *backend) pathConnectionWrite( // Store it entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{ - ConnectionString: connValue, ConnectionURL: connURL, MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, + VerifyConnection: verifyConnection, }) if err != nil { return nil, err @@ -116,11 +137,10 @@ func (b *backend) pathConnectionWrite( } type connectionConfig struct { - ConnectionURL string `json:"connection_url"` - // Deprecate "value" in coming releases - ConnectionString string `json:"value"` - MaxOpenConnections int `json:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections"` + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + VerifyConnection bool `json:"verify_connection" structs:"verify_connection" mapstructure:"verify_connection"` } const pathConfigConnectionHelpSyn = `