diff --git a/builtin/logical/pki/acme_billing_test.go b/builtin/logical/pki/acme_billing_test.go index f8db67e644..b1948d7be2 100644 --- a/builtin/logical/pki/acme_billing_test.go +++ b/builtin/logical/pki/acme_billing_test.go @@ -104,17 +104,15 @@ func TestACMEBilling(t *testing.T) { expectedCount = validateClientCount(t, client, "ns2/pki", expectedCount+1, "unique identifier in a different namespace") // Check the current fragment - localFragment, globalFragment := cluster.Cores[0].Core.ResetActivityLog() - if globalFragment == nil || localFragment == nil { + fragment := cluster.Cores[0].Core.ResetActivityLog()[0] + if fragment == nil { t.Fatal("no fragment created") } - validateAcmeClientTypes(t, localFragment[0], 0) - validateAcmeClientTypes(t, globalFragment[0], expectedCount) + validateAcmeClientTypes(t, fragment, expectedCount) } func validateAcmeClientTypes(t *testing.T, fragment *activity.LogFragment, expectedCount int64) { t.Helper() - if int64(len(fragment.Clients)) != expectedCount { t.Fatalf("bad number of entities, expected %v: got %v, entities are: %v", expectedCount, len(fragment.Clients), fragment.Clients) } diff --git a/command/command_testonly/operator_usage_testonly_test.go b/command/command_testonly/operator_usage_testonly_test.go index 4cdfc0536a..74d67291fd 100644 --- a/command/command_testonly/operator_usage_testonly_test.go +++ b/command/command_testonly/operator_usage_testonly_test.go @@ -53,7 +53,7 @@ func TestOperatorUsageCommandRun(t *testing.T) { now := time.Now().UTC() - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(6, clientcountutil.WithClientType("entity")). NewClientsSeen(4, clientcountutil.WithClientType("non-entity-token")). diff --git a/sdk/helper/clientcountutil/clientcountutil.go b/sdk/helper/clientcountutil/clientcountutil.go index 85b25dab43..d09c5be13d 100644 --- a/sdk/helper/clientcountutil/clientcountutil.go +++ b/sdk/helper/clientcountutil/clientcountutil.go @@ -280,30 +280,39 @@ func (d *ActivityLogDataGenerator) ToProto() *generation.ActivityLogMockInput { } // Write writes the data to the API with the given write options. The method -// returns the new local and global paths that have been written. Note that the API endpoint will +// returns the new paths that have been written. Note that the API endpoint will // only be present when Vault has been compiled with the "testonly" flag. -func (d *ActivityLogDataGenerator) Write(ctx context.Context, writeOptions ...generation.WriteOptions) ([]string, []string, error) { +func (d *ActivityLogDataGenerator) Write(ctx context.Context, writeOptions ...generation.WriteOptions) ([]string, []string, []string, error) { d.data.Write = writeOptions err := VerifyInput(d.data) if err != nil { - return nil, nil, err + return nil, nil, nil, err } data, err := d.ToJSON() if err != nil { - return nil, nil, err + return nil, nil, nil, err } resp, err := d.client.Logical().WriteWithContext(ctx, "sys/internal/counters/activity/write", map[string]interface{}{"input": string(data)}) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if resp.Data == nil { - return nil, nil, fmt.Errorf("received no data") + return nil, nil, nil, fmt.Errorf("received no data") + } + paths := resp.Data["paths"] + castedPaths, ok := paths.([]interface{}) + if !ok { + return nil, nil, nil, fmt.Errorf("invalid paths data: %v", paths) + } + returnPaths := make([]string, 0, len(castedPaths)) + for _, path := range castedPaths { + returnPaths = append(returnPaths, path.(string)) } localPaths := resp.Data["local_paths"] localCastedPaths, ok := localPaths.([]interface{}) if !ok { - return nil, nil, fmt.Errorf("invalid local paths data: %v", localPaths) + return nil, nil, nil, fmt.Errorf("invalid local paths data: %v", localPaths) } returnLocalPaths := make([]string, 0, len(localCastedPaths)) for _, path := range localCastedPaths { @@ -313,13 +322,13 @@ func (d *ActivityLogDataGenerator) Write(ctx context.Context, writeOptions ...ge globalPaths := resp.Data["global_paths"] globalCastedPaths, ok := globalPaths.([]interface{}) if !ok { - return nil, nil, fmt.Errorf("invalid global paths data: %v", globalPaths) + return nil, nil, nil, fmt.Errorf("invalid global paths data: %v", globalPaths) } returnGlobalPaths := make([]string, 0, len(globalCastedPaths)) for _, path := range globalCastedPaths { returnGlobalPaths = append(returnGlobalPaths, path.(string)) } - return returnLocalPaths, returnGlobalPaths, nil + return returnPaths, returnLocalPaths, returnGlobalPaths, nil } // VerifyInput checks that the input data is valid diff --git a/sdk/helper/clientcountutil/clientcountutil_test.go b/sdk/helper/clientcountutil/clientcountutil_test.go index 6374074365..4ea987fed0 100644 --- a/sdk/helper/clientcountutil/clientcountutil_test.go +++ b/sdk/helper/clientcountutil/clientcountutil_test.go @@ -116,7 +116,7 @@ func TestNewCurrentMonthData_AddClients(t *testing.T) { // sent to the server is correct. func TestWrite(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := io.WriteString(w, `{"data":{"global_paths":["path2","path3"], "local_paths":["path3","path4"]}}`) + _, err := io.WriteString(w, `{"data":{"paths":["path1","path2"],"global_paths":["path2","path3"], "local_paths":["path3","path4"]}}`) require.NoError(t, err) body, err := io.ReadAll(r.Body) require.NoError(t, err) @@ -131,7 +131,7 @@ func TestWrite(t *testing.T) { Address: ts.URL, }) require.NoError(t, err) - localPaths, globalPaths, err := NewActivityLogData(client). + paths, localPaths, globalPaths, err := NewActivityLogData(client). NewPreviousMonthData(3). NewClientSeen(). NewPreviousMonthData(2). @@ -140,6 +140,7 @@ func TestWrite(t *testing.T) { NewCurrentMonthData().Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) require.NoError(t, err) + require.Equal(t, []string{"path1", "path2"}, paths) require.Equal(t, []string{"path2", "path3"}, globalPaths) require.Equal(t, []string{"path3", "path4"}, localPaths) } diff --git a/vault/activity_log.go b/vault/activity_log.go index 757165f3e1..3ad43d31b4 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -51,6 +51,7 @@ const ( distinctClientsBasePath = "log/distinctclients/" // for testing purposes (public as needed) + ActivityLogPrefix = "sys/counters/activity/log/" ActivityGlobalLogPrefix = "sys/counters/activity/global/log/" ActivityLogLocalPrefix = "sys/counters/activity/local/log/" ActivityPrefix = "sys/counters/activity/" @@ -146,7 +147,8 @@ type ActivityLog struct { // Acquire "l" before fragmentLock, globalFragmentLock, and localFragmentLock if all must be held. l sync.RWMutex - // fragmentLock protects enable + // fragmentLock protects enable, partialMonthClientTracker, fragment, + // standbyFragmentsReceived. fragmentLock sync.RWMutex // localFragmentLock protects partialMonthLocalClientTracker, localFragment, @@ -178,6 +180,9 @@ type ActivityLog struct { // could be adapted to use a secondary in the future. nodeID string + // current log fragment (may be nil) + fragment *activity.LogFragment + // Channel to signal a new fragment has been created // so it's appropriate to start the timer. newFragmentCh chan struct{} @@ -205,6 +210,9 @@ type ActivityLog struct { // track metadata and contents of the most recent local log segment currentLocalSegment segmentInfo + // Fragments received from performance standbys + standbyFragmentsReceived []*activity.LogFragment + // Local fragments received from performance standbys standbyLocalFragmentsReceived []*activity.LogFragment @@ -230,6 +238,9 @@ type ActivityLog struct { // for testing: is config currently being invalidated. protected by l configInvalidationInProgress bool + // partialMonthClientTracker tracks active clients this month. Protected by fragmentLock. + partialMonthClientTracker map[string]*activity.EntityRecord + // partialMonthLocalClientTracker tracks active local clients this month. Protected by localFragmentLock. partialMonthLocalClientTracker map[string]*activity.EntityRecord @@ -359,6 +370,7 @@ func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics me newFragmentCh: make(chan struct{}, 1), sendCh: make(chan struct{}, 1), // buffered so it can be triggered by fragment size doneCh: make(chan struct{}, 1), + partialMonthClientTracker: make(map[string]*activity.EntityRecord), partialMonthLocalClientTracker: make(map[string]*activity.EntityRecord), newGlobalClientFragmentCh: make(chan struct{}, 1), globalPartialMonthClientTracker: make(map[string]*activity.EntityRecord), @@ -402,6 +414,7 @@ func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics me }, clientSequenceNumber: 0, }, + standbyFragmentsReceived: make([]*activity.LogFragment, 0), standbyLocalFragmentsReceived: make([]*activity.LogFragment, 0), standbyGlobalFragmentsReceived: make([]*activity.LogFragment, 0), secondaryGlobalClientFragments: make([]*activity.LogFragment, 0), @@ -449,7 +462,14 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for defer a.metrics.MeasureSinceWithLabels([]string{"core", "activity", "segment_write"}, a.clock.Now(), []metricsutil.Label{}) - // Swap out the pending global fragments + // Swap out the pending regular fragments + a.fragmentLock.Lock() + currentFragment := a.fragment + a.fragment = nil + standbys := a.standbyFragmentsReceived + a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) + a.fragmentLock.Unlock() + a.globalFragmentLock.Lock() secondaryGlobalClients := a.secondaryGlobalClientFragments a.secondaryGlobalClientFragments = make([]*activity.LogFragment, 0) @@ -485,10 +505,14 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for // If segment start time is zero, do not update or write // (even if force is true). This can happen if activityLog is // disabled after a save as been triggered. - if a.currentGlobalSegment.startTimestamp == 0 { + if a.currentSegment.startTimestamp == 0 { return nil } + if ret := a.createCurrentSegmentFromFragments(ctx, append(standbys, currentFragment), &a.currentSegment, force, ""); ret != nil { + return ret + } + // If we are the primary, store global clients // Create fragments from global clients and store the segment if !a.core.IsPerfSecondary() { @@ -551,7 +575,7 @@ func (a *ActivityLog) createCurrentSegmentFromFragments(ctx context.Context, fra // month when the client upgrades to 1.9, we must retain this functionality. for ns, val := range f.NonEntityTokens { // We track these pre-1.9 values in the old location, which is - // currentSegment.tokenCount, as opposed to the counter that stores tokens + // a.currentSegment.tokenCount, as opposed to the counter that stores tokens // without entities that have client IDs, namely // a.partialMonthClientTracker.nonEntityCountByNamespaceID. This preserves backward // compatibility for the precomputedQueryWorkers and the segment storing @@ -712,7 +736,7 @@ func parseSegmentNumberFromPath(path string) (int, bool) { // sorted last to first func (a *ActivityLog) availableLogs(ctx context.Context, upTo time.Time) ([]time.Time, error) { paths := make([]string, 0) - for _, basePath := range []string{activityLocalPathPrefix + activityEntityBasePath, activityGlobalPathPrefix + activityEntityBasePath, activityTokenLocalBasePath} { + for _, basePath := range []string{activityEntityBasePath, activityLocalPathPrefix + activityEntityBasePath, activityGlobalPathPrefix + activityEntityBasePath, activityTokenLocalBasePath} { p, err := a.view.List(ctx, basePath) if err != nil { return nil, err @@ -761,17 +785,21 @@ func (a *ActivityLog) getMostRecentActivityLogSegment(ctx context.Context, now t } // getLastEntitySegmentNumber returns the (non-negative) last segment number for the :startTime:, if it exists -func (a *ActivityLog) getLastEntitySegmentNumber(ctx context.Context, startTime time.Time) (uint64, uint64, bool, error) { +func (a *ActivityLog) getLastEntitySegmentNumber(ctx context.Context, startTime time.Time) (uint64, uint64, uint64, bool, error) { + segmentHighestNum, segmentPresent, err := a.getLastSegmentNumberByEntityPath(ctx, activityEntityBasePath+fmt.Sprint(startTime.Unix())+"/") + if err != nil { + return 0, 0, 0, false, err + } globalHighestNum, globalSegmentPresent, err := a.getLastSegmentNumberByEntityPath(ctx, activityGlobalPathPrefix+activityEntityBasePath+fmt.Sprint(startTime.Unix())+"/") if err != nil { - return 0, 0, false, err + return 0, 0, 0, false, err } localHighestNum, localSegmentPresent, err := a.getLastSegmentNumberByEntityPath(ctx, activityLocalPathPrefix+activityEntityBasePath+fmt.Sprint(startTime.Unix())+"/") if err != nil { - return 0, 0, false, err + return 0, 0, 0, false, err } - return uint64(localHighestNum), uint64(globalHighestNum), (localSegmentPresent || globalSegmentPresent), nil + return segmentHighestNum, uint64(localHighestNum), uint64(globalHighestNum), (segmentPresent || localSegmentPresent || globalSegmentPresent), nil } func (a *ActivityLog) getLastSegmentNumberByEntityPath(ctx context.Context, entityPath string) (uint64, bool, error) { @@ -801,33 +829,30 @@ func (a *ActivityLog) getLastSegmentNumberByEntityPath(ctx context.Context, enti // WalkEntitySegments loads each of the entity segments for a particular start time func (a *ActivityLog) WalkEntitySegments(ctx context.Context, startTime time.Time, hll *hyperloglog.Sketch, walkFn func(*activity.EntityActivityLog, time.Time, *hyperloglog.Sketch) error) error { - baseGlobalPath := activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" - baseLocalPath := activityLocalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + basePath := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + pathList, err := a.view.List(ctx, basePath) + if err != nil { + return err + } - for _, basePath := range []string{baseGlobalPath, baseLocalPath} { - pathList, err := a.view.List(ctx, basePath) + for _, path := range pathList { + raw, err := a.view.Get(ctx, basePath+path) if err != nil { return err } - for _, path := range pathList { - raw, err := a.view.Get(ctx, basePath+path) - if err != nil { - return err - } - if raw == nil { - a.logger.Warn("expected log segment not found", "startTime", startTime, "segment", path) - continue - } + if raw == nil { + a.logger.Warn("expected log segment not found", "startTime", startTime, "segment", path) + continue + } - out := &activity.EntityActivityLog{} - err = proto.Unmarshal(raw.Value, out) - if err != nil { - return fmt.Errorf("unable to parse segment %v%v: %w", basePath, path, err) - } - err = walkFn(out, startTime, hll) - if err != nil { - return fmt.Errorf("unable to walk entities: %w", err) - } + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(raw.Value, out) + if err != nil { + return fmt.Errorf("unable to parse segment %v%v: %w", basePath, path, err) + } + err = walkFn(out, startTime, hll) + if err != nil { + return fmt.Errorf("unable to walk entities: %w", err) } } return nil @@ -864,57 +889,73 @@ func (a *ActivityLog) WalkTokenSegments(ctx context.Context, } // loadPriorEntitySegment populates the in-memory tracker for entity IDs that have -// been active "this month". If the entity segment to load is global, globalPartialMonthClientTracker -// is updated else partialMonthLocalClientTracker gets updated. -func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time.Time, sequenceNum uint64, isLocal bool) error { - a.l.RLock() - defer a.l.RUnlock() - - // protecting a.enabled - a.fragmentLock.Lock() - defer a.fragmentLock.Unlock() - - // load all the active global clients - if !isLocal { - globalPath := activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) - data, err := a.view.Get(ctx, globalPath) - if err != nil { - return err - } - if data == nil { - return nil - } - out := &activity.EntityActivityLog{} - err = proto.Unmarshal(data.Value, out) - if err != nil { - return err - } - a.globalFragmentLock.Lock() - // Handle the (unlikely) case where the end of the month has been reached while background loading. - // Or the feature has been disabled. - if a.enabled && startTime.Unix() == a.currentGlobalSegment.startTimestamp { - for _, ent := range out.Clients { - a.globalPartialMonthClientTracker[ent.ClientID] = ent - } - } - a.globalFragmentLock.Unlock() - return nil - } - - // load all the active local clients - localPath := activityLocalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) - data, err := a.view.Get(ctx, localPath) +// been active "this month" +func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time.Time, sequenceNum uint64) error { + path := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) + data, err := a.view.Get(ctx, path) if err != nil { return err } if data == nil { return nil } + out := &activity.EntityActivityLog{} err = proto.Unmarshal(data.Value, out) if err != nil { return err } + + a.l.RLock() + defer a.l.RUnlock() + a.fragmentLock.Lock() + // Handle the (unlikely) case where the end of the month has been reached while background loading. + // Or the feature has been disabled. + if a.enabled && startTime.Unix() == a.currentSegment.startTimestamp { + for _, ent := range out.Clients { + a.partialMonthClientTracker[ent.ClientID] = ent + } + } + a.fragmentLock.Unlock() + + // load all the active global clients + globalPath := activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) + data, err = a.view.Get(ctx, globalPath) + if err != nil { + return err + } + if data == nil { + return nil + } + out = &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } + a.globalFragmentLock.Lock() + // Handle the (unlikely) case where the end of the month has been reached while background loading. + // Or the feature has been disabled. + if a.enabled && startTime.Unix() == a.currentGlobalSegment.startTimestamp { + for _, ent := range out.Clients { + a.globalPartialMonthClientTracker[ent.ClientID] = ent + } + } + a.globalFragmentLock.Unlock() + + // load all the active local clients + localPath := activityLocalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) + data, err = a.view.Get(ctx, localPath) + if err != nil { + return err + } + if data == nil { + return nil + } + out = &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } a.localFragmentLock.Lock() // Handle the (unlikely) case where the end of the month has been reached while background loading. // Or the feature has been disabled. @@ -929,44 +970,75 @@ func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time } // loadCurrentClientSegment loads the most recent segment (for "this month") -// into memory (to append new entries), and to the globalPartialMonthClientTracker and partialMonthLocalClientTracker to +// into memory (to append new entries), and to the partialMonthClientTracker to // avoid duplication call with fragmentLock, globalFragmentLock, localFragmentLock and l held. -func (a *ActivityLog) loadCurrentClientSegment(ctx context.Context, startTime time.Time, localSegmentSequenceNumber uint64, globalSegmentSequenceNumber uint64) error { - // load current global segment - path := activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(globalSegmentSequenceNumber, 10) - - // setting a.currentSegment timestamp to support upgrades - a.currentSegment.startTimestamp = startTime.Unix() - +func (a *ActivityLog) loadCurrentClientSegment(ctx context.Context, startTime time.Time, sequenceNum uint64, localSegmentSequenceNumber uint64, globalSegmentSequenceNumber uint64) error { + path := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) data, err := a.view.Get(ctx, path) if err != nil { return err } - if data != nil { - out := &activity.EntityActivityLog{} - err = proto.Unmarshal(data.Value, out) - if err != nil { - return err - } + if data == nil { + return nil + } - if !a.core.perfStandby { - a.currentGlobalSegment = segmentInfo{ - startTimestamp: startTime.Unix(), - currentClients: &activity.EntityActivityLog{ - Clients: out.Clients, - }, - tokenCount: &activity.TokenCount{ - CountByNamespaceID: make(map[string]uint64), - }, - clientSequenceNumber: globalSegmentSequenceNumber, - } - } else { - // populate this for edge case checking (if end of month passes while background loading on standby) - a.currentGlobalSegment.startTimestamp = startTime.Unix() + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } + + if !a.core.perfStandby { + a.currentSegment = segmentInfo{ + startTimestamp: startTime.Unix(), + currentClients: &activity.EntityActivityLog{ + Clients: out.Clients, + }, + tokenCount: a.currentSegment.tokenCount, + clientSequenceNumber: sequenceNum, } - for _, client := range out.Clients { - a.globalPartialMonthClientTracker[client.ClientID] = client + } else { + // populate this for edge case checking (if end of month passes while background loading on standby) + a.currentSegment.startTimestamp = startTime.Unix() + } + + for _, client := range out.Clients { + a.partialMonthClientTracker[client.ClientID] = client + } + + // load current global segment + path = activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(globalSegmentSequenceNumber, 10) + data, err = a.view.Get(ctx, path) + if err != nil { + return err + } + if data == nil { + return nil + } + + out = &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } + + if !a.core.perfStandby { + a.currentGlobalSegment = segmentInfo{ + startTimestamp: startTime.Unix(), + currentClients: &activity.EntityActivityLog{ + Clients: out.Clients, + }, + tokenCount: &activity.TokenCount{ + CountByNamespaceID: make(map[string]uint64), + }, + clientSequenceNumber: sequenceNum, } + } else { + // populate this for edge case checking (if end of month passes while background loading on standby) + a.currentGlobalSegment.startTimestamp = startTime.Unix() + } + for _, client := range out.Clients { + a.globalPartialMonthClientTracker[client.ClientID] = client } // load current local segment @@ -975,30 +1047,31 @@ func (a *ActivityLog) loadCurrentClientSegment(ctx context.Context, startTime ti if err != nil { return err } - if data != nil { - out := &activity.EntityActivityLog{} - err = proto.Unmarshal(data.Value, out) - if err != nil { - return err - } + if data == nil { + return nil + } - if !a.core.perfStandby { - a.currentLocalSegment = segmentInfo{ - startTimestamp: startTime.Unix(), - currentClients: &activity.EntityActivityLog{ - Clients: out.Clients, - }, - tokenCount: a.currentLocalSegment.tokenCount, - clientSequenceNumber: localSegmentSequenceNumber, - } - } else { - // populate this for edge case checking (if end of month passes while background loading on standby) - a.currentLocalSegment.startTimestamp = startTime.Unix() - } - for _, client := range out.Clients { - a.partialMonthLocalClientTracker[client.ClientID] = client - } + out = &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } + if !a.core.perfStandby { + a.currentLocalSegment = segmentInfo{ + startTimestamp: startTime.Unix(), + currentClients: &activity.EntityActivityLog{ + Clients: out.Clients, + }, + tokenCount: a.currentLocalSegment.tokenCount, + clientSequenceNumber: sequenceNum, + } + } else { + // populate this for edge case checking (if end of month passes while background loading on standby) + a.currentLocalSegment.startTimestamp = startTime.Unix() + } + for _, client := range out.Clients { + a.partialMonthLocalClientTracker[client.ClientID] = client } return nil @@ -1055,15 +1128,14 @@ func (a *ActivityLog) loadTokenCount(ctx context.Context, startTime time.Time) e // We must load the tokenCount of the current segment into the activity log // so that TWEs counted before the introduction of a client ID for TWEs are // still reported in the partial client counts. + a.currentSegment.tokenCount = out a.currentLocalSegment.tokenCount = out return nil } -// entityBackgroundLoader loads entity activity log records for start_date `t`. -// If isLocal is true, it loads the local entity activity log records else it -// loads global entity activity log records. -func (a *ActivityLog) entityBackgroundLoader(ctx context.Context, wg *sync.WaitGroup, t time.Time, seqNums <-chan uint64, isLocal bool) { +// entityBackgroundLoader loads entity activity log records for start_date `t` +func (a *ActivityLog) entityBackgroundLoader(ctx context.Context, wg *sync.WaitGroup, t time.Time, seqNums <-chan uint64) { defer wg.Done() for seqNum := range seqNums { select { @@ -1073,7 +1145,7 @@ func (a *ActivityLog) entityBackgroundLoader(ctx context.Context, wg *sync.WaitG default: } - err := a.loadPriorEntitySegment(ctx, t, seqNum, isLocal) + err := a.loadPriorEntitySegment(ctx, t, seqNum) if err != nil { a.logger.Error("error loading entity activity log", "time", t, "sequence", seqNum, "err", err) } @@ -1097,7 +1169,7 @@ func (a *ActivityLog) newMonthCurrentLogLocked(currentTime time.Time) { } // Initialize a new current segment, based on the given time -// should be called with globalFragmentLock, localFragmentLock and l held. +// should be called with fragmentLock, globalFragmentLock, localFragmentLock and l held. func (a *ActivityLog) newSegmentAtGivenTime(t time.Time) { timestamp := t.Unix() @@ -1110,17 +1182,26 @@ func (a *ActivityLog) newSegmentAtGivenTime(t time.Time) { // should be called with l held. func (a *ActivityLog) setCurrentSegmentTimeLocked(t time.Time) { timestamp := t.Unix() + a.currentSegment.startTimestamp = timestamp a.currentGlobalSegment.startTimestamp = timestamp a.currentLocalSegment.startTimestamp = timestamp - // setting a.currentSegment timestamp to support upgrades - a.currentSegment.startTimestamp = timestamp } // Reset all the current segment state. -// Should be called with globalFragmentLock, localFragmentLock and l held. +// Should be called with fragmentLock, globalFragmentLock, localFragmentLock and l held. func (a *ActivityLog) resetCurrentLog() { - // setting a.currentSegment timestamp to support upgrades a.currentSegment.startTimestamp = 0 + a.currentSegment.currentClients = &activity.EntityActivityLog{ + Clients: make([]*activity.EntityRecord, 0), + } + + // We must still initialize the tokenCount to recieve tokenCounts from fragments + // during the month where customers upgrade to 1.9 + a.currentSegment.tokenCount = &activity.TokenCount{ + CountByNamespaceID: make(map[string]uint64), + } + + a.currentSegment.clientSequenceNumber = 0 // global segment a.currentGlobalSegment.startTimestamp = 0 @@ -1136,12 +1217,16 @@ func (a *ActivityLog) resetCurrentLog() { } a.currentLocalSegment.clientSequenceNumber = 0 + a.fragment = nil + a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) + a.currentGlobalFragment = nil a.globalPartialMonthClientTracker = make(map[string]*activity.EntityRecord) a.localFragment = nil a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) + a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) a.standbyLocalFragmentsReceived = make([]*activity.LogFragment, 0) a.standbyGlobalFragmentsReceived = make([]*activity.LogFragment, 0) a.secondaryGlobalClientFragments = make([]*activity.LogFragment, 0) @@ -1149,6 +1234,7 @@ func (a *ActivityLog) resetCurrentLog() { func (a *ActivityLog) deleteLogWorker(ctx context.Context, startTimestamp int64, whenDone chan struct{}) { entityPathsToDelete := make([]string, 0) + entityPathsToDelete = append(entityPathsToDelete, fmt.Sprintf("%v%v/", activityEntityBasePath, startTimestamp)) entityPathsToDelete = append(entityPathsToDelete, fmt.Sprintf("%s%v%v/", activityGlobalPathPrefix, activityEntityBasePath, startTimestamp)) entityPathsToDelete = append(entityPathsToDelete, fmt.Sprintf("%s%v%v/", activityLocalPathPrefix, activityEntityBasePath, startTimestamp)) entityPathsToDelete = append(entityPathsToDelete, fmt.Sprintf("%v%v/", activityTokenLocalBasePath, startTimestamp)) @@ -1264,7 +1350,7 @@ func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGro } // load entity logs from storage into memory - localLastSegment, globalLastSegment, segmentsExist, err := a.getLastEntitySegmentNumber(ctx, mostRecent) + lastSegment, localLastSegment, globalLastSegment, segmentsExist, err := a.getLastEntitySegmentNumber(ctx, mostRecent) if err != nil { return err } @@ -1273,39 +1359,20 @@ func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGro return nil } - err = a.loadCurrentClientSegment(ctx, mostRecent, localLastSegment, globalLastSegment) - // if both localLastSegment and globalLastSegment are 0, it will return nil here - if err != nil || (localLastSegment == 0 && globalLastSegment == 0) { + err = a.loadCurrentClientSegment(ctx, mostRecent, lastSegment, localLastSegment, globalLastSegment) + if err != nil || lastSegment == 0 { return err } + lastSegment-- - // if last local segment that got loaded using loadCurrentClientSegment is not 0, there are more local segments to load - if localLastSegment != 0 { - localLastSegment-- + seqNums := make(chan uint64, lastSegment+1) + wg.Add(1) + go a.entityBackgroundLoader(ctx, wg, mostRecent, seqNums) - localSeqNums := make(chan uint64, localLastSegment+1) - wg.Add(1) - go a.entityBackgroundLoader(ctx, wg, mostRecent, localSeqNums, true) - - for n := int(localLastSegment); n >= 0; n-- { - localSeqNums <- uint64(n) - } - close(localSeqNums) - } - - // if last global segment that got loaded using loadCurrentClientSegment is not 0, there are more global segments to load - if globalLastSegment != 0 { - globalLastSegment-- - - globalSeqNums := make(chan uint64, globalLastSegment+1) - wg.Add(1) - go a.entityBackgroundLoader(ctx, wg, mostRecent, globalSeqNums, false) - - for n := int(globalLastSegment); n >= 0; n-- { - globalSeqNums <- uint64(n) - } - close(globalSeqNums) + for n := int(lastSegment); n >= 0; n-- { + seqNums <- uint64(n) } + close(seqNums) return nil } @@ -1358,16 +1425,16 @@ func (a *ActivityLog) SetConfig(ctx context.Context, config activityConfig) { a.logger.Info("activity log enable changed", "original", originalEnabled, "current", a.enabled) } - if !a.enabled && a.currentGlobalSegment.startTimestamp != 0 && a.currentLocalSegment.startTimestamp != 0 { + if !a.enabled && a.currentSegment.startTimestamp != 0 && a.currentGlobalSegment.startTimestamp != 0 && a.currentLocalSegment.startTimestamp != 0 { a.logger.Trace("deleting current segment") a.deleteDone = make(chan struct{}) // this is called from a request under stateLock, so use activeContext - go a.deleteLogWorker(a.core.activeContext, a.currentGlobalSegment.startTimestamp, a.deleteDone) + go a.deleteLogWorker(a.core.activeContext, a.currentSegment.startTimestamp, a.deleteDone) a.resetCurrentLog() } forceSave := false - if a.enabled && a.currentGlobalSegment.startTimestamp == 0 && a.currentLocalSegment.startTimestamp == 0 { + if a.enabled && a.currentSegment.startTimestamp == 0 && a.currentGlobalSegment.startTimestamp == 0 && a.currentLocalSegment.startTimestamp == 0 { a.startNewCurrentLogLocked(a.clock.Now().UTC()) // Force a save so we can distinguish between // @@ -1386,6 +1453,7 @@ func (a *ActivityLog) SetConfig(ctx context.Context, config activityConfig) { if forceSave { // l is still held here + a.saveCurrentSegmentInternal(ctx, true, a.currentSegment, "") a.saveCurrentSegmentInternal(ctx, true, a.currentGlobalSegment, activityGlobalPathPrefix) a.saveCurrentSegmentInternal(ctx, true, a.currentLocalSegment, activityLocalPathPrefix) } @@ -1622,10 +1690,10 @@ func (a *ActivityLog) StartOfNextMonth() time.Time { a.l.RLock() defer a.l.RUnlock() var segmentStart time.Time - if a.currentGlobalSegment.startTimestamp == 0 { + if a.currentSegment.startTimestamp == 0 { segmentStart = a.clock.Now().UTC() } else { - segmentStart = time.Unix(a.currentGlobalSegment.startTimestamp, 0).UTC() + segmentStart = time.Unix(a.currentSegment.startTimestamp, 0).UTC() } // Basing this on the segment start will mean we trigger EOM rollover when // necessary because we were down. @@ -1800,6 +1868,12 @@ func (a *ActivityLog) perfStandbyFragmentWorker(ctx context.Context) { } sendFunc() + // clear active entity set + a.fragmentLock.Lock() + a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) + + a.fragmentLock.Unlock() + // clear local active entity set a.localFragmentLock.Lock() a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) @@ -1916,7 +1990,7 @@ func (a *ActivityLog) HandleEndOfMonth(ctx context.Context, currentTime time.Tim a.logger.Trace("starting end of month processing", "rolloverTime", currentTime) - err := a.writeIntentLog(ctx, a.currentGlobalSegment.startTimestamp, currentTime) + err := a.writeIntentLog(ctx, a.currentSegment.startTimestamp, currentTime) if err != nil { return err } @@ -1975,38 +2049,42 @@ func (a *ActivityLog) writeIntentLog(ctx context.Context, prevSegmentTimestamp i return nil } -// ResetActivityLog is used to extract the current local and global fragment(s) during +// ResetActivityLog is used to extract the current fragment(s) during // integration testing, so that it can be checked in a race-free way. -func (c *Core) ResetActivityLog() ([]*activity.LogFragment, []*activity.LogFragment) { +func (c *Core) ResetActivityLog() []*activity.LogFragment { c.stateLock.RLock() a := c.activityLog c.stateLock.RUnlock() if a == nil { - return nil, nil + return nil } - localFragments := make([]*activity.LogFragment, 0) - globalFragments := make([]*activity.LogFragment, 0) + allFragments := make([]*activity.LogFragment, 1) + a.fragmentLock.Lock() + + allFragments[0] = a.fragment + a.fragment = nil + allFragments = append(allFragments, a.standbyFragmentsReceived...) + a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) + a.secondaryGlobalClientFragments = make([]*activity.LogFragment, 0) + a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) + a.fragmentLock.Unlock() // local fragments a.localFragmentLock.Lock() - localFragments = append(localFragments, a.localFragment) + allFragments = append(allFragments, a.localFragment) a.localFragment = nil - localFragments = append(localFragments, a.standbyLocalFragmentsReceived...) + allFragments = append(allFragments, a.standbyLocalFragmentsReceived...) a.standbyLocalFragmentsReceived = make([]*activity.LogFragment, 0) a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) a.localFragmentLock.Unlock() // global fragments a.globalFragmentLock.Lock() - globalFragments = append(globalFragments, a.currentGlobalFragment) - a.currentGlobalFragment = nil - globalFragments = append(globalFragments, a.standbyGlobalFragmentsReceived...) a.globalPartialMonthClientTracker = make(map[string]*activity.EntityRecord) a.standbyGlobalFragmentsReceived = make([]*activity.LogFragment, 0) - a.secondaryGlobalClientFragments = make([]*activity.LogFragment, 0) a.globalFragmentLock.Unlock() - return localFragments, globalFragments + return allFragments } func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, timestamp int64) { @@ -2043,7 +2121,7 @@ func (a *ActivityLog) AddActivityToFragment(clientID string, namespaceID string, a.fragmentLock.RLock() if a.enabled { - _, presentInRegularClientMap := a.globalPartialMonthClientTracker[clientID] + _, presentInRegularClientMap := a.partialMonthClientTracker[clientID] _, presentInLocalClientmap := a.partialMonthLocalClientTracker[clientID] if presentInRegularClientMap || presentInLocalClientmap { present = true @@ -2068,7 +2146,7 @@ func (a *ActivityLog) AddActivityToFragment(clientID string, namespaceID string, defer a.globalFragmentLock.Unlock() // Re-check entity ID after re-acquiring lock - _, presentInRegularClientMap := a.globalPartialMonthClientTracker[clientID] + _, presentInRegularClientMap := a.partialMonthClientTracker[clientID] _, presentInLocalClientmap := a.partialMonthLocalClientTracker[clientID] if presentInRegularClientMap || presentInLocalClientmap { present = true @@ -2096,6 +2174,10 @@ func (a *ActivityLog) AddActivityToFragment(clientID string, namespaceID string, clientRecord.NonEntity = true } + // add the clients to the regular fragment + a.fragment.Clients = append(a.fragment.Clients, clientRecord) + a.partialMonthClientTracker[clientRecord.ClientID] = clientRecord + if local, _ := a.isClientLocal(clientRecord); local { // If the client is local then add the client to the current local fragment a.localFragment.Clients = append(a.localFragment.Clients, clientRecord) @@ -2130,10 +2212,17 @@ func (a *ActivityLog) isClientLocal(client *activity.EntityRecord) (bool, error) return false, nil } -// Create the fragments (local fragment and global fragment) if it doesn't already exist. +// Create the fragments (regular fragment, local fragment and global fragment) if it doesn't already exist. // Must be called with the fragmentLock, localFragmentLock and globalFragmentLock held. func (a *ActivityLog) createCurrentFragment() { - if a.currentGlobalFragment == nil { + if a.fragment == nil { + // create regular fragment + a.fragment = &activity.LogFragment{ + OriginatingNode: a.nodeID, + Clients: make([]*activity.EntityRecord, 0, 120), + NonEntityTokens: make(map[string]uint64), + } + // create local fragment a.localFragment = &activity.LogFragment{ OriginatingNode: a.nodeID, @@ -2143,7 +2232,6 @@ func (a *ActivityLog) createCurrentFragment() { // create global fragment a.currentGlobalFragment = &activity.LogFragment{ - OriginatingNode: a.nodeID, OriginatingCluster: a.core.ClusterID(), Clients: make([]*activity.EntityRecord, 0), } @@ -2205,6 +2293,7 @@ func (a *ActivityLog) receivedFragment(fragment *activity.LogFragment) { } for _, e := range fragment.Clients { + a.partialMonthClientTracker[e.ClientID] = e if isLocalFragment { a.partialMonthLocalClientTracker[e.ClientID] = e } else { @@ -2212,6 +2301,8 @@ func (a *ActivityLog) receivedFragment(fragment *activity.LogFragment) { } } + a.standbyFragmentsReceived = append(a.standbyFragmentsReceived, fragment) + if isLocalFragment { a.standbyLocalFragmentsReceived = append(a.standbyLocalFragmentsReceived, fragment) } else { @@ -2875,7 +2966,7 @@ func (a *ActivityLog) segmentToPrecomputedQuery(ctx context.Context, segmentTime // Iterate through entities, adding them to the hyperloglog and the summary maps in opts for { - entity, err := reader.ReadGlobalEntity(ctx) + entity, err := reader.ReadEntity(ctx) if errors.Is(err, io.EOF) { break } @@ -2890,23 +2981,6 @@ func (a *ActivityLog) segmentToPrecomputedQuery(ctx context.Context, segmentTime } } - for { - entity, err := reader.ReadLocalEntity(ctx) - if errors.Is(err, io.EOF) { - break - } - if err != nil { - a.logger.Warn("failed to read segment", "error", err) - return err - } - err = a.handleEntitySegment(entity, segmentTime, hyperloglog, opts) - if err != nil { - a.logger.Warn("failed to handle entity segment", "error", err) - return err - } - - } - // Store the hyperloglog err = a.StoreHyperlogLog(ctx, segmentTime, hyperloglog) if err != nil { @@ -3059,7 +3133,7 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context, intent *Activi // too old, and startTimestamp should only go forward (unless it is zero.) // If there's an intent log, finish it even if the feature is currently disabled. a.l.RLock() - currentMonth := a.currentGlobalSegment.startTimestamp + currentMonth := a.currentSegment.startTimestamp // Base retention period on the month we are generating (even in the past)--- a.clock.Now() // would work but this will be easier to control in tests. retentionWindow := timeutil.MonthsPreviousTo(a.retentionMonths, time.Unix(intent.NextMonth, 0).UTC()) @@ -3198,7 +3272,7 @@ func (a *ActivityLog) PartialMonthMetrics(ctx context.Context) ([]metricsutil.Ga // Empty list return []metricsutil.GaugeLabelValues{}, nil } - count := len(a.globalPartialMonthClientTracker) + len(a.partialMonthLocalClientTracker) + count := len(a.partialMonthClientTracker) return []metricsutil.GaugeLabelValues{ { @@ -3224,7 +3298,7 @@ func (a *ActivityLog) populateNamespaceAndMonthlyBreakdowns() (map[int64]*proces // Parse the monthly clients and prepare the breakdowns. byNamespace := make(map[string]*processByNamespace) byMonth := make(map[int64]*processMonth) - for _, e := range a.globalPartialMonthClientTracker { + for _, e := range a.partialMonthClientTracker { processClientRecord(e, byNamespace, byMonth, a.clock.Now()) } for _, e := range a.partialMonthLocalClientTracker { diff --git a/vault/activity_log_test.go b/vault/activity_log_test.go index 1f36a78565..4742d11467 100644 --- a/vault/activity_log_test.go +++ b/vault/activity_log_test.go @@ -34,7 +34,7 @@ import ( "github.com/stretchr/testify/require" ) -// TestActivityLog_Creation calls AddEntityToFragment and verifies that it appears correctly in a.currentGlobalFragment. +// TestActivityLog_Creation calls AddEntityToFragment and verifies that it appears correctly in a.fragment. func TestActivityLog_Creation(t *testing.T) { storage := &logical.InmemStorage{} coreConfig := &CoreConfig{ @@ -56,13 +56,11 @@ func TestActivityLog_Creation(t *testing.T) { if a.logger == nil || a.view == nil { t.Fatal("activity log not initialized") } - currentGlobalFragment := core.GetActiveGlobalFragment() - if currentGlobalFragment != nil { - t.Fatal("activity log already has global fragment") + if a.fragment != nil || a.currentGlobalFragment != nil { + t.Fatal("activity log already has fragment") } - localFragment := core.GetActiveLocalFragment() - if localFragment != nil { + if a.localFragment != nil { t.Fatal("activity log already has a local fragment") } @@ -71,29 +69,44 @@ func TestActivityLog_Creation(t *testing.T) { ts := time.Now() a.AddEntityToFragment(entity_id, namespace_id, ts.Unix()) - currentGlobalFragment = core.GetActiveGlobalFragment() - localFragment = core.GetActiveLocalFragment() - - if currentGlobalFragment == nil { + if a.fragment == nil || a.currentGlobalFragment == nil { t.Fatal("no fragment created") } - if a.currentGlobalFragment.OriginatingNode != a.nodeID { - t.Errorf("mismatched node ID, %q vs %q", currentGlobalFragment.OriginatingNode, a.nodeID) + if a.fragment.OriginatingNode != a.nodeID { + t.Errorf("mismatched node ID, %q vs %q", a.fragment.OriginatingNode, a.nodeID) } - if currentGlobalFragment.OriginatingCluster != a.core.ClusterID() { - t.Errorf("mismatched cluster ID, %q vs %q", currentGlobalFragment.GetOriginatingCluster(), a.core.ClusterID()) + if a.currentGlobalFragment.OriginatingCluster != a.core.ClusterID() { + t.Errorf("mismatched cluster ID, %q vs %q", a.currentGlobalFragment.GetOriginatingCluster(), a.core.ClusterID()) } - if currentGlobalFragment.Clients == nil { + if a.fragment.Clients == nil || a.currentGlobalFragment.Clients == nil { t.Fatal("no fragment entity slice") } - if len(currentGlobalFragment.Clients) != 1 { - t.Fatalf("wrong number of entities %v", len(currentGlobalFragment.Clients)) + if a.fragment.NonEntityTokens == nil { + t.Fatal("no fragment token map") } - er := currentGlobalFragment.Clients[0] + if len(a.fragment.Clients) != 1 { + t.Fatalf("wrong number of entities %v", len(a.fragment.Clients)) + } + if len(a.currentGlobalFragment.Clients) != 1 { + t.Fatalf("wrong number of entities %v", len(a.currentGlobalFragment.Clients)) + } + + er := a.fragment.Clients[0] + if er.ClientID != entity_id { + t.Errorf("mimatched entity ID, %q vs %q", er.ClientID, entity_id) + } + if er.NamespaceID != namespace_id { + t.Errorf("mimatched namespace ID, %q vs %q", er.NamespaceID, namespace_id) + } + if er.Timestamp != ts.Unix() { + t.Errorf("mimatched timestamp, %v vs %v", er.Timestamp, ts.Unix()) + } + + er = a.currentGlobalFragment.Clients[0] if er.ClientID != entity_id { t.Errorf("mimatched entity ID, %q vs %q", er.ClientID, entity_id) } @@ -105,14 +118,22 @@ func TestActivityLog_Creation(t *testing.T) { } // Reset and test the other code path + a.fragment = nil a.AddTokenToFragment(namespace_id) - currentGlobalFragment = core.GetActiveGlobalFragment() - localFragment = core.GetActiveLocalFragment() - if currentGlobalFragment == nil { + if a.fragment == nil { t.Fatal("no fragment created") } + if a.fragment.NonEntityTokens == nil { + t.Fatal("no fragment token map") + } + + actual := a.fragment.NonEntityTokens[namespace_id] + if actual != 1 { + t.Errorf("mismatched number of tokens, %v vs %v", actual, 1) + } + // test local fragment localMe := &MountEntry{ Table: credentialTableType, @@ -128,25 +149,24 @@ func TestActivityLog_Creation(t *testing.T) { local_ts := time.Now() a.AddClientToFragment(local_entity_id, "root", local_ts.Unix(), false, "local_mount_accessor") - localFragment = core.GetActiveLocalFragment() - if localFragment.OriginatingNode != a.nodeID { - t.Errorf("mismatched node ID, %q vs %q", localFragment.OriginatingNode, a.nodeID) + if a.localFragment.OriginatingNode != a.nodeID { + t.Errorf("mismatched node ID, %q vs %q", a.localFragment.OriginatingNode, a.nodeID) } - if localFragment.Clients == nil { + if a.localFragment.Clients == nil { t.Fatal("no local fragment entity slice") } - if localFragment.NonEntityTokens == nil { + if a.localFragment.NonEntityTokens == nil { t.Fatal("no local fragment token map") } - if len(localFragment.Clients) != 1 { - t.Fatalf("wrong number of entities %v", len(localFragment.Clients)) + if len(a.localFragment.Clients) != 1 { + t.Fatalf("wrong number of entities %v", len(a.localFragment.Clients)) } - er = localFragment.Clients[0] + er = a.localFragment.Clients[0] if er.ClientID != local_entity_id { t.Errorf("mimatched entity ID, %q vs %q", er.ClientID, local_entity_id) } @@ -172,13 +192,17 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { if a.logger == nil || a.view == nil { t.Fatal("activity log not initialized") } - if core.GetActiveGlobalFragment() != nil { + a.fragmentLock.Lock() + if a.fragment != nil || a.currentGlobalFragment != nil { t.Fatal("activity log already has fragment") } + a.fragmentLock.Unlock() - if core.GetActiveLocalFragment() != nil { + a.localFragmentLock.Lock() + if a.localFragment != nil { t.Fatal("activity log already has local fragment") } + a.localFragmentLock.Unlock() const namespace_id = "ns123" @@ -196,9 +220,11 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { t.Fatal(err) } - if core.GetActiveGlobalFragment() != nil { + a.fragmentLock.Lock() + if a.fragment != nil || a.currentGlobalFragment != nil { t.Fatal("fragment created") } + a.fragmentLock.Unlock() teNew := &logical.TokenEntry{ Path: "test", @@ -214,9 +240,11 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { t.Fatal(err) } - if core.GetActiveGlobalFragment() != nil { + a.fragmentLock.Lock() + if a.fragment != nil || a.currentGlobalFragment != nil { t.Fatal("fragment created") } + a.fragmentLock.Unlock() } func checkExpectedEntitiesInMap(t *testing.T, a *ActivityLog, entityIDs []string) { @@ -252,15 +280,36 @@ func TestActivityLog_UniqueEntities(t *testing.T) { a.AddEntityToFragment(id2, "root", t3.Unix()) a.AddEntityToFragment(id1, "root", t3.Unix()) - currentGlobalFragment := core.GetActiveGlobalFragment() - if currentGlobalFragment == nil { - t.Fatal("no current global fragment") - } - if len(currentGlobalFragment.Clients) != 2 { - t.Fatalf("number of entities is %v", len(currentGlobalFragment.Clients)) + if a.fragment == nil || a.currentGlobalFragment == nil { + t.Fatal("no current fragment") } - for i, e := range currentGlobalFragment.Clients { + if len(a.fragment.Clients) != 2 { + t.Fatalf("number of entities is %v", len(a.fragment.Clients)) + } + if len(a.currentGlobalFragment.Clients) != 2 { + t.Fatalf("number of entities is %v", len(a.currentGlobalFragment.Clients)) + } + + for i, e := range a.fragment.Clients { + expectedID := id1 + expectedTime := t1.Unix() + expectedNS := "root" + if i == 1 { + expectedID = id2 + expectedTime = t2.Unix() + } + if e.ClientID != expectedID { + t.Errorf("%v: expected %q, got %q", i, expectedID, e.ClientID) + } + if e.NamespaceID != expectedNS { + t.Errorf("%v: expected %q, got %q", i, expectedNS, e.NamespaceID) + } + if e.Timestamp != expectedTime { + t.Errorf("%v: expected %v, got %v", i, expectedTime, e.Timestamp) + } + } + for i, e := range a.currentGlobalFragment.Clients { expectedID := id1 expectedTime := t1.Unix() expectedNS := "root" @@ -361,11 +410,11 @@ func TestActivityLog_SaveTokensToStorage(t *testing.T) { if err != nil { t.Fatalf("got error writing tokens to storage: %v", err) } - if core.GetActiveGlobalFragment() != nil { + if a.fragment != nil || a.currentGlobalFragment != nil { t.Errorf("fragment was not reset after write to storage") } - if core.GetActiveLocalFragment() != nil { + if a.localFragment != nil { t.Errorf("local fragment was not reset after write to storage") } @@ -397,12 +446,11 @@ func TestActivityLog_SaveTokensToStorage(t *testing.T) { if err != nil { t.Fatalf("got error writing tokens to storage: %v", err) } - - if core.GetActiveGlobalFragment() != nil { + if a.fragment != nil || a.currentGlobalFragment != nil { t.Errorf("fragment was not reset after write to storage") } - if core.GetActiveLocalFragment() != nil { + if a.localFragment != nil { t.Errorf("local fragment was not reset after write to storage") } @@ -444,8 +492,7 @@ func TestActivityLog_SaveTokensToStorageDoesNotUpdateTokenCount(t *testing.T) { a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment tokenPath := fmt.Sprintf("%sdirecttokens/%d/0", ActivityLogLocalPrefix, a.GetStartTimestamp()) - clientPath := fmt.Sprintf("sys/counters/activity/global/log/entity/%d/0", a.GetStartTimestamp()) - localPath := fmt.Sprintf("sys/counters/activity/local/log/entity/%d/0", a.GetStartTimestamp()) + clientPath := fmt.Sprintf("sys/counters/activity/log/entity/%d/0", a.GetStartTimestamp()) // Create some entries without entityIDs tokenEntryOne := logical.TokenEntry{NamespaceID: namespace.RootNamespaceID, Policies: []string{"hi"}} entityEntry := logical.TokenEntry{EntityID: "foo", NamespaceID: namespace.RootNamespaceID, Policies: []string{"hi"}} @@ -459,9 +506,6 @@ func TestActivityLog_SaveTokensToStorageDoesNotUpdateTokenCount(t *testing.T) { } } - // verify that the client got added to a local fragment - require.Len(t, core.GetActiveLocalFragment().Clients, 1) - idEntity, isTWE := entityEntry.CreateClientID() for i := 0; i < 2; i++ { err := a.HandleTokenUsage(ctx, &entityEntry, idEntity, isTWE) @@ -469,53 +513,35 @@ func TestActivityLog_SaveTokensToStorageDoesNotUpdateTokenCount(t *testing.T) { t.Fatal(err) } } - - // verify that the client got added to the global fragment - require.Len(t, core.GetActiveGlobalFragment().Clients, 1) - err := a.saveCurrentSegmentToStorage(ctx, false) if err != nil { t.Fatalf("got error writing TWEs to storage: %v", err) } // Assert that new elements have been written to the fragment - if core.GetActiveGlobalFragment() != nil { + if a.fragment != nil || a.currentGlobalFragment != nil { t.Errorf("fragment was not reset after write to storage") } - if core.GetActiveLocalFragment() != nil { + if a.localFragment != nil { t.Errorf("local fragment was not reset after write to storage") } // Assert that no tokens have been written to the fragment readSegmentFromStorageNil(t, core, tokenPath) - allClients := make([]*activity.EntityRecord, 0) e := readSegmentFromStorage(t, core, clientPath) out := &activity.EntityActivityLog{} err = proto.Unmarshal(e.Value, out) if err != nil { t.Fatalf("could not unmarshal protobuf: %v", err) } - if len(out.Clients) != 1 { - t.Fatalf("added 2 distinct entity tokens that should all result in the same ID, got: %d", len(out.Clients)) + if len(out.Clients) != 2 { + t.Fatalf("added 3 distinct TWEs and 2 distinct entity tokens that should all result in the same ID, got: %d", len(out.Clients)) } - allClients = append(allClients, out.Clients...) - - e = readSegmentFromStorage(t, core, localPath) - out = &activity.EntityActivityLog{} - err = proto.Unmarshal(e.Value, out) - if err != nil { - t.Fatalf("could not unmarshal protobuf: %v", err) - } - if len(out.Clients) != 1 { - t.Fatalf("added 3 distinct TWEs that should all result in the same ID, got: %d", len(out.Clients)) - } - allClients = append(allClients, out.Clients...) - nonEntityTokenFlag := false entityTokenFlag := false - for _, client := range allClients { + for _, client := range out.Clients { if client.NonEntity == true { nonEntityTokenFlag = true if client.ClientID != idNonEntity { @@ -552,6 +578,7 @@ func TestActivityLog_SaveEntitiesToStorage(t *testing.T) { now.Add(1 * time.Second).Unix(), now.Add(2 * time.Second).Unix(), } + path := fmt.Sprintf("%sentity/%d/0", ActivityLogPrefix, a.GetStartTimestamp()) globalPath := fmt.Sprintf("%sentity/%d/0", ActivityGlobalLogPrefix, a.GetStartTimestamp()) a.AddEntityToFragment(ids[0], "root", times[0]) @@ -560,14 +587,14 @@ func TestActivityLog_SaveEntitiesToStorage(t *testing.T) { if err != nil { t.Fatalf("got error writing entities to storage: %v", err) } - if core.GetActiveGlobalFragment() != nil { + if a.fragment != nil || a.currentGlobalFragment != nil { t.Errorf("fragment was not reset after write to storage") } - if core.GetActiveLocalFragment() != nil { + if a.localFragment != nil { t.Errorf("local fragment was not reset after write to storage") } - protoSegment := readSegmentFromStorage(t, core, globalPath) + protoSegment := readSegmentFromStorage(t, core, path) out := &activity.EntityActivityLog{} err = proto.Unmarshal(protoSegment.Value, out) if err != nil { @@ -582,6 +609,14 @@ func TestActivityLog_SaveEntitiesToStorage(t *testing.T) { t.Fatalf("got error writing segments to storage: %v", err) } + protoSegment = readSegmentFromStorage(t, core, path) + out = &activity.EntityActivityLog{} + err = proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + expectedEntityIDs(t, out, ids) + protoSegment = readSegmentFromStorage(t, core, globalPath) out = &activity.EntityActivityLog{} err = proto.Unmarshal(protoSegment.Value, out) @@ -651,7 +686,7 @@ func TestActivityLog_SaveEntitiesToStorageCommon(t *testing.T) { if err != nil { t.Fatalf("got error writing entities to storage: %v", err) } - if core.GetActiveGlobalFragment() != nil || core.GetActiveLocalFragment() != nil { + if a.fragment != nil { t.Errorf("fragment was not reset after write to storage") } @@ -740,8 +775,8 @@ func TestModifyResponseMonthsNilAppend(t *testing.T) { } // TestActivityLog_ReceivedFragment calls receivedFragment with a fragment and verifies it gets added to -// standbyGlobalFragmentsReceived. Send the same fragment again and then verify that it doesn't change the entity map but does -// get added to standbyGlobalFragmentsReceived. +// standbyFragmentsReceived and standbyGlobalFragmentsReceived. Send the same fragment again and then verify that it doesn't change the entity map but does +// get added to standbyFragmentsReceived and standbyGlobalFragmentsReceived. func TestActivityLog_ReceivedFragment(t *testing.T) { core, _, _ := TestCoreUnsealed(t) a := core.activityLog @@ -771,7 +806,7 @@ func TestActivityLog_ReceivedFragment(t *testing.T) { NonEntityTokens: make(map[string]uint64), } - if len(a.standbyGlobalFragmentsReceived) != 0 { + if len(a.standbyFragmentsReceived) != 0 { t.Fatalf("fragment already received") } @@ -779,6 +814,10 @@ func TestActivityLog_ReceivedFragment(t *testing.T) { checkExpectedEntitiesInMap(t, a, ids) + if len(a.standbyFragmentsReceived) != 1 { + t.Fatalf("fragment count is %v, expected 1", len(a.standbyFragmentsReceived)) + } + if len(a.standbyGlobalFragmentsReceived) != 1 { t.Fatalf("fragment count is %v, expected 1", len(a.standbyGlobalFragmentsReceived)) } @@ -788,6 +827,9 @@ func TestActivityLog_ReceivedFragment(t *testing.T) { checkExpectedEntitiesInMap(t, a, ids) + if len(a.standbyFragmentsReceived) != 2 { + t.Fatalf("fragment count is %v, expected 2", len(a.standbyFragmentsReceived)) + } if len(a.standbyGlobalFragmentsReceived) != 2 { t.Fatalf("fragment count is %v, expected 2", len(a.standbyGlobalFragmentsReceived)) } @@ -814,17 +856,12 @@ func TestActivityLog_availableLogs(t *testing.T) { // set up a few files in storage core, _, _ := TestCoreUnsealed(t) a := core.activityLog - globalPaths := [...]string{"entity/1111/1", "entity/992/3", "entity/991/1"} - localPaths := [...]string{"entity/1111/1", "entity/992/3", "entity/990/1"} + paths := [...]string{"entity/1111/1", "entity/992/3"} tokenPaths := [...]string{"directtokens/1111/1", "directtokens/1000000/1", "directtokens/992/1"} - expectedTimes := [...]time.Time{time.Unix(1000000, 0), time.Unix(1111, 0), time.Unix(992, 0), time.Unix(991, 0), time.Unix(990, 0)} + expectedTimes := [...]time.Time{time.Unix(1000000, 0), time.Unix(1111, 0), time.Unix(992, 0)} - for _, path := range globalPaths { - WriteToStorage(t, core, ActivityGlobalLogPrefix+path, []byte("test")) - } - - for _, path := range localPaths { - WriteToStorage(t, core, ActivityLogLocalPrefix+path, []byte("test")) + for _, path := range paths { + WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test")) } for _, path := range tokenPaths { @@ -913,7 +950,7 @@ func TestActivityLog_createRegenerationIntentLog(t *testing.T) { } for _, subPath := range paths { - fullPath := ActivityGlobalLogPrefix + subPath + fullPath := ActivityLogPrefix + subPath WriteToStorage(t, core, fullPath, []byte("test")) deletePaths = append(deletePaths, fullPath) } @@ -962,9 +999,9 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment startTimestamp := a.GetStartTimestamp() - path0 := fmt.Sprintf("sys/counters/activity/global/log/entity/%d/0", startTimestamp) - path1 := fmt.Sprintf("sys/counters/activity/global/log/entity/%d/1", startTimestamp) - path2 := fmt.Sprintf("sys/counters/activity/global/log/entity/%d/2", startTimestamp) + path0 := fmt.Sprintf("sys/counters/activity/log/entity/%d/0", startTimestamp) + path1 := fmt.Sprintf("sys/counters/activity/log/entity/%d/1", startTimestamp) + path2 := fmt.Sprintf("sys/counters/activity/log/entity/%d/2", startTimestamp) tokenPath := fmt.Sprintf("sys/counters/activity/local/log/directtokens/%d/0", startTimestamp) genID := func(i int) string { @@ -1057,6 +1094,11 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { t.Fatalf("got error writing entities to storage: %v", err) } + seqNum := a.GetEntitySequenceNumber() + if seqNum != 2 { + t.Fatalf("expected sequence number 2, got %v", seqNum) + } + protoSegment0 = readSegmentFromStorage(t, core, path0) err = proto.Unmarshal(protoSegment0.Value, &entityLog0) if err != nil { @@ -1263,8 +1305,12 @@ func TestActivityLog_parseSegmentNumberFromPath(t *testing.T) { func TestActivityLog_getLastEntitySegmentNumber(t *testing.T) { core, _, _ := TestCoreUnsealed(t) a := core.activityLog + paths := [...]string{"entity/992/0", "entity/1000/-1", "entity/1001/foo", "entity/1111/0", "entity/1111/1"} globalPaths := [...]string{"entity/992/0", "entity/1000/-1", "entity/1001/foo", "entity/1111/1"} localPaths := [...]string{"entity/992/0", "entity/1000/-1", "entity/1001/foo", "entity/1111/0", "entity/1111/1"} + for _, path := range paths { + WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test")) + } for _, path := range globalPaths { WriteToStorage(t, core, ActivityGlobalLogPrefix+path, []byte("test")) } @@ -1274,36 +1320,42 @@ func TestActivityLog_getLastEntitySegmentNumber(t *testing.T) { testCases := []struct { input int64 + expectedVal uint64 expectedGlobalVal uint64 expectedLocalVal uint64 expectExists bool }{ { input: 992, + expectedVal: 0, expectedGlobalVal: 0, expectedLocalVal: 0, expectExists: true, }, { input: 1000, + expectedVal: 0, expectedGlobalVal: 0, expectedLocalVal: 0, expectExists: false, }, { input: 1001, + expectedVal: 0, expectedGlobalVal: 0, expectedLocalVal: 0, expectExists: false, }, { input: 1111, + expectedVal: 1, expectedGlobalVal: 1, expectedLocalVal: 1, expectExists: true, }, { input: 2222, + expectedVal: 0, expectedGlobalVal: 0, expectedLocalVal: 0, expectExists: false, @@ -1312,13 +1364,16 @@ func TestActivityLog_getLastEntitySegmentNumber(t *testing.T) { ctx := context.Background() for _, tc := range testCases { - localSegmentNumber, globalSegmentNumber, exists, err := a.getLastEntitySegmentNumber(ctx, time.Unix(tc.input, 0)) + result, localSegmentNumber, globalSegmentNumber, exists, err := a.getLastEntitySegmentNumber(ctx, time.Unix(tc.input, 0)) if err != nil { t.Fatalf("unexpected error for input %d: %v", tc.input, err) } if exists != tc.expectExists { t.Errorf("expected result exists: %t, got: %t for input: %d", tc.expectExists, exists, tc.input) } + if result != tc.expectedVal { + t.Errorf("expected: %d got: %d for input: %d", tc.expectedVal, result, tc.input) + } if globalSegmentNumber != tc.expectedGlobalVal { t.Errorf("expected: %d got: %d for input: %d", tc.expectedGlobalVal, globalSegmentNumber, tc.input) } @@ -1450,6 +1505,15 @@ func (a *ActivityLog) resetEntitiesInMemory(t *testing.T) { a.globalFragmentLock.Lock() defer a.globalFragmentLock.Unlock() + a.currentSegment = segmentInfo{ + startTimestamp: time.Time{}.Unix(), + currentClients: &activity.EntityActivityLog{ + Clients: make([]*activity.EntityRecord, 0), + }, + tokenCount: a.currentSegment.tokenCount, + clientSequenceNumber: 0, + } + a.currentGlobalSegment = segmentInfo{ startTimestamp: time.Time{}.Unix(), currentClients: &activity.EntityActivityLog{ @@ -1468,6 +1532,7 @@ func (a *ActivityLog) resetEntitiesInMemory(t *testing.T) { clientSequenceNumber: 0, } + a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) a.globalPartialMonthClientTracker = make(map[string]*activity.EntityRecord) } @@ -1484,6 +1549,7 @@ func TestActivityLog_loadCurrentClientSegment(t *testing.T) { CountByNamespaceID: tokenRecords, } a.l.Lock() + a.currentSegment.tokenCount = tokenCount a.currentLocalSegment.tokenCount = tokenCount a.l.Unlock() @@ -1544,6 +1610,7 @@ func TestActivityLog_loadCurrentClientSegment(t *testing.T) { if err != nil { t.Fatalf(err.Error()) } + WriteToStorage(t, core, ActivityLogPrefix+tc.path, data) WriteToStorage(t, core, ActivityGlobalLogPrefix+tc.path, data) WriteToStorage(t, core, ActivityLogLocalPrefix+tc.path, data) } @@ -1557,7 +1624,7 @@ func TestActivityLog_loadCurrentClientSegment(t *testing.T) { // loadCurrentClientSegment requires us to grab the fragment lock and the // activityLog lock, as per the comment in the loadCurrentClientSegment // function - err := a.loadCurrentClientSegment(ctx, time.Unix(tc.time, 0), tc.seqNum, tc.seqNum) + err := a.loadCurrentClientSegment(ctx, time.Unix(tc.time, 0), tc.seqNum, tc.seqNum, tc.seqNum) a.localFragmentLock.Unlock() a.globalFragmentLock.Unlock() a.fragmentLock.Unlock() @@ -1572,9 +1639,15 @@ func TestActivityLog_loadCurrentClientSegment(t *testing.T) { // verify accurate data in in-memory current segment require.Equal(t, tc.time, a.GetStartTimestamp()) + require.Equal(t, tc.seqNum, a.GetEntitySequenceNumber()) require.Equal(t, tc.seqNum, a.GetGlobalEntitySequenceNumber()) require.Equal(t, tc.seqNum, a.GetLocalEntitySequenceNumber()) + currentEntities := a.GetCurrentEntities() + if !entityRecordsEqual(t, currentEntities.Clients, tc.entities.Clients) { + t.Errorf("bad data loaded. expected: %v, got: %v for path %q", tc.entities.Clients, currentEntities, tc.path) + } + globalClients := core.GetActiveGlobalClientsList() if err := ActiveEntitiesEqual(globalClients, tc.entities.Clients); err != nil { t.Errorf("bad data loaded into active global entities. expected only set of EntityID from %v in %v for path %q: %v", tc.entities.Clients, globalClients, tc.path, err) @@ -1666,6 +1739,7 @@ func TestActivityLog_loadPriorEntitySegment(t *testing.T) { if err != nil { t.Fatalf(err.Error()) } + WriteToStorage(t, core, ActivityLogPrefix+tc.path, data) WriteToStorage(t, core, ActivityGlobalLogPrefix+tc.path, data) WriteToStorage(t, core, ActivityLogLocalPrefix+tc.path, data) } @@ -1674,22 +1748,20 @@ func TestActivityLog_loadPriorEntitySegment(t *testing.T) { for _, tc := range testCases { if tc.refresh { a.l.Lock() + a.fragmentLock.Lock() a.localFragmentLock.Lock() + a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) a.globalPartialMonthClientTracker = make(map[string]*activity.EntityRecord) + a.currentSegment.startTimestamp = tc.time a.currentGlobalSegment.startTimestamp = tc.time a.currentLocalSegment.startTimestamp = tc.time + a.fragmentLock.Unlock() a.localFragmentLock.Unlock() a.l.Unlock() } - // load global segments - err := a.loadPriorEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum, false) - if err != nil { - t.Fatalf("got error loading data for %q: %v", tc.path, err) - } - // load local segments - err = a.loadPriorEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum, true) + err := a.loadPriorEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum) if err != nil { t.Fatalf("got error loading data for %q: %v", tc.path, err) } @@ -1865,12 +1937,14 @@ func setupActivityRecordsInStorage(t *testing.T, base time.Time, includeEntities } switch i { case 0: + WriteToStorage(t, core, ActivityLogPrefix+"entity/"+fmt.Sprint(monthsAgo.Unix())+"/0", entityData) WriteToStorage(t, core, ActivityGlobalLogPrefix+"entity/"+fmt.Sprint(monthsAgo.Unix())+"/0", entityData) case len(entityRecords) - 1: // local data WriteToStorage(t, core, ActivityLogLocalPrefix+"entity/"+fmt.Sprint(base.Unix())+"/"+strconv.Itoa(i-1), entityData) default: + WriteToStorage(t, core, ActivityLogPrefix+"entity/"+fmt.Sprint(base.Unix())+"/"+strconv.Itoa(i-1), entityData) WriteToStorage(t, core, ActivityGlobalLogPrefix+"entity/"+fmt.Sprint(base.Unix())+"/"+strconv.Itoa(i-1), entityData) } } @@ -1914,23 +1988,15 @@ func TestActivityLog_refreshFromStoredLog(t *testing.T) { } wg.Wait() - // active clients for the entire month expectedActive := &activity.EntityActivityLog{ Clients: expectedClientRecords[1:], } - expectedActiveGlobal := &activity.EntityActivityLog{ - Clients: expectedClientRecords[1 : len(expectedClientRecords)-1], - } - - // local client is only added to the newest segment for the current month. This should also appear in the active clients for the entire month. - expectedCurrentLocal := &activity.EntityActivityLog{ - Clients: expectedClientRecords[len(expectedClientRecords)-1:], - } - - // global clients added to the newest local entity segment expectedCurrent := &activity.EntityActivityLog{ Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], } + expectedCurrentLocal := &activity.EntityActivityLog{ + Clients: expectedClientRecords[len(expectedClientRecords)-1:], + } currentEntities := a.GetCurrentGlobalEntities() if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { @@ -1955,19 +2021,6 @@ func TestActivityLog_refreshFromStoredLog(t *testing.T) { // we expect activeClients to be loaded for the entire month t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v: %v", expectedActive.Clients, activeClients, err) } - - // verify active global clients list - activeGlobalClients := a.core.GetActiveGlobalClientsList() - if err := ActiveEntitiesEqual(activeGlobalClients, expectedActiveGlobal.Clients); err != nil { - // we expect activeClients to be loaded for the entire month - t.Errorf("bad data loaded into active global entities. expected only set of EntityID from %v in %v: %v", expectedActiveGlobal.Clients, activeGlobalClients, err) - } - // verify active local clients list - activeLocalClients := a.core.GetActiveLocalClientsList() - if err := ActiveEntitiesEqual(activeLocalClients, expectedCurrentLocal.Clients); err != nil { - // we expect activeClients to be loaded for the entire month - t.Errorf("bad data loaded into active local entities. expected only set of EntityID from %v in %v: %v", expectedCurrentLocal.Clients, activeLocalClients, err) - } } // TestActivityLog_refreshFromStoredLogWithBackgroundLoadingCancelled writes data from 3 months ago to this month. The @@ -2029,18 +2082,6 @@ func TestActivityLog_refreshFromStoredLogWithBackgroundLoadingCancelled(t *testi // we only expect activeClients to be loaded for the newest segment (for the current month) t.Error(err) } - - // verify if the right global clients are loaded for the newest segment (for the current month) - activeGlobalClients := a.core.GetActiveGlobalClientsList() - if err := ActiveEntitiesEqual(activeGlobalClients, expectedCurrent.Clients); err != nil { - t.Error(err) - } - - // the right local clients are loaded for the newest segment (for the current month) - activeLocalClients := a.core.GetActiveLocalClientsList() - if err := ActiveEntitiesEqual(activeLocalClients, currentLocalEntities.Clients); err != nil { - t.Error(err) - } } // TestActivityLog_refreshFromStoredLogContextCancelled writes data from 3 months ago to this month and calls @@ -2074,6 +2115,9 @@ func TestActivityLog_refreshFromStoredLogNoTokens(t *testing.T) { expectedActive := &activity.EntityActivityLog{ Clients: expectedClientRecords[1:], } + expectedCurrent := &activity.EntityActivityLog{ + Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], + } expectedCurrentGlobal := &activity.EntityActivityLog{ Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], } @@ -2081,6 +2125,12 @@ func TestActivityLog_refreshFromStoredLogNoTokens(t *testing.T) { Clients: expectedClientRecords[len(expectedClientRecords)-1:], } + currentEntities := a.GetCurrentEntities() + if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { + // we expect all segments for the current month to be loaded + t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities) + } + currentGlobalEntities := a.GetCurrentGlobalEntities() if !entityRecordsEqual(t, currentGlobalEntities.Clients, expectedCurrentGlobal.Clients) { // we only expect the newest entity segment to be loaded (for the current month) @@ -2124,7 +2174,7 @@ func TestActivityLog_refreshFromStoredLogNoEntities(t *testing.T) { t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount) } - currentEntities := a.GetCurrentGlobalEntities() + currentEntities := a.GetCurrentEntities() if len(currentEntities.Clients) > 0 { t.Errorf("expected no current entity segment to be loaded. got: %v", currentEntities) } @@ -2195,7 +2245,7 @@ func TestActivityLog_refreshFromStoredLogPreviousMonth(t *testing.T) { Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], } - currentEntities := a.GetCurrentGlobalEntities() + currentEntities := a.GetCurrentEntities() if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { // we only expect the newest entity segment to be loaded (for the current month) t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities) @@ -2293,7 +2343,16 @@ func TestActivityLog_DeleteWorker(t *testing.T) { "entity/1112/1", } for _, path := range paths { - WriteToStorage(t, core, ActivityGlobalLogPrefix+path, []byte("test")) + WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test")) + } + + localPaths := []string{ + "entity/1111/1", + "entity/1111/2", + "entity/1111/3", + "entity/1112/1", + } + for _, path := range localPaths { WriteToStorage(t, core, ActivityLogLocalPrefix+path, []byte("test")) } @@ -2317,14 +2376,14 @@ func TestActivityLog_DeleteWorker(t *testing.T) { } // Check segments still present - readSegmentFromStorage(t, core, ActivityGlobalLogPrefix+"entity/1112/1") + readSegmentFromStorage(t, core, ActivityLogPrefix+"entity/1112/1") readSegmentFromStorage(t, core, ActivityLogLocalPrefix+"entity/1112/1") readSegmentFromStorage(t, core, ActivityLogLocalPrefix+"directtokens/1112/1") // Check other segments not present - expectMissingSegment(t, core, ActivityGlobalLogPrefix+"entity/1111/1") - expectMissingSegment(t, core, ActivityGlobalLogPrefix+"entity/1111/2") - expectMissingSegment(t, core, ActivityGlobalLogPrefix+"entity/1111/3") + expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/1") + expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/2") + expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/3") expectMissingSegment(t, core, ActivityLogLocalPrefix+"entity/1111/1") expectMissingSegment(t, core, ActivityLogLocalPrefix+"entity/1111/2") expectMissingSegment(t, core, ActivityLogLocalPrefix+"entity/1111/3") @@ -2414,7 +2473,7 @@ func TestActivityLog_EnableDisable(t *testing.T) { } // verify segment exists - path := fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, seg1) + path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, seg1) readSegmentFromStorage(t, core, path) // Add in-memory fragment @@ -2444,7 +2503,7 @@ func TestActivityLog_EnableDisable(t *testing.T) { } // Verify empty segments are present - path = fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, seg2) + path = fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, seg2) readSegmentFromStorage(t, core, path) path = fmt.Sprintf("%vdirecttokens/%v/0", ActivityLogLocalPrefix, seg2) @@ -2485,8 +2544,6 @@ func TestActivityLog_EndOfMonth(t *testing.T) { id2 := "22222222-2222-2222-2222-222222222222" id3 := "33333333-3333-3333-3333-333333333333" id4 := "44444444-4444-4444-4444-444444444444" - - // add global data a.AddEntityToFragment(id1, "root", time.Now().Unix()) // add local data @@ -2510,13 +2567,22 @@ func TestActivityLog_EndOfMonth(t *testing.T) { a.HandleEndOfMonth(ctx, month1) // Check segment is present, with 1 entity - path := fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, segment0) + path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, segment0) protoSegment := readSegmentFromStorage(t, core, path) out := &activity.EntityActivityLog{} err = proto.Unmarshal(protoSegment.Value, out) if err != nil { t.Fatal(err) } + expectedEntityIDs(t, out, []string{id1, id4}) + + path = fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, segment0) + protoSegment = readSegmentFromStorage(t, core, path) + out = &activity.EntityActivityLog{} + err = proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatal(err) + } expectedEntityIDs(t, out, []string{id1}) path = fmt.Sprintf("%ventity/%v/0", ActivityLogLocalPrefix, segment0) @@ -2583,6 +2649,18 @@ func TestActivityLog_EndOfMonth(t *testing.T) { for i, tc := range testCases { t.Logf("checking segment %v timestamp %v", i, tc.SegmentTimestamp) + expectedAllEntities := make([]string, 0) + expectedAllEntities = append(expectedAllEntities, tc.ExpectedGlobalEntityIDs...) + expectedAllEntities = append(expectedAllEntities, tc.ExpectedLocalEntityIDs...) + path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, tc.SegmentTimestamp) + protoSegment := readSegmentFromStorage(t, core, path) + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(protoSegment.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + expectedEntityIDs(t, out, expectedAllEntities) + // Check for global entities at global storage path path = fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, tc.SegmentTimestamp) protoSegment = readSegmentFromStorage(t, core, path) @@ -2766,7 +2844,7 @@ func TestActivityLog_CalculatePrecomputedQueriesWithMixedTWEs(t *testing.T) { if err != nil { t.Fatal(err) } - path := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, segment.StartTime, segment.Segment) + path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) WriteToStorage(t, core, path, data) } expectedCounts := []struct { @@ -3047,10 +3125,10 @@ func TestActivityLog_SaveAfterDisable(t *testing.T) { t.Fatal(err) } - path := ActivityGlobalLogPrefix + "entity/0/0" + path := ActivityLogPrefix + "entity/0/0" expectMissingSegment(t, core, path) - path = fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, startTimestamp) + path = fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, startTimestamp) expectMissingSegment(t, core, path) } @@ -3151,7 +3229,7 @@ func TestActivityLog_Precompute(t *testing.T) { if err != nil { t.Fatal(err) } - path := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, segment.StartTime, segment.Segment) + path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) WriteToStorage(t, core, path, data) } @@ -3462,7 +3540,7 @@ func TestActivityLog_Precompute_SkipMonth(t *testing.T) { if err != nil { t.Fatal(err) } - path := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, segment.StartTime, segment.Segment) + path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) WriteToStorage(t, core, path, data) } @@ -3679,7 +3757,7 @@ func TestActivityLog_PrecomputeNonEntityTokensWithID(t *testing.T) { if err != nil { t.Fatal(err) } - path := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, segment.StartTime, segment.Segment) + path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) WriteToStorage(t, core, path, data) } @@ -4068,7 +4146,7 @@ func TestActivityLog_Deletion(t *testing.T) { for i, start := range times { // no entities in some months, just for fun for j := 0; j < (i+3)%5; j++ { - entityPath := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, start.Unix(), j) + entityPath := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, start.Unix(), j) paths[i] = append(paths[i], entityPath) WriteToStorage(t, core, entityPath, []byte("test")) } @@ -4444,7 +4522,7 @@ func TestActivityLog_partialMonthClientCountWithMultipleMountPaths(t *testing.T) if err != nil { t.Fatalf(err.Error()) } - storagePath := fmt.Sprintf("%sentity/%d/%d", ActivityGlobalLogPrefix, timeutil.StartOfMonth(now).Unix(), i) + storagePath := fmt.Sprintf("%sentity/%d/%d", ActivityLogPrefix, timeutil.StartOfMonth(now).Unix(), i) WriteToStorage(t, core, storagePath, entityData) } @@ -5084,6 +5162,7 @@ func TestAddActivityToFragment(t *testing.T) { a := core.activityLog a.SetEnable(true) + require.Nil(t, a.fragment) require.Nil(t, a.localFragment) require.Nil(t, a.currentGlobalFragment) @@ -5175,6 +5254,10 @@ func TestAddActivityToFragment(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var mountAccessor string + a.fragmentLock.RLock() + numClientsBefore := len(a.fragment.Clients) + a.fragmentLock.RUnlock() + a.globalFragmentLock.RLock() globalClientsBefore := len(a.currentGlobalFragment.Clients) a.globalFragmentLock.RUnlock() @@ -5198,6 +5281,9 @@ func TestAddActivityToFragment(t *testing.T) { a.AddActivityToFragment(tc.id, ns, 0, tc.activityType, mount) } + a.fragmentLock.RLock() + defer a.fragmentLock.RUnlock() + numClientsAfter := len(a.fragment.Clients) a.globalFragmentLock.RLock() defer a.globalFragmentLock.RUnlock() globalClientsAfter := len(a.currentGlobalFragment.Clients) @@ -5226,6 +5312,24 @@ func TestAddActivityToFragment(t *testing.T) { } } + // for now local clients are added to both regular fragment and local fragment. + // this will be modified in ticket vault-31234 + if tc.isAdded { + require.Equal(t, numClientsBefore+1, numClientsAfter) + } else { + require.Equal(t, numClientsBefore, numClientsAfter) + } + + require.Contains(t, a.partialMonthClientTracker, tc.expectedID) + require.True(t, proto.Equal(&activity.EntityRecord{ + ClientID: tc.expectedID, + NamespaceID: ns, + Timestamp: 0, + NonEntity: tc.isNonEntity, + MountAccessor: mountAccessor, + ClientType: tc.activityType, + }, a.partialMonthClientTracker[tc.expectedID])) + if tc.isLocal { require.Contains(t, a.partialMonthLocalClientTracker, tc.expectedID) require.True(t, proto.Equal(&activity.EntityRecord{ @@ -5267,6 +5371,7 @@ func TestGetAllPartialMonthClients(t *testing.T) { a := core.activityLog a.SetEnable(true) + require.Nil(t, a.fragment) require.Nil(t, a.localFragment) require.Nil(t, a.currentGlobalFragment) @@ -5280,6 +5385,7 @@ func TestGetAllPartialMonthClients(t *testing.T) { a.AddActivityToFragment(clientID, ns, 0, entityActivityType, mount) require.NotNil(t, a.localFragment) + require.NotNil(t, a.fragment) require.NotNil(t, a.currentGlobalFragment) // create a local mount accessor @@ -5677,6 +5783,37 @@ func TestCreateSegment_StoreSegment(t *testing.T) { global: true, forceStore: true, }, + + { + testName: "[non-global] max segment size", + numClients: ActivitySegmentClientCapacity, + maxClientsPerFragment: ActivitySegmentClientCapacity, + global: false, + }, + { + testName: "[non-global] max segment size, multiple fragments", + numClients: ActivitySegmentClientCapacity, + maxClientsPerFragment: ActivitySegmentClientCapacity - 1, + global: false, + }, + { + testName: "[non-global] roll over", + numClients: ActivitySegmentClientCapacity + 2, + maxClientsPerFragment: ActivitySegmentClientCapacity, + global: false, + }, + { + testName: "[non-global] max segment size, rollover multiple fragments", + numClients: ActivitySegmentClientCapacity * 2, + maxClientsPerFragment: ActivitySegmentClientCapacity - 1, + global: false, + }, + { + testName: "[non-global] max client size, drop clients", + numClients: ActivitySegmentClientCapacity*2 + 1, + maxClientsPerFragment: ActivitySegmentClientCapacity, + global: false, + }, { testName: "[local] max client size, drop clients", numClients: ActivitySegmentClientCapacity*2 + 1, @@ -5773,7 +5910,10 @@ func TestCreateSegment_StoreSegment(t *testing.T) { segment := &a.currentGlobalSegment if !test.global { - segment = &a.currentLocalSegment + segment = &a.currentSegment + if test.pathPrefix == activityLocalPathPrefix { + segment = &a.currentLocalSegment + } } // Create segments and write to storage @@ -5792,13 +5932,24 @@ func TestCreateSegment_StoreSegment(t *testing.T) { clientTotal += len(entity.GetClients()) } } else { - for { - entity, err := reader.ReadLocalEntity(ctx) - if errors.Is(err, io.EOF) { - break + if test.pathPrefix == activityLocalPathPrefix { + for { + entity, err := reader.ReadLocalEntity(ctx) + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + clientTotal += len(entity.GetClients()) + } + } else { + for { + entity, err := reader.ReadEntity(ctx) + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + clientTotal += len(entity.GetClients()) } - require.NoError(t, err) - clientTotal += len(entity.GetClients()) } } diff --git a/vault/activity_log_testing_util.go b/vault/activity_log_testing_util.go index d0fd4b7b35..f9bb25ba14 100644 --- a/vault/activity_log_testing_util.go +++ b/vault/activity_log_testing_util.go @@ -36,7 +36,7 @@ func (c *Core) InjectActivityLogDataThisMonth(t *testing.T) map[string]*activity Timestamp: c.activityLog.clock.Now().Unix(), NonEntity: i%2 == 0, } - c.activityLog.globalPartialMonthClientTracker[er.ClientID] = er + c.activityLog.partialMonthClientTracker[er.ClientID] = er } if constants.IsEnterprise { @@ -49,12 +49,12 @@ func (c *Core) InjectActivityLogDataThisMonth(t *testing.T) map[string]*activity Timestamp: c.activityLog.clock.Now().Unix(), NonEntity: i%2 == 0, } - c.activityLog.globalPartialMonthClientTracker[er.ClientID] = er + c.activityLog.partialMonthClientTracker[er.ClientID] = er } } } - return c.activityLog.globalPartialMonthClientTracker + return c.activityLog.partialMonthClientTracker } // GetActiveClients returns the in-memory globalPartialMonthClientTracker and partialMonthLocalClientTracker from an @@ -93,7 +93,6 @@ func (c *Core) GetActiveClientsList() []*activity.EntityRecord { return out } -// GetActiveLocalClientsList returns the active clients from globalPartialMonthClientTracker in activity log func (c *Core) GetActiveGlobalClientsList() []*activity.EntityRecord { out := []*activity.EntityRecord{} c.activityLog.globalFragmentLock.RLock() @@ -105,7 +104,6 @@ func (c *Core) GetActiveGlobalClientsList() []*activity.EntityRecord { return out } -// GetActiveLocalClientsList returns the active clients from partialMonthLocalClientTracker in activity log func (c *Core) GetActiveLocalClientsList() []*activity.EntityRecord { out := []*activity.EntityRecord{} c.activityLog.localFragmentLock.RLock() @@ -117,14 +115,21 @@ func (c *Core) GetActiveLocalClientsList() []*activity.EntityRecord { return out } -// GetCurrentGlobalEntities returns the current clients from currentGlobalSegment in activity log +// GetCurrentEntities returns the current entity activity log +func (a *ActivityLog) GetCurrentEntities() *activity.EntityActivityLog { + a.l.RLock() + defer a.l.RUnlock() + return a.currentSegment.currentClients +} + +// GetCurrentGlobalEntities returns the current global entity activity log func (a *ActivityLog) GetCurrentGlobalEntities() *activity.EntityActivityLog { a.l.RLock() defer a.l.RUnlock() return a.currentGlobalSegment.currentClients } -// GetCurrentLocalEntities returns the current clients from currentLocalSegment in activity log +// GetCurrentLocalEntities returns the current local entity activity log func (a *ActivityLog) GetCurrentLocalEntities() *activity.EntityActivityLog { a.l.RLock() defer a.l.RUnlock() @@ -164,11 +169,8 @@ func (a *ActivityLog) SetStandbyEnable(ctx context.Context, enabled bool) { // NOTE: AddTokenToFragment is deprecated and can no longer be used, except for // testing backward compatibility. Please use AddClientToFragment instead. func (a *ActivityLog) AddTokenToFragment(namespaceID string) { - a.globalFragmentLock.Lock() - defer a.globalFragmentLock.Unlock() - - a.localFragmentLock.Lock() - defer a.localFragmentLock.Unlock() + a.fragmentLock.Lock() + defer a.fragmentLock.Unlock() if !a.enabled { return @@ -176,7 +178,7 @@ func (a *ActivityLog) AddTokenToFragment(namespaceID string) { a.createCurrentFragment() - a.localFragment.NonEntityTokens[namespaceID] += 1 + a.fragment.NonEntityTokens[namespaceID] += 1 } func RandStringBytes(n int) string { @@ -197,29 +199,20 @@ func (a *ActivityLog) ExpectCurrentSegmentRefreshed(t *testing.T, expectedStart defer a.l.RUnlock() a.fragmentLock.RLock() defer a.fragmentLock.RUnlock() - if a.currentGlobalSegment.currentClients == nil { + if a.currentSegment.currentClients == nil { t.Fatalf("expected non-nil currentSegment.currentClients") } - if a.currentGlobalSegment.currentClients.Clients == nil { + if a.currentSegment.currentClients.Clients == nil { t.Errorf("expected non-nil currentSegment.currentClients.Entities") } - if a.currentGlobalSegment.tokenCount == nil { + if a.currentSegment.tokenCount == nil { t.Fatalf("expected non-nil currentSegment.tokenCount") } - if a.currentGlobalSegment.tokenCount.CountByNamespaceID == nil { + if a.currentSegment.tokenCount.CountByNamespaceID == nil { t.Errorf("expected non-nil currentSegment.tokenCount.CountByNamespaceID") } - if a.currentLocalSegment.currentClients == nil { - t.Fatalf("expected non-nil currentSegment.currentClients") - } - if a.currentLocalSegment.currentClients.Clients == nil { - t.Errorf("expected non-nil currentSegment.currentClients.Entities") - } - if a.currentLocalSegment.tokenCount == nil { - t.Fatalf("expected non-nil currentSegment.tokenCount") - } - if a.currentLocalSegment.tokenCount.CountByNamespaceID == nil { - t.Errorf("expected non-nil currentSegment.tokenCount.CountByNamespaceID") + if a.partialMonthClientTracker == nil { + t.Errorf("expected non-nil partialMonthClientTracker") } if a.partialMonthLocalClientTracker == nil { t.Errorf("expected non-nil partialMonthLocalClientTracker") @@ -227,14 +220,14 @@ func (a *ActivityLog) ExpectCurrentSegmentRefreshed(t *testing.T, expectedStart if a.globalPartialMonthClientTracker == nil { t.Errorf("expected non-nil globalPartialMonthClientTracker") } - if len(a.currentGlobalSegment.currentClients.Clients) > 0 { - t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentGlobalSegment.currentClients) + if len(a.currentSegment.currentClients.Clients) > 0 { + t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentSegment.currentClients) } - if len(a.currentLocalSegment.currentClients.Clients) > 0 { - t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentLocalSegment.currentClients) + if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { + t.Errorf("expected no token counts to be loaded. got: %v", a.currentSegment.tokenCount.CountByNamespaceID) } - if len(a.currentLocalSegment.tokenCount.CountByNamespaceID) > 0 { - t.Errorf("expected no token counts to be loaded. got: %v", a.currentLocalSegment.tokenCount.CountByNamespaceID) + if len(a.partialMonthClientTracker) > 0 { + t.Errorf("expected no active entity segment to be loaded. got: %v", a.partialMonthClientTracker) } if len(a.partialMonthLocalClientTracker) > 0 { t.Errorf("expected no active entity segment to be loaded. got: %v", a.partialMonthLocalClientTracker) @@ -244,12 +237,17 @@ func (a *ActivityLog) ExpectCurrentSegmentRefreshed(t *testing.T, expectedStart } if verifyTimeNotZero { + if a.currentSegment.startTimestamp == 0 { + t.Error("bad start timestamp. expected no reset but timestamp was reset") + } if a.currentGlobalSegment.startTimestamp == 0 { t.Error("bad start timestamp. expected no reset but timestamp was reset") } if a.currentLocalSegment.startTimestamp == 0 { t.Error("bad start timestamp. expected no reset but timestamp was reset") } + } else if a.currentSegment.startTimestamp != expectedStart { + t.Errorf("bad start timestamp. expected: %v got: %v", expectedStart, a.currentSegment.startTimestamp) } else if a.currentGlobalSegment.startTimestamp != expectedStart { t.Errorf("bad start timestamp. expected: %v got: %v", expectedStart, a.currentGlobalSegment.startTimestamp) } else if a.currentLocalSegment.startTimestamp != expectedStart { @@ -272,7 +270,9 @@ func ActiveEntitiesEqual(active []*activity.EntityRecord, test []*activity.Entit func (a *ActivityLog) GetStartTimestamp() int64 { a.l.RLock() defer a.l.RUnlock() - if a.currentGlobalSegment.startTimestamp != a.currentLocalSegment.startTimestamp { + // TODO: We will substitute a.currentSegment with a.currentLocalSegment when we remove + // a.currentSegment from the code + if a.currentGlobalSegment.startTimestamp != a.currentSegment.startTimestamp { return -1 } return a.currentGlobalSegment.startTimestamp @@ -282,6 +282,7 @@ func (a *ActivityLog) GetStartTimestamp() int64 { func (a *ActivityLog) SetStartTimestamp(timestamp int64) { a.l.Lock() defer a.l.Unlock() + a.currentSegment.startTimestamp = timestamp a.currentGlobalSegment.startTimestamp = timestamp a.currentLocalSegment.startTimestamp = timestamp } @@ -293,6 +294,13 @@ func (a *ActivityLog) GetStoredTokenCountByNamespaceID() map[string]uint64 { return a.currentLocalSegment.tokenCount.CountByNamespaceID } +// GetEntitySequenceNumber returns the current entity sequence number +func (a *ActivityLog) GetEntitySequenceNumber() uint64 { + a.l.RLock() + defer a.l.RUnlock() + return a.currentSegment.clientSequenceNumber +} + // GetGlobalEntitySequenceNumber returns the current entity sequence number func (a *ActivityLog) GetGlobalEntitySequenceNumber() uint64 { a.l.RLock() @@ -347,6 +355,12 @@ func (c *Core) GetActiveLocalFragment() *activity.LogFragment { return c.activityLog.localFragment } +func (c *Core) GetActiveFragment() *activity.LogFragment { + c.activityLog.fragmentLock.RLock() + defer c.activityLog.fragmentLock.RUnlock() + return c.activityLog.fragment +} + // StoreCurrentSegment is a test only method to create and store // segments from fragments. This allows createCurrentSegmentFromFragments to remain // private diff --git a/vault/activity_log_util_common.go b/vault/activity_log_util_common.go index f3cd616ed9..c019d03a47 100644 --- a/vault/activity_log_util_common.go +++ b/vault/activity_log_util_common.go @@ -425,6 +425,7 @@ type singleTypeSegmentReader struct { } type segmentReader struct { tokens *singleTypeSegmentReader + entities *singleTypeSegmentReader globalEntities *singleTypeSegmentReader localEntities *singleTypeSegmentReader } @@ -432,11 +433,16 @@ type segmentReader struct { // SegmentReader is an interface that provides methods to read tokens and entities in order type SegmentReader interface { ReadToken(ctx context.Context) (*activity.TokenCount, error) + ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) ReadGlobalEntity(ctx context.Context) (*activity.EntityActivityLog, error) ReadLocalEntity(ctx context.Context) (*activity.EntityActivityLog, error) } func (a *ActivityLog) NewSegmentFileReader(ctx context.Context, startTime time.Time) (SegmentReader, error) { + entities, err := a.newSingleTypeSegmentReader(ctx, startTime, activityEntityBasePath) + if err != nil { + return nil, err + } globalEntities, err := a.newSingleTypeSegmentReader(ctx, startTime, activityGlobalPathPrefix+activityEntityBasePath) if err != nil { return nil, err @@ -449,7 +455,7 @@ func (a *ActivityLog) NewSegmentFileReader(ctx context.Context, startTime time.T if err != nil { return nil, err } - return &segmentReader{globalEntities: globalEntities, localEntities: localEntities, tokens: tokens}, nil + return &segmentReader{entities: entities, globalEntities: globalEntities, localEntities: localEntities, tokens: tokens}, nil } func (a *ActivityLog) newSingleTypeSegmentReader(ctx context.Context, startTime time.Time, prefix string) (*singleTypeSegmentReader, error) { @@ -504,6 +510,17 @@ func (e *segmentReader) ReadToken(ctx context.Context) (*activity.TokenCount, er return out, nil } +// ReadEntity reads an entity from the segment +// If there is none available, then the error will be io.EOF +func (e *segmentReader) ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) { + out := &activity.EntityActivityLog{} + err := e.entities.nextValue(ctx, out) + if err != nil { + return nil, err + } + return out, nil +} + // ReadGlobalEntity reads a global entity from the global segment // If there is none available, then the error will be io.EOF func (e *segmentReader) ReadGlobalEntity(ctx context.Context) (*activity.EntityActivityLog, error) { diff --git a/vault/activity_log_util_common_test.go b/vault/activity_log_util_common_test.go index f84775da3f..7201cdc651 100644 --- a/vault/activity_log_util_common_test.go +++ b/vault/activity_log_util_common_test.go @@ -1006,6 +1006,14 @@ func writeLocalEntitySegment(t *testing.T, core *Core, ts time.Time, index int, WriteToStorage(t, core, makeSegmentPath(t, activityLocalPathPrefix+activityEntityBasePath, ts, index), protoItem) } +// writeEntitySegment writes a single segment file with the given time and index for an entity +func writeEntitySegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.EntityActivityLog) { + t.Helper() + protoItem, err := proto.Marshal(item) + require.NoError(t, err) + WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, ts, index), protoItem) +} + // writeTokenSegment writes a single segment file with the given time and index for a token func writeTokenSegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.TokenCount) { t.Helper() @@ -1029,6 +1037,7 @@ func TestSegmentFileReader_BadData(t *testing.T) { // write bad data that won't be able to be unmarshaled at index 0 WriteToStorage(t, core, makeSegmentPath(t, activityTokenLocalBasePath, now, 0), []byte("fake data")) + WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, 0), []byte("fake data")) WriteToStorage(t, core, makeSegmentPath(t, activityGlobalPathPrefix+activityEntityBasePath, now, 0), []byte("fake data")) WriteToStorage(t, core, makeSegmentPath(t, activityLocalPathPrefix+activityEntityBasePath, now, 0), []byte("fake data")) @@ -1038,6 +1047,8 @@ func TestSegmentFileReader_BadData(t *testing.T) { ClientID: "id", }, }} + writeEntitySegment(t, core, now, 1, entity) + // write global data at index 1 writeGlobalEntitySegment(t, core, now, 1, entity) @@ -1052,19 +1063,25 @@ func TestSegmentFileReader_BadData(t *testing.T) { reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) require.NoError(t, err) + // first the bad entity is read, which returns an error + _, err = reader.ReadEntity(context.Background()) + require.Error(t, err) + // then, the reader can read the good entity at index 1 + gotEntity, err := reader.ReadEntity(context.Background()) + require.True(t, proto.Equal(gotEntity, entity)) + require.Nil(t, err) + // first the bad global entity is read, which returns an error _, err = reader.ReadGlobalEntity(context.Background()) require.Error(t, err) - // then, the reader can read the good entity at index 1 - gotEntity, err := reader.ReadGlobalEntity(context.Background()) + gotEntity, err = reader.ReadGlobalEntity(context.Background()) require.True(t, proto.Equal(gotEntity, entity)) require.Nil(t, err) // first the bad local entity is read, which returns an error _, err = reader.ReadLocalEntity(context.Background()) require.Error(t, err) - // then, the reader can read the good entity at index 1 gotEntity, err = reader.ReadLocalEntity(context.Background()) require.True(t, proto.Equal(gotEntity, entity)) @@ -1073,7 +1090,6 @@ func TestSegmentFileReader_BadData(t *testing.T) { // the bad token causes an error _, err = reader.ReadToken(context.Background()) require.Error(t, err) - // but the good token is able to be read gotToken, err := reader.ReadToken(context.Background()) require.True(t, proto.Equal(gotToken, token)) @@ -1088,7 +1104,9 @@ func TestSegmentFileReader_MissingData(t *testing.T) { // write entities and tokens at indexes 0, 1, 2 for i := 0; i < 3; i++ { WriteToStorage(t, core, makeSegmentPath(t, activityTokenLocalBasePath, now, i), []byte("fake data")) + WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, i), []byte("fake data")) WriteToStorage(t, core, makeSegmentPath(t, activityGlobalPathPrefix+activityEntityBasePath, now, i), []byte("fake data")) + } // write entity at index 3 entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{ @@ -1096,6 +1114,7 @@ func TestSegmentFileReader_MissingData(t *testing.T) { ClientID: "id", }, }} + writeEntitySegment(t, core, now, 3, entity) // write global entity at index 3 writeGlobalEntitySegment(t, core, now, 3, entity) @@ -1114,18 +1133,25 @@ func TestSegmentFileReader_MissingData(t *testing.T) { // delete the indexes 0, 1, 2 for i := 0; i < 3; i++ { require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityTokenLocalBasePath, now, i))) + require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityEntityBasePath, now, i))) require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityGlobalPathPrefix+activityEntityBasePath, now, i))) require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityLocalPathPrefix+activityEntityBasePath, now, i))) } // we expect the reader to only return the data at index 3, and then be done + gotEntity, err := reader.ReadEntity(context.Background()) + require.NoError(t, err) + require.True(t, proto.Equal(gotEntity, entity)) + _, err = reader.ReadEntity(context.Background()) + require.Equal(t, err, io.EOF) + gotToken, err := reader.ReadToken(context.Background()) require.NoError(t, err) require.True(t, proto.Equal(gotToken, token)) _, err = reader.ReadToken(context.Background()) require.Equal(t, err, io.EOF) - gotEntity, err := reader.ReadGlobalEntity(context.Background()) + gotEntity, err = reader.ReadGlobalEntity(context.Background()) require.NoError(t, err) require.True(t, proto.Equal(gotEntity, entity)) _, err = reader.ReadGlobalEntity(context.Background()) @@ -1144,7 +1170,7 @@ func TestSegmentFileReader_NoData(t *testing.T) { now := time.Now() reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) require.NoError(t, err) - entity, err := reader.ReadGlobalEntity(context.Background()) + entity, err := reader.ReadEntity(context.Background()) require.Nil(t, entity) require.Equal(t, err, io.EOF) token, err := reader.ReadToken(context.Background()) @@ -1170,8 +1196,7 @@ func TestSegmentFileReader(t *testing.T) { token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{ fmt.Sprintf("ns-%d", i): uint64(i), }} - writeGlobalEntitySegment(t, core, now, i, entity) - writeLocalEntitySegment(t, core, now, i, entity) + writeEntitySegment(t, core, now, i, entity) writeTokenSegment(t, core, now, i, token) entities = append(entities, entity) tokens = append(tokens, token) @@ -1180,20 +1205,13 @@ func TestSegmentFileReader(t *testing.T) { reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) require.NoError(t, err) - gotGlobalEntities := make([]*activity.EntityActivityLog, 0, 3) - gotLocalEntities := make([]*activity.EntityActivityLog, 0, 3) + gotEntities := make([]*activity.EntityActivityLog, 0, 3) gotTokens := make([]*activity.TokenCount, 0, 3) - // read the global entities from the reader - for entity, err := reader.ReadGlobalEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadGlobalEntity(context.Background()) { + // read the entities from the reader + for entity, err := reader.ReadEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadEntity(context.Background()) { require.NoError(t, err) - gotGlobalEntities = append(gotGlobalEntities, entity) - } - - // read the local entities from the reader - for entity, err := reader.ReadLocalEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadLocalEntity(context.Background()) { - require.NoError(t, err) - gotLocalEntities = append(gotLocalEntities, entity) + gotEntities = append(gotEntities, entity) } // read the tokens from the reader @@ -1201,15 +1219,13 @@ func TestSegmentFileReader(t *testing.T) { require.NoError(t, err) gotTokens = append(gotTokens, token) } - require.Len(t, gotGlobalEntities, 3) - require.Len(t, gotLocalEntities, 3) + require.Len(t, gotEntities, 3) require.Len(t, gotTokens, 3) // verify that the entities and tokens we got from the reader are correct // we can't use require.Equals() here because there are protobuf differences in unexported fields for i := 0; i < 3; i++ { - require.True(t, proto.Equal(gotGlobalEntities[i], entities[i])) - require.True(t, proto.Equal(gotLocalEntities[i], entities[i])) + require.True(t, proto.Equal(gotEntities[i], entities[i])) require.True(t, proto.Equal(gotTokens[i], tokens[i])) } } diff --git a/vault/external_tests/activity_testonly/acme_regeneration_test.go b/vault/external_tests/activity_testonly/acme_regeneration_test.go index 5d70dc0c21..c663b174b8 100644 --- a/vault/external_tests/activity_testonly/acme_regeneration_test.go +++ b/vault/external_tests/activity_testonly/acme_regeneration_test.go @@ -54,7 +54,7 @@ func TestACMERegeneration_RegenerateWithCurrentMonth(t *testing.T) { }) require.NoError(t, err) now := time.Now().UTC() - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(3). // 3 months ago, 15 non-entity clients and 10 ACME clients NewClientsSeen(15, clientcountutil.WithClientType("non-entity-token")). @@ -116,7 +116,7 @@ func TestACMERegeneration_RegenerateMuchOlder(t *testing.T) { client := cluster.Cores[0].Client now := time.Now().UTC() - _, _, err := clientcountutil.NewActivityLogData(client). + _, _, _, err := clientcountutil.NewActivityLogData(client). NewPreviousMonthData(5). // 5 months ago, 15 non-entity clients and 10 ACME clients NewClientsSeen(15, clientcountutil.WithClientType("non-entity-token")). @@ -159,7 +159,7 @@ func TestACMERegeneration_RegeneratePreviousMonths(t *testing.T) { client := cluster.Cores[0].Client now := time.Now().UTC() - _, _, err := clientcountutil.NewActivityLogData(client). + _, _, _, err := clientcountutil.NewActivityLogData(client). NewPreviousMonthData(3). // 3 months ago, 15 non-entity clients and 10 ACME clients NewClientsSeen(15, clientcountutil.WithClientType("non-entity-token")). diff --git a/vault/external_tests/activity_testonly/activity_testonly_oss_test.go b/vault/external_tests/activity_testonly/activity_testonly_oss_test.go index c5463bb801..4b59142008 100644 --- a/vault/external_tests/activity_testonly/activity_testonly_oss_test.go +++ b/vault/external_tests/activity_testonly/activity_testonly_oss_test.go @@ -29,7 +29,7 @@ func Test_ActivityLog_Disable(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(5). NewCurrentMonthData(). diff --git a/vault/external_tests/activity_testonly/activity_testonly_test.go b/vault/external_tests/activity_testonly/activity_testonly_test.go index cd9dfb2157..3e3a1259b2 100644 --- a/vault/external_tests/activity_testonly/activity_testonly_test.go +++ b/vault/external_tests/activity_testonly/activity_testonly_test.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -////go:build testonly +//go:build testonly package activity_testonly @@ -86,7 +86,7 @@ func Test_ActivityLog_LoseLeadership(t *testing.T) { }) require.NoError(t, err) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10). Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) @@ -121,7 +121,7 @@ func Test_ActivityLog_ClientsOverlapping(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(7). NewCurrentMonthData(). @@ -169,7 +169,7 @@ func Test_ActivityLog_ClientsNewCurrentMonth(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(5). NewCurrentMonthData(). @@ -203,7 +203,7 @@ func Test_ActivityLog_EmptyDataMonths(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10). Write(context.Background(), generation.WriteOptions_WRITE_PRECOMPUTED_QUERIES, generation.WriteOptions_WRITE_ENTITIES) @@ -243,7 +243,7 @@ func Test_ActivityLog_FutureEndDate(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(10). NewCurrentMonthData(). @@ -316,7 +316,7 @@ func Test_ActivityLog_ClientTypeResponse(t *testing.T) { _, err := client.Logical().Write("sys/internal/counters/config", map[string]interface{}{ "enabled": "enable", }) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10, clientcountutil.WithClientType(tc.clientType)). Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) @@ -369,7 +369,7 @@ func Test_ActivityLogCurrentMonth_Response(t *testing.T) { _, err := client.Logical().Write("sys/internal/counters/config", map[string]interface{}{ "enabled": "enable", }) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10, clientcountutil.WithClientType(tc.clientType)). Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) @@ -420,7 +420,7 @@ func Test_ActivityLog_Deduplication(t *testing.T) { _, err := client.Logical().Write("sys/internal/counters/config", map[string]interface{}{ "enabled": "enable", }) - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(3). NewClientsSeen(10, clientcountutil.WithClientType(tc.clientType)). NewPreviousMonthData(2). @@ -462,7 +462,7 @@ func Test_ActivityLog_MountDeduplication(t *testing.T) { require.NoError(t, err) now := time.Now().UTC() - localPaths, globalPaths, err := clientcountutil.NewActivityLogData(client). + _, localPaths, globalPaths, err := clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientSeen(clientcountutil.WithClientMount("sys")). NewClientSeen(clientcountutil.WithClientMount("secret")). @@ -673,7 +673,7 @@ func Test_ActivityLog_Export_Sudo(t *testing.T) { rootToken := client.Token() - _, _, err = clientcountutil.NewActivityLogData(client). + _, _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10). Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) @@ -849,7 +849,7 @@ func TestHandleQuery_MultipleMounts(t *testing.T) { } // Write all the client count data - _, _, err = activityLogGenerator.Write(context.Background(), generation.WriteOptions_WRITE_PRECOMPUTED_QUERIES, generation.WriteOptions_WRITE_ENTITIES) + _, _, _, err = activityLogGenerator.Write(context.Background(), generation.WriteOptions_WRITE_PRECOMPUTED_QUERIES, generation.WriteOptions_WRITE_ENTITIES) require.NoError(t, err) endOfCurrentMonth := timeutil.EndOfMonth(time.Now().UTC()) diff --git a/vault/logical_system_activity_write_testonly.go b/vault/logical_system_activity_write_testonly.go index 51fe65e61e..3f6f4caa56 100644 --- a/vault/logical_system_activity_write_testonly.go +++ b/vault/logical_system_activity_write_testonly.go @@ -85,13 +85,14 @@ func (b *SystemBackend) handleActivityWriteData(ctx context.Context, request *lo for _, opt := range input.Write { opts[opt] = struct{}{} } - localPaths, globalPaths, err := generated.write(ctx, opts, b.Core.activityLog, now) + paths, localPaths, globalPaths, err := generated.write(ctx, opts, b.Core.activityLog, now) if err != nil { b.logger.Debug("failed to write activity log data", "error", err.Error()) return logical.ErrorResponse("failed to write data"), err } return &logical.Response{ Data: map[string]interface{}{ + "paths": paths, "local_paths": localPaths, "global_paths": globalPaths, }, @@ -100,10 +101,15 @@ func (b *SystemBackend) handleActivityWriteData(ctx context.Context, request *lo // singleMonthActivityClients holds a single month's client IDs, in the order they were seen type singleMonthActivityClients struct { + // clients are indexed by ID + clients []*activity.EntityRecord // globalClients are indexed by ID globalClients []*activity.EntityRecord // localClients are indexed by ID localClients []*activity.EntityRecord + // predefinedSegments map from the segment number to the client's index in + // the clients slice + predefinedSegments map[int][]int // predefinedGlobalSegments map from the segment number to the client's index in // the clients slice predefinedGlobalSegments map[int][]int @@ -120,13 +126,17 @@ type multipleMonthsActivityClients struct { months []*singleMonthActivityClients } -func (s *singleMonthActivityClients) addEntityRecord(core *Core, record *activity.EntityRecord, segmentIndex *int, local bool) { +func (s *singleMonthActivityClients) addEntityRecord(core *Core, record *activity.EntityRecord, segmentIndex *int) { + s.clients = append(s.clients, record) + local, _ := core.activityLog.isClientLocal(record) if !local { s.globalClients = append(s.globalClients, record) } else { s.localClients = append(s.localClients, record) } if segmentIndex != nil { + index := len(s.clients) - 1 + s.predefinedSegments[*segmentIndex] = append(s.predefinedSegments[*segmentIndex], index) if !local { globalIndex := len(s.globalClients) - 1 s.predefinedGlobalSegments[*segmentIndex] = append(s.predefinedGlobalSegments[*segmentIndex], globalIndex) @@ -220,15 +230,9 @@ func (s *singleMonthActivityClients) addNewClients(c *generation.Client, mountAc if c.Count > 1 { count = int(c.Count) } + isNonEntity := c.ClientType != entityActivityType ts := timeutil.MonthsPreviousTo(int(monthsAgo), now) - // identify is client is local or global - isLocal, err := isClientLocal(core, c.ClientType, mountAccessor) - if err != nil { - return err - } - - isNonEntity := c.ClientType != entityActivityType for i := 0; i < count; i++ { record := &activity.EntityRecord{ ClientID: c.Id, @@ -246,7 +250,7 @@ func (s *singleMonthActivityClients) addNewClients(c *generation.Client, mountAc } } - s.addEntityRecord(core, record, segmentIndex, isLocal) + s.addEntityRecord(core, record, segmentIndex) } return nil } @@ -355,25 +359,13 @@ func (m *multipleMonthsActivityClients) addRepeatedClients(monthsAgo int32, c *g repeatedFromMonth = c.RepeatedFromMonth } repeatedFrom := m.months[repeatedFromMonth] - - // identify is client is local or global - isLocal, err := isClientLocal(core, c.ClientType, mountAccessor) - if err != nil { - return err - } - numClients := 1 if c.Count > 0 { numClients = int(c.Count) } - - repeatedClients := repeatedFrom.globalClients - if isLocal { - repeatedClients = repeatedFrom.localClients - } - for _, client := range repeatedClients { + for _, client := range repeatedFrom.clients { if c.ClientType == client.ClientType && mountAccessor == client.MountAccessor && c.Namespace == client.NamespaceID { - addingTo.addEntityRecord(core, client, segmentIndex, isLocal) + addingTo.addEntityRecord(core, client, segmentIndex) numClients-- if numClients == 0 { break @@ -386,23 +378,6 @@ func (m *multipleMonthsActivityClients) addRepeatedClients(monthsAgo int32, c *g return nil } -// isClientLocal checks whether the given client is on a local mount. -// In all other cases, we will assume it is a global client. -func isClientLocal(core *Core, clientType string, mountAccessor string) (bool, error) { - // Tokens are not replicated to performance secondary clusters - if clientType == nonEntityTokenActivityType { - return true, nil - } - mountEntry := core.router.MatchingMountByAccessor(mountAccessor) - // If the mount entry is nil, this means the mount has been deleted. We will assume it was replicated because we do not want to - // over count clients - if mountEntry != nil && mountEntry.Local { - return true, nil - } - - return false, nil -} - func (m *multipleMonthsActivityClients) addMissingCurrentMonth() { missing := m.months[0].generationParameters == nil && len(m.months) > 1 && @@ -420,7 +395,8 @@ func (m *multipleMonthsActivityClients) timestampForMonth(i int, now time.Time) return now } -func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[generation.WriteOptions]struct{}, activityLog *ActivityLog, now time.Time) ([]string, []string, error) { +func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[generation.WriteOptions]struct{}, activityLog *ActivityLog, now time.Time) ([]string, []string, []string, error) { + paths := []string{} globalPaths := []string{} localPaths := []string{} @@ -435,10 +411,30 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene continue } timestamp := m.timestampForMonth(i, now) + segments, err := month.populateSegments(month.predefinedSegments, month.clients) + if err != nil { + return nil, nil, nil, err + } + for segmentIndex, segment := range segments { + if segment == nil { + // skip the index + continue + } + entityPath, err := activityLog.saveSegmentEntitiesInternal(ctx, segmentInfo{ + startTimestamp: timestamp.Unix(), + currentClients: &activity.EntityActivityLog{Clients: segment}, + clientSequenceNumber: uint64(segmentIndex), + tokenCount: &activity.TokenCount{}, + }, true, "") + if err != nil { + return nil, nil, nil, err + } + paths = append(paths, entityPath) + } if len(month.globalClients) > 0 { globalSegments, err := month.populateSegments(month.predefinedGlobalSegments, month.globalClients) if err != nil { - return nil, nil, err + return nil, nil, nil, err } for segmentIndex, segment := range globalSegments { if segment == nil { @@ -452,7 +448,7 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene tokenCount: &activity.TokenCount{}, }, true, activityGlobalPathPrefix) if err != nil { - return nil, nil, err + return nil, nil, nil, err } globalPaths = append(globalPaths, entityPath) } @@ -460,7 +456,7 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene if len(month.localClients) > 0 { localSegments, err := month.populateSegments(month.predefinedLocalSegments, month.localClients) if err != nil { - return nil, nil, err + return nil, nil, nil, err } for segmentIndex, segment := range localSegments { if segment == nil { @@ -474,7 +470,7 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene tokenCount: &activity.TokenCount{}, }, true, activityLocalPathPrefix) if err != nil { - return nil, nil, err + return nil, nil, nil, err } localPaths = append(localPaths, entityPath) } @@ -499,16 +495,16 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene if writeIntentLog { err := activityLog.writeIntentLog(ctx, m.latestTimestamp(now, false).Unix(), m.latestTimestamp(now, true).UTC()) if err != nil { - return nil, nil, err + return nil, nil, nil, err } } wg := sync.WaitGroup{} err := activityLog.refreshFromStoredLog(ctx, &wg, now) if err != nil { - return nil, nil, err + return nil, nil, nil, err } wg.Wait() - return localPaths, globalPaths, nil + return paths, localPaths, globalPaths, nil } func (m *multipleMonthsActivityClients) latestTimestamp(now time.Time, includeCurrentMonth bool) time.Time { @@ -536,6 +532,7 @@ func newMultipleMonthsActivityClients(numberOfMonths int) *multipleMonthsActivit } for i := 0; i < numberOfMonths; i++ { m.months[i] = &singleMonthActivityClients{ + predefinedSegments: make(map[int][]int), predefinedGlobalSegments: make(map[int][]int), predefinedLocalSegments: make(map[int][]int), } @@ -586,3 +583,12 @@ func (p *sliceSegmentReader) ReadLocalEntity(ctx context.Context) (*activity.Ent func (p *sliceSegmentReader) ReadToken(ctx context.Context) (*activity.TokenCount, error) { return nil, io.EOF } + +func (p *sliceSegmentReader) ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) { + if p.i == len(p.records) { + return nil, io.EOF + } + record := p.records[p.i] + p.i++ + return &activity.EntityActivityLog{Clients: record}, nil +} diff --git a/vault/logical_system_activity_write_testonly_test.go b/vault/logical_system_activity_write_testonly_test.go index 420e2079d0..4df992172d 100644 --- a/vault/logical_system_activity_write_testonly_test.go +++ b/vault/logical_system_activity_write_testonly_test.go @@ -12,7 +12,6 @@ import ( "testing" "time" - "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/timeutil" "github.com/hashicorp/vault/sdk/helper/clientcountutil/generation" @@ -27,12 +26,11 @@ import ( // correctly validated func TestSystemBackend_handleActivityWriteData(t *testing.T) { testCases := []struct { - name string - operation logical.Operation - input map[string]interface{} - hasLocalClients bool - wantError error - wantPaths int + name string + operation logical.Operation + input map[string]interface{} + wantError error + wantPaths int }{ { name: "read fails", @@ -86,13 +84,6 @@ func TestSystemBackend_handleActivityWriteData(t *testing.T) { input: map[string]interface{}{"input": `{"write":["WRITE_ENTITIES"],"data":[{"current_month":true,"num_segments":3,"all":{"clients":[{"count":5}]}}]}`}, wantPaths: 3, }, - { - name: "entities with multiple segments", - operation: logical.UpdateOperation, - input: map[string]interface{}{"input": `{"write":["WRITE_ENTITIES"],"data":[{"current_month":true,"num_segments":3,"all":{"clients":[{"count":5, "mount":"cubbyhole/"}]}}]}`}, - hasLocalClients: true, - wantPaths: 3, - }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -104,16 +95,8 @@ func TestSystemBackend_handleActivityWriteData(t *testing.T) { require.Equal(t, tc.wantError, err, resp.Error()) } else { require.NoError(t, err) - globalPaths := resp.Data["global_paths"].([]string) - localPaths := resp.Data["local_paths"].([]string) - if tc.hasLocalClients { - require.Len(t, globalPaths, 0) - require.Len(t, localPaths, tc.wantPaths) - } else { - require.Len(t, globalPaths, tc.wantPaths) - require.Len(t, localPaths, 0) - } - + paths := resp.Data["paths"].([]string) + require.Len(t, paths, tc.wantPaths) } }) } @@ -133,7 +116,6 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { wantNamespace string wantMount string wantID string - isLocal bool segmentIndex *int }{ { @@ -171,13 +153,6 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { ClientType: "non-entity", }, }, - { - name: "non entity token client", - clients: &generation.Client{ - ClientType: nonEntityTokenActivityType, - }, - isLocal: true, - }, { name: "acme client", clients: &generation.Client{ @@ -194,8 +169,8 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { t.Run(tt.name, func(t *testing.T) { core, _, _ := TestCoreUnsealed(t) m := &singleMonthActivityClients{ + predefinedSegments: make(map[int][]int), predefinedGlobalSegments: make(map[int][]int), - predefinedLocalSegments: make(map[int][]int), } err := m.addNewClients(tt.clients, tt.mount, tt.segmentIndex, 0, time.Now().UTC(), core) require.NoError(t, err) @@ -203,16 +178,8 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { if numNew == 0 { numNew = 1 } - - var clients []*activity.EntityRecord - if tt.isLocal { - require.Len(t, m.localClients, int(numNew)) - clients = m.localClients - } else { - require.Len(t, m.globalClients, int(numNew)) - clients = m.globalClients - } - for i, rec := range clients { + require.Len(t, m.clients, int(numNew)) + for i, rec := range m.clients { require.NotNil(t, rec) require.Equal(t, tt.wantNamespace, rec.NamespaceID) require.Equal(t, tt.wantMount, rec.MountAccessor) @@ -222,11 +189,8 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { } else { require.NotEqual(t, "", rec.ClientID) } - if tt.segmentIndex != nil && tt.isLocal { - require.Contains(t, m.predefinedLocalSegments[*tt.segmentIndex], i) - } - if tt.segmentIndex != nil && !tt.isLocal { - require.Contains(t, m.predefinedGlobalSegments[*tt.segmentIndex], i) + if tt.segmentIndex != nil { + require.Contains(t, m.predefinedSegments[*tt.segmentIndex], i) } } }) @@ -242,7 +206,6 @@ func Test_multipleMonthsActivityClients_processMonth(t *testing.T) { name string clients *generation.Data wantError bool - isLocal bool numMonths int }{ { @@ -255,16 +218,6 @@ func Test_multipleMonthsActivityClients_processMonth(t *testing.T) { }, numMonths: 1, }, - { - name: "specified namespace and local mount exist", - clients: &generation.Data{ - Clients: &generation.Data_All{All: &generation.Clients{Clients: []*generation.Client{{ - Mount: "cubbyhole/", - }}}}, - }, - numMonths: 1, - isLocal: true, - }, { name: "mount missing slash", clients: &generation.Data{ @@ -329,24 +282,13 @@ func Test_multipleMonthsActivityClients_processMonth(t *testing.T) { require.Error(t, err) } else { require.NoError(t, err) - if tt.isLocal { - require.Len(t, m.months[tt.clients.GetMonthsAgo()].localClients, len(tt.clients.GetAll().Clients)) - for _, month := range m.months { - for _, c := range month.localClients { - require.NotEmpty(t, c.NamespaceID) - require.NotEmpty(t, c.MountAccessor) - } - } - } else { - require.Len(t, m.months[tt.clients.GetMonthsAgo()].globalClients, len(tt.clients.GetAll().Clients)) - for _, month := range m.months { - for _, c := range month.globalClients { - require.NotEmpty(t, c.NamespaceID) - require.NotEmpty(t, c.MountAccessor) - } + require.Len(t, m.months[tt.clients.GetMonthsAgo()].clients, len(tt.clients.GetAll().Clients)) + for _, month := range m.months { + for _, c := range month.clients { + require.NotEmpty(t, c.NamespaceID) + require.NotEmpty(t, c.MountAccessor) } } - } }) } @@ -381,95 +323,58 @@ func Test_multipleMonthsActivityClients_processMonth_segmented(t *testing.T) { m := newMultipleMonthsActivityClients(1) core, _, _ := TestCoreUnsealed(t) require.NoError(t, m.processMonth(context.Background(), core, data, time.Now().UTC())) - require.Len(t, m.months[0].predefinedGlobalSegments, 3) - require.Len(t, m.months[0].globalClients, 3) + require.Len(t, m.months[0].predefinedSegments, 3) + require.Len(t, m.months[0].clients, 3) // segment indexes are correct - require.Contains(t, m.months[0].predefinedGlobalSegments, 0) - require.Contains(t, m.months[0].predefinedGlobalSegments, 1) - require.Contains(t, m.months[0].predefinedGlobalSegments, 7) + require.Contains(t, m.months[0].predefinedSegments, 0) + require.Contains(t, m.months[0].predefinedSegments, 1) + require.Contains(t, m.months[0].predefinedSegments, 7) // the data in each segment is correct - require.Contains(t, m.months[0].predefinedGlobalSegments[0], 0) - require.Contains(t, m.months[0].predefinedGlobalSegments[1], 1) - require.Contains(t, m.months[0].predefinedGlobalSegments[7], 2) + require.Contains(t, m.months[0].predefinedSegments[0], 0) + require.Contains(t, m.months[0].predefinedSegments[1], 1) + require.Contains(t, m.months[0].predefinedSegments[7], 2) } // Test_multipleMonthsActivityClients_addRepeatedClients adds repeated clients // from 1 month ago and 2 months ago, and verifies that the correct clients are // added based on namespace, mount, and non-entity attributes func Test_multipleMonthsActivityClients_addRepeatedClients(t *testing.T) { - storage := &logical.InmemStorage{} - coreConfig := &CoreConfig{ - CredentialBackends: map[string]logical.Factory{ - "userpass": userpass.Factory, - }, - Physical: storage.Underlying(), - } - - cluster := NewTestCluster(t, coreConfig, nil) - core := cluster.Cores[0].Core + core, _, _ := TestCoreUnsealed(t) now := time.Now().UTC() m := newMultipleMonthsActivityClients(3) defaultMount := "default" - // add global clients require.NoError(t, m.addClientToMonth(2, &generation.Client{Count: 2}, "identity", nil, now, core)) require.NoError(t, m.addClientToMonth(2, &generation.Client{Count: 2, Namespace: "other_ns"}, defaultMount, nil, now, core)) require.NoError(t, m.addClientToMonth(1, &generation.Client{Count: 2}, defaultMount, nil, now, core)) require.NoError(t, m.addClientToMonth(1, &generation.Client{Count: 2, ClientType: "non-entity"}, defaultMount, nil, now, core)) - // create a local mount - localMount := "localMountAccessor" - localMe := &MountEntry{ - Table: credentialTableType, - Path: "userpass-local/", - Type: "userpass", - Local: true, - Accessor: localMount, - } - err := core.enableCredential(namespace.RootContext(nil), localMe) - require.NoError(t, err) - - // add a local client - require.NoError(t, m.addClientToMonth(2, &generation.Client{Count: 2}, localMount, nil, now, core)) - require.NoError(t, m.addClientToMonth(1, &generation.Client{Count: 2}, localMount, nil, now, core)) - - month2GlobalClients := m.months[2].globalClients - month1GlobalClients := m.months[1].globalClients - - month2LocalClients := m.months[2].localClients - month1LocalClients := m.months[1].localClients + month2Clients := m.months[2].clients + month1Clients := m.months[1].clients thisMonth := m.months[0] // this will match the first client in month 1 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, Repeated: true}, defaultMount, nil, core)) - require.Contains(t, month1GlobalClients, thisMonth.globalClients[0]) - - // this will match the first local client in month 1 - require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, Repeated: true}, localMount, nil, core)) - require.Contains(t, month1LocalClients, thisMonth.localClients[0]) + require.Contains(t, month1Clients, thisMonth.clients[0]) // this will match the 3rd client in month 1 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, Repeated: true, ClientType: "non-entity"}, defaultMount, nil, core)) - require.Equal(t, month1GlobalClients[2], thisMonth.globalClients[1]) + require.Equal(t, month1Clients[2], thisMonth.clients[1]) // this will match the first two clients in month 1 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 2, Repeated: true}, defaultMount, nil, core)) - require.Equal(t, month1GlobalClients[0:2], thisMonth.globalClients[2:4]) + require.Equal(t, month1Clients[0:2], thisMonth.clients[2:4]) // this will match the first client in month 2 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, RepeatedFromMonth: 2}, "identity", nil, core)) - require.Equal(t, month2GlobalClients[0], thisMonth.globalClients[4]) - - // this will match the first local client in month 2 - require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, RepeatedFromMonth: 2}, localMount, nil, core)) - require.Equal(t, month2LocalClients[0], thisMonth.localClients[1]) + require.Equal(t, month2Clients[0], thisMonth.clients[4]) // this will match the 3rd client in month 2 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, RepeatedFromMonth: 2, Namespace: "other_ns"}, defaultMount, nil, core)) - require.Equal(t, month2GlobalClients[2], thisMonth.globalClients[5]) + require.Equal(t, month2Clients[2], thisMonth.clients[5]) require.Error(t, m.addRepeatedClients(0, &generation.Client{Count: 1, RepeatedFromMonth: 2, Namespace: "other_ns"}, "other_mount", nil, core)) } @@ -553,8 +458,8 @@ func Test_singleMonthActivityClients_populateSegments(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - s := singleMonthActivityClients{predefinedGlobalSegments: tc.segments, globalClients: clients, generationParameters: &generation.Data{EmptySegmentIndexes: tc.emptyIndexes, SkipSegmentIndexes: tc.skipIndexes, NumSegments: int32(tc.numSegments)}} - gotSegments, err := s.populateSegments(s.predefinedGlobalSegments, s.globalClients) + s := singleMonthActivityClients{predefinedSegments: tc.segments, clients: clients, generationParameters: &generation.Data{EmptySegmentIndexes: tc.emptyIndexes, SkipSegmentIndexes: tc.skipIndexes, NumSegments: int32(tc.numSegments)}} + gotSegments, err := s.populateSegments(s.predefinedSegments, s.clients) require.NoError(t, err) require.Equal(t, tc.wantSegments, gotSegments) }) @@ -624,7 +529,7 @@ func Test_handleActivityWriteData(t *testing.T) { req.Data = map[string]interface{}{"input": string(marshaled)} resp, err := core.systemBackend.HandleRequest(namespace.RootContext(nil), req) require.NoError(t, err) - paths := resp.Data["global_paths"].([]string) + paths := resp.Data["paths"].([]string) require.Len(t, paths, 9) times, err := core.activityLog.availableLogs(context.Background(), time.Now())