mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 10:37:56 +00:00 
			
		
		
		
	 7ad778541e
			
		
	
	7ad778541e
	
	
	
		
			
			This PR flips the logic for the Request Limiter, setting it to default disabled. We allow users to turn on the global Request Limiter, but leave the Listener configuration as a "disable per Listener".
		
			
				
	
	
		
			214 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			214 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) HashiCorp, Inc.
 | |
| // SPDX-License-Identifier: BUSL-1.1
 | |
| 
 | |
| //go:build testonly
 | |
| 
 | |
| package command_testonly
 | |
| 
 | |
| import (
 | |
| 	"os"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/hashicorp/vault/api"
 | |
| 	"github.com/hashicorp/vault/command"
 | |
| 	"github.com/hashicorp/vault/limits"
 | |
| 	"github.com/hashicorp/vault/vault"
 | |
| 	"github.com/mitchellh/mapstructure"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| )
 | |
| 
 | |
| func init() {
 | |
| 	if signed := os.Getenv("VAULT_LICENSE_CI"); signed != "" {
 | |
| 		os.Setenv(command.EnvVaultLicense, signed)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| const (
 | |
| 	baseHCL = `
 | |
| 		backend "inmem" { }
 | |
| 		disable_mlock = true
 | |
| 		listener "tcp" {
 | |
| 			address     = "127.0.0.1:8209"
 | |
| 			tls_disable = "true"
 | |
| 		}
 | |
| 		api_addr = "http://127.0.0.1:8209"
 | |
| 	`
 | |
| 	requestLimiterDisableHCL = `
 | |
|   request_limiter {
 | |
| 	disable = true
 | |
|   }
 | |
| `
 | |
| 	requestLimiterEnableHCL = `
 | |
|   request_limiter {
 | |
| 	disable = false
 | |
|   }
 | |
| `
 | |
| )
 | |
| 
 | |
| // TestServer_ReloadRequestLimiter tests a series of reloads and state
 | |
| // transitions between RequestLimiter enable and disable.
 | |
| func TestServer_ReloadRequestLimiter(t *testing.T) {
 | |
| 	t.Parallel()
 | |
| 
 | |
| 	enabledResponse := &vault.RequestLimiterResponse{
 | |
| 		GlobalDisabled:   false,
 | |
| 		ListenerDisabled: false,
 | |
| 		Limiters: map[string]*vault.LimiterStatus{
 | |
| 			limits.WriteLimiter: {
 | |
| 				Enabled: true,
 | |
| 				Flags:   limits.DefaultLimiterFlags[limits.WriteLimiter],
 | |
| 			},
 | |
| 			limits.SpecialPathLimiter: {
 | |
| 				Enabled: true,
 | |
| 				Flags:   limits.DefaultLimiterFlags[limits.SpecialPathLimiter],
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	disabledResponse := &vault.RequestLimiterResponse{
 | |
| 		GlobalDisabled:   true,
 | |
| 		ListenerDisabled: false,
 | |
| 		Limiters: map[string]*vault.LimiterStatus{
 | |
| 			limits.WriteLimiter: {
 | |
| 				Enabled: false,
 | |
| 			},
 | |
| 			limits.SpecialPathLimiter: {
 | |
| 				Enabled: false,
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	cases := []struct {
 | |
| 		name             string
 | |
| 		configAfter      string
 | |
| 		expectedResponse *vault.RequestLimiterResponse
 | |
| 	}{
 | |
| 		{
 | |
| 			"disable after default",
 | |
| 			baseHCL + requestLimiterDisableHCL,
 | |
| 			disabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"disable after disable",
 | |
| 			baseHCL + requestLimiterDisableHCL,
 | |
| 			disabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"enable after disable",
 | |
| 			baseHCL + requestLimiterEnableHCL,
 | |
| 			enabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"default after enable",
 | |
| 			baseHCL,
 | |
| 			disabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"default after default",
 | |
| 			baseHCL,
 | |
| 			disabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"enable after default",
 | |
| 			baseHCL + requestLimiterEnableHCL,
 | |
| 			enabledResponse,
 | |
| 		},
 | |
| 		{
 | |
| 			"enable after enable",
 | |
| 			baseHCL + requestLimiterEnableHCL,
 | |
| 			enabledResponse,
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	ui, srv := command.TestServerCommand(t)
 | |
| 
 | |
| 	f, err := os.CreateTemp(t.TempDir(), "")
 | |
| 	require.NoErrorf(t, err, "error creating temp dir: %v", err)
 | |
| 
 | |
| 	_, err = f.WriteString(baseHCL)
 | |
| 	require.NoErrorf(t, err, "cannot write temp file contents")
 | |
| 
 | |
| 	configPath := f.Name()
 | |
| 
 | |
| 	var output string
 | |
| 	wg := &sync.WaitGroup{}
 | |
| 	wg.Add(1)
 | |
| 	go func() {
 | |
| 		defer wg.Done()
 | |
| 		code := srv.Run([]string{"-config", configPath})
 | |
| 		output = ui.ErrorWriter.String() + ui.OutputWriter.String()
 | |
| 		require.Equal(t, 0, code, output)
 | |
| 	}()
 | |
| 
 | |
| 	select {
 | |
| 	case <-srv.StartedCh():
 | |
| 	case <-time.After(5 * time.Second):
 | |
| 		t.Fatalf("timeout")
 | |
| 	}
 | |
| 	defer func() {
 | |
| 		srv.ShutdownCh <- struct{}{}
 | |
| 		wg.Wait()
 | |
| 	}()
 | |
| 
 | |
| 	err = f.Close()
 | |
| 	require.NoErrorf(t, err, "unable to close temp file")
 | |
| 
 | |
| 	// create a client and unseal vault
 | |
| 	cli, err := srv.Client()
 | |
| 	require.NoError(t, err)
 | |
| 	require.NoError(t, cli.SetAddress("http://127.0.0.1:8209"))
 | |
| 	initResp, err := cli.Sys().Init(&api.InitRequest{SecretShares: 1, SecretThreshold: 1})
 | |
| 	require.NoError(t, err)
 | |
| 	_, err = cli.Sys().Unseal(initResp.Keys[0])
 | |
| 	require.NoError(t, err)
 | |
| 	cli.SetToken(initResp.RootToken)
 | |
| 
 | |
| 	output = ui.ErrorWriter.String() + ui.OutputWriter.String()
 | |
| 	require.Contains(t, output, "Request Limiter: disabled")
 | |
| 
 | |
| 	verifyLimiters := func(t *testing.T, expectedResponse *vault.RequestLimiterResponse) {
 | |
| 		t.Helper()
 | |
| 
 | |
| 		statusResp, err := cli.Logical().Read("/sys/internal/request-limiter/status")
 | |
| 		require.NoError(t, err)
 | |
| 		require.NotNil(t, statusResp)
 | |
| 
 | |
| 		limitersResp, ok := statusResp.Data["request_limiter"]
 | |
| 		require.True(t, ok)
 | |
| 		require.NotNil(t, limitersResp)
 | |
| 
 | |
| 		var limiters *vault.RequestLimiterResponse
 | |
| 		err = mapstructure.Decode(limitersResp, &limiters)
 | |
| 		require.NoError(t, err)
 | |
| 		require.NotNil(t, limiters)
 | |
| 
 | |
| 		require.Equal(t, expectedResponse, limiters)
 | |
| 	}
 | |
| 
 | |
| 	// Start off with default disabled
 | |
| 	verifyLimiters(t, disabledResponse)
 | |
| 
 | |
| 	for _, tc := range cases {
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			// Write the new contents and reload the server
 | |
| 			f, err = os.OpenFile(configPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644)
 | |
| 			require.NoError(t, err)
 | |
| 			defer f.Close()
 | |
| 
 | |
| 			_, err = f.WriteString(tc.configAfter)
 | |
| 			require.NoErrorf(t, err, "cannot write temp file contents")
 | |
| 
 | |
| 			srv.SighupCh <- struct{}{}
 | |
| 			select {
 | |
| 			case <-srv.ReloadedCh():
 | |
| 			case <-time.After(5 * time.Second):
 | |
| 				t.Fatalf("test timed out")
 | |
| 			}
 | |
| 
 | |
| 			verifyLimiters(t, tc.expectedResponse)
 | |
| 		})
 | |
| 	}
 | |
| }
 |