make namespace manager a field in the Manager struct (#24815)

This commit is contained in:
Marc Boudreau
2024-01-11 11:09:41 -05:00
committed by GitHub
parent 7697e8b14c
commit ae45b8eb04
4 changed files with 24 additions and 24 deletions

View File

@@ -64,7 +64,6 @@ import (
"github.com/hashicorp/vault/vault/plugincatalog" "github.com/hashicorp/vault/vault/plugincatalog"
"github.com/hashicorp/vault/vault/quotas" "github.com/hashicorp/vault/vault/quotas"
vaultseal "github.com/hashicorp/vault/vault/seal" vaultseal "github.com/hashicorp/vault/vault/seal"
uicustommessages "github.com/hashicorp/vault/vault/ui_custom_messages"
"github.com/hashicorp/vault/version" "github.com/hashicorp/vault/version"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
uberAtomic "go.uber.org/atomic" uberAtomic "go.uber.org/atomic"
@@ -1264,7 +1263,7 @@ func NewCore(conf *CoreConfig) (*Core, error) {
// UI // UI
uiStoragePrefix := systemBarrierPrefix + "ui" uiStoragePrefix := systemBarrierPrefix + "ui"
c.uiConfig = NewUIConfig(conf.EnableUI, physical.NewView(c.physical, uiStoragePrefix), NewBarrierView(c.barrier, uiStoragePrefix)) c.uiConfig = NewUIConfig(conf.EnableUI, physical.NewView(c.physical, uiStoragePrefix), NewBarrierView(c.barrier, uiStoragePrefix))
c.customMessageManager = uicustommessages.NewManager(c.barrier) c.customMessageManager = createCustomMessageManager(c.barrier, c)
// Listeners // Listeners
err = c.configureListeners(conf) err = c.configureListeners(conf)

View File

@@ -16,6 +16,7 @@ import (
"github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/vault/quotas" "github.com/hashicorp/vault/vault/quotas"
"github.com/hashicorp/vault/vault/replication" "github.com/hashicorp/vault/vault/replication"
uicustommessages "github.com/hashicorp/vault/vault/ui_custom_messages"
) )
const ( const (
@@ -202,3 +203,11 @@ func (c *Core) MissingRequiredState(raw []string, perfStandby bool) bool {
func DiagnoseCheckLicense(ctx context.Context, vaultCore *Core, coreConfig CoreConfig, generate bool) (bool, []string) { func DiagnoseCheckLicense(ctx context.Context, vaultCore *Core, coreConfig CoreConfig, generate bool) (bool, []string) {
return false, nil return false, nil
} }
// createCustomMessageManager is a function implemented differently for the
// community edition and the enterprise edition. This is the community
// edition implementation. It simply constructs a uicustommessages.Manager
// instance and returns a pointer to it.
func createCustomMessageManager(storage logical.Storage, _ *Core) CustomMessagesManager {
return uicustommessages.NewManager(storage)
}

View File

@@ -26,31 +26,28 @@ const (
MaximumMessageCountPerNamespace int = 100 MaximumMessageCountPerNamespace int = 100
) )
// nsManager is the NamespaceManager instance used to determine the set of
// Namespaces to consider when retrieving active Custom Message. This
// variable is re-assigned to point to a real NamespaceManager in the
// enterprise edition.
var nsManager NamespaceManager = &CommunityEditionNamespaceManager{}
// Manager is a struct that provides methods to manage messages stored in a // Manager is a struct that provides methods to manage messages stored in a
// logical.Storage. // logical.Storage.
type Manager struct { type Manager struct {
view logical.Storage view logical.Storage
l sync.RWMutex l sync.RWMutex
nsManager NamespaceManager
} }
// NewManager creates a new Manager struct that has been fully initialized. // NewManager creates a new Manager struct that has been fully initialized.
func NewManager(storage logical.Storage) *Manager { func NewManager(storage logical.Storage) *Manager {
return &Manager{ return &Manager{
view: storage, view: storage,
nsManager: &CommunityEditionNamespaceManager{},
} }
} }
// FindMessages handles getting a list of existing messages that match the // FindMessages handles getting a list of existing messages that match the
// criteria set in the provided FindFilter struct. // criteria set in the provided FindFilter struct.
func (m *Manager) FindMessages(ctx context.Context, filters FindFilter) ([]Message, error) { func (m *Manager) FindMessages(ctx context.Context, filters FindFilter) ([]Message, error) {
nsList, err := getNamespacesToSearch(ctx, filters) nsList, err := m.getNamespacesToSearch(ctx, filters)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -218,7 +215,7 @@ func (m *Manager) putEntry(ctx context.Context, entry *Entry) error {
// This function handles the complexity of gathering all of the applicable // This function handles the complexity of gathering all of the applicable
// namespaces depending on the namespace set in the context and whether the // namespaces depending on the namespace set in the context and whether the
// IncludeAncestors criterion is set to true in the provided FindFilter struct. // IncludeAncestors criterion is set to true in the provided FindFilter struct.
func getNamespacesToSearch(ctx context.Context, filters FindFilter) ([]*namespace.Namespace, error) { func (m *Manager) getNamespacesToSearch(ctx context.Context, filters FindFilter) ([]*namespace.Namespace, error) {
var nsList []*namespace.Namespace var nsList []*namespace.Namespace
ns, err := namespace.FromContext(ctx) ns, err := namespace.FromContext(ctx)
@@ -230,8 +227,8 @@ func getNamespacesToSearch(ctx context.Context, filters FindFilter) ([]*namespac
nsList = append(nsList, ns) nsList = append(nsList, ns)
if filters.IncludeAncestors { if filters.IncludeAncestors {
parentNs := nsManager.GetParentNamespace(ns.Path) parentNs := m.nsManager.GetParentNamespace(ns.Path)
for ; parentNs.ID != ns.ID; parentNs = nsManager.GetParentNamespace(ns.Path) { for ; parentNs.ID != ns.ID; parentNs = m.nsManager.GetParentNamespace(ns.Path) {
ns = parentNs ns = parentNs
nsList = append(nsList, ns) nsList = append(nsList, ns)
} }

View File

@@ -214,23 +214,18 @@ func TestManagerPutEntry(t *testing.T) {
// context (e.g. checking that the list contains 1 element and that it's equal // context (e.g. checking that the list contains 1 element and that it's equal
// to namespace.RootNamespace). // to namespace.RootNamespace).
func TestGetNamespacesToSearch(t *testing.T) { func TestGetNamespacesToSearch(t *testing.T) {
list, err := getNamespacesToSearch(context.Background(), FindFilter{}) testManager := &Manager{nsManager: &CommunityEditionNamespaceManager{}}
list, err := testManager.getNamespacesToSearch(context.Background(), FindFilter{})
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, list) assert.Nil(t, list)
list, err = getNamespacesToSearch(namespace.ContextWithNamespace(context.Background(), namespace.RootNamespace), FindFilter{}) list, err = testManager.getNamespacesToSearch(namespace.ContextWithNamespace(context.Background(), namespace.RootNamespace), FindFilter{})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, list, 1) assert.Len(t, list, 1)
assert.Equal(t, namespace.RootNamespace, list[0]) assert.Equal(t, namespace.RootNamespace, list[0])
// Verify with nsManager set to an instance of testNamespaceManager to testManager.nsManager = &testNamespaceManager{
// ensure that it is used to calculate the list of namespaces.
currentNsManager := nsManager
defer func() {
nsManager = currentNsManager
}()
nsManager = &testNamespaceManager{
results: []namespace.Namespace{ results: []namespace.Namespace{
{ {
ID: "ccc", ID: "ccc",
@@ -247,7 +242,7 @@ func TestGetNamespacesToSearch(t *testing.T) {
}, },
} }
list, err = getNamespacesToSearch(namespace.ContextWithNamespace(context.Background(), &namespace.Namespace{ID: "ddd", Path: "d/"}), FindFilter{IncludeAncestors: true}) list, err = testManager.getNamespacesToSearch(namespace.ContextWithNamespace(context.Background(), &namespace.Namespace{ID: "ddd", Path: "d/"}), FindFilter{IncludeAncestors: true})
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, list, 5) assert.Len(t, list, 5)
assert.Equal(t, list[0].Path, "d/") assert.Equal(t, list[0].Path, "d/")