mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-01 02:57:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			369 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			369 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package postgresql
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"os"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	log "github.com/hashicorp/go-hclog"
 | |
| 	"github.com/hashicorp/vault/helper/testhelpers/docker"
 | |
| 	"github.com/hashicorp/vault/sdk/helper/logging"
 | |
| 	"github.com/hashicorp/vault/sdk/physical"
 | |
| 	"github.com/ory/dockertest"
 | |
| 
 | |
| 	_ "github.com/lib/pq"
 | |
| )
 | |
| 
 | |
| func TestPostgreSQLBackend(t *testing.T) {
 | |
| 	logger := logging.NewVaultLogger(log.Debug)
 | |
| 
 | |
| 	// Use docker as pg backend if no url is provided via environment variables
 | |
| 	var cleanup func()
 | |
| 	connURL := os.Getenv("PGURL")
 | |
| 	if connURL == "" {
 | |
| 		cleanup, connURL = prepareTestContainer(t, logger)
 | |
| 		defer cleanup()
 | |
| 	}
 | |
| 
 | |
| 	table := os.Getenv("PGTABLE")
 | |
| 	if table == "" {
 | |
| 		table = "vault_kv_store"
 | |
| 	}
 | |
| 
 | |
| 	hae := os.Getenv("PGHAENABLED")
 | |
| 	if hae == "" {
 | |
| 		hae = "true"
 | |
| 	}
 | |
| 
 | |
| 	// Run vault tests
 | |
| 	logger.Info(fmt.Sprintf("Connection URL: %v", connURL))
 | |
| 
 | |
| 	b1, err := NewPostgreSQLBackend(map[string]string{
 | |
| 		"connection_url": connURL,
 | |
| 		"table":          table,
 | |
| 		"ha_enabled":     hae,
 | |
| 	}, logger)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to create new backend: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	b2, err := NewPostgreSQLBackend(map[string]string{
 | |
| 		"connection_url": connURL,
 | |
| 		"table":          table,
 | |
| 		"ha_enabled":     hae,
 | |
| 	}, logger)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to create new backend: %v", err)
 | |
| 	}
 | |
| 	pg := b1.(*PostgreSQLBackend)
 | |
| 
 | |
| 	// Read postgres version to test basic connects works
 | |
| 	var pgversion string
 | |
| 	if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil {
 | |
| 		t.Fatalf("Failed to check for Postgres version: %v", err)
 | |
| 	}
 | |
| 	logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion))
 | |
| 
 | |
| 	setupDatabaseObjects(t, logger, pg)
 | |
| 
 | |
| 	defer func() {
 | |
| 		pg := b1.(*PostgreSQLBackend)
 | |
| 		_, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table))
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("Failed to truncate table: %v", err)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	logger.Info("Running basic backend tests")
 | |
| 	physical.ExerciseBackend(t, b1)
 | |
| 	logger.Info("Running list prefix backend tests")
 | |
| 	physical.ExerciseBackend_ListPrefix(t, b1)
 | |
| 
 | |
| 	ha1, ok := b1.(physical.HABackend)
 | |
| 	if !ok {
 | |
| 		t.Fatalf("PostgreSQLDB does not implement HABackend")
 | |
| 	}
 | |
| 
 | |
| 	ha2, ok := b2.(physical.HABackend)
 | |
| 	if !ok {
 | |
| 		t.Fatalf("PostgreSQLDB does not implement HABackend")
 | |
| 	}
 | |
| 
 | |
