mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	Vault Agent Cache Auto-Auth SSRF Protection (#7627)
* implement SSRF protection header * add test for SSRF protection header * cleanup * refactor * implement SSRF header on a per-listener basis * cleanup * cleanup * creat unit test for agent SSRF * improve unit test for agent SSRF * add VaultRequest SSRF header to CLI * fix unit test * cleanup * improve test suite * simplify check for Vault-Request header * add constant for Vault-Request header * improve test suite * change 'config' to 'agentConfig' * Revert "change 'config' to 'agentConfig'" This reverts commit 14ee72d21fff8027966ee3c89dd3ac41d849206f. * do not remove header from request * change header name to X-Vault-Request * simplify http.Handler logic * cleanup * simplify http.Handler logic * use stdlib errors package
This commit is contained in:
		@@ -427,10 +427,14 @@ func NewClient(c *Config) (*Client, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := &Client{
 | 
			
		||||
		addr:   u,
 | 
			
		||||
		config: c,
 | 
			
		||||
		addr:    u,
 | 
			
		||||
		config:  c,
 | 
			
		||||
		headers: make(http.Header),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Add the VaultRequest SSRF protection header
 | 
			
		||||
	client.headers[consts.RequestHeaderName] = []string{"true"}
 | 
			
		||||
 | 
			
		||||
	if token := os.Getenv(EnvVaultToken); token != "" {
 | 
			
		||||
		client.token = token
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package command
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"flag"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -28,13 +29,14 @@ import (
 | 
			
		||||
	"github.com/hashicorp/vault/command/agent/auth/jwt"
 | 
			
		||||
	"github.com/hashicorp/vault/command/agent/auth/kubernetes"
 | 
			
		||||
	"github.com/hashicorp/vault/command/agent/cache"
 | 
			
		||||
	"github.com/hashicorp/vault/command/agent/config"
 | 
			
		||||
	agentConfig "github.com/hashicorp/vault/command/agent/config"
 | 
			
		||||
	"github.com/hashicorp/vault/command/agent/sink"
 | 
			
		||||
	"github.com/hashicorp/vault/command/agent/sink/file"
 | 
			
		||||
	"github.com/hashicorp/vault/command/agent/sink/inmem"
 | 
			
		||||
	gatedwriter "github.com/hashicorp/vault/helper/gated-writer"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/helper/consts"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/helper/logging"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/version"
 | 
			
		||||
	"github.com/kr/pretty"
 | 
			
		||||
	"github.com/mitchellh/cli"
 | 
			
		||||
@@ -192,7 +194,7 @@ func (c *AgentCommand) Run(args []string) int {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Load the configuration
 | 
			
		||||
	config, err := config.LoadConfig(c.flagConfigs[0])
 | 
			
		||||
	config, err := agentConfig.LoadConfig(c.flagConfigs[0])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.UI.Error(fmt.Sprintf("Error loading configuration from %s: %s", c.flagConfigs[0], err))
 | 
			
		||||
		return 1
 | 
			
		||||
@@ -418,11 +420,8 @@ func (c *AgentCommand) Run(args []string) int {
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Create a muxer and add paths relevant for the lease cache layer
 | 
			
		||||
		mux := http.NewServeMux()
 | 
			
		||||
		mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
 | 
			
		||||
 | 
			
		||||
		mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink))
 | 
			
		||||
		// Create the request handler
 | 
			
		||||
		cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink)
 | 
			
		||||
 | 
			
		||||
		var listeners []net.Listener
 | 
			
		||||
		for i, lnConfig := range config.Listeners {
 | 
			
		||||
@@ -434,6 +433,25 @@ func (c *AgentCommand) Run(args []string) int {
 | 
			
		||||
 | 
			
		||||
			listeners = append(listeners, ln)
 | 
			
		||||
 | 
			
		||||
			// Parse 'require_request_header' listener config option, and wrap
 | 
			
		||||
			// the request handler if necessary
 | 
			
		||||
			muxHandler := cacheHandler
 | 
			
		||||
			if v, ok := lnConfig.Config[agentConfig.RequireRequestHeader]; ok {
 | 
			
		||||
				switch v {
 | 
			
		||||
				case true:
 | 
			
		||||
					muxHandler = verifyRequestHeader(muxHandler)
 | 
			
		||||
				case false /* noop */ :
 | 
			
		||||
				default:
 | 
			
		||||
					c.UI.Error(fmt.Sprintf("Invalid value for 'require_request_header': %v", v))
 | 
			
		||||
					return 1
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Create a muxer and add paths relevant for the lease cache layer
 | 
			
		||||
			mux := http.NewServeMux()
 | 
			
		||||
			mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))
 | 
			
		||||
			mux.Handle("/", muxHandler)
 | 
			
		||||
 | 
			
		||||
			scheme := "https://"
 | 
			
		||||
			if tlsConf == nil {
 | 
			
		||||
				scheme = "http://"
 | 
			
		||||
@@ -536,6 +554,22 @@ func (c *AgentCommand) Run(args []string) int {
 | 
			
		||||
	return 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// verifyRequestHeader wraps an http.Handler inside a Handler that checks for
 | 
			
		||||
// the request header that is used for SSRF protection.
 | 
			
		||||
func verifyRequestHeader(handler http.Handler) http.Handler {
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
 | 
			
		||||
		if val, ok := r.Header[consts.RequestHeaderName]; !ok || len(val) != 1 || val[0] != "true" {
 | 
			
		||||
			logical.RespondError(w,
 | 
			
		||||
				http.StatusPreconditionFailed,
 | 
			
		||||
				errors.New(fmt.Sprintf("missing '%s' header", consts.RequestHeaderName)))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		handler.ServeHTTP(w, r)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
 | 
			
		||||
	var isFlagSet bool
 | 
			
		||||
	f.Visit(func(f *flag.Flag) {
 | 
			
		||||
 
 | 
			
		||||
@@ -45,6 +45,9 @@ type Listener struct {
 | 
			
		||||
	Config map[string]interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequireRequestHeader is a listener configuration option
 | 
			
		||||
const RequireRequestHeader = "require_request_header"
 | 
			
		||||
 | 
			
		||||
type AutoAuth struct {
 | 
			
		||||
	Method *Method `hcl:"-"`
 | 
			
		||||
	Sinks  []*Sink `hcl:"sinks"`
 | 
			
		||||
 
 | 
			
		||||
@@ -1,16 +1,23 @@
 | 
			
		||||
package command
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	hclog "github.com/hashicorp/go-hclog"
 | 
			
		||||
	vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
 | 
			
		||||
	"github.com/hashicorp/vault/api"
 | 
			
		||||
	credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
 | 
			
		||||
	"github.com/hashicorp/vault/command/agent"
 | 
			
		||||
	vaulthttp "github.com/hashicorp/vault/http"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/helper/consts"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/helper/logging"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
			
		||||
	"github.com/hashicorp/vault/vault"
 | 
			
		||||
@@ -370,3 +377,246 @@ auto_auth {
 | 
			
		||||
		t.Fatal("sink 1/2 values don't match")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAgent_RequireRequestHeader(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	// request issues HTTP requests.
 | 
			
		||||
	request := func(client *api.Client, req *api.Request, expectedStatusCode int) map[string]interface{} {
 | 
			
		||||
		resp, err := client.RawRequest(req)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("err: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
		if resp.StatusCode != expectedStatusCode {
 | 
			
		||||
			t.Fatalf("expected status code %d, not %d", expectedStatusCode, resp.StatusCode)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		bytes, err := ioutil.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("err: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
		if len(bytes) == 0 {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var body map[string]interface{}
 | 
			
		||||
		err = json.Unmarshal(bytes, &body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("err: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
		return body
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// makeTempFile creates a temp file and populates it.
 | 
			
		||||
	makeTempFile := func(name, contents string) string {
 | 
			
		||||
		f, err := ioutil.TempFile("", name)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		path := f.Name()
 | 
			
		||||
		f.WriteString(contents)
 | 
			
		||||
		f.Close()
 | 
			
		||||
		return path
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// newApiClient creates an *api.Client.
 | 
			
		||||
	newApiClient := func(addr string, includeVaultRequestHeader bool) *api.Client {
 | 
			
		||||
		conf := api.DefaultConfig()
 | 
			
		||||
		conf.Address = addr
 | 
			
		||||
		cli, err := api.NewClient(conf)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("err: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		h := cli.Headers()
 | 
			
		||||
		val, ok := h[consts.RequestHeaderName]
 | 
			
		||||
		if !ok || !reflect.DeepEqual(val, []string{"true"}) {
 | 
			
		||||
			t.Fatalf("invalid %s header", consts.RequestHeaderName)
 | 
			
		||||
		}
 | 
			
		||||
		if !includeVaultRequestHeader {
 | 
			
		||||
			delete(h, consts.RequestHeaderName)
 | 
			
		||||
			cli.SetHeaders(h)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return cli
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//----------------------------------------------------
 | 
			
		||||
	// Start the server and agent
 | 
			
		||||
	//----------------------------------------------------
 | 
			
		||||
 | 
			
		||||
	// Start a vault server
 | 
			
		||||
	logger := logging.NewVaultLogger(hclog.Trace)
 | 
			
		||||
	cluster := vault.NewTestCluster(t,
 | 
			
		||||
		&vault.CoreConfig{
 | 
			
		||||
			Logger: logger,
 | 
			
		||||
			CredentialBackends: map[string]logical.Factory{
 | 
			
		||||
				"approle": credAppRole.Factory,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		&vault.TestClusterOptions{
 | 
			
		||||
			HandlerFunc: vaulthttp.Handler,
 | 
			
		||||
		})
 | 
			
		||||
	cluster.Start()
 | 
			
		||||
	defer cluster.Cleanup()
 | 
			
		||||
	vault.TestWaitActive(t, cluster.Cores[0].Core)
 | 
			
		||||
	serverClient := cluster.Cores[0].Client
 | 
			
		||||
 | 
			
		||||
	// Enable the approle auth method
 | 
			
		||||
	req := serverClient.NewRequest("POST", "/v1/sys/auth/approle")
 | 
			
		||||
	req.BodyBytes = []byte(`{
 | 
			
		||||
		"type": "approle"
 | 
			
		||||
	}`)
 | 
			
		||||
	request(serverClient, req, 204)
 | 
			
		||||
 | 
			
		||||
	// Create a named role
 | 
			
		||||
	req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role")
 | 
			
		||||
	req.BodyBytes = []byte(`{
 | 
			
		||||
	  "secret_id_num_uses": "10",
 | 
			
		||||
	  "secret_id_ttl": "1m",
 | 
			
		||||
	  "token_max_ttl": "1m",
 | 
			
		||||
	  "token_num_uses": "10",
 | 
			
		||||
	  "token_ttl": "1m"
 | 
			
		||||
	}`)
 | 
			
		||||
	request(serverClient, req, 204)
 | 
			
		||||
 | 
			
		||||
	// Fetch the RoleID of the named role
 | 
			
		||||
	req = serverClient.NewRequest("GET", "/v1/auth/approle/role/test-role/role-id")
 | 
			
		||||
	body := request(serverClient, req, 200)
 | 
			
		||||
	data := body["data"].(map[string]interface{})
 | 
			
		||||
	roleID := data["role_id"].(string)
 | 
			
		||||
 | 
			
		||||
	// Get a SecretID issued against the named role
 | 
			
		||||
	req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role/secret-id")
 | 
			
		||||
	body = request(serverClient, req, 200)
 | 
			
		||||
	data = body["data"].(map[string]interface{})
 | 
			
		||||
	secretID := data["secret_id"].(string)
 | 
			
		||||
 | 
			
		||||
	// Write the RoleID and SecretID to temp files
 | 
			
		||||
	roleIDPath := makeTempFile("role_id.txt", roleID+"\n")
 | 
			
		||||
	secretIDPath := makeTempFile("secret_id.txt", secretID+"\n")
 | 
			
		||||
	defer os.Remove(roleIDPath)
 | 
			
		||||
	defer os.Remove(secretIDPath)
 | 
			
		||||
 | 
			
		||||
	// Get a temp file path we can use for the sink
 | 
			
		||||
	sinkPath := makeTempFile("sink.txt", "")
 | 
			
		||||
	defer os.Remove(sinkPath)
 | 
			
		||||
 | 
			
		||||
	// Create a config file
 | 
			
		||||
	config := `
 | 
			
		||||
auto_auth {
 | 
			
		||||
    method "approle" {
 | 
			
		||||
        mount_path = "auth/approle"
 | 
			
		||||
        config = {
 | 
			
		||||
            role_id_file_path = "%s"
 | 
			
		||||
            secret_id_file_path = "%s"
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    sink "file" {
 | 
			
		||||
        config = {
 | 
			
		||||
            path = "%s"
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
cache {
 | 
			
		||||
    use_auto_auth_token = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
listener "tcp" {
 | 
			
		||||
    address = "127.0.0.1:8101"
 | 
			
		||||
    tls_disable = true
 | 
			
		||||
}
 | 
			
		||||
listener "tcp" {
 | 
			
		||||
    address = "127.0.0.1:8102"
 | 
			
		||||
    tls_disable = true
 | 
			
		||||
    require_request_header = false
 | 
			
		||||
}
 | 
			
		||||
listener "tcp" {
 | 
			
		||||
    address = "127.0.0.1:8103"
 | 
			
		||||
    tls_disable = true
 | 
			
		||||
    require_request_header = true
 | 
			
		||||
}
 | 
			
		||||
`
 | 
			
		||||
	config = fmt.Sprintf(config, roleIDPath, secretIDPath, sinkPath)
 | 
			
		||||
	configPath := makeTempFile("config.hcl", config)
 | 
			
		||||
	defer os.Remove(configPath)
 | 
			
		||||
 | 
			
		||||
	// Start the agent
 | 
			
		||||
	ui, cmd := testAgentCommand(t, logger)
 | 
			
		||||
	cmd.client = serverClient
 | 
			
		||||
	cmd.startedCh = make(chan struct{})
 | 
			
		||||
 | 
			
		||||
	wg := &sync.WaitGroup{}
 | 
			
		||||
	wg.Add(1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		code := cmd.Run([]string{"-config", configPath})
 | 
			
		||||
		if code != 0 {
 | 
			
		||||
			t.Errorf("non-zero return code when running agent: %d", code)
 | 
			
		||||
			t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String())
 | 
			
		||||
			t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String())
 | 
			
		||||
		}
 | 
			
		||||
		wg.Done()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-cmd.startedCh:
 | 
			
		||||
	case <-time.After(5 * time.Second):
 | 
			
		||||
		t.Errorf("timeout")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// defer agent shutdown
 | 
			
		||||
	defer func() {
 | 
			
		||||
		cmd.ShutdownCh <- struct{}{}
 | 
			
		||||
		wg.Wait()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	//----------------------------------------------------
 | 
			
		||||
	// Perform the tests
 | 
			
		||||
	//----------------------------------------------------
 | 
			
		||||
 | 
			
		||||
	// Test against a listener configuration that omits
 | 
			
		||||
	// 'require_request_header', with the header missing from the request.
 | 
			
		||||
	agentClient := newApiClient("http://127.0.0.1:8101", false)
 | 
			
		||||
	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | 
			
		||||
	request(agentClient, req, 200)
 | 
			
		||||
 | 
			
		||||
	// Test against a listener configuration that sets 'require_request_header'
 | 
			
		||||
	// to 'false', with the header missing from the request.
 | 
			
		||||
	agentClient = newApiClient("http://127.0.0.1:8102", false)
 | 
			
		||||
	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | 
			
		||||
	request(agentClient, req, 200)
 | 
			
		||||
 | 
			
		||||
	// Test against a listener configuration that sets 'require_request_header'
 | 
			
		||||
	// to 'true', with the header missing from the request.
 | 
			
		||||
	agentClient = newApiClient("http://127.0.0.1:8103", false)
 | 
			
		||||
	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | 
			
		||||
	resp, err := agentClient.RawRequest(req)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Fatalf("expected error")
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != http.StatusPreconditionFailed {
 | 
			
		||||
		t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test against a listener configuration that sets 'require_request_header'
 | 
			
		||||
	// to 'true', with an invalid header present in the request.
 | 
			
		||||
	agentClient = newApiClient("http://127.0.0.1:8103", false)
 | 
			
		||||
	h := agentClient.Headers()
 | 
			
		||||
	h[consts.RequestHeaderName] = []string{"bogus"}
 | 
			
		||||
	agentClient.SetHeaders(h)
 | 
			
		||||
	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | 
			
		||||
	resp, err = agentClient.RawRequest(req)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Fatalf("expected error")
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != http.StatusPreconditionFailed {
 | 
			
		||||
		t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Test against a listener configuration that sets 'require_request_header'
 | 
			
		||||
	// to 'true', with the proper header present in the request.
 | 
			
		||||
	agentClient = newApiClient("http://127.0.0.1:8103", true)
 | 
			
		||||
	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | 
			
		||||
	request(agentClient, req, 200)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -12,6 +12,10 @@ const (
 | 
			
		||||
	// AuthHeaderName is the name of the header containing the token.
 | 
			
		||||
	AuthHeaderName = "X-Vault-Token"
 | 
			
		||||
 | 
			
		||||
	// RequestHeaderName is the name of the header used by the Agent for
 | 
			
		||||
	// SSRF protection.
 | 
			
		||||
	RequestHeaderName = "X-Vault-Request"
 | 
			
		||||
 | 
			
		||||
	// PerformanceReplicationALPN is the negotiated protocol used for
 | 
			
		||||
	// performance replication.
 | 
			
		||||
	PerformanceReplicationALPN = "replication_v1"
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user