Upgrade go-limiter to fix building on 1.17 (#12358)

* Upgrade go-limiter

* Modify quota system to pass contexts to upgraded go-limiter

* One more spot

* Add context vars to unit tests

* missed one
This commit is contained in:
Scott Miller
2021-09-01 16:28:47 -05:00
committed by GitHub
parent 566767a3c7
commit b368a67595
8 changed files with 30 additions and 23 deletions

2
go.mod
View File

@@ -159,7 +159,7 @@ require (
github.com/ryanuber/go-glob v1.0.0 github.com/ryanuber/go-glob v1.0.0
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da
github.com/sasha-s/go-deadlock v0.2.0 github.com/sasha-s/go-deadlock v0.2.0
github.com/sethvargo/go-limiter v0.3.0 github.com/sethvargo/go-limiter v0.7.0
github.com/shirou/gopsutil v3.21.5+incompatible github.com/shirou/gopsutil v3.21.5+incompatible
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
github.com/tencentcloud/tencentcloud-sdk-go v3.0.171+incompatible // indirect github.com/tencentcloud/tencentcloud-sdk-go v3.0.171+incompatible // indirect

2
go.sum
View File

@@ -1148,6 +1148,8 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUt
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/sethvargo/go-limiter v0.3.0 h1:yRMc+Qs2yqw6YJp6UxrO2iUs6DOSq4zcnljbB7/rMns= github.com/sethvargo/go-limiter v0.3.0 h1:yRMc+Qs2yqw6YJp6UxrO2iUs6DOSq4zcnljbB7/rMns=
github.com/sethvargo/go-limiter v0.3.0/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU= github.com/sethvargo/go-limiter v0.3.0/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/sethvargo/go-limiter v0.7.0 h1:CSvIHUxzNBVmsopHcMmYANZMsJFFJTi9kO+Ms+EYIhM=
github.com/sethvargo/go-limiter v0.7.0/go.mod h1:C0kbSFbiriE5k2FFOe18M1YZbAR2Fiwf72uGu0CXCcU=
github.com/shirou/gopsutil v3.21.5+incompatible h1:OloQyEerMi7JUrXiNzy8wQ5XN+baemxSl12QgIzt0jc= github.com/shirou/gopsutil v3.21.5+incompatible h1:OloQyEerMi7JUrXiNzy8wQ5XN+baemxSl12QgIzt0jc=
github.com/shirou/gopsutil v3.21.5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/gopsutil v3.21.5+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 h1:pntxY8Ary0t43dCZ5dqY4YTJCObLY1kIXl0uzMv+7DE=

View File

@@ -48,7 +48,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
return return
} }
quotaResp, err := core.ApplyRateLimitQuota(&quotas.Request{ quotaResp, err := core.ApplyRateLimitQuota(r.Context(), &quotas.Request{
Type: quotas.TypeRateLimit, Type: quotas.TypeRateLimit,
Path: path, Path: path,
MountPath: strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path), MountPath: strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path),

View File

@@ -2733,7 +2733,7 @@ func (c *Core) setupQuotas(ctx context.Context, isPerfStandby bool) error {
// ApplyRateLimitQuota checks the request against all the applicable quota rules. // ApplyRateLimitQuota checks the request against all the applicable quota rules.
// If the given request's path is exempt, no rate limiting will be applied. // If the given request's path is exempt, no rate limiting will be applied.
func (c *Core) ApplyRateLimitQuota(req *quotas.Request) (quotas.Response, error) { func (c *Core) ApplyRateLimitQuota(ctx context.Context, req *quotas.Request) (quotas.Response, error) {
req.Type = quotas.TypeRateLimit req.Type = quotas.TypeRateLimit
resp := quotas.Response{ resp := quotas.Response{
@@ -2747,7 +2747,7 @@ func (c *Core) ApplyRateLimitQuota(req *quotas.Request) (quotas.Response, error)
return resp, nil return resp, nil
} }
return c.quotaManager.ApplyQuota(req) return c.quotaManager.ApplyQuota(ctx, req)
} }
return resp, nil return resp, nil

View File

@@ -168,7 +168,7 @@ type Manager struct {
// Quota represents the common properties of every quota type // Quota represents the common properties of every quota type
type Quota interface { type Quota interface {
// allow checks the if the request is allowed by the quota type implementation. // allow checks the if the request is allowed by the quota type implementation.
allow(*Request) (Response, error) allow(context.Context, *Request) (Response, error)
// quotaID is the identifier of the quota rule // quotaID is the identifier of the quota rule
quotaID() string quotaID() string
@@ -181,7 +181,7 @@ type Quota interface {
// close defines any cleanup behavior that needs to be executed when a quota // close defines any cleanup behavior that needs to be executed when a quota
// rule is deleted. // rule is deleted.
close() error close(context.Context) error
// handleRemount takes in the new mount path in the quota // handleRemount takes in the new mount path in the quota
handleRemount(string) handleRemount(string)
@@ -287,7 +287,7 @@ func (m *Manager) setQuotaLocked(ctx context.Context, qType string, quota Quota,
// If there already exists an entry in the db, remove that first. // If there already exists an entry in the db, remove that first.
if raw != nil { if raw != nil {
quota := raw.(Quota) quota := raw.(Quota)
if err := quota.close(); err != nil { if err := quota.close(ctx); err != nil {
return err return err
} }
err = txn.Delete(qType, raw) err = txn.Delete(qType, raw)
@@ -518,7 +518,7 @@ func (m *Manager) DeleteQuota(ctx context.Context, qType string, name string) er
} }
quota := raw.(Quota) quota := raw.(Quota)
if err := quota.close(); err != nil { if err := quota.close(ctx); err != nil {
return err return err
} }
@@ -541,7 +541,7 @@ func (m *Manager) DeleteQuota(ctx context.Context, qType string, name string) er
// ApplyQuota runs the request against any quota rule that is applicable to it. If // ApplyQuota runs the request against any quota rule that is applicable to it. If
// there are multiple quota rule that matches the request parameters, rule that // there are multiple quota rule that matches the request parameters, rule that
// takes precedence will be used to allow/reject the request. // takes precedence will be used to allow/reject the request.
func (m *Manager) ApplyQuota(req *Request) (Response, error) { func (m *Manager) ApplyQuota(ctx context.Context, req *Request) (Response, error) {
var resp Response var resp Response
quota, err := m.QueryQuota(req) quota, err := m.QueryQuota(req)
@@ -562,7 +562,7 @@ func (m *Manager) ApplyQuota(req *Request) (Response, error) {
return resp, nil return resp, nil
} }
return quota.allow(req) return quota.allow(ctx, req)
} }
// SetEnableRateLimitAuditLogging updates the operator preference regarding the // SetEnableRateLimitAuditLogging updates the operator preference regarding the

View File

@@ -1,6 +1,7 @@
package quotas package quotas
import ( import (
"context"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"math" "math"
@@ -264,7 +265,7 @@ func (rlq *RateLimitQuota) QuotaName() string {
// returned if the request ID or address is empty. If the path is exempt, the // returned if the request ID or address is empty. If the path is exempt, the
// quota will not be evaluated. Otherwise, the client rate limiter is retrieved // quota will not be evaluated. Otherwise, the client rate limiter is retrieved
// by address and the rate limit quota is checked against that limiter. // by address and the rate limit quota is checked against that limiter.
func (rlq *RateLimitQuota) allow(req *Request) (Response, error) { func (rlq *RateLimitQuota) allow(ctx context.Context, req *Request) (Response, error) {
resp := Response{ resp := Response{
Headers: make(map[string]string), Headers: make(map[string]string),
} }
@@ -300,7 +301,11 @@ func (rlq *RateLimitQuota) allow(req *Request) (Response, error) {
} }
} }
limit, remaining, reset, allow := rlq.store.Take(req.ClientAddress) limit, remaining, reset, allow, err := rlq.store.Take(ctx, req.ClientAddress)
if err != nil {
return resp, err
}
resp.Allowed = allow resp.Allowed = allow
resp.Headers[httplimit.HeaderRateLimitLimit] = strconv.FormatUint(limit, 10) resp.Headers[httplimit.HeaderRateLimitLimit] = strconv.FormatUint(limit, 10)
resp.Headers[httplimit.HeaderRateLimitRemaining] = strconv.FormatUint(remaining, 10) resp.Headers[httplimit.HeaderRateLimitRemaining] = strconv.FormatUint(remaining, 10)
@@ -320,13 +325,13 @@ func (rlq *RateLimitQuota) allow(req *Request) (Response, error) {
// close stops the current running client purge loop. // close stops the current running client purge loop.
// It should be called with the write lock held. // It should be called with the write lock held.
func (rlq *RateLimitQuota) close() error { func (rlq *RateLimitQuota) close(ctx context.Context) error {
if rlq.purgeBlocked { if rlq.purgeBlocked {
close(rlq.closePurgeBlockedCh) close(rlq.closePurgeBlockedCh)
} }
if rlq.store != nil { if rlq.store != nil {
return rlq.store.Close() return rlq.store.Close(ctx)
} }
return nil return nil

View File

@@ -37,7 +37,7 @@ func TestNewRateLimitQuota(t *testing.T) {
err := tc.rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()) err := tc.rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink())
require.Equal(t, tc.expectErr, err != nil, err) require.Equal(t, tc.expectErr, err != nil, err)
if err == nil { if err == nil {
require.Nil(t, tc.rlq.close()) require.Nil(t, tc.rlq.close(context.Background()))
} }
}) })
} }
@@ -46,7 +46,7 @@ func TestNewRateLimitQuota(t *testing.T) {
func TestRateLimitQuota_Close(t *testing.T) { func TestRateLimitQuota_Close(t *testing.T) {
rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, time.Second, time.Minute) rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, time.Second, time.Minute)
require.NoError(t, rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink())) require.NoError(t, rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()))
require.NoError(t, rlq.close()) require.NoError(t, rlq.close(context.Background()))
time.Sleep(time.Second) // allow enough time for purgeClientsLoop to receive on closeCh time.Sleep(time.Second) // allow enough time for purgeClientsLoop to receive on closeCh
require.False(t, rlq.getPurgeBlocked(), "expected blocked client purging to be disabled after explicit close") require.False(t, rlq.getPurgeBlocked(), "expected blocked client purging to be disabled after explicit close")
@@ -66,14 +66,14 @@ func TestRateLimitQuota_Allow(t *testing.T) {
} }
require.NoError(t, rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink())) require.NoError(t, rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()))
defer rlq.close() defer rlq.close(context.Background())
var wg sync.WaitGroup var wg sync.WaitGroup
reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) { reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) {
defer wg.Done() defer wg.Done()
resp, err := rlq.allow(&Request{ClientAddress: addr}) resp, err := rlq.allow(context.Background(), &Request{ClientAddress: addr})
if err != nil { if err != nil {
return return
} }
@@ -141,7 +141,7 @@ func TestRateLimitQuota_Allow_WithBlock(t *testing.T) {
} }
require.NoError(t, rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink())) require.NoError(t, rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()))
defer rlq.close() defer rlq.close(context.Background())
require.True(t, rlq.getPurgeBlocked()) require.True(t, rlq.getPurgeBlocked())
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -149,7 +149,7 @@ func TestRateLimitQuota_Allow_WithBlock(t *testing.T) {
reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) { reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) {
defer wg.Done() defer wg.Done()
resp, err := rlq.allow(&Request{ClientAddress: addr}) resp, err := rlq.allow(context.Background(), &Request{ClientAddress: addr})
if err != nil { if err != nil {
return return
} }
@@ -221,5 +221,5 @@ func TestRateLimitQuota_Update(t *testing.T) {
require.NoError(t, qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true)) require.NoError(t, qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true))
require.NoError(t, qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true)) require.NoError(t, qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true))
require.Nil(t, quota.close()) require.Nil(t, quota.close(context.Background()))
} }

View File

@@ -40,7 +40,7 @@ func (*entManager) Reset() error {
type LeaseCountQuota struct{} type LeaseCountQuota struct{}
func (l LeaseCountQuota) allow(request *Request) (Response, error) { func (l LeaseCountQuota) allow(_ context.Context, _ *Request) (Response, error) {
panic("implement me") panic("implement me")
} }
@@ -56,7 +56,7 @@ func (l LeaseCountQuota) initialize(logger log.Logger, sink *metricsutil.Cluster
panic("implement me") panic("implement me")
} }
func (l LeaseCountQuota) close() error { func (l LeaseCountQuota) close(_ context.Context) error {
panic("implement me") panic("implement me")
} }