From 52d9d43a1c39bfe73ca8037127de37b84c975b57 Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Thu, 4 Jan 2024 10:26:41 -0800 Subject: [PATCH] Refactor auto-auth backoff to helper package. (#24668) I have an upcoming PR for event notifications that needs similar exponential backoff logic, and I prefer the API and logic in the auto-auth exponential backoff rather than that of github.com/cenkalti/backoff/v3. This does have a small behavior change: the auto-auth min backoff will now be randomly reduced by up to 25% on the first call. This is a desirable property to avoid thundering herd problems, where a bunch of agents won't all try have the same retry timeout. --- command/agentproxyshared/auth/auth.go | 92 +++++++++------------- command/agentproxyshared/auth/auth_test.go | 56 ++++++------- sdk/helper/backoff/backoff.go | 90 +++++++++++++++++++++ sdk/helper/backoff/backoff_test.go | 52 ++++++++++++ 4 files changed, 208 insertions(+), 82 deletions(-) create mode 100644 sdk/helper/backoff/backoff.go create mode 100644 sdk/helper/backoff/backoff_test.go diff --git a/command/agentproxyshared/auth/auth.go b/command/agentproxyshared/auth/auth.go index 36f775cb5e..0017acd34c 100644 --- a/command/agentproxyshared/auth/auth.go +++ b/command/agentproxyshared/auth/auth.go @@ -7,12 +7,14 @@ import ( "context" "encoding/json" "errors" + "math" "math/rand" "net/http" "time" "github.com/armon/go-metrics" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/backoff" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/sdk/helper/jsonutil" @@ -113,19 +115,15 @@ func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler { return ah } -func backoff(ctx context.Context, backoff *autoAuthBackoff) bool { - if backoff.exitOnErr { +func backoffSleep(ctx context.Context, backoff *autoAuthBackoff) bool { + nextSleep, err := backoff.backoff.Next() + if err != nil { return false } - select { - case <-time.After(backoff.current): + case <-time.After(nextSleep): case <-ctx.Done(): } - - // Increase exponential backoff for the next time if we don't - // successfully auth/renew/etc. - backoff.next() return true } @@ -137,12 +135,13 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { if ah.minBackoff <= 0 { ah.minBackoff = defaultMinBackoff } - - backoffCfg := newAutoAuthBackoff(ah.minBackoff, ah.maxBackoff, ah.exitOnError) - - if backoffCfg.min >= backoffCfg.max { + if ah.maxBackoff <= 0 { + ah.maxBackoff = defaultMaxBackoff + } + if ah.minBackoff > ah.maxBackoff { return errors.New("auth handler: min_backoff cannot be greater than max_backoff") } + backoffCfg := newAutoAuthBackoff(ah.minBackoff, ah.maxBackoff, ah.exitOnError) ah.logger.Info("starting auth handler") defer func() { @@ -204,10 +203,10 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { case AuthMethodWithClient: clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client) if err != nil { - ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff) + ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } @@ -234,7 +233,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("could not look up token", "err", err, "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -254,7 +253,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -267,7 +266,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -305,7 +304,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("error authenticating", "error", err, "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -320,7 +319,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("authentication returned nil wrap info", "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -329,7 +328,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -339,7 +338,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -354,7 +353,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { } am.CredSuccess() - backoffCfg.reset() + backoffCfg.backoff.Reset() select { case <-ctx.Done(): @@ -378,7 +377,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("token file validation failed, token may be invalid", "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -388,7 +387,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("token file validation returned empty client token", "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -420,7 +419,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("authentication returned nil auth info", "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -429,7 +428,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("authentication returned empty client token", "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -447,7 +446,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { } am.CredSuccess() - backoffCfg.reset() + backoffCfg.backoff.Reset() } if watcher != nil { @@ -461,7 +460,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { ah.logger.Error("error creating lifetime watcher", "error", err, "backoff", backoffCfg) metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1) - if backoff(ctx, backoffCfg) { + if backoffSleep(ctx, backoffCfg) { continue } return err @@ -507,10 +506,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error { // autoAuthBackoff tracks exponential backoff state. type autoAuthBackoff struct { - min time.Duration - max time.Duration - current time.Duration - exitOnErr bool + backoff *backoff.Backoff } func newAutoAuthBackoff(min, max time.Duration, exitErr bool) *autoAuthBackoff { @@ -522,32 +518,18 @@ func newAutoAuthBackoff(min, max time.Duration, exitErr bool) *autoAuthBackoff { min = defaultMinBackoff } + retries := math.MaxInt + if exitErr { + retries = 0 + } + + b := backoff.NewBackoff(retries, min, max) + return &autoAuthBackoff{ - current: min, - max: max, - min: min, - exitOnErr: exitErr, + backoff: b, } } -// next determines the next backoff duration that is roughly twice -// the current value, capped to a max value, with a measure of randomness. -func (b *autoAuthBackoff) next() { - maxBackoff := 2 * b.current - - if maxBackoff > b.max { - maxBackoff = b.max - } - - // Trim a random amount (0-25%) off the doubled duration - trim := rand.Int63n(int64(maxBackoff) / 4) - b.current = maxBackoff - time.Duration(trim) -} - -func (b *autoAuthBackoff) reset() { - b.current = b.min -} - func (b autoAuthBackoff) String() string { - return b.current.Truncate(10 * time.Millisecond).String() + return b.backoff.Current().Truncate(10 * time.Millisecond).String() } diff --git a/command/agentproxyshared/auth/auth_test.go b/command/agentproxyshared/auth/auth_test.go index d695a33a48..4ecfff03f5 100644 --- a/command/agentproxyshared/auth/auth_test.go +++ b/command/agentproxyshared/auth/auth_test.go @@ -113,35 +113,36 @@ func TestAgentBackoff(t *testing.T) { backoff := newAutoAuthBackoff(defaultMinBackoff, max, false) // Test initial value - if backoff.current != defaultMinBackoff { - t.Fatalf("expected 1s initial backoff, got: %v", backoff.current) + if backoff.backoff.Current() > defaultMinBackoff || backoff.backoff.Current() < defaultMinBackoff*3/4 { + t.Fatalf("expected 1s initial backoff, got: %v", backoff.backoff.Current()) } - // Test that backoff values are in expected range (75-100% of 2*previous) + // Test that backoffSleep values are in expected range (75-100% of 2*previous) + next, _ := backoff.backoff.Next() for i := 0; i < 9; i++ { - old := backoff.current - backoff.next() + old := next + next, _ = backoff.backoff.Next() expMax := 2 * old expMin := 3 * expMax / 4 - if backoff.current < expMin || backoff.current > expMax { - t.Fatalf("expected backoff in range %v to %v, got: %v", expMin, expMax, backoff) + if next < expMin || next > expMax { + t.Fatalf("expected backoffSleep in range %v to %v, got: %v", expMin, expMax, backoff) } } - // Test that backoff is capped + // Test that backoffSleep is capped for i := 0; i < 100; i++ { - backoff.next() - if backoff.current > max { + _, _ = backoff.backoff.Next() + if backoff.backoff.Current() > max { t.Fatalf("backoff exceeded max of 100s: %v", backoff) } } // Test reset - backoff.reset() - if backoff.current != defaultMinBackoff { - t.Fatalf("expected 1s backoff after reset, got: %v", backoff.current) + backoff.backoff.Reset() + if backoff.backoff.Current() > defaultMinBackoff || backoff.backoff.Current() < defaultMinBackoff*3/4 { + t.Fatalf("expected 1s backoff after reset, got: %v", backoff.backoff.Current()) } } @@ -163,35 +164,36 @@ func TestAgentMinBackoffCustom(t *testing.T) { backoff := newAutoAuthBackoff(test.minBackoff, max, false) // Test initial value - if backoff.current != test.want { - t.Fatalf("expected %d initial backoff, got: %v", test.want, backoff.current) + if backoff.backoff.Current() > test.want || backoff.backoff.Current() < test.want*3/4 { + t.Fatalf("expected %d initial backoffSleep, got: %v", test.want, backoff.backoff.Current()) } - // Test that backoff values are in expected range (75-100% of 2*previous) + // Test that backoffSleep values are in expected range (75-100% of 2*previous) + next, _ := backoff.backoff.Next() for i := 0; i < 5; i++ { - old := backoff.current - backoff.next() + old := next + next, _ = backoff.backoff.Next() expMax := 2 * old expMin := 3 * expMax / 4 - if backoff.current < expMin || backoff.current > expMax { - t.Fatalf("expected backoff in range %v to %v, got: %v", expMin, expMax, backoff) + if next < expMin || next > expMax { + t.Fatalf("expected backoffSleep in range %v to %v, got: %v", expMin, expMax, backoff) } } - // Test that backoff is capped + // Test that backoffSleep is capped for i := 0; i < 100; i++ { - backoff.next() - if backoff.current > max { - t.Fatalf("backoff exceeded max of 100s: %v", backoff) + next, _ = backoff.backoff.Next() + if next > max { + t.Fatalf("backoffSleep exceeded max of 100s: %v", backoff) } } // Test reset - backoff.reset() - if backoff.current != test.want { - t.Fatalf("expected %d backoff after reset, got: %v", test.want, backoff.current) + backoff.backoff.Reset() + if backoff.backoff.Current() > test.want || backoff.backoff.Current() < test.want*3/4 { + t.Fatalf("expected %d backoffSleep after reset, got: %v", test.want, backoff.backoff.Current()) } } } diff --git a/sdk/helper/backoff/backoff.go b/sdk/helper/backoff/backoff.go new file mode 100644 index 0000000000..35fb059538 --- /dev/null +++ b/sdk/helper/backoff/backoff.go @@ -0,0 +1,90 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package backoff + +import ( + "errors" + "math" + "math/rand" + "time" +) + +var ErrMaxRetry = errors.New("exceeded maximum number of retries") + +const maxJitter = 0.25 + +// Backoff is used to do capped exponential backoff with jitter, with a maximum number of retries. +// Generally, use this struct by calling Next() or NextSleep() after a failure. +// If configured for N max retries, Next() and NextSleep() will return an error on the call N+1. +// The jitter is set to 25%, so values returned will have up to 25% less than twice the previous value. +// The min value will also include jitter, so the first call will almost always be less than the requested minimum value. +// Backoff is not thread-safe. +type Backoff struct { + currentAttempt int + maxRetries int + min time.Duration + max time.Duration + current time.Duration +} + +// NewBackoff creates a new exponential backoff with the given number of maximum retries and min/max durations. +func NewBackoff(maxRetries int, min, max time.Duration) *Backoff { + b := &Backoff{ + maxRetries: maxRetries, + max: max, + min: min, + } + b.Reset() + return b +} + +// Current returns the next time that will be returned by Next() (or slept in NextSleep()). +func (b *Backoff) Current() time.Duration { + return b.current +} + +// Next determines the next backoff duration that is roughly twice +// the current value, capped to a max value, with a measure of randomness. +// It returns an error if there are no more retries left. +func (b *Backoff) Next() (time.Duration, error) { + if b.currentAttempt >= b.maxRetries { + return time.Duration(-1), ErrMaxRetry + } + defer func() { + b.currentAttempt += 1 + }() + if b.currentAttempt == 0 { + return b.current, nil + } + next := 2 * b.current + if next > b.max { + next = b.max + } + next = jitter(next) + b.current = next + return next, nil +} + +// NextSleep will synchronously sleep the next backoff amount (see Next()). +// It returns an error if there are no more retries left. +func (b *Backoff) NextSleep() error { + next, err := b.Next() + if err != nil { + return err + } + time.Sleep(next) + return nil +} + +// Reset resets the state to the initial backoff amount and 0 retries. +func (b *Backoff) Reset() { + b.current = b.min + b.current = jitter(b.current) + b.currentAttempt = 0 +} + +func jitter(t time.Duration) time.Duration { + f := float64(t) * (1.0 - maxJitter*rand.Float64()) + return time.Duration(math.Floor(f)) +} diff --git a/sdk/helper/backoff/backoff_test.go b/sdk/helper/backoff/backoff_test.go new file mode 100644 index 0000000000..46b85257ba --- /dev/null +++ b/sdk/helper/backoff/backoff_test.go @@ -0,0 +1,52 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package backoff + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestBackoff_Basic tests that basic exponential backoff works as expected up to a max of 3 times. +func TestBackoff_Basic(t *testing.T) { + for i := 0; i < 100; i++ { + b := NewBackoff(3, 1*time.Millisecond, 10*time.Millisecond) + x, err := b.Next() + assert.Nil(t, err) + assert.LessOrEqual(t, x, 1*time.Millisecond) + assert.GreaterOrEqual(t, x, 750*time.Microsecond) + + x2, err := b.Next() + assert.Nil(t, err) + assert.LessOrEqual(t, x2, x*2) + assert.GreaterOrEqual(t, x2, x*3/4) + + x3, err := b.Next() + assert.Nil(t, err) + assert.LessOrEqual(t, x3, x2*2) + assert.GreaterOrEqual(t, x3, x2*3/4) + + _, err = b.Next() + assert.NotNil(t, err) + } +} + +// TestBackoff_ZeroRetriesAlwaysFails checks that if retries is set to zero, then an error is returned immediately. +func TestBackoff_ZeroRetriesAlwaysFails(t *testing.T) { + b := NewBackoff(0, 1*time.Millisecond, 10*time.Millisecond) + _, err := b.Next() + assert.NotNil(t, err) +} + +// TestBackoff_MaxIsEnforced checks that the maximum backoff is enforced. +func TestBackoff_MaxIsEnforced(t *testing.T) { + b := NewBackoff(1001, 1*time.Millisecond, 2*time.Millisecond) + for i := 0; i < 1000; i++ { + x, err := b.Next() + assert.LessOrEqual(t, x, 2*time.Millisecond) + assert.Nil(t, err) + } +}