VAULT-24582: Refactor precomputed query worker to support ACME regeneration (#26364)

* refactor and add tests

* fix write options

* forgot that this endpoint was testonly

* refactor the refactor

* fix test failures and rename variable
This commit is contained in:
miagilepner
2024-04-29 14:28:37 +02:00
committed by GitHub
parent 6ccb2bd64a
commit 46cd5bbf32
5 changed files with 189 additions and 47 deletions

View File

@@ -267,7 +267,10 @@ func (d *ActivityLogDataGenerator) Segment(opts ...SegmentOption) *ActivityLogDa
}
// ToJSON returns the JSON representation of the data
func (d *ActivityLogDataGenerator) ToJSON() ([]byte, error) {
func (d *ActivityLogDataGenerator) ToJSON(writeOptions ...generation.WriteOptions) ([]byte, error) {
if len(writeOptions) > 0 {
d.data.Write = writeOptions
}
return protojson.Marshal(d.data)
}

View File

@@ -511,7 +511,7 @@ func parseSegmentNumberFromPath(path string) (int, bool) {
// availableLogs returns the start_time(s) (in UTC) associated with months for which logs exist,
// sorted last to first
func (a *ActivityLog) availableLogs(ctx context.Context) ([]time.Time, error) {
func (a *ActivityLog) availableLogs(ctx context.Context, upTo time.Time) ([]time.Time, error) {
paths := make([]string, 0)
for _, basePath := range []string{activityEntityBasePath, activityTokenBasePath} {
p, err := a.view.List(ctx, basePath)
@@ -526,14 +526,17 @@ func (a *ActivityLog) availableLogs(ctx context.Context) ([]time.Time, error) {
out := make([]time.Time, 0)
for _, path := range paths {
// generate a set of unique start times
time, err := timeutil.ParseTimeFromPath(path)
segmentTime, err := timeutil.ParseTimeFromPath(path)
if err != nil {
return nil, err
}
if segmentTime.After(upTo) {
continue
}
if _, present := pathSet[time]; !present {
pathSet[time] = struct{}{}
out = append(out, time)
if _, present := pathSet[segmentTime]; !present {
pathSet[segmentTime] = struct{}{}
out = append(out, segmentTime)
}
}
@@ -542,15 +545,15 @@ func (a *ActivityLog) availableLogs(ctx context.Context) ([]time.Time, error) {
return out[i].After(out[j])
})
a.logger.Trace("scanned existing logs", "out", out)
a.logger.Trace("scanned existing logs", "out", out, "up to", upTo)
return out, nil
}
// getMostRecentActivityLogSegment gets the times (in UTC) associated with the most recent
// contiguous set of activity logs, sorted in decreasing order (latest to earliest)
func (a *ActivityLog) getMostRecentActivityLogSegment(ctx context.Context) ([]time.Time, error) {
logTimes, err := a.availableLogs(ctx)
func (a *ActivityLog) getMostRecentActivityLogSegment(ctx context.Context, now time.Time) ([]time.Time, error) {
logTimes, err := a.availableLogs(ctx, now)
if err != nil {
return nil, err
}
@@ -892,7 +895,7 @@ func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGro
a.fragmentLock.Lock()
defer a.fragmentLock.Unlock()
decreasingLogTimes, err := a.getMostRecentActivityLogSegment(ctx)
decreasingLogTimes, err := a.getMostRecentActivityLogSegment(ctx, now)
if err != nil {
return err
}
@@ -1156,7 +1159,7 @@ func (c *Core) setupActivityLogLocked(ctx context.Context, wg *sync.WaitGroup) e
// Check for any intent log, in the background
manager.computationWorkerDone = make(chan struct{})
go func() {
manager.precomputedQueryWorker(ctx)
manager.precomputedQueryWorker(ctx, nil)
close(manager.computationWorkerDone)
}()
@@ -1174,7 +1177,7 @@ func (c *Core) setupActivityLogLocked(ctx context.Context, wg *sync.WaitGroup) e
func (a *ActivityLog) createRegenerationIntentLog(ctx context.Context, now time.Time) (*ActivityIntentLog, error) {
intentLog := &ActivityIntentLog{}
segments, err := a.availableLogs(ctx)
segments, err := a.availableLogs(ctx, now)
if err != nil {
return nil, fmt.Errorf("error fetching available logs: %w", err)
}
@@ -1439,7 +1442,7 @@ func (a *ActivityLog) HandleEndOfMonth(ctx context.Context, currentTime time.Tim
a.fragmentLock.Unlock()
// Work on precomputed queries in background
go a.precomputedQueryWorker(ctx)
go a.precomputedQueryWorker(ctx, nil)
return nil
}
@@ -2431,7 +2434,9 @@ func (a *ActivityLog) reportPrecomputedQueryMetrics(ctx context.Context, segment
// goroutine to process the request in the intent log, creating precomputed queries.
// We expect the return value won't be checked, so log errors as they occur
// (but for unit testing having the error return should help.)
func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error {
// If the intent log that's passed into the function is non-nil, we use that
// intent log. Otherwise, we read the intent log from storage
func (a *ActivityLog) precomputedQueryWorker(ctx context.Context, intent *ActivityIntentLog) (err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -2452,21 +2457,39 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error {
}(a.doneCh)
a.l.RUnlock()
// Load the intent log
rawIntentLog, err := a.view.Get(ctx, activityIntentLogKey)
if err != nil {
a.logger.Warn("could not load intent log", "error", err)
return err
strictEnforcement := intent == nil
shouldCleanupIntentLog := false
if intent == nil {
// Load the intent log
rawIntentLog, err := a.view.Get(ctx, activityIntentLogKey)
if err != nil {
a.logger.Warn("could not load intent log", "error", err)
return err
}
if rawIntentLog == nil {
a.logger.Trace("no intent log found")
return err
}
intent = new(ActivityIntentLog)
err = json.Unmarshal(rawIntentLog.Value, intent)
if err != nil {
a.logger.Warn("could not parse intent log", "error", err)
return err
}
shouldCleanupIntentLog = true
}
if rawIntentLog == nil {
a.logger.Trace("no intent log found")
return err
}
var intent ActivityIntentLog
err = json.Unmarshal(rawIntentLog.Value, &intent)
if err != nil {
a.logger.Warn("could not parse intent log", "error", err)
return err
cleanupIntentLog := func() {
if !shouldCleanupIntentLog {
return
}
// delete the intent log
// this should happen if the precomputed queries were generated
// successfully (i.e. err is nil) or if there's no data for the previous
// month.
// It should not happen in the general error case
a.view.Delete(ctx, activityIntentLogKey)
}
// currentMonth could change (from another month end) after we release the lock.
@@ -2479,28 +2502,29 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error {
// would work but this will be easier to control in tests.
retentionWindow := timeutil.MonthsPreviousTo(a.retentionMonths, time.Unix(intent.NextMonth, 0).UTC())
a.l.RUnlock()
if currentMonth != 0 && intent.NextMonth != currentMonth {
if strictEnforcement && currentMonth != 0 && intent.NextMonth != currentMonth {
a.logger.Warn("intent log does not match current segment",
"intent", intent.NextMonth, "current", currentMonth)
return errors.New("intent log is too far in the past")
}
lastMonth := intent.PreviousMonth
a.logger.Info("computing queries", "month", time.Unix(lastMonth, 0).UTC())
lastMonthTime := time.Unix(lastMonth, 0)
a.logger.Info("computing queries", "month", lastMonthTime.UTC())
times, err := a.availableLogs(ctx)
times, err := a.availableLogs(ctx, lastMonthTime)
if err != nil {
a.logger.Warn("could not list available logs", "error", err)
return err
}
if len(times) == 0 {
a.logger.Warn("no months in storage")
a.view.Delete(ctx, activityIntentLogKey)
cleanupIntentLog()
return errors.New("previous month not found")
}
if times[0].Unix() != lastMonth {
a.logger.Warn("last month not in storage", "latest", times[0].Unix())
a.view.Delete(ctx, activityIntentLogKey)
cleanupIntentLog()
return errors.New("previous month not found")
}
@@ -2537,9 +2561,7 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context) error {
return err
}
}
// delete the intent log
a.view.Delete(ctx, activityIntentLogKey)
cleanupIntentLog()
a.logger.Info("finished computing queries", "month", endTime)
@@ -2579,7 +2601,7 @@ func (a *ActivityLog) retentionWorker(ctx context.Context, currentTime time.Time
// everything >= the threshold is OK
retentionThreshold := timeutil.MonthsPreviousTo(retentionMonths, currentTime)
available, err := a.availableLogs(ctx)
available, err := a.availableLogs(ctx, retentionThreshold)
if err != nil {
a.logger.Warn("could not list segments", "error", err)
return err
@@ -2892,7 +2914,7 @@ func (a *ActivityLog) writeExport(ctx context.Context, rw http.ResponseWriter, f
// Find the months with activity log data that are between the start and end
// months. We want to walk this in cronological order so the oldest instance of a
// client usage is recorded, not the most recent.
times, err := a.availableLogs(ctx)
times, err := a.availableLogs(ctx, endTime)
if err != nil {
a.logger.Warn("failed to list available log segments", "error", err)
return fmt.Errorf("failed to list available log segments: %w", err)

View File

@@ -624,7 +624,7 @@ func TestActivityLog_availableLogsEmptyDirectory(t *testing.T) {
// verify that directory is empty, and nothing goes wrong
core, _, _ := TestCoreUnsealed(t)
a := core.activityLog
times, err := a.availableLogs(context.Background())
times, err := a.availableLogs(context.Background(), time.Now())
if err != nil {
t.Fatalf("error getting start_time(s) for empty activity log")
}
@@ -647,7 +647,7 @@ func TestActivityLog_availableLogs(t *testing.T) {
}
// verify above files are there, and dates in correct order
times, err := a.availableLogs(context.Background())
times, err := a.availableLogs(context.Background(), time.Now())
if err != nil {
t.Fatalf("error getting start_time(s) for activity log")
}
@@ -2725,7 +2725,7 @@ func TestActivityLog_CalculatePrecomputedQueriesWithMixedTWEs(t *testing.T) {
// Pretend we've successfully rolled over to the following month
a.SetStartTimestamp(tc.NextMonth)
err = a.precomputedQueryWorker(ctx)
err = a.precomputedQueryWorker(ctx, nil)
if err != nil {
t.Fatal(err)
}
@@ -3106,7 +3106,7 @@ func TestActivityLog_Precompute(t *testing.T) {
// Pretend we've successfully rolled over to the following month
a.SetStartTimestamp(tc.NextMonth)
err = a.precomputedQueryWorker(ctx)
err = a.precomputedQueryWorker(ctx, nil)
if err != nil {
t.Fatal(err)
}
@@ -3368,7 +3368,7 @@ func TestActivityLog_Precompute_SkipMonth(t *testing.T) {
// Pretend we've successfully rolled over to the following month
a.SetStartTimestamp(tc.NextMonth)
err = a.precomputedQueryWorker(ctx)
err = a.precomputedQueryWorker(ctx, nil)
if err != nil {
t.Fatal(err)
}
@@ -3634,7 +3634,7 @@ func TestActivityLog_PrecomputeNonEntityTokensWithID(t *testing.T) {
// Pretend we've successfully rolled over to the following month
a.SetStartTimestamp(tc.NextMonth)
err = a.precomputedQueryWorker(ctx)
err = a.precomputedQueryWorker(ctx, nil)
if err != nil {
t.Fatal(err)
}
@@ -3761,7 +3761,7 @@ func TestActivityLog_PrecomputeCancel(t *testing.T) {
// This will block if the shutdown didn't work.
go func() {
// We expect this to error because of BlockingInmemStorage
_ = a.precomputedQueryWorker(namespace.RootContext(nil))
_ = a.precomputedQueryWorker(namespace.RootContext(nil), nil)
close(done)
}()

View File

@@ -0,0 +1,117 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build testonly
package vault
import (
"context"
"testing"
"time"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/sdk/helper/clientcountutil"
"github.com/hashicorp/vault/sdk/helper/clientcountutil/generation"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestActivityLog_doPrecomputedQueryCreation creates segments for the last 4
// months and then calls doPrecomputedQueryCreation, in order of oldest to most
// recent month. The test verifies that the count of clients in the generated
// precomputed query is equal to the number of deduplicated clients.
func TestActivityLog_doPrecomputedQueryCreation(t *testing.T) {
core, _, token := TestCoreUnsealed(t)
a := core.activityLog
a.SetEnable(true)
j, err := clientcountutil.NewActivityLogData(nil).
// 8 new clients
// across two segments
NewPreviousMonthData(4).
Segment().NewClientsSeen(5).
Segment().NewClientsSeen(3).
// 2 repeated clients
// 10 new clients
// across 3 segments
NewPreviousMonthData(3).
Segment().RepeatedClientsSeen(2).
NewClientsSeen(3).
Segment().NewClientsSeen(2).
Segment().NewClientsSeen(5).
// 7 new clients
// single segment
NewPreviousMonthData(2).
NewClientsSeen(7).
// 6 repeated clients
// 5 new clients
// across 2 segments
NewPreviousMonthData(1).
Segment().NewClientsSeen(5).
Segment().RepeatedClientsSeen(6).
ToJSON(generation.WriteOptions_WRITE_ENTITIES)
require.NoError(t, err)
r := logical.TestRequest(t, logical.UpdateOperation, "sys/internal/counters/activity/write")
r.Data["input"] = string(j)
r.ClientToken = token
_, err = core.HandleRequest(namespace.RootContext(context.Background()), r)
require.NoError(t, err)
now := time.Now().UTC()
times := map[int]time.Time{}
for i := 1; i < 5; i++ {
times[i] = timeutil.StartOfMonth(timeutil.MonthsPreviousTo(i, now))
}
testCases := []struct {
name string
generateUpToMonth int
strictEnforcement bool
wantClients int
}{
{
name: "only 4 months ago",
generateUpToMonth: 4,
wantClients: 8, // 8 clients from month 4
},
{
name: "3 months ago",
generateUpToMonth: 3,
// 8 clients (month 4) + 10 new clients (month 3)
wantClients: 18,
},
{
name: "2 months ago",
generateUpToMonth: 2,
// 8 clients (month 4) + 10 new clients (month 3) + 7 new clients
// (month 2)
wantClients: 25,
},
{
name: "1 month ago",
generateUpToMonth: 1,
// 8 clients (month 4) + 10 new clients (month 3) + 7 new clients
// (month 2) + 5 new clients (month 1)
wantClients: 30,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
generateUpTo := times[tc.generateUpToMonth]
nextMonth := timeutil.StartOfNextMonth(generateUpTo)
err = a.precomputedQueryWorker(context.Background(), &ActivityIntentLog{PreviousMonth: generateUpTo.Unix(), NextMonth: nextMonth.Unix()})
require.NoError(t, err)
// get precomputed queries spanning the whole time period
pq, err := a.queryStore.Get(context.Background(), times[4], now)
require.NoError(t, err)
require.Equal(t, tc.wantClients, int(pq.Namespaces[0].Entities))
})
}
}

View File

@@ -526,7 +526,7 @@ func Test_handleActivityWriteData(t *testing.T) {
paths := resp.Data["paths"].([]string)
require.Len(t, paths, 9)
times, err := core.activityLog.availableLogs(context.Background())
times, err := core.activityLog.availableLogs(context.Background(), time.Now())
require.NoError(t, err)
require.Len(t, times, 4)
@@ -645,7 +645,7 @@ func Test_handleActivityWriteData(t *testing.T) {
require.Equal(t, timeutil.StartOfMonth(now), next.UTC())
require.Equal(t, timeutil.StartOfMonth(timeutil.MonthsPreviousTo(3, now)), prev.UTC())
times, err := core.activityLog.availableLogs(context.Background())
times, err := core.activityLog.availableLogs(context.Background(), time.Now())
require.NoError(t, err)
require.Len(t, times, 4)
})