diff --git a/physical/postgresql.go b/physical/postgresql.go index 3d22f8fc36..2fe2b4fafe 100644 --- a/physical/postgresql.go +++ b/physical/postgresql.go @@ -14,9 +14,12 @@ import ( // PostgreSQL Backend is a physical backend that stores data // within a PostgreSQL database. type PostgreSQLBackend struct { - table string - client *sql.DB - statements map[string]*sql.Stmt + table string + client *sql.DB + put_query string + get_query string + delete_query string + list_query string logger *log.Logger } @@ -50,49 +53,30 @@ func newPostgreSQLBackend(conf map[string]string, logger *log.Logger) (Backend, // Setup our put strategy based on the presence or absence of a native // upsert. - var put_statement string + var put_query string if upsert_required { - put_statement = "SELECT vault_kv_put($1, $2, $3, $4)" + put_query = "SELECT vault_kv_put($1, $2, $3, $4)" } else { - put_statement = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" + + put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" + " ON CONFLICT (path, key) DO " + " UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)" } // Setup the backend. m := &PostgreSQLBackend{ - table: quoted_table, - client: db, - statements: make(map[string]*sql.Stmt), + table: quoted_table, + client: db, + put_query: put_query, + get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2", + delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2", + list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" + + "UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1", logger: logger, } - // Prepare all the statements required - statements := map[string]string{ - "put": put_statement, - "get": "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2", - "delete": "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2", - "list": "SELECT key FROM " + quoted_table + " WHERE path = $1" + - "UNION SELECT substr(path, length($1)+1) FROM " + quoted_table + "WHERE parent_path = $1", - } - for name, query := range statements { - if err := m.prepare(name, query); err != nil { - return nil, err - } - } return m, nil } -// prepare is a helper to prepare a query for future execution -func (m *PostgreSQLBackend) prepare(name, query string) error { - stmt, err := m.client.Prepare(query) - if err != nil { - return fmt.Errorf("failed to prepare '%s': %v", name, err) - } - m.statements[name] = stmt - return nil -} - // splitKey is a helper to split a full path key into individual // parts: parentPath, path, key func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) { @@ -123,7 +107,7 @@ func (m *PostgreSQLBackend) Put(entry *Entry) error { parentPath, path, key := m.splitKey(entry.Key) - _, err := m.statements["put"].Exec(parentPath, path, key, entry.Value) + _, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value) if err != nil { return err } @@ -137,7 +121,7 @@ func (m *PostgreSQLBackend) Get(fullPath string) (*Entry, error) { _, path, key := m.splitKey(fullPath) var result []byte - err := m.statements["get"].QueryRow(path, key).Scan(&result) + err := m.client.QueryRow(m.get_query, path, key).Scan(&result) if err == sql.ErrNoRows { return nil, nil } @@ -158,7 +142,7 @@ func (m *PostgreSQLBackend) Delete(fullPath string) error { _, path, key := m.splitKey(fullPath) - _, err := m.statements["delete"].Exec(path, key) + _, err := m.client.Exec(m.delete_query, path, key) if err != nil { return err } @@ -170,7 +154,7 @@ func (m *PostgreSQLBackend) Delete(fullPath string) error { func (m *PostgreSQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now()) - rows, err := m.statements["list"].Query("/" + prefix) + rows, err := m.client.Query(m.list_query, "/" + prefix) if err != nil { return nil, err } diff --git a/physical/postgresql_test.go b/physical/postgresql_test.go index 92b014ec2d..026ed29619 100644 --- a/physical/postgresql_test.go +++ b/physical/postgresql_test.go @@ -32,7 +32,7 @@ func TestPostgreSQLBackend(t *testing.T) { defer func() { pg := b.(*PostgreSQLBackend) - _, err := pg.client.Exec("DROP TABLE " + pg.table) + _, err := pg.client.Exec("TRUNCATE TABLE " + pg.table) if err != nil { t.Fatalf("Failed to drop table: %v", err) }