Allow reading of config in sql backends

This commit is contained in:
vishalnayak
2016-06-11 11:48:40 -04:00
parent 117200c88a
commit adbfef8561
8 changed files with 207 additions and 35 deletions

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"log"
"os"
"reflect"
"testing"
"github.com/hashicorp/vault/logical"
@@ -11,6 +12,47 @@ import (
"github.com/mitchellh/mapstructure"
)
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"connection_string": "sample_connection_string",
"max_open_connections": 7,
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/connection",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response")
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if !reflect.DeepEqual(configData, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
}
}
func TestBackend_basic(t *testing.T) {
b, _ := Factory(logical.TestBackendConfig())

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@@ -29,6 +30,7 @@ func pathConfigConnection(b *backend) *framework.Path {
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
logical.ReadOperation: b.pathConnectionRead,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
@@ -36,8 +38,27 @@ func pathConfigConnection(b *backend) *framework.Path {
}
}
func (b *backend) pathConnectionWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// pathConnectionWrite reads out the connection configuration
func (b *backend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get("config/connection")
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config connectionConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
return &logical.Response{
Data: structs.New(config).Map(),
}, nil
}
// pathConnectionWrite stores the connection configuration
func (b *backend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connString := data.Get("connection_string").(string)
maxOpenConns := data.Get("max_open_connections").(int)
@@ -66,6 +87,7 @@ func (b *backend) pathConnectionWrite(
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
ConnectionString: connString,
MaxOpenConnections: maxOpenConns,
VerifyConnection: verifyConnection,
})
if err != nil {
return nil, err
@@ -80,8 +102,9 @@ func (b *backend) pathConnectionWrite(
}
type connectionConfig struct {
ConnectionString string `json:"connection_string"`
MaxOpenConnections int `json:"max_open_connections"`
ConnectionString string `json:"connection_string" structs:"connection_string" mapstructure:"connection_string"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
VerifyConnection bool `json:"verify_connection" structs:"verify_connection" mapstructure:"verify_connection"`
}
const pathConfigConnectionHelpSyn = `

View File

@@ -68,12 +68,7 @@ func (b *backend) DB(s logical.Storage) (*sql.DB, error) {
return nil, err
}
conn := connConfig.ConnectionURL
if len(conn) == 0 {
conn = connConfig.ConnectionString
}
b.db, err = sql.Open("mysql", conn)
b.db, err = sql.Open("mysql", connConfig.ConnectionURL)
if err != nil {
return nil, err
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"log"
"os"
"reflect"
"testing"
"github.com/hashicorp/vault/logical"
@@ -11,6 +12,47 @@ import (
"github.com/mitchellh/mapstructure"
)
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
"max_open_connections": 7,
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/connection",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response")
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if !reflect.DeepEqual(configData, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
}
}
func TestBackend_basic(t *testing.T) {
b, _ := Factory(logical.TestBackendConfig())

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"github.com/fatih/structs"
_ "github.com/go-sql-driver/mysql"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@@ -17,12 +18,6 @@ func pathConfigConnection(b *backend) *framework.Path {
Type: framework.TypeString,
Description: "DB connection string",
},
"value": &framework.FieldSchema{
Type: framework.TypeString,
Description: `
DB connection string. Use 'connection_url' instead.
This name is deprecated.`,
},
"max_open_connections": &framework.FieldSchema{
Type: framework.TypeInt,
Description: "Maximum number of open connections to database",
@@ -36,6 +31,7 @@ This name is deprecated.`,
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
logical.ReadOperation: b.pathConnectionRead,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
@@ -43,16 +39,30 @@ This name is deprecated.`,
}
}
// pathConnectionRead reads out the connection configuration
func (b *backend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get("config/connection")
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config connectionConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
return &logical.Response{
Data: structs.New(config).Map(),
}, nil
}
func (b *backend) pathConnectionWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connValue := data.Get("value").(string)
connURL := data.Get("connection_url").(string)
if connURL == "" {
if connValue == "" {
return logical.ErrorResponse("the connection_url parameter must be supplied"), nil
} else {
connURL = connValue
}
}
maxOpenConns := data.Get("max_open_connections").(int)
@@ -81,6 +91,7 @@ func (b *backend) pathConnectionWrite(
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
ConnectionURL: connURL,
MaxOpenConnections: maxOpenConns,
VerifyConnection: verifyConnection,
})
if err != nil {
return nil, err
@@ -95,10 +106,9 @@ func (b *backend) pathConnectionWrite(
}
type connectionConfig struct {
ConnectionURL string `json:"connection_url"`
// Deprecate "value" in coming releases
ConnectionString string `json:"value"`
MaxOpenConnections int `json:"max_open_connections"`
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
VerifyConnection bool `json:"verify_connection" structs:"verify_connection" mapstructure:"verify_connection"`
}
const pathConfigConnectionHelpSyn = `

View File

@@ -70,9 +70,6 @@ func (b *backend) DB(s logical.Storage) (*sql.DB, error) {
}
conn := connConfig.ConnectionURL
if len(conn) == 0 {
conn = connConfig.ConnectionString
}
// Ensure timezone is set to UTC for all the conenctions
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"log"
"os"
"reflect"
"testing"
"github.com/hashicorp/vault/logical"
@@ -13,6 +14,48 @@ import (
"github.com/mitchellh/mapstructure"
)
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
if err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
"max_open_connections": 9,
"max_idle_connections": 7,
"verify_connection": false,
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/connection",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if resp != nil {
t.Fatalf("expected a nil response")
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if !reflect.DeepEqual(configData, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data)
}
}
func TestBackend_basic(t *testing.T) {
b, _ := Factory(logical.TestBackendConfig())

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
_ "github.com/lib/pq"
@@ -47,6 +48,7 @@ reduced to the same size.`,
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathConnectionWrite,
logical.ReadOperation: b.pathConnectionRead,
},
HelpSynopsis: pathConfigConnectionHelpSyn,
@@ -54,6 +56,25 @@ reduced to the same size.`,
}
}
// pathConnectionRead reads out the connection configuration
func (b *backend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get("config/connection")
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration")
}
if entry == nil {
return nil, nil
}
var config connectionConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
return &logical.Response{
Data: structs.New(config).Map(),
}, nil
}
func (b *backend) pathConnectionWrite(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connValue := data.Get("value").(string)
@@ -97,10 +118,10 @@ func (b *backend) pathConnectionWrite(
// Store it
entry, err := logical.StorageEntryJSON("config/connection", connectionConfig{
ConnectionString: connValue,
ConnectionURL: connURL,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
VerifyConnection: verifyConnection,
})
if err != nil {
return nil, err
@@ -116,11 +137,10 @@ func (b *backend) pathConnectionWrite(
}
type connectionConfig struct {
ConnectionURL string `json:"connection_url"`
// Deprecate "value" in coming releases
ConnectionString string `json:"value"`
MaxOpenConnections int `json:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections"`
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
VerifyConnection bool `json:"verify_connection" structs:"verify_connection" mapstructure:"verify_connection"`
}
const pathConfigConnectionHelpSyn = `