diff --git a/.github/scripts/generate-test-package-lists.sh b/.github/scripts/generate-test-package-lists.sh index b71b1d72ea..6d5e41ada5 100755 --- a/.github/scripts/generate-test-package-lists.sh +++ b/.github/scripts/generate-test-package-lists.sh @@ -120,6 +120,7 @@ test_packages[6]+=" $base/helper/namespace" test_packages[6]+=" $base/helper/osutil" test_packages[6]+=" $base/helper/parseip" test_packages[6]+=" $base/helper/policies" +test_packages[6]+=" $base/helper/syncmap" test_packages[6]+=" $base/helper/testhelpers/logical" test_packages[6]+=" $base/helper/timeutil" test_packages[6]+=" $base/helper/useragent" diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 94091e2019..f4e5ef31bd 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/metricsutil" + "github.com/hashicorp/vault/helper/syncmap" "github.com/hashicorp/vault/internalshared/configutil" v4 "github.com/hashicorp/vault/sdk/database/dbplugin" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" @@ -43,6 +44,10 @@ type dbPluginInstance struct { closed bool } +func (dbi *dbPluginInstance) ID() string { + return dbi.id +} + func (dbi *dbPluginInstance) Close() error { dbi.Lock() defer dbi.Unlock() @@ -119,7 +124,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]*dbPluginInstance) + b.connections = syncmap.NewSyncMap[string, *dbPluginInstance]() b.queueCtx, b.cancelQueueCtx = context.WithCancel(context.Background()) b.roleLocks = locksutil.CreateLocks() return &b @@ -127,17 +132,9 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]metricsutil.GaugeLabelValues, error) { // copy the map so we can release the lock - connMapCopy := func() map[string]*dbPluginInstance { - b.connLock.RLock() - defer b.connLock.RUnlock() - mapCopy := map[string]*dbPluginInstance{} - for k, v := range b.connections { - mapCopy[k] = v - } - return mapCopy - }() + connectionsCopy := b.connections.Values() counts := map[string]int{} - for _, v := range connMapCopy { + for _, v := range connectionsCopy { dbType, err := v.database.Type() if err != nil { // there's a chance this will already be closed since we don't hold the lock @@ -156,10 +153,8 @@ func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]m } type databaseBackend struct { - // connLock is used to synchronize access to the connections map - connLock sync.RWMutex // connections holds configured database connections by config name - connections map[string]*dbPluginInstance + connections *syncmap.SyncMap[string, *dbPluginInstance] logger log.Logger *framework.Backend @@ -183,49 +178,6 @@ type databaseBackend struct { gaugeCollectionProcessStop sync.Once } -func (b *databaseBackend) connGet(name string) *dbPluginInstance { - b.connLock.RLock() - defer b.connLock.RUnlock() - return b.connections[name] -} - -func (b *databaseBackend) connPop(name string) *dbPluginInstance { - b.connLock.Lock() - defer b.connLock.Unlock() - dbi, ok := b.connections[name] - if ok { - delete(b.connections, name) - } - return dbi -} - -func (b *databaseBackend) connPopIfEqual(name, id string) *dbPluginInstance { - b.connLock.Lock() - defer b.connLock.Unlock() - dbi, ok := b.connections[name] - if ok && dbi.id == id { - delete(b.connections, name) - return dbi - } - return nil -} - -func (b *databaseBackend) connPut(name string, newDbi *dbPluginInstance) *dbPluginInstance { - b.connLock.Lock() - defer b.connLock.Unlock() - dbi := b.connections[name] - b.connections[name] = newDbi - return dbi -} - -func (b *databaseBackend) connClear() map[string]*dbPluginInstance { - b.connLock.Lock() - defer b.connLock.Unlock() - old := b.connections - b.connections = make(map[string]*dbPluginInstance) - return old -} - func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) { entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name)) if err != nil { @@ -330,7 +282,7 @@ func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage, } func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) { - dbi := b.connGet(name) + dbi := b.connections.Get(name) if dbi != nil { return dbi, nil } @@ -360,7 +312,7 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri id: id, name: name, } - oldConn := b.connPut(name, dbi) + oldConn := b.connections.Put(name, dbi) if oldConn != nil { err := oldConn.Close() if err != nil { @@ -373,7 +325,7 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri // ClearConnection closes the database connection and // removes it from the b.connections map. func (b *databaseBackend) ClearConnection(name string) error { - db := b.connPop(name) + db := b.connections.Pop(name) if db != nil { // Ignore error here since the database client is always killed db.Close() @@ -384,7 +336,7 @@ func (b *databaseBackend) ClearConnection(name string) error { // ClearConnectionId closes the database connection with a specific id and // removes it from the b.connections map. func (b *databaseBackend) ClearConnectionId(name, id string) error { - db := b.connPopIfEqual(name, id) + db := b.connections.PopIfEqual(name, id) if db != nil { // Ignore error here since the database client is always killed db.Close() @@ -403,7 +355,7 @@ func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) { db.Close() // Delete the connection if it is still active. - b.connPopIfEqual(db.name, db.id) + b.connections.PopIfEqual(db.name, db.id) }() } } @@ -416,7 +368,7 @@ func (b *databaseBackend) clean(_ context.Context) { b.cancelQueueCtx() } - connections := b.connClear() + connections := b.connections.Clear() for _, db := range connections { go db.Close() } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index b869facef0..a50499280f 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -462,7 +462,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName) // Close and remove the old connection - oldConn := b.connPut(name, &dbPluginInstance{ + oldConn := b.connections.Put(name, &dbPluginInstance{ database: dbw, name: name, id: id, diff --git a/builtin/logical/database/rotation_test.go b/builtin/logical/database/rotation_test.go index e0cb96dd67..5dfa096593 100644 --- a/builtin/logical/database/rotation_test.go +++ b/builtin/logical/database/rotation_test.go @@ -1390,7 +1390,7 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase { id: "foo-id", name: "mockV5", } - b.connections["mockv5"] = dbi + b.connections.Put("mockv5", dbi) return mockDB } diff --git a/helper/syncmap/syncmap.go b/helper/syncmap/syncmap.go new file mode 100644 index 0000000000..ce3d9e8ca0 --- /dev/null +++ b/helper/syncmap/syncmap.go @@ -0,0 +1,86 @@ +package syncmap + +import "sync" + +// SyncMap implements a map similar to sync.Map, but with generics and with an equality +// in the values specified by an "ID()" method. +type SyncMap[K comparable, V IDer] struct { + // lock is used to synchronize access to the map + lock sync.RWMutex + // data holds the actual data + data map[K]V +} + +// NewSyncMap returns a new, empty SyncMap. +func NewSyncMap[K comparable, V IDer]() *SyncMap[K, V] { + return &SyncMap[K, V]{ + data: make(map[K]V), + } +} + +// Get returns the value for the given key. +func (m *SyncMap[K, V]) Get(k K) V { + m.lock.RLock() + defer m.lock.RUnlock() + return m.data[k] +} + +// Pop deletes and returns the value for the given key, if it exists. +func (m *SyncMap[K, V]) Pop(k K) V { + m.lock.Lock() + defer m.lock.Unlock() + v, ok := m.data[k] + if ok { + delete(m.data, k) + } + return v +} + +// PopIfEqual deletes and returns the value for the given key, if it exists +// and only if the ID is equal to the provided string. +func (m *SyncMap[K, V]) PopIfEqual(k K, id string) V { + m.lock.Lock() + defer m.lock.Unlock() + v, ok := m.data[k] + if ok && v.ID() == id { + delete(m.data, k) + return v + } + var zero V + return zero +} + +// Put adds the given key-value pair to the map and returns the previous value, if any. +func (m *SyncMap[K, V]) Put(k K, v V) V { + m.lock.Lock() + defer m.lock.Unlock() + oldV := m.data[k] + m.data[k] = v + return oldV +} + +// Clear deletes all entries from the map, and returns the previous map. +func (m *SyncMap[K, V]) Clear() map[K]V { + m.lock.Lock() + defer m.lock.Unlock() + old := m.data + m.data = make(map[K]V) + return old +} + +// Values returns a copy of all values in the map. +func (m *SyncMap[K, V]) Values() []V { + m.lock.RLock() + defer m.lock.RUnlock() + + values := make([]V, 0, len(m.data)) + for _, v := range m.data { + values = append(values, v) + } + return values +} + +// IDer is used to extract an ID that SyncMap uses for equality checking. +type IDer interface { + ID() string +} diff --git a/helper/syncmap/syncmap_test.go b/helper/syncmap/syncmap_test.go new file mode 100644 index 0000000000..a62de301fa --- /dev/null +++ b/helper/syncmap/syncmap_test.go @@ -0,0 +1,75 @@ +package syncmap + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +type stringID struct { + val string + id string +} + +func (s stringID) ID() string { + return s.id +} + +var _ IDer = stringID{"", ""} + +// TestSyncMap_Get tests that basic getting and putting works. +func TestSyncMap_Get(t *testing.T) { + m := NewSyncMap[string, stringID]() + m.Put("a", stringID{"b", "b"}) + assert.Equal(t, stringID{"b", "b"}, m.Get("a")) + assert.Equal(t, stringID{"", ""}, m.Get("c")) +} + +// TestSyncMap_Pop tests that basic Pop operations work. +func TestSyncMap_Pop(t *testing.T) { + m := NewSyncMap[string, stringID]() + m.Put("a", stringID{"b", "b"}) + assert.Equal(t, stringID{"b", "b"}, m.Pop("a")) + assert.Equal(t, stringID{"", ""}, m.Pop("a")) + assert.Equal(t, stringID{"", ""}, m.Pop("c")) +} + +// TestSyncMap_PopIfEqual tests that basic PopIfEqual operations pop only if the IDs are equal. +func TestSyncMap_PopIfEqual(t *testing.T) { + m := NewSyncMap[string, stringID]() + m.Put("a", stringID{"b", "c"}) + assert.Equal(t, stringID{"", ""}, m.PopIfEqual("a", "b")) + assert.Equal(t, stringID{"b", "c"}, m.PopIfEqual("a", "c")) + assert.Equal(t, stringID{"", ""}, m.PopIfEqual("a", "c")) +} + +// TestSyncMap_Clear checks that clearing works as expected and returns a copy of the original map. +func TestSyncMap_Clear(t *testing.T) { + m := NewSyncMap[string, stringID]() + assert.Equal(t, map[string]stringID{}, m.data) + oldMap := m.Clear() + assert.Equal(t, map[string]stringID{}, m.data) + assert.Equal(t, map[string]stringID{}, oldMap) + + m.Put("a", stringID{"b", "b"}) + m.Put("c", stringID{"d", "d"}) + oldMap = m.Clear() + + assert.Equal(t, map[string]stringID{"a": {"b", "b"}, "c": {"d", "d"}}, oldMap) + assert.Equal(t, map[string]stringID{}, m.data) +} + +// TestSyncMap_Values checks that the Values method returns an array of the values. +func TestSyncMap_Values(t *testing.T) { + m := NewSyncMap[string, stringID]() + assert.Equal(t, []stringID{}, m.Values()) + m.Put("a", stringID{"b", "b"}) + assert.Equal(t, []stringID{{"b", "b"}}, m.Values()) + m.Put("c", stringID{"d", "d"}) + values := m.Values() + sort.Slice(values, func(i, j int) bool { + return values[i].val < values[j].val + }) + assert.Equal(t, []stringID{{"b", "b"}, {"d", "d"}}, m.Values()) +}