mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			475 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			475 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) HashiCorp, Inc.
 | |
| // SPDX-License-Identifier: MPL-2.0
 | |
| 
 | |
| package postgresql
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"os"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/armon/go-metrics"
 | |
| 	log "github.com/hashicorp/go-hclog"
 | |
| 	"github.com/hashicorp/go-uuid"
 | |
| 	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
 | |
| 	"github.com/hashicorp/vault/sdk/physical"
 | |
| 	_ "github.com/jackc/pgx/v4/stdlib"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 
 | |
| 	// The lock TTL matches the default that Consul API uses, 15 seconds.
 | |
| 	// Used as part of SQL commands to set/extend lock expiry time relative to
 | |
| 	// database clock.
 | |
| 	PostgreSQLLockTTLSeconds = 15
 | |
| 
 | |
| 	// The amount of time to wait between the lock renewals
 | |
| 	PostgreSQLLockRenewInterval = 5 * time.Second
 | |
| 
 | |
| 	// PostgreSQLLockRetryInterval is the amount of time to wait
 | |
| 	// if a lock fails before trying again.
 | |
| 	PostgreSQLLockRetryInterval = time.Second
 | |
| )
 | |
| 
 | |
| // Verify PostgreSQLBackend satisfies the correct interfaces
 | |
| var _ physical.Backend = (*PostgreSQLBackend)(nil)
 | |
| 
 | |
| // HA backend was implemented based on the DynamoDB backend pattern
 | |
| // With distinction using central postgres clock, hereby avoiding
 | |
| // possible issues with multiple clocks
 | |
| var (
 | |
| 	_ physical.HABackend = (*PostgreSQLBackend)(nil)
 | |
| 	_ physical.Lock      = (*PostgreSQLLock)(nil)
 | |
| )
 | |
| 
 | |
| // PostgreSQL Backend is a physical backend that stores data
 | |
| // within a PostgreSQL database.
 | |
| type PostgreSQLBackend struct {
 | |
| 	table        string
 | |
| 	client       *sql.DB
 | |
| 	put_query    string
 | |
| 	get_query    string
 | |
| 	delete_query string
 | |
| 	list_query   string
 | |
| 
 | |
| 	ha_table                 string
 | |
| 	haGetLockValueQuery      string
 | |
| 	haUpsertLockIdentityExec string
 | |
| 	haDeleteLockExec         string
 | |
| 
 | |
| 	haEnabled  bool
 | |
| 	logger     log.Logger
 | |
| 	permitPool *physical.PermitPool
 | |
| }
 | |
| 
 | |
| // PostgreSQLLock implements a lock using an PostgreSQL client.
 | |
| type PostgreSQLLock struct {
 | |
| 	backend    *PostgreSQLBackend
 | |
| 	value, key string
 | |
| 	identity   string
 | |
| 	lock       sync.Mutex
 | |
| 
 | |
| 	renewTicker *time.Ticker
 | |
| 
 | |
| 	// ttlSeconds is how long a lock is valid for
 | |
| 	ttlSeconds int
 | |
| 
 | |
| 	// renewInterval is how much time to wait between lock renewals.  must be << ttl
 | |
| 	renewInterval time.Duration
 | |
| 
 | |
| 	// retryInterval is how much time to wait between attempts to grab the lock
 | |
| 	retryInterval time.Duration
 | |
| }
 | |
| 
 | |
| // NewPostgreSQLBackend constructs a PostgreSQL backend using the given
 | |
| // API client, server address, credentials, and database.
 | |
| func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
 | |
| 	// Get the PostgreSQL credentials to perform read/write operations.
 | |
| 	connURL := connectionURL(conf)
 | |
| 	if connURL == "" {
 | |
| 		return nil, fmt.Errorf("missing connection_url")
 | |
| 	}
 | |
| 
 | |
| 	unquoted_table, ok := conf["table"]
 | |
| 	if !ok {
 | |
| 		unquoted_table = "vault_kv_store"
 | |
| 	}
 | |
| 	quoted_table := dbutil.QuoteIdentifier(unquoted_table)
 | |
| 
 | |
| 	maxParStr, ok := conf["max_parallel"]
 | |
| 	var maxParInt int
 | |
| 	var err error
 | |
| 	if ok {
 | |
| 		maxParInt, err = strconv.Atoi(maxParStr)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
 | |
| 		}
 | |
