mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	Move Request Limiter to enterprise (#25615)
This commit is contained in:
		| @@ -1,5 +1,5 @@ | |||||||
| ```release-note:feature | ```release-note:feature | ||||||
| **Request Limiter**: Add adaptive concurrency limits to write-based HTTP | **Request Limiter (enterprise)**: Add adaptive concurrency limits to | ||||||
| methods and special-case `pki/issue` requests to prevent overloading the Vault | write-based HTTP methods and special-case `pki/issue` requests to prevent | ||||||
| server. | overloading the Vault server. | ||||||
| ``` | ``` | ||||||
|   | |||||||
| @@ -31,3 +31,7 @@ func entCheckStorageType(coreConfig *vault.CoreConfig) bool { | |||||||
| func entGetFIPSInfoKey() string { | func entGetFIPSInfoKey() string { | ||||||
| 	return "" | 	return "" | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func entGetRequestLimiterStatus(coreConfig vault.CoreConfig) string { | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,213 +0,0 @@ | |||||||
| // Copyright (c) HashiCorp, Inc. |  | ||||||
| // SPDX-License-Identifier: BUSL-1.1 |  | ||||||
|  |  | ||||||
| //go:build testonly |  | ||||||
|  |  | ||||||
| package command_testonly |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"os" |  | ||||||
| 	"sync" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/api" |  | ||||||
| 	"github.com/hashicorp/vault/command" |  | ||||||
| 	"github.com/hashicorp/vault/limits" |  | ||||||
| 	"github.com/hashicorp/vault/vault" |  | ||||||
| 	"github.com/mitchellh/mapstructure" |  | ||||||
| 	"github.com/stretchr/testify/require" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func init() { |  | ||||||
| 	if signed := os.Getenv("VAULT_LICENSE_CI"); signed != "" { |  | ||||||
| 		os.Setenv(command.EnvVaultLicense, signed) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	baseHCL = ` |  | ||||||
| 		backend "inmem" { } |  | ||||||
| 		disable_mlock = true |  | ||||||
| 		listener "tcp" { |  | ||||||
| 			address     = "127.0.0.1:8209" |  | ||||||
| 			tls_disable = "true" |  | ||||||
| 		} |  | ||||||
| 		api_addr = "http://127.0.0.1:8209" |  | ||||||
| 	` |  | ||||||
| 	requestLimiterDisableHCL = ` |  | ||||||
|   request_limiter { |  | ||||||
| 	disable = true |  | ||||||
|   } |  | ||||||
| ` |  | ||||||
| 	requestLimiterEnableHCL = ` |  | ||||||
|   request_limiter { |  | ||||||
| 	disable = false |  | ||||||
|   } |  | ||||||
| ` |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // TestServer_ReloadRequestLimiter tests a series of reloads and state |  | ||||||
| // transitions between RequestLimiter enable and disable. |  | ||||||
| func TestServer_ReloadRequestLimiter(t *testing.T) { |  | ||||||
| 	t.Parallel() |  | ||||||
|  |  | ||||||
| 	enabledResponse := &vault.RequestLimiterResponse{ |  | ||||||
| 		GlobalDisabled:   false, |  | ||||||
| 		ListenerDisabled: false, |  | ||||||
| 		Limiters: map[string]*vault.LimiterStatus{ |  | ||||||
| 			limits.WriteLimiter: { |  | ||||||
| 				Enabled: true, |  | ||||||
| 				Flags:   limits.DefaultLimiterFlags[limits.WriteLimiter], |  | ||||||
| 			}, |  | ||||||
| 			limits.SpecialPathLimiter: { |  | ||||||
| 				Enabled: true, |  | ||||||
| 				Flags:   limits.DefaultLimiterFlags[limits.SpecialPathLimiter], |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	disabledResponse := &vault.RequestLimiterResponse{ |  | ||||||
| 		GlobalDisabled:   true, |  | ||||||
| 		ListenerDisabled: false, |  | ||||||
| 		Limiters: map[string]*vault.LimiterStatus{ |  | ||||||
| 			limits.WriteLimiter: { |  | ||||||
| 				Enabled: false, |  | ||||||
| 			}, |  | ||||||
| 			limits.SpecialPathLimiter: { |  | ||||||
| 				Enabled: false, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	cases := []struct { |  | ||||||
| 		name             string |  | ||||||
| 		configAfter      string |  | ||||||
| 		expectedResponse *vault.RequestLimiterResponse |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			"disable after default", |  | ||||||
| 			baseHCL + requestLimiterDisableHCL, |  | ||||||
| 			disabledResponse, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			"disable after disable", |  | ||||||
| 			baseHCL + requestLimiterDisableHCL, |  | ||||||
| 			disabledResponse, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			"enable after disable", |  | ||||||
| 			baseHCL + requestLimiterEnableHCL, |  | ||||||
| 			enabledResponse, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			"default after enable", |  | ||||||
| 			baseHCL, |  | ||||||
| 			disabledResponse, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			"default after default", |  | ||||||
| 			baseHCL, |  | ||||||
| 			disabledResponse, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			"enable after default", |  | ||||||
| 			baseHCL + requestLimiterEnableHCL, |  | ||||||
| 			enabledResponse, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			"enable after enable", |  | ||||||
| 			baseHCL + requestLimiterEnableHCL, |  | ||||||
| 			enabledResponse, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	ui, srv := command.TestServerCommand(t) |  | ||||||
|  |  | ||||||
| 	f, err := os.CreateTemp(t.TempDir(), "") |  | ||||||
| 	require.NoErrorf(t, err, "error creating temp dir: %v", err) |  | ||||||
|  |  | ||||||
| 	_, err = f.WriteString(baseHCL) |  | ||||||
| 	require.NoErrorf(t, err, "cannot write temp file contents") |  | ||||||
|  |  | ||||||
| 	configPath := f.Name() |  | ||||||
|  |  | ||||||
| 	var output string |  | ||||||
| 	wg := &sync.WaitGroup{} |  | ||||||
| 	wg.Add(1) |  | ||||||
| 	go func() { |  | ||||||
| 		defer wg.Done() |  | ||||||
| 		code := srv.Run([]string{"-config", configPath}) |  | ||||||
| 		output = ui.ErrorWriter.String() + ui.OutputWriter.String() |  | ||||||
| 		require.Equal(t, 0, code, output) |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	select { |  | ||||||
| 	case <-srv.StartedCh(): |  | ||||||
| 	case <-time.After(5 * time.Second): |  | ||||||
| 		t.Fatalf("timeout") |  | ||||||
| 	} |  | ||||||
| 	defer func() { |  | ||||||
| 		srv.ShutdownCh <- struct{}{} |  | ||||||
| 		wg.Wait() |  | ||||||
| 	}() |  | ||||||
|  |  | ||||||
| 	err = f.Close() |  | ||||||
| 	require.NoErrorf(t, err, "unable to close temp file") |  | ||||||
|  |  | ||||||
| 	// create a client and unseal vault |  | ||||||
| 	cli, err := srv.Client() |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	require.NoError(t, cli.SetAddress("http://127.0.0.1:8209")) |  | ||||||
| 	initResp, err := cli.Sys().Init(&api.InitRequest{SecretShares: 1, SecretThreshold: 1}) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	_, err = cli.Sys().Unseal(initResp.Keys[0]) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	cli.SetToken(initResp.RootToken) |  | ||||||
|  |  | ||||||
| 	output = ui.ErrorWriter.String() + ui.OutputWriter.String() |  | ||||||
| 	require.Contains(t, output, "Request Limiter: disabled") |  | ||||||
|  |  | ||||||
| 	verifyLimiters := func(t *testing.T, expectedResponse *vault.RequestLimiterResponse) { |  | ||||||
| 		t.Helper() |  | ||||||
|  |  | ||||||
| 		statusResp, err := cli.Logical().Read("/sys/internal/request-limiter/status") |  | ||||||
| 		require.NoError(t, err) |  | ||||||
| 		require.NotNil(t, statusResp) |  | ||||||
|  |  | ||||||
| 		limitersResp, ok := statusResp.Data["request_limiter"] |  | ||||||
| 		require.True(t, ok) |  | ||||||
| 		require.NotNil(t, limitersResp) |  | ||||||
|  |  | ||||||
| 		var limiters *vault.RequestLimiterResponse |  | ||||||
| 		err = mapstructure.Decode(limitersResp, &limiters) |  | ||||||
| 		require.NoError(t, err) |  | ||||||
| 		require.NotNil(t, limiters) |  | ||||||
|  |  | ||||||
| 		require.Equal(t, expectedResponse, limiters) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Start off with default disabled |  | ||||||
| 	verifyLimiters(t, disabledResponse) |  | ||||||
|  |  | ||||||
| 	for _, tc := range cases { |  | ||||||
| 		t.Run(tc.name, func(t *testing.T) { |  | ||||||
| 			// Write the new contents and reload the server |  | ||||||
| 			f, err = os.OpenFile(configPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) |  | ||||||
| 			require.NoError(t, err) |  | ||||||
| 			defer f.Close() |  | ||||||
|  |  | ||||||
| 			_, err = f.WriteString(tc.configAfter) |  | ||||||
| 			require.NoErrorf(t, err, "cannot write temp file contents") |  | ||||||
|  |  | ||||||
| 			srv.SighupCh <- struct{}{} |  | ||||||
| 			select { |  | ||||||
| 			case <-srv.ReloadedCh(): |  | ||||||
| 			case <-time.After(5 * time.Second): |  | ||||||
| 				t.Fatalf("test timed out") |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			verifyLimiters(t, tc.expectedResponse) |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -1437,15 +1437,15 @@ func (c *ServerCommand) Run(args []string) int { | |||||||
| 		info["HCP resource ID"] = config.HCPLinkConf.Resource.ID | 		info["HCP resource ID"] = config.HCPLinkConf.Resource.ID | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	requestLimiterStatus := entGetRequestLimiterStatus(coreConfig) | ||||||
|  | 	if requestLimiterStatus != "" { | ||||||
|  | 		infoKeys = append(infoKeys, "request_limiter") | ||||||
|  | 		info["request_limiter"] = requestLimiterStatus | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	infoKeys = append(infoKeys, "administrative namespace") | 	infoKeys = append(infoKeys, "administrative namespace") | ||||||
| 	info["administrative namespace"] = config.AdministrativeNamespacePath | 	info["administrative namespace"] = config.AdministrativeNamespacePath | ||||||
|  |  | ||||||
| 	infoKeys = append(infoKeys, "request limiter") |  | ||||||
| 	info["request limiter"] = "disabled" |  | ||||||
| 	if config.RequestLimiter != nil && !config.RequestLimiter.Disable { |  | ||||||
| 		info["request limiter"] = "enabled" |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	sort.Strings(infoKeys) | 	sort.Strings(infoKeys) | ||||||
| 	c.UI.Output("==> Vault server configuration:\n") | 	c.UI.Output("==> Vault server configuration:\n") | ||||||
|  |  | ||||||
| @@ -3118,12 +3118,6 @@ func createCoreConfig(c *ServerCommand, config *server.Config, backend physical. | |||||||
| 		AdministrativeNamespacePath:    config.AdministrativeNamespacePath, | 		AdministrativeNamespacePath:    config.AdministrativeNamespacePath, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if config.RequestLimiter != nil { |  | ||||||
| 		coreConfig.DisableRequestLimiter = config.RequestLimiter.Disable |  | ||||||
| 	} else { |  | ||||||
| 		coreConfig.DisableRequestLimiter = true |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if c.flagDev { | 	if c.flagDev { | ||||||
| 		coreConfig.EnableRaw = true | 		coreConfig.EnableRaw = true | ||||||
| 		coreConfig.EnableIntrospection = true | 		coreConfig.EnableIntrospection = true | ||||||
|   | |||||||
| @@ -613,7 +613,6 @@ func testLoadConfigFile_json(t *testing.T) { | |||||||
| 					Type:                  "tcp", | 					Type:                  "tcp", | ||||||
| 					Address:               "127.0.0.1:443", | 					Address:               "127.0.0.1:443", | ||||||
| 					CustomResponseHeaders: DefaultCustomHeaders, | 					CustomResponseHeaders: DefaultCustomHeaders, | ||||||
| 					DisableRequestLimiter: false, |  | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  |  | ||||||
| @@ -904,6 +903,7 @@ listener "unix" { | |||||||
|   redact_addresses = true |   redact_addresses = true | ||||||
|   redact_cluster_name = true |   redact_cluster_name = true | ||||||
|   redact_version = true |   redact_version = true | ||||||
|  |   disable_request_limiter = true | ||||||
| }`)) | }`)) | ||||||
|  |  | ||||||
| 	config := Config{ | 	config := Config{ | ||||||
| @@ -968,6 +968,7 @@ listener "unix" { | |||||||
| 					RedactAddresses:       false, | 					RedactAddresses:       false, | ||||||
| 					RedactClusterName:     false, | 					RedactClusterName:     false, | ||||||
| 					RedactVersion:         false, | 					RedactVersion:         false, | ||||||
|  | 					DisableRequestLimiter: true, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
|   | |||||||
| @@ -6,7 +6,6 @@ | |||||||
| package server | package server | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" |  | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/internalshared/configutil" | 	"github.com/hashicorp/vault/internalshared/configutil" | ||||||
| @@ -87,55 +86,3 @@ func TestCheckSealConfig(t *testing.T) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // TestRequestLimiterConfig verifies that the census config is correctly instantiated from HCL |  | ||||||
| func TestRequestLimiterConfig(t *testing.T) { |  | ||||||
| 	testCases := []struct { |  | ||||||
| 		name              string |  | ||||||
| 		inConfig          string |  | ||||||
| 		outErr            bool |  | ||||||
| 		outRequestLimiter *configutil.RequestLimiter |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			name:              "empty", |  | ||||||
| 			outRequestLimiter: nil, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "disabled", |  | ||||||
| 			inConfig: ` |  | ||||||
| request_limiter { |  | ||||||
| 	disable = true |  | ||||||
| }`, |  | ||||||
| 			outRequestLimiter: &configutil.RequestLimiter{Disable: true}, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			name: "invalid disable", |  | ||||||
| 			inConfig: ` |  | ||||||
| request_limiter { |  | ||||||
| 	disable = "people make mistakes" |  | ||||||
| }`, |  | ||||||
| 			outErr: true, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	for _, tc := range testCases { |  | ||||||
| 		t.Run(tc.name, func(t *testing.T) { |  | ||||||
| 			config := fmt.Sprintf(` |  | ||||||
| ui = false |  | ||||||
| storage "file" { |  | ||||||
| 	path = "/tmp/test" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| listener "tcp" { |  | ||||||
| 	address = "0.0.0.0:8200" |  | ||||||
| } |  | ||||||
| %s`, tc.inConfig) |  | ||||||
| 			gotConfig, err := ParseConfig(config, "") |  | ||||||
| 			if tc.outErr { |  | ||||||
| 				require.Error(t, err) |  | ||||||
| 			} else { |  | ||||||
| 				require.NoError(t, err) |  | ||||||
| 				require.Equal(t, tc.outRequestLimiter, gotConfig.RequestLimiter) |  | ||||||
| 			} |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							| @@ -198,7 +198,6 @@ require ( | |||||||
| 	github.com/patrickmn/go-cache v2.1.0+incompatible | 	github.com/patrickmn/go-cache v2.1.0+incompatible | ||||||
| 	github.com/pires/go-proxyproto v0.6.1 | 	github.com/pires/go-proxyproto v0.6.1 | ||||||
| 	github.com/pkg/errors v0.9.1 | 	github.com/pkg/errors v0.9.1 | ||||||
| 	github.com/platinummonkey/go-concurrency-limits v0.7.0 |  | ||||||
| 	github.com/posener/complete v1.2.3 | 	github.com/posener/complete v1.2.3 | ||||||
| 	github.com/pquerna/otp v1.2.1-0.20191009055518-468c2dd2b58d | 	github.com/pquerna/otp v1.2.1-0.20191009055518-468c2dd2b58d | ||||||
| 	github.com/prometheus/client_golang v1.14.0 | 	github.com/prometheus/client_golang v1.14.0 | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1274,7 +1274,6 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym | |||||||
| github.com/DataDog/datadog-go v2.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= | github.com/DataDog/datadog-go v2.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= | ||||||
| github.com/DataDog/datadog-go v3.2.0+incompatible h1:qSG2N4FghB1He/r2mFrWKCaL7dXCilEuNEeAn20fdD4= | github.com/DataDog/datadog-go v3.2.0+incompatible h1:qSG2N4FghB1He/r2mFrWKCaL7dXCilEuNEeAn20fdD4= | ||||||
| github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= | github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= | ||||||
| github.com/DataDog/datadog-go/v5 v5.0.2/go.mod h1:ZI9JFB4ewXbw1sBnF4sxsR2k1H3xjV+PUAOUsHvKpcU= |  | ||||||
| github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= | github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= | ||||||
| github.com/Jeffail/gabs v1.1.1 h1:V0uzR08Hj22EX8+8QMhyI9sX2hwRu+/RJhJUmnwda/E= | github.com/Jeffail/gabs v1.1.1 h1:V0uzR08Hj22EX8+8QMhyI9sX2hwRu+/RJhJUmnwda/E= | ||||||
| github.com/Jeffail/gabs v1.1.1/go.mod h1:6xMvQMK4k33lb7GUUpaAPh6nKMmemQeg5d4gn7/bOXc= | github.com/Jeffail/gabs v1.1.1/go.mod h1:6xMvQMK4k33lb7GUUpaAPh6nKMmemQeg5d4gn7/bOXc= | ||||||
| @@ -1299,7 +1298,6 @@ github.com/Microsoft/go-winio v0.4.16/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugX | |||||||
| github.com/Microsoft/go-winio v0.4.17-0.20210211115548-6eac466e5fa3/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= | github.com/Microsoft/go-winio v0.4.17-0.20210211115548-6eac466e5fa3/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= | ||||||
| github.com/Microsoft/go-winio v0.4.17-0.20210324224401-5516f17a5958/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= | github.com/Microsoft/go-winio v0.4.17-0.20210324224401-5516f17a5958/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= | ||||||
| github.com/Microsoft/go-winio v0.4.17/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= | github.com/Microsoft/go-winio v0.4.17/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= | ||||||
| github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= |  | ||||||
| github.com/Microsoft/go-winio v0.5.1/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= | github.com/Microsoft/go-winio v0.5.1/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= | ||||||
| github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= | github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= | ||||||
| github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= | github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= | ||||||
| @@ -3140,8 +3138,6 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | |||||||
| github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | ||||||
| github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= | github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= | ||||||
| github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= | github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= | ||||||
| github.com/platinummonkey/go-concurrency-limits v0.7.0 h1:Bl9E74+67BrlRLBeryHOaFy0e1L3zD9g436/3vo6akQ= |  | ||||||
| github.com/platinummonkey/go-concurrency-limits v0.7.0/go.mod h1:Xxr6BywMVH3QyLyd0PanLnkkkmByTTPET3azMpdfmng= |  | ||||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||||
| github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= | ||||||
| github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||||
| @@ -3210,7 +3206,6 @@ github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0ua | |||||||
| github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= | github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= | ||||||
| github.com/rboyer/safeio v0.2.1 h1:05xhhdRNAdS3apYm7JRjOqngf4xruaW959jmRxGDuSU= | github.com/rboyer/safeio v0.2.1 h1:05xhhdRNAdS3apYm7JRjOqngf4xruaW959jmRxGDuSU= | ||||||
| github.com/rboyer/safeio v0.2.1/go.mod h1:Cq/cEPK+YXFn622lsQ0K4KsPZSPtaptHHEldsy7Fmig= | github.com/rboyer/safeio v0.2.1/go.mod h1:Cq/cEPK+YXFn622lsQ0K4KsPZSPtaptHHEldsy7Fmig= | ||||||
| github.com/rcrowley/go-metrics v0.0.0-20180503174638-e2704e165165/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= |  | ||||||
| github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= | github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= | ||||||
| github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= | github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= | ||||||
| github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= | ||||||
|   | |||||||
| @@ -918,35 +918,15 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) { | |||||||
| 	w.Write(retBytes) | 	w.Write(retBytes) | ||||||
| } | } | ||||||
|  |  | ||||||
| func acquireLimiterListener(core *vault.Core, rawReq *http.Request, r *logical.Request) (*limits.RequestListener, bool) { |  | ||||||
| 	var disable bool |  | ||||||
| 	disableRequestLimiter := rawReq.Context().Value(logical.CtxKeyDisableRequestLimiter{}) |  | ||||||
| 	if disableRequestLimiter != nil { |  | ||||||
| 		disable = disableRequestLimiter.(bool) |  | ||||||
| 	} |  | ||||||
| 	r.RequestLimiterDisabled = disable |  | ||||||
| 	if disable { |  | ||||||
| 		return &limits.RequestListener{}, true |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	lim := &limits.RequestLimiter{} |  | ||||||
| 	if r.PathLimited { |  | ||||||
| 		lim = core.GetRequestLimiter(limits.SpecialPathLimiter) |  | ||||||
| 	} else { |  | ||||||
| 		switch rawReq.Method { |  | ||||||
| 		case http.MethodGet, http.MethodHead, http.MethodTrace, http.MethodOptions: |  | ||||||
| 			// We're only interested in the inverse, so do nothing here. |  | ||||||
| 		default: |  | ||||||
| 			lim = core.GetRequestLimiter(limits.WriteLimiter) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return lim.Acquire(rawReq.Context()) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // request is a helper to perform a request and properly exit in the | // request is a helper to perform a request and properly exit in the | ||||||
| // case of an error. | // case of an error. | ||||||
| func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool, bool) { | func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool, bool) { | ||||||
| 	lsnr, ok := acquireLimiterListener(core, rawReq, r) | 	lim := &limits.HTTPLimiter{ | ||||||
|  | 		Method:      rawReq.Method, | ||||||
|  | 		PathLimited: r.PathLimited, | ||||||
|  | 		LookupFunc:  core.GetRequestLimiter, | ||||||
|  | 	} | ||||||
|  | 	lsnr, ok := lim.Acquire(rawReq.Context()) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		resp := &logical.Response{} | 		resp := &logical.Response{} | ||||||
| 		logical.RespondWithStatusCode(resp, r, http.StatusServiceUnavailable) | 		logical.RespondWithStatusCode(resp, r, http.StatusServiceUnavailable) | ||||||
|   | |||||||
| @@ -14,6 +14,7 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-multierror" | 	"github.com/hashicorp/go-multierror" | ||||||
| 	"github.com/hashicorp/vault/helper/namespace" | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
|  | 	"github.com/hashicorp/vault/limits" | ||||||
| 	"github.com/hashicorp/vault/sdk/logical" | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
| 	"github.com/hashicorp/vault/vault" | 	"github.com/hashicorp/vault/vault" | ||||||
| 	"github.com/hashicorp/vault/vault/quotas" | 	"github.com/hashicorp/vault/vault/quotas" | ||||||
| @@ -47,7 +48,7 @@ func wrapRequestLimiterHandler(handler http.Handler, props *vault.HandlerPropert | |||||||
| 		request := r.WithContext( | 		request := r.WithContext( | ||||||
| 			context.WithValue( | 			context.WithValue( | ||||||
| 				r.Context(), | 				r.Context(), | ||||||
| 				logical.CtxKeyDisableRequestLimiter{}, | 				limits.CtxKeyDisableRequestLimiter{}, | ||||||
| 				props.ListenerConfig.DisableRequestLimiter, | 				props.ListenerConfig.DisableRequestLimiter, | ||||||
| 			), | 			), | ||||||
| 		) | 		) | ||||||
|   | |||||||
| @@ -55,8 +55,6 @@ type SharedConfig struct { | |||||||
| 	ClusterName string `hcl:"cluster_name"` | 	ClusterName string `hcl:"cluster_name"` | ||||||
|  |  | ||||||
| 	AdministrativeNamespacePath string `hcl:"administrative_namespace_path"` | 	AdministrativeNamespacePath string `hcl:"administrative_namespace_path"` | ||||||
|  |  | ||||||
| 	RequestLimiter *RequestLimiter `hcl:"request_limiter"` |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func ParseConfig(d string) (*SharedConfig, error) { | func ParseConfig(d string) (*SharedConfig, error) { | ||||||
| @@ -158,13 +156,6 @@ func ParseConfig(d string) (*SharedConfig, error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if o := list.Filter("request_limiter"); len(o.Items) > 0 { |  | ||||||
| 		result.found("request_limiter", "RequestLimiter") |  | ||||||
| 		if err := parseRequestLimiter(&result, o); err != nil { |  | ||||||
| 			return nil, fmt.Errorf("error parsing 'request_limiter': %w", err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	entConfig := &(result.EntSharedConfig) | 	entConfig := &(result.EntSharedConfig) | ||||||
| 	if err := entConfig.ParseConfig(list); err != nil { | 	if err := entConfig.ParseConfig(list); err != nil { | ||||||
| 		return nil, fmt.Errorf("error parsing enterprise config: %w", err) | 		return nil, fmt.Errorf("error parsing enterprise config: %w", err) | ||||||
| @@ -293,13 +284,6 @@ func (c *SharedConfig) Sanitized() map[string]interface{} { | |||||||
| 		result["telemetry"] = sanitizedTelemetry | 		result["telemetry"] = sanitizedTelemetry | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if c.RequestLimiter != nil { |  | ||||||
| 		sanitizedRequestLimiter := map[string]interface{}{ |  | ||||||
| 			"disable": c.RequestLimiter.Disable, |  | ||||||
| 		} |  | ||||||
| 		result["request_limiter"] = sanitizedRequestLimiter |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return result | 	return result | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -98,10 +98,5 @@ func (c *SharedConfig) Merge(c2 *SharedConfig) *SharedConfig { | |||||||
| 		result.ClusterName = c2.ClusterName | 		result.ClusterName = c2.ClusterName | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	result.RequestLimiter = c.RequestLimiter |  | ||||||
| 	if c2.RequestLimiter != nil { |  | ||||||
| 		result.RequestLimiter = c2.RequestLimiter |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return result | 	return result | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,58 +0,0 @@ | |||||||
| // Copyright (c) HashiCorp, Inc. |  | ||||||
| // SPDX-License-Identifier: BUSL-1.1 |  | ||||||
|  |  | ||||||
| package configutil |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-multierror" |  | ||||||
| 	"github.com/hashicorp/go-secure-stdlib/parseutil" |  | ||||||
| 	"github.com/hashicorp/hcl" |  | ||||||
| 	"github.com/hashicorp/hcl/hcl/ast" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type RequestLimiter struct { |  | ||||||
| 	UnusedKeys UnusedKeyMap `hcl:",unusedKeyPositions"` |  | ||||||
|  |  | ||||||
| 	Disable    bool        `hcl:"-"` |  | ||||||
| 	DisableRaw interface{} `hcl:"disable"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (r *RequestLimiter) Validate(source string) []ConfigError { |  | ||||||
| 	return ValidateUnusedFields(r.UnusedKeys, source) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (r *RequestLimiter) GoString() string { |  | ||||||
| 	return fmt.Sprintf("*%#v", *r) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DefaultRequestLimiter = &RequestLimiter{ |  | ||||||
| 	Disable: true, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func parseRequestLimiter(result *SharedConfig, list *ast.ObjectList) error { |  | ||||||
| 	if len(list.Items) > 1 { |  | ||||||
| 		return fmt.Errorf("only one 'request_limiter' block is permitted") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	result.RequestLimiter = DefaultRequestLimiter |  | ||||||
|  |  | ||||||
| 	// Get our one item |  | ||||||
| 	item := list.Items[0] |  | ||||||
|  |  | ||||||
| 	if err := hcl.DecodeObject(&result.RequestLimiter, item.Val); err != nil { |  | ||||||
| 		return multierror.Prefix(err, "request_limiter:") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	result.RequestLimiter.Disable = true |  | ||||||
| 	if result.RequestLimiter.DisableRaw != nil { |  | ||||||
| 		var err error |  | ||||||
| 		if result.RequestLimiter.Disable, err = parseutil.ParseBool(result.RequestLimiter.DisableRaw); err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		result.RequestLimiter.DisableRaw = nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
							
								
								
									
										56
									
								
								limits/http_limiter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								limits/http_limiter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | |||||||
|  | // Copyright (c) HashiCorp, Inc. | ||||||
|  | // SPDX-License-Identifier: BUSL-1.1 | ||||||
|  |  | ||||||
|  | package limits | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | //lint:ignore ST1005 Vault is the product name | ||||||
|  | var ErrCapacity = errors.New("Vault server temporarily overloaded") | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	WriteLimiter       = "write" | ||||||
|  | 	SpecialPathLimiter = "special-path" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // HTTPLimiter is a convenience struct that we use to wrap some logical request | ||||||
|  | // context and prevent dependence on Core. | ||||||
|  | type HTTPLimiter struct { | ||||||
|  | 	Method      string | ||||||
|  | 	PathLimited bool | ||||||
|  | 	LookupFunc  func(key string) *RequestLimiter | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CtxKeyDisableRequestLimiter holds the HTTP Listener's disable config if set. | ||||||
|  | type CtxKeyDisableRequestLimiter struct{} | ||||||
|  |  | ||||||
|  | func (c CtxKeyDisableRequestLimiter) String() string { | ||||||
|  | 	return "disable_request_limiter" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Acquire checks the HTTPLimiter metadata to determine if an HTTP request | ||||||
|  | // should be limited, or simply passed through as a no-op. | ||||||
|  | func (h *HTTPLimiter) Acquire(ctx context.Context) (*RequestListener, bool) { | ||||||
|  | 	// If the limiter is disabled, return an empty wrapper so the limiter is a | ||||||
|  | 	// no-op and indicate that the request can proceed. | ||||||
|  | 	if disable := ctx.Value(CtxKeyDisableRequestLimiter{}); disable != nil && disable.(bool) { | ||||||
|  | 		return &RequestListener{}, true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	lim := &RequestLimiter{} | ||||||
|  | 	if h.PathLimited { | ||||||
|  | 		lim = h.LookupFunc(SpecialPathLimiter) | ||||||
|  | 	} else { | ||||||
|  | 		switch h.Method { | ||||||
|  | 		case http.MethodGet, http.MethodHead, http.MethodTrace, http.MethodOptions: | ||||||
|  | 			// We're only interested in the inverse, so do nothing here. | ||||||
|  | 		default: | ||||||
|  | 			lim = h.LookupFunc(WriteLimiter) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return lim.Acquire(ctx) | ||||||
|  | } | ||||||
| @@ -1,189 +1,20 @@ | |||||||
| // Copyright (c) HashiCorp, Inc. | // Copyright (c) HashiCorp, Inc. | ||||||
| // SPDX-License-Identifier: BUSL-1.1 | // SPDX-License-Identifier: BUSL-1.1 | ||||||
|  |  | ||||||
|  | //go:build !enterprise | ||||||
|  |  | ||||||
| package limits | package limits | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" |  | ||||||
| 	"fmt" |  | ||||||
| 	"math" |  | ||||||
| 	"sync/atomic" |  | ||||||
|  |  | ||||||
| 	"github.com/armon/go-metrics" |  | ||||||
| 	"github.com/hashicorp/go-hclog" |  | ||||||
| 	"github.com/platinummonkey/go-concurrency-limits/core" |  | ||||||
| 	"github.com/platinummonkey/go-concurrency-limits/limit" |  | ||||||
| 	"github.com/platinummonkey/go-concurrency-limits/limiter" |  | ||||||
| 	"github.com/platinummonkey/go-concurrency-limits/strategy" |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | type RequestLimiter struct{} | ||||||
| 	// ErrCapacity is a new error type to indicate that Vault is not accepting new |  | ||||||
| 	// requests. This should be handled by callers in request paths to return |  | ||||||
| 	// http.StatusServiceUnavailable to the client. |  | ||||||
| 	ErrCapacity = errors.New("Vault server temporarily overloaded") |  | ||||||
|  |  | ||||||
| 	// DefaultDebugLogger opts out of the go-concurrency-limits internal Debug | // Acquire is a no-op on CE | ||||||
| 	// logger, since it's rather noisy. We're generating logs of interest in | func (l *RequestLimiter) Acquire(_ctx context.Context) (*RequestListener, bool) { | ||||||
| 	// Vault. |  | ||||||
| 	DefaultDebugLogger limit.Logger = nil |  | ||||||
|  |  | ||||||
| 	// DefaultMetricsRegistry opts out of the go-concurrency-limits internal |  | ||||||
| 	// metrics because we're tracking what we care about in Vault. |  | ||||||
| 	DefaultMetricsRegistry core.MetricRegistry = core.EmptyMetricRegistryInstance |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	// Smoothing adjusts how heavily we weight newer high-latency detection. |  | ||||||
| 	// Higher values (>1) place more emphasis on recent measurements. We set |  | ||||||
| 	// this below 1 to better tolerate short-lived spikes in request rate. |  | ||||||
| 	DefaultSmoothing = .1 |  | ||||||
|  |  | ||||||
| 	// DefaultLongWindow is chosen as a minimum of 1000 samples. longWindow |  | ||||||
| 	// defines sliding window size used for the Exponential Moving Average. |  | ||||||
| 	DefaultLongWindow = 1000 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // RequestLimiter is a thin wrapper for limiter.DefaultLimiter. |  | ||||||
| type RequestLimiter struct { |  | ||||||
| 	*limiter.DefaultLimiter |  | ||||||
| 	Flags LimiterFlags |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Acquire consults the underlying RequestLimiter to see if a new |  | ||||||
| // RequestListener can be acquired. |  | ||||||
| // |  | ||||||
| // The return values are a *RequestListener, which the caller can use to perform |  | ||||||
| // latency measurements, and a bool to indicate whether or not a RequestListener |  | ||||||
| // was acquired. |  | ||||||
| // |  | ||||||
| // The returned RequestListener is short-lived and eventually garbage-collected; |  | ||||||
| // however, the RequestLimiter keeps track of in-flight concurrency using a |  | ||||||
| // token bucket implementation. The caller must release the resulting Limiter |  | ||||||
| // token by conducting a measurement. |  | ||||||
| // |  | ||||||
| // There are three return cases: |  | ||||||
| // |  | ||||||
| // 1) If Request Limiting is disabled, we return an empty RequestListener so all |  | ||||||
| // measurements are no-ops. |  | ||||||
| // |  | ||||||
| // 2) If the request limit has been exceeded, we will not acquire a |  | ||||||
| // RequestListener and instead return nil, false. No measurement is required, |  | ||||||
| // since we immediately return from callers with ErrCapacity. |  | ||||||
| // |  | ||||||
| // 3) If we have not exceeded the request limit, the caller must call one of |  | ||||||
| // OnSuccess(), OnDropped(), or OnIgnore() to return a measurement and release |  | ||||||
| // the underlying Limiter token. |  | ||||||
| func (l *RequestLimiter) Acquire(ctx context.Context) (*RequestListener, bool) { |  | ||||||
| 	// Transparently handle the case where the limiter is disabled. |  | ||||||
| 	if l == nil || l.DefaultLimiter == nil { |  | ||||||
| 	return &RequestListener{}, true | 	return &RequestListener{}, true | ||||||
| } | } | ||||||
|  |  | ||||||
| 	lsnr, ok := l.DefaultLimiter.Acquire(ctx) | // EstimatedLimit is effectively 0, since we're not limiting requests on CE. | ||||||
| 	if !ok { | func (l *RequestLimiter) EstimatedLimit() int { return 0 } | ||||||
| 		metrics.IncrCounter(([]string{"limits", "concurrency", "service_unavailable"}), 1) |  | ||||||
| 		// If the token acquisition fails, we've reached capacity and we won't |  | ||||||
| 		// get a listener, so just return nil. |  | ||||||
| 		return nil, false |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &RequestListener{ |  | ||||||
| 		DefaultListener: lsnr.(*limiter.DefaultListener), |  | ||||||
| 		released:        new(atomic.Bool), |  | ||||||
| 	}, true |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // concurrencyChanger adjusts the current allowed concurrency with an |  | ||||||
| // exponential backoff as we approach the max limit. |  | ||||||
| func concurrencyChanger(limit int) int { |  | ||||||
| 	change := math.Sqrt(float64(limit)) |  | ||||||
| 	if change < 1.0 { |  | ||||||
| 		change = 1.0 |  | ||||||
| 	} |  | ||||||
| 	return int(change) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var DefaultLimiterFlags = map[string]LimiterFlags{ |  | ||||||
| 	// WriteLimiter default flags have a less conservative MinLimit to prevent |  | ||||||
| 	// over-optimizing the request latency, which would result in |  | ||||||
| 	// under-utilization and client starvation. |  | ||||||
| 	WriteLimiter: { |  | ||||||
| 		MinLimit:     100, |  | ||||||
| 		MaxLimit:     5000, |  | ||||||
| 		InitialLimit: 100, |  | ||||||
| 	}, |  | ||||||
|  |  | ||||||
| 	// SpecialPathLimiter default flags have a conservative MinLimit to allow |  | ||||||
| 	// more aggressive concurrency throttling for CPU-bound workloads such as |  | ||||||
| 	// `pki/issue`. |  | ||||||
| 	SpecialPathLimiter: { |  | ||||||
| 		MinLimit:     5, |  | ||||||
| 		MaxLimit:     5000, |  | ||||||
| 		InitialLimit: 5, |  | ||||||
| 	}, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // LimiterFlags establish some initial configuration for a new request limiter. |  | ||||||
| type LimiterFlags struct { |  | ||||||
| 	// MinLimit defines the minimum concurrency floor to prevent over-throttling |  | ||||||
| 	// requests during periods of high traffic. |  | ||||||
| 	MinLimit int `json:"min_limit,omitempty" mapstructure:"min_limit,omitempty"` |  | ||||||
|  |  | ||||||
| 	// MaxLimit defines the maximum concurrency ceiling to prevent skewing to a |  | ||||||
| 	// point of no return. |  | ||||||
| 	// |  | ||||||
| 	// We set this to a high value (5000) with the expectation that systems with |  | ||||||
| 	// high-performing specs will tolerate higher limits, while the algorithm |  | ||||||
| 	// will find its own steady-state concurrency well below this threshold in |  | ||||||
| 	// most cases. |  | ||||||
| 	MaxLimit int `json:"max_limit,omitempty" mapstructure:"max_limit,omitempty"` |  | ||||||
|  |  | ||||||
| 	// InitialLimit defines the starting concurrency limit prior to any |  | ||||||
| 	// measurements. |  | ||||||
| 	// |  | ||||||
| 	// If we start this value off too high, Vault could become |  | ||||||
| 	// overloaded before the algorithm has a chance to adapt. Setting the value |  | ||||||
| 	// to the minimum is a safety measure which could result in early request |  | ||||||
| 	// rejection; however, the adaptive nature of the algorithm will prevent |  | ||||||
| 	// this from being a prolonged state as the allowed concurrency will |  | ||||||
| 	// increase during normal operation. |  | ||||||
| 	InitialLimit int `json:"initial_limit,omitempty" mapstructure:"initial_limit,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewRequestLimiter is a basic constructor for the RequestLimiter wrapper. It |  | ||||||
| // is responsible for setting up the Gradient2 Limit and instantiating a new |  | ||||||
| // wrapped DefaultLimiter. |  | ||||||
| func NewRequestLimiter(logger hclog.Logger, name string, flags LimiterFlags) (*RequestLimiter, error) { |  | ||||||
| 	logger.Info("setting up new request limiter", |  | ||||||
| 		"initialLimit", flags.InitialLimit, |  | ||||||
| 		"maxLimit", flags.MaxLimit, |  | ||||||
| 		"minLimit", flags.MinLimit, |  | ||||||
| 	) |  | ||||||
|  |  | ||||||
| 	// NewGradient2Limit is the algorithm which drives request limiting |  | ||||||
| 	// decisions. It gathers latency measurements and calculates an Exponential |  | ||||||
| 	// Moving Average to determine whether latency deviation warrants a change |  | ||||||
| 	// in the current concurrency limit. |  | ||||||
| 	lim, err := limit.NewGradient2Limit(name, |  | ||||||
| 		flags.InitialLimit, |  | ||||||
| 		flags.MaxLimit, |  | ||||||
| 		flags.MinLimit, |  | ||||||
| 		concurrencyChanger, |  | ||||||
| 		DefaultSmoothing, |  | ||||||
| 		DefaultLongWindow, |  | ||||||
| 		DefaultDebugLogger, |  | ||||||
| 		DefaultMetricsRegistry, |  | ||||||
| 	) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return &RequestLimiter{}, fmt.Errorf("failed to create gradient2 limit: %w", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	strategy := strategy.NewSimpleStrategy(flags.InitialLimit) |  | ||||||
| 	defLimiter, err := limiter.NewDefaultLimiter(lim, 1e9, 1e9, 10, 100, strategy, nil, DefaultMetricsRegistry) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return &RequestLimiter{}, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &RequestLimiter{Flags: flags, DefaultLimiter: defLimiter}, nil |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -1,51 +1,14 @@ | |||||||
| // Copyright (c) HashiCorp, Inc. | // Copyright (c) HashiCorp, Inc. | ||||||
| // SPDX-License-Identifier: BUSL-1.1 | // SPDX-License-Identifier: BUSL-1.1 | ||||||
|  |  | ||||||
|  | //go:build !enterprise | ||||||
|  |  | ||||||
| package limits | package limits | ||||||
|  |  | ||||||
| import ( | type RequestListener struct{} | ||||||
| 	"sync/atomic" |  | ||||||
|  |  | ||||||
| 	"github.com/armon/go-metrics" | func (l *RequestListener) OnSuccess() {} | ||||||
| 	"github.com/platinummonkey/go-concurrency-limits/limiter" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // RequestListener is a thin wrapper for limiter.DefaultLimiter to handle the | func (l *RequestListener) OnDropped() {} | ||||||
| // case where request limiting is turned off. |  | ||||||
| type RequestListener struct { |  | ||||||
| 	*limiter.DefaultListener |  | ||||||
| 	released *atomic.Bool |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // OnSuccess is called as a notification that the operation succeeded and | func (l *RequestListener) OnIgnore() {} | ||||||
| // internally measured latency should be used as an RTT sample. |  | ||||||
| func (l *RequestListener) OnSuccess() { |  | ||||||
| 	if l.DefaultListener != nil { |  | ||||||
| 		metrics.IncrCounter(([]string{"limits", "concurrency", "success"}), 1) |  | ||||||
| 		l.DefaultListener.OnSuccess() |  | ||||||
| 		l.released.Store(true) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // OnDropped is called to indicate the request failed and was dropped due to an |  | ||||||
| // internal server error. Note that this does not include ErrCapacity. |  | ||||||
| func (l *RequestListener) OnDropped() { |  | ||||||
| 	if l.DefaultListener != nil { |  | ||||||
| 		metrics.IncrCounter(([]string{"limits", "concurrency", "dropped"}), 1) |  | ||||||
| 		l.DefaultListener.OnDropped() |  | ||||||
| 		l.released.Store(true) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // OnIgnore is called to indicate the operation failed before any meaningful RTT |  | ||||||
| // measurement could be made and should be ignored to not introduce an |  | ||||||
| // artificially low RTT. It also provides an extra layer of protection against |  | ||||||
| // leaks of the underlying StrategyToken during recoverable panics in the |  | ||||||
| // request handler. We treat these as Ignored, discard the measurement, and mark |  | ||||||
| // the listener as released. |  | ||||||
| func (l *RequestListener) OnIgnore() { |  | ||||||
| 	if l.DefaultListener != nil && l.released.Load() != true { |  | ||||||
| 		metrics.IncrCounter(([]string{"limits", "concurrency", "ignored"}), 1) |  | ||||||
| 		l.DefaultListener.OnIgnore() |  | ||||||
| 		l.released.Store(true) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -1,222 +1,9 @@ | |||||||
| // Copyright (c) HashiCorp, Inc. | // Copyright (c) HashiCorp, Inc. | ||||||
| // SPDX-License-Identifier: BUSL-1.1 | // SPDX-License-Identifier: BUSL-1.1 | ||||||
|  |  | ||||||
|  | //go:build !enterprise | ||||||
|  |  | ||||||
| package limits | package limits | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"os" |  | ||||||
| 	"strconv" |  | ||||||
| 	"sync" |  | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-hclog" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	WriteLimiter         = "write" |  | ||||||
| 	SpecialPathLimiter   = "special-path" |  | ||||||
| 	LimitsBadEnvVariable = "failed to process limiter environment variable, using default" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // NOTE: Great care should be taken when setting any of these variables to avoid |  | ||||||
| // adverse affects in optimal request servicing. It is strongly advised that |  | ||||||
| // these variables not be used unless there is a very good reason. These are |  | ||||||
| // intentionally undocumented environment variables that may be removed in |  | ||||||
| // future versions of Vault. |  | ||||||
| const ( |  | ||||||
| 	// EnvVaultDisableWriteLimiter is used to turn off the |  | ||||||
| 	// RequestLimiter for write-based HTTP methods. |  | ||||||
| 	EnvVaultDisableWriteLimiter = "VAULT_DISABLE_WRITE_LIMITER" |  | ||||||
|  |  | ||||||
| 	// EnvVaultWriteLimiterMin is used to modify the minimum |  | ||||||
| 	// concurrency limit for write-based HTTP methods. |  | ||||||
| 	EnvVaultWriteLimiterMin = "VAULT_WRITE_LIMITER_MIN" |  | ||||||
|  |  | ||||||
| 	// EnvVaultWriteLimiterMax is used to modify the maximum |  | ||||||
| 	// concurrency limit for write-based HTTP methods. |  | ||||||
| 	EnvVaultWriteLimiterMax = "VAULT_WRITE_LIMITER_MAX" |  | ||||||
|  |  | ||||||
| 	// EnvVaultDisablePathBasedRequestLimiting is used to turn off the |  | ||||||
| 	// RequestLimiter for special-cased paths, specified in |  | ||||||
| 	// Backend.PathsSpecial. |  | ||||||
| 	EnvVaultDisableSpecialPathLimiter = "VAULT_DISABLE_SPECIAL_PATH_LIMITER" |  | ||||||
|  |  | ||||||
| 	// EnvVaultSpecialPathLimiterMin is used to modify the minimum |  | ||||||
| 	// concurrency limit for write-based HTTP methods. |  | ||||||
| 	EnvVaultSpecialPathLimiterMin = "VAULT_SPECIAL_PATH_LIMITER_MIN" |  | ||||||
|  |  | ||||||
| 	// EnvVaultSpecialPathLimiterMax is used to modify the maximum |  | ||||||
| 	// concurrency limit for write-based HTTP methods. |  | ||||||
| 	EnvVaultSpecialPathLimiterMax = "VAULT_SPECIAL_PATH_LIMITER_MAX" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // LimiterRegistry holds the map of RequestLimiters mapped to keys. | // LimiterRegistry holds the map of RequestLimiters mapped to keys. | ||||||
| type LimiterRegistry struct { | type LimiterRegistry struct{} | ||||||
| 	Limiters map[string]*RequestLimiter |  | ||||||
| 	Logger   hclog.Logger |  | ||||||
| 	Enabled  bool |  | ||||||
| 	sync.RWMutex |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NewLimiterRegistry is a basic LimiterRegistry constructor. |  | ||||||
| func NewLimiterRegistry(logger hclog.Logger) *LimiterRegistry { |  | ||||||
| 	return &LimiterRegistry{ |  | ||||||
| 		Limiters: make(map[string]*RequestLimiter), |  | ||||||
| 		Logger:   logger, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // processEnvVars consults Limiter-specific environment variables and tells the |  | ||||||
| // caller if the Limiter should be disabled. If not, it adjusts the passed-in |  | ||||||
| // limiterFlags as appropriate. |  | ||||||
| func (r *LimiterRegistry) processEnvVars(name string, flags *LimiterFlags, envDisabled, envMin, envMax string) bool { |  | ||||||
| 	envFlagsLogger := r.Logger.With("name", name) |  | ||||||
| 	if disabledRaw := os.Getenv(envDisabled); disabledRaw != "" { |  | ||||||
| 		disabled, err := strconv.ParseBool(disabledRaw) |  | ||||||
| 		if err != nil { |  | ||||||
| 			envFlagsLogger.Warn(LimitsBadEnvVariable, |  | ||||||
| 				"env", envDisabled, |  | ||||||
| 				"val", disabledRaw, |  | ||||||
| 				"default", false, |  | ||||||
| 				"error", err, |  | ||||||
| 			) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if disabled { |  | ||||||
| 			envFlagsLogger.Warn("limiter disabled by environment variable", "env", envDisabled, "val", disabledRaw) |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	envFlags := &LimiterFlags{} |  | ||||||
| 	if minRaw := os.Getenv(envMin); minRaw != "" { |  | ||||||
| 		min, err := strconv.Atoi(minRaw) |  | ||||||
| 		if err != nil { |  | ||||||
| 			envFlagsLogger.Warn(LimitsBadEnvVariable, |  | ||||||
| 				"env", envMin, |  | ||||||
| 				"val", minRaw, |  | ||||||
| 				"default", flags.MinLimit, |  | ||||||
| 				"error", err, |  | ||||||
| 			) |  | ||||||
| 		} else { |  | ||||||
| 			envFlags.MinLimit = min |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if maxRaw := os.Getenv(envMax); maxRaw != "" { |  | ||||||
| 		max, err := strconv.Atoi(maxRaw) |  | ||||||
| 		if err != nil { |  | ||||||
| 			envFlagsLogger.Warn(LimitsBadEnvVariable, |  | ||||||
| 				"env", envMax, |  | ||||||
| 				"val", maxRaw, |  | ||||||
| 				"default", flags.MaxLimit, |  | ||||||
| 				"error", err, |  | ||||||
| 			) |  | ||||||
| 		} else { |  | ||||||
| 			envFlags.MaxLimit = max |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	switch { |  | ||||||
| 	case envFlags.MinLimit == 0: |  | ||||||
| 		// Assume no environment variable was provided. |  | ||||||
| 	case envFlags.MinLimit > 0: |  | ||||||
| 		flags.MinLimit = envFlags.MinLimit |  | ||||||
| 	default: |  | ||||||
| 		r.Logger.Warn("min limit must be greater than zero, falling back to defaults", "minLimit", flags.MinLimit) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	switch { |  | ||||||
| 	case envFlags.MaxLimit == 0: |  | ||||||
| 		// Assume no environment variable was provided. |  | ||||||
| 	case envFlags.MaxLimit > flags.MinLimit: |  | ||||||
| 		flags.MaxLimit = envFlags.MaxLimit |  | ||||||
| 	default: |  | ||||||
| 		r.Logger.Warn("max limit must be greater than min, falling back to defaults", "maxLimit", flags.MaxLimit) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Enable sets up a new LimiterRegistry and marks it Enabled. |  | ||||||
| func (r *LimiterRegistry) Enable() { |  | ||||||
| 	r.Lock() |  | ||||||
| 	defer r.Unlock() |  | ||||||
|  |  | ||||||
| 	if r.Enabled { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	r.Logger.Info("enabling request limiters") |  | ||||||
| 	r.Limiters = map[string]*RequestLimiter{} |  | ||||||
|  |  | ||||||
| 	for name, flags := range DefaultLimiterFlags { |  | ||||||
| 		r.Register(name, flags) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	r.Enabled = true |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Register creates a new request limiter and assigns it a slot in the |  | ||||||
| // LimiterRegistry. Locking should be done in the caller. |  | ||||||
| func (r *LimiterRegistry) Register(name string, flags LimiterFlags) { |  | ||||||
| 	var disabled bool |  | ||||||
|  |  | ||||||
| 	switch name { |  | ||||||
| 	case WriteLimiter: |  | ||||||
| 		disabled = r.processEnvVars(name, &flags, |  | ||||||
| 			EnvVaultDisableWriteLimiter, |  | ||||||
| 			EnvVaultWriteLimiterMin, |  | ||||||
| 			EnvVaultWriteLimiterMax, |  | ||||||
| 		) |  | ||||||
| 		if disabled { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	case SpecialPathLimiter: |  | ||||||
| 		disabled = r.processEnvVars(name, &flags, |  | ||||||
| 			EnvVaultDisableSpecialPathLimiter, |  | ||||||
| 			EnvVaultSpecialPathLimiterMin, |  | ||||||
| 			EnvVaultSpecialPathLimiterMax, |  | ||||||
| 		) |  | ||||||
| 		if disabled { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	default: |  | ||||||
| 		r.Logger.Warn("skipping invalid limiter type", "key", name) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Always set the initial limit to min so the system can find its own |  | ||||||
| 	// equilibrium, since max might be too high. |  | ||||||
| 	flags.InitialLimit = flags.MinLimit |  | ||||||
|  |  | ||||||
| 	limiter, err := NewRequestLimiter(r.Logger.Named(name), name, flags) |  | ||||||
| 	if err != nil { |  | ||||||
| 		r.Logger.Error("failed to register limiter", "name", name, "error", err) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	r.Limiters[name] = limiter |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Disable drops its references to underlying limiters. |  | ||||||
| func (r *LimiterRegistry) Disable() { |  | ||||||
| 	r.Lock() |  | ||||||
| 	defer r.Unlock() |  | ||||||
|  |  | ||||||
| 	if !r.Enabled { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	r.Logger.Info("disabling request limiters") |  | ||||||
| 	// Any outstanding tokens will be flushed when their request completes, as |  | ||||||
| 	// they've already acquired a listener. Just drop the limiter references |  | ||||||
| 	// here and the garbage-collector should take care of the rest. |  | ||||||
| 	r.Limiters = map[string]*RequestLimiter{} |  | ||||||
| 	r.Enabled = false |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // GetLimiter looks up a RequestLimiter by key in the LimiterRegistry. |  | ||||||
| func (r *LimiterRegistry) GetLimiter(key string) *RequestLimiter { |  | ||||||
| 	r.RLock() |  | ||||||
| 	defer r.RUnlock() |  | ||||||
| 	return r.Limiters[key] |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -546,9 +546,3 @@ func ContextOriginalBodyValue(ctx context.Context) (io.ReadCloser, bool) { | |||||||
| func CreateContextOriginalBody(parent context.Context, body io.ReadCloser) context.Context { | func CreateContextOriginalBody(parent context.Context, body io.ReadCloser) context.Context { | ||||||
| 	return context.WithValue(parent, ctxKeyOriginalBody{}, body) | 	return context.WithValue(parent, ctxKeyOriginalBody{}, body) | ||||||
| } | } | ||||||
|  |  | ||||||
| type CtxKeyDisableRequestLimiter struct{} |  | ||||||
|  |  | ||||||
| func (c CtxKeyDisableRequestLimiter) String() string { |  | ||||||
| 	return "disable_request_limiter" |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -49,7 +49,6 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/helper/metricsutil" | 	"github.com/hashicorp/vault/helper/metricsutil" | ||||||
| 	"github.com/hashicorp/vault/helper/namespace" | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
| 	"github.com/hashicorp/vault/helper/osutil" | 	"github.com/hashicorp/vault/helper/osutil" | ||||||
| 	"github.com/hashicorp/vault/limits" |  | ||||||
| 	"github.com/hashicorp/vault/physical/raft" | 	"github.com/hashicorp/vault/physical/raft" | ||||||
| 	"github.com/hashicorp/vault/plugins/event" | 	"github.com/hashicorp/vault/plugins/event" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/certutil" | 	"github.com/hashicorp/vault/sdk/helper/certutil" | ||||||
| @@ -715,9 +714,6 @@ type Core struct { | |||||||
| 	periodicLeaderRefreshInterval time.Duration | 	periodicLeaderRefreshInterval time.Duration | ||||||
|  |  | ||||||
| 	clusterAddrBridge *raft.ClusterAddrBridge | 	clusterAddrBridge *raft.ClusterAddrBridge | ||||||
|  |  | ||||||
| 	limiterRegistry     *limits.LimiterRegistry |  | ||||||
| 	limiterRegistryLock sync.Mutex |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Core) ActiveNodeClockSkewMillis() int64 { | func (c *Core) ActiveNodeClockSkewMillis() int64 { | ||||||
| @@ -728,12 +724,6 @@ func (c *Core) EchoDuration() time.Duration { | |||||||
| 	return c.echoDuration.Load() | 	return c.echoDuration.Load() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Core) GetRequestLimiter(key string) *limits.RequestLimiter { |  | ||||||
| 	c.limiterRegistryLock.Lock() |  | ||||||
| 	defer c.limiterRegistryLock.Unlock() |  | ||||||
| 	return c.limiterRegistry.GetLimiter(key) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // c.stateLock needs to be held in read mode before calling this function. | // c.stateLock needs to be held in read mode before calling this function. | ||||||
| func (c *Core) HAState() consts.HAState { | func (c *Core) HAState() consts.HAState { | ||||||
| 	switch { | 	switch { | ||||||
| @@ -902,9 +892,6 @@ type CoreConfig struct { | |||||||
| 	PeriodicLeaderRefreshInterval time.Duration | 	PeriodicLeaderRefreshInterval time.Duration | ||||||
|  |  | ||||||
| 	ClusterAddrBridge *raft.ClusterAddrBridge | 	ClusterAddrBridge *raft.ClusterAddrBridge | ||||||
|  |  | ||||||
| 	DisableRequestLimiter bool |  | ||||||
| 	LimiterRegistry       *limits.LimiterRegistry |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // GetServiceRegistration returns the config's ServiceRegistration, or nil if it does | // GetServiceRegistration returns the config's ServiceRegistration, or nil if it does | ||||||
| @@ -1007,10 +994,6 @@ func CreateCore(conf *CoreConfig) (*Core, error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if conf.LimiterRegistry == nil { |  | ||||||
| 		conf.LimiterRegistry = limits.NewLimiterRegistry(conf.Logger.Named("limits")) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Use imported logging deadlock if requested | 	// Use imported logging deadlock if requested | ||||||
| 	var stateLock locking.RWMutex | 	var stateLock locking.RWMutex | ||||||
| 	stateLock = &locking.SyncRWMutex{} | 	stateLock = &locking.SyncRWMutex{} | ||||||
| @@ -1315,14 +1298,6 @@ func NewCore(conf *CoreConfig) (*Core, error) { | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c.limiterRegistry = conf.LimiterRegistry |  | ||||||
| 	c.limiterRegistryLock.Lock() |  | ||||||
| 	c.limiterRegistry.Disable() |  | ||||||
| 	if !conf.DisableRequestLimiter { |  | ||||||
| 		c.limiterRegistry.Enable() |  | ||||||
| 	} |  | ||||||
| 	c.limiterRegistryLock.Unlock() |  | ||||||
|  |  | ||||||
| 	err = c.adjustForSealMigration(conf.UnwrapSeal) | 	err = c.adjustForSealMigration(conf.UnwrapSeal) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -4109,27 +4084,6 @@ func (c *Core) ReloadLogRequestsLevel() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Core) ReloadRequestLimiter() { |  | ||||||
| 	c.limiterRegistry.Logger.Info("reloading request limiter config") |  | ||||||
| 	conf := c.rawConfig.Load() |  | ||||||
| 	if conf == nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	disable := true |  | ||||||
| 	requestLimiterConfig := conf.(*server.Config).RequestLimiter |  | ||||||
| 	if requestLimiterConfig != nil { |  | ||||||
| 		disable = requestLimiterConfig.Disable |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	switch disable { |  | ||||||
| 	case true: |  | ||||||
| 		c.limiterRegistry.Disable() |  | ||||||
| 	default: |  | ||||||
| 		c.limiterRegistry.Enable() |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (c *Core) ReloadIntrospectionEndpointEnabled() { | func (c *Core) ReloadIntrospectionEndpointEnabled() { | ||||||
| 	conf := c.rawConfig.Load() | 	conf := c.rawConfig.Load() | ||||||
| 	if conf == nil { | 	if conf == nil { | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-hclog" | 	"github.com/hashicorp/go-hclog" | ||||||
| 	"github.com/hashicorp/vault/helper/namespace" | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
|  | 	"github.com/hashicorp/vault/limits" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/license" | 	"github.com/hashicorp/vault/sdk/helper/license" | ||||||
| 	"github.com/hashicorp/vault/sdk/logical" | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
| 	"github.com/hashicorp/vault/sdk/physical" | 	"github.com/hashicorp/vault/sdk/physical" | ||||||
| @@ -213,3 +214,11 @@ func DiagnoseCheckLicense(ctx context.Context, vaultCore *Core, coreConfig CoreC | |||||||
| func createCustomMessageManager(storage logical.Storage, _ *Core) CustomMessagesManager { | func createCustomMessageManager(storage logical.Storage, _ *Core) CustomMessagesManager { | ||||||
| 	return uicustommessages.NewManager(storage) | 	return uicustommessages.NewManager(storage) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // GetRequestLimiter is a stub for CE. The caller will handle the nil case as a no-op. | ||||||
|  | func (c *Core) GetRequestLimiter(key string) *limits.RequestLimiter { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ReloadRequestLimiter is a no-op on CE. | ||||||
|  | func (c *Core) ReloadRequestLimiter() {} | ||||||
|   | |||||||
| @@ -226,10 +226,6 @@ func NewSystemBackend(core *Core, logger log.Logger, config *logical.BackendConf | |||||||
| 	b.Backend.Paths = append(b.Backend.Paths, b.experimentPaths()...) | 	b.Backend.Paths = append(b.Backend.Paths, b.experimentPaths()...) | ||||||
| 	b.Backend.Paths = append(b.Backend.Paths, b.introspectionPaths()...) | 	b.Backend.Paths = append(b.Backend.Paths, b.introspectionPaths()...) | ||||||
|  |  | ||||||
| 	if requestLimiterRead := b.requestLimiterReadPath(); requestLimiterRead != nil { |  | ||||||
| 		b.Backend.Paths = append(b.Backend.Paths, b.requestLimiterReadPath()) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if core.rawEnabled { | 	if core.rawEnabled { | ||||||
| 		b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) | 		b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,88 +0,0 @@ | |||||||
| // Copyright (c) HashiCorp, Inc. |  | ||||||
| // SPDX-License-Identifier: BUSL-1.1 |  | ||||||
|  |  | ||||||
| //go:build testonly |  | ||||||
|  |  | ||||||
| package vault |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"net/http" |  | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/limits" |  | ||||||
| 	"github.com/hashicorp/vault/sdk/framework" |  | ||||||
| 	"github.com/hashicorp/vault/sdk/logical" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // RequestLimiterResponse is a struct for marshalling Request Limiter status responses. |  | ||||||
| type RequestLimiterResponse struct { |  | ||||||
| 	GlobalDisabled   bool                      `json:"global_disabled" mapstructure:"global_disabled"` |  | ||||||
| 	ListenerDisabled bool                      `json:"listener_disabled" mapstructure:"listener_disabled"` |  | ||||||
| 	Limiters         map[string]*LimiterStatus `json:"types" mapstructure:"types"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // LimiterStatus holds the per-limiter status and flags for testing. |  | ||||||
| type LimiterStatus struct { |  | ||||||
| 	Enabled bool                `json:"enabled" mapstructure:"enabled"` |  | ||||||
| 	Flags   limits.LimiterFlags `json:"flags,omitempty" mapstructure:"flags,omitempty"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const readRequestLimiterHelpText = ` |  | ||||||
| Read the current status of the request limiter. |  | ||||||
| ` |  | ||||||
|  |  | ||||||
| func (b *SystemBackend) requestLimiterReadPath() *framework.Path { |  | ||||||
| 	return &framework.Path{ |  | ||||||
| 		Pattern:         "internal/request-limiter/status$", |  | ||||||
| 		HelpDescription: readRequestLimiterHelpText, |  | ||||||
| 		HelpSynopsis:    readRequestLimiterHelpText, |  | ||||||
| 		Operations: map[logical.Operation]framework.OperationHandler{ |  | ||||||
| 			logical.ReadOperation: &framework.PathOperation{ |  | ||||||
| 				Callback: b.handleReadRequestLimiter, |  | ||||||
| 				DisplayAttrs: &framework.DisplayAttributes{ |  | ||||||
| 					OperationVerb:   "read", |  | ||||||
| 					OperationSuffix: "verbosity-level-for", |  | ||||||
| 				}, |  | ||||||
| 				Responses: map[int][]framework.Response{ |  | ||||||
| 					http.StatusOK: {{ |  | ||||||
| 						Description: "OK", |  | ||||||
| 					}}, |  | ||||||
| 				}, |  | ||||||
| 				Summary: "Read the current status of the request limiter.", |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // handleReadRequestLimiter returns the enabled Request Limiter status for this node. |  | ||||||
| func (b *SystemBackend) handleReadRequestLimiter(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { |  | ||||||
| 	resp := &RequestLimiterResponse{ |  | ||||||
| 		Limiters: make(map[string]*LimiterStatus), |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	b.Core.limiterRegistryLock.Lock() |  | ||||||
| 	registry := b.Core.limiterRegistry |  | ||||||
| 	b.Core.limiterRegistryLock.Unlock() |  | ||||||
|  |  | ||||||
| 	resp.GlobalDisabled = !registry.Enabled |  | ||||||
| 	resp.ListenerDisabled = req.RequestLimiterDisabled |  | ||||||
| 	enabled := !(resp.GlobalDisabled || resp.ListenerDisabled) |  | ||||||
|  |  | ||||||
| 	for name := range limits.DefaultLimiterFlags { |  | ||||||
| 		var flags limits.LimiterFlags |  | ||||||
| 		if requestLimiter := b.Core.GetRequestLimiter(name); requestLimiter != nil && enabled { |  | ||||||
| 			flags = requestLimiter.Flags |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		resp.Limiters[name] = &LimiterStatus{ |  | ||||||
| 			Enabled: enabled, |  | ||||||
| 			Flags:   flags, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return &logical.Response{ |  | ||||||
| 		Data: map[string]interface{}{ |  | ||||||
| 			"request_limiter": resp, |  | ||||||
| 		}, |  | ||||||
| 	}, nil |  | ||||||
| } |  | ||||||
| @@ -44,7 +44,6 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/helper/testhelpers/corehelpers" | 	"github.com/hashicorp/vault/helper/testhelpers/corehelpers" | ||||||
| 	"github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" | 	"github.com/hashicorp/vault/helper/testhelpers/pluginhelpers" | ||||||
| 	"github.com/hashicorp/vault/internalshared/configutil" | 	"github.com/hashicorp/vault/internalshared/configutil" | ||||||
| 	"github.com/hashicorp/vault/limits" |  | ||||||
| 	"github.com/hashicorp/vault/sdk/framework" | 	"github.com/hashicorp/vault/sdk/framework" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/consts" | 	"github.com/hashicorp/vault/sdk/helper/consts" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/logging" | 	"github.com/hashicorp/vault/sdk/helper/logging" | ||||||
| @@ -1133,8 +1132,6 @@ type TestClusterOptions struct { | |||||||
|  |  | ||||||
| 	// ABCDLoggerNames names the loggers according to our ABCD convention when generating 4 clusters | 	// ABCDLoggerNames names the loggers according to our ABCD convention when generating 4 clusters | ||||||
| 	ABCDLoggerNames bool | 	ABCDLoggerNames bool | ||||||
|  |  | ||||||
| 	LimiterRegistry *limits.LimiterRegistry |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type TestPluginConfig struct { | type TestPluginConfig struct { | ||||||
| @@ -1425,7 +1422,6 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te | |||||||
| 		EnableUI:           true, | 		EnableUI:           true, | ||||||
| 		EnableRaw:          true, | 		EnableRaw:          true, | ||||||
| 		BuiltinRegistry:    corehelpers.NewMockBuiltinRegistry(), | 		BuiltinRegistry:    corehelpers.NewMockBuiltinRegistry(), | ||||||
| 		LimiterRegistry:    limits.NewLimiterRegistry(testCluster.Logger), |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if base != nil { | 	if base != nil { | ||||||
| @@ -1515,10 +1511,6 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te | |||||||
| 		coreConfig.PeriodicLeaderRefreshInterval = base.PeriodicLeaderRefreshInterval | 		coreConfig.PeriodicLeaderRefreshInterval = base.PeriodicLeaderRefreshInterval | ||||||
| 		coreConfig.ClusterAddrBridge = base.ClusterAddrBridge | 		coreConfig.ClusterAddrBridge = base.ClusterAddrBridge | ||||||
|  |  | ||||||
| 		if base.LimiterRegistry != nil { |  | ||||||
| 			coreConfig.LimiterRegistry = base.LimiterRegistry |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		testApplyEntBaseConfig(coreConfig, base) | 		testApplyEntBaseConfig(coreConfig, base) | ||||||
| 	} | 	} | ||||||
| 	if coreConfig.ClusterName == "" { | 	if coreConfig.ClusterName == "" { | ||||||
| @@ -1912,10 +1904,6 @@ func (testCluster *TestCluster) newCore(t testing.T, idx int, coreConfig *CoreCo | |||||||
|  |  | ||||||
| 	localConfig.NumExpirationWorkers = numExpirationWorkersTest | 	localConfig.NumExpirationWorkers = numExpirationWorkersTest | ||||||
|  |  | ||||||
| 	if opts != nil && opts.LimiterRegistry != nil { |  | ||||||
| 		localConfig.LimiterRegistry = opts.LimiterRegistry |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	c, err := NewCore(&localConfig) | 	c, err := NewCore(&localConfig) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("err: %v", err) | 		t.Fatalf("err: %v", err) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Mike Palmiotto
					Mike Palmiotto