mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Added HA backend for postgres based on dynamodb model (#5731)
Add optional HA support for postgres backend if Postgres version >= 9.5.
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/docker"
|
||||
@@ -30,56 +31,260 @@ func TestPostgreSQLBackend(t *testing.T) {
|
||||
table = "vault_kv_store"
|
||||
}
|
||||
|
||||
hae := os.Getenv("PGHAENABLED")
|
||||
if hae == "" {
|
||||
hae = "true"
|
||||
}
|
||||
|
||||
// Run vault tests
|
||||
logger.Info(fmt.Sprintf("Connection URL: %v", connURL))
|
||||
|
||||
b, err := NewPostgreSQLBackend(map[string]string{
|
||||
b1, err := NewPostgreSQLBackend(map[string]string{
|
||||
"connection_url": connURL,
|
||||
"table": table,
|
||||
"ha_enabled": hae,
|
||||
}, logger)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
pg := b.(*PostgreSQLBackend)
|
||||
|
||||
//Read postgres version to test basic connects works
|
||||
b2, err := NewPostgreSQLBackend(map[string]string{
|
||||
"connection_url": connURL,
|
||||
"table": table,
|
||||
"ha_enabled": hae,
|
||||
}, logger)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create new backend: %v", err)
|
||||
}
|
||||
pg := b1.(*PostgreSQLBackend)
|
||||
|
||||
// Read postgres version to test basic connects works
|
||||
var pgversion string
|
||||
if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil {
|
||||
t.Fatalf("Failed to check for Postgres version: %v", err)
|
||||
}
|
||||
logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion))
|
||||
|
||||
//Setup tables and indexes if not exists.
|
||||
createTableSQL := fmt.Sprintf(
|
||||
" CREATE TABLE IF NOT EXISTS %v ( "+
|
||||
" parent_path TEXT COLLATE \"C\" NOT NULL, "+
|
||||
" path TEXT COLLATE \"C\", "+
|
||||
" key TEXT COLLATE \"C\", "+
|
||||
" value BYTEA, "+
|
||||
" CONSTRAINT pkey PRIMARY KEY (path, key) "+
|
||||
" ); ", table)
|
||||
|
||||
_, err = pg.client.Exec(createTableSQL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", table)
|
||||
|
||||
_, err = pg.client.Exec(createIndexSQL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create index: %v", err)
|
||||
}
|
||||
setupDatabaseObjects(t, logger, pg)
|
||||
|
||||
defer func() {
|
||||
pg := b.(*PostgreSQLBackend)
|
||||
pg := b1.(*PostgreSQLBackend)
|
||||
_, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to truncate table: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
physical.ExerciseBackend(t, b)
|
||||
physical.ExerciseBackend_ListPrefix(t, b)
|
||||
logger.Info("Running basic backend tests")
|
||||
physical.ExerciseBackend(t, b1)
|
||||
logger.Info("Running list prefix backend tests")
|
||||
physical.ExerciseBackend_ListPrefix(t, b1)
|
||||
|
||||
ha1, ok := b1.(physical.HABackend)
|
||||
if !ok {
|
||||
t.Fatalf("PostgreSQLDB does not implement HABackend")
|
||||
}
|
||||
|
||||
ha2, ok := b2.(physical.HABackend)
|
||||
if !ok {
|
||||
t.Fatalf("PostgreSQLDB does not implement HABackend")
|
||||
}
|
||||
|
||||
if ha1.HAEnabled() && ha2.HAEnabled() {
|
||||
logger.Info("Running ha backend tests")
|
||||
physical.ExerciseHABackend(t, ha1, ha2)
|
||||
testPostgresSQLLockTTL(t, ha1)
|
||||
testPostgresSQLLockRenewal(t, ha1)
|
||||
}
|
||||
}
|
||||
|
||||
// Similar to testHABackend, but using internal implementation details to
|
||||
// trigger the lock failure scenario by setting the lock renew period for one
|
||||
// of the locks to a higher value than the lock TTL.
|
||||
func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) {
|
||||
// Set much smaller lock times to speed up the test.
|
||||
lockTTL := 3
|
||||
renewInterval := time.Second * 1
|
||||
watchInterval := time.Second * 1
|
||||
|
||||
// Get the lock
|
||||
origLock, err := ha.LockWith("dynamodbttl", "bar")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
// set the first lock renew period to double the expected TTL.
|
||||
lock := origLock.(*PostgreSQLLock)
|
||||
lock.renewInterval = time.Duration(lockTTL*2) * time.Second
|
||||
lock.ttlSeconds = lockTTL
|
||||
// lock.retryInterval = watchInterval
|
||||
|
||||
// Attempt to lock
|
||||
leaderCh, err := lock.Lock(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if leaderCh == nil {
|
||||
t.Fatalf("failed to get leader ch")
|
||||
}
|
||||
|
||||
// Check the value
|
||||
held, val, err := lock.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
t.Fatalf("should be held")
|
||||
}
|
||||
if val != "bar" {
|
||||
t.Fatalf("bad value: %v", err)
|
||||
}
|
||||
|
||||
// Second acquisition should succeed because the first lock should
|
||||
// not renew within the 3 sec TTL.
|
||||
origLock2, err := ha.LockWith("dynamodbttl", "baz")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
lock2 := origLock2.(*PostgreSQLLock)
|
||||
lock2.renewInterval = renewInterval
|
||||
lock2.ttlSeconds = lockTTL
|
||||
// lock2.retryInterval = watchInterval
|
||||
|
||||
// Cancel attempt in 6 sec so as not to block unit tests forever
|
||||
stopCh := make(chan struct{})
|
||||
time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() {
|
||||
close(stopCh)
|
||||
})
|
||||
|
||||
// Attempt to lock should work
|
||||
leaderCh2, err := lock2.Lock(stopCh)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if leaderCh2 == nil {
|
||||
t.Fatalf("should get leader ch")
|
||||
}
|
||||
|
||||
// Check the value
|
||||
held, val, err = lock2.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
t.Fatalf("should be held")
|
||||
}
|
||||
if val != "baz" {
|
||||
t.Fatalf("bad value: %v", err)
|
||||
}
|
||||
|
||||
// The first lock should have lost the leader channel
|
||||
leaderChClosed := false
|
||||
blocking := make(chan struct{})
|
||||
// Attempt to read from the leader or the blocking channel, which ever one
|
||||
// happens first.
|
||||
go func() {
|
||||
select {
|
||||
case <-time.After(watchInterval * 3):
|
||||
return
|
||||
case <-leaderCh:
|
||||
leaderChClosed = true
|
||||
close(blocking)
|
||||
case <-blocking:
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
<-blocking
|
||||
if !leaderChClosed {
|
||||
t.Fatalf("original lock did not have its leader channel closed.")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
lock2.Unlock()
|
||||
}
|
||||
|
||||
// Verify that once Unlock is called, we don't keep trying to renew the original
|
||||
// lock.
|
||||
func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) {
|
||||
// Get the lock
|
||||
origLock, err := ha.LockWith("pgrenewal", "bar")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// customize the renewal and watch intervals
|
||||
lock := origLock.(*PostgreSQLLock)
|
||||
// lock.renewInterval = time.Second * 1
|
||||
|
||||
// Attempt to lock
|
||||
leaderCh, err := lock.Lock(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if leaderCh == nil {
|
||||
t.Fatalf("failed to get leader ch")
|
||||
}
|
||||
|
||||
// Check the value
|
||||
held, val, err := lock.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
t.Fatalf("should be held")
|
||||
}
|
||||
if val != "bar" {
|
||||
t.Fatalf("bad value: %v", err)
|
||||
}
|
||||
|
||||
// Release the lock, which will delete the stored item
|
||||
if err := lock.Unlock(); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Wait longer than the renewal time
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
// Attempt to lock with new lock
|
||||
newLock, err := ha.LockWith("pgrenewal", "baz")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Cancel attempt after lock ttl + 1s so as not to block unit tests forever
|
||||
stopCh := make(chan struct{})
|
||||
timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second
|
||||
time.AfterFunc(timeout, func() {
|
||||
t.Logf("giving up on lock attempt after %v", timeout)
|
||||
close(stopCh)
|
||||
})
|
||||
|
||||
// Attempt to lock should work
|
||||
leaderCh2, err := newLock.Lock(stopCh)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if leaderCh2 == nil {
|
||||
t.Fatalf("should get leader ch")
|
||||
}
|
||||
|
||||
// Check the value
|
||||
held, val, err = newLock.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if !held {
|
||||
t.Fatalf("should be held")
|
||||
}
|
||||
if val != "baz" {
|
||||
t.Fatalf("bad value: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
newLock.Unlock()
|
||||
}
|
||||
|
||||
func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retConnString string) {
|
||||
@@ -92,7 +297,7 @@ func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retC
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
//using 11.1 which is currently latest, use hard version for stabillity of tests
|
||||
// using 11.1 which is currently latest, use hard version for stability of tests
|
||||
resource, err := pool.Run("postgres", "11.1", []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start docker Postgres: %s", err)
|
||||
@@ -122,3 +327,42 @@ func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retC
|
||||
|
||||
return cleanup, retConnString
|
||||
}
|
||||
|
||||
func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) {
|
||||
var err error
|
||||
// Setup tables and indexes if not exists.
|
||||
createTableSQL := fmt.Sprintf(
|
||||
" CREATE TABLE IF NOT EXISTS %v ( "+
|
||||
" parent_path TEXT COLLATE \"C\" NOT NULL, "+
|
||||
" path TEXT COLLATE \"C\", "+
|
||||
" key TEXT COLLATE \"C\", "+
|
||||
" value BYTEA, "+
|
||||
" CONSTRAINT pkey PRIMARY KEY (path, key) "+
|
||||
" ); ", pg.table)
|
||||
|
||||
_, err = pg.client.Exec(createTableSQL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table)
|
||||
|
||||
_, err = pg.client.Exec(createIndexSQL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create index: %v", err)
|
||||
}
|
||||
|
||||
createHaTableSQL :=
|
||||
" CREATE TABLE IF NOT EXISTS vault_ha_locks ( " +
|
||||
" ha_key TEXT COLLATE \"C\" NOT NULL, " +
|
||||
" ha_identity TEXT COLLATE \"C\" NOT NULL, " +
|
||||
" ha_value TEXT COLLATE \"C\", " +
|
||||
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
|
||||
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
|
||||
" ); "
|
||||
|
||||
_, err = pg.client.Exec(createHaTableSQL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create hatable: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user