mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 02:02:43 +00:00 
			
		
		
		
	Request Limiter Reload tests (#25126)
This PR introduces a new testonly endpoint for introspecting the RequestLimiter state. It makes use of the endpoint to verify that changes to the request_limiter config are honored across reload. In the future, we may choose to make the sys/internal/request-limiter/status endpoint available in normal binaries, but this is an expedient way to expose the status for testing without having to rush the design. In order to re-use as much of the existing command package utility funcionality as possible without introducing sprawling code changes, I introduced a new server_util.go and exported some fields via accessors. The tests shook out a couple of bugs (including a deadlock and lack of locking around the core limiterRegistry state).
This commit is contained in:
		
							
								
								
									
										2
									
								
								.github/workflows/test-go.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/test-go.yml
									
									
									
									
										vendored
									
									
								
							| @@ -156,7 +156,7 @@ jobs: | ||||
|           # testonly tagged tests need an additional tag to be included | ||||
|           # also running some extra tests for sanity checking with the testonly build tag | ||||
|           ( | ||||
|             go list -tags=testonly ./vault/external_tests/{kv,token,*replication-perf*,*testonly*} ./vault/ | gotestsum tool ci-matrix --debug \ | ||||
|             go list -tags=testonly ./vault/external_tests/{kv,token,*replication-perf*,*testonly*} ./command/*testonly* ./vault/ | gotestsum tool ci-matrix --debug \ | ||||
|               --partitions "${{ inputs.total-runners }}" \ | ||||
|               --timing-files 'test-results/go-test/*.json' > matrix.json | ||||
|           ) | ||||
|   | ||||
							
								
								
									
										207
									
								
								command/command_testonly/server_testonly_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										207
									
								
								command/command_testonly/server_testonly_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,207 @@ | ||||
| // Copyright (c) HashiCorp, Inc. | ||||
| // SPDX-License-Identifier: BUSL-1.1 | ||||
|  | ||||
| //go:build testonly | ||||
|  | ||||
| package command_testonly | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/api" | ||||
| 	"github.com/hashicorp/vault/command" | ||||
| 	"github.com/hashicorp/vault/limits" | ||||
| 	"github.com/hashicorp/vault/vault" | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	baseHCL = ` | ||||
| 		backend "inmem" { } | ||||
| 		disable_mlock = true | ||||
| 		listener "tcp" { | ||||
| 			address     = "127.0.0.1:8209" | ||||
| 			tls_disable = "true" | ||||
| 		} | ||||
| 		api_addr = "http://127.0.0.1:8209" | ||||
| 	` | ||||
| 	requestLimiterDisableHCL = ` | ||||
|   request_limiter { | ||||
| 	disable = true | ||||
|   } | ||||
| ` | ||||
| 	requestLimiterEnableHCL = ` | ||||
|   request_limiter { | ||||
| 	disable = false | ||||
|   } | ||||
| ` | ||||
| ) | ||||
|  | ||||
| // TestServer_ReloadRequestLimiter tests a series of reloads and state | ||||
| // transitions between RequestLimiter enable and disable. | ||||
| func TestServer_ReloadRequestLimiter(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	enabledResponse := &vault.RequestLimiterResponse{ | ||||
| 		GlobalDisabled:   false, | ||||
| 		ListenerDisabled: false, | ||||
| 		Limiters: map[string]*vault.LimiterStatus{ | ||||
| 			limits.WriteLimiter: { | ||||
| 				Enabled: true, | ||||
| 				Flags:   limits.DefaultLimiterFlags[limits.WriteLimiter], | ||||
| 			}, | ||||
| 			limits.SpecialPathLimiter: { | ||||
| 				Enabled: true, | ||||
| 				Flags:   limits.DefaultLimiterFlags[limits.SpecialPathLimiter], | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	disabledResponse := &vault.RequestLimiterResponse{ | ||||
| 		GlobalDisabled:   true, | ||||
| 		ListenerDisabled: false, | ||||
| 		Limiters: map[string]*vault.LimiterStatus{ | ||||
| 			limits.WriteLimiter: { | ||||
| 				Enabled: false, | ||||
| 			}, | ||||
| 			limits.SpecialPathLimiter: { | ||||
| 				Enabled: false, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	cases := []struct { | ||||
| 		name             string | ||||
| 		configAfter      string | ||||
| 		expectedResponse *vault.RequestLimiterResponse | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"enable after default", | ||||
| 			baseHCL + requestLimiterEnableHCL, | ||||
| 			enabledResponse, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"enable after enable", | ||||
| 			baseHCL + requestLimiterEnableHCL, | ||||
| 			enabledResponse, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"disable after enable", | ||||
| 			baseHCL + requestLimiterDisableHCL, | ||||
| 			disabledResponse, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default after disable", | ||||
| 			baseHCL, | ||||
| 			enabledResponse, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default after default", | ||||
| 			baseHCL, | ||||
| 			enabledResponse, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"disable after default", | ||||
| 			baseHCL + requestLimiterDisableHCL, | ||||
| 			disabledResponse, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"disable after disable", | ||||
| 			baseHCL + requestLimiterDisableHCL, | ||||
| 			disabledResponse, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	ui, srv := command.TestServerCommand(t) | ||||
|  | ||||
| 	f, err := os.CreateTemp(t.TempDir(), "") | ||||
| 	require.NoErrorf(t, err, "error creating temp dir: %v", err) | ||||
|  | ||||
| 	_, err = f.WriteString(baseHCL) | ||||
| 	require.NoErrorf(t, err, "cannot write temp file contents") | ||||
|  | ||||
| 	configPath := f.Name() | ||||
|  | ||||
| 	var output string | ||||
| 	wg := &sync.WaitGroup{} | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		code := srv.Run([]string{"-config", configPath}) | ||||
| 		output = ui.ErrorWriter.String() + ui.OutputWriter.String() | ||||
| 		require.Equal(t, 0, code, output) | ||||
| 	}() | ||||
|  | ||||
| 	select { | ||||
| 	case <-srv.StartedCh(): | ||||
| 	case <-time.After(5 * time.Second): | ||||
| 		t.Fatalf("timeout") | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		srv.ShutdownCh <- struct{}{} | ||||
| 		wg.Wait() | ||||
| 	}() | ||||
|  | ||||
| 	err = f.Close() | ||||
| 	require.NoErrorf(t, err, "unable to close temp file") | ||||
|  | ||||
| 	// create a client and unseal vault | ||||
| 	cli, err := srv.Client() | ||||
| 	require.NoError(t, err) | ||||
| 	require.NoError(t, cli.SetAddress("http://127.0.0.1:8209")) | ||||
| 	initResp, err := cli.Sys().Init(&api.InitRequest{SecretShares: 1, SecretThreshold: 1}) | ||||
| 	require.NoError(t, err) | ||||
| 	_, err = cli.Sys().Unseal(initResp.Keys[0]) | ||||
| 	require.NoError(t, err) | ||||
| 	cli.SetToken(initResp.RootToken) | ||||
|  | ||||
| 	output = ui.ErrorWriter.String() + ui.OutputWriter.String() | ||||
| 	require.Contains(t, output, "Request Limiter: enabled") | ||||
|  | ||||
| 	verifyLimiters := func(t *testing.T, expectedResponse *vault.RequestLimiterResponse) { | ||||
| 		t.Helper() | ||||
|  | ||||
| 		statusResp, err := cli.Logical().Read("/sys/internal/request-limiter/status") | ||||
| 		require.NoError(t, err) | ||||
| 		require.NotNil(t, statusResp) | ||||
|  | ||||
| 		limitersResp, ok := statusResp.Data["request_limiter"] | ||||
| 		require.True(t, ok) | ||||
| 		require.NotNil(t, limitersResp) | ||||
|  | ||||
| 		var limiters *vault.RequestLimiterResponse | ||||
| 		err = mapstructure.Decode(limitersResp, &limiters) | ||||
| 		require.NoError(t, err) | ||||
| 		require.NotNil(t, limiters) | ||||
|  | ||||
| 		require.Equal(t, expectedResponse, limiters) | ||||
| 	} | ||||
|  | ||||
| 	// Start off with default enabled | ||||
| 	verifyLimiters(t, enabledResponse) | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			// Write the new contents and reload the server | ||||
| 			f, err = os.OpenFile(configPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) | ||||
| 			require.NoError(t, err) | ||||
| 			defer f.Close() | ||||
|  | ||||
| 			_, err = f.WriteString(tc.configAfter) | ||||
| 			require.NoErrorf(t, err, "cannot write temp file contents") | ||||
|  | ||||
| 			srv.SighupCh <- struct{}{} | ||||
| 			select { | ||||
| 			case <-srv.ReloadedCh(): | ||||
| 			case <-time.After(5 * time.Second): | ||||
| 				t.Fatalf("test timed out") | ||||
| 			} | ||||
|  | ||||
| 			verifyLimiters(t, tc.expectedResponse) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| @@ -22,11 +22,9 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/cli" | ||||
| 	"github.com/hashicorp/vault/command/server" | ||||
| 	"github.com/hashicorp/vault/helper/testhelpers/corehelpers" | ||||
| 	"github.com/hashicorp/vault/internalshared/configutil" | ||||
| 	"github.com/hashicorp/vault/sdk/physical" | ||||
| 	physInmem "github.com/hashicorp/vault/sdk/physical/inmem" | ||||
| 	"github.com/hashicorp/vault/vault" | ||||
| 	"github.com/hashicorp/vault/vault/seal" | ||||
| @@ -97,29 +95,6 @@ cloud { | ||||
| ` | ||||
| ) | ||||
|  | ||||
| func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { | ||||
| 	tb.Helper() | ||||
|  | ||||
| 	ui := cli.NewMockUi() | ||||
| 	return ui, &ServerCommand{ | ||||
| 		BaseCommand: &BaseCommand{ | ||||
| 			UI: ui, | ||||
| 		}, | ||||
| 		ShutdownCh: MakeShutdownCh(), | ||||
| 		SighupCh:   MakeSighupCh(), | ||||
| 		SigUSR2Ch:  MakeSigUSR2Ch(), | ||||
| 		PhysicalBackends: map[string]physical.Factory{ | ||||
| 			"inmem":    physInmem.NewInmem, | ||||
| 			"inmem_ha": physInmem.NewInmemHA, | ||||
| 		}, | ||||
|  | ||||
| 		// These prevent us from random sleep guessing... | ||||
| 		startedCh:         make(chan struct{}, 5), | ||||
| 		reloadedCh:        make(chan struct{}, 5), | ||||
| 		licenseReloadedCh: make(chan error), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestServer_ReloadListener(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
|   | ||||
							
								
								
									
										48
									
								
								command/server_util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								command/server_util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| // Copyright (c) HashiCorp, Inc. | ||||
| // SPDX-License-Identifier: BUSL-1.1 | ||||
|  | ||||
| package command | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/hashicorp/cli" | ||||
| 	"github.com/hashicorp/vault/sdk/physical" | ||||
| 	physInmem "github.com/hashicorp/vault/sdk/physical/inmem" | ||||
| ) | ||||
|  | ||||
| func TestServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { | ||||
| 	tb.Helper() | ||||
| 	return testServerCommand(tb) | ||||
| } | ||||
|  | ||||
| func (c *ServerCommand) StartedCh() chan struct{} { | ||||
| 	return c.startedCh | ||||
| } | ||||
|  | ||||
| func (c *ServerCommand) ReloadedCh() chan struct{} { | ||||
| 	return c.reloadedCh | ||||
| } | ||||
|  | ||||
| func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { | ||||
| 	tb.Helper() | ||||
|  | ||||
| 	ui := cli.NewMockUi() | ||||
| 	return ui, &ServerCommand{ | ||||
| 		BaseCommand: &BaseCommand{ | ||||
| 			UI: ui, | ||||
| 		}, | ||||
| 		ShutdownCh: MakeShutdownCh(), | ||||
| 		SighupCh:   MakeSighupCh(), | ||||
| 		SigUSR2Ch:  MakeSigUSR2Ch(), | ||||
| 		PhysicalBackends: map[string]physical.Factory{ | ||||
| 			"inmem":    physInmem.NewInmem, | ||||
| 			"inmem_ha": physInmem.NewInmemHA, | ||||
| 		}, | ||||
|  | ||||
| 		// These prevent us from random sleep guessing... | ||||
| 		startedCh:         make(chan struct{}, 5), | ||||
| 		reloadedCh:        make(chan struct{}, 5), | ||||
| 		licenseReloadedCh: make(chan error), | ||||
| 	} | ||||
| } | ||||
| @@ -919,6 +919,7 @@ func acquireLimiterListener(core *vault.Core, rawReq *http.Request, r *logical.R | ||||
| 	if disableRequestLimiter != nil { | ||||
| 		disable = disableRequestLimiter.(bool) | ||||
| 	} | ||||
| 	r.RequestLimiterDisabled = disable | ||||
| 	if disable { | ||||
| 		return &limits.RequestListener{}, true | ||||
| 	} | ||||
|   | ||||
| @@ -47,6 +47,7 @@ const ( | ||||
| // RequestLimiter is a thin wrapper for limiter.DefaultLimiter. | ||||
| type RequestLimiter struct { | ||||
| 	*limiter.DefaultLimiter | ||||
| 	Flags LimiterFlags | ||||
| } | ||||
|  | ||||
| // Acquire consults the underlying RequestLimiter to see if a new | ||||
| @@ -103,34 +104,31 @@ func concurrencyChanger(limit int) int { | ||||
| 	return int(change) | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	// DefaultWriteLimiterFlags have a less conservative MinLimit to prevent | ||||
| var DefaultLimiterFlags = map[string]LimiterFlags{ | ||||
| 	// WriteLimiter default flags have a less conservative MinLimit to prevent | ||||
| 	// over-optimizing the request latency, which would result in | ||||
| 	// under-utilization and client starvation. | ||||
| 	DefaultWriteLimiterFlags = LimiterFlags{ | ||||
| 		Name:     WriteLimiter, | ||||
| 		MinLimit: 100, | ||||
| 		MaxLimit: 5000, | ||||
| 	} | ||||
| 	WriteLimiter: { | ||||
| 		MinLimit:     100, | ||||
| 		MaxLimit:     5000, | ||||
| 		InitialLimit: 100, | ||||
| 	}, | ||||
|  | ||||
| 	// DefaultSpecialPathLimiterFlags have a conservative MinLimit to allow more | ||||
| 	// aggressive concurrency throttling for CPU-bound workloads such as | ||||
| 	// SpecialPathLimiter default flags have a conservative MinLimit to allow | ||||
| 	// more aggressive concurrency throttling for CPU-bound workloads such as | ||||
| 	// `pki/issue`. | ||||
| 	DefaultSpecialPathLimiterFlags = LimiterFlags{ | ||||
| 		Name:     SpecialPathLimiter, | ||||
| 		MinLimit: 5, | ||||
| 		MaxLimit: 5000, | ||||
| 	} | ||||
| ) | ||||
| 	SpecialPathLimiter: { | ||||
| 		MinLimit:     5, | ||||
| 		MaxLimit:     5000, | ||||
| 		InitialLimit: 5, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| // LimiterFlags establish some initial configuration for a new request limiter. | ||||
| type LimiterFlags struct { | ||||
| 	// Name specifies the limiter Name for registry lookup and logging. | ||||
| 	Name string | ||||
|  | ||||
| 	// MinLimit defines the minimum concurrency floor to prevent over-throttling | ||||
| 	// requests during periods of high traffic. | ||||
| 	MinLimit int | ||||
| 	MinLimit int `json:"min_limit,omitempty" mapstructure:"min_limit,omitempty"` | ||||
|  | ||||
| 	// MaxLimit defines the maximum concurrency ceiling to prevent skewing to a | ||||
| 	// point of no return. | ||||
| @@ -139,7 +137,7 @@ type LimiterFlags struct { | ||||
| 	// high-performing specs will tolerate higher limits, while the algorithm | ||||
| 	// will find its own steady-state concurrency well below this threshold in | ||||
| 	// most cases. | ||||
| 	MaxLimit int | ||||
| 	MaxLimit int `json:"max_limit,omitempty" mapstructure:"max_limit,omitempty"` | ||||
|  | ||||
| 	// InitialLimit defines the starting concurrency limit prior to any | ||||
| 	// measurements. | ||||
| @@ -150,13 +148,13 @@ type LimiterFlags struct { | ||||
| 	// rejection; however, the adaptive nature of the algorithm will prevent | ||||
| 	// this from being a prolonged state as the allowed concurrency will | ||||
| 	// increase during normal operation. | ||||
| 	InitialLimit int | ||||
| 	InitialLimit int `json:"initial_limit,omitempty" mapstructure:"initial_limit,omitempty"` | ||||
| } | ||||
|  | ||||
| // NewRequestLimiter is a basic constructor for the RequestLimiter wrapper. It | ||||
| // is responsible for setting up the Gradient2 Limit and instantiating a new | ||||
| // wrapped DefaultLimiter. | ||||
| func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter, error) { | ||||
| func NewRequestLimiter(logger hclog.Logger, name string, flags LimiterFlags) (*RequestLimiter, error) { | ||||
| 	logger.Info("setting up new request limiter", | ||||
| 		"initialLimit", flags.InitialLimit, | ||||
| 		"maxLimit", flags.MaxLimit, | ||||
| @@ -167,7 +165,7 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter | ||||
| 	// decisions. It gathers latency measurements and calculates an Exponential | ||||
| 	// Moving Average to determine whether latency deviation warrants a change | ||||
| 	// in the current concurrency limit. | ||||
| 	lim, err := limit.NewGradient2Limit(flags.Name, | ||||
| 	lim, err := limit.NewGradient2Limit(name, | ||||
| 		flags.InitialLimit, | ||||
| 		flags.MaxLimit, | ||||
| 		flags.MinLimit, | ||||
| @@ -178,7 +176,7 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter | ||||
| 		DefaultMetricsRegistry, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create gradient2 limit: %w", err) | ||||
| 		return &RequestLimiter{}, fmt.Errorf("failed to create gradient2 limit: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	strategy := strategy.NewSimpleStrategy(flags.InitialLimit) | ||||
| @@ -187,5 +185,5 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter | ||||
| 		return &RequestLimiter{}, err | ||||
| 	} | ||||
|  | ||||
| 	return &RequestLimiter{defLimiter}, nil | ||||
| 	return &RequestLimiter{Flags: flags, DefaultLimiter: defLimiter}, nil | ||||
| } | ||||
|   | ||||
| @@ -67,8 +67,8 @@ func NewLimiterRegistry(logger hclog.Logger) *LimiterRegistry { | ||||
| // processEnvVars consults Limiter-specific environment variables and tells the | ||||
| // caller if the Limiter should be disabled. If not, it adjusts the passed-in | ||||
| // limiterFlags as appropriate. | ||||
| func (r *LimiterRegistry) processEnvVars(flags *LimiterFlags, envDisabled, envMin, envMax string) bool { | ||||
| 	envFlagsLogger := r.Logger.With("name", flags.Name) | ||||
| func (r *LimiterRegistry) processEnvVars(name string, flags *LimiterFlags, envDisabled, envMin, envMax string) bool { | ||||
| 	envFlagsLogger := r.Logger.With("name", name) | ||||
| 	if disabledRaw := os.Getenv(envDisabled); disabledRaw != "" { | ||||
| 		disabled, err := strconv.ParseBool(disabledRaw) | ||||
| 		if err != nil { | ||||
| @@ -147,20 +147,22 @@ func (r *LimiterRegistry) Enable() { | ||||
|  | ||||
| 	r.Logger.Info("enabling request limiters") | ||||
| 	r.Limiters = map[string]*RequestLimiter{} | ||||
| 	r.Register(DefaultWriteLimiterFlags) | ||||
| 	r.Register(DefaultSpecialPathLimiterFlags) | ||||
|  | ||||
| 	for name, flags := range DefaultLimiterFlags { | ||||
| 		r.Register(name, flags) | ||||
| 	} | ||||
|  | ||||
| 	r.Enabled = true | ||||
| } | ||||
|  | ||||
| // Register creates a new request limiter and assigns it a slot in the | ||||
| // LimiterRegistry. Locking should be done in the caller. | ||||
| func (r *LimiterRegistry) Register(flags LimiterFlags) { | ||||
| func (r *LimiterRegistry) Register(name string, flags LimiterFlags) { | ||||
| 	var disabled bool | ||||
|  | ||||
| 	switch flags.Name { | ||||
| 	switch name { | ||||
| 	case WriteLimiter: | ||||
| 		disabled = r.processEnvVars(&flags, | ||||
| 		disabled = r.processEnvVars(name, &flags, | ||||
| 			EnvVaultDisableWriteLimiter, | ||||
| 			EnvVaultWriteLimiterMin, | ||||
| 			EnvVaultWriteLimiterMax, | ||||
| @@ -169,7 +171,7 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) { | ||||
| 			return | ||||
| 		} | ||||
| 	case SpecialPathLimiter: | ||||
| 		disabled = r.processEnvVars(&flags, | ||||
| 		disabled = r.processEnvVars(name, &flags, | ||||
| 			EnvVaultDisableSpecialPathLimiter, | ||||
| 			EnvVaultSpecialPathLimiterMin, | ||||
| 			EnvVaultSpecialPathLimiterMax, | ||||
| @@ -178,7 +180,7 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) { | ||||
| 			return | ||||
| 		} | ||||
| 	default: | ||||
| 		r.Logger.Warn("skipping invalid limiter type", "key", flags.Name) | ||||
| 		r.Logger.Warn("skipping invalid limiter type", "key", name) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -186,18 +188,19 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) { | ||||
| 	// equilibrium, since max might be too high. | ||||
| 	flags.InitialLimit = flags.MinLimit | ||||
|  | ||||
| 	limiter, err := NewRequestLimiter(r.Logger.Named(flags.Name), flags) | ||||
| 	limiter, err := NewRequestLimiter(r.Logger.Named(name), name, flags) | ||||
| 	if err != nil { | ||||
| 		r.Logger.Error("failed to register limiter", "name", flags.Name, "error", err) | ||||
| 		r.Logger.Error("failed to register limiter", "name", name, "error", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.Limiters[flags.Name] = limiter | ||||
| 	r.Limiters[name] = limiter | ||||
| } | ||||
|  | ||||
| // Disable drops its references to underlying limiters. | ||||
| func (r *LimiterRegistry) Disable() { | ||||
| 	r.Lock() | ||||
| 	defer r.Unlock() | ||||
|  | ||||
| 	if !r.Enabled { | ||||
| 		return | ||||
| @@ -209,7 +212,6 @@ func (r *LimiterRegistry) Disable() { | ||||
| 	// here and the garbage-collector should take care of the rest. | ||||
| 	r.Limiters = map[string]*RequestLimiter{} | ||||
| 	r.Enabled = false | ||||
| 	r.Unlock() | ||||
| } | ||||
|  | ||||
| // GetLimiter looks up a RequestLimiter by key in the LimiterRegistry. | ||||
|   | ||||
| @@ -255,6 +255,9 @@ type Request struct { | ||||
|  | ||||
| 	// Name of the chroot namespace for the listener that the request was made against | ||||
| 	ChrootNamespace string `json:"chroot_namespace,omitempty"` | ||||
|  | ||||
| 	// RequestLimiterDisabled tells whether the request context has Request Limiter applied. | ||||
| 	RequestLimiterDisabled bool `json:"request_limiter_disabled,omitempty"` | ||||
| } | ||||
|  | ||||
| // Clone returns a deep copy (almost) of the request. | ||||
|   | ||||
| @@ -725,6 +725,8 @@ func (c *Core) EchoDuration() time.Duration { | ||||
| } | ||||
|  | ||||
| func (c *Core) GetRequestLimiter(key string) *limits.RequestLimiter { | ||||
| 	c.limiterRegistryLock.Lock() | ||||
| 	defer c.limiterRegistryLock.Unlock() | ||||
| 	return c.limiterRegistry.GetLimiter(key) | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -226,6 +226,10 @@ func NewSystemBackend(core *Core, logger log.Logger, config *logical.BackendConf | ||||
| 	b.Backend.Paths = append(b.Backend.Paths, b.experimentPaths()...) | ||||
| 	b.Backend.Paths = append(b.Backend.Paths, b.introspectionPaths()...) | ||||
|  | ||||
| 	if requestLimiterRead := b.requestLimiterReadPath(); requestLimiterRead != nil { | ||||
| 		b.Backend.Paths = append(b.Backend.Paths, b.requestLimiterReadPath()) | ||||
| 	} | ||||
|  | ||||
| 	if core.rawEnabled { | ||||
| 		b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										12
									
								
								vault/logical_system_limits.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								vault/logical_system_limits.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| // Copyright (c) HashiCorp, Inc. | ||||
| // SPDX-License-Identifier: BUSL-1.1 | ||||
|  | ||||
| //go:build !testonly | ||||
|  | ||||
| package vault | ||||
|  | ||||
| import ( | ||||
| 	"github.com/hashicorp/vault/sdk/framework" | ||||
| ) | ||||
|  | ||||
| func (b *SystemBackend) requestLimiterReadPath() *framework.Path { return nil } | ||||
							
								
								
									
										88
									
								
								vault/logical_system_limits_testonly.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								vault/logical_system_limits_testonly.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | ||||
| // Copyright (c) HashiCorp, Inc. | ||||
| // SPDX-License-Identifier: BUSL-1.1 | ||||
|  | ||||
| //go:build testonly | ||||
|  | ||||
| package vault | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/limits" | ||||
| 	"github.com/hashicorp/vault/sdk/framework" | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
|  | ||||
| // RequestLimiterResponse is a struct for marshalling Request Limiter status responses. | ||||
| type RequestLimiterResponse struct { | ||||
| 	GlobalDisabled   bool                      `json:"global_disabled" mapstructure:"global_disabled"` | ||||
| 	ListenerDisabled bool                      `json:"listener_disabled" mapstructure:"listener_disabled"` | ||||
| 	Limiters         map[string]*LimiterStatus `json:"types" mapstructure:"types"` | ||||
| } | ||||
|  | ||||
| // LimiterStatus holds the per-limiter status and flags for testing. | ||||
| type LimiterStatus struct { | ||||
| 	Enabled bool                `json:"enabled" mapstructure:"enabled"` | ||||
| 	Flags   limits.LimiterFlags `json:"flags,omitempty" mapstructure:"flags,omitempty"` | ||||
| } | ||||
|  | ||||
| const readRequestLimiterHelpText = ` | ||||
| Read the current status of the request limiter. | ||||
| ` | ||||
|  | ||||
| func (b *SystemBackend) requestLimiterReadPath() *framework.Path { | ||||
| 	return &framework.Path{ | ||||
| 		Pattern:         "internal/request-limiter/status$", | ||||
| 		HelpDescription: readRequestLimiterHelpText, | ||||
| 		HelpSynopsis:    readRequestLimiterHelpText, | ||||
| 		Operations: map[logical.Operation]framework.OperationHandler{ | ||||
| 			logical.ReadOperation: &framework.PathOperation{ | ||||
| 				Callback: b.handleReadRequestLimiter, | ||||
| 				DisplayAttrs: &framework.DisplayAttributes{ | ||||
| 					OperationVerb:   "read", | ||||
| 					OperationSuffix: "verbosity-level-for", | ||||
| 				}, | ||||
| 				Responses: map[int][]framework.Response{ | ||||
| 					http.StatusOK: {{ | ||||
| 						Description: "OK", | ||||
| 					}}, | ||||
| 				}, | ||||
| 				Summary: "Read the current status of the request limiter.", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // handleReadRequestLimiter returns the enabled Request Limiter status for this node. | ||||
| func (b *SystemBackend) handleReadRequestLimiter(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||
| 	resp := &RequestLimiterResponse{ | ||||
| 		Limiters: make(map[string]*LimiterStatus), | ||||
| 	} | ||||
|  | ||||
| 	b.Core.limiterRegistryLock.Lock() | ||||
| 	registry := b.Core.limiterRegistry | ||||
| 	b.Core.limiterRegistryLock.Unlock() | ||||
|  | ||||
| 	resp.GlobalDisabled = !registry.Enabled | ||||
| 	resp.ListenerDisabled = req.RequestLimiterDisabled | ||||
| 	enabled := !(resp.GlobalDisabled || resp.ListenerDisabled) | ||||
|  | ||||
| 	for name := range limits.DefaultLimiterFlags { | ||||
| 		var flags limits.LimiterFlags | ||||
| 		if requestLimiter := b.Core.GetRequestLimiter(name); requestLimiter != nil && enabled { | ||||
| 			flags = requestLimiter.Flags | ||||
| 		} | ||||
|  | ||||
| 		resp.Limiters[name] = &LimiterStatus{ | ||||
| 			Enabled: enabled, | ||||
| 			Flags:   flags, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return &logical.Response{ | ||||
| 		Data: map[string]interface{}{ | ||||
| 			"request_limiter": resp, | ||||
| 		}, | ||||
| 	}, nil | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Mike Palmiotto
					Mike Palmiotto