| 		if logger.IsDebug() {
 | |
| 			logger.Debug("max_parallel set", "max_parallel", maxParInt)
 | |
| 		}
 | |
| 	} else {
 | |
| 		maxParInt = physical.DefaultParallelOperations
 | |
| 	}
 | |
| 
 | |
| 	maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"]
 | |
| 	var maxIdleConns int
 | |
| 	if maxIdleConnsIsSet {
 | |
| 		maxIdleConns, err = strconv.Atoi(maxIdleConnsStr)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
 | |
| 		}
 | |
| 		if logger.IsDebug() {
 | |
| 			logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Create PostgreSQL handle for the database.
 | |
| 	db, err := sql.Open("pgx", connURL)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to connect to postgres: %w", err)
 | |
| 	}
 | |
| 	db.SetMaxOpenConns(maxParInt)
 | |
| 
 | |
| 	if maxIdleConnsIsSet {
 | |
| 		db.SetMaxIdleConns(maxIdleConns)
 | |
| 	}
 | |
| 
 | |
| 	// Determine if we should use a function to work around lack of upsert (versions < 9.5)
 | |
| 	var upsertAvailable bool
 | |
| 	upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500"
 | |
| 	if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil {
 | |
| 		return nil, fmt.Errorf("failed to check for native upsert: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	if !upsertAvailable && conf["ha_enabled"] == "true" {
 | |
| 		return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5")
 | |
| 	}
 | |
| 
 | |
| 	// Setup our put strategy based on the presence or absence of a native
 | |
| 	// upsert.
 | |
| 	var put_query string
 | |
| 	if !upsertAvailable {
 | |
| 		put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
 | |
| 	} else {
 | |
| 		put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
 | |
| 			" ON CONFLICT (path, key) DO " +
 | |
| 			" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
 | |
| 	}
 | |
| 
 | |
| 	unquoted_ha_table, ok := conf["ha_table"]
 | |
| 	if !ok {
 | |
| 		unquoted_ha_table = "vault_ha_locks"
 | |
| 	}
 | |
| 	quoted_ha_table := dbutil.QuoteIdentifier(unquoted_ha_table)
 | |
| 
 | |
| 	// Setup the backend.
 | |
| 	m := &PostgreSQLBackend{
 | |
| 		table:        quoted_table,
 | |
| 		client:       db,
 | |
| 		put_query:    put_query,
 | |
| 		get_query:    "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
 | |
| 		delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
 | |
| 		list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
 | |
| 			" UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table +
 | |
| 			" WHERE parent_path LIKE $1 || '%'",
 | |
| 		haGetLockValueQuery:
 | |
| 		// only read non expired data
 | |
| 		" SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ",
 | |
| 		haUpsertLockIdentityExec:
 | |
| 		// $1=identity $2=ha_key $3=ha_value $4=TTL in seconds
 | |
| 		// update either steal expired lock OR update expiry for lock owned by me
 | |
| 		" INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds'  ) " +
 | |
| 			" ON CONFLICT (ha_key) DO " +
 | |
| 			" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " +
 | |
| 			" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
 | |
| 			" (t.ha_identity = $1 AND t.ha_key = $2)  ",
 | |
| 		haDeleteLockExec:
 | |
| 		// $1=ha_identity $2=ha_key
 | |
| 		" DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ",
 | |
| 		logger:     logger,
 | |
| 		permitPool: physical.NewPermitPool(maxParInt),
 | |
| 		haEnabled:  conf["ha_enabled"] == "true",
 | |
| 	}
 | |
| 
 | |
| 	return m, nil
 | |
| }
 | |
| 
 | |
| // connectionURL first check the environment variables for a connection URL. If
 | |
| // no connection URL exists in the environment variable, the Vault config file is
 | |
| // checked. If neither the environment variables or the config file set the connection
 | |
| // URL for the Postgres backend, because it is a required field, an error is returned.
 | |
| func connectionURL(conf map[string]string) string {
 | |
| 	connURL := conf["connection_url"]
 | |
| 	if envURL := os.Getenv("VAULT_PG_CONNECTION_URL"); envURL != "" {
 | |
| 		connURL = envURL
 | |
| 	}
 | |
| 
 | |
| 	return connURL
 | |
| }
 | |
| 
 | |