| 	if ha1.HAEnabled() && ha2.HAEnabled() {
 | |
| 		logger.Info("Running ha backend tests")
 | |
| 		physical.ExerciseHABackend(t, ha1, ha2)
 | |
| 		testPostgresSQLLockTTL(t, ha1)
 | |
| 		testPostgresSQLLockRenewal(t, ha1)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Similar to testHABackend, but using internal implementation details to
 | |
| // trigger the lock failure scenario by setting the lock renew period for one
 | |
| // of the locks to a higher value than the lock TTL.
 | |
| func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) {
 | |
| 	// Set much smaller lock times to speed up the test.
 | |
| 	lockTTL := 3
 | |
| 	renewInterval := time.Second * 1
 | |
| 	watchInterval := time.Second * 1
 | |
| 
 | |
| 	// Get the lock
 | |
| 	origLock, err := ha.LockWith("dynamodbttl", "bar")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	// set the first lock renew period to double the expected TTL.
 | |
| 	lock := origLock.(*PostgreSQLLock)
 | |
| 	lock.renewInterval = time.Duration(lockTTL*2) * time.Second
 | |
| 	lock.ttlSeconds = lockTTL
 | |
| 	// lock.retryInterval = watchInterval
 | |
| 
 | |
| 	// Attempt to lock
 | |
| 	leaderCh, err := lock.Lock(nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if leaderCh == nil {
 | |
| 		t.Fatalf("failed to get leader ch")
 | |
| 	}
 | |
| 
 | |
| 	// Check the value
 | |
| 	held, val, err := lock.Value()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if !held {
 | |
| 		t.Fatalf("should be held")
 | |
| 	}
 | |
| 	if val != "bar" {
 | |
| 		t.Fatalf("bad value: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Second acquisition should succeed because the first lock should
 | |
| 	// not renew within the 3 sec TTL.
 | |
| 	origLock2, err := ha.LockWith("dynamodbttl", "baz")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	lock2 := origLock2.(*PostgreSQLLock)
 | |
| 	lock2.renewInterval = renewInterval
 | |
| 	lock2.ttlSeconds = lockTTL
 | |
| 	// lock2.retryInterval = watchInterval
 | |
| 
 | |
| 	// Cancel attempt in 6 sec so as not to block unit tests forever
 | |
| 	stopCh := make(chan struct{})
 | |
| 	time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() {
 | |
| 		close(stopCh)
 | |
| 	})
 | |
| 
 | |
| 	// Attempt to lock should work
 | |
| 	leaderCh2, err := lock2.Lock(stopCh)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if leaderCh2 == nil {
 | |
| 		t.Fatalf("should get leader ch")
 | |
| 	}
 | |
| 
 | |
| 	// Check the value
 | |
| 	held, val, err = lock2.Value()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if !held {
 | |
| 		t.Fatalf("should be held")
 | |
| 	}
 | |
| 	if val != "baz" {
 | |
| 		t.Fatalf("bad value: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// The first lock should have lost the leader channel
 | |
| 	leaderChClosed := false
 | |
| 	blocking := make(chan struct{})
 | |
| 	// Attempt to read from the leader or the blocking channel, which ever one
 | |
| 	// happens first.
 | |
| 	go func() {
 | |
| 		select {
 | |
| 		case <-time.After(watchInterval * 3):
 | |
| 			return
 | |
| 		case <-leaderCh:
 | |
| 			leaderChClosed = true
 | |
| 			close(blocking)
 | |
| 		case <-blocking:
 | |
| 			return
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	<-blocking
 | |
| 	if !leaderChClosed {
 | |
| 		t.Fatalf("original lock did not have its leader channel closed.")
 | |
| 	}
 | |
| 
 | |
| 	// Cleanup
 | |
| 	lock2.Unlock()
 | |
| }
 | |
| 
 | |
| // Verify that once Unlock is called, we don't keep trying to renew the original
 | |
| // lock.
 | |
| func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) {
 | |
| 	// Get the lock
 | |
| 	origLock, err := ha.LockWith("pgrenewal", "bar")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// customize the renewal and watch intervals
 | |
| 	lock := origLock.(*PostgreSQLLock)
 | |
| 	// lock.renewInterval = time.Second * 1
 | |
| 
 | |
| 	// Attempt to lock
 | |
| 	leaderCh, err := lock.Lock(nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if leaderCh == nil {
 | |
| 		t.Fatalf("failed to get leader ch")
 | |
| 	}
 | |
| 
 | |
| 	// Check the value
 | |
| 	held, val, err := lock.Value()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if !held {
 | |
| 		t.Fatalf("should be held")
 | |
| 	}
 | |
| 	if val != "bar" {
 | |
| 		t.Fatalf("bad value: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Release the lock, which will delete the stored item
 | |
| 	if err := lock.Unlock(); err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Wait longer than the renewal time
 | |
| 	time.Sleep(1500 * time.Millisecond)
 | |
| 
 | |
| 	// Attempt to lock with new lock
 | |
| 	newLock, err := ha.LockWith("pgrenewal", "baz")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Cancel attempt after lock ttl + 1s so as not to block unit tests forever
 | |
| 	stopCh := make(chan struct{})
 | |
| 	timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second
 | |
| 	time.AfterFunc(timeout, func() {
 | |
| 		t.Logf("giving up on lock attempt after %v", timeout)
 | |
| 		close(stopCh)
 | |
| 	})
 | |
| 
 | |
| 	// Attempt to lock should work
 | |
| 	leaderCh2, err := newLock.Lock(stopCh)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if leaderCh2 == nil {
 | |
| 		t.Fatalf("should get leader ch")
 | |
| 	}
 | |
| 
 | |
| 	// Check the value
 | |
| 	held, val, err = newLock.Value()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if !held {
 | |
| 		t.Fatalf("should be held")
 | |
| 	}
 | |
| 	if val != "baz" {
 | |
| 		t.Fatalf("bad value: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Cleanup
 | |
| 	newLock.Unlock()
 | |
| }
 | |
| 
 | |
| func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retConnString string) {
 | |
| 	// If environment variable is set, use this connectionstring without starting docker container
 | |
| 	if os.Getenv("PGURL") != "" {
 | |
| 		return func() {}, os.Getenv("PGURL")
 | |
| 	}
 | |
| 
 | |
| 	pool, err := dockertest.NewPool("")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to connect to docker: %s", err)
 | |
| 	}
 | |
| 	// using 11.1 which is currently latest, use hard version for stability of tests
 | |
| 	resource, err := pool.Run("postgres", "11.1", []string{})
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Could not start docker Postgres: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	retConnString = fmt.Sprintf("postgres://postgres@localhost:%v/postgres?sslmode=disable", resource.GetPort("5432/tcp"))
 | |
| 
 | |
| 	cleanup = func() {
 | |
| 		docker.CleanupResource(t, pool, resource)
 | |
| 	}
 | |
| 
 | |
| 	// Provide a test function to the pool to test if docker instance service is up.
 | |
| 	// We try to setup a pg backend as test for successful connect
 | |
| 	// exponential backoff-retry, because the dockerinstance may not be able to accept
 | |
| 	// connections yet, test by trying to setup a postgres backend, max-timeout is 60s
 | |
| 	if err := pool.Retry(func() error {
 | |
| 		var err error
 | |
| 		_, err = NewPostgreSQLBackend(map[string]string{
 | |
| 			"connection_url": retConnString,
 | |
| 		}, logger)
 | |
| 		return err
 | |
| 
 | |
| 	}); err != nil {
 | |
| 		cleanup()
 | |
| 		t.Fatalf("Could not connect to docker: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	return cleanup, retConnString
 | |
| }
 | |
| 
 | |
| func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) {
 | |
| 	var err error
 | |
| 	// Setup tables and indexes if not exists.
 | |
| 	createTableSQL := fmt.Sprintf(
 | |
| 		"  CREATE TABLE IF NOT EXISTS %v ( "+
 | |
| 			"  parent_path TEXT COLLATE \"C\" NOT NULL, "+
 | |
| 			"  path        TEXT COLLATE \"C\", "+
 | |
| 			"  key         TEXT COLLATE \"C\", "+
 | |
| 			"  value       BYTEA, "+
 | |
| 			"  CONSTRAINT pkey PRIMARY KEY (path, key) "+
 | |
| 			" ); ", pg.table)
 | |
| 
 | |
| 	_, err = pg.client.Exec(createTableSQL)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to create table: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table)
 | |
| 
 | |
| 	_, err = pg.client.Exec(createIndexSQL)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to create index: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	createHaTableSQL :=
 | |
| 		" CREATE TABLE IF NOT EXISTS vault_ha_locks ( " +
 | |
| 			" ha_key                                      TEXT COLLATE \"C\" NOT NULL, " +
 | |
| 			" ha_identity                                 TEXT COLLATE \"C\" NOT NULL, " +
 | |
| 			" ha_value                                    TEXT COLLATE \"C\", " +
 | |
| 			" valid_until                                 TIMESTAMP WITH TIME ZONE NOT NULL, " +
 | |
| 			" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
 | |
| 			" ); "
 | |
| 
 | |
| 	_, err = pg.client.Exec(createHaTableSQL)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to create hatable: %v", err)
 | |
| 	}
 | |
| }
 | 
