Vault Agent Cache Auto-Auth SSRF Protection (#7627)

* implement SSRF protection header

* add test for SSRF protection header

* cleanup

* refactor

* implement SSRF header on a per-listener basis

* cleanup

* cleanup

* creat unit test for agent SSRF

* improve unit test for agent SSRF

* add VaultRequest SSRF header to CLI

* fix unit test

* cleanup

* improve test suite

* simplify check for Vault-Request header

* add constant for Vault-Request header

* improve test suite

* change 'config' to 'agentConfig'

* Revert "change 'config' to 'agentConfig'"

This reverts commit 14ee72d21fff8027966ee3c89dd3ac41d849206f.

* do not remove header from request

* change header name to X-Vault-Request

* simplify http.Handler logic

* cleanup

* simplify http.Handler logic

* use stdlib errors package
This commit is contained in:
Mike Jarmy
2019-10-11 18:56:07 -04:00
committed by GitHub
parent 68750b70a2
commit 77ceb7dde0
5 changed files with 304 additions and 9 deletions

View File

@@ -2,6 +2,7 @@ package command
import (
"context"
"errors"
"flag"
"fmt"
"io"
@@ -28,13 +29,14 @@ import (
"github.com/hashicorp/vault/command/agent/auth/jwt"
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
"github.com/hashicorp/vault/command/agent/cache"
"github.com/hashicorp/vault/command/agent/config"
agentConfig "github.com/hashicorp/vault/command/agent/config"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/command/agent/sink/file"
"github.com/hashicorp/vault/command/agent/sink/inmem"
gatedwriter "github.com/hashicorp/vault/helper/gated-writer"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/version"
"github.com/kr/pretty"
"github.com/mitchellh/cli"
@@ -192,7 +194,7 @@ func (c *AgentCommand) Run(args []string) int {
}
// Load the configuration
config, err := config.LoadConfig(c.flagConfigs[0])
config, err := agentConfig.LoadConfig(c.flagConfigs[0])
if err != nil {
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
return 1
@@ -418,11 +420,8 @@ func (c *AgentCommand) Run(args []string) int {
})
}
// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink))
// Create the request handler
cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink)
var listeners []net.Listener
for i, lnConfig := range config.Listeners {
@@ -434,6 +433,25 @@ func (c *AgentCommand) Run(args []string) int {
listeners = append(listeners, ln)
// Parse 'require_request_header' listener config option, and wrap
// the request handler if necessary
muxHandler := cacheHandler
if v, ok := lnConfig.Config[agentConfig.RequireRequestHeader]; ok {
switch v {
case true:
muxHandler = verifyRequestHeader(muxHandler)
case false /* noop */ :
default:
c.UI.Error(fmt.Sprintf("Invalid value for 'require_request_header': %v", v))
return 1
}
}
// Create a muxer and add paths relevant for the lease cache layer
mux := http.NewServeMux()
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
mux.Handle("/", muxHandler)
scheme := "https://"
if tlsConf == nil {
scheme = "http://"
@@ -536,6 +554,22 @@ func (c *AgentCommand) Run(args []string) int {
return 0
}
// verifyRequestHeader wraps an http.Handler inside a Handler that checks for
// the request header that is used for SSRF protection.
func verifyRequestHeader(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" {
logical.RespondError(w,
http.StatusPreconditionFailed,
errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName)))
return
}
handler.ServeHTTP(w, r)
})
}
func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
var isFlagSet bool
f.Visit(func(f *flag.Flag) {