mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	[VAULT-17827] Rollback manager worker pool (#22567)
* workerpool implementation * rollback tests * website documentation * add changelog * fix failing test
This commit is contained in:
		
							
								
								
									
										3
									
								
								changelog/22567.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/22567.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | ```release-note:improvement | ||||||
|  | core: Use a worker pool for the rollback manager. Add new metrics for the rollback manager to track the queued tasks.  | ||||||
|  | ``` | ||||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							| @@ -56,6 +56,7 @@ require ( | |||||||
| 	github.com/fatih/color v1.15.0 | 	github.com/fatih/color v1.15.0 | ||||||
| 	github.com/fatih/structs v1.1.0 | 	github.com/fatih/structs v1.1.0 | ||||||
| 	github.com/favadi/protoc-go-inject-tag v1.4.0 | 	github.com/favadi/protoc-go-inject-tag v1.4.0 | ||||||
|  | 	github.com/gammazero/workerpool v1.1.3 | ||||||
| 	github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 | 	github.com/ghodss/yaml v1.0.1-0.20190212211648-25d852aebe32 | ||||||
| 	github.com/go-errors/errors v1.4.2 | 	github.com/go-errors/errors v1.4.2 | ||||||
| 	github.com/go-git/go-git/v5 v5.7.0 | 	github.com/go-git/go-git/v5 v5.7.0 | ||||||
| @@ -338,7 +339,6 @@ require ( | |||||||
| 	github.com/fsnotify/fsnotify v1.6.0 // indirect | 	github.com/fsnotify/fsnotify v1.6.0 // indirect | ||||||
| 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | 	github.com/gabriel-vasile/mimetype v1.4.2 // indirect | ||||||
| 	github.com/gammazero/deque v0.2.1 // indirect | 	github.com/gammazero/deque v0.2.1 // indirect | ||||||
| 	github.com/gammazero/workerpool v1.1.3 // indirect |  | ||||||
| 	github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect | 	github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect | ||||||
| 	github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect | 	github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect | ||||||
| 	github.com/go-git/go-billy/v5 v5.4.1 // indirect | 	github.com/go-git/go-billy/v5 v5.4.1 // indirect | ||||||
|   | |||||||
| @@ -680,6 +680,7 @@ type Core struct { | |||||||
| 	// heartbeating with the active node. Default to the current SDK version. | 	// heartbeating with the active node. Default to the current SDK version. | ||||||
| 	effectiveSDKVersion string | 	effectiveSDKVersion string | ||||||
|  |  | ||||||
|  | 	numRollbackWorkers       int | ||||||
| 	rollbackPeriod           time.Duration | 	rollbackPeriod           time.Duration | ||||||
| 	rollbackMountPathMetrics bool | 	rollbackMountPathMetrics bool | ||||||
|  |  | ||||||
| @@ -866,6 +867,8 @@ type CoreConfig struct { | |||||||
| 	// AdministrativeNamespacePath is used to configure the administrative namespace, which has access to some sys endpoints that are | 	// AdministrativeNamespacePath is used to configure the administrative namespace, which has access to some sys endpoints that are | ||||||
| 	// only accessible in the root namespace, currently sys/audit-hash and sys/monitor. | 	// only accessible in the root namespace, currently sys/audit-hash and sys/monitor. | ||||||
| 	AdministrativeNamespacePath string | 	AdministrativeNamespacePath string | ||||||
|  |  | ||||||
|  | 	NumRollbackWorkers int | ||||||
| } | } | ||||||
|  |  | ||||||
| // SubloggerHook implements the SubloggerAdder interface. This implementation | // SubloggerHook implements the SubloggerAdder interface. This implementation | ||||||
| @@ -954,6 +957,9 @@ func CreateCore(conf *CoreConfig) (*Core, error) { | |||||||
| 		conf.NumExpirationWorkers = numExpirationWorkersDefault | 		conf.NumExpirationWorkers = numExpirationWorkersDefault | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if conf.NumRollbackWorkers == 0 { | ||||||
|  | 		conf.NumRollbackWorkers = RollbackDefaultNumWorkers | ||||||
|  | 	} | ||||||
| 	// Use imported logging deadlock if requested | 	// Use imported logging deadlock if requested | ||||||
| 	var stateLock locking.RWMutex | 	var stateLock locking.RWMutex | ||||||
| 	if strings.Contains(conf.DetectDeadlocks, "statelock") { | 	if strings.Contains(conf.DetectDeadlocks, "statelock") { | ||||||
| @@ -1038,6 +1044,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) { | |||||||
| 		pendingRemovalMountsAllowed:    conf.PendingRemovalMountsAllowed, | 		pendingRemovalMountsAllowed:    conf.PendingRemovalMountsAllowed, | ||||||
| 		expirationRevokeRetryBase:      conf.ExpirationRevokeRetryBase, | 		expirationRevokeRetryBase:      conf.ExpirationRevokeRetryBase, | ||||||
| 		rollbackMountPathMetrics:       conf.MetricSink.TelemetryConsts.RollbackMetricsIncludeMountPoint, | 		rollbackMountPathMetrics:       conf.MetricSink.TelemetryConsts.RollbackMetricsIncludeMountPoint, | ||||||
|  | 		numRollbackWorkers:             conf.NumRollbackWorkers, | ||||||
| 		impreciseLeaseRoleTracking:     conf.ImpreciseLeaseRoleTracking, | 		impreciseLeaseRoleTracking:     conf.ImpreciseLeaseRoleTracking, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,16 +6,25 @@ package vault | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | 	"strconv" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	metrics "github.com/armon/go-metrics" | 	metrics "github.com/armon/go-metrics" | ||||||
|  | 	"github.com/gammazero/workerpool" | ||||||
| 	log "github.com/hashicorp/go-hclog" | 	log "github.com/hashicorp/go-hclog" | ||||||
| 	"github.com/hashicorp/vault/helper/namespace" | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
| 	"github.com/hashicorp/vault/sdk/logical" | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	RollbackDefaultNumWorkers = 256 | ||||||
|  | 	RollbackWorkersEnvVar     = "VAULT_ROLLBACK_WORKERS" | ||||||
|  | ) | ||||||
|  |  | ||||||
| // RollbackManager is responsible for performing rollbacks of partial | // RollbackManager is responsible for performing rollbacks of partial | ||||||
| // secrets within logical backends. | // secrets within logical backends. | ||||||
| // | // | ||||||
| @@ -51,8 +60,8 @@ type RollbackManager struct { | |||||||
| 	stopTicker      chan struct{} | 	stopTicker      chan struct{} | ||||||
| 	tickerIsStopped bool | 	tickerIsStopped bool | ||||||
| 	quitContext     context.Context | 	quitContext     context.Context | ||||||
|  | 	runner          *workerpool.WorkerPool | ||||||
| 	core *Core | 	core            *Core | ||||||
| 	// This channel is used for testing | 	// This channel is used for testing | ||||||
| 	rollbacksDoneCh chan struct{} | 	rollbacksDoneCh chan struct{} | ||||||
| } | } | ||||||
| @@ -63,6 +72,9 @@ type rollbackState struct { | |||||||
| 	sync.WaitGroup | 	sync.WaitGroup | ||||||
| 	cancelLockGrabCtx       context.Context | 	cancelLockGrabCtx       context.Context | ||||||
| 	cancelLockGrabCtxCancel context.CancelFunc | 	cancelLockGrabCtxCancel context.CancelFunc | ||||||
|  | 	// scheduled is the time that this job was created and submitted to the | ||||||
|  | 	// rollbackRunner | ||||||
|  | 	scheduled time.Time | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewRollbackManager is used to create a new rollback manager | // NewRollbackManager is used to create a new rollback manager | ||||||
| @@ -81,9 +93,26 @@ func NewRollbackManager(ctx context.Context, logger log.Logger, backendsFunc fun | |||||||
| 		rollbackMetricsMountName: core.rollbackMountPathMetrics, | 		rollbackMetricsMountName: core.rollbackMountPathMetrics, | ||||||
| 		rollbacksDoneCh:          make(chan struct{}), | 		rollbacksDoneCh:          make(chan struct{}), | ||||||
| 	} | 	} | ||||||
|  | 	numWorkers := r.numRollbackWorkers() | ||||||
|  | 	r.logger.Info(fmt.Sprintf("Starting the rollback manager with %d workers", numWorkers)) | ||||||
|  | 	r.runner = workerpool.New(numWorkers) | ||||||
| 	return r | 	return r | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (m *RollbackManager) numRollbackWorkers() int { | ||||||
|  | 	numWorkers := m.core.numRollbackWorkers | ||||||
|  | 	envOverride := os.Getenv(RollbackWorkersEnvVar) | ||||||
|  | 	if envOverride != "" { | ||||||
|  | 		envVarWorkers, err := strconv.Atoi(envOverride) | ||||||
|  | 		if err != nil || envVarWorkers < 1 { | ||||||
|  | 			m.logger.Warn(fmt.Sprintf("%s must be a positive integer, but was %s", RollbackWorkersEnvVar, envOverride)) | ||||||
|  | 		} else { | ||||||
|  | 			numWorkers = envVarWorkers | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return numWorkers | ||||||
|  | } | ||||||
|  |  | ||||||
| // Start starts the rollback manager | // Start starts the rollback manager | ||||||
| func (m *RollbackManager) Start() { | func (m *RollbackManager) Start() { | ||||||
| 	go m.run() | 	go m.run() | ||||||
| @@ -99,7 +128,7 @@ func (m *RollbackManager) Stop() { | |||||||
| 		close(m.shutdownCh) | 		close(m.shutdownCh) | ||||||
| 		<-m.doneCh | 		<-m.doneCh | ||||||
| 	} | 	} | ||||||
| 	m.inflightAll.Wait() | 	m.runner.StopWait() | ||||||
| } | } | ||||||
|  |  | ||||||
| // StopTicker stops the automatic Rollback manager's ticker, causing us | // StopTicker stops the automatic Rollback manager's ticker, causing us | ||||||
| @@ -168,6 +197,8 @@ func (m *RollbackManager) triggerRollbacks() { | |||||||
| func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState { | func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath string, grabStatelock bool) *rollbackState { | ||||||
| 	m.inflightLock.Lock() | 	m.inflightLock.Lock() | ||||||
| 	defer m.inflightLock.Unlock() | 	defer m.inflightLock.Unlock() | ||||||
|  | 	defer metrics.SetGauge([]string{"rollback", "queued"}, float32(m.runner.WaitingQueueSize())) | ||||||
|  | 	defer metrics.SetGauge([]string{"rollback", "inflight"}, float32(len(m.inflight))) | ||||||
| 	rsInflight, ok := m.inflight[fullPath] | 	rsInflight, ok := m.inflight[fullPath] | ||||||
| 	if ok { | 	if ok { | ||||||
| 		return rsInflight | 		return rsInflight | ||||||
| @@ -183,31 +214,48 @@ func (m *RollbackManager) startOrLookupRollback(ctx context.Context, fullPath st | |||||||
| 	m.inflight[fullPath] = rs | 	m.inflight[fullPath] = rs | ||||||
| 	rs.Add(1) | 	rs.Add(1) | ||||||
| 	m.inflightAll.Add(1) | 	m.inflightAll.Add(1) | ||||||
| 	go func() { | 	rs.scheduled = time.Now() | ||||||
| 		m.attemptRollback(ctx, fullPath, rs, grabStatelock) | 	select { | ||||||
| 		select { | 	case <-m.doneCh: | ||||||
| 		case m.rollbacksDoneCh <- struct{}{}: | 		// if we've already shut down, then don't submit the task to avoid a panic | ||||||
| 		default: | 		// we should still call finishRollback for the rollback state in order to remove | ||||||
| 		} | 		// it from the map and decrement the waitgroup. | ||||||
| 	}() |  | ||||||
|  | 		// we already have the inflight lock, so we can't grab it here | ||||||
|  | 		m.finishRollback(rs, errors.New("rollback manager is stopped"), fullPath, false) | ||||||
|  | 	default: | ||||||
|  | 		m.runner.Submit(func() { | ||||||
|  | 			m.attemptRollback(ctx, fullPath, rs, grabStatelock) | ||||||
|  | 			select { | ||||||
|  | 			case m.rollbacksDoneCh <- struct{}{}: | ||||||
|  | 			default: | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 	} | ||||||
| 	return rs | 	return rs | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (m *RollbackManager) finishRollback(rs *rollbackState, err error, fullPath string, grabInflightLock bool) { | ||||||
|  | 	rs.lastError = err | ||||||
|  | 	rs.Done() | ||||||
|  | 	m.inflightAll.Done() | ||||||
|  | 	if grabInflightLock { | ||||||
|  | 		m.inflightLock.Lock() | ||||||
|  | 		defer m.inflightLock.Unlock() | ||||||
|  | 	} | ||||||
|  | 	delete(m.inflight, fullPath) | ||||||
|  | } | ||||||
|  |  | ||||||
| // attemptRollback invokes a RollbackOperation for the given path | // attemptRollback invokes a RollbackOperation for the given path | ||||||
| func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string, rs *rollbackState, grabStatelock bool) (err error) { | func (m *RollbackManager) attemptRollback(ctx context.Context, fullPath string, rs *rollbackState, grabStatelock bool) (err error) { | ||||||
|  | 	metrics.MeasureSince([]string{"rollback", "waiting"}, rs.scheduled) | ||||||
| 	metricName := []string{"rollback", "attempt"} | 	metricName := []string{"rollback", "attempt"} | ||||||
| 	if m.rollbackMetricsMountName { | 	if m.rollbackMetricsMountName { | ||||||
| 		metricName = append(metricName, strings.ReplaceAll(fullPath, "/", "-")) | 		metricName = append(metricName, strings.ReplaceAll(fullPath, "/", "-")) | ||||||
| 	} | 	} | ||||||
| 	defer metrics.MeasureSince(metricName, time.Now()) | 	defer metrics.MeasureSince(metricName, time.Now()) | ||||||
| 	defer func() { | 	defer m.finishRollback(rs, err, fullPath, true) | ||||||
| 		rs.lastError = err |  | ||||||
| 		rs.Done() |  | ||||||
| 		m.inflightAll.Done() |  | ||||||
| 		m.inflightLock.Lock() |  | ||||||
| 		delete(m.inflight, fullPath) |  | ||||||
| 		m.inflightLock.Unlock() |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	ns, err := namespace.FromContext(ctx) | 	ns, err := namespace.FromContext(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ package vault | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| @@ -16,6 +17,7 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/helper/metricsutil" | 	"github.com/hashicorp/vault/helper/metricsutil" | ||||||
| 	"github.com/hashicorp/vault/helper/namespace" | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/logging" | 	"github.com/hashicorp/vault/sdk/helper/logging" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -81,6 +83,253 @@ func TestRollbackManager(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // TestRollbackManager_ManyWorkers adds 10 backends that require a rollback | ||||||
|  | // operation, with 20 workers. The test verifies that the 10 | ||||||
|  | // work items will run in parallel | ||||||
|  | func TestRollbackManager_ManyWorkers(t *testing.T) { | ||||||
|  | 	core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 20, RollbackPeriod: time.Millisecond * 10}) | ||||||
|  | 	view := NewBarrierView(core.barrier, "logical/") | ||||||
|  |  | ||||||
|  | 	ran := make(chan string) | ||||||
|  | 	release := make(chan struct{}) | ||||||
|  | 	core, _, _ = testCoreUnsealed(t, core) | ||||||
|  |  | ||||||
|  | 	// create 10 backends | ||||||
|  | 	// when a rollback happens, each backend will try to write to an unbuffered | ||||||
|  | 	// channel, then wait to be released | ||||||
|  | 	for i := 0; i < 10; i++ { | ||||||
|  | 		b := &NoopBackend{} | ||||||
|  | 		b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) { | ||||||
|  | 			if request.Operation == logical.RollbackOperation { | ||||||
|  | 				ran <- request.Path | ||||||
|  | 				<-release | ||||||
|  | 			} | ||||||
|  | 			return nil, nil | ||||||
|  | 		} | ||||||
|  | 		b.Root = []string{fmt.Sprintf("foo/%d", i)} | ||||||
|  | 		meUUID, err := uuid.GenerateUUID() | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		mountEntry := &MountEntry{ | ||||||
|  | 			Table:       mountTableType, | ||||||
|  | 			UUID:        meUUID, | ||||||
|  | 			Accessor:    fmt.Sprintf("accessor-%d", i), | ||||||
|  | 			NamespaceID: namespace.RootNamespaceID, | ||||||
|  | 			namespace:   namespace.RootNamespace, | ||||||
|  | 			Path:        fmt.Sprintf("logical/foo/%d", i), | ||||||
|  | 		} | ||||||
|  | 		func() { | ||||||
|  | 			core.mountsLock.Lock() | ||||||
|  | 			defer core.mountsLock.Unlock() | ||||||
|  | 			newTable := core.mounts.shallowClone() | ||||||
|  | 			newTable.Entries = append(newTable.Entries, mountEntry) | ||||||
|  | 			core.mounts = newTable | ||||||
|  | 			err = core.router.Mount(b, "logical", mountEntry, view) | ||||||
|  | 			require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local)) | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  | 	got := make(map[string]bool) | ||||||
|  | 	hasMore := true | ||||||
|  | 	for hasMore { | ||||||
|  | 		// we're not bounding the number of workers, so we would expect to see | ||||||
|  | 		// all 10 writes to the channel from each of the backends. Once that | ||||||
|  | 		// happens, close the release channel so that the functions can exit | ||||||
|  | 		select { | ||||||
|  | 		case <-timeout.Done(): | ||||||
|  | 			require.Fail(t, "test timed out") | ||||||
|  | 		case i := <-ran: | ||||||
|  | 			got[i] = true | ||||||
|  | 			if len(got) == 10 { | ||||||
|  | 				close(release) | ||||||
|  | 				hasMore = false | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	done := make(chan struct{}) | ||||||
|  |  | ||||||
|  | 	// start a goroutine to consume the remaining items from the queued work | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			select { | ||||||
|  | 			case <-ran: | ||||||
|  | 			case <-done: | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 	// stop the rollback worker, which will wait for all inflight rollbacks to | ||||||
|  | 	// complete | ||||||
|  | 	core.rollback.Stop() | ||||||
|  | 	close(done) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestRollbackManager_WorkerPool adds 10 backends that require a rollback | ||||||
|  | // operation, with 5 workers. The test verifies that the 5 work items can occur | ||||||
|  | // concurrently, and that the remainder of the work is queued and run when | ||||||
|  | // workers are available | ||||||
|  | func TestRollbackManager_WorkerPool(t *testing.T) { | ||||||
|  | 	core := TestCoreWithConfig(t, &CoreConfig{NumRollbackWorkers: 5, RollbackPeriod: time.Millisecond * 10}) | ||||||
|  | 	view := NewBarrierView(core.barrier, "logical/") | ||||||
|  |  | ||||||
|  | 	ran := make(chan string) | ||||||
|  | 	release := make(chan struct{}) | ||||||
|  | 	core, _, _ = testCoreUnsealed(t, core) | ||||||
|  |  | ||||||
|  | 	// create 10 backends | ||||||
|  | 	// when a rollback happens, each backend will try to write to an unbuffered | ||||||
|  | 	// channel, then wait to be released | ||||||
|  | 	for i := 0; i < 10; i++ { | ||||||
|  | 		b := &NoopBackend{} | ||||||
|  | 		b.RequestHandler = func(ctx context.Context, request *logical.Request) (*logical.Response, error) { | ||||||
|  | 			if request.Operation == logical.RollbackOperation { | ||||||
|  | 				ran <- request.Path | ||||||
|  | 				<-release | ||||||
|  | 			} | ||||||
|  | 			return nil, nil | ||||||
|  | 		} | ||||||
|  | 		b.Root = []string{fmt.Sprintf("foo/%d", i)} | ||||||
|  | 		meUUID, err := uuid.GenerateUUID() | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		mountEntry := &MountEntry{ | ||||||
|  | 			Table:       mountTableType, | ||||||
|  | 			UUID:        meUUID, | ||||||
|  | 			Accessor:    fmt.Sprintf("accessor-%d", i), | ||||||
|  | 			NamespaceID: namespace.RootNamespaceID, | ||||||
|  | 			namespace:   namespace.RootNamespace, | ||||||
|  | 			Path:        fmt.Sprintf("logical/foo/%d", i), | ||||||
|  | 		} | ||||||
|  | 		func() { | ||||||
|  | 			core.mountsLock.Lock() | ||||||
|  | 			defer core.mountsLock.Unlock() | ||||||
|  | 			newTable := core.mounts.shallowClone() | ||||||
|  | 			newTable.Entries = append(newTable.Entries, mountEntry) | ||||||
|  | 			core.mounts = newTable | ||||||
|  | 			err = core.router.Mount(b, "logical", mountEntry, view) | ||||||
|  | 			require.NoError(t, core.persistMounts(context.Background(), newTable, &mountEntry.Local)) | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	core.mountsLock.RLock() | ||||||
|  | 	numMounts := len(core.mounts.Entries) | ||||||
|  | 	core.mountsLock.RUnlock() | ||||||
|  |  | ||||||
|  | 	timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second) | ||||||
|  | 	defer cancel() | ||||||
|  | 	got := make(map[string]bool) | ||||||
|  | 	gotLock := sync.RWMutex{} | ||||||
|  | 	hasMore := true | ||||||
|  | 	for hasMore { | ||||||
|  | 		// we're using 5 workers, so we would expect to see 5 writes to the | ||||||
|  | 		// channel. Once that happens, close the release channel so that the | ||||||
|  | 		// functions can exit and new rollback operations can run | ||||||
|  | 		select { | ||||||
|  | 		case <-timeout.Done(): | ||||||
|  | 			require.Fail(t, "test timed out") | ||||||
|  | 		case i := <-ran: | ||||||
|  | 			gotLock.Lock() | ||||||
|  | 			got[i] = true | ||||||
|  | 			numGot := len(got) | ||||||
|  | 			gotLock.Unlock() | ||||||
|  | 			if numGot == 5 { | ||||||
|  | 				close(release) | ||||||
|  | 				hasMore = false | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	done := make(chan struct{}) | ||||||
|  |  | ||||||
|  | 	// start a goroutine to consume the remaining items from the queued work | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			select { | ||||||
|  | 			case i := <-ran: | ||||||
|  | 				gotLock.Lock() | ||||||
|  | 				got[i] = true | ||||||
|  | 				gotLock.Unlock() | ||||||
|  | 			case <-done: | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	// wait for every mount to be rolled back at least once | ||||||
|  | 	numMountsDone := 0 | ||||||
|  | 	for numMountsDone < numMounts { | ||||||
|  | 		<-core.rollback.rollbacksDoneCh | ||||||
|  | 		numMountsDone++ | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// stop the rollback worker, which will wait for all inflight rollbacks to | ||||||
|  | 	// complete | ||||||
|  | 	core.rollback.Stop() | ||||||
|  | 	close(done) | ||||||
|  |  | ||||||
|  | 	// we should have received at least 1 rollback for every backend | ||||||
|  | 	gotLock.RLock() | ||||||
|  | 	defer gotLock.RUnlock() | ||||||
|  | 	require.GreaterOrEqual(t, len(got), 10) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TestRollbackManager_numRollbackWorkers verifies that the number of rollback | ||||||
|  | // workers is parsed from the configuration, but can be overridden by an | ||||||
|  | // environment variable. This test cannot be run in parallel because of the | ||||||
|  | // environment variable | ||||||
|  | func TestRollbackManager_numRollbackWorkers(t *testing.T) { | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		name          string | ||||||
|  | 		configWorkers int | ||||||
|  | 		setEnvVar     bool | ||||||
|  | 		envVar        string | ||||||
|  | 		wantWorkers   int | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:          "default in config", | ||||||
|  | 			configWorkers: RollbackDefaultNumWorkers, | ||||||
|  | 			wantWorkers:   RollbackDefaultNumWorkers, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "invalid envvar", | ||||||
|  | 			configWorkers: RollbackDefaultNumWorkers, | ||||||
|  | 			wantWorkers:   RollbackDefaultNumWorkers, | ||||||
|  | 			setEnvVar:     true, | ||||||
|  | 			envVar:        "invalid", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "envvar overrides config", | ||||||
|  | 			configWorkers: RollbackDefaultNumWorkers, | ||||||
|  | 			wantWorkers:   20, | ||||||
|  | 			setEnvVar:     true, | ||||||
|  | 			envVar:        "20", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "envvar negative", | ||||||
|  | 			configWorkers: RollbackDefaultNumWorkers, | ||||||
|  | 			wantWorkers:   RollbackDefaultNumWorkers, | ||||||
|  | 			setEnvVar:     true, | ||||||
|  | 			envVar:        "-1", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:          "envvar zero", | ||||||
|  | 			configWorkers: RollbackDefaultNumWorkers, | ||||||
|  | 			wantWorkers:   RollbackDefaultNumWorkers, | ||||||
|  | 			setEnvVar:     true, | ||||||
|  | 			envVar:        "0", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tc := range testCases { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			if tc.setEnvVar { | ||||||
|  | 				t.Setenv(RollbackWorkersEnvVar, tc.envVar) | ||||||
|  | 			} | ||||||
|  | 			core := &Core{numRollbackWorkers: tc.configWorkers} | ||||||
|  | 			r := &RollbackManager{logger: logger.Named("test"), core: core} | ||||||
|  | 			require.Equal(t, tc.wantWorkers, r.numRollbackWorkers()) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestRollbackManager_Join(t *testing.T) { | func TestRollbackManager_Join(t *testing.T) { | ||||||
| 	m, backend := mockRollback(t) | 	m, backend := mockRollback(t) | ||||||
| 	if len(backend.Paths) > 0 { | 	if len(backend.Paths) > 0 { | ||||||
|   | |||||||
| @@ -249,6 +249,9 @@ func TestCoreWithSealAndUINoCleanup(t testing.T, opts *CoreConfig) *Core { | |||||||
| 	if opts.RollbackPeriod != time.Duration(0) { | 	if opts.RollbackPeriod != time.Duration(0) { | ||||||
| 		conf.RollbackPeriod = opts.RollbackPeriod | 		conf.RollbackPeriod = opts.RollbackPeriod | ||||||
| 	} | 	} | ||||||
|  | 	if opts.NumRollbackWorkers != 0 { | ||||||
|  | 		conf.NumRollbackWorkers = opts.NumRollbackWorkers | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	conf.ActivityLogConfig = opts.ActivityLogConfig | 	conf.ActivityLogConfig = opts.ActivityLogConfig | ||||||
| 	testApplyEntBaseConfig(conf, opts) | 	testApplyEntBaseConfig(conf, opts) | ||||||
| @@ -305,6 +308,7 @@ func testCoreConfig(t testing.T, physicalBackend physical.Backend, logger log.Lo | |||||||
| 		CredentialBackends: credentialBackends, | 		CredentialBackends: credentialBackends, | ||||||
| 		DisableMlock:       true, | 		DisableMlock:       true, | ||||||
| 		Logger:             logger, | 		Logger:             logger, | ||||||
|  | 		NumRollbackWorkers: 10, | ||||||
| 		BuiltinRegistry:    corehelpers.NewMockBuiltinRegistry(), | 		BuiltinRegistry:    corehelpers.NewMockBuiltinRegistry(), | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -618,6 +618,12 @@ alphabetic order by name. | |||||||
|  |  | ||||||
| @include 'telemetry-metrics/vault/rollback/attempt.mdx' | @include 'telemetry-metrics/vault/rollback/attempt.mdx' | ||||||
|  |  | ||||||
|  | @include 'telemetry-metrics/vault/rollback/inflight.mdx' | ||||||
|  |  | ||||||
|  | @include 'telemetry-metrics/vault/rollback/queued.mdx' | ||||||
|  |  | ||||||
|  | @include 'telemetry-metrics/vault/rollback/waiting.mdx' | ||||||
|  |  | ||||||
| @include 'telemetry-metrics/vault/route/create/mountpoint.mdx' | @include 'telemetry-metrics/vault/route/create/mountpoint.mdx' | ||||||
|  |  | ||||||
| @include 'telemetry-metrics/vault/route/delete/mountpoint.mdx' | @include 'telemetry-metrics/vault/route/delete/mountpoint.mdx' | ||||||
|   | |||||||
| @@ -114,6 +114,12 @@ Vault instance. | |||||||
|  |  | ||||||
| @include 'telemetry-metrics/vault/rollback/attempt.mdx' | @include 'telemetry-metrics/vault/rollback/attempt.mdx' | ||||||
|  |  | ||||||
|  | @include 'telemetry-metrics/vault/rollback/inflight.mdx' | ||||||
|  |  | ||||||
|  | @include 'telemetry-metrics/vault/rollback/queued.mdx' | ||||||
|  |  | ||||||
|  | @include 'telemetry-metrics/vault/rollback/waiting.mdx' | ||||||
|  |  | ||||||
| ## Route metrics | ## Route metrics | ||||||
|  |  | ||||||
| @include 'telemetry-metrics/route-intro.mdx' | @include 'telemetry-metrics/route-intro.mdx' | ||||||
|   | |||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | ### vault.rollback.inflight ((#vault-rollback-inflight)) | ||||||
|  |  | ||||||
|  | Metric type | Value  | Description | ||||||
|  | ----------- | ------ | ----------- | ||||||
|  | gauge       | number | Number of rollback operations inflight | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | ### vault.rollback.queued ((#vault-rollback-queued)) | ||||||
|  |  | ||||||
|  | Metric type | Value  | Description | ||||||
|  | ----------- | ------ | ----------- | ||||||
|  | guage       | number | The number of rollback operations waiting to be started | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | ### vault.rollback.waiting ((#vault-rollback-waiting)) | ||||||
|  |  | ||||||
|  | Metric type | Value | Description | ||||||
|  | ----------- | ----- | ----------- | ||||||
|  | summary     | ms    | Time between queueing a rollback operation and the operation starting | ||||||
		Reference in New Issue
	
	Block a user
	 miagilepner
					miagilepner