mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 10:37:56 +00:00 
			
		
		
		
	This PR flips the logic for the Request Limiter, setting it to default disabled. We allow users to turn on the global Request Limiter, but leave the Listener configuration as a "disable per Listener".
		
			
				
	
	
		
			214 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			214 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) HashiCorp, Inc.
 | |
| // SPDX-License-Identifier: BUSL-1.1
 | |
| 
 | |
| //go:build testonly
 | |
| 
 | |
| package command_testonly
 | |
| 
 | |
| import (
 | |
| 	"os"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/hashicorp/vault/api"
 | |
| 	"github.com/hashicorp/vault/command"
 | |
| 	"github.com/hashicorp/vault/limits"
 | |
| 	"github.com/hashicorp/vault/vault"
 | |
| 	"github.com/mitchellh/mapstructure"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| )
 | |
| 
 | |
| func init() {
 | |
| 	if signed := os.Getenv("VAULT_LICENSE_CI"); signed != "" {
 | |
| 		os.Setenv(command.EnvVaultLicense, signed)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| const (
 | |
| 	baseHCL = `
 | |
| 		backend "inmem" { }
 | |
| 		disable_mlock = true
 | |
| 		listener "tcp" {
 | |
| 			address     = "127.0.0.1:8209"
 | |
| 			tls_disable = "true"
 | |
| 		}
 | |
| 		api_addr = "http://127.0.0.1:8209"
 | |
| 	`
 | |
| 	requestLimiterDisableHCL = `
 | |
|   request_limiter {
 | |
| 	disable = true
 | |
|   }
 | |
| `
 | |
| 	requestLimiterEnableHCL = `
 | |
|   request_limiter {
 | |
| 	disable = false
 | |
|   }
 | |
| `
 | |
| )
 | |
| 
 | |
| // TestServer_ReloadRequestLimiter tests a series of reloads and state
 | |
| // transitions between RequestLimiter enable and disable.
 | |
| func TestServer_ReloadRequestLimiter(t *testing.T) {
 | |
| 	t.Parallel()
 | |
| 
 | |
| 	enabledResponse := &vault.RequestLimiterResponse{
 | |
| 		GlobalDisabled:   false,
 | |
| 		ListenerDisabled: false,
 | |
| 		Limiters: map[string]*vault.LimiterStatus{
 | |
| 			limits.WriteLimiter: {
 | |
| 				Enabled: true,
 | |
| 				Flags:   limits.DefaultLimiterFlags[limits.WriteLimiter],
 | |
| 			},
 | |
| 			limits.SpecialPathLimiter: {
 | |
| 				Enabled: true,
 | |
| 				Flags:   limits.DefaultLimiterFlags[limits.SpecialPathLimiter],
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	disabledResponse := &vault.RequestLimiterResponse{
 | |
| 		GlobalDisabled:   true,
 | |
| 		ListenerDisabled: false,
 | |
| 		Limiters: map[string]*vault.LimiterStatus{
 | |
| 			limits.WriteLimiter: {
 | |
| 				Enabled: false,
 | |
| 			},
 | |
| 			limits.SpecialPathLimiter: {
 | |
| 				Enabled: false,
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	cases := []struct {
 | |
| 		name             string
 | |
| 		configAfter      string
 | |
| 		expectedResponse *vault.RequestLimiterResponse
 | |
| 	}{
 | |
| 		{
 | |
| 			"disable after default",
 | |
| 			baseHCL + requestLimiterDisableHCL,
 | |
| 			disabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"disable after disable",
 | |
| 			baseHCL + requestLimiterDisableHCL,
 | |
| 			disabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"enable after disable",
 | |
| 			baseHCL + requestLimiterEnableHCL,
 | |
| 			enabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"default after enable",
 | |
| 			baseHCL,
 | |
| 			disabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"default after default",
 | |
| 			baseHCL,
 | |
| 			disabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"enable after default",
 | |
| 			baseHCL + requestLimiterEnableHCL,
 | |
| 			enabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"enable after enable",
 | |
| 			baseHCL + requestLimiterEnableHCL,
 | |
| 			enabledResponse,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	ui, srv := command.TestServerCommand(t)
 | |
| 
 | |
| 	f, err := os.CreateTemp(t.TempDir(), "")
 | |
| 	require.NoErrorf(t, err, "error creating temp dir: %v", err)
 | |
| 
 | |
| 	_, err = f.WriteString(baseHCL)
 | |
| 	require.NoErrorf(t, err, "cannot write temp file contents")
 | |
| 
 | |
| 	configPath := f.Name()
 | |
| 
 | |
| 	var output string
 | |
| 	wg := &sync.WaitGroup{}
 | |
| 	wg.Add(1)
 | |
| 	go func() {
 | |
| 		defer wg.Done()
 | |
| 		code := srv.Run([]string{"-config", configPath})
 | |
| 		output = ui.ErrorWriter.String() + ui.OutputWriter.String()
 | |
| 		require.Equal(t, 0, code, output)
 | |
| 	}()
 | |
| 
 | |
| 	select {
 | |
| 	case <-srv.StartedCh():
 | |
| 	case <-time.After(5 * time.Second):
 | |
| 		t.Fatalf("timeout")
 | |
| 	}
 | |
| 	defer func() {
 | |
| 		srv.ShutdownCh <- struct{}{}
 | |
| 		wg.Wait()
 | |
| 	}()
 | |
| 
 | |
| 	err = f.Close()
 | |
| 	require.NoErrorf(t, err, "unable to close temp file")
 | |
| 
 | |
| 	// create a client and unseal vault
 | |
| 	cli, err := srv.Client()
 | |
| 	require.NoError(t, err)
 | |
| 	require.NoError(t, cli.SetAddress("http://127.0.0.1:8209"))
 | |
| 	initResp, err := cli.Sys().Init(&api.InitRequest{SecretShares: 1, SecretThreshold: 1})
 | |
| 	require.NoError(t, err)
 | |
| 	_, err = cli.Sys().Unseal(initResp.Keys[0])
 | |
| 	require.NoError(t, err)
 | |
| 	cli.SetToken(initResp.RootToken)
 | |
| 
 | |
| 	output = ui.ErrorWriter.String() + ui.OutputWriter.String()
 | |
| 	require.Contains(t, output, "Request Limiter: disabled")
 | |
| 
 | |
| 	verifyLimiters := func(t *testing.T, expectedResponse *vault.RequestLimiterResponse) {
 | |
| 		t.Helper()
 | |
| 
 | |
| 		statusResp, err := cli.Logical().Read("/sys/internal/request-limiter/status")
 | |
| 		require.NoError(t, err)
 | |
| 		require.NotNil(t, statusResp)
 | |
| 
 | |
| 		limitersResp, ok := statusResp.Data["request_limiter"]
 | |
| 		require.True(t, ok)
 | |
| 		require.NotNil(t, limitersResp)
 | |
| 
 | |
| 		var limiters *vault.RequestLimiterResponse
 | |
| 		err = mapstructure.Decode(limitersResp, &limiters)
 | |
| 		require.NoError(t, err)
 | |
| 		require.NotNil(t, limiters)
 | |
| 
 | |
| 		require.Equal(t, expectedResponse, limiters)
 | |
| 	}
 | |
| 
 | |
| 	// Start off with default disabled
 | |
| 	verifyLimiters(t, disabledResponse)
 | |
| 
 | |
| 	for _, tc := range cases {
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			// Write the new contents and reload the server
 | |
| 			f, err = os.OpenFile(configPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644)
 | |
| 			require.NoError(t, err)
 | |
| 			defer f.Close()
 | |
| 
 | |
| 			_, err = f.WriteString(tc.configAfter)
 | |
| 			require.NoErrorf(t, err, "cannot write temp file contents")
 | |
| 
 | |
| 			srv.SighupCh <- struct{}{}
 | |
| 			select {
 | |
| 			case <-srv.ReloadedCh():
 | |
| 			case <-time.After(5 * time.Second):
 | |
| 				t.Fatalf("test timed out")
 | |
| 			}
 | |
| 
 | |
| 			verifyLimiters(t, tc.expectedResponse)
 | |
| 		})
 | |
| 	}
 | |
| }
 |