mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-03 03:58:01 +00:00
Vault Agent Cache Auto-Auth SSRF Protection (#7627)
* implement SSRF protection header * add test for SSRF protection header * cleanup * refactor * implement SSRF header on a per-listener basis * cleanup * cleanup * creat unit test for agent SSRF * improve unit test for agent SSRF * add VaultRequest SSRF header to CLI * fix unit test * cleanup * improve test suite * simplify check for Vault-Request header * add constant for Vault-Request header * improve test suite * change 'config' to 'agentConfig' * Revert "change 'config' to 'agentConfig'" This reverts commit 14ee72d21fff8027966ee3c89dd3ac41d849206f. * do not remove header from request * change header name to X-Vault-Request * simplify http.Handler logic * cleanup * simplify http.Handler logic * use stdlib errors package
This commit is contained in:
@@ -427,10 +427,14 @@ func NewClient(c *Config) (*Client, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := &Client{
|
client := &Client{
|
||||||
addr: u,
|
addr: u,
|
||||||
config: c,
|
config: c,
|
||||||
|
headers: make(http.Header),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add the VaultRequest SSRF protection header
|
||||||
|
client.headers[consts.RequestHeaderName] = []string{"true"}
|
||||||
|
|
||||||
if token := os.Getenv(EnvVaultToken); token != "" {
|
if token := os.Getenv(EnvVaultToken); token != "" {
|
||||||
client.token = token
|
client.token = token
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package command
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -28,13 +29,14 @@ import (
|
|||||||
"github.com/hashicorp/vault/command/agent/auth/jwt"
|
"github.com/hashicorp/vault/command/agent/auth/jwt"
|
||||||
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
"github.com/hashicorp/vault/command/agent/auth/kubernetes"
|
||||||
"github.com/hashicorp/vault/command/agent/cache"
|
"github.com/hashicorp/vault/command/agent/cache"
|
||||||
"github.com/hashicorp/vault/command/agent/config"
|
agentConfig "github.com/hashicorp/vault/command/agent/config"
|
||||||
"github.com/hashicorp/vault/command/agent/sink"
|
"github.com/hashicorp/vault/command/agent/sink"
|
||||||
"github.com/hashicorp/vault/command/agent/sink/file"
|
"github.com/hashicorp/vault/command/agent/sink/file"
|
||||||
"github.com/hashicorp/vault/command/agent/sink/inmem"
|
"github.com/hashicorp/vault/command/agent/sink/inmem"
|
||||||
gatedwriter "github.com/hashicorp/vault/helper/gated-writer"
|
gatedwriter "github.com/hashicorp/vault/helper/gated-writer"
|
||||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||||
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/sdk/version"
|
"github.com/hashicorp/vault/sdk/version"
|
||||||
"github.com/kr/pretty"
|
"github.com/kr/pretty"
|
||||||
"github.com/mitchellh/cli"
|
"github.com/mitchellh/cli"
|
||||||
@@ -192,7 +194,7 @@ func (c *AgentCommand) Run(args []string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load the configuration
|
// Load the configuration
|
||||||
config, err := config.LoadConfig(c.flagConfigs[0])
|
config, err := agentConfig.LoadConfig(c.flagConfigs[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
|
c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
|
||||||
return 1
|
return 1
|
||||||
@@ -418,11 +420,8 @@ func (c *AgentCommand) Run(args []string) int {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a muxer and add paths relevant for the lease cache layer
|
// Create the request handler
|
||||||
mux := http.NewServeMux()
|
cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink)
|
||||||
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
|
|
||||||
|
|
||||||
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink))
|
|
||||||
|
|
||||||
var listeners []net.Listener
|
var listeners []net.Listener
|
||||||
for i, lnConfig := range config.Listeners {
|
for i, lnConfig := range config.Listeners {
|
||||||
@@ -434,6 +433,25 @@ func (c *AgentCommand) Run(args []string) int {
|
|||||||
|
|
||||||
listeners = append(listeners, ln)
|
listeners = append(listeners, ln)
|
||||||
|
|
||||||
|
// Parse 'require_request_header' listener config option, and wrap
|
||||||
|
// the request handler if necessary
|
||||||
|
muxHandler := cacheHandler
|
||||||
|
if v, ok := lnConfig.Config[agentConfig.RequireRequestHeader]; ok {
|
||||||
|
switch v {
|
||||||
|
case true:
|
||||||
|
muxHandler = verifyRequestHeader(muxHandler)
|
||||||
|
case false /* noop */ :
|
||||||
|
default:
|
||||||
|
c.UI.Error(fmt.Sprintf("Invalid value for 'require_request_header': %v", v))
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a muxer and add paths relevant for the lease cache layer
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
|
||||||
|
mux.Handle("/", muxHandler)
|
||||||
|
|
||||||
scheme := "https://"
|
scheme := "https://"
|
||||||
if tlsConf == nil {
|
if tlsConf == nil {
|
||||||
scheme = "http://"
|
scheme = "http://"
|
||||||
@@ -536,6 +554,22 @@ func (c *AgentCommand) Run(args []string) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// verifyRequestHeader wraps an http.Handler inside a Handler that checks for
|
||||||
|
// the request header that is used for SSRF protection.
|
||||||
|
func verifyRequestHeader(handler http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" {
|
||||||
|
logical.RespondError(w,
|
||||||
|
http.StatusPreconditionFailed,
|
||||||
|
errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
|
func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
|
||||||
var isFlagSet bool
|
var isFlagSet bool
|
||||||
f.Visit(func(f *flag.Flag) {
|
f.Visit(func(f *flag.Flag) {
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ type Listener struct {
|
|||||||
Config map[string]interface{}
|
Config map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequireRequestHeader is a listener configuration option
|
||||||
|
const RequireRequestHeader = "require_request_header"
|
||||||
|
|
||||||
type AutoAuth struct {
|
type AutoAuth struct {
|
||||||
Method *Method `hcl:"-"`
|
Method *Method `hcl:"-"`
|
||||||
Sinks []*Sink `hcl:"sinks"`
|
Sinks []*Sink `hcl:"sinks"`
|
||||||
|
|||||||
@@ -1,16 +1,23 @@
|
|||||||
package command
|
package command
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
hclog "github.com/hashicorp/go-hclog"
|
hclog "github.com/hashicorp/go-hclog"
|
||||||
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
|
vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
|
||||||
"github.com/hashicorp/vault/api"
|
"github.com/hashicorp/vault/api"
|
||||||
|
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
|
||||||
"github.com/hashicorp/vault/command/agent"
|
"github.com/hashicorp/vault/command/agent"
|
||||||
vaulthttp "github.com/hashicorp/vault/http"
|
vaulthttp "github.com/hashicorp/vault/http"
|
||||||
|
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||||
"github.com/hashicorp/vault/sdk/helper/logging"
|
"github.com/hashicorp/vault/sdk/helper/logging"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
@@ -370,3 +377,246 @@ auto_auth {
|
|||||||
t.Fatal("sink 1/2 values don't match")
|
t.Fatal("sink 1/2 values don't match")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAgent_RequireRequestHeader(t *testing.T) {
|
||||||
|
|
||||||
|
// request issues HTTP requests.
|
||||||
|
request := func(client *api.Client, req *api.Request, expectedStatusCode int) map[string]interface{} {
|
||||||
|
resp, err := client.RawRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != expectedStatusCode {
|
||||||
|
t.Fatalf("expected status code %d, not %d", expectedStatusCode, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
if len(bytes) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var body map[string]interface{}
|
||||||
|
err = json.Unmarshal(bytes, &body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeTempFile creates a temp file and populates it.
|
||||||
|
makeTempFile := func(name, contents string) string {
|
||||||
|
f, err := ioutil.TempFile("", name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
path := f.Name()
|
||||||
|
f.WriteString(contents)
|
||||||
|
f.Close()
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
// newApiClient creates an *api.Client.
|
||||||
|
newApiClient := func(addr string, includeVaultRequestHeader bool) *api.Client {
|
||||||
|
conf := api.DefaultConfig()
|
||||||
|
conf.Address = addr
|
||||||
|
cli, err := api.NewClient(conf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := cli.Headers()
|
||||||
|
val, ok := h[consts.RequestHeaderName]
|
||||||
|
if !ok || !reflect.DeepEqual(val, []string{"true"}) {
|
||||||
|
t.Fatalf("invalid %s header", consts.RequestHeaderName)
|
||||||
|
}
|
||||||
|
if !includeVaultRequestHeader {
|
||||||
|
delete(h, consts.RequestHeaderName)
|
||||||
|
cli.SetHeaders(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cli
|
||||||
|
}
|
||||||
|
|
||||||
|
//----------------------------------------------------
|
||||||
|
// Start the server and agent
|
||||||
|
//----------------------------------------------------
|
||||||
|
|
||||||
|
// Start a vault server
|
||||||
|
logger := logging.NewVaultLogger(hclog.Trace)
|
||||||
|
cluster := vault.NewTestCluster(t,
|
||||||
|
&vault.CoreConfig{
|
||||||
|
Logger: logger,
|
||||||
|
CredentialBackends: map[string]logical.Factory{
|
||||||
|
"approle": credAppRole.Factory,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&vault.TestClusterOptions{
|
||||||
|
HandlerFunc: vaulthttp.Handler,
|
||||||
|
})
|
||||||
|
cluster.Start()
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
vault.TestWaitActive(t, cluster.Cores[0].Core)
|
||||||
|
serverClient := cluster.Cores[0].Client
|
||||||
|
|
||||||
|
// Enable the approle auth method
|
||||||
|
req := serverClient.NewRequest("POST", "/v1/sys/auth/approle")
|
||||||
|
req.BodyBytes = []byte(`{
|
||||||
|
"type": "approle"
|
||||||
|
}`)
|
||||||
|
request(serverClient, req, 204)
|
||||||
|
|
||||||
|
// Create a named role
|
||||||
|
req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role")
|
||||||
|
req.BodyBytes = []byte(`{
|
||||||
|
"secret_id_num_uses": "10",
|
||||||
|
"secret_id_ttl": "1m",
|
||||||
|
"token_max_ttl": "1m",
|
||||||
|
"token_num_uses": "10",
|
||||||
|
"token_ttl": "1m"
|
||||||
|
}`)
|
||||||
|
request(serverClient, req, 204)
|
||||||
|
|
||||||
|
// Fetch the RoleID of the named role
|
||||||
|
req = serverClient.NewRequest("GET", "/v1/auth/approle/role/test-role/role-id")
|
||||||
|
body := request(serverClient, req, 200)
|
||||||
|
data := body["data"].(map[string]interface{})
|
||||||
|
roleID := data["role_id"].(string)
|
||||||
|
|
||||||
|
// Get a SecretID issued against the named role
|
||||||
|
req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role/secret-id")
|
||||||
|
body = request(serverClient, req, 200)
|
||||||
|
data = body["data"].(map[string]interface{})
|
||||||
|
secretID := data["secret_id"].(string)
|
||||||
|
|
||||||
|
// Write the RoleID and SecretID to temp files
|
||||||
|
roleIDPath := makeTempFile("role_id.txt", roleID+"\n")
|
||||||
|
secretIDPath := makeTempFile("secret_id.txt", secretID+"\n")
|
||||||
|
defer os.Remove(roleIDPath)
|
||||||
|
defer os.Remove(secretIDPath)
|
||||||
|
|
||||||
|
// Get a temp file path we can use for the sink
|
||||||
|
sinkPath := makeTempFile("sink.txt", "")
|
||||||
|
defer os.Remove(sinkPath)
|
||||||
|
|
||||||
|
// Create a config file
|
||||||
|
config := `
|
||||||
|
auto_auth {
|
||||||
|
method "approle" {
|
||||||
|
mount_path = "auth/approle"
|
||||||
|
config = {
|
||||||
|
role_id_file_path = "%s"
|
||||||
|
secret_id_file_path = "%s"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sink "file" {
|
||||||
|
config = {
|
||||||
|
path = "%s"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cache {
|
||||||
|
use_auto_auth_token = true
|
||||||
|
}
|
||||||
|
|
||||||
|
listener "tcp" {
|
||||||
|
address = "127.0.0.1:8101"
|
||||||
|
tls_disable = true
|
||||||
|
}
|
||||||
|
listener "tcp" {
|
||||||
|
address = "127.0.0.1:8102"
|
||||||
|
tls_disable = true
|
||||||
|
require_request_header = false
|
||||||
|
}
|
||||||
|
listener "tcp" {
|
||||||
|
address = "127.0.0.1:8103"
|
||||||
|
tls_disable = true
|
||||||
|
require_request_header = true
|
||||||
|
}
|
||||||
|
`
|
||||||
|
config = fmt.Sprintf(config, roleIDPath, secretIDPath, sinkPath)
|
||||||
|
configPath := makeTempFile("config.hcl", config)
|
||||||
|
defer os.Remove(configPath)
|
||||||
|
|
||||||
|
// Start the agent
|
||||||
|
ui, cmd := testAgentCommand(t, logger)
|
||||||
|
cmd.client = serverClient
|
||||||
|
cmd.startedCh = make(chan struct{})
|
||||||
|
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
code := cmd.Run([]string{"-config", configPath})
|
||||||
|
if code != 0 {
|
||||||
|
t.Errorf("non-zero return code when running agent: %d", code)
|
||||||
|
t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String())
|
||||||
|
t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String())
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-cmd.startedCh:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// defer agent shutdown
|
||||||
|
defer func() {
|
||||||
|
cmd.ShutdownCh <- struct{}{}
|
||||||
|
wg.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
//----------------------------------------------------
|
||||||
|
// Perform the tests
|
||||||
|
//----------------------------------------------------
|
||||||
|
|
||||||
|
// Test against a listener configuration that omits
|
||||||
|
// 'require_request_header', with the header missing from the request.
|
||||||
|
agentClient := newApiClient("http://127.0.0.1:8101", false)
|
||||||
|
req = agentClient.NewRequest("GET", "/v1/sys/health")
|
||||||
|
request(agentClient, req, 200)
|
||||||
|
|
||||||
|
// Test against a listener configuration that sets 'require_request_header'
|
||||||
|
// to 'false', with the header missing from the request.
|
||||||
|
agentClient = newApiClient("http://127.0.0.1:8102", false)
|
||||||
|
req = agentClient.NewRequest("GET", "/v1/sys/health")
|
||||||
|
request(agentClient, req, 200)
|
||||||
|
|
||||||
|
// Test against a listener configuration that sets 'require_request_header'
|
||||||
|
// to 'true', with the header missing from the request.
|
||||||
|
agentClient = newApiClient("http://127.0.0.1:8103", false)
|
||||||
|
req = agentClient.NewRequest("GET", "/v1/sys/health")
|
||||||
|
resp, err := agentClient.RawRequest(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusPreconditionFailed {
|
||||||
|
t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test against a listener configuration that sets 'require_request_header'
|
||||||
|
// to 'true', with an invalid header present in the request.
|
||||||
|
agentClient = newApiClient("http://127.0.0.1:8103", false)
|
||||||
|
h := agentClient.Headers()
|
||||||
|
h[consts.RequestHeaderName] = []string{"bogus"}
|
||||||
|
agentClient.SetHeaders(h)
|
||||||
|
req = agentClient.NewRequest("GET", "/v1/sys/health")
|
||||||
|
resp, err = agentClient.RawRequest(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusPreconditionFailed {
|
||||||
|
t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test against a listener configuration that sets 'require_request_header'
|
||||||
|
// to 'true', with the proper header present in the request.
|
||||||
|
agentClient = newApiClient("http://127.0.0.1:8103", true)
|
||||||
|
req = agentClient.NewRequest("GET", "/v1/sys/health")
|
||||||
|
request(agentClient, req, 200)
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ const (
|
|||||||
// AuthHeaderName is the name of the header containing the token.
|
// AuthHeaderName is the name of the header containing the token.
|
||||||
AuthHeaderName = "X-Vault-Token"
|
AuthHeaderName = "X-Vault-Token"
|
||||||
|
|
||||||
|
// RequestHeaderName is the name of the header used by the Agent for
|
||||||
|
// SSRF protection.
|
||||||
|
RequestHeaderName = "X-Vault-Request"
|
||||||
|
|
||||||
// PerformanceReplicationALPN is the negotiated protocol used for
|
// PerformanceReplicationALPN is the negotiated protocol used for
|
||||||
// performance replication.
|
// performance replication.
|
||||||
PerformanceReplicationALPN = "replication_v1"
|
PerformanceReplicationALPN = "replication_v1"
|
||||||
|
|||||||
Reference in New Issue
Block a user