Request Limiter Reload tests (#25126)

This PR introduces a new testonly endpoint for introspecting the
RequestLimiter state. It makes use of the endpoint to verify that changes to
the request_limiter config are honored across reload.

In the future, we may choose to make the sys/internal/request-limiter/status
endpoint available in normal binaries, but this is an expedient way to expose
the status for testing without having to rush the design.

In order to re-use as much of the existing command package utility funcionality
as possible without introducing sprawling code changes, I introduced a new
server_util.go and exported some fields via accessors.

The tests shook out a couple of bugs (including a deadlock and lack of
locking around the core limiterRegistry state).
This commit is contained in:
Mike Palmiotto
2024-02-01 09:11:08 -05:00
committed by GitHub
parent eb2b905af0
commit e4a11ae7cd
12 changed files with 404 additions and 64 deletions

View File

@@ -156,7 +156,7 @@ jobs:
# testonly tagged tests need an additional tag to be included
# also running some extra tests for sanity checking with the testonly build tag
(
go list -tags=testonly ./vault/external_tests/{kv,token,*replication-perf*,*testonly*} ./vault/ | gotestsum tool ci-matrix --debug \
go list -tags=testonly ./vault/external_tests/{kv,token,*replication-perf*,*testonly*} ./command/*testonly* ./vault/ | gotestsum tool ci-matrix --debug \
--partitions "${{ inputs.total-runners }}" \
--timing-files 'test-results/go-test/*.json' > matrix.json
)

View File

@@ -0,0 +1,207 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build testonly
package command_testonly
import (
"os"
"sync"
"testing"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
)
const (
baseHCL = `
backend "inmem" { }
disable_mlock = true
listener "tcp" {
address = "127.0.0.1:8209"
tls_disable = "true"
}
api_addr = "http://127.0.0.1:8209"
`
requestLimiterDisableHCL = `
request_limiter {
disable = true
}
`
requestLimiterEnableHCL = `
request_limiter {
disable = false
}
`
)
// TestServer_ReloadRequestLimiter tests a series of reloads and state
// transitions between RequestLimiter enable and disable.
func TestServer_ReloadRequestLimiter(t *testing.T) {
t.Parallel()
enabledResponse := &vault.RequestLimiterResponse{
GlobalDisabled: false,
ListenerDisabled: false,
Limiters: map[string]*vault.LimiterStatus{
limits.WriteLimiter: {
Enabled: true,
Flags: limits.DefaultLimiterFlags[limits.WriteLimiter],
},
limits.SpecialPathLimiter: {
Enabled: true,
Flags: limits.DefaultLimiterFlags[limits.SpecialPathLimiter],
},
},
}
disabledResponse := &vault.RequestLimiterResponse{
GlobalDisabled: true,
ListenerDisabled: false,
Limiters: map[string]*vault.LimiterStatus{
limits.WriteLimiter: {
Enabled: false,
},
limits.SpecialPathLimiter: {
Enabled: false,
},
},
}
cases := []struct {
name string
configAfter string
expectedResponse *vault.RequestLimiterResponse
}{
{
"enable after default",
baseHCL + requestLimiterEnableHCL,
enabledResponse,
},
{
"enable after enable",
baseHCL + requestLimiterEnableHCL,
enabledResponse,
},
{
"disable after enable",
baseHCL + requestLimiterDisableHCL,
disabledResponse,
},
{
"default after disable",
baseHCL,
enabledResponse,
},
{
"default after default",
baseHCL,
enabledResponse,
},
{
"disable after default",
baseHCL + requestLimiterDisableHCL,
disabledResponse,
},
{
"disable after disable",
baseHCL + requestLimiterDisableHCL,
disabledResponse,
},
}
ui, srv := command.TestServerCommand(t)
f, err := os.CreateTemp(t.TempDir(), "")
require.NoErrorf(t, err, "error creating temp dir: %v", err)
_, err = f.WriteString(baseHCL)
require.NoErrorf(t, err, "cannot write temp file contents")
configPath := f.Name()
var output string
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
code := srv.Run([]string{"-config", configPath})
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
require.Equal(t, 0, code, output)
}()
select {
case <-srv.StartedCh():
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
defer func() {
srv.ShutdownCh <- struct{}{}
wg.Wait()
}()
err = f.Close()
require.NoErrorf(t, err, "unable to close temp file")
// create a client and unseal vault
cli, err := srv.Client()
require.NoError(t, err)
require.NoError(t, cli.SetAddress("http://127.0.0.1:8209"))
initResp, err := cli.Sys().Init(&api.InitRequest{SecretShares: 1, SecretThreshold: 1})
require.NoError(t, err)
_, err = cli.Sys().Unseal(initResp.Keys[0])
require.NoError(t, err)
cli.SetToken(initResp.RootToken)
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
require.Contains(t, output, "Request Limiter: enabled")
verifyLimiters := func(t *testing.T, expectedResponse *vault.RequestLimiterResponse) {
t.Helper()
statusResp, err := cli.Logical().Read("/sys/internal/request-limiter/status")
require.NoError(t, err)
require.NotNil(t, statusResp)
limitersResp, ok := statusResp.Data["request_limiter"]
require.True(t, ok)
require.NotNil(t, limitersResp)
var limiters *vault.RequestLimiterResponse
err = mapstructure.Decode(limitersResp, &limiters)
require.NoError(t, err)
require.NotNil(t, limiters)
require.Equal(t, expectedResponse, limiters)
}
// Start off with default enabled
verifyLimiters(t, enabledResponse)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Write the new contents and reload the server
f, err = os.OpenFile(configPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644)
require.NoError(t, err)
defer f.Close()
_, err = f.WriteString(tc.configAfter)
require.NoErrorf(t, err, "cannot write temp file contents")
srv.SighupCh <- struct{}{}
select {
case <-srv.ReloadedCh():
case <-time.After(5 * time.Second):
t.Fatalf("test timed out")
}
verifyLimiters(t, tc.expectedResponse)
})
}
}

View File

@@ -22,11 +22,9 @@ import (
"testing"
"time"
"github.com/hashicorp/cli"
"github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/physical"
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/seal"
@@ -97,29 +95,6 @@ cloud {
`
)
func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &ServerCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
ShutdownCh: MakeShutdownCh(),
SighupCh: MakeSighupCh(),
SigUSR2Ch: MakeSigUSR2Ch(),
PhysicalBackends: map[string]physical.Factory{
"inmem": physInmem.NewInmem,
"inmem_ha": physInmem.NewInmemHA,
},
// These prevent us from random sleep guessing...
startedCh: make(chan struct{}, 5),
reloadedCh: make(chan struct{}, 5),
licenseReloadedCh: make(chan error),
}
}
func TestServer_ReloadListener(t *testing.T) {
t.Parallel()

48
command/server_util.go Normal file
View File

@@ -0,0 +1,48 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package command
import (
"testing"
"github.com/hashicorp/cli"
"github.com/hashicorp/vault/sdk/physical"
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
)
func TestServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) {
tb.Helper()
return testServerCommand(tb)
}
func (c *ServerCommand) StartedCh() chan struct{} {
return c.startedCh
}
func (c *ServerCommand) ReloadedCh() chan struct{} {
return c.reloadedCh
}
func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &ServerCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
ShutdownCh: MakeShutdownCh(),
SighupCh: MakeSighupCh(),
SigUSR2Ch: MakeSigUSR2Ch(),
PhysicalBackends: map[string]physical.Factory{
"inmem": physInmem.NewInmem,
"inmem_ha": physInmem.NewInmemHA,
},
// These prevent us from random sleep guessing...
startedCh: make(chan struct{}, 5),
reloadedCh: make(chan struct{}, 5),
licenseReloadedCh: make(chan error),
}
}

View File

@@ -919,6 +919,7 @@ func acquireLimiterListener(core *vault.Core, rawReq *http.Request, r *logical.R
if disableRequestLimiter != nil {
disable = disableRequestLimiter.(bool)
}
r.RequestLimiterDisabled = disable
if disable {
return &limits.RequestListener{}, true
}

View File

@@ -47,6 +47,7 @@ const (
// RequestLimiter is a thin wrapper for limiter.DefaultLimiter.
type RequestLimiter struct {
*limiter.DefaultLimiter
Flags LimiterFlags
}
// Acquire consults the underlying RequestLimiter to see if a new
@@ -103,34 +104,31 @@ func concurrencyChanger(limit int) int {
return int(change)
}
var (
// DefaultWriteLimiterFlags have a less conservative MinLimit to prevent
var DefaultLimiterFlags = map[string]LimiterFlags{
// WriteLimiter default flags have a less conservative MinLimit to prevent
// over-optimizing the request latency, which would result in
// under-utilization and client starvation.
DefaultWriteLimiterFlags = LimiterFlags{
Name: WriteLimiter,
MinLimit: 100,
MaxLimit: 5000,
}
WriteLimiter: {
MinLimit: 100,
MaxLimit: 5000,
InitialLimit: 100,
},
// DefaultSpecialPathLimiterFlags have a conservative MinLimit to allow more
// aggressive concurrency throttling for CPU-bound workloads such as
// SpecialPathLimiter default flags have a conservative MinLimit to allow
// more aggressive concurrency throttling for CPU-bound workloads such as
// `pki/issue`.
DefaultSpecialPathLimiterFlags = LimiterFlags{
Name: SpecialPathLimiter,
MinLimit: 5,
MaxLimit: 5000,
}
)
SpecialPathLimiter: {
MinLimit: 5,
MaxLimit: 5000,
InitialLimit: 5,
},
}
// LimiterFlags establish some initial configuration for a new request limiter.
type LimiterFlags struct {
// Name specifies the limiter Name for registry lookup and logging.
Name string
// MinLimit defines the minimum concurrency floor to prevent over-throttling
// requests during periods of high traffic.
MinLimit int
MinLimit int `json:"min_limit,omitempty" mapstructure:"min_limit,omitempty"`
// MaxLimit defines the maximum concurrency ceiling to prevent skewing to a
// point of no return.
@@ -139,7 +137,7 @@ type LimiterFlags struct {
// high-performing specs will tolerate higher limits, while the algorithm
// will find its own steady-state concurrency well below this threshold in
// most cases.
MaxLimit int
MaxLimit int `json:"max_limit,omitempty" mapstructure:"max_limit,omitempty"`
// InitialLimit defines the starting concurrency limit prior to any
// measurements.
@@ -150,13 +148,13 @@ type LimiterFlags struct {
// rejection; however, the adaptive nature of the algorithm will prevent
// this from being a prolonged state as the allowed concurrency will
// increase during normal operation.
InitialLimit int
InitialLimit int `json:"initial_limit,omitempty" mapstructure:"initial_limit,omitempty"`
}
// NewRequestLimiter is a basic constructor for the RequestLimiter wrapper. It
// is responsible for setting up the Gradient2 Limit and instantiating a new
// wrapped DefaultLimiter.
func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter, error) {
func NewRequestLimiter(logger hclog.Logger, name string, flags LimiterFlags) (*RequestLimiter, error) {
logger.Info("setting up new request limiter",
"initialLimit", flags.InitialLimit,
"maxLimit", flags.MaxLimit,
@@ -167,7 +165,7 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter
// decisions. It gathers latency measurements and calculates an Exponential
// Moving Average to determine whether latency deviation warrants a change
// in the current concurrency limit.
lim, err := limit.NewGradient2Limit(flags.Name,
lim, err := limit.NewGradient2Limit(name,
flags.InitialLimit,
flags.MaxLimit,
flags.MinLimit,
@@ -178,7 +176,7 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter
DefaultMetricsRegistry,
)
if err != nil {
return nil, fmt.Errorf("failed to create gradient2 limit: %w", err)
return &RequestLimiter{}, fmt.Errorf("failed to create gradient2 limit: %w", err)
}
strategy := strategy.NewSimpleStrategy(flags.InitialLimit)
@@ -187,5 +185,5 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter
return &RequestLimiter{}, err
}
return &RequestLimiter{defLimiter}, nil
return &RequestLimiter{Flags: flags, DefaultLimiter: defLimiter}, nil
}

View File

@@ -67,8 +67,8 @@ func NewLimiterRegistry(logger hclog.Logger) *LimiterRegistry {
// processEnvVars consults Limiter-specific environment variables and tells the
// caller if the Limiter should be disabled. If not, it adjusts the passed-in
// limiterFlags as appropriate.
func (r *LimiterRegistry) processEnvVars(flags *LimiterFlags, envDisabled, envMin, envMax string) bool {
envFlagsLogger := r.Logger.With("name", flags.Name)
func (r *LimiterRegistry) processEnvVars(name string, flags *LimiterFlags, envDisabled, envMin, envMax string) bool {
envFlagsLogger := r.Logger.With("name", name)
if disabledRaw := os.Getenv(envDisabled); disabledRaw != "" {
disabled, err := strconv.ParseBool(disabledRaw)
if err != nil {
@@ -147,20 +147,22 @@ func (r *LimiterRegistry) Enable() {
r.Logger.Info("enabling request limiters")
r.Limiters = map[string]*RequestLimiter{}
r.Register(DefaultWriteLimiterFlags)
r.Register(DefaultSpecialPathLimiterFlags)
for name, flags := range DefaultLimiterFlags {
r.Register(name, flags)
}
r.Enabled = true
}
// Register creates a new request limiter and assigns it a slot in the
// LimiterRegistry. Locking should be done in the caller.
func (r *LimiterRegistry) Register(flags LimiterFlags) {
func (r *LimiterRegistry) Register(name string, flags LimiterFlags) {
var disabled bool
switch flags.Name {
switch name {
case WriteLimiter:
disabled = r.processEnvVars(&flags,
disabled = r.processEnvVars(name, &flags,
EnvVaultDisableWriteLimiter,
EnvVaultWriteLimiterMin,
EnvVaultWriteLimiterMax,
@@ -169,7 +171,7 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) {
return
}
case SpecialPathLimiter:
disabled = r.processEnvVars(&flags,
disabled = r.processEnvVars(name, &flags,
EnvVaultDisableSpecialPathLimiter,
EnvVaultSpecialPathLimiterMin,
EnvVaultSpecialPathLimiterMax,
@@ -178,7 +180,7 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) {
return
}
default:
r.Logger.Warn("skipping invalid limiter type", "key", flags.Name)
r.Logger.Warn("skipping invalid limiter type", "key", name)
return
}
@@ -186,18 +188,19 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) {
// equilibrium, since max might be too high.
flags.InitialLimit = flags.MinLimit
limiter, err := NewRequestLimiter(r.Logger.Named(flags.Name), flags)
limiter, err := NewRequestLimiter(r.Logger.Named(name), name, flags)
if err != nil {
r.Logger.Error("failed to register limiter", "name", flags.Name, "error", err)
r.Logger.Error("failed to register limiter", "name", name, "error", err)
return
}
r.Limiters[flags.Name] = limiter
r.Limiters[name] = limiter
}
// Disable drops its references to underlying limiters.
func (r *LimiterRegistry) Disable() {
r.Lock()
defer r.Unlock()
if !r.Enabled {
return
@@ -209,7 +212,6 @@ func (r *LimiterRegistry) Disable() {
// here and the garbage-collector should take care of the rest.
r.Limiters = map[string]*RequestLimiter{}
r.Enabled = false
r.Unlock()
}
// GetLimiter looks up a RequestLimiter by key in the LimiterRegistry.

View File

@@ -255,6 +255,9 @@ type Request struct {
// Name of the chroot namespace for the listener that the request was made against
ChrootNamespace string `json:"chroot_namespace,omitempty"`
// RequestLimiterDisabled tells whether the request context has Request Limiter applied.
RequestLimiterDisabled bool `json:"request_limiter_disabled,omitempty"`
}
// Clone returns a deep copy (almost) of the request.

View File

@@ -725,6 +725,8 @@ func (c *Core) EchoDuration() time.Duration {
}
func (c *Core) GetRequestLimiter(key string) *limits.RequestLimiter {
c.limiterRegistryLock.Lock()
defer c.limiterRegistryLock.Unlock()
return c.limiterRegistry.GetLimiter(key)
}

View File

@@ -226,6 +226,10 @@ func NewSystemBackend(core *Core, logger log.Logger, config *logical.BackendConf
b.Backend.Paths = append(b.Backend.Paths, b.experimentPaths()...)
b.Backend.Paths = append(b.Backend.Paths, b.introspectionPaths()...)
if requestLimiterRead := b.requestLimiterReadPath(); requestLimiterRead != nil {
b.Backend.Paths = append(b.Backend.Paths, b.requestLimiterReadPath())
}
if core.rawEnabled {
b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...)
}

View File

@@ -0,0 +1,12 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !testonly
package vault
import (
"github.com/hashicorp/vault/sdk/framework"
)
func (b *SystemBackend) requestLimiterReadPath() *framework.Path { return nil }

View File

@@ -0,0 +1,88 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build testonly
package vault
import (
"context"
"net/http"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
// RequestLimiterResponse is a struct for marshalling Request Limiter status responses.
type RequestLimiterResponse struct {
GlobalDisabled bool `json:"global_disabled" mapstructure:"global_disabled"`
ListenerDisabled bool `json:"listener_disabled" mapstructure:"listener_disabled"`
Limiters map[string]*LimiterStatus `json:"types" mapstructure:"types"`
}
// LimiterStatus holds the per-limiter status and flags for testing.
type LimiterStatus struct {
Enabled bool `json:"enabled" mapstructure:"enabled"`
Flags limits.LimiterFlags `json:"flags,omitempty" mapstructure:"flags,omitempty"`
}
const readRequestLimiterHelpText = `
Read the current status of the request limiter.
`
func (b *SystemBackend) requestLimiterReadPath() *framework.Path {
return &framework.Path{
Pattern: "internal/request-limiter/status$",
HelpDescription: readRequestLimiterHelpText,
HelpSynopsis: readRequestLimiterHelpText,
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Callback: b.handleReadRequestLimiter,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "read",
OperationSuffix: "verbosity-level-for",
},
Responses: map[int][]framework.Response{
http.StatusOK: {{
Description: "OK",
}},
},
Summary: "Read the current status of the request limiter.",
},
},
}
}
// handleReadRequestLimiter returns the enabled Request Limiter status for this node.
func (b *SystemBackend) handleReadRequestLimiter(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
resp := &RequestLimiterResponse{
Limiters: make(map[string]*LimiterStatus),
}
b.Core.limiterRegistryLock.Lock()
registry := b.Core.limiterRegistry
b.Core.limiterRegistryLock.Unlock()
resp.GlobalDisabled = !registry.Enabled
resp.ListenerDisabled = req.RequestLimiterDisabled
enabled := !(resp.GlobalDisabled || resp.ListenerDisabled)
for name := range limits.DefaultLimiterFlags {
var flags limits.LimiterFlags
if requestLimiter := b.Core.GetRequestLimiter(name); requestLimiter != nil && enabled {
flags = requestLimiter.Flags
}
resp.Limiters[name] = &LimiterStatus{
Enabled: enabled,
Flags: flags,
}
}
return &logical.Response{
Data: map[string]interface{}{
"request_limiter": resp,
},
}, nil
}