| // splitKey is a helper to split a full path key into individual
 | |
| // parts: parentPath, path, key
 | |
| func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
 | |
| 	var parentPath string
 | |
| 	var path string
 | |
| 
 | |
| 	pieces := strings.Split(fullPath, "/")
 | |
| 	depth := len(pieces)
 | |
| 	key := pieces[depth-1]
 | |
| 
 | |
| 	if depth == 1 {
 | |
| 		parentPath = ""
 | |
| 		path = "/"
 | |
| 	} else if depth == 2 {
 | |
| 		parentPath = "/"
 | |
| 		path = "/" + pieces[0] + "/"
 | |
| 	} else {
 | |
| 		parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
 | |
| 		path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
 | |
| 	}
 | |
| 
 | |
| 	return parentPath, path, key
 | |
| }
 | |
| 
 | |
| // Put is used to insert or update an entry.
 | |
| func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
 | |
| 	defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
 | |
| 
 | |
| 	m.permitPool.Acquire()
 | |
| 	defer m.permitPool.Release()
 | |
| 
 | |
| 	parentPath, path, key := m.splitKey(entry.Key)
 | |
| 
 | |
| 	_, err := m.client.ExecContext(ctx, m.put_query, parentPath, path, key, entry.Value)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Get is used to fetch and entry.
 | |
| func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
 | |
| 	defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
 | |
| 
 | |
| 	m.permitPool.Acquire()
 | |
| 	defer m.permitPool.Release()
 | |
| 
 | |
| 	_, path, key := m.splitKey(fullPath)
 | |
| 
 | |
| 	var result []byte
 | |
| 	err := m.client.QueryRowContext(ctx, m.get_query, path, key).Scan(&result)
 | |
| 	if err == sql.ErrNoRows {
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	ent := &physical.Entry{
 | |
| 		Key:   fullPath,
 | |
| 		Value: result,
 | |
| 	}
 | |
| 	return ent, nil
 | |
| }
 | |
| 
 | |
| // Delete is used to permanently delete an entry
 | |
| func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
 | |
| 	defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
 | |
| 
 | |
| 	m.permitPool.Acquire()
 | |
| 	defer m.permitPool.Release()
 | |
| 
 | |
| 	_, path, key := m.splitKey(fullPath)
 | |
| 
 | |
| 	_, err := m.client.ExecContext(ctx, m.delete_query, path, key)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // List is used to list all the keys under a given
 | |
| // prefix, up to the next prefix.
 | |
| func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
 | |
| 	defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
 | |
| 
 | |
| 	m.permitPool.Acquire()
 | |
| 	defer m.permitPool.Release()
 | |
| 
 | |
| 	rows, err := m.client.QueryContext(ctx, m.list_query, "/"+prefix)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	defer rows.Close()
 | |
| 
 | |
| 	var keys []string
 | |
| 	for rows.Next() {
 | |
| 		var key string
 | |
| 		err = rows.Scan(&key)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to scan rows: %w", err)
 | |
| 		}
 | |
| 
 | |
| 		keys = append(keys, key)
 | |
| 	}
 | |
| 
 | |
| 	return keys, nil
 | |
| }
 | |
| 
 | |
| // LockWith is used for mutual exclusion based on the given key.
 | |
