feat: add retry package

This package provides a consistent way for us to retry arbitrary logic.
It provides the following backoff algorithms:

- exponential
- linear
- constant

Signed-off-by: Andrew Rynhard <andrew@andrewrynhard.com>
This commit is contained in:
Andrew Rynhard
2019-10-05 16:37:28 -07:00
parent a799b05012
commit 92de30715e
15 changed files with 984 additions and 91 deletions

View File

@@ -10,7 +10,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
stdlibnet "net" stdlibnet "net"
"os" "os"
"strings" "strings"
@@ -34,6 +33,7 @@ import (
"github.com/talos-systems/talos/pkg/crypto/x509" "github.com/talos-systems/talos/pkg/crypto/x509"
"github.com/talos-systems/talos/pkg/kubernetes" "github.com/talos-systems/talos/pkg/kubernetes"
"github.com/talos-systems/talos/pkg/net" "github.com/talos-systems/talos/pkg/net"
"github.com/talos-systems/talos/pkg/retry"
) )
var etcdImage = fmt.Sprintf("%s:%s", constants.EtcdImage, constants.DefaultEtcdVersion) var etcdImage = fmt.Sprintf("%s:%s", constants.EtcdImage, constants.DefaultEtcdVersion)
@@ -304,13 +304,12 @@ func buildInitialCluster(config config.Configurator, name, ip string) (initial s
return "", err return "", err
} }
for i := 0; i < 200; i++ { opts := []retry.Option{retry.WithUnits(3 * time.Second), retry.WithJitter(time.Second)}
endpoints, err := h.MasterIPs() err = retry.Constant(10*time.Minute, opts...).Retry(func() error {
var endpoints []string
endpoints, err = h.MasterIPs()
if err != nil { if err != nil {
log.Printf("failed to get client endpoints: %+v\n", err) return retry.ExpectedError(err)
time.Sleep(3 * time.Second)
continue
} }
// Etcd expects host:port format. // Etcd expects host:port format.
@@ -320,12 +319,10 @@ func buildInitialCluster(config config.Configurator, name, ip string) (initial s
peerAddrs := []string{"https://" + ip + ":2380"} peerAddrs := []string{"https://" + ip + ":2380"}
resp, err := addMember(endpoints, peerAddrs) var resp *clientv3.MemberAddResponse
resp, err = addMember(endpoints, peerAddrs)
if err != nil { if err != nil {
log.Printf("failed to add etcd member: %+v\n", err) return retry.ExpectedError(err)
time.Sleep(3 * time.Second)
continue
} }
newID := resp.Member.ID newID := resp.Member.ID
@@ -344,8 +341,12 @@ func buildInitialCluster(config config.Configurator, name, ip string) (initial s
initial = strings.Join(conf, ",") initial = strings.Join(conf, ",")
return initial, nil return nil
})
if err != nil {
return "", errors.New("failed to discover etcd cluster")
} }
return "", errors.New("failed to discover etcd cluster") return initial, nil
} }

View File

