Remove prepared stmnts from pgsql physical backend

Prepared statements prevent the use of connection multiplexing software
such as PGBouncer. Even when PGBouncer is configured for [session mode][1]
there's a possibility that a connection to PostgreSQL can be re-used by
different clients.  This leads to errors when clients use session based
features (like prepared statements).

This change removes prepared statements from the PostgreSQL physical
backend. This will allow vault to successfully work in infrastructures
that employ the use of PGBouncer or other connection multiplexing
software.

[1]: https://pgbouncer.github.io/config.html#poolmode
This commit is contained in:
Devin Christensen
2016-05-26 17:07:21 -06:00
parent 9c6aebf1c0
commit 3cbedeaa4d
2 changed files with 21 additions and 37 deletions

View File

@@ -16,7 +16,10 @@ import (
type PostgreSQLBackend struct { type PostgreSQLBackend struct {
table string table string
client *sql.DB client *sql.DB
statements map[string]*sql.Stmt put_query string
get_query string
delete_query string
list_query string
logger *log.Logger logger *log.Logger
} }
@@ -50,11 +53,11 @@ func newPostgreSQLBackend(conf map[string]string, logger *log.Logger) (Backend,
// Setup our put strategy based on the presence or absence of a native // Setup our put strategy based on the presence or absence of a native
// upsert. // upsert.
var put_statement string var put_query string
if upsert_required { if upsert_required {
put_statement = "SELECT vault_kv_put($1, $2, $3, $4)" put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
} else { } else {
put_statement = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" + put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
" ON CONFLICT (path, key) DO " + " ON CONFLICT (path, key) DO " +
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)" " UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
} }
@@ -63,36 +66,17 @@ func newPostgreSQLBackend(conf map[string]string, logger *log.Logger) (Backend,
m := &PostgreSQLBackend{ m := &PostgreSQLBackend{
table: quoted_table, table: quoted_table,
client: db, client: db,
statements: make(map[string]*sql.Stmt), put_query: put_query,
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
logger: logger, logger: logger,
} }
// Prepare all the statements required
statements := map[string]string{
"put": put_statement,
"get": "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
"delete": "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
"list": "SELECT key FROM " + quoted_table + " WHERE path = $1" +
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil return m, nil
} }
// prepare is a helper to prepare a query for future execution
func (m *PostgreSQLBackend) prepare(name, query string) error {
stmt, err := m.client.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare '%s': %v", name, err)
}
m.statements[name] = stmt
return nil
}
// splitKey is a helper to split a full path key into individual // splitKey is a helper to split a full path key into individual
// parts: parentPath, path, key // parts: parentPath, path, key
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) { func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
@@ -123,7 +107,7 @@ func (m *PostgreSQLBackend) Put(entry *Entry) error {
parentPath, path, key := m.splitKey(entry.Key) parentPath, path, key := m.splitKey(entry.Key)
_, err := m.statements["put"].Exec(parentPath, path, key, entry.Value) _, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
if err != nil { if err != nil {
return err return err
} }
@@ -137,7 +121,7 @@ func (m *PostgreSQLBackend) Get(fullPath string) (*Entry, error) {
_, path, key := m.splitKey(fullPath) _, path, key := m.splitKey(fullPath)
var result []byte var result []byte
err := m.statements["get"].QueryRow(path, key).Scan(&result) err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
@@ -158,7 +142,7 @@ func (m *PostgreSQLBackend) Delete(fullPath string) error {
_, path, key := m.splitKey(fullPath) _, path, key := m.splitKey(fullPath)
_, err := m.statements["delete"].Exec(path, key) _, err := m.client.Exec(m.delete_query, path, key)
if err != nil { if err != nil {
return err return err
} }
@@ -170,7 +154,7 @@ func (m *PostgreSQLBackend) Delete(fullPath string) error {
func (m *PostgreSQLBackend) List(prefix string) ([]string, error) { func (m *PostgreSQLBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now()) defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
rows, err := m.statements["list"].Query("/" + prefix) rows, err := m.client.Query(m.list_query, "/" + prefix)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -32,7 +32,7 @@ func TestPostgreSQLBackend(t *testing.T) {
defer func() { defer func() {
pg := b.(*PostgreSQLBackend) pg := b.(*PostgreSQLBackend)
_, err := pg.client.Exec("DROP TABLE " + pg.table) _, err := pg.client.Exec("TRUNCATE TABLE " + pg.table)
if err != nil { if err != nil {
t.Fatalf("Failed to drop table: %v", err) t.Fatalf("Failed to drop table: %v", err)
} }