diff --git a/sdk/helper/clientcountutil/clientcountutil.go b/sdk/helper/clientcountutil/clientcountutil.go index dfafd4bee8..7d0be5526e 100644 --- a/sdk/helper/clientcountutil/clientcountutil.go +++ b/sdk/helper/clientcountutil/clientcountutil.go @@ -267,7 +267,10 @@ func (d *ActivityLogDataGenerator) Segment(opts ...SegmentOption) *ActivityLogDa } // ToJSON returns the JSON representation of the data -func (d *ActivityLogDataGenerator) ToJSON() ([]byte, error) { +func (d *ActivityLogDataGenerator) ToJSON(writeOptions ...generation.WriteOptions) ([]byte, error) { + if len(writeOptions) > 0 { + d.data.Write = writeOptions + } return protojson.Marshal(d.data) } diff --git a/vault/activity_log.go b/vault/activity_log.go index 5357ffc093..48c48fb07b 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -511,7 +511,7 @@ func parseSegmentNumberFromPath(path string) (int, bool) { // availableLogs returns the start_time(s) (in UTC) associated with months for which logs exist, // sorted last to first -func (a *ActivityLog) availableLogs(ctx context.Context) ([]time.Time, error) { +func (a *ActivityLog) availableLogs(ctx context.Context, upTo time.Time) ([]time.Time, error) { paths := make([]string, 0) for _, basePath := range []string{activityEntityBasePath, activityTokenBasePath} { p, err := a.view.List(ctx, basePath) @@ -526,14 +526,17 @@ func (a *ActivityLog) availableLogs(ctx context.Context) ([]time.Time, error) { out := make([]time.Time, 0) for _, path := range paths { // generate a set of unique start times - time, err := timeutil.ParseTimeFromPath(path) + segmentTime, err := timeutil.ParseTimeFromPath(path) if err != nil { return nil, err } + if segmentTime.After(upTo) { + continue + } - if _, present := pathSet[time]; !present { - pathSet[time] = struct{}{} - out = append(out, time) + if _, present := pathSet[segmentTime]; !present { + pathSet[segmentTime] = struct{}{} + out = append(out, segmentTime) } } @@ -542,15 +545,15 @@ func (a *ActivityLog) availableLogs(ctx context.Context) ([]time.Time, error) { return out[i].After(out[j]) }) - a.logger.Trace("scanned existing logs", "out", out) + a.logger.Trace("scanned existing logs", "out", out, "up to", upTo) return out, nil } // getMostRecentActivityLogSegment gets the times (in UTC) associated with the most recent // contiguous set of activity logs, sorted in decreasing order (latest to earliest) -func (a *ActivityLog) getMostRecentActivityLogSegment(ctx context.Context) ([]time.Time, error) { - logTimes, err := a.availableLogs(ctx) +func (a *ActivityLog) getMostRecentActivityLogSegment(ctx context.Context, now time.Time) ([]time.Time, error) { + logTimes, err := a.availableLogs(ctx, now) if err != nil { return nil, err } @@ -892,7 +895,7 @@ func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGro a.fragmentLock.Lock() defer a.fragmentLock.Unlock() - decreasingLogTimes, err := a.getMostRecentActivityLogSegment(ctx) + decreasingLogTimes, err := a.getMostRecentActivityLogSegment(ctx, now) if err != nil { return err } @@ -1156,7 +1159,7 @@ func (c *Core) setupActivityLogLocked(ctx context.Context, wg *sync.WaitGroup) e // Check for any intent log, in the background manager.computationWorkerDone = make(chan struct{}) go func() { - manager.precomputedQueryWorker(ctx) + manager.precomputedQueryWorker(ctx, nil) close(manager.computationWorkerDone) }() @@ -1174,7 +1177,7 @@ func (c *Core) setupActivityLogLocked(ctx context.Context, wg *sync.WaitGroup) e func (a *ActivityLog) createRegenerationIntentLog(ctx context.Context, now time.Time) (*ActivityIntentLog, error) { intentLog := &ActivityIntentLog{} - segments, err := a.availableLogs(ctx) + segments, err := a.availableLogs(ctx, now) if err != nil { return nil, fmt.Errorf("error fetching available logs: %w", err) } @@ -1439,7 +1442,7 @@ func (a *ActivityLog) HandleEndOfMonth(ctx context.Context, currentTime time.Tim a.fragmentLock.Unlock() // Work on precomputed queries in background - go a.precomputedQueryWorker(ctx) + go a.precomputedQueryWorker(ctx, nil) return nil } @@ -2431,7 +2434,9 @@ func (a *ActivityLog) reportPrecomputedQueryMetrics(ctx context.Context, segment // goroutine to process the request in the intent log, creating precomputed queries. // We expect the return value won't be checked, so log errors as they occur // (but for unit testing having the error return should help.) -func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error { +// If the intent log that's passed into the function is non-nil, we use that +// intent log. Otherwise, we read the intent log from storage +func (a *ActivityLog) precomputedQueryWorker(ctx context.Context, intent *ActivityIntentLog) (err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -2452,21 +2457,39 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error { }(a.doneCh) a.l.RUnlock() - // Load the intent log - rawIntentLog, err := a.view.Get(ctx, activityIntentLogKey) - if err != nil { - a.logger.Warn("could not load intent log", "error", err) - return err + strictEnforcement := intent == nil + shouldCleanupIntentLog := false + if intent == nil { + + // Load the intent log + rawIntentLog, err := a.view.Get(ctx, activityIntentLogKey) + if err != nil { + a.logger.Warn("could not load intent log", "error", err) + return err + } + if rawIntentLog == nil { + a.logger.Trace("no intent log found") + return err + } + intent = new(ActivityIntentLog) + err = json.Unmarshal(rawIntentLog.Value, intent) + if err != nil { + a.logger.Warn("could not parse intent log", "error", err) + return err + } + shouldCleanupIntentLog = true } - if rawIntentLog == nil { - a.logger.Trace("no intent log found") - return err - } - var intent ActivityIntentLog - err = json.Unmarshal(rawIntentLog.Value, &intent) - if err != nil { - a.logger.Warn("could not parse intent log", "error", err) - return err + + cleanupIntentLog := func() { + if !shouldCleanupIntentLog { + return + } + // delete the intent log + // this should happen if the precomputed queries were generated + // successfully (i.e. err is nil) or if there's no data for the previous + // month. + // It should not happen in the general error case + a.view.Delete(ctx, activityIntentLogKey) } // currentMonth could change (from another month end) after we release the lock. @@ -2479,28 +2502,29 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error { // would work but this will be easier to control in tests. retentionWindow := timeutil.MonthsPreviousTo(a.retentionMonths, time.Unix(intent.NextMonth, 0).UTC()) a.l.RUnlock() - if currentMonth != 0 && intent.NextMonth != currentMonth { + if strictEnforcement && currentMonth != 0 && intent.NextMonth != currentMonth { a.logger.Warn("intent log does not match current segment", "intent", intent.NextMonth, "current", currentMonth) return errors.New("intent log is too far in the past") } lastMonth := intent.PreviousMonth - a.logger.Info("computing queries", "month", time.Unix(lastMonth, 0).UTC()) + lastMonthTime := time.Unix(lastMonth, 0) + a.logger.Info("computing queries", "month", lastMonthTime.UTC()) - times, err := a.availableLogs(ctx) + times, err := a.availableLogs(ctx, lastMonthTime) if err != nil { a.logger.Warn("could not list available logs", "error", err) return err } if len(times) == 0 { a.logger.Warn("no months in storage") - a.view.Delete(ctx, activityIntentLogKey) + cleanupIntentLog() return errors.New("previous month not found") } if times[0].Unix() != lastMonth { a.logger.Warn("last month not in storage", "latest", times[0].Unix()) - a.view.Delete(ctx, activityIntentLogKey) + cleanupIntentLog() return errors.New("previous month not found") } @@ -2537,9 +2561,7 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error { return err } } - - // delete the intent log - a.view.Delete(ctx, activityIntentLogKey) + cleanupIntentLog() a.logger.Info("finished computing queries", "month", endTime) @@ -2579,7 +2601,7 @@ func (a *ActivityLog) retentionWorker(ctx context.Context, currentTime time.Time // everything >= the threshold is OK retentionThreshold := timeutil.MonthsPreviousTo(retentionMonths, currentTime) - available, err := a.availableLogs(ctx) + available, err := a.availableLogs(ctx, retentionThreshold) if err != nil { a.logger.Warn("could not list segments", "error", err) return err @@ -2892,7 +2914,7 @@ func (a *ActivityLog) writeExport(ctx context.Context, rw http.ResponseWriter, f // Find the months with activity log data that are between the start and end // months. We want to walk this in cronological order so the oldest instance of a // client usage is recorded, not the most recent. - times, err := a.availableLogs(ctx) + times, err := a.availableLogs(ctx, endTime) if err != nil { a.logger.Warn("failed to list available log segments", "error", err) return fmt.Errorf("failed to list available log segments: %w", err) diff --git a/vault/activity_log_test.go b/vault/activity_log_test.go index 80048a85a0..3869502ca9 100644 --- a/vault/activity_log_test.go +++ b/vault/activity_log_test.go @@ -624,7 +624,7 @@ func TestActivityLog_availableLogsEmptyDirectory(t *testing.T) { // verify that directory is empty, and nothing goes wrong core, _, _ := TestCoreUnsealed(t) a := core.activityLog - times, err := a.availableLogs(context.Background()) + times, err := a.availableLogs(context.Background(), time.Now()) if err != nil { t.Fatalf("error getting start_time(s) for empty activity log") } @@ -647,7 +647,7 @@ func TestActivityLog_availableLogs(t *testing.T) { } // verify above files are there, and dates in correct order - times, err := a.availableLogs(context.Background()) + times, err := a.availableLogs(context.Background(), time.Now()) if err != nil { t.Fatalf("error getting start_time(s) for activity log") } @@ -2725,7 +2725,7 @@ func TestActivityLog_CalculatePrecomputedQueriesWithMixedTWEs(t *testing.T) { // Pretend we've successfully rolled over to the following month a.SetStartTimestamp(tc.NextMonth) - err = a.precomputedQueryWorker(ctx) + err = a.precomputedQueryWorker(ctx, nil) if err != nil { t.Fatal(err) } @@ -3106,7 +3106,7 @@ func TestActivityLog_Precompute(t *testing.T) { // Pretend we've successfully rolled over to the following month a.SetStartTimestamp(tc.NextMonth) - err = a.precomputedQueryWorker(ctx) + err = a.precomputedQueryWorker(ctx, nil) if err != nil { t.Fatal(err) } @@ -3368,7 +3368,7 @@ func TestActivityLog_Precompute_SkipMonth(t *testing.T) { // Pretend we've successfully rolled over to the following month a.SetStartTimestamp(tc.NextMonth) - err = a.precomputedQueryWorker(ctx) + err = a.precomputedQueryWorker(ctx, nil) if err != nil { t.Fatal(err) } @@ -3634,7 +3634,7 @@ func TestActivityLog_PrecomputeNonEntityTokensWithID(t *testing.T) { // Pretend we've successfully rolled over to the following month a.SetStartTimestamp(tc.NextMonth) - err = a.precomputedQueryWorker(ctx) + err = a.precomputedQueryWorker(ctx, nil) if err != nil { t.Fatal(err) } @@ -3761,7 +3761,7 @@ func TestActivityLog_PrecomputeCancel(t *testing.T) { // This will block if the shutdown didn't work. go func() { // We expect this to error because of BlockingInmemStorage - _ = a.precomputedQueryWorker(namespace.RootContext(nil)) + _ = a.precomputedQueryWorker(namespace.RootContext(nil), nil) close(done) }() diff --git a/vault/activity_log_testonly_test.go b/vault/activity_log_testonly_test.go new file mode 100644 index 0000000000..991d9302d8 --- /dev/null +++ b/vault/activity_log_testonly_test.go @@ -0,0 +1,117 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build testonly + +package vault + +import ( + "context" + "testing" + "time" + + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/timeutil" + "github.com/hashicorp/vault/sdk/helper/clientcountutil" + "github.com/hashicorp/vault/sdk/helper/clientcountutil/generation" + "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/require" +) + +// TestActivityLog_doPrecomputedQueryCreation creates segments for the last 4 +// months and then calls doPrecomputedQueryCreation, in order of oldest to most +// recent month. The test verifies that the count of clients in the generated +// precomputed query is equal to the number of deduplicated clients. +func TestActivityLog_doPrecomputedQueryCreation(t *testing.T) { + core, _, token := TestCoreUnsealed(t) + a := core.activityLog + a.SetEnable(true) + + j, err := clientcountutil.NewActivityLogData(nil). + // 8 new clients + // across two segments + NewPreviousMonthData(4). + Segment().NewClientsSeen(5). + Segment().NewClientsSeen(3). + + // 2 repeated clients + // 10 new clients + // across 3 segments + NewPreviousMonthData(3). + Segment().RepeatedClientsSeen(2). + NewClientsSeen(3). + Segment().NewClientsSeen(2). + Segment().NewClientsSeen(5). + + // 7 new clients + // single segment + NewPreviousMonthData(2). + NewClientsSeen(7). + + // 6 repeated clients + // 5 new clients + // across 2 segments + NewPreviousMonthData(1). + Segment().NewClientsSeen(5). + Segment().RepeatedClientsSeen(6). + ToJSON(generation.WriteOptions_WRITE_ENTITIES) + require.NoError(t, err) + + r := logical.TestRequest(t, logical.UpdateOperation, "sys/internal/counters/activity/write") + r.Data["input"] = string(j) + r.ClientToken = token + _, err = core.HandleRequest(namespace.RootContext(context.Background()), r) + require.NoError(t, err) + + now := time.Now().UTC() + times := map[int]time.Time{} + for i := 1; i < 5; i++ { + times[i] = timeutil.StartOfMonth(timeutil.MonthsPreviousTo(i, now)) + } + + testCases := []struct { + name string + generateUpToMonth int + strictEnforcement bool + wantClients int + }{ + { + name: "only 4 months ago", + generateUpToMonth: 4, + wantClients: 8, // 8 clients from month 4 + }, + { + name: "3 months ago", + generateUpToMonth: 3, + // 8 clients (month 4) + 10 new clients (month 3) + wantClients: 18, + }, + { + name: "2 months ago", + generateUpToMonth: 2, + // 8 clients (month 4) + 10 new clients (month 3) + 7 new clients + // (month 2) + wantClients: 25, + }, + { + name: "1 month ago", + generateUpToMonth: 1, + // 8 clients (month 4) + 10 new clients (month 3) + 7 new clients + // (month 2) + 5 new clients (month 1) + wantClients: 30, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + generateUpTo := times[tc.generateUpToMonth] + nextMonth := timeutil.StartOfNextMonth(generateUpTo) + err = a.precomputedQueryWorker(context.Background(), &ActivityIntentLog{PreviousMonth: generateUpTo.Unix(), NextMonth: nextMonth.Unix()}) + require.NoError(t, err) + + // get precomputed queries spanning the whole time period + pq, err := a.queryStore.Get(context.Background(), times[4], now) + require.NoError(t, err) + require.Equal(t, tc.wantClients, int(pq.Namespaces[0].Entities)) + }) + } +} diff --git a/vault/logical_system_activity_write_testonly_test.go b/vault/logical_system_activity_write_testonly_test.go index fd2ce2b38d..f3f2577354 100644 --- a/vault/logical_system_activity_write_testonly_test.go +++ b/vault/logical_system_activity_write_testonly_test.go @@ -526,7 +526,7 @@ func Test_handleActivityWriteData(t *testing.T) { paths := resp.Data["paths"].([]string) require.Len(t, paths, 9) - times, err := core.activityLog.availableLogs(context.Background()) + times, err := core.activityLog.availableLogs(context.Background(), time.Now()) require.NoError(t, err) require.Len(t, times, 4) @@ -645,7 +645,7 @@ func Test_handleActivityWriteData(t *testing.T) { require.Equal(t, timeutil.StartOfMonth(now), next.UTC()) require.Equal(t, timeutil.StartOfMonth(timeutil.MonthsPreviousTo(3, now)), prev.UTC()) - times, err := core.activityLog.availableLogs(context.Background()) + times, err := core.activityLog.availableLogs(context.Background(), time.Now()) require.NoError(t, err) require.Len(t, times, 4) })