@@ -20,27 +20,27 @@ import (
gptpartition "github.com/talos-systems/talos/pkg/blockdevice/table/gpt/partition" gptpartition "github.com/talos-systems/talos/pkg/blockdevice/table/gpt/partition"
"github.com/talos-systems/talos/pkg/blockdevice/util" "github.com/talos-systems/talos/pkg/blockdevice/util"
"github.com/talos-systems/talos/pkg/constants" "github.com/talos-systems/talos/pkg/constants"
"github.com/talos-systems/talos/pkg/retry"
) )
// RetryFunc defines the requirements for retrying a mount point operation. // RetryFunc defines the requirements for retrying a mount point operation.
type RetryFunc func(*Point) error type RetryFunc func(*Point) error
func retry(f RetryFunc, p *Point) (err error) { func mountRetry(f RetryFunc, p *Point) (err error) {
for i := 0; i < 50; i++ { err = retry.Constant(5*time.Second, retry.WithUnits(50*time.Millisecond)).Retry(func() error {
if err = f(p); err != nil { if err = f(p); err != nil {
switch err { switch err {
case unix.EBUSY: case unix.EBUSY:
time.Sleep(100 * time.Millisecond) return retry.ExpectedError(err)
continue
default: default:
return err return retry.UnexpectedError(err)
} }
} }
return nil return nil
} })
return errors.Errorf("timeout: %+v", err) return err
} }
// Point represents a Linux mount point. // Point represents a Linux mount point.
@@ -123,9 +123,9 @@ func (p *Point) Mount() (err error) {
switch { switch {
case p.Overlay: case p.Overlay:
err = retry(overlay, p) err = mountRetry(overlay, p)
default: default:
err = retry(mount, p) err = mountRetry(mount, p)
} }
if err != nil { if err != nil {
@@ -133,7 +133,7 @@ func (p *Point) Mount() (err error) {
} }
if p.Shared { if p.Shared {
if err = retry(share, p); err != nil { if err = mountRetry(share, p); err != nil {
return errors.Errorf("error sharing mount point %s: %+v", p.target, err) return errors.Errorf("error sharing mount point %s: %+v", p.target, err)
} }
} }
@@ -145,7 +145,7 @@ func (p *Point) Mount() (err error) {
// retry every 100 milliseconds over the course of 5 seconds. // retry every 100 milliseconds over the course of 5 seconds.
func (p *Point) Unmount() (err error) { func (p *Point) Unmount() (err error) {
p.target = path.Join(p.Prefix, p.target) p.target = path.Join(p.Prefix, p.target)
if err := retry(unmount, p); err != nil { if err := mountRetry(unmount, p); err != nil {
return err return err
} }

View File

@@ -16,6 +16,7 @@ import (
"github.com/talos-systems/talos/pkg/blockdevice/table" "github.com/talos-systems/talos/pkg/blockdevice/table"
"github.com/talos-systems/talos/pkg/blockdevice/table/gpt" "github.com/talos-systems/talos/pkg/blockdevice/table/gpt"
"github.com/talos-systems/talos/pkg/retry"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -129,20 +130,19 @@ func (bd *BlockDevice) RereadPartitionTable() error {
) )
// Reread the partition table. // Reread the partition table.
for i := 0; i < 50; i++ { err = retry.Constant(5*time.Second, retry.WithUnits(50*time.Millisecond)).Retry(func() error {
if _, _, ret = unix.Syscall(unix.SYS_IOCTL, bd.f.Fd(), unix.BLKRRPART, 0); ret == 0 { if _, _, ret = unix.Syscall(unix.SYS_IOCTL, bd.f.Fd(), unix.BLKRRPART, 0); ret == 0 {
return nil return nil
} }
err = errors.Errorf("re-read partition table: %v", ret)
switch ret { switch ret {
case syscall.EBUSY: case syscall.EBUSY:
time.Sleep(100 * time.Millisecond) return retry.ExpectedError(err)
continue
default: default:
return err return retry.UnexpectedError(err)
} }
})
if err != nil {
return errors.Wrap(err, "failed to re-read partition table")
} }
return err return err

View File

@@ -22,6 +22,7 @@ import (
"github.com/talos-systems/talos/pkg/blockdevice/filesystem/iso9660" "github.com/talos-systems/talos/pkg/blockdevice/filesystem/iso9660"
"github.com/talos-systems/talos/pkg/blockdevice/filesystem/vfat" "github.com/talos-systems/talos/pkg/blockdevice/filesystem/vfat"
"github.com/talos-systems/talos/pkg/blockdevice/filesystem/xfs" "github.com/talos-systems/talos/pkg/blockdevice/filesystem/xfs"
"github.com/talos-systems/talos/pkg/retry"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@@ -64,17 +65,18 @@ func FileSystem(path string) (sb filesystem.SuperBlocker, err error) {
// Sleep for up to 5s to wait for kernel to create the necessary device files. // Sleep for up to 5s to wait for kernel to create the necessary device files.
// If we dont sleep this becomes racy in that the device file does not exist // If we dont sleep this becomes racy in that the device file does not exist
// and it will fail to open. // and it will fail to open.
for i := 0; i <= 100; i++ { err = retry.Constant(5*time.Second, retry.WithUnits((50 * time.Millisecond))).Retry(func() error {
if f, err = os.OpenFile(path, os.O_RDONLY|unix.O_CLOEXEC, os.ModeDevice); err != nil { if f, err = os.OpenFile(path, os.O_RDONLY|unix.O_CLOEXEC, os.ModeDevice); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
time.Sleep(50 * time.Millisecond) return retry.ExpectedError(err)
continue
} }
return retry.UnexpectedError(err)
return nil, err
} }
break return nil
})
if err != nil {
return nil, err
} }
// nolint: errcheck // nolint: errcheck

View File

@@ -8,11 +8,13 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"math"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
"github.com/pkg/errors"
"github.com/talos-systems/talos/pkg/retry"
) )
const b64 = "base64" const b64 = "base64"
@@ -20,8 +22,6 @@ const b64 = "base64"
type downloadOptions struct { type downloadOptions struct {
Headers map[string]string Headers map[string]string
Format string Format string
Retries int
Wait float64
} }
// Option configures the download options // Option configures the download options
@@ -30,8 +30,6 @@ type Option func(*downloadOptions)
func downloadDefaults() *downloadOptions { func downloadDefaults() *downloadOptions {
return &downloadOptions{ return &downloadOptions{
Headers: make(map[string]string), Headers: make(map[string]string),
Retries: 10,
Wait: float64(64),
} }
} }
@@ -57,21 +55,6 @@ func WithHeaders(headers map[string]string) Option {
} }
} }
// WithRetries specifies how many times download is retried before failing
func WithRetries(retries int) Option {
return func(d *downloadOptions) {
d.Retries = retries
}
}
// WithMaxWait specifies the maximum amount of time to wait between download
// attempts
func WithMaxWait(wait float64) Option {
return func(d *downloadOptions) {
d.Wait = wait
}
}
// Download downloads a config. // Download downloads a config.
// nolint: gocyclo // nolint: gocyclo
func Download(endpoint string, opts ...Option) (b []byte, err error) { func Download(endpoint string, opts ...Option) (b []byte, err error) {
@@ -95,33 +78,30 @@ func Download(endpoint string, opts ...Option) (b []byte, err error) {
req.Header.Set(k, v) req.Header.Set(k, v)
} }
for attempt := 0; attempt < dlOpts.Retries; attempt++ { err = retry.Exponential(60*time.Second, retry.WithUnits(time.Second), retry.WithJitter(time.Second)).Retry(func() error {
b, err = download(req) b, err = download(req)
if err != nil { if err != nil {
log.Printf("download failed: %+v", err) return retry.ExpectedError(err)
backoff(float64(attempt), dlOpts.Wait)
continue
} }
// Only need to do something 'extra' if base64 if dlOpts.Format == b64 {
// nolint: gocritic
switch dlOpts.Format {
case b64:
var b64 []byte var b64 []byte
b64, err = base64.StdEncoding.DecodeString(string(b)) b64, err = base64.StdEncoding.DecodeString(string(b))
if err != nil { if err != nil {
return b, err return err
} }
b = b64 b = b64
} }
return b, nil return nil
})
if err != nil {
return nil, errors.Wrapf(err, "failed to download config from: %s", u.String())
} }
return nil, fmt.Errorf("failed to download config from: %s", u.String()) return b, nil
} }
// download handles the actual http request // download handles the actual http request
@@ -146,14 +126,3 @@ func download(req *http.Request) (data []byte, err error) {
return data, err return data, err
} }
// backoff is a simple exponential sleep/backoff
func backoff(attempt float64, wait float64) {
snooze := math.Pow(2, attempt)
if snooze > wait {
snooze = wait
}
log.Printf("download attempt %g failed, retrying in %g seconds", attempt, snooze)
time.Sleep(time.Duration(snooze) * time.Second)
}