| func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) {
 | |
| 	identity, err := uuid.GenerateUUID()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return &PostgreSQLLock{
 | |
| 		backend:       p,
 | |
| 		key:           key,
 | |
| 		value:         value,
 | |
| 		identity:      identity,
 | |
| 		ttlSeconds:    PostgreSQLLockTTLSeconds,
 | |
| 		renewInterval: PostgreSQLLockRenewInterval,
 | |
| 		retryInterval: PostgreSQLLockRetryInterval,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (p *PostgreSQLBackend) HAEnabled() bool {
 | |
| 	return p.haEnabled
 | |
| }
 | |
| 
 | |
| // Lock tries to acquire the lock by repeatedly trying to create a record in the
 | |
| // PostgreSQL table. It will block until either the stop channel is closed or
 | |
| // the lock could be acquired successfully. The returned channel will be closed
 | |
| // once the lock in the PostgreSQL table cannot be renewed, either due to an
 | |
| // error speaking to PostgreSQL or because someone else has taken it.
 | |
| func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
 | |
| 	l.lock.Lock()
 | |
| 	defer l.lock.Unlock()
 | |
| 
 | |
| 	var (
 | |
| 		success = make(chan struct{})
 | |
| 		errors  = make(chan error)
 | |
| 		leader  = make(chan struct{})
 | |
| 	)
 | |
| 	// try to acquire the lock asynchronously
 | |
| 	go l.tryToLock(stopCh, success, errors)
 | |
| 
 | |
| 	select {
 | |
| 	case <-success:
 | |
| 		// after acquiring it successfully, we must renew the lock periodically
 | |
| 		l.renewTicker = time.NewTicker(l.renewInterval)
 | |
| 		go l.periodicallyRenewLock(leader)
 | |
| 	case err := <-errors:
 | |
| 		return nil, err
 | |
| 	case <-stopCh:
 | |
| 		return nil, nil
 | |
| 	}
 | |
| 
 | |
| 	return leader, nil
 | |
| }
 | |
| 
 | |
| // Unlock releases the lock by deleting the lock record from the
 | |
| // PostgreSQL table.
 | |
| func (l *PostgreSQLLock) Unlock() error {
 | |
| 	pg := l.backend
 | |
| 	pg.permitPool.Acquire()
 | |
| 	defer pg.permitPool.Release()
 | |
| 
 | |
| 	if l.renewTicker != nil {
 | |
| 		l.renewTicker.Stop()
 | |
| 	}
 | |
| 
 | |
| 	// Delete lock owned by me
 | |
| 	_, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // Value checks whether or not the lock is held by any instance of PostgreSQLLock,
 | |
| // including this one, and returns the current value.
 | |
| func (l *PostgreSQLLock) Value() (bool, string, error) {
 | |
| 	pg := l.backend
 | |
| 	pg.permitPool.Acquire()
 | |
| 	defer pg.permitPool.Release()
 | |
| 	var result string
 | |
| 	err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result)
 | |
| 
 | |
| 	switch err {
 | |
| 	case nil:
 | |
| 		return true, result, nil
 | |
| 	case sql.ErrNoRows:
 | |
| 		return false, "", nil
 | |
| 	default:
 | |
| 		return false, "", err
 | |
| 
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // tryToLock tries to create a new item in PostgreSQL every `retryInterval`.
 | |
| // As long as the item cannot be created (because it already exists), it will
 | |
| // be retried. If the operation fails due to an error, it is sent to the errors
 | |
| // channel. When the lock could be acquired successfully, the success channel
 | |
| // is closed.
 | |
| func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
 | |
| 	ticker := time.NewTicker(l.retryInterval)
 | |
| 	defer ticker.Stop()
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-stop:
 | |
| 			return
 | |
| 		case <-ticker.C:
 | |
| 			gotlock, err := l.writeItem()
 | |
| 			switch {
 | |
| 			case err != nil:
 | |
| 				errors <- err
 | |
| 				return
 | |
| 			case gotlock:
 | |
| 				close(success)
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) {
 | |
| 	for range l.renewTicker.C {
 | |
| 		gotlock, err := l.writeItem()
 | |
| 		if err != nil || !gotlock {
 | |
| 			close(done)
 | |
| 			l.renewTicker.Stop()
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Attempts to put/update the PostgreSQL item using condition expressions to
 | |
| // evaluate the TTL.  Returns true if the lock was obtained, false if not.
 | |
| // If false error may be nil or non-nil: nil indicates simply that someone
 | |
| // else has the lock, whereas non-nil means that something unexpected happened.
 | |
| func (l *PostgreSQLLock) writeItem() (bool, error) {
 | |
| 	pg := l.backend
 | |
| 	pg.permitPool.Acquire()
 | |
| 	defer pg.permitPool.Release()
 | |
| 
 | |
| 	// Try steal lock or update expiry on my lock
 | |
| 
 | |
| 	sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds)
 | |
| 	if err != nil {
 | |
| 		return false, err
 | |
| 	}
 | |
| 	if sqlResult == nil {
 | |
| 		return false, fmt.Errorf("empty SQL response received")
 | |
| 	}
 | |
| 
 | |
| 	ar, err := sqlResult.RowsAffected()
 | |
| 	if err != nil {
 | |
| 		return false, err
 | |
| 	}
 | |
| 	return ar == 1, nil
 | |
| }
 | 
