From 4e3b91d91f379b6368e778849c044fadfa7e67e5 Mon Sep 17 00:00:00 2001 From: miagilepner Date: Mon, 4 Sep 2023 15:48:09 +0200 Subject: [PATCH] [VAULT-17827] Rollback manager worker pool (#22567) * workerpool implementation * rollback tests * website documentation * add changelog * fix failing test --- changelog/22567.txt | 3 + go.mod | 2 +- vault/core.go | 7 + vault/rollback.go | 84 ++++-- vault/rollback_test.go | 249 ++++++++++++++++++ vault/testing.go | 4 + .../docs/internals/telemetry/metrics/all.mdx | 6 + .../telemetry/metrics/core-system.mdx | 6 + .../vault/rollback/inflight.mdx | 5 + .../vault/rollback/queued.mdx | 5 + .../vault/rollback/waiting.mdx | 5 + 11 files changed, 357 insertions(+), 19 deletions(-) create mode 100644 changelog/22567.txt create mode 100644 website/content/partials/telemetry-metrics/vault/rollback/inflight.mdx create mode 100644 website/content/partials/telemetry-metrics/vault/rollback/queued.mdx create mode 100644 website/content/partials/telemetry-metrics/vault/rollback/waiting.mdx diff --git a/changelog/22567.txt b/changelog/22567.txt new file mode 100644 index 0000000000..d9e5570139 --- /dev/null +++ b/changelog/22567.txt @@ -0,0 +1,3 @@ +```release-note:improvement +core: Use a worker pool for the rollback manager. Add new metrics for the rollback manager to track the queued tasks. +``` \ No newline at end of file diff --git a/go.mod b/go.mod index e212640edd..8084f365fe 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/fatih/color v1.15.0 github.com/fatih/structs v1.1.0 github.com/favadi/protoc-go-inject-tag v1.4.0 + github.com/gammazero/workerpool v1.1.3 github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 github.com/go-errors/errors v1.4.2 github.com/go-git/go-git/v5 v5.7.0 @@ -338,7 +339,6 @@ require ( github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gammazero/deque v0.2.1 // indirect - github.com/gammazero/workerpool v1.1.3 // indirect github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/go-billy/v5 v5.4.1 // indirect diff --git a/vault/core.go b/vault/core.go index 69c05fcbcb..ba5fc8aa7f 100644 --- a/vault/core.go +++ b/vault/core.go @@ -680,6 +680,7 @@ type Core struct { // heartbeating with the active node. Default to the current SDK version. effectiveSDKVersion string + numRollbackWorkers int rollbackPeriod time.Duration rollbackMountPathMetrics bool @@ -866,6 +867,8 @@ type CoreConfig struct { // AdministrativeNamespacePath is used to configure the administrative namespace, which has access to some sys endpoints that are // only accessible in the root namespace, currently sys/audit-hash and sys/monitor. AdministrativeNamespacePath string + + NumRollbackWorkers int } // SubloggerHook implements the SubloggerAdder interface. This implementation @@ -954,6 +957,9 @@ func CreateCore(conf *CoreConfig) (*Core, error) { conf.NumExpirationWorkers = numExpirationWorkersDefault } + if conf.NumRollbackWorkers == 0 { + conf.NumRollbackWorkers = RollbackDefaultNumWorkers + } // Use imported logging deadlock if requested var stateLock locking.RWMutex if strings.Contains(conf.DetectDeadlocks, "statelock") { @@ -1038,6 +1044,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) { pendingRemovalMountsAllowed: conf.PendingRemovalMountsAllowed, expirationRevokeRetryBase: conf.ExpirationRevokeRetryBase, rollbackMountPathMetrics: conf.MetricSink.TelemetryConsts.RollbackMetricsIncludeMountPoint, + numRollbackWorkers: conf.NumRollbackWorkers, impreciseLeaseRoleTracking: conf.ImpreciseLeaseRoleTracking, } diff --git a/vault/rollback.go b/vault/rollback.go index 3a9f92a424..c40a98937b 100644 --- a/vault/rollback.go +++ b/vault/rollback.go @@ -6,16 +6,25 @@ package vault import ( "context" "errors" + "fmt" + "os" + "strconv" "strings" "sync" "time" metrics "github.com/armon/go-metrics" + "github.com/gammazero/workerpool" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/logical" ) +const ( + RollbackDefaultNumWorkers = 256 + RollbackWorkersEnvVar = "VAULT_ROLLBACK_WORKERS" +) + // RollbackManager is responsible for performing rollbacks of partial // secrets within logical backends. // @@ -51,8 +60,8 @@ type RollbackManager struct { stopTicker chan struct{} tickerIsStopped bool quitContext context.Context - - core *Core + runner *workerpool.WorkerPool + core *Core // This channel is used for testing rollbacksDoneCh chan struct{} } @@ -63,6 +72,9 @@ type rollbackState struct { sync.WaitGroup cancelLockGrabCtx context.Context cancelLockGrabCtxCancel context.CancelFunc + // scheduled is the time that this job was created and submitted to the + // rollbackRunner + scheduled time.Time } // NewRollbackManager is used to create a new rollback manager @@ -81,9 +93,26 @@ func NewRollbackManager(ctx context.Context, logger log.Logger, backendsFunc fun rollbackMetricsMountName: core.rollbackMountPathMetrics, rollbacksDoneCh: make(chan struct{}), } + numWorkers := r.numRollbackWorkers() + r.logger.Info(fmt.Sprintf("Starting the rollback manager with %d workers", numWorkers)) + r.runner = workerpool.New(numWorkers) return r } +func (m *RollbackManager) numRollbackWorkers() int { + numWorkers := m.core.numRollbackWorkers + envOverride := os.Getenv(RollbackWorkersEnvVar) + if envOverride != "" { + envVarWorkers, err := strconv.Atoi(envOverride) + if err != nil || envVarWorkers < 1 { + m.logger.Warn(fmt.Sprintf("%s must be a positive integer, but was %s", RollbackWorkersEnvVar, envOverride)) + } else { + numWorkers = envVarWorkers + } + } + return numWorkers +} + // Start starts the rollback manager func (m *RollbackManager) Start() { go m.run() @@ -99,7 +128,7 @@ func (m *RollbackManager) Stop() { close(m.shutdownCh) <-m.doneCh } - m.inflightAll.Wait() + m.runner.StopWait() } // StopTicker stops the automatic Rollback manager's ticker, causing us @@ -168,6 +197,8 @@ func (m *RollbackManager) triggerRollbacks() { func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState { m.inflightLock.Lock() defer m.inflightLock.Unlock() + defer metrics.SetGauge([]string{"rollback", "queued"}, float32(m.runner.WaitingQueueSize())) + defer metrics.SetGauge([]string{"rollback", "inflight"}, float32(len(m.inflight))) rsInflight, ok := m.inflight[fullPath] if ok { return rsInflight @@ -183,31 +214,48 @@ func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath st m.inflight[fullPath] = rs rs.Add(1) m.inflightAll.Add(1) - go func() { - m.attemptRollback(ctx, fullPath, rs, grabStatelock) - select { - case m.rollbacksDoneCh <- struct{}{}: - default: - } - }() + rs.scheduled = time.Now() + select { + case <-m.doneCh: + // if we've already shut down, then don't submit the task to avoid a panic + // we should still call finishRollback for the rollback state in order to remove + // it from the map and decrement the waitgroup. + + // we already have the inflight lock, so we can't grab it here + m.finishRollback(rs, errors.New("rollback manager is stopped"), fullPath, false) + default: + m.runner.Submit(func() { + m.attemptRollback(ctx, fullPath, rs, grabStatelock) + select { + case m.rollbacksDoneCh <- struct{}{}: + default: + } + }) + + } return rs } +func (m *RollbackManager) finishRollback(rs *rollbackState, err error, fullPath string, grabInflightLock bool) { + rs.lastError = err + rs.Done() + m.inflightAll.Done() + if grabInflightLock { + m.inflightLock.Lock() + defer m.inflightLock.Unlock() + } + delete(m.inflight, fullPath) +} + // attemptRollback invokes a RollbackOperation for the given path func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string, rs *rollbackState, grabStatelock bool) (err error) { + metrics.MeasureSince([]string{"rollback", "waiting"}, rs.scheduled) metricName := []string{"rollback", "attempt"} if m.rollbackMetricsMountName { metricName = append(metricName, strings.ReplaceAll(fullPath, "/", "-")) } defer metrics.MeasureSince(metricName, time.Now()) - defer func() { - rs.lastError = err - rs.Done() - m.inflightAll.Done() - m.inflightLock.Lock() - delete(m.inflight, fullPath) - m.inflightLock.Unlock() - }() + defer m.finishRollback(rs, err, fullPath, true) ns, err := namespace.FromContext(ctx) if err != nil { diff --git a/vault/rollback_test.go b/vault/rollback_test.go index 0eaba1bb7d..a99060df69 100644 --- a/vault/rollback_test.go +++ b/vault/rollback_test.go @@ -5,6 +5,7 @@ package vault import ( "context" + "fmt" "strings" "sync" "testing" @@ -16,6 +17,7 @@ import ( "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" "github.com/stretchr/testify/require" ) @@ -81,6 +83,253 @@ func TestRollbackManager(t *testing.T) { } } +// TestRollbackManager_ManyWorkers adds 10 backends that require a rollback +// operation, with 20 workers. The test verifies that the 10 +// work items will run in parallel +func TestRollbackManager_ManyWorkers(t *testing.T) { + core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 20, RollbackPeriod: time.Millisecond * 10}) + view := NewBarrierView(core.barrier, "logical/") + + ran := make(chan string) + release := make(chan struct{}) + core, _, _ = testCoreUnsealed(t, core) + + // create 10 backends + // when a rollback happens, each backend will try to write to an unbuffered + // channel, then wait to be released + for i := 0; i < 10; i++ { + b := &NoopBackend{} + b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) { + if request.Operation == logical.RollbackOperation { + ran <- request.Path + <-release + } + return nil, nil + } + b.Root = []string{fmt.Sprintf("foo/%d", i)} + meUUID, err := uuid.GenerateUUID() + require.NoError(t, err) + mountEntry := &MountEntry{ + Table: mountTableType, + UUID: meUUID, + Accessor: fmt.Sprintf("accessor-%d", i), + NamespaceID: namespace.RootNamespaceID, + namespace: namespace.RootNamespace, + Path: fmt.Sprintf("logical/foo/%d", i), + } + func() { + core.mountsLock.Lock() + defer core.mountsLock.Unlock() + newTable := core.mounts.shallowClone() + newTable.Entries = append(newTable.Entries, mountEntry) + core.mounts = newTable + err = core.router.Mount(b, "logical", mountEntry, view) + require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local)) + }() + } + + timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + got := make(map[string]bool) + hasMore := true + for hasMore { + // we're not bounding the number of workers, so we would expect to see + // all 10 writes to the channel from each of the backends. Once that + // happens, close the release channel so that the functions can exit + select { + case <-timeout.Done(): + require.Fail(t, "test timed out") + case i := <-ran: + got[i] = true + if len(got) == 10 { + close(release) + hasMore = false + } + } + } + done := make(chan struct{}) + + // start a goroutine to consume the remaining items from the queued work + go func() { + for { + select { + case <-ran: + case <-done: + return + } + } + }() + // stop the rollback worker, which will wait for all inflight rollbacks to + // complete + core.rollback.Stop() + close(done) +} + +// TestRollbackManager_WorkerPool adds 10 backends that require a rollback +// operation, with 5 workers. The test verifies that the 5 work items can occur +// concurrently, and that the remainder of the work is queued and run when +// workers are available +func TestRollbackManager_WorkerPool(t *testing.T) { + core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 5, RollbackPeriod: time.Millisecond * 10}) + view := NewBarrierView(core.barrier, "logical/") + + ran := make(chan string) + release := make(chan struct{}) + core, _, _ = testCoreUnsealed(t, core) + + // create 10 backends + // when a rollback happens, each backend will try to write to an unbuffered + // channel, then wait to be released + for i := 0; i < 10; i++ { + b := &NoopBackend{} + b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) { + if request.Operation == logical.RollbackOperation { + ran <- request.Path + <-release + } + return nil, nil + } + b.Root = []string{fmt.Sprintf("foo/%d", i)} + meUUID, err := uuid.GenerateUUID() + require.NoError(t, err) + mountEntry := &MountEntry{ + Table: mountTableType, + UUID: meUUID, + Accessor: fmt.Sprintf("accessor-%d", i), + NamespaceID: namespace.RootNamespaceID, + namespace: namespace.RootNamespace, + Path: fmt.Sprintf("logical/foo/%d", i), + } + func() { + core.mountsLock.Lock() + defer core.mountsLock.Unlock() + newTable := core.mounts.shallowClone() + newTable.Entries = append(newTable.Entries, mountEntry) + core.mounts = newTable + err = core.router.Mount(b, "logical", mountEntry, view) + require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local)) + }() + } + + core.mountsLock.RLock() + numMounts := len(core.mounts.Entries) + core.mountsLock.RUnlock() + + timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + got := make(map[string]bool) + gotLock := sync.RWMutex{} + hasMore := true + for hasMore { + // we're using 5 workers, so we would expect to see 5 writes to the + // channel. Once that happens, close the release channel so that the + // functions can exit and new rollback operations can run + select { + case <-timeout.Done(): + require.Fail(t, "test timed out") + case i := <-ran: + gotLock.Lock() + got[i] = true + numGot := len(got) + gotLock.Unlock() + if numGot == 5 { + close(release) + hasMore = false + } + } + } + done := make(chan struct{}) + + // start a goroutine to consume the remaining items from the queued work + go func() { + for { + select { + case i := <-ran: + gotLock.Lock() + got[i] = true + gotLock.Unlock() + case <-done: + return + } + } + }() + + // wait for every mount to be rolled back at least once + numMountsDone := 0 + for numMountsDone < numMounts { + <-core.rollback.rollbacksDoneCh + numMountsDone++ + } + + // stop the rollback worker, which will wait for all inflight rollbacks to + // complete + core.rollback.Stop() + close(done) + + // we should have received at least 1 rollback for every backend + gotLock.RLock() + defer gotLock.RUnlock() + require.GreaterOrEqual(t, len(got), 10) +} + +// TestRollbackManager_numRollbackWorkers verifies that the number of rollback +// workers is parsed from the configuration, but can be overridden by an +// environment variable. This test cannot be run in parallel because of the +// environment variable +func TestRollbackManager_numRollbackWorkers(t *testing.T) { + testCases := []struct { + name string + configWorkers int + setEnvVar bool + envVar string + wantWorkers int + }{ + { + name: "default in config", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: RollbackDefaultNumWorkers, + }, + { + name: "invalid envvar", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: RollbackDefaultNumWorkers, + setEnvVar: true, + envVar: "invalid", + }, + { + name: "envvar overrides config", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: 20, + setEnvVar: true, + envVar: "20", + }, + { + name: "envvar negative", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: RollbackDefaultNumWorkers, + setEnvVar: true, + envVar: "-1", + }, + { + name: "envvar zero", + configWorkers: RollbackDefaultNumWorkers, + wantWorkers: RollbackDefaultNumWorkers, + setEnvVar: true, + envVar: "0", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.setEnvVar { + t.Setenv(RollbackWorkersEnvVar, tc.envVar) + } + core := &Core{numRollbackWorkers: tc.configWorkers} + r := &RollbackManager{logger: logger.Named("test"), core: core} + require.Equal(t, tc.wantWorkers, r.numRollbackWorkers()) + }) + } +} + func TestRollbackManager_Join(t *testing.T) { m, backend := mockRollback(t) if len(backend.Paths) > 0 { diff --git a/vault/testing.go b/vault/testing.go index cb4df9597b..d9dd631022 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -249,6 +249,9 @@ func TestCoreWithSealAndUINoCleanup(t testing.T, opts *CoreConfig) *Core { if opts.RollbackPeriod != time.Duration(0) { conf.RollbackPeriod = opts.RollbackPeriod } + if opts.NumRollbackWorkers != 0 { + conf.NumRollbackWorkers = opts.NumRollbackWorkers + } conf.ActivityLogConfig = opts.ActivityLogConfig testApplyEntBaseConfig(conf, opts) @@ -305,6 +308,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo CredentialBackends: credentialBackends, DisableMlock: true, Logger: logger, + NumRollbackWorkers: 10, BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(), } diff --git a/website/content/docs/internals/telemetry/metrics/all.mdx b/website/content/docs/internals/telemetry/metrics/all.mdx index 26a454e0cd..7fd38ea1ac 100644 --- a/website/content/docs/internals/telemetry/metrics/all.mdx +++ b/website/content/docs/internals/telemetry/metrics/all.mdx @@ -618,6 +618,12 @@ alphabetic order by name. @include 'telemetry-metrics/vault/rollback/attempt.mdx' +@include 'telemetry-metrics/vault/rollback/inflight.mdx' + +@include 'telemetry-metrics/vault/rollback/queued.mdx' + +@include 'telemetry-metrics/vault/rollback/waiting.mdx' + @include 'telemetry-metrics/vault/route/create/mountpoint.mdx' @include 'telemetry-metrics/vault/route/delete/mountpoint.mdx' diff --git a/website/content/docs/internals/telemetry/metrics/core-system.mdx b/website/content/docs/internals/telemetry/metrics/core-system.mdx index 15f4b51306..90476dd3b7 100644 --- a/website/content/docs/internals/telemetry/metrics/core-system.mdx +++ b/website/content/docs/internals/telemetry/metrics/core-system.mdx @@ -114,6 +114,12 @@ Vault instance. @include 'telemetry-metrics/vault/rollback/attempt.mdx' +@include 'telemetry-metrics/vault/rollback/inflight.mdx' + +@include 'telemetry-metrics/vault/rollback/queued.mdx' + +@include 'telemetry-metrics/vault/rollback/waiting.mdx' + ## Route metrics @include 'telemetry-metrics/route-intro.mdx' diff --git a/website/content/partials/telemetry-metrics/vault/rollback/inflight.mdx b/website/content/partials/telemetry-metrics/vault/rollback/inflight.mdx new file mode 100644 index 0000000000..832cb30888 --- /dev/null +++ b/website/content/partials/telemetry-metrics/vault/rollback/inflight.mdx @@ -0,0 +1,5 @@ +### vault.rollback.inflight ((#vault-rollback-inflight)) + +Metric type | Value | Description +----------- | ------ | ----------- +gauge | number | Number of rollback operations inflight diff --git a/website/content/partials/telemetry-metrics/vault/rollback/queued.mdx b/website/content/partials/telemetry-metrics/vault/rollback/queued.mdx new file mode 100644 index 0000000000..e8a7d099f4 --- /dev/null +++ b/website/content/partials/telemetry-metrics/vault/rollback/queued.mdx @@ -0,0 +1,5 @@ +### vault.rollback.queued ((#vault-rollback-queued)) + +Metric type | Value | Description +----------- | ------ | ----------- +guage | number | The number of rollback operations waiting to be started diff --git a/website/content/partials/telemetry-metrics/vault/rollback/waiting.mdx b/website/content/partials/telemetry-metrics/vault/rollback/waiting.mdx new file mode 100644 index 0000000000..2fb0e2eab0 --- /dev/null +++ b/website/content/partials/telemetry-metrics/vault/rollback/waiting.mdx @@ -0,0 +1,5 @@ +### vault.rollback.waiting ((#vault-rollback-waiting)) + +Metric type | Value | Description +----------- | ----- | ----------- +summary | ms | Time between queueing a rollback operation and the operation starting