View File

@@ -21,13 +21,13 @@ import (
"k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/strategicpatch" "k8s.io/apimachinery/pkg/util/strategicpatch"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
restclient "k8s.io/client-go/rest" restclient "k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/clientcmd"
"github.com/talos-systems/talos/pkg/constants" "github.com/talos-systems/talos/pkg/constants"
"github.com/talos-systems/talos/pkg/crypto/x509" "github.com/talos-systems/talos/pkg/crypto/x509"
"github.com/talos-systems/talos/pkg/retry"
) )
// Helper represents a set of helper methods for interacting with the // Helper represents a set of helper methods for interacting with the
@@ -289,17 +289,19 @@ func (h *Helper) evict(p corev1.Pod, gracePeriod int64) error {
} }
func (h *Helper) waitForPodDeleted(p *corev1.Pod) error { func (h *Helper) waitForPodDeleted(p *corev1.Pod) error {
return wait.PollImmediate(1*time.Second, 60*time.Second, func() (bool, error) { return retry.Constant(time.Minute, retry.WithUnits(3*time.Second)).Retry(func() error {
pod, err := h.client.CoreV1().Pods(p.GetNamespace()).Get(p.GetName(), metav1.GetOptions{}) pod, err := h.client.CoreV1().Pods(p.GetNamespace()).Get(p.GetName(), metav1.GetOptions{})
if apierrors.IsNotFound(err) { switch {
return true, nil case apierrors.IsNotFound(err):
} return nil
if err != nil { case err != nil:
return false, errors.Wrapf(err, "failed to get pod %s/%s", p.GetNamespace(), p.GetName()) return retry.UnexpectedError(errors.Wrapf(err, "failed to get pod %s/%s", p.GetNamespace(), p.GetName()))
} }
if pod.GetUID() != p.GetUID() { if pod.GetUID() != p.GetUID() {
return true, nil return nil
} }
return false, nil
return retry.ExpectedError(errors.New("pod is still running on the node"))
}) })
} }

