mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Vault Agent Cache Auto-Auth SSRF Protection (#7627)
* implement SSRF protection header * add test for SSRF protection header * cleanup * refactor * implement SSRF header on a per-listener basis * cleanup * cleanup * creat unit test for agent SSRF * improve unit test for agent SSRF * add VaultRequest SSRF header to CLI * fix unit test * cleanup * improve test suite * simplify check for Vault-Request header * add constant for Vault-Request header * improve test suite * change 'config' to 'agentConfig' * Revert "change 'config' to 'agentConfig'" This reverts commit 14ee72d21fff8027966ee3c89dd3ac41d849206f. * do not remove header from request * change header name to X-Vault-Request * simplify http.Handler logic * cleanup * simplify http.Handler logic * use stdlib errors package
This commit is contained in:
@@ -2,6 +2,7 @@ package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -28,13 +29,14 @@ import (
|
||||
"github.com/hashicorp/vault/command/agent/auth/jwt"
|
||||
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
||||
"github.com/hashicorp/vault/command/agent/cache"
|
||||
"github.com/hashicorp/vault/command/agent/config"
|
||||
agentConfig "github.com/hashicorp/vault/command/agent/config"
|
||||
"github.com/hashicorp/vault/command/agent/sink"
|
||||
"github.com/hashicorp/vault/command/agent/sink/file"
|
||||
"github.com/hashicorp/vault/command/agent/sink/inmem"
|
||||
gatedwriter "github.com/hashicorp/vault/helper/gated-writer"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/sdk/version"
|
||||
"github.com/kr/pretty"
|
||||
"github.com/mitchellh/cli"
|
||||
@@ -192,7 +194,7 @@ func (c *AgentCommand) Run(args []string) int {
|
||||
}
|
||||
|
||||
// Load the configuration
|
||||
config, err := config.LoadConfig(c.flagConfigs[0])
|
||||
config, err := agentConfig.LoadConfig(c.flagConfigs[0])
|
||||
if err != nil {
|
||||
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
|
||||
return 1
|
||||
@@ -418,11 +420,8 @@ func (c *AgentCommand) Run(args []string) int {
|
||||
})
|
||||
}
|
||||
|
||||
// Create a muxer and add paths relevant for the lease cache layer
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
|
||||
|
||||
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink))
|
||||
// Create the request handler
|
||||
cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink)
|
||||
|
||||
var listeners []net.Listener
|
||||
for i, lnConfig := range config.Listeners {
|
||||
@@ -434,6 +433,25 @@ func (c *AgentCommand) Run(args []string) int {
|
||||
|
||||
listeners = append(listeners, ln)
|
||||
|
||||
// Parse 'require_request_header' listener config option, and wrap
|
||||
// the request handler if necessary
|
||||
muxHandler := cacheHandler
|
||||
if v, ok := lnConfig.Config[agentConfig.RequireRequestHeader]; ok {
|
||||
switch v {
|
||||
case true:
|
||||
muxHandler = verifyRequestHeader(muxHandler)
|
||||
case false /* noop */ :
|
||||
default:
|
||||
c.UI.Error(fmt.Sprintf("Invalid value for 'require_request_header': %v", v))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// Create a muxer and add paths relevant for the lease cache layer
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
|
||||
mux.Handle("/", muxHandler)
|
||||
|
||||
scheme := "https://"
|
||||
if tlsConf == nil {
|
||||
scheme = "http://"
|
||||
@@ -536,6 +554,22 @@ func (c *AgentCommand) Run(args []string) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// verifyRequestHeader wraps an http.Handler inside a Handler that checks for
|
||||
// the request header that is used for SSRF protection.
|
||||
func verifyRequestHeader(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" {
|
||||
logical.RespondError(w,
|
||||
http.StatusPreconditionFailed,
|
||||
errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName)))
|
||||
return
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
|
||||
var isFlagSet bool
|
||||
f.Visit(func(f *flag.Flag) {
|
||||
|
Reference in New Issue
Block a user