mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 09:42:25 +00:00
database: Avoid race condition in connection creation (#26147)
When creating database connections, there is a race condition when multiple goroutines try to create the connection at the same time. This happens, for example, on leadership changes in a cluster. Normally, the extra database connections are cleaned up when this is detected. However, some database implementations, notably Postgres, do not seem to clean up in a timely manner, and can leak in these scenarios. To fix this, we create a global lock when creating database connections to prevent multiple connections from being created at the same time. We also clean up the logic at the end so that if (somehow) we ended up creating an additional connection, we use the existing one rather than the new one. This by itself would solve our problem long-term, however, would still involve many transient database connections being created and immediately killed on leadership changes. It's not ideal to have a single global lock for database connection creation. Some potential alternatives: * a map of locks from the connection name to the lock. The biggest downside is the we probably will want to garbage collect this map so that we don't have an unbounded number of locks. * a small pool of locks, where we hash the connection names to pick the lock. Using such a pool generally is a good way to introduce deadlock, but since we will only use it in a specific case, and the purpose is to improve performance for concurrent connection creation, this is probably acceptable. Co-authored-by: Jason O'Donnell <2160810+jasonodonnell@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
7bd75eb858
commit
a65d9133a1
@@ -161,8 +161,9 @@ func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]m
|
||||
|
||||
type databaseBackend struct {
|
||||
// connections holds configured database connections by config name
|
||||
connections *syncmap.SyncMap[string, *dbPluginInstance]
|
||||
logger log.Logger
|
||||
createConnectionLock sync.Mutex
|
||||
connections *syncmap.SyncMap[string, *dbPluginInstance]
|
||||
logger log.Logger
|
||||
|
||||
*framework.Backend
|
||||
// credRotationQueue is an in-memory priority queue used to track Static Roles
|
||||
@@ -291,11 +292,23 @@ func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage,
|
||||
}
|
||||
|
||||
func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) {
|
||||
// fast path, reuse the existing connection
|
||||
dbi := b.connections.Get(name)
|
||||
if dbi != nil {
|
||||
return dbi, nil
|
||||
}
|
||||
|
||||
// slow path, create a new connection
|
||||
// if we don't lock the rest of the operation, there is a race condition for multiple callers of this function
|
||||
b.createConnectionLock.Lock()
|
||||
defer b.createConnectionLock.Unlock()
|
||||
|
||||
// check again in case we lost the race
|
||||
dbi = b.connections.Get(name)
|
||||
if dbi != nil {
|
||||
return dbi, nil
|
||||
}
|
||||
|
||||
id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -332,14 +345,17 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
|
||||
name: name,
|
||||
runningPluginVersion: pluginVersion,
|
||||
}
|
||||
oldConn := b.connections.Put(name, dbi)
|
||||
if oldConn != nil {
|
||||
err := oldConn.Close()
|
||||
conn, ok := b.connections.PutIfEmpty(name, dbi)
|
||||
if !ok {
|
||||
// this is a bug
|
||||
b.Logger().Warn("BUG: there was a race condition adding to the database connection map")
|
||||
// There was already an existing connection, so we will use that and close our new one to avoid a race condition.
|
||||
err := dbi.Close()
|
||||
if err != nil {
|
||||
b.Logger().Warn("Error closing database connection", "error", err)
|
||||
b.Logger().Warn("Error closing new database connection", "error", err)
|
||||
}
|
||||
}
|
||||
return dbi, nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ClearConnection closes the database connection and
|
||||
|
||||
109
builtin/logical/database/backend_get_test.go
Normal file
109
builtin/logical/database/backend_get_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/queue"
|
||||
)
|
||||
|
||||
func newSystemViewWrapper(view logical.SystemView) logical.SystemView {
|
||||
return &systemViewWrapper{
|
||||
view,
|
||||
}
|
||||
}
|
||||
|
||||
type systemViewWrapper struct {
|
||||
logical.SystemView
|
||||
}
|
||||
|
||||
var _ logical.ExtendedSystemView = (*systemViewWrapper)(nil)
|
||||
|
||||
func (s *systemViewWrapper) RequestWellKnownRedirect(ctx context.Context, src, dest string) error {
|
||||
panic("nope")
|
||||
}
|
||||
|
||||
func (s *systemViewWrapper) DeregisterWellKnownRedirect(ctx context.Context, src string) bool {
|
||||
panic("nope")
|
||||
}
|
||||
|
||||
func (s *systemViewWrapper) Auditor() logical.Auditor {
|
||||
panic("nope")
|
||||
}
|
||||
|
||||
func (s *systemViewWrapper) ForwardGenericRequest(ctx context.Context, request *logical.Request) (*logical.Response, error) {
|
||||
panic("nope")
|
||||
}
|
||||
|
||||
func (s *systemViewWrapper) APILockShouldBlockRequest() (bool, error) {
|
||||
panic("nope")
|
||||
}
|
||||
|
||||
func (s *systemViewWrapper) GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) {
|
||||
return nil, pluginutil.ErrPinnedVersionNotFound
|
||||
}
|
||||
|
||||
func (s *systemViewWrapper) LookupPluginVersion(ctx context.Context, pluginName string, pluginType consts.PluginType, version string) (*pluginutil.PluginRunner, error) {
|
||||
return &pluginutil.PluginRunner{
|
||||
Name: mockv5,
|
||||
Type: consts.PluginTypeDatabase,
|
||||
Builtin: true,
|
||||
BuiltinFactory: New,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getDbBackend(t *testing.T) (*databaseBackend, logical.Storage) {
|
||||
t.Helper()
|
||||
config := logical.TestBackendConfig()
|
||||
config.System = newSystemViewWrapper(config.System)
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
// Create and init the backend ourselves instead of using a Factory because
|
||||
// the factory function kicks off threads that cause racy tests.
|
||||
b := Backend(config)
|
||||
if err := b.Setup(context.Background(), config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b.schedule = &TestSchedule{}
|
||||
b.credRotationQueue = queue.New()
|
||||
b.populateQueue(context.Background(), config.StorageView)
|
||||
|
||||
return b, config.StorageView
|
||||
}
|
||||
|
||||
// TestGetConnectionRaceCondition checks that GetConnection always returns the same instance, even when asked
|
||||
// by multiple goroutines in parallel.
|
||||
func TestGetConnectionRaceCondition(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
b, s := getDbBackend(t)
|
||||
defer b.Cleanup(ctx)
|
||||
configureDBMount(t, s)
|
||||
|
||||
goroutines := 16
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(goroutines)
|
||||
dbis := make([]*dbPluginInstance, goroutines)
|
||||
errs := make([]error, goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
dbis[i], errs[i] = b.GetConnection(ctx, s, mockv5)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
for i := 0; i < goroutines; i++ {
|
||||
if errs[i] != nil {
|
||||
t.Fatal(errs[i])
|
||||
}
|
||||
if dbis[0] != dbis[i] {
|
||||
t.Fatal("Error: database instances did not match")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -51,6 +51,9 @@ func (m MockDatabaseV5) Initialize(ctx context.Context, req v5.InitializeRequest
|
||||
"req", req)
|
||||
|
||||
config := req.Config
|
||||
if config == nil {
|
||||
config = map[string]interface{}{}
|
||||
}
|
||||
config["from-plugin"] = "this value is from the plugin itself"
|
||||
|
||||
resp := v5.InitializeResponse{
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
mockv5 = "mockv5"
|
||||
dbUser = "vaultstatictest"
|
||||
dbUserDefaultPassword = "password"
|
||||
testMinRotationWindowSeconds = 5
|
||||
@@ -1446,7 +1447,7 @@ func TestStoredWALsCorrectlyProcessed(t *testing.T) {
|
||||
|
||||
rotationPeriodData := map[string]interface{}{
|
||||
"username": "hashicorp",
|
||||
"db_name": "mockv5",
|
||||
"db_name": mockv5,
|
||||
"rotation_period": "86400s",
|
||||
}
|
||||
|
||||
@@ -1500,7 +1501,7 @@ func TestStoredWALsCorrectlyProcessed(t *testing.T) {
|
||||
},
|
||||
map[string]interface{}{
|
||||
"username": "hashicorp",
|
||||
"db_name": "mockv5",
|
||||
"db_name": mockv5,
|
||||
"rotation_schedule": "*/10 * * * * *",
|
||||
},
|
||||
},
|
||||
@@ -1699,9 +1700,9 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase {
|
||||
dbi := &dbPluginInstance{
|
||||
database: dbw,
|
||||
id: "foo-id",
|
||||
name: "mockV5",
|
||||
name: mockv5,
|
||||
}
|
||||
b.connections.Put("mockv5", dbi)
|
||||
b.connections.Put(mockv5, dbi)
|
||||
|
||||
return mockDB
|
||||
}
|
||||
@@ -1710,7 +1711,7 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase {
|
||||
// plugin init code paths, allowing us to use a manually populated mock DB object.
|
||||
func configureDBMount(t *testing.T, storage logical.Storage) {
|
||||
t.Helper()
|
||||
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/mockv5"), &DatabaseConfig{
|
||||
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/"+mockv5), &DatabaseConfig{
|
||||
AllowedRoles: []string{"*"},
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
3
changelog/26147.txt
Normal file
3
changelog/26147.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
```release-note:bug
|
||||
secret/database: Fixed race condition where database mounts may leak connections
|
||||
```
|
||||
@@ -62,6 +62,20 @@ func (m *SyncMap[K, V]) Put(k K, v V) V {
|
||||
return oldV
|
||||
}
|
||||
|
||||
// PutIfEmpty adds the given key-value pair to the map only if there is no value already in it,
|
||||
// and returns the new value and true if so.
|
||||
// If there is already a value, it returns the existing value and false.
|
||||
func (m *SyncMap[K, V]) PutIfEmpty(k K, v V) (V, bool) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
oldV, ok := m.data[k]
|
||||
if ok {
|
||||
return oldV, false
|
||||
}
|
||||
m.data[k] = v
|
||||
return v, true
|
||||
}
|
||||
|
||||
// Clear deletes all entries from the map, and returns the previous map.
|
||||
func (m *SyncMap[K, V]) Clear() map[K]V {
|
||||
m.lock.Lock()
|
||||
|
||||
Reference in New Issue
Block a user