diff --git a/command/agent.go b/command/agent.go index 3f9e7b1417..0130c554e1 100644 --- a/command/agent.go +++ b/command/agent.go @@ -669,15 +669,17 @@ func (c *AgentCommand) Run(args []string) int { Logger: apiProxyLogger, Sink: inmemSink, }) + useAutoAuthToken := false if config.APIProxy != nil { + useAutoAuthToken = true proxyVaultToken = !config.APIProxy.ForceAutoAuthToken } var muxHandler http.Handler if leaseCache != nil { - muxHandler = cache.ProxyHandler(ctx, apiProxyLogger, leaseCache, inmemSink, proxyVaultToken, authInProgress, invalidTokenErrCh) + muxHandler = cache.ProxyHandler(ctx, apiProxyLogger, leaseCache, inmemSink, proxyVaultToken, useAutoAuthToken, authInProgress, invalidTokenErrCh) } else { - muxHandler = cache.ProxyHandler(ctx, apiProxyLogger, apiProxy, inmemSink, proxyVaultToken, authInProgress, invalidTokenErrCh) + muxHandler = cache.ProxyHandler(ctx, apiProxyLogger, apiProxy, inmemSink, proxyVaultToken, useAutoAuthToken, authInProgress, invalidTokenErrCh) } // Parse 'require_request_header' listener config option, and wrap diff --git a/command/agent/cache_end_to_end_test.go b/command/agent/cache_end_to_end_test.go index 7fa6c0fc23..3d9305de9a 100644 --- a/command/agent/cache_end_to_end_test.go +++ b/command/agent/cache_end_to_end_test.go @@ -318,8 +318,8 @@ func TestCache_UsingAutoAuthToken(t *testing.T) { mux := http.NewServeMux() mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx)) - // Passing a non-nil inmemsink tells the agent to use the auto-auth token - mux.Handle("/", cache.ProxyHandler(ctx, cacheLogger, leaseCache, inmemSink, true, nil, nil)) + // Setting useAutoAuthToken to true to ensure that the auto-auth token is used + mux.Handle("/", cache.ProxyHandler(ctx, cacheLogger, leaseCache, inmemSink, true, true, nil, nil)) server := &http.Server{ Handler: mux, ReadHeaderTimeout: 10 * time.Second, diff --git a/command/agentproxyshared/cache/api_proxy_test.go b/command/agentproxyshared/cache/api_proxy_test.go index 234c6ae6ed..fceb317de6 100644 --- a/command/agentproxyshared/cache/api_proxy_test.go +++ b/command/agentproxyshared/cache/api_proxy_test.go @@ -285,9 +285,9 @@ func setupClusterAndAgentCommon(ctx context.Context, t *testing.T, coreConfig *v mux.Handle("/agent/v1/cache-clear", leaseCache.HandleCacheClear(ctx)) - mux.Handle("/", ProxyHandler(ctx, cacheLogger, leaseCache, nil, true, nil, nil)) + mux.Handle("/", ProxyHandler(ctx, cacheLogger, leaseCache, nil, true, false, nil, nil)) } else { - mux.Handle("/", ProxyHandler(ctx, apiProxyLogger, apiProxy, nil, true, nil, nil)) + mux.Handle("/", ProxyHandler(ctx, apiProxyLogger, apiProxy, nil, true, false, nil, nil)) } server := &http.Server{ diff --git a/command/agentproxyshared/cache/cache_test.go b/command/agentproxyshared/cache/cache_test.go index f19267e055..2e83bee38e 100644 --- a/command/agentproxyshared/cache/cache_test.go +++ b/command/agentproxyshared/cache/cache_test.go @@ -81,7 +81,7 @@ func TestCache_AutoAuthTokenStripping(t *testing.T) { mux := http.NewServeMux() mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx)) - mux.Handle("/", ProxyHandler(ctx, cacheLogger, leaseCache, mock.NewSink("testid"), true, nil, nil)) + mux.Handle("/", ProxyHandler(ctx, cacheLogger, leaseCache, mock.NewSink("testid"), true, true, nil, nil)) server := &http.Server{ Handler: mux, ReadHeaderTimeout: 10 * time.Second, @@ -170,7 +170,7 @@ func TestCache_AutoAuthClientTokenProxyStripping(t *testing.T) { mux := http.NewServeMux() // mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx)) - mux.Handle("/", ProxyHandler(ctx, cacheLogger, leaseCache, mock.NewSink(realToken), false, nil, nil)) + mux.Handle("/", ProxyHandler(ctx, cacheLogger, leaseCache, mock.NewSink(realToken), false, true, nil, nil)) server := &http.Server{ Handler: mux, ReadHeaderTimeout: 10 * time.Second, diff --git a/command/agentproxyshared/cache/handler.go b/command/agentproxyshared/cache/handler.go index 107c384024..f8cc1d1e54 100644 --- a/command/agentproxyshared/cache/handler.go +++ b/command/agentproxyshared/cache/handler.go @@ -25,11 +25,11 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func ProxyHandler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSink sink.Sink, proxyVaultToken bool, authInProgress *atomic.Bool, invalidTokenErrCh chan error) http.Handler { +func ProxyHandler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSink sink.Sink, useProxyVaultToken bool, useAutoAuthToken bool, authInProgress *atomic.Bool, invalidTokenErrCh chan error) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Info("received request", "method", r.Method, "path", r.URL.Path) - if !proxyVaultToken { + if !useProxyVaultToken { r.Header.Del(consts.AuthHeaderName) } @@ -38,7 +38,7 @@ func ProxyHandler(ctx context.Context, logger hclog.Logger, proxier Proxier, inm var autoAuthToken string if inmemSink != nil { autoAuthToken = inmemSink.(sink.SinkReader).Token() - if token == "" { + if token == "" && useAutoAuthToken { logger.Debug("using auto auth token", "method", r.Method, "path", r.URL.Path) token = autoAuthToken } diff --git a/command/proxy.go b/command/proxy.go index 3e06f15dd6..00e1020f48 100644 --- a/command/proxy.go +++ b/command/proxy.go @@ -627,16 +627,18 @@ func (c *ProxyCommand) Run(args []string) int { Logger: apiProxyLogger, Sink: inmemSink, }) + useAutoAuthToken := false proxyVaultToken := true if config.APIProxy != nil { + useAutoAuthToken = true proxyVaultToken = !config.APIProxy.ForceAutoAuthToken } var muxHandler http.Handler if leaseCache != nil { - muxHandler = cache.ProxyHandler(ctx, apiProxyLogger, leaseCache, inmemSink, proxyVaultToken, authInProgress, invalidTokenErrCh) + muxHandler = cache.ProxyHandler(ctx, apiProxyLogger, leaseCache, inmemSink, proxyVaultToken, useAutoAuthToken, authInProgress, invalidTokenErrCh) } else { - muxHandler = cache.ProxyHandler(ctx, apiProxyLogger, apiProxy, inmemSink, proxyVaultToken, authInProgress, invalidTokenErrCh) + muxHandler = cache.ProxyHandler(ctx, apiProxyLogger, apiProxy, inmemSink, proxyVaultToken, useAutoAuthToken, authInProgress, invalidTokenErrCh) } // Parse 'require_request_header' listener config option, and wrap