mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-03 03:58:01 +00:00
Refactor auto-auth backoff to helper package. (#24668)
I have an upcoming PR for event notifications that needs similar exponential backoff logic, and I prefer the API and logic in the auto-auth exponential backoff rather than that of github.com/cenkalti/backoff/v3. This does have a small behavior change: the auto-auth min backoff will now be randomly reduced by up to 25% on the first call. This is a desirable property to avoid thundering herd problems, where a bunch of agents won't all try have the same retry timeout.
This commit is contained in:
committed by
GitHub
parent
edaa48ad90
commit
52d9d43a1c
@@ -7,12 +7,14 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/armon/go-metrics"
|
"github.com/armon/go-metrics"
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/backoff"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||||
@@ -113,19 +115,15 @@ func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
|
|||||||
return ah
|
return ah
|
||||||
}
|
}
|
||||||
|
|
||||||
func backoff(ctx context.Context, backoff *autoAuthBackoff) bool {
|
func backoffSleep(ctx context.Context, backoff *autoAuthBackoff) bool {
|
||||||
if backoff.exitOnErr {
|
nextSleep, err := backoff.backoff.Next()
|
||||||
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-time.After(backoff.current):
|
case <-time.After(nextSleep):
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increase exponential backoff for the next time if we don't
|
|
||||||
// successfully auth/renew/etc.
|
|
||||||
backoff.next()
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,12 +135,13 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
if ah.minBackoff <= 0 {
|
if ah.minBackoff <= 0 {
|
||||||
ah.minBackoff = defaultMinBackoff
|
ah.minBackoff = defaultMinBackoff
|
||||||
}
|
}
|
||||||
|
if ah.maxBackoff <= 0 {
|
||||||
backoffCfg := newAutoAuthBackoff(ah.minBackoff, ah.maxBackoff, ah.exitOnError)
|
ah.maxBackoff = defaultMaxBackoff
|
||||||
|
}
|
||||||
if backoffCfg.min >= backoffCfg.max {
|
if ah.minBackoff > ah.maxBackoff {
|
||||||
return errors.New("auth handler: min_backoff cannot be greater than max_backoff")
|
return errors.New("auth handler: min_backoff cannot be greater than max_backoff")
|
||||||
}
|
}
|
||||||
|
backoffCfg := newAutoAuthBackoff(ah.minBackoff, ah.maxBackoff, ah.exitOnError)
|
||||||
|
|
||||||
ah.logger.Info("starting auth handler")
|
ah.logger.Info("starting auth handler")
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -204,10 +203,10 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
case AuthMethodWithClient:
|
case AuthMethodWithClient:
|
||||||
clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client)
|
clientToUse, err = am.(AuthMethodWithClient).AuthClient(ah.client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoff)
|
ah.logger.Error("error creating client for authentication call", "error", err, "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,7 +233,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("could not look up token", "err", err, "backoff", backoffCfg)
|
ah.logger.Error("could not look up token", "err", err, "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -254,7 +253,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoffCfg)
|
ah.logger.Error("error getting path or data from method", "error", err, "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -267,7 +266,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoffCfg)
|
ah.logger.Error("error creating client for wrapped call", "error", err, "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -305,7 +304,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("error authenticating", "error", err, "backoff", backoffCfg)
|
ah.logger.Error("error authenticating", "error", err, "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -320,7 +319,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("authentication returned nil wrap info", "backoff", backoffCfg)
|
ah.logger.Error("authentication returned nil wrap info", "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -329,7 +328,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoffCfg)
|
ah.logger.Error("authentication returned empty wrapped client token", "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -339,7 +338,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoffCfg)
|
ah.logger.Error("failed to encode wrapinfo", "error", err, "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -354,7 +353,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
am.CredSuccess()
|
am.CredSuccess()
|
||||||
backoffCfg.reset()
|
backoffCfg.backoff.Reset()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -378,7 +377,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("token file validation failed, token may be invalid", "backoff", backoffCfg)
|
ah.logger.Error("token file validation failed, token may be invalid", "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -388,7 +387,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("token file validation returned empty client token", "backoff", backoffCfg)
|
ah.logger.Error("token file validation returned empty client token", "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -420,7 +419,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("authentication returned nil auth info", "backoff", backoffCfg)
|
ah.logger.Error("authentication returned nil auth info", "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -429,7 +428,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("authentication returned empty client token", "backoff", backoffCfg)
|
ah.logger.Error("authentication returned empty client token", "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -447,7 +446,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
am.CredSuccess()
|
am.CredSuccess()
|
||||||
backoffCfg.reset()
|
backoffCfg.backoff.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
if watcher != nil {
|
if watcher != nil {
|
||||||
@@ -461,7 +460,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
ah.logger.Error("error creating lifetime watcher", "error", err, "backoff", backoffCfg)
|
ah.logger.Error("error creating lifetime watcher", "error", err, "backoff", backoffCfg)
|
||||||
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
metrics.IncrCounter([]string{ah.metricsSignifier, "auth", "failure"}, 1)
|
||||||
|
|
||||||
if backoff(ctx, backoffCfg) {
|
if backoffSleep(ctx, backoffCfg) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -507,10 +506,7 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
|
|||||||
|
|
||||||
// autoAuthBackoff tracks exponential backoff state.
|
// autoAuthBackoff tracks exponential backoff state.
|
||||||
type autoAuthBackoff struct {
|
type autoAuthBackoff struct {
|
||||||
min time.Duration
|
backoff *backoff.Backoff
|
||||||
max time.Duration
|
|
||||||
current time.Duration
|
|
||||||
exitOnErr bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAutoAuthBackoff(min, max time.Duration, exitErr bool) *autoAuthBackoff {
|
func newAutoAuthBackoff(min, max time.Duration, exitErr bool) *autoAuthBackoff {
|
||||||
@@ -522,32 +518,18 @@ func newAutoAuthBackoff(min, max time.Duration, exitErr bool) *autoAuthBackoff {
|
|||||||
min = defaultMinBackoff
|
min = defaultMinBackoff
|
||||||
}
|
}
|
||||||
|
|
||||||
|
retries := math.MaxInt
|
||||||
|
if exitErr {
|
||||||
|
retries = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
b := backoff.NewBackoff(retries, min, max)
|
||||||
|
|
||||||
return &autoAuthBackoff{
|
return &autoAuthBackoff{
|
||||||
current: min,
|
backoff: b,
|
||||||
max: max,
|
|
||||||
min: min,
|
|
||||||
exitOnErr: exitErr,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// next determines the next backoff duration that is roughly twice
|
|
||||||
// the current value, capped to a max value, with a measure of randomness.
|
|
||||||
func (b *autoAuthBackoff) next() {
|
|
||||||
maxBackoff := 2 * b.current
|
|
||||||
|
|
||||||
if maxBackoff > b.max {
|
|
||||||
maxBackoff = b.max
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trim a random amount (0-25%) off the doubled duration
|
|
||||||
trim := rand.Int63n(int64(maxBackoff) / 4)
|
|
||||||
b.current = maxBackoff - time.Duration(trim)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *autoAuthBackoff) reset() {
|
|
||||||
b.current = b.min
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b autoAuthBackoff) String() string {
|
func (b autoAuthBackoff) String() string {
|
||||||
return b.current.Truncate(10 * time.Millisecond).String()
|
return b.backoff.Current().Truncate(10 * time.Millisecond).String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,35 +113,36 @@ func TestAgentBackoff(t *testing.T) {
|
|||||||
backoff := newAutoAuthBackoff(defaultMinBackoff, max, false)
|
backoff := newAutoAuthBackoff(defaultMinBackoff, max, false)
|
||||||
|
|
||||||
// Test initial value
|
// Test initial value
|
||||||
if backoff.current != defaultMinBackoff {
|
if backoff.backoff.Current() > defaultMinBackoff || backoff.backoff.Current() < defaultMinBackoff*3/4 {
|
||||||
t.Fatalf("expected 1s initial backoff, got: %v", backoff.current)
|
t.Fatalf("expected 1s initial backoff, got: %v", backoff.backoff.Current())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that backoff values are in expected range (75-100% of 2*previous)
|
// Test that backoffSleep values are in expected range (75-100% of 2*previous)
|
||||||
|
next, _ := backoff.backoff.Next()
|
||||||
for i := 0; i < 9; i++ {
|
for i := 0; i < 9; i++ {
|
||||||
old := backoff.current
|
old := next
|
||||||
backoff.next()
|
next, _ = backoff.backoff.Next()
|
||||||
|
|
||||||
expMax := 2 * old
|
expMax := 2 * old
|
||||||
expMin := 3 * expMax / 4
|
expMin := 3 * expMax / 4
|
||||||
|
|
||||||
if backoff.current < expMin || backoff.current > expMax {
|
if next < expMin || next > expMax {
|
||||||
t.Fatalf("expected backoff in range %v to %v, got: %v", expMin, expMax, backoff)
|
t.Fatalf("expected backoffSleep in range %v to %v, got: %v", expMin, expMax, backoff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that backoff is capped
|
// Test that backoffSleep is capped
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
backoff.next()
|
_, _ = backoff.backoff.Next()
|
||||||
if backoff.current > max {
|
if backoff.backoff.Current() > max {
|
||||||
t.Fatalf("backoff exceeded max of 100s: %v", backoff)
|
t.Fatalf("backoff exceeded max of 100s: %v", backoff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test reset
|
// Test reset
|
||||||
backoff.reset()
|
backoff.backoff.Reset()
|
||||||
if backoff.current != defaultMinBackoff {
|
if backoff.backoff.Current() > defaultMinBackoff || backoff.backoff.Current() < defaultMinBackoff*3/4 {
|
||||||
t.Fatalf("expected 1s backoff after reset, got: %v", backoff.current)
|
t.Fatalf("expected 1s backoff after reset, got: %v", backoff.backoff.Current())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,35 +164,36 @@ func TestAgentMinBackoffCustom(t *testing.T) {
|
|||||||
backoff := newAutoAuthBackoff(test.minBackoff, max, false)
|
backoff := newAutoAuthBackoff(test.minBackoff, max, false)
|
||||||
|
|
||||||
// Test initial value
|
// Test initial value
|
||||||
if backoff.current != test.want {
|
if backoff.backoff.Current() > test.want || backoff.backoff.Current() < test.want*3/4 {
|
||||||
t.Fatalf("expected %d initial backoff, got: %v", test.want, backoff.current)
|
t.Fatalf("expected %d initial backoffSleep, got: %v", test.want, backoff.backoff.Current())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that backoff values are in expected range (75-100% of 2*previous)
|
// Test that backoffSleep values are in expected range (75-100% of 2*previous)
|
||||||
|
next, _ := backoff.backoff.Next()
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
old := backoff.current
|
old := next
|
||||||
backoff.next()
|
next, _ = backoff.backoff.Next()
|
||||||
|
|
||||||
expMax := 2 * old
|
expMax := 2 * old
|
||||||
expMin := 3 * expMax / 4
|
expMin := 3 * expMax / 4
|
||||||
|
|
||||||
if backoff.current < expMin || backoff.current > expMax {
|
if next < expMin || next > expMax {
|
||||||
t.Fatalf("expected backoff in range %v to %v, got: %v", expMin, expMax, backoff)
|
t.Fatalf("expected backoffSleep in range %v to %v, got: %v", expMin, expMax, backoff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that backoff is capped
|
// Test that backoffSleep is capped
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
backoff.next()
|
next, _ = backoff.backoff.Next()
|
||||||
if backoff.current > max {
|
if next > max {
|
||||||
t.Fatalf("backoff exceeded max of 100s: %v", backoff)
|
t.Fatalf("backoffSleep exceeded max of 100s: %v", backoff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test reset
|
// Test reset
|
||||||
backoff.reset()
|
backoff.backoff.Reset()
|
||||||
if backoff.current != test.want {
|
if backoff.backoff.Current() > test.want || backoff.backoff.Current() < test.want*3/4 {
|
||||||
t.Fatalf("expected %d backoff after reset, got: %v", test.want, backoff.current)
|
t.Fatalf("expected %d backoffSleep after reset, got: %v", test.want, backoff.backoff.Current())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
90
sdk/helper/backoff/backoff.go
Normal file
90
sdk/helper/backoff/backoff.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
// Copyright (c) HashiCorp, Inc.
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package backoff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrMaxRetry = errors.New("exceeded maximum number of retries")
|
||||||
|
|
||||||
|
const maxJitter = 0.25
|
||||||
|
|
||||||
|
// Backoff is used to do capped exponential backoff with jitter, with a maximum number of retries.
|
||||||
|
// Generally, use this struct by calling Next() or NextSleep() after a failure.
|
||||||
|
// If configured for N max retries, Next() and NextSleep() will return an error on the call N+1.
|
||||||
|
// The jitter is set to 25%, so values returned will have up to 25% less than twice the previous value.
|
||||||
|
// The min value will also include jitter, so the first call will almost always be less than the requested minimum value.
|
||||||
|
// Backoff is not thread-safe.
|
||||||
|
type Backoff struct {
|
||||||
|
currentAttempt int
|
||||||
|
maxRetries int
|
||||||
|
min time.Duration
|
||||||
|
max time.Duration
|
||||||
|
current time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBackoff creates a new exponential backoff with the given number of maximum retries and min/max durations.
|
||||||
|
func NewBackoff(maxRetries int, min, max time.Duration) *Backoff {
|
||||||
|
b := &Backoff{
|
||||||
|
maxRetries: maxRetries,
|
||||||
|
max: max,
|
||||||
|
min: min,
|
||||||
|
}
|
||||||
|
b.Reset()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Current returns the next time that will be returned by Next() (or slept in NextSleep()).
|
||||||
|
func (b *Backoff) Current() time.Duration {
|
||||||
|
return b.current
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next determines the next backoff duration that is roughly twice
|
||||||
|
// the current value, capped to a max value, with a measure of randomness.
|
||||||
|
// It returns an error if there are no more retries left.
|
||||||
|
func (b *Backoff) Next() (time.Duration, error) {
|
||||||
|
if b.currentAttempt >= b.maxRetries {
|
||||||
|
return time.Duration(-1), ErrMaxRetry
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
b.currentAttempt += 1
|
||||||
|
}()
|
||||||
|
if b.currentAttempt == 0 {
|
||||||
|
return b.current, nil
|
||||||
|
}
|
||||||
|
next := 2 * b.current
|
||||||
|
if next > b.max {
|
||||||
|
next = b.max
|
||||||
|
}
|
||||||
|
next = jitter(next)
|
||||||
|
b.current = next
|
||||||
|
return next, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextSleep will synchronously sleep the next backoff amount (see Next()).
|
||||||
|
// It returns an error if there are no more retries left.
|
||||||
|
func (b *Backoff) NextSleep() error {
|
||||||
|
next, err := b.Next()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(next)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets the state to the initial backoff amount and 0 retries.
|
||||||
|
func (b *Backoff) Reset() {
|
||||||
|
b.current = b.min
|
||||||
|
b.current = jitter(b.current)
|
||||||
|
b.currentAttempt = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func jitter(t time.Duration) time.Duration {
|
||||||
|
f := float64(t) * (1.0 - maxJitter*rand.Float64())
|
||||||
|
return time.Duration(math.Floor(f))
|
||||||
|
}
|
||||||
52
sdk/helper/backoff/backoff_test.go
Normal file
52
sdk/helper/backoff/backoff_test.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
// Copyright (c) HashiCorp, Inc.
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package backoff
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestBackoff_Basic tests that basic exponential backoff works as expected up to a max of 3 times.
|
||||||
|
func TestBackoff_Basic(t *testing.T) {
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
b := NewBackoff(3, 1*time.Millisecond, 10*time.Millisecond)
|
||||||
|
x, err := b.Next()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.LessOrEqual(t, x, 1*time.Millisecond)
|
||||||
|
assert.GreaterOrEqual(t, x, 750*time.Microsecond)
|
||||||
|
|
||||||
|
x2, err := b.Next()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.LessOrEqual(t, x2, x*2)
|
||||||
|
assert.GreaterOrEqual(t, x2, x*3/4)
|
||||||
|
|
||||||
|
x3, err := b.Next()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.LessOrEqual(t, x3, x2*2)
|
||||||
|
assert.GreaterOrEqual(t, x3, x2*3/4)
|
||||||
|
|
||||||
|
_, err = b.Next()
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBackoff_ZeroRetriesAlwaysFails checks that if retries is set to zero, then an error is returned immediately.
|
||||||
|
func TestBackoff_ZeroRetriesAlwaysFails(t *testing.T) {
|
||||||
|
b := NewBackoff(0, 1*time.Millisecond, 10*time.Millisecond)
|
||||||
|
_, err := b.Next()
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBackoff_MaxIsEnforced checks that the maximum backoff is enforced.
|
||||||
|
func TestBackoff_MaxIsEnforced(t *testing.T) {
|
||||||
|
b := NewBackoff(1001, 1*time.Millisecond, 2*time.Millisecond)
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
x, err := b.Next()
|
||||||
|
assert.LessOrEqual(t, x, 2*time.Millisecond)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user