mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Move database connections map out to separate package (#21207)
The upcoming event main plugin will use a very similar pattern as the database plugin map, so it makes sense to refactor this and move this map out. It also cleans up the database plugin backend so that it does not have to keep track of the lock. Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
58b6cb1c42
commit
cf48236a3c
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/helper/metricsutil"
|
||||
"github.com/hashicorp/vault/helper/syncmap"
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
v4 "github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
@@ -43,6 +44,10 @@ type dbPluginInstance struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (dbi *dbPluginInstance) ID() string {
|
||||
return dbi.id
|
||||
}
|
||||
|
||||
func (dbi *dbPluginInstance) Close() error {
|
||||
dbi.Lock()
|
||||
defer dbi.Unlock()
|
||||
@@ -119,7 +124,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||
}
|
||||
|
||||
b.logger = conf.Logger
|
||||
b.connections = make(map[string]*dbPluginInstance)
|
||||
b.connections = syncmap.NewSyncMap[string, *dbPluginInstance]()
|
||||
b.queueCtx, b.cancelQueueCtx = context.WithCancel(context.Background())
|
||||
b.roleLocks = locksutil.CreateLocks()
|
||||
return &b
|
||||
@@ -127,17 +132,9 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||
|
||||
func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]metricsutil.GaugeLabelValues, error) {
|
||||
// copy the map so we can release the lock
|
||||
connMapCopy := func() map[string]*dbPluginInstance {
|
||||
b.connLock.RLock()
|
||||
defer b.connLock.RUnlock()
|
||||
mapCopy := map[string]*dbPluginInstance{}
|
||||
for k, v := range b.connections {
|
||||
mapCopy[k] = v
|
||||
}
|
||||
return mapCopy
|
||||
}()
|
||||
connectionsCopy := b.connections.Values()
|
||||
counts := map[string]int{}
|
||||
for _, v := range connMapCopy {
|
||||
for _, v := range connectionsCopy {
|
||||
dbType, err := v.database.Type()
|
||||
if err != nil {
|
||||
// there's a chance this will already be closed since we don't hold the lock
|
||||
@@ -156,10 +153,8 @@ func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]m
|
||||
}
|
||||
|
||||
type databaseBackend struct {
|
||||
// connLock is used to synchronize access to the connections map
|
||||
connLock sync.RWMutex
|
||||
// connections holds configured database connections by config name
|
||||
connections map[string]*dbPluginInstance
|
||||
connections *syncmap.SyncMap[string, *dbPluginInstance]
|
||||
logger log.Logger
|
||||
|
||||
*framework.Backend
|
||||
@@ -183,49 +178,6 @@ type databaseBackend struct {
|
||||
gaugeCollectionProcessStop sync.Once
|
||||
}
|
||||
|
||||
func (b *databaseBackend) connGet(name string) *dbPluginInstance {
|
||||
b.connLock.RLock()
|
||||
defer b.connLock.RUnlock()
|
||||
return b.connections[name]
|
||||
}
|
||||
|
||||
func (b *databaseBackend) connPop(name string) *dbPluginInstance {
|
||||
b.connLock.Lock()
|
||||
defer b.connLock.Unlock()
|
||||
dbi, ok := b.connections[name]
|
||||
if ok {
|
||||
delete(b.connections, name)
|
||||
}
|
||||
return dbi
|
||||
}
|
||||
|
||||
func (b *databaseBackend) connPopIfEqual(name, id string) *dbPluginInstance {
|
||||
b.connLock.Lock()
|
||||
defer b.connLock.Unlock()
|
||||
dbi, ok := b.connections[name]
|
||||
if ok && dbi.id == id {
|
||||
delete(b.connections, name)
|
||||
return dbi
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *databaseBackend) connPut(name string, newDbi *dbPluginInstance) *dbPluginInstance {
|
||||
b.connLock.Lock()
|
||||
defer b.connLock.Unlock()
|
||||
dbi := b.connections[name]
|
||||
b.connections[name] = newDbi
|
||||
return dbi
|
||||
}
|
||||
|
||||
func (b *databaseBackend) connClear() map[string]*dbPluginInstance {
|
||||
b.connLock.Lock()
|
||||
defer b.connLock.Unlock()
|
||||
old := b.connections
|
||||
b.connections = make(map[string]*dbPluginInstance)
|
||||
return old
|
||||
}
|
||||
|
||||
func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) {
|
||||
entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name))
|
||||
if err != nil {
|
||||
@@ -330,7 +282,7 @@ func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage,
|
||||
}
|
||||
|
||||
func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) {
|
||||
dbi := b.connGet(name)
|
||||
dbi := b.connections.Get(name)
|
||||
if dbi != nil {
|
||||
return dbi, nil
|
||||
}
|
||||
@@ -360,7 +312,7 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
|
||||
id: id,
|
||||
name: name,
|
||||
}
|
||||
oldConn := b.connPut(name, dbi)
|
||||
oldConn := b.connections.Put(name, dbi)
|
||||
if oldConn != nil {
|
||||
err := oldConn.Close()
|
||||
if err != nil {
|
||||
@@ -373,7 +325,7 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
|
||||
// ClearConnection closes the database connection and
|
||||
// removes it from the b.connections map.
|
||||
func (b *databaseBackend) ClearConnection(name string) error {
|
||||
db := b.connPop(name)
|
||||
db := b.connections.Pop(name)
|
||||
if db != nil {
|
||||
// Ignore error here since the database client is always killed
|
||||
db.Close()
|
||||
@@ -384,7 +336,7 @@ func (b *databaseBackend) ClearConnection(name string) error {
|
||||
// ClearConnectionId closes the database connection with a specific id and
|
||||
// removes it from the b.connections map.
|
||||
func (b *databaseBackend) ClearConnectionId(name, id string) error {
|
||||
db := b.connPopIfEqual(name, id)
|
||||
db := b.connections.PopIfEqual(name, id)
|
||||
if db != nil {
|
||||
// Ignore error here since the database client is always killed
|
||||
db.Close()
|
||||
@@ -403,7 +355,7 @@ func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) {
|
||||
db.Close()
|
||||
|
||||
// Delete the connection if it is still active.
|
||||
b.connPopIfEqual(db.name, db.id)
|
||||
b.connections.PopIfEqual(db.name, db.id)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -416,7 +368,7 @@ func (b *databaseBackend) clean(_ context.Context) {
|
||||
b.cancelQueueCtx()
|
||||
}
|
||||
|
||||
connections := b.connClear()
|
||||
connections := b.connections.Clear()
|
||||
for _, db := range connections {
|
||||
go db.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user