Remove prepared stmnts from pgsql physical backend

Prepared statements prevent the use of connection multiplexing software
such as PGBouncer. Even when PGBouncer is configured for [session mode][1]
there's a possibility that a connection to PostgreSQL can be re-used by
different clients.  This leads to errors when clients use session based
features (like prepared statements).

This change removes prepared statements from the PostgreSQL physical
backend. This will allow vault to successfully work in infrastructures
that employ the use of PGBouncer or other connection multiplexing
software.

[1]: https://pgbouncer.github.io/config.html#poolmode
This commit is contained in:
Devin Christensen
2016-05-26 17:07:21 -06:00
parent 9c6aebf1c0
commit 3cbedeaa4d
2 changed files with 21 additions and 37 deletions

View File

@@ -14,9 +14,12 @@ import (
// PostgreSQL Backend is a physical backend that stores data
// within a PostgreSQL database.
type PostgreSQLBackend struct {
table string
client *sql.DB
statements map[string]*sql.Stmt
table string
client *sql.DB
put_query string
get_query string
delete_query string
list_query string
logger *log.Logger
}
@@ -50,49 +53,30 @@ func newPostgreSQLBackend(conf map[string]string, logger *log.Logger) (Backend,
// Setup our put strategy based on the presence or absence of a native
// upsert.
var put_statement string
var put_query string
if upsert_required {
put_statement = "SELECT vault_kv_put($1, $2, $3, $4)"
put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
} else {
put_statement = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
" ON CONFLICT (path, key) DO " +
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
}
// Setup the backend.
m := &PostgreSQLBackend{
table: quoted_table,
client: db,
statements: make(map[string]*sql.Stmt),
table: quoted_table,
client: db,
put_query: put_query,
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
logger: logger,
}
// Prepare all the statements required
statements := map[string]string{
"put": put_statement,
"get": "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
"delete": "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
"list": "SELECT key FROM " + quoted_table + " WHERE path = $1" +
"UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1",
}
for name, query := range statements {
if err := m.prepare(name, query); err != nil {
return nil, err
}
}
return m, nil
}
// prepare is a helper to prepare a query for future execution
func (m *PostgreSQLBackend) prepare(name, query string) error {
stmt, err := m.client.Prepare(query)
if err != nil {
return fmt.Errorf("failed to prepare '%s': %v", name, err)
}
m.statements[name] = stmt
return nil
}
// splitKey is a helper to split a full path key into individual
// parts: parentPath, path, key
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
@@ -123,7 +107,7 @@ func (m *PostgreSQLBackend) Put(entry *Entry) error {
parentPath, path, key := m.splitKey(entry.Key)
_, err := m.statements["put"].Exec(parentPath, path, key, entry.Value)
_, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
if err != nil {
return err
}
@@ -137,7 +121,7 @@ func (m *PostgreSQLBackend) Get(fullPath string) (*Entry, error) {
_, path, key := m.splitKey(fullPath)
var result []byte
err := m.statements["get"].QueryRow(path, key).Scan(&result)
err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
if err == sql.ErrNoRows {
return nil, nil
}
@@ -158,7 +142,7 @@ func (m *PostgreSQLBackend) Delete(fullPath string) error {
_, path, key := m.splitKey(fullPath)
_, err := m.statements["delete"].Exec(path, key)
_, err := m.client.Exec(m.delete_query, path, key)
if err != nil {
return err
}
@@ -170,7 +154,7 @@ func (m *PostgreSQLBackend) Delete(fullPath string) error {
func (m *PostgreSQLBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
rows, err := m.statements["list"].Query("/" + prefix)
rows, err := m.client.Query(m.list_query, "/" + prefix)
if err != nil {
return nil, err
}