57
pkg/retry/constant.go Normal file
View File

@@ -0,0 +1,57 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
package retry
import (
"time"
)
type constantRetryer struct {
retryer
}
// ConstantTicker represents a ticker with a constant algorithm.
type ConstantTicker struct {
ticker
}
// Constant initializes and returns a constant Retryer.
func Constant(duration time.Duration, setters ...Option) Retryer {
opts := NewDefaultOptions(setters...)
return constantRetryer{
retryer: retryer{
duration: duration,
options: opts,
},
}
}
// NewConstantTicker is a ticker that sends the time on a channel using a
// constant algorithm.
func NewConstantTicker(opts *Options) *ConstantTicker {
l := &ConstantTicker{
ticker: ticker{
C: make(chan time.Time, 1),
options: opts,
s: make(chan struct{}, 1),
},
}
return l
}
// Retry implements the Retryer interface.
func (c constantRetryer) Retry(f RetryableFunc) error {
tick := NewConstantTicker(c.options)
defer tick.Stop()
return retry(f, c.duration, tick)
}
// Tick implements the Ticker interface.
func (c ConstantTicker) Tick() time.Duration {
return c.options.Units + c.Jitter()
}

172
pkg/retry/constant_test.go Normal file
View File

@@ -0,0 +1,172 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
// nolint: dupl
package retry
import (
"fmt"
"testing"
"time"
)
// nolint: scopelint
func Test_constantRetryer_Retry(t *testing.T) {
type fields struct {
retryer retryer
}
type args struct {
f RetryableFunc
}
count := 0
tests := []struct {
name string
fields fields
args args
expectedCount int
wantErr bool
}{
{
name: "test expected number of retries",
fields: fields{
retryer: retryer{
duration: 2 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 2,
wantErr: true,
},
{
name: "test expected number of retries with units",
fields: fields{
retryer: retryer{
duration: 2 * time.Second,
options: NewDefaultOptions(WithUnits(500 * time.Millisecond)),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 4,
wantErr: true,
},
{
name: "test unexpected error",
fields: fields{
retryer: retryer{
duration: 2 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
return UnexpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 1,
wantErr: true,
},
{
name: "test conditional unexpected error",
fields: fields{
retryer: retryer{
duration: 10 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
if count == 2 {
return UnexpectedError(fmt.Errorf("unexpected"))
}
return ExpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 2,
wantErr: true,
},
{
name: "test conditional no error",
fields: fields{
retryer: retryer{
duration: 10 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
if count == 2 {
return nil
}
return ExpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 2,
wantErr: false,
},
{
name: "no error",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
return nil
},
},
expectedCount: 0,
wantErr: false,
},
{
name: "test timeout",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(WithUnits(10 * time.Second)),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 1,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := constantRetryer{
retryer: tt.fields.retryer,
}
count = 0
if err := e.Retry(tt.args.f); (err != nil) != tt.wantErr {
t.Errorf("constantRetryer.Retry() error = %v, wantErr %v", err, tt.wantErr)
}
if count != tt.expectedCount {
t.Errorf("expected count of %d, got %d", tt.expectedCount, count)
}
})
}
}

66
pkg/retry/exponential.go Normal file
View File

@@ -0,0 +1,66 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
package retry
import (
"math"
"time"
)
type exponentialRetryer struct {
retryer
}
// ExponentialTicker represents a ticker with a truncated exponential algorithm.
// Please see https://en.wikipedia.org/wiki/Exponential_backoff for details on
// the algorithm.
type ExponentialTicker struct {
ticker
c float64
}
// Exponential initializes and returns a truncated exponential Retryer.
func Exponential(duration time.Duration, setters ...Option) Retryer {
opts := NewDefaultOptions(setters...)
return exponentialRetryer{
retryer: retryer{
duration: duration,
options: opts,
},
}
}
// NewExponentialTicker is a ticker that sends the time on a channel using a
// truncated exponential algorithm.
func NewExponentialTicker(opts *Options) *ExponentialTicker {
e := &ExponentialTicker{
ticker: ticker{
C: make(chan time.Time, 1),
options: opts,
s: make(chan struct{}, 1),
},
c: 1.0,
}
return e
}
// Retry implements the Retryer interface.
func (e exponentialRetryer) Retry(f RetryableFunc) error {
tick := NewExponentialTicker(e.options)
defer tick.Stop()
return retry(f, e.duration, tick)
}
// Tick implements the Ticker interface.
func (e *ExponentialTicker) Tick() time.Duration {
d := time.Duration((math.Pow(2, e.c)-1)/2)*e.options.Units + e.Jitter()
e.c++
return d
}

View File

@@ -0,0 +1,172 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
// nolint: dupl
package retry
import (
"fmt"
"testing"
"time"
)
// nolint: scopelint
func Test_exponentialRetryer_Retry(t *testing.T) {
type fields struct {
retryer retryer
}
type args struct {
f RetryableFunc
}
count := 0
tests := []struct {
name string
fields fields
args args
expectedCount int
wantErr bool
}{
{
name: "test expected number of retries",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(WithUnits(100 * time.Millisecond)),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 4,
wantErr: true,
},
{
name: "test expected number of retries with units",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(WithUnits(50 * time.Millisecond)),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 5,
wantErr: true,
},
{
name: "test unexpected error",
fields: fields{
retryer: retryer{
duration: 2 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
return UnexpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 1,
wantErr: true,
},
{
name: "test conditional unexpected error",
fields: fields{
retryer: retryer{
duration: 10 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
if count == 2 {
return UnexpectedError(fmt.Errorf("unexpected"))
}
return ExpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 2,
wantErr: true,
},
{
name: "test conditional no error",
fields: fields{
retryer: retryer{
duration: 10 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
if count == 2 {
return nil
}
return ExpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 2,
wantErr: false,
},
{
name: "no error",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
return nil
},
},
expectedCount: 0,
wantErr: false,
},
{
name: "test timeout",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(WithUnits(10 * time.Second)),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 2,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := exponentialRetryer{
retryer: tt.fields.retryer,
}
count = 0
if err := e.Retry(tt.args.f); (err != nil) != tt.wantErr {
t.Errorf("exponentialRetryer.Retry() error = %v, wantErr %v", err, tt.wantErr)
}
if count != tt.expectedCount {
t.Errorf("expected count of %d, got %d", tt.expectedCount, count)
}
})
}
}

63
pkg/retry/linear.go Normal file
View File

@@ -0,0 +1,63 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
package retry
import (
"time"
)
type linearRetryer struct {
retryer
}
// LinearTicker represents a ticker with a linear algorithm.
type LinearTicker struct {
ticker
c int
}
// Linear initializes and returns a linear Retryer.
func Linear(duration time.Duration, setters ...Option) Retryer {
opts := NewDefaultOptions(setters...)
return linearRetryer{
retryer: retryer{
duration: duration,
options: opts,
},
}
}
// NewLinearTicker is a ticker that sends the time on a channel using a
// linear algorithm.
func NewLinearTicker(opts *Options) *LinearTicker {
l := &LinearTicker{
ticker: ticker{
C: make(chan time.Time, 1),
options: opts,
s: make(chan struct{}, 1),
},
c: 1,
}
return l
}
// Retry implements the Retryer interface.
func (l linearRetryer) Retry(f RetryableFunc) error {
tick := NewLinearTicker(l.options)
defer tick.Stop()
return retry(f, l.duration, tick)
}
// Tick implements the Ticker interface.
func (l *LinearTicker) Tick() time.Duration {
d := time.Duration(l.c)*l.options.Units + l.Jitter()
l.c++
return d
}

172
pkg/retry/linear_test.go Normal file
View File

@@ -0,0 +1,172 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
// nolint: dupl
package retry
import (
"fmt"
"testing"
"time"
)
// nolint: scopelint
func Test_linearRetryer_Retry(t *testing.T) {
type fields struct {
retryer retryer
}
type args struct {
f RetryableFunc
}
count := 0
tests := []struct {
name string
fields fields
args args
expectedCount int
wantErr bool
}{
{
name: "test expected number of retries",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(WithUnits(100 * time.Millisecond)),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 4,
wantErr: true,
},
{
name: "test expected number of retries with units",
fields: fields{
retryer: retryer{
duration: 2 * time.Second,
options: NewDefaultOptions(WithUnits(50 * time.Millisecond)),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 9,
wantErr: true,
},
{
name: "test unexpected error",
fields: fields{
retryer: retryer{
duration: 2 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
return UnexpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 1,
wantErr: true,
},
{
name: "test conditional unexpected error",
fields: fields{
retryer: retryer{
duration: 10 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
if count == 1 {
return UnexpectedError(fmt.Errorf("unexpected"))
}
return ExpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 1,
wantErr: true,
},
{
name: "test conditional no error",
fields: fields{
retryer: retryer{
duration: 10 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
count++
if count == 2 {
return nil
}
return ExpectedError(fmt.Errorf("unexpected"))
},
},
expectedCount: 2,
wantErr: false,
},
{
name: "no error",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(),
},
},
args: args{
f: func() error {
return nil
},
},
expectedCount: 0,
wantErr: false,
},
{
name: "test timeout",
fields: fields{
retryer: retryer{
duration: 1 * time.Second,
options: NewDefaultOptions(WithUnits(10 * time.Second)),
},
},
args: args{
f: func() error {
count++
return ExpectedError(fmt.Errorf("expected"))
},
},
expectedCount: 1,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
l := linearRetryer{
retryer: tt.fields.retryer,
}
count = 0
if err := l.Retry(tt.args.f); (err != nil) != tt.wantErr {
t.Errorf("linearRetryer.Retry() error = %v, wantErr %v", err, tt.wantErr)
}
if count != tt.expectedCount {
t.Errorf("expected count of %d, got %d", tt.expectedCount, count)
}
})
}
}

44
pkg/retry/options.go Normal file
View File

@@ -0,0 +1,44 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
package retry
import "time"
// Options is the functional options struct.
type Options struct {
Units time.Duration
Jitter time.Duration
}
// Option is the functional option func.
type Option func(*Options)
// WithUnits is a functional option for setting the units of the ticker.
func WithUnits(o time.Duration) Option {
return func(args *Options) {
args.Units = o
}
}
// WithJitter is a functional option for setting the jitter flag.
func WithJitter(o time.Duration) Option {
return func(args *Options) {
args.Jitter = o
}
}
// NewDefaultOptions initializes a Options struct with default values.
func NewDefaultOptions(setters ...Option) *Options {
opts := &Options{
Units: time.Second,
Jitter: time.Duration(0),
}
for _, setter := range setters {
setter(opts)
}
return opts
}

51
pkg/retry/options_test.go Normal file
View File

@@ -0,0 +1,51 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
package retry
import (
"reflect"
"testing"
"time"
)
// nolint: scopelint
func TestNewDefaultOptions(t *testing.T) {
type args struct {
setters []Option
}
tests := []struct {
name string
args args
want *Options
}{
{
name: "with options",
args: args{
setters: []Option{WithUnits(time.Millisecond)},
},
want: &Options{
Units: time.Millisecond,
},
},
{
name: "default",
args: args{
setters: []Option{},
},
want: &Options{
Units: time.Second,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewDefaultOptions(tt.args.setters...); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewDefaultOptions() = %v, want %v", got, tt.want)
}
})
}
}

122
pkg/retry/retry.go Normal file
View File

@@ -0,0 +1,122 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
package retry
import (
"math/rand"
"time"
)
// RetryableFunc represents a function that can be retried.
type RetryableFunc func() error
// Retryer defines the requirements for retrying a function.
type Retryer interface {
Retry(RetryableFunc) error
}
// Ticker defines the requirements for providing a clock to the retry logic.
type Ticker interface {
Tick() time.Duration
StopChan() <-chan struct{}
Stop()
}
// TimeoutError represents a timeout error.
type TimeoutError struct{}
func (TimeoutError) Error() string {
return "timeout"
}
// IsTimeout reutrns if the provided error is a timeout error.
func IsTimeout(err error) bool {
_, ok := err.(TimeoutError)
return ok
}
type expectedError struct{ error }
type unexpectedError struct{ error }
type retryer struct {
duration time.Duration
options *Options
}
type ticker struct {
C chan time.Time
options *Options
rand *rand.Rand
s chan struct{}
}
func (t ticker) Jitter() time.Duration {
if int(t.options.Jitter) == 0 {
return time.Duration(0)
}
if t.rand == nil {
t.rand = rand.New(rand.NewSource(time.Now().UnixNano()))
}
return time.Duration(t.rand.Int63n(int64(t.options.Jitter)))
}
func (t ticker) StopChan() <-chan struct{} {
return t.s
}
func (t ticker) Stop() {
t.s <- struct{}{}
}
// ExpectedError error represents an error that is expected by the retrying
// function. This error is ignored.
func ExpectedError(err error) error {
return expectedError{err}
}
// UnexpectedError error represents an error that is unexpected by the retrying
// function. This error is fatal.
func UnexpectedError(err error) error {
return unexpectedError{err}
}
func retry(f RetryableFunc, d time.Duration, t Ticker) error {
timer := time.NewTimer(d)
defer timer.Stop()
// We run the func first to avoid having to wait for the next tick.
if err := f(); err != nil {
if _, ok := err.(unexpectedError); ok {
return err
}
} else {
return nil
}
for {
select {
case <-timer.C:
return TimeoutError{}
case <-t.StopChan():
return nil
case <-time.After(t.Tick()):
}
if err := f(); err != nil {
switch err.(type) {
case expectedError:
continue
case unexpectedError:
return err
}
}
return nil
}
}