From ae45b8eb047f4b39705824002399a3f0e66ebb15 Mon Sep 17 00:00:00 2001 From: Marc Boudreau Date: Thu, 11 Jan 2024 11:09:41 -0500 Subject: [PATCH] make namespace manager a field in the Manager struct (#24815) --- vault/core.go | 3 +-- vault/core_util.go | 9 +++++++++ vault/ui_custom_messages/manager.go | 19 ++++++++----------- vault/ui_custom_messages/manager_test.go | 17 ++++++----------- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/vault/core.go b/vault/core.go index 22f19fc5b0..f6b5cd0add 100644 --- a/vault/core.go +++ b/vault/core.go @@ -64,7 +64,6 @@ import ( "github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/vault/quotas" vaultseal "github.com/hashicorp/vault/vault/seal" - uicustommessages "github.com/hashicorp/vault/vault/ui_custom_messages" "github.com/hashicorp/vault/version" "github.com/patrickmn/go-cache" uberAtomic "go.uber.org/atomic" @@ -1264,7 +1263,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { // UI uiStoragePrefix := systemBarrierPrefix + "ui" c.uiConfig = NewUIConfig(conf.EnableUI, physical.NewView(c.physical, uiStoragePrefix), NewBarrierView(c.barrier, uiStoragePrefix)) - c.customMessageManager = uicustommessages.NewManager(c.barrier) + c.customMessageManager = createCustomMessageManager(c.barrier, c) // Listeners err = c.configureListeners(conf) diff --git a/vault/core_util.go b/vault/core_util.go index 5e0d7438c0..b9f858e25b 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/vault/quotas" "github.com/hashicorp/vault/vault/replication" + uicustommessages "github.com/hashicorp/vault/vault/ui_custom_messages" ) const ( @@ -202,3 +203,11 @@ func (c *Core) MissingRequiredState(raw []string, perfStandby bool) bool { func DiagnoseCheckLicense(ctx context.Context, vaultCore *Core, coreConfig CoreConfig, generate bool) (bool, []string) { return false, nil } + +// createCustomMessageManager is a function implemented differently for the +// community edition and the enterprise edition. This is the community +// edition implementation. It simply constructs a uicustommessages.Manager +// instance and returns a pointer to it. +func createCustomMessageManager(storage logical.Storage, _ *Core) CustomMessagesManager { + return uicustommessages.NewManager(storage) +} diff --git a/vault/ui_custom_messages/manager.go b/vault/ui_custom_messages/manager.go index 47439e620b..db2ea24c74 100644 --- a/vault/ui_custom_messages/manager.go +++ b/vault/ui_custom_messages/manager.go @@ -26,31 +26,28 @@ const ( MaximumMessageCountPerNamespace int = 100 ) -// nsManager is the NamespaceManager instance used to determine the set of -// Namespaces to consider when retrieving active Custom Message. This -// variable is re-assigned to point to a real NamespaceManager in the -// enterprise edition. -var nsManager NamespaceManager = &CommunityEditionNamespaceManager{} - // Manager is a struct that provides methods to manage messages stored in a // logical.Storage. type Manager struct { view logical.Storage l sync.RWMutex + + nsManager NamespaceManager } // NewManager creates a new Manager struct that has been fully initialized. func NewManager(storage logical.Storage) *Manager { return &Manager{ - view: storage, + view: storage, + nsManager: &CommunityEditionNamespaceManager{}, } } // FindMessages handles getting a list of existing messages that match the // criteria set in the provided FindFilter struct. func (m *Manager) FindMessages(ctx context.Context, filters FindFilter) ([]Message, error) { - nsList, err := getNamespacesToSearch(ctx, filters) + nsList, err := m.getNamespacesToSearch(ctx, filters) if err != nil { return nil, err } @@ -218,7 +215,7 @@ func (m *Manager) putEntry(ctx context.Context, entry *Entry) error { // This function handles the complexity of gathering all of the applicable // namespaces depending on the namespace set in the context and whether the // IncludeAncestors criterion is set to true in the provided FindFilter struct. -func getNamespacesToSearch(ctx context.Context, filters FindFilter) ([]*namespace.Namespace, error) { +func (m *Manager) getNamespacesToSearch(ctx context.Context, filters FindFilter) ([]*namespace.Namespace, error) { var nsList []*namespace.Namespace ns, err := namespace.FromContext(ctx) @@ -230,8 +227,8 @@ func getNamespacesToSearch(ctx context.Context, filters FindFilter) ([]*namespac nsList = append(nsList, ns) if filters.IncludeAncestors { - parentNs := nsManager.GetParentNamespace(ns.Path) - for ; parentNs.ID != ns.ID; parentNs = nsManager.GetParentNamespace(ns.Path) { + parentNs := m.nsManager.GetParentNamespace(ns.Path) + for ; parentNs.ID != ns.ID; parentNs = m.nsManager.GetParentNamespace(ns.Path) { ns = parentNs nsList = append(nsList, ns) } diff --git a/vault/ui_custom_messages/manager_test.go b/vault/ui_custom_messages/manager_test.go index e0d8e43ddf..2029407248 100644 --- a/vault/ui_custom_messages/manager_test.go +++ b/vault/ui_custom_messages/manager_test.go @@ -214,23 +214,18 @@ func TestManagerPutEntry(t *testing.T) { // context (e.g. checking that the list contains 1 element and that it's equal // to namespace.RootNamespace). func TestGetNamespacesToSearch(t *testing.T) { - list, err := getNamespacesToSearch(context.Background(), FindFilter{}) + testManager := &Manager{nsManager: &CommunityEditionNamespaceManager{}} + + list, err := testManager.getNamespacesToSearch(context.Background(), FindFilter{}) assert.Error(t, err) assert.Nil(t, list) - list, err = getNamespacesToSearch(namespace.ContextWithNamespace(context.Background(), namespace.RootNamespace), FindFilter{}) + list, err = testManager.getNamespacesToSearch(namespace.ContextWithNamespace(context.Background(), namespace.RootNamespace), FindFilter{}) assert.NoError(t, err) assert.Len(t, list, 1) assert.Equal(t, namespace.RootNamespace, list[0]) - // Verify with nsManager set to an instance of testNamespaceManager to - // ensure that it is used to calculate the list of namespaces. - currentNsManager := nsManager - defer func() { - nsManager = currentNsManager - }() - - nsManager = &testNamespaceManager{ + testManager.nsManager = &testNamespaceManager{ results: []namespace.Namespace{ { ID: "ccc", @@ -247,7 +242,7 @@ func TestGetNamespacesToSearch(t *testing.T) { }, } - list, err = getNamespacesToSearch(namespace.ContextWithNamespace(context.Background(), &namespace.Namespace{ID: "ddd", Path: "d/"}), FindFilter{IncludeAncestors: true}) + list, err = testManager.getNamespacesToSearch(namespace.ContextWithNamespace(context.Background(), &namespace.Namespace{ID: "ddd", Path: "d/"}), FindFilter{IncludeAncestors: true}) assert.NoError(t, err) assert.Len(t, list, 5) assert.Equal(t, list[0].Path, "d/")