diff --git a/vault/core.go b/vault/core.go index 857eb603ce..668d0e4c4d 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2482,7 +2482,7 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c } // Run setup-like functions - if err := runUnsealSetupFunctions(ctx, buildUnsealSetupFunctionSlice(c)); err != nil { + if err := runUnsealSetupFunctions(ctx, buildUnsealSetupFunctionSlice(c, true)); err != nil { return err } @@ -2583,7 +2583,7 @@ func (c *Core) setupPluginCatalog(ctx context.Context) error { // buildUnsealSetupFunctionSlice returns a slice of functions, tailored for this // Core's replication state, that can be passed to the runUnsealSetupFunctions // function. -func buildUnsealSetupFunctionSlice(c *Core) []func(context.Context) error { +func buildUnsealSetupFunctionSlice(c *Core, isActive bool) []func(context.Context) error { // setupFunctions is a slice of functions that need to be called in order, // that if any return an error, processing should immediately cease. setupFunctions := []func(context.Context) error{ @@ -2643,7 +2643,7 @@ func buildUnsealSetupFunctionSlice(c *Core) []func(context.Context) error { if c.identityStore == nil { return nil } - return c.identityStore.loadArtifacts(ctx) + return c.identityStore.loadArtifacts(ctx, isActive) }) setupFunctions = append(setupFunctions, func(ctx context.Context) error { return loadPolicyMFAConfigs(ctx, c) diff --git a/vault/core_test.go b/vault/core_test.go index b4089d5119..57739dc9a9 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -3674,7 +3674,7 @@ func TestBuildUnsealSetupFunctionSlice(t *testing.T) { expectedLength: 14, }, } { - funcs := buildUnsealSetupFunctionSlice(testcase.core) + funcs := buildUnsealSetupFunctionSlice(testcase.core, true) assert.Equal(t, testcase.expectedLength, len(funcs), testcase.name) } } diff --git a/vault/identity_store_conflicts.go b/vault/identity_store_conflicts.go index 83f87bc641..ff0ace5238 100644 --- a/vault/identity_store_conflicts.go +++ b/vault/identity_store_conflicts.go @@ -20,11 +20,14 @@ var errDuplicateIdentityName = errors.New("duplicate identity name") // ConflictResolver defines the interface for resolving conflicts between // entities, groups, and aliases. All methods should implement a check for // existing=nil. This is an intentional design choice to allow the caller to -// search for extra information if necessary. +// search for extra information if necessary. Resolvers may not modify existing. +// If they choose to modify duplicate the modified version will be inserted into +// MemDB but they must return true in this case to all the calling code to take +// appropriate actions like persisting the change. type ConflictResolver interface { - ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) error - ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) error - ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) error + ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error) + ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error) + ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) } // errorResolver is a ConflictResolver that logs a warning message when a @@ -36,9 +39,9 @@ type errorResolver struct { // ResolveEntities logs a warning message when a pre-existing Entity is found // and returns a duplicate name error, which should be handled by the caller by // putting the system in case-sensitive mode. -func (r *errorResolver) ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) error { +func (r *errorResolver) ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error) { if existing == nil { - return nil + return false, nil } r.logger.Warn(errDuplicateIdentityName.Error(), @@ -47,15 +50,15 @@ func (r *errorResolver) ResolveEntities(ctx context.Context, existing, duplicate "duplicate_of_id", existing.ID, "action", "merge the duplicate entities into one") - return errDuplicateIdentityName + return false, errDuplicateIdentityName } // ResolveGroups logs a warning message when a pre-existing Group is found and // returns a duplicate name error, which should be handled by the caller by // putting the system in case-sensitive mode. -func (r *errorResolver) ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) error { +func (r *errorResolver) ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error) { if existing == nil { - return nil + return false, nil } r.logger.Warn(errDuplicateIdentityName.Error(), @@ -64,15 +67,15 @@ func (r *errorResolver) ResolveGroups(ctx context.Context, existing, duplicate * "duplicate_of_id", existing.ID, "action", "merge the contents of duplicated groups into one and delete the other") - return errDuplicateIdentityName + return false, errDuplicateIdentityName } // ResolveAliases logs a warning message when a pre-existing Alias is found and // returns a duplicate name error, which should be handled by the caller by // putting the system in case-sensitive mode. -func (r *errorResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) error { +func (r *errorResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) { if existing == nil { - return nil + return false, nil } r.logger.Warn(errDuplicateIdentityName.Error(), @@ -85,7 +88,7 @@ func (r *errorResolver) ResolveAliases(ctx context.Context, parent *identity.Ent "duplicate_of_canonical_id", existing.CanonicalID, "action", "merge the canonical entity IDs into one") - return errDuplicateIdentityName + return false, errDuplicateIdentityName } // duplicateReportingErrorResolver collects duplicate information and optionally @@ -119,26 +122,26 @@ func newDuplicateReportingErrorResolver(logger hclog.Logger) *duplicateReporting } } -func (r *duplicateReportingErrorResolver) ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) error { +func (r *duplicateReportingErrorResolver) ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error) { entityKey := fmt.Sprintf("%s/%s", duplicate.NamespaceID, strings.ToLower(duplicate.Name)) r.seenEntities[entityKey] = append(r.seenEntities[entityKey], duplicate) - return errDuplicateIdentityName + return false, errDuplicateIdentityName } -func (r *duplicateReportingErrorResolver) ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) error { +func (r *duplicateReportingErrorResolver) ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error) { groupKey := fmt.Sprintf("%s/%s", duplicate.NamespaceID, strings.ToLower(duplicate.Name)) r.seenGroups[groupKey] = append(r.seenGroups[groupKey], duplicate) - return errDuplicateIdentityName + return false, errDuplicateIdentityName } -func (r *duplicateReportingErrorResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) error { +func (r *duplicateReportingErrorResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) { aliasKey := fmt.Sprintf("%s/%s", duplicate.MountAccessor, strings.ToLower(duplicate.Name)) if duplicate.Local { r.seenLocalAliases[aliasKey] = append(r.seenLocalAliases[aliasKey], duplicate) } else { r.seenAliases[aliasKey] = append(r.seenAliases[aliasKey], duplicate) } - return errDuplicateIdentityName + return false, errDuplicateIdentityName } type identityDuplicateReportEntry struct { @@ -375,9 +378,9 @@ type renameResolver struct { // pre-existing entity such that only the last occurrence retains its unmodified // name. Note that this is potentially destructive but is the best option // available to resolve duplicates in storage caused by bugs in our validation. -func (r *renameResolver) ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) error { +func (r *renameResolver) ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error) { if existing == nil { - return nil + return false, nil } duplicate.Name = duplicate.Name + "-" + duplicate.ID @@ -394,7 +397,7 @@ func (r *renameResolver) ResolveEntities(ctx context.Context, existing, duplicat "renamed_to", duplicate.Name, ) - return nil + return true, nil } // ResolveGroups deals with group name duplicates by renaming those that @@ -403,9 +406,9 @@ func (r *renameResolver) ResolveEntities(ctx context.Context, existing, duplicat // nodes. We use the ID to ensure the new name is unique bit also // deterministic. For now, don't persist this. The user can choose to // resolve it permanently by renaming or deleting explicitly. -func (r *renameResolver) ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) error { +func (r *renameResolver) ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error) { if existing == nil { - return nil + return false, nil } duplicate.Name = duplicate.Name + "-" + duplicate.ID @@ -420,10 +423,10 @@ func (r *renameResolver) ResolveGroups(ctx context.Context, existing, duplicate "renamed_from", existing.Name, "renamed_to", duplicate.Name, ) - return nil + return true, nil } // ResolveAliases is a no-op for the renameResolver implementation. -func (r *renameResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) error { - return nil +func (r *renameResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) { + return false, nil } diff --git a/vault/identity_store_conflicts_test.go b/vault/identity_store_conflicts_test.go index e7fe844586..27aa83c4db 100644 --- a/vault/identity_store_conflicts_test.go +++ b/vault/identity_store_conflicts_test.go @@ -119,7 +119,7 @@ end of identity duplicate report, refer to https://developer.hashicorp.com/vault // Call ResolveEntities, assume existing is nil for now. In real life we // should be passed the existing entity for the exact match dupes but we // don't depend on that so it's fine to omit. - _ = r.ResolveEntities(context.Background(), nil, entity) + _, _ = r.ResolveEntities(context.Background(), nil, entity) // Don't care about the actual error here since it would be ignored in // case-sensitive mode anyway. @@ -129,7 +129,7 @@ end of identity duplicate report, refer to https://developer.hashicorp.com/vault Name: pair[1], NamespaceID: pair[0], } - _ = r.ResolveGroups(context.Background(), nil, group) + _, _ = r.ResolveGroups(context.Background(), nil, group) } // Load aliases second because that is realistic and yet we want to report on @@ -148,7 +148,7 @@ end of identity duplicate report, refer to https://developer.hashicorp.com/vault // Parse our hacky DSL to define some alias mounts as local Local: strings.HasPrefix(pair[0], "local-"), } - _ = r.ResolveAliases(context.Background(), entity, nil, alias) + _, _ = r.ResolveAliases(context.Background(), entity, nil, alias) } // "log" the report and check it matches expected report below. @@ -220,13 +220,15 @@ func TestDuplicateRenameResolver(t *testing.T) { // Simulate a MemDB lookup existingEntity := seenEntities[name] - err := r.ResolveEntities(context.Background(), existingEntity, entity) + renamed, err := r.ResolveEntities(context.Background(), existingEntity, entity) require.NoError(t, err) if existingEntity != nil { + require.True(t, renamed) require.Equal(t, name+"-"+id, entity.Name) require.Equal(t, existingEntity.ID, entity.Metadata["duplicate_of_canonical_id"]) } else { + require.False(t, renamed) seenEntities[name] = entity } @@ -239,13 +241,15 @@ func TestDuplicateRenameResolver(t *testing.T) { // More MemDB mocking existingGroup := seenGroups[name] - err = r.ResolveGroups(context.Background(), existingGroup, group) + renamed, err = r.ResolveGroups(context.Background(), existingGroup, group) require.NoError(t, err) if existingGroup != nil { + require.True(t, renamed) require.Equal(t, name+"-"+id, group.Name) require.Equal(t, existingGroup.ID, group.Metadata["duplicate_of_canonical_id"]) } else { + require.False(t, renamed) seenGroups[name] = group } } diff --git a/vault/identity_store_test.go b/vault/identity_store_test.go index ed3f23580e..2eee174ba2 100644 --- a/vault/identity_store_test.go +++ b/vault/identity_store_test.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "math/rand" + "os" "regexp" "slices" "strconv" @@ -19,7 +20,6 @@ import ( "github.com/hashicorp/go-hclog" uuid "github.com/hashicorp/go-uuid" credGithub "github.com/hashicorp/vault/builtin/credential/github" - "github.com/hashicorp/vault/builtin/credential/userpass" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/helper/activationflags" "github.com/hashicorp/vault/helper/identity" @@ -1416,12 +1416,65 @@ func TestIdentityStoreInvalidate_TemporaryEntity(t *testing.T) { // the identity cleanup rename resolver to ensure that loading is deterministic // for both. func TestIdentityStoreLoadingIsDeterministic(t *testing.T) { - t.Run(t.Name()+"-error-resolver", func(t *testing.T) { - identityStoreLoadingIsDeterministic(t, false) - }) - t.Run(t.Name()+"-identity-cleanup", func(t *testing.T) { - identityStoreLoadingIsDeterministic(t, true) - }) + seedval := rand.Int63() + if os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED") != "" { + var err error + seedval, err = strconv.ParseInt(os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED"), 10, 64) + require.NoError(t, err) + } + seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test + defer t.Logf("Test generated with seed: %d", seedval) + + tests := []struct { + name string + flags *determinismTestFlags + }{ + { + name: "error-resolver-primary", + flags: &determinismTestFlags{ + identityDeduplication: false, + secondary: false, + seed: seed, + }, + }, + { + name: "identity-cleanup-primary", + flags: &determinismTestFlags{ + identityDeduplication: true, + secondary: false, + seed: seed, + }, + }, + + { + name: "error-resolver-secondary", + flags: &determinismTestFlags{ + identityDeduplication: false, + secondary: true, + seed: seed, + }, + }, + { + name: "identity-cleanup-secondary", + flags: &determinismTestFlags{ + identityDeduplication: true, + secondary: true, + seed: seed, + }, + }, + } + + for _, test := range tests { + t.Run(t.Name()+"-"+test.name, func(t *testing.T) { + identityStoreLoadingIsDeterministic(t, test.flags) + }) + } +} + +type determinismTestFlags struct { + identityDeduplication bool + secondary bool + seed *rand.Rand } // identityStoreLoadingIsDeterministic is a property-based test helper that @@ -1432,7 +1485,7 @@ func TestIdentityStoreLoadingIsDeterministic(t *testing.T) { // deterministic anyway if all data in storage was correct see comments inline // for examples of ways storage can be corrupt with respect to the expected // schema invariants. -func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication bool) { +func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFlags) { // Create some state in store that could trigger non-deterministic behavior. // The nature of the identity store schema is such that the order of loading // entities etc shouldn't matter even if it was non-deterministic, however due @@ -1456,7 +1509,7 @@ func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication boo Logger: logger, BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(), CredentialBackends: map[string]logical.Factory{ - "userpass": userpass.Factory, + "userpass": credUserpass.Factory, }, } @@ -1470,6 +1523,10 @@ func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication boo ctx := context.Background() + seed := flags.seed + identityDeduplication := flags.identityDeduplication + secondary := flags.secondary + // We create 100 entities each with 1 non-local alias and 1 local alias. We // then randomly create duplicate alias or local alias entries with a // probability that is unrealistic but ensures we have duplicates on every @@ -1478,9 +1535,9 @@ func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication boo name := fmt.Sprintf("entity-%d", i) alias := fmt.Sprintf("alias-%d", i) localAlias := fmt.Sprintf("localalias-%d", i) - e := makeEntityForPacker(t, name, c.identityStore.entityPacker) - attachAlias(t, e, alias, upme) - attachAlias(t, e, localAlias, localMe) + e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed) + attachAlias(t, e, alias, upme, seed) + attachAlias(t, e, localAlias, localMe, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) require.NoError(t, err) @@ -1489,35 +1546,35 @@ func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication boo // few double and maybe triple duplicates of each type every few test runs // and may have duplicates of both types or neither etc. pDup := 0.3 - rnd := rand.Float64() + rnd := seed.Float64() dupeNum := 1 for rnd < pDup && dupeNum < 10 { - e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-dup-%d", i, dupeNum), c.identityStore.entityPacker) - attachAlias(t, e, alias, upme) + e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-dup-%d", i, dupeNum), c.identityStore.entityPacker, seed) + attachAlias(t, e, alias, upme, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) require.NoError(t, err) // Toss again to see if we continue - rnd = rand.Float64() + rnd = seed.Float64() dupeNum++ } // Toss the coin again to see if there are any local dupes dupeNum = 1 - rnd = rand.Float64() + rnd = seed.Float64() for rnd < pDup && dupeNum < 10 { - e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-localdup-%d", i, dupeNum), c.identityStore.entityPacker) - attachAlias(t, e, localAlias, localMe) + e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-localdup-%d", i, dupeNum), c.identityStore.entityPacker, seed) + attachAlias(t, e, localAlias, localMe, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) require.NoError(t, err) - rnd = rand.Float64() + rnd = seed.Float64() dupeNum++ } // See if we should add entity _name_ duplicates too (with no aliases) - rnd = rand.Float64() + rnd = seed.Float64() for rnd < pDup { - e := makeEntityForPacker(t, name, c.identityStore.entityPacker) + e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) require.NoError(t, err) - rnd = rand.Float64() + rnd = seed.Float64() } // One more edge case is that it's currently possible as of the time of // writing for a failure during entity invalidation to result in a permanent @@ -1543,7 +1600,7 @@ func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication boo if i%2 == 0 { alias = fmt.Sprintf("groupalias-%d", i) } - e := makeGroupWithNameAndAlias(t, name, alias, c.identityStore.groupPacker, upme) + e := makeGroupWithNameAndAlias(t, name, alias, c.identityStore.groupPacker, upme, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.groupPacker, e.ID, e) require.NoError(t, err) } @@ -1551,18 +1608,20 @@ func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication boo // non-deterministic behavior. for i := 0; i <= 10; i++ { name := fmt.Sprintf("group-dup-%d", i) - e := makeGroupWithNameAndAlias(t, name, "groupalias-dup", c.identityStore.groupPacker, upme) + e := makeGroupWithNameAndAlias(t, name, "groupalias-dup", c.identityStore.groupPacker, upme, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.groupPacker, e.ID, e) require.NoError(t, err) } // Add a second and third groups with duplicate names too. for _, name := range []string{"group-0", "group-1", "group-1"} { - e := makeGroupWithNameAndAlias(t, name, "", c.identityStore.groupPacker, upme) + e := makeGroupWithNameAndAlias(t, name, "", c.identityStore.groupPacker, upme, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.groupPacker, e.ID, e) require.NoError(t, err) } - entIdentityStoreDeterminismTestSetup(t, ctx, c, upme, localMe) + if secondary { + entIdentityStoreDeterminismSecondaryTestSetup(t, ctx, c, upme, localMe, seed) + } // Storage is now primed for the test. @@ -1599,7 +1658,7 @@ func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication boo err := c.identityStore.resetDB() require.NoError(t, err) - err = c.identityStore.loadArtifacts(ctx) + err = c.identityStore.loadArtifacts(ctx, true) if i > 0 { require.Equal(t, prevErr, err) } @@ -1652,7 +1711,9 @@ func identityStoreLoadingIsDeterministic(t *testing.T, identityDeduplication boo // note `lastIDs` argument is not needed anymore but we can't change the // signature without breaking enterprise. It's simpler to keep it unused // for now until both parts of this merge. - entIdentityStoreDeterminismAssert(t, i, loadedNames, nil) + if secondary { + entIdentityStoreDeterminismSecondaryAssert(t, i, loadedNames, nil) + } if i > 0 { // Should be in the same order if we are deterministic since MemDB has strong ordering. @@ -1676,7 +1737,7 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) { Logger: logger, BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(), CredentialBackends: map[string]logical.Factory{ - "userpass": userpass.Factory, + "userpass": credUserpass.Factory, }, } @@ -1690,9 +1751,16 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) { ctx := namespace.RootContext(nil) - identityCreateCaseDuplicates(t, ctx, c, upme, localMe) + seedval := rand.Int63() + if os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED") != "" { + seedval, err = strconv.ParseInt(os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED"), 10, 64) + require.NoError(t, err) + } + seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test + defer t.Logf("Test generated with seed %d", seedval) + identityCreateCaseDuplicates(t, ctx, c, upme, localMe, seed) - entIdentityStoreDuplicateReportTestSetup(t, ctx, c, rootToken) + entIdentityStoreDuplicateReportTestSetup(t, ctx, c, rootToken, seed) // Storage is now primed for the test. @@ -1711,7 +1779,7 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) { } logger.RegisterSink(unsealLogger) - err = c.identityStore.loadArtifacts(ctx) + err = c.identityStore.loadArtifacts(ctx, true) require.NoError(t, err) logger.DeregisterSink(unsealLogger) diff --git a/vault/identity_store_test_stubs_oss.go b/vault/identity_store_test_stubs_oss.go index 821b62428c..02a78851a6 100644 --- a/vault/identity_store_test_stubs_oss.go +++ b/vault/identity_store_test_stubs_oss.go @@ -7,20 +7,21 @@ package vault import ( "context" + "math/rand" "testing" ) //go:generate go run github.com/hashicorp/vault/tools/stubmaker -func entIdentityStoreDeterminismTestSetup(t *testing.T, ctx context.Context, c *Core, me, localme *MountEntry) { +func entIdentityStoreDeterminismSecondaryTestSetup(t *testing.T, ctx context.Context, c *Core, me, localme *MountEntry, seed *rand.Rand) { // no op } -func entIdentityStoreDeterminismAssert(t *testing.T, i int, loadedIDs, lastIDs []string) { +func entIdentityStoreDeterminismSecondaryAssert(t *testing.T, i int, loadedIDs, lastIDs []string) { // no op } -func entIdentityStoreDuplicateReportTestSetup(t *testing.T, ctx context.Context, c *Core, rootToken string) { +func entIdentityStoreDuplicateReportTestSetup(t *testing.T, ctx context.Context, c *Core, rootToken string, seed *rand.Rand) { // no op } diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 1b08540db6..39e203b45a 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "math/rand" "strings" "sync" "testing" @@ -38,7 +39,7 @@ var ( // loadArtifacts is responsible for loading entities, groups, and aliases from // storage into MemDB. -func (i *IdentityStore) loadArtifacts(ctx context.Context) error { +func (i *IdentityStore) loadArtifacts(ctx context.Context, isActive bool) error { if i == nil { return nil } @@ -48,10 +49,10 @@ func (i *IdentityStore) loadArtifacts(ctx context.Context) error { "case_sensitive", !i.disableLowerCasedNames, "conflict_resolver", i.conflictResolver) - if err := i.loadEntities(ctx); err != nil { + if err := i.loadEntities(ctx, isActive); err != nil { return fmt.Errorf("failed to load entities: %w", err) } - if err := i.loadGroups(ctx); err != nil { + if err := i.loadGroups(ctx, isActive); err != nil { return fmt.Errorf("failed to load groups: %w", err) } if err := i.loadOIDCClients(ctx); err != nil { @@ -144,7 +145,7 @@ func (i *IdentityStore) activateDeduplication(ctx context.Context, req *logical. return fmt.Errorf("failed to reset existing identity state: %w", err) } - if err := i.loadArtifacts(ctx); err != nil { + if err := i.loadArtifacts(ctx, i.localNode.HAState() == consts.Active); err != nil { return fmt.Errorf("failed to activate identity deduplication: %w", err) } @@ -159,7 +160,7 @@ func (i *IdentityStore) sanitizeName(name string) string { return strings.ToLower(name) } -func (i *IdentityStore) loadGroups(ctx context.Context) error { +func (i *IdentityStore) loadGroups(ctx context.Context, isActive bool) error { i.logger.Debug("identity loading groups") existing, err := i.groupPacker.View().List(ctx, groupBucketsPrefix) if err != nil { @@ -210,9 +211,31 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error { if err != nil { return err } - if err := i.conflictResolver.ResolveGroups(ctx, groupByName, group); err != nil && !i.disableLowerCasedNames { + modified, err := i.conflictResolver.ResolveGroups(ctx, groupByName, group) + if err != nil && !i.disableLowerCasedNames { return err } + persist := false + if modified { + // If we modified the group we need to persist the changes to avoid bugs + // where memDB and storage are out of sync in the future (e.g. after + // invalidations of other items in the same bucket later). We do this + // _even_ if `persist=false` because it is in general during unseal but + // this is exactly when we need to fix these. We must be _really_ + // careful to only do this on primary active node though which is the + // only source of truth that should have write access to groups across a + // cluster since they are always non-local. Note that we check !Standby + // and !secondary because we still need to write back even if this is a + // single cluster with no replication setup and I'm not _sure_ that we + // report such a cluster as a primary. + if !i.localNode.ReplicationState().HasState( + consts.ReplicationDRSecondary| + consts.ReplicationPerformanceSecondary| + consts.ReplicationPerformanceStandby, + ) && isActive { + persist = true + } + } if i.logger.IsDebug() { i.logger.Debug("loading group", "namespace", ns.ID, "name", group.Name, "id", group.ID) @@ -224,7 +247,6 @@ func (i *IdentityStore) loadGroups(ctx context.Context) error { // updated when respective entities were deleted. This is here to // check that the entity IDs in the group are indeed valid, and if // not remove them. - persist := false for _, memberEntityID := range group.MemberEntityIDs { entity, err := i.MemDBEntityByID(memberEntityID, false) if err != nil { @@ -403,7 +425,7 @@ func (i *IdentityStore) loadCachedEntitiesOfLocalAliases(ctx context.Context) er return nil } -func (i *IdentityStore) loadEntities(ctx context.Context) error { +func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool) error { // Accumulate existing entities i.logger.Debug("loading entities") existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix) @@ -544,9 +566,31 @@ LOOP: if err != nil { return nil } - if err := i.conflictResolver.ResolveEntities(ctx, entityByName, entity); err != nil && !i.disableLowerCasedNames { + modified, err := i.conflictResolver.ResolveEntities(ctx, entityByName, entity) + if err != nil && !i.disableLowerCasedNames { return err } + persist := false + if modified { + // If we modified the group we need to persist the changes to avoid bugs + // where memDB and storage are out of sync in the future (e.g. after + // invalidations of other items in the same bucket later). We do this + // _even_ if `persist=false` because it is in general during unseal but + // this is exactly when we need to fix these. We must be _really_ + // careful to only do this on primary active node though which is the + // only source of truth that should have write access to groups across a + // cluster since they are always non-local. Note that we check !Stadby + // and !secondary because we still need to write back even if this is a + // single cluster with no replication setup and I'm not _sure_ that we + // report such a cluster as a primary. + if !i.localNode.ReplicationState().HasState( + consts.ReplicationDRSecondary| + consts.ReplicationPerformanceSecondary| + consts.ReplicationPerformanceStandby, + ) && isActive { + persist = true + } + } mountAccessors := getAccessorsOnDuplicateAliases(entity.Aliases) @@ -573,7 +617,7 @@ LOOP: defer tx.Abort() } // Only update MemDB and don't hit the storage again - err = i.upsertEntityInTxn(nsCtx, tx, entity, nil, false) + err = i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist) if err != nil { return fmt.Errorf("failed to update entity in MemDB: %w", err) } @@ -780,8 +824,9 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e // problem to the user and are already logged. We care about different-case // duplicates that are not being considered duplicates right now because we // are in case-sensitive mode so we can report these to the operator ahead - // of them disabling case-sensitive mode. - conflictErr := i.conflictResolver.ResolveAliases(ctx, entity, aliasByFactors, alias) + // of them disabling case-sensitive mode. Note that alias resolvers don't + // ever modify right now so ignore the bool. + _, conflictErr := i.conflictResolver.ResolveAliases(ctx, entity, aliasByFactors, alias) // This appears to be accounting for any duplicate aliases for the same // Entity. In that case we would have skipped over the merge above in the @@ -2783,14 +2828,14 @@ func (i *IdentityStore) countEntitiesByMountAccessor(ctx context.Context) (map[s return byMountAccessor, nil } -func makeEntityForPacker(t *testing.T, name string, p *storagepacker.StoragePacker) *identity.Entity { +func makeEntityForPacker(t *testing.T, name string, p *storagepacker.StoragePacker, seed *rand.Rand) *identity.Entity { t.Helper() - return makeEntityForPackerWithNamespace(t, namespace.RootNamespaceID, name, p) + return makeEntityForPackerWithNamespace(t, namespace.RootNamespaceID, name, p, seed) } -func makeEntityForPackerWithNamespace(t *testing.T, namespaceID, name string, p *storagepacker.StoragePacker) *identity.Entity { +func makeEntityForPackerWithNamespace(t *testing.T, namespaceID, name string, p *storagepacker.StoragePacker, seed *rand.Rand) *identity.Entity { t.Helper() - id, err := uuid.GenerateUUID() + id, err := uuid.GenerateUUIDWithReader(seed) require.NoError(t, err) return &identity.Entity{ ID: id, @@ -2800,9 +2845,10 @@ func makeEntityForPackerWithNamespace(t *testing.T, namespaceID, name string, p } } -func attachAlias(t *testing.T, e *identity.Entity, name string, me *MountEntry) *identity.Alias { +func attachAlias(t *testing.T, e *identity.Entity, name string, me *MountEntry, seed *rand.Rand) *identity.Alias { t.Helper() - id, err := uuid.GenerateUUID() + + id, err := uuid.GenerateUUIDWithReader(seed) require.NoError(t, err) if e.NamespaceID != me.NamespaceID { panic("mount and entity in different namespaces") @@ -2821,7 +2867,7 @@ func attachAlias(t *testing.T, e *identity.Entity, name string, me *MountEntry) return a } -func identityCreateCaseDuplicates(t *testing.T, ctx context.Context, c *Core, upme, localme *MountEntry) { +func identityCreateCaseDuplicates(t *testing.T, ctx context.Context, c *Core, upme, localme *MountEntry, seed *rand.Rand) { t.Helper() if upme.NamespaceID != localme.NamespaceID { @@ -2832,31 +2878,31 @@ func identityCreateCaseDuplicates(t *testing.T, ctx context.Context, c *Core, up // suffixes. for i, suffix := range []string{"-case", "-case", "-cAsE"} { // Entity duplicated by name - e := makeEntityForPackerWithNamespace(t, upme.NamespaceID, "entity"+suffix, c.identityStore.entityPacker) + e := makeEntityForPackerWithNamespace(t, upme.NamespaceID, "entity"+suffix, c.identityStore.entityPacker, seed) err := TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) require.NoError(t, err) // Entity that isn't a dupe itself but has duplicated aliases - e2 := makeEntityForPackerWithNamespace(t, upme.NamespaceID, fmt.Sprintf("entity-%d", i), c.identityStore.entityPacker) + e2 := makeEntityForPackerWithNamespace(t, upme.NamespaceID, fmt.Sprintf("entity-%d", i), c.identityStore.entityPacker, seed) // Add local and non-local aliases for this entity (which will also be // duplicated) - attachAlias(t, e2, "alias"+suffix, upme) - attachAlias(t, e2, "local-alias"+suffix, localme) + attachAlias(t, e2, "alias"+suffix, upme, seed) + attachAlias(t, e2, "local-alias"+suffix, localme, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e2.ID, e2) require.NoError(t, err) // Group duplicated by name - g := makeGroupWithNameAndAlias(t, "group"+suffix, "", c.identityStore.groupPacker, upme) + g := makeGroupWithNameAndAlias(t, "group"+suffix, "", c.identityStore.groupPacker, upme, seed) err = TestHelperWriteToStoragePacker(ctx, c.identityStore.groupPacker, g.ID, g) require.NoError(t, err) } } -func makeGroupWithNameAndAlias(t *testing.T, name, alias string, p *storagepacker.StoragePacker, me *MountEntry) *identity.Group { +func makeGroupWithNameAndAlias(t *testing.T, name, alias string, p *storagepacker.StoragePacker, me *MountEntry, seed *rand.Rand) *identity.Group { t.Helper() - id, err := uuid.GenerateUUID() + id, err := uuid.GenerateUUIDWithReader(seed) require.NoError(t, err) - id2, err := uuid.GenerateUUID() + id2, err := uuid.GenerateUUIDWithReader(seed) require.NoError(t, err) g := &identity.Group{ ID: id, @@ -2877,9 +2923,9 @@ func makeGroupWithNameAndAlias(t *testing.T, name, alias string, p *storagepacke return g } -func makeLocalAliasWithName(t *testing.T, name, entityID string, bucketKey string, me *MountEntry) *identity.LocalAliases { +func makeLocalAliasWithName(t *testing.T, name, entityID string, bucketKey string, me *MountEntry, seed *rand.Rand) *identity.LocalAliases { t.Helper() - id, err := uuid.GenerateUUID() + id, err := uuid.GenerateUUIDWithReader(seed) require.NoError(t, err) return &identity.LocalAliases{ Aliases: []*identity.Alias{