From 018ea84997b49137ae3884e00e4dc9fc389f8b50 Mon Sep 17 00:00:00 2001 From: miagilepner Date: Tue, 23 May 2023 18:25:23 +0200 Subject: [PATCH] VAULT-15395: Support mocking time functions in the activity log (#20720) * mock time in the activity log * cleanup * fix comment * pr fixes * update comment to explain why new timer is needed --- helper/metricsutil/gauge_process.go | 25 ++------- helper/metricsutil/gauge_process_test.go | 6 +- helper/timeutil/timeutil.go | 23 ++++++++ vault/activity_log.go | 67 ++++++++++++++-------- vault/activity_log_test.go | 71 +++++++++++++++++++++++- vault/activity_log_testing_util.go | 6 +- vault/activity_log_util_common.go | 2 +- 7 files changed, 148 insertions(+), 52 deletions(-) diff --git a/helper/metricsutil/gauge_process.go b/helper/metricsutil/gauge_process.go index f471249d75..c6fcd56639 100644 --- a/helper/metricsutil/gauge_process.go +++ b/helper/metricsutil/gauge_process.go @@ -11,24 +11,9 @@ import ( "github.com/armon/go-metrics" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/timeutil" ) -// This interface allows unit tests to substitute in a simulated clock. -type clock interface { - Now() time.Time - NewTicker(time.Duration) *time.Ticker -} - -type defaultClock struct{} - -func (_ defaultClock) Now() time.Time { - return time.Now() -} - -func (_ defaultClock) NewTicker(d time.Duration) *time.Ticker { - return time.NewTicker(d) -} - // GaugeLabelValues is one gauge in a set sharing a single key, that // are measured in a batch. type GaugeLabelValues struct { @@ -76,7 +61,7 @@ type GaugeCollectionProcess struct { maxGaugeCardinality int // time source - clock clock + clock timeutil.Clock } // NewGaugeCollectionProcess creates a new collection process for the callback @@ -101,7 +86,7 @@ func NewGaugeCollectionProcess( gaugeInterval, maxGaugeCardinality, logger, - defaultClock{}, + timeutil.DefaultClock{}, ) } @@ -124,7 +109,7 @@ func (m *ClusterMetricSink) NewGaugeCollectionProcess( m.GaugeInterval, m.MaxGaugeCardinality, logger, - defaultClock{}, + timeutil.DefaultClock{}, ) } @@ -137,7 +122,7 @@ func newGaugeCollectionProcessWithClock( gaugeInterval time.Duration, maxGaugeCardinality int, logger log.Logger, - clock clock, + clock timeutil.Clock, ) (*GaugeCollectionProcess, error) { process := &GaugeCollectionProcess{ stop: make(chan struct{}, 1), diff --git a/helper/metricsutil/gauge_process_test.go b/helper/metricsutil/gauge_process_test.go index 83165a997b..efd74e707d 100644 --- a/helper/metricsutil/gauge_process_test.go +++ b/helper/metricsutil/gauge_process_test.go @@ -15,6 +15,7 @@ import ( "github.com/armon/go-metrics" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/timeutil" ) // SimulatedTime maintains a virtual clock so the test isn't @@ -24,9 +25,10 @@ import ( type SimulatedTime struct { now time.Time tickerBarrier chan *SimulatedTicker + timeutil.DefaultClock } -var _ clock = &SimulatedTime{} +var _ timeutil.Clock = &SimulatedTime{} type SimulatedTicker struct { ticker *time.Ticker @@ -121,7 +123,7 @@ func TestGauge_Creation(t *testing.T) { t.Fatalf("Error creating collection process: %v", err) } - if _, ok := p.clock.(defaultClock); !ok { + if _, ok := p.clock.(timeutil.DefaultClock); !ok { t.Error("Default clock not installed.") } diff --git a/helper/timeutil/timeutil.go b/helper/timeutil/timeutil.go index 89daab7d40..16f8343513 100644 --- a/helper/timeutil/timeutil.go +++ b/helper/timeutil/timeutil.go @@ -142,3 +142,26 @@ func SkipAtEndOfMonth(t *testing.T) { t.Skip("too close to end of month") } } + +// This interface allows unit tests to substitute in a simulated Clock. +type Clock interface { + Now() time.Time + NewTicker(time.Duration) *time.Ticker + NewTimer(time.Duration) *time.Timer +} + +type DefaultClock struct{} + +var _ Clock = (*DefaultClock)(nil) + +func (_ DefaultClock) Now() time.Time { + return time.Now() +} + +func (_ DefaultClock) NewTicker(d time.Duration) *time.Ticker { + return time.NewTicker(d) +} + +func (_ DefaultClock) NewTimer(d time.Duration) *time.Timer { + return time.NewTimer(d) +} diff --git a/vault/activity_log.go b/vault/activity_log.go index 32f0570ad0..8d11a7ab5c 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -188,6 +188,12 @@ type ActivityLog struct { // CensusReportInterval is the testing configuration for time between // Write() calls initiated in CensusReport. CensusReportInterval time.Duration + + // clock is used to support manipulating time in unit and integration tests + clock timeutil.Clock + // precomputedQueryWritten receives an element whenever a precomputed query + // is written. It's used for unit testing + precomputedQueryWritten chan struct{} } // These non-persistent configuration options allow us to disable @@ -205,6 +211,10 @@ type ActivityLogCoreConfig struct { // MinimumRetentionMonths defines the minimum value for retention MinimumRetentionMonths int + + // Clock holds a custom clock to modify time.Now, time.Ticker, time.Timer. + // If nil, the default functions from the time package are used + Clock timeutil.Clock } // NewActivityLog creates an activity log. @@ -214,6 +224,10 @@ func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics me return nil, err } + clock := core.activityLogConfig.Clock + if clock == nil { + clock = timeutil.DefaultClock{} + } a := &ActivityLog{ core: core, configOverrides: &core.activityLogConfig, @@ -227,7 +241,7 @@ func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics me doneCh: make(chan struct{}, 1), partialMonthClientTracker: make(map[string]*activity.EntityRecord), CensusReportInterval: time.Hour * 1, - + clock: clock, currentSegment: segmentInfo{ startTimestamp: 0, currentClients: &activity.EntityActivityLog{ @@ -243,6 +257,7 @@ func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics me }, standbyFragmentsReceived: make([]*activity.LogFragment, 0), inprocessExport: atomic.NewBool(false), + precomputedQueryWritten: make(chan struct{}), } config, err := a.loadConfigOrDefault(core.activeContext) @@ -274,7 +289,7 @@ func (a *ActivityLog) saveCurrentSegmentToStorage(ctx context.Context, force boo // :force: forces a save of tokens/entities even if the in-memory log is empty func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, force bool) error { defer a.metrics.MeasureSinceWithLabels([]string{"core", "activity", "segment_write"}, - time.Now(), []metricsutil.Label{}) + a.clock.Now(), []metricsutil.Label{}) // Swap out the pending fragments a.fragmentLock.Lock() @@ -433,7 +448,7 @@ func (a *ActivityLog) saveCurrentSegmentInternal(ctx context.Context, force bool case err != nil: a.logger.Error(fmt.Sprintf("unable to retrieve oldest version timestamp: %s", err.Error())) case len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 && - (oldestUpgradeTime.Add(time.Duration(trackedTWESegmentPeriod * time.Hour)).Before(time.Now())): + (oldestUpgradeTime.Add(time.Duration(trackedTWESegmentPeriod * time.Hour)).Before(a.clock.Now())): a.logger.Error(fmt.Sprintf("storing nonzero token count over a month after vault was upgraded to %s", oldestVersion)) default: if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { @@ -1005,7 +1020,7 @@ func (a *ActivityLog) SetConfig(ctx context.Context, config activityConfig) { forceSave := false if a.enabled && a.currentSegment.startTimestamp == 0 { - a.startNewCurrentLogLocked(time.Now().UTC()) + a.startNewCurrentLogLocked(a.clock.Now().UTC()) // Force a save so we can distinguish between // // Month N-1: present @@ -1031,7 +1046,7 @@ func (a *ActivityLog) SetConfig(ctx context.Context, config activityConfig) { } // check for segments out of retention period, if it has changed - go a.retentionWorker(ctx, time.Now(), a.retentionMonths) + go a.retentionWorker(ctx, a.clock.Now(), a.retentionMonths) } // update the enable flag and reset the current log @@ -1097,7 +1112,7 @@ func (c *Core) setupActivityLogLocked(ctx context.Context, wg *sync.WaitGroup) e c.activityLog = manager // load activity log for "this month" into memory - err = manager.refreshFromStoredLog(manager.core.activeContext, wg, time.Now().UTC()) + err = manager.refreshFromStoredLog(manager.core.activeContext, wg, manager.clock.Now().UTC()) if err != nil { return err } @@ -1121,7 +1136,7 @@ func (c *Core) setupActivityLogLocked(ctx context.Context, wg *sync.WaitGroup) e // Signal when this is done so that unit tests can proceed. manager.retentionDone = make(chan struct{}) go func(months int) { - manager.retentionWorker(ctx, time.Now(), months) + manager.retentionWorker(ctx, manager.clock.Now(), months) close(manager.retentionDone) }(manager.retentionMonths) @@ -1158,7 +1173,7 @@ func (a *ActivityLog) StartOfNextMonth() time.Time { defer a.l.RUnlock() var segmentStart time.Time if a.currentSegment.startTimestamp == 0 { - segmentStart = time.Now().UTC() + segmentStart = a.clock.Now().UTC() } else { segmentStart = time.Unix(a.currentSegment.startTimestamp, 0).UTC() } @@ -1170,12 +1185,12 @@ func (a *ActivityLog) StartOfNextMonth() time.Time { // perfStandbyFragmentWorker handles scheduling fragments // to send via RPC; it runs on perf standby nodes only. func (a *ActivityLog) perfStandbyFragmentWorker(ctx context.Context) { - timer := time.NewTimer(time.Duration(0)) + timer := a.clock.NewTimer(time.Duration(0)) fragmentWaiting := false // Eat first event, so timer is stopped <-timer.C - endOfMonth := time.NewTimer(a.StartOfNextMonth().Sub(time.Now())) + endOfMonth := a.clock.NewTimer(a.StartOfNextMonth().Sub(a.clock.Now())) if a.configOverrides.DisableTimers { endOfMonth.Stop() } @@ -1247,8 +1262,8 @@ func (a *ActivityLog) perfStandbyFragmentWorker(ctx context.Context) { // Set timer for next month. // The current segment *probably* hasn't been set yet (via invalidation), // so don't rely on it. - target := timeutil.StartOfNextMonth(time.Now().UTC()) - endOfMonth.Reset(target.Sub(time.Now())) + target := timeutil.StartOfNextMonth(a.clock.Now().UTC()) + endOfMonth.Reset(target.Sub(a.clock.Now())) } } } @@ -1256,9 +1271,9 @@ func (a *ActivityLog) perfStandbyFragmentWorker(ctx context.Context) { // activeFragmentWorker handles scheduling the write of the next // segment. It runs on active nodes only. func (a *ActivityLog) activeFragmentWorker(ctx context.Context) { - ticker := time.NewTicker(activitySegmentInterval) + ticker := a.clock.NewTicker(activitySegmentInterval) - endOfMonth := time.NewTimer(a.StartOfNextMonth().Sub(time.Now())) + endOfMonth := a.clock.NewTimer(a.StartOfNextMonth().Sub(a.clock.Now())) if a.configOverrides.DisableTimers { endOfMonth.Stop() } @@ -1308,7 +1323,7 @@ func (a *ActivityLog) activeFragmentWorker(ctx context.Context) { // Reset the schedule to wait 10 minutes from this forced write. ticker.Stop() - ticker = time.NewTicker(activitySegmentInterval) + ticker = a.clock.NewTicker(activitySegmentInterval) // Simpler, but ticker.Reset was introduced in go 1.15: // ticker.Reset(activitySegmentInterval) @@ -1324,7 +1339,7 @@ func (a *ActivityLog) activeFragmentWorker(ctx context.Context) { go a.retentionWorker(ctx, currentTime.UTC(), a.retentionMonths) a.l.RUnlock() - delta := a.StartOfNextMonth().Sub(time.Now()) + delta := a.StartOfNextMonth().Sub(a.clock.Now()) if delta < 20*time.Minute { delta = 20 * time.Minute } @@ -1513,7 +1528,7 @@ func (a *ActivityLog) createCurrentFragment() { Clients: make([]*activity.EntityRecord, 0, 120), NonEntityTokens: make(map[string]uint64), } - a.fragmentCreation = time.Now().UTC() + a.fragmentCreation = a.clock.Now().UTC() // Signal that a new segment is available, start // the timer to send it. @@ -1613,13 +1628,13 @@ func (a *ActivityLog) handleQuery(ctx context.Context, startTime, endTime time.T // with the endTime equal to the end of the last month, and add in the current month // data. precomputedQueryEndTime := endTime - if timeutil.IsCurrentMonth(endTime, time.Now().UTC()) { + if timeutil.IsCurrentMonth(endTime, a.clock.Now().UTC()) { precomputedQueryEndTime = timeutil.EndOfMonth(timeutil.MonthsPreviousTo(1, timeutil.StartOfMonth(endTime))) computePartial = true } pq := &activity.PrecomputedQuery{} - if startTime.After(precomputedQueryEndTime) && timeutil.IsCurrentMonth(startTime, time.Now().UTC()) { + if startTime.After(precomputedQueryEndTime) && timeutil.IsCurrentMonth(startTime, a.clock.Now().UTC()) { // We're only calculating the partial month client count. Skip the precomputation // get call. pq = &activity.PrecomputedQuery{ @@ -1794,7 +1809,7 @@ func (a *ActivityLog) handleQuery(ctx context.Context, startTime, endTime time.T a.sortActivityLogMonthsResponse(months) // Modify the final month output to make response more consumable based on API request - months = modifyResponseMonths(months, startTime, endTime) + months = a.modifyResponseMonths(months, startTime, endTime) responseData["months"] = months return responseData, nil @@ -1802,13 +1817,13 @@ func (a *ActivityLog) handleQuery(ctx context.Context, startTime, endTime time.T // modifyResponseMonths fills out various parts of the query structure to help // activity log clients parse the returned query. -func modifyResponseMonths(months []*ResponseMonth, start time.Time, end time.Time) []*ResponseMonth { +func (a *ActivityLog) modifyResponseMonths(months []*ResponseMonth, start time.Time, end time.Time) []*ResponseMonth { if len(months) == 0 { return months } start = timeutil.StartOfMonth(start) end = timeutil.EndOfMonth(end) - if timeutil.IsCurrentMonth(end, time.Now().UTC()) { + if timeutil.IsCurrentMonth(end, a.clock.Now().UTC()) { end = timeutil.EndOfMonth(timeutil.StartOfMonth(end).AddDate(0, -1, 0)) } modifiedResponseMonths := make([]*ResponseMonth, 0) @@ -2328,7 +2343,7 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error { // If there's an intent log, finish it even if the feature is currently disabled. a.l.RLock() currentMonth := a.currentSegment.startTimestamp - // Base retention period on the month we are generating (even in the past)--- time.Now() + // Base retention period on the month we are generating (even in the past)--- a.clock.Now() // would work but this will be easier to control in tests. retentionWindow := timeutil.MonthsPreviousTo(a.retentionMonths, time.Unix(intent.NextMonth, 0).UTC()) a.l.RUnlock() @@ -2396,6 +2411,10 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error { a.logger.Info("finished computing queries", "month", endTime) + select { + case a.precomputedQueryWritten <- struct{}{}: + default: + } return nil } @@ -2489,7 +2508,7 @@ func (a *ActivityLog) populateNamespaceAndMonthlyBreakdowns() (map[int64]*proces byNamespace := make(map[string]*processByNamespace) byMonth := make(map[int64]*processMonth) for _, e := range a.partialMonthClientTracker { - processClientRecord(e, byNamespace, byMonth, time.Now()) + processClientRecord(e, byNamespace, byMonth, a.clock.Now()) } return byMonth, byNamespace } diff --git a/vault/activity_log_test.go b/vault/activity_log_test.go index 19316611b8..db85e0463d 100644 --- a/vault/activity_log_test.go +++ b/vault/activity_log_test.go @@ -522,11 +522,13 @@ func TestActivityLog_StoreAndReadHyperloglog(t *testing.T) { // TestModifyResponseMonthsNilAppend calls modifyResponseMonths for a range of 5 months ago to now. It verifies that the // 5 months in the range are correct. func TestModifyResponseMonthsNilAppend(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + a := core.activityLog end := time.Now().UTC() start := timeutil.StartOfMonth(end).AddDate(0, -5, 0) responseMonthTimestamp := timeutil.StartOfMonth(end).AddDate(0, -3, 0).Format(time.RFC3339) responseMonths := []*ResponseMonth{{Timestamp: responseMonthTimestamp}} - months := modifyResponseMonths(responseMonths, start, end) + months := a.modifyResponseMonths(responseMonths, start, end) if len(months) != 5 { t.Fatal("wrong number of months padded") } @@ -4679,3 +4681,70 @@ func TestActivityLog_writePrecomputedQuery(t *testing.T) { require.Equal(t, 1, monthRecord.NewClients.Counts.EntityClients) require.Equal(t, 1, monthRecord.NewClients.Counts.NonEntityClients) } + +type mockTimeNowClock struct { + timeutil.DefaultClock + start time.Time + created time.Time +} + +func newMockTimeNowClock(startAt time.Time) timeutil.Clock { + return &mockTimeNowClock{start: startAt, created: time.Now()} +} + +// NewTimer returns a timer with a channel that will return the correct time, +// relative to the starting time. This is used when testing the +// activeFragmentWorker, as that function uses the returned value from timer.C +// to perform additional functionality +func (m mockTimeNowClock) NewTimer(d time.Duration) *time.Timer { + timerStarted := m.Now() + t := time.NewTimer(d) + readCh := t.C + writeCh := make(chan time.Time, 1) + go func() { + <-readCh + writeCh <- timerStarted.Add(d) + }() + t.C = writeCh + return t +} + +func (m mockTimeNowClock) Now() time.Time { + return m.start.Add(time.Since(m.created)) +} + +// TestActivityLog_HandleEndOfMonth runs the activity log with a mock clock. +// The current time is set to be 3 seconds before the end of a month. The test +// verifies that the precomputedQueryWorker runs and writes precomputed queries +// with the proper start and end times when the end of the month is triggered +func TestActivityLog_HandleEndOfMonth(t *testing.T) { + // 3 seconds until a new month + now := time.Date(2021, 1, 31, 23, 59, 57, 0, time.UTC) + core, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ActivityLogConfig: ActivityLogCoreConfig{Clock: newMockTimeNowClock(now)}}) + done := make(chan struct{}) + go func() { + defer close(done) + <-core.activityLog.precomputedQueryWritten + }() + core.activityLog.SetEnable(true) + core.activityLog.SetStartTimestamp(now.Unix()) + core.activityLog.AddClientToFragment("id", "ns", now.Unix(), false, "mount") + + // wait for the end of month to be triggered + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for precomputed query") + } + + // verify that a precomputed query was written + exists, err := core.activityLog.queryStore.QueriesAvailable(context.Background()) + require.NoError(t, err) + require.True(t, exists) + + // verify that the timestamp is correct + pq, err := core.activityLog.queryStore.Get(context.Background(), now, now.Add(24*time.Hour)) + require.NoError(t, err) + require.Equal(t, now, pq.StartTime) + require.Equal(t, timeutil.EndOfMonth(now), pq.EndTime) +} diff --git a/vault/activity_log_testing_util.go b/vault/activity_log_testing_util.go index 2561fc98b7..25e0c900c1 100644 --- a/vault/activity_log_testing_util.go +++ b/vault/activity_log_testing_util.go @@ -8,10 +8,8 @@ import ( "fmt" "math/rand" "testing" - "time" "github.com/hashicorp/vault/helper/constants" - "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/activity" ) @@ -32,7 +30,7 @@ func (c *Core) InjectActivityLogDataThisMonth(t *testing.T) map[string]*activity ClientID: fmt.Sprintf("testclientid-%d", i), NamespaceID: "root", MountAccessor: fmt.Sprintf("testmountaccessor-%d", i), - Timestamp: time.Now().Unix(), + Timestamp: c.activityLog.clock.Now().Unix(), NonEntity: i%2 == 0, } c.activityLog.partialMonthClientTracker[er.ClientID] = er @@ -45,7 +43,7 @@ func (c *Core) InjectActivityLogDataThisMonth(t *testing.T) map[string]*activity ClientID: fmt.Sprintf("ns-%d-testclientid-%d", j, i), NamespaceID: fmt.Sprintf("ns-%d", j), MountAccessor: fmt.Sprintf("ns-%d-testmountaccessor-%d", j, i), - Timestamp: time.Now().Unix(), + Timestamp: c.activityLog.clock.Now().Unix(), NonEntity: i%2 == 0, } c.activityLog.partialMonthClientTracker[er.ClientID] = er diff --git a/vault/activity_log_util_common.go b/vault/activity_log_util_common.go index 3ae8915534..10a3735e6f 100644 --- a/vault/activity_log_util_common.go +++ b/vault/activity_log_util_common.go @@ -75,7 +75,7 @@ func (a *ActivityLog) StoreHyperlogLog(ctx context.Context, startTime time.Time, } func (a *ActivityLog) computeCurrentMonthForBillingPeriodInternal(ctx context.Context, byMonth map[int64]*processMonth, hllGetFunc HLLGetter, startTime time.Time, endTime time.Time) (*activity.MonthRecord, error) { - if timeutil.IsCurrentMonth(startTime, time.Now().UTC()) { + if timeutil.IsCurrentMonth(startTime, a.clock.Now().UTC()) { monthlyComputation := a.transformMonthBreakdowns(byMonth) if len(monthlyComputation) > 1 { a.logger.Warn("monthly in-memory activitylog computation returned multiple months of data", "months returned", len(byMonth))