diff --git a/vault/rollback_test.go b/vault/rollback_test.go index a99060df69..eeca3dad9f 100644 --- a/vault/rollback_test.go +++ b/vault/rollback_test.go @@ -211,14 +211,9 @@ func TestRollbackManager_WorkerPool(t *testing.T) { }() } - core.mountsLock.RLock() - numMounts := len(core.mounts.Entries) - core.mountsLock.RUnlock() - timeout, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() got := make(map[string]bool) - gotLock := sync.RWMutex{} hasMore := true for hasMore { // we're using 5 workers, so we would expect to see 5 writes to the @@ -228,10 +223,8 @@ func TestRollbackManager_WorkerPool(t *testing.T) { case <-timeout.Done(): require.Fail(t, "test timed out") case i := <-ran: - gotLock.Lock() got[i] = true numGot := len(got) - gotLock.Unlock() if numGot == 5 { close(release) hasMore = false @@ -239,37 +232,38 @@ func TestRollbackManager_WorkerPool(t *testing.T) { } } done := make(chan struct{}) + defer close(done) // start a goroutine to consume the remaining items from the queued work + gotAllPaths := make(chan struct{}) go func() { + channelClosed := false for { select { case i := <-ran: - gotLock.Lock() got[i] = true - gotLock.Unlock() + + // keep this goroutine running even after there are 10 paths. + // More rollback operations might get queued before Stop() is + // called, and we don't want them to block on writing the to the + // ran channel + if len(got) == 10 && !channelClosed { + close(gotAllPaths) + channelClosed = true + } + case <-timeout.Done(): + require.Fail(t, "test timed out") case <-done: return } } }() - // wait for every mount to be rolled back at least once - numMountsDone := 0 - for numMountsDone < numMounts { - <-core.rollback.rollbacksDoneCh - numMountsDone++ - } - - // stop the rollback worker, which will wait for all inflight rollbacks to + // wait until all 10 backends have each ran at least once + <-gotAllPaths + // stop the rollback worker, which will wait for any inflight rollbacks to // complete core.rollback.Stop() - close(done) - - // we should have received at least 1 rollback for every backend - gotLock.RLock() - defer gotLock.RUnlock() - require.GreaterOrEqual(t, len(got), 10) } // TestRollbackManager_numRollbackWorkers verifies that the number of rollback