mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	 77ceb7dde0
			
		
	
	77ceb7dde0
	
	
	
		
			
			* implement SSRF protection header * add test for SSRF protection header * cleanup * refactor * implement SSRF header on a per-listener basis * cleanup * cleanup * creat unit test for agent SSRF * improve unit test for agent SSRF * add VaultRequest SSRF header to CLI * fix unit test * cleanup * improve test suite * simplify check for Vault-Request header * add constant for Vault-Request header * improve test suite * change 'config' to 'agentConfig' * Revert "change 'config' to 'agentConfig'" This reverts commit 14ee72d21fff8027966ee3c89dd3ac41d849206f. * do not remove header from request * change header name to X-Vault-Request * simplify http.Handler logic * cleanup * simplify http.Handler logic * use stdlib errors package
		
			
				
	
	
		
			623 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			623 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package command
 | |
| 
 | |
| import (
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"io/ioutil"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"reflect"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	hclog "github.com/hashicorp/go-hclog"
 | |
| 	vaultjwt "github.com/hashicorp/vault-plugin-auth-jwt"
 | |
| 	"github.com/hashicorp/vault/api"
 | |
| 	credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
 | |
| 	"github.com/hashicorp/vault/command/agent"
 | |
| 	vaulthttp "github.com/hashicorp/vault/http"
 | |
| 	"github.com/hashicorp/vault/sdk/helper/consts"
 | |
| 	"github.com/hashicorp/vault/sdk/helper/logging"
 | |
| 	"github.com/hashicorp/vault/sdk/logical"
 | |
| 	"github.com/hashicorp/vault/vault"
 | |
| 	"github.com/mitchellh/cli"
 | |
| )
 | |
| 
 | |
| func testAgentCommand(tb testing.TB, logger hclog.Logger) (*cli.MockUi, *AgentCommand) {
 | |
| 	tb.Helper()
 | |
| 
 | |
| 	ui := cli.NewMockUi()
 | |
| 	return ui, &AgentCommand{
 | |
| 		BaseCommand: &BaseCommand{
 | |
| 			UI: ui,
 | |
| 		},
 | |
| 		ShutdownCh: MakeShutdownCh(),
 | |
| 		logger:     logger,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| /*
 | |
| func TestAgent_Cache_UnixListener(t *testing.T) {
 | |
| 	logger := logging.NewVaultLogger(hclog.Trace)
 | |
| 	coreConfig := &vault.CoreConfig{
 | |
| 		Logger: logger.Named("core"),
 | |
| 		CredentialBackends: map[string]logical.Factory{
 | |
| 			"jwt": vaultjwt.Factory,
 | |
| 		},
 | |
| 	}
 | |
| 	cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
 | |
| 		HandlerFunc: vaulthttp.Handler,
 | |
| 	})
 | |
| 	cluster.Start()
 | |
| 	defer cluster.Cleanup()
 | |
| 
 | |
| 	vault.TestWaitActive(t, cluster.Cores[0].Core)
 | |
| 	client := cluster.Cores[0].Client
 | |
| 
 | |
| 	defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
 | |
| 	os.Setenv(api.EnvVaultAddress, client.Address())
 | |
| 
 | |
| 	defer os.Setenv(api.EnvVaultCACert, os.Getenv(api.EnvVaultCACert))
 | |
| 	os.Setenv(api.EnvVaultCACert, fmt.Sprintf("%s/ca_cert.pem", cluster.TempDir))
 | |
| 
 | |
| 	// Setup Vault
 | |
| 	err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
 | |
| 		Type: "jwt",
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
 | |
| 		"bound_issuer":           "https://team-vault.auth0.com/",
 | |
| 		"jwt_validation_pubkeys": agent.TestECDSAPubKey,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
 | |
| 		"role_type":       "jwt",
 | |
| 		"bound_subject":   "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
 | |
| 		"bound_audiences": "https://vault.plugin.auth.jwt.test",
 | |
| 		"user_claim":      "https://vault/user",
 | |
| 		"groups_claim":    "https://vault/groups",
 | |
| 		"policies":        "test",
 | |
| 		"period":          "3s",
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	inf, err := ioutil.TempFile("", "auth.jwt.test.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	in := inf.Name()
 | |
| 	inf.Close()
 | |
| 	os.Remove(in)
 | |
| 	t.Logf("input: %s", in)
 | |
| 
 | |
| 	sink1f, err := ioutil.TempFile("", "sink1.jwt.test.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	sink1 := sink1f.Name()
 | |
| 	sink1f.Close()
 | |
| 	os.Remove(sink1)
 | |
| 	t.Logf("sink1: %s", sink1)
 | |
| 
 | |
| 	sink2f, err := ioutil.TempFile("", "sink2.jwt.test.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	sink2 := sink2f.Name()
 | |
| 	sink2f.Close()
 | |
| 	os.Remove(sink2)
 | |
| 	t.Logf("sink2: %s", sink2)
 | |
| 
 | |
| 	conff, err := ioutil.TempFile("", "conf.jwt.test.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	conf := conff.Name()
 | |
| 	conff.Close()
 | |
| 	os.Remove(conf)
 | |
| 	t.Logf("config: %s", conf)
 | |
| 
 | |
| 	jwtToken, _ := agent.GetTestJWT(t)
 | |
| 	if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else {
 | |
| 		logger.Trace("wrote test jwt", "path", in)
 | |
| 	}
 | |
| 
 | |
| 	socketff, err := ioutil.TempFile("", "cache.socket.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	socketf := socketff.Name()
 | |
| 	socketff.Close()
 | |
| 	os.Remove(socketf)
 | |
| 	t.Logf("socketf: %s", socketf)
 | |
| 
 | |
| 	config := `
 | |
| auto_auth {
 | |
|         method {
 | |
|                 type = "jwt"
 | |
|                 config = {
 | |
|                         role = "test"
 | |
|                         path = "%s"
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|         sink {
 | |
|                 type = "file"
 | |
|                 config = {
 | |
|                         path = "%s"
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|         sink "file" {
 | |
|                 config = {
 | |
|                         path = "%s"
 | |
|                 }
 | |
|         }
 | |
| }
 | |
| 
 | |
| cache {
 | |
| 	use_auto_auth_token = true
 | |
| 
 | |
| 	listener "unix" {
 | |
| 		address = "%s"
 | |
| 		tls_disable = true
 | |
| 	}
 | |
| }
 | |
| `
 | |
| 
 | |
| 	config = fmt.Sprintf(config, in, sink1, sink2, socketf)
 | |
| 	if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else {
 | |
| 		logger.Trace("wrote test config", "path", conf)
 | |
| 	}
 | |
| 
 | |
| 	_, cmd := testAgentCommand(t, logger)
 | |
| 	cmd.client = client
 | |
| 
 | |
| 	// Kill the command 5 seconds after it starts
 | |
| 	go func() {
 | |
| 		select {
 | |
| 		case <-cmd.ShutdownCh:
 | |
| 		case <-time.After(5 * time.Second):
 | |
| 			cmd.ShutdownCh <- struct{}{}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddr)
 | |
| 
 | |
| 	// Create a client that talks to the agent
 | |
| 	os.Setenv(api.EnvVaultAgentAddr, socketf)
 | |
| 	testClient, err := api.NewClient(api.DefaultConfig())
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	os.Setenv(api.EnvVaultAgentAddr, originalVaultAgentAddress)
 | |
| 
 | |
| 	// Start the agent
 | |
| 	go cmd.Run([]string{"-config", conf})
 | |
| 
 | |
| 	// Give some time for the auto-auth to complete
 | |
| 	time.Sleep(1 * time.Second)
 | |
| 
 | |
| 	// Invoke lookup self through the agent
 | |
| 	secret, err := testClient.Auth().Token().LookupSelf()
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	if secret == nil || secret.Data == nil || secret.Data["id"].(string) == "" {
 | |
| 		t.Fatalf("failed to perform lookup self through agent")
 | |
| 	}
 | |
| }
 | |
| */
 | |
| 
 | |
| func TestExitAfterAuth(t *testing.T) {
 | |
| 	logger := logging.NewVaultLogger(hclog.Trace)
 | |
| 	coreConfig := &vault.CoreConfig{
 | |
| 		Logger: logger,
 | |
| 		CredentialBackends: map[string]logical.Factory{
 | |
| 			"jwt": vaultjwt.Factory,
 | |
| 		},
 | |
| 	}
 | |
| 	cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
 | |
| 		HandlerFunc: vaulthttp.Handler,
 | |
| 	})
 | |
| 	cluster.Start()
 | |
| 	defer cluster.Cleanup()
 | |
| 
 | |
| 	vault.TestWaitActive(t, cluster.Cores[0].Core)
 | |
| 	client := cluster.Cores[0].Client
 | |
| 
 | |
| 	// Setup Vault
 | |
| 	err := client.Sys().EnableAuthWithOptions("jwt", &api.EnableAuthOptions{
 | |
| 		Type: "jwt",
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
 | |
| 		"bound_issuer":           "https://team-vault.auth0.com/",
 | |
| 		"jwt_validation_pubkeys": agent.TestECDSAPubKey,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	_, err = client.Logical().Write("auth/jwt/role/test", map[string]interface{}{
 | |
| 		"role_type":       "jwt",
 | |
| 		"bound_subject":   "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
 | |
| 		"bound_audiences": "https://vault.plugin.auth.jwt.test",
 | |
| 		"user_claim":      "https://vault/user",
 | |
| 		"groups_claim":    "https://vault/groups",
 | |
| 		"policies":        "test",
 | |
| 		"period":          "3s",
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 
 | |
| 	inf, err := ioutil.TempFile("", "auth.jwt.test.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	in := inf.Name()
 | |
| 	inf.Close()
 | |
| 	os.Remove(in)
 | |
| 	t.Logf("input: %s", in)
 | |
| 
 | |
| 	sink1f, err := ioutil.TempFile("", "sink1.jwt.test.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	sink1 := sink1f.Name()
 | |
| 	sink1f.Close()
 | |
| 	os.Remove(sink1)
 | |
| 	t.Logf("sink1: %s", sink1)
 | |
| 
 | |
| 	sink2f, err := ioutil.TempFile("", "sink2.jwt.test.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	sink2 := sink2f.Name()
 | |
| 	sink2f.Close()
 | |
| 	os.Remove(sink2)
 | |
| 	t.Logf("sink2: %s", sink2)
 | |
| 
 | |
| 	conff, err := ioutil.TempFile("", "conf.jwt.test.")
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	conf := conff.Name()
 | |
| 	conff.Close()
 | |
| 	os.Remove(conf)
 | |
| 	t.Logf("config: %s", conf)
 | |
| 
 | |
| 	jwtToken, _ := agent.GetTestJWT(t)
 | |
| 	if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else {
 | |
| 		logger.Trace("wrote test jwt", "path", in)
 | |
| 	}
 | |
| 
 | |
| 	config := `
 | |
| exit_after_auth = true
 | |
| 
 | |
| auto_auth {
 | |
|         method {
 | |
|                 type = "jwt"
 | |
|                 config = {
 | |
|                         role = "test"
 | |
|                         path = "%s"
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|         sink {
 | |
|                 type = "file"
 | |
|                 config = {
 | |
|                         path = "%s"
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|         sink "file" {
 | |
|                 config = {
 | |
|                         path = "%s"
 | |
|                 }
 | |
|         }
 | |
| }
 | |
| `
 | |
| 
 | |
| 	config = fmt.Sprintf(config, in, sink1, sink2)
 | |
| 	if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	} else {
 | |
| 		logger.Trace("wrote test config", "path", conf)
 | |
| 	}
 | |
| 
 | |
| 	// If this hangs forever until the test times out, exit-after-auth isn't
 | |
| 	// working
 | |
| 	ui, cmd := testAgentCommand(t, logger)
 | |
| 	cmd.client = client
 | |
| 
 | |
| 	code := cmd.Run([]string{"-config", conf})
 | |
| 	if code != 0 {
 | |
| 		t.Errorf("expected %d to be %d", code, 0)
 | |
| 		t.Logf("output from agent:\n%s", ui.OutputWriter.String())
 | |
| 		t.Logf("error from agent:\n%s", ui.ErrorWriter.String())
 | |
| 	}
 | |
| 
 | |
| 	sink1Bytes, err := ioutil.ReadFile(sink1)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	if len(sink1Bytes) == 0 {
 | |
| 		t.Fatal("got no output from sink 1")
 | |
| 	}
 | |
| 
 | |
| 	sink2Bytes, err := ioutil.ReadFile(sink2)
 | |
| 	if err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| 	if len(sink2Bytes) == 0 {
 | |
| 		t.Fatal("got no output from sink 2")
 | |
| 	}
 | |
| 
 | |
| 	if string(sink1Bytes) != string(sink2Bytes) {
 | |
| 		t.Fatal("sink 1/2 values don't match")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestAgent_RequireRequestHeader(t *testing.T) {
 | |
| 
 | |
| 	// request issues HTTP requests.
 | |
| 	request := func(client *api.Client, req *api.Request, expectedStatusCode int) map[string]interface{} {
 | |
| 		resp, err := client.RawRequest(req)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("err: %s", err)
 | |
| 		}
 | |
| 		if resp.StatusCode != expectedStatusCode {
 | |
| 			t.Fatalf("expected status code %d, not %d", expectedStatusCode, resp.StatusCode)
 | |
| 		}
 | |
| 
 | |
| 		bytes, err := ioutil.ReadAll(resp.Body)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("err: %s", err)
 | |
| 		}
 | |
| 		if len(bytes) == 0 {
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		var body map[string]interface{}
 | |
| 		err = json.Unmarshal(bytes, &body)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("err: %s", err)
 | |
| 		}
 | |
| 		return body
 | |
| 	}
 | |
| 
 | |
| 	// makeTempFile creates a temp file and populates it.
 | |
| 	makeTempFile := func(name, contents string) string {
 | |
| 		f, err := ioutil.TempFile("", name)
 | |
| 		if err != nil {
 | |
| 			t.Fatal(err)
 | |
| 		}
 | |
| 		path := f.Name()
 | |
| 		f.WriteString(contents)
 | |
| 		f.Close()
 | |
| 		return path
 | |
| 	}
 | |
| 
 | |
| 	// newApiClient creates an *api.Client.
 | |
| 	newApiClient := func(addr string, includeVaultRequestHeader bool) *api.Client {
 | |
| 		conf := api.DefaultConfig()
 | |
| 		conf.Address = addr
 | |
| 		cli, err := api.NewClient(conf)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("err: %s", err)
 | |
| 		}
 | |
| 
 | |
| 		h := cli.Headers()
 | |
| 		val, ok := h[consts.RequestHeaderName]
 | |
| 		if !ok || !reflect.DeepEqual(val, []string{"true"}) {
 | |
| 			t.Fatalf("invalid %s header", consts.RequestHeaderName)
 | |
| 		}
 | |
| 		if !includeVaultRequestHeader {
 | |
| 			delete(h, consts.RequestHeaderName)
 | |
| 			cli.SetHeaders(h)
 | |
| 		}
 | |
| 
 | |
| 		return cli
 | |
| 	}
 | |
| 
 | |
| 	//----------------------------------------------------
 | |
| 	// Start the server and agent
 | |
| 	//----------------------------------------------------
 | |
| 
 | |
| 	// Start a vault server
 | |
| 	logger := logging.NewVaultLogger(hclog.Trace)
 | |
| 	cluster := vault.NewTestCluster(t,
 | |
| 		&vault.CoreConfig{
 | |
| 			Logger: logger,
 | |
| 			CredentialBackends: map[string]logical.Factory{
 | |
| 				"approle": credAppRole.Factory,
 | |
| 			},
 | |
| 		},
 | |
| 		&vault.TestClusterOptions{
 | |
| 			HandlerFunc: vaulthttp.Handler,
 | |
| 		})
 | |
| 	cluster.Start()
 | |
| 	defer cluster.Cleanup()
 | |
| 	vault.TestWaitActive(t, cluster.Cores[0].Core)
 | |
| 	serverClient := cluster.Cores[0].Client
 | |
| 
 | |
| 	// Enable the approle auth method
 | |
| 	req := serverClient.NewRequest("POST", "/v1/sys/auth/approle")
 | |
| 	req.BodyBytes = []byte(`{
 | |
| 		"type": "approle"
 | |
| 	}`)
 | |
| 	request(serverClient, req, 204)
 | |
| 
 | |
| 	// Create a named role
 | |
| 	req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role")
 | |
| 	req.BodyBytes = []byte(`{
 | |
| 	  "secret_id_num_uses": "10",
 | |
| 	  "secret_id_ttl": "1m",
 | |
| 	  "token_max_ttl": "1m",
 | |
| 	  "token_num_uses": "10",
 | |
| 	  "token_ttl": "1m"
 | |
| 	}`)
 | |
| 	request(serverClient, req, 204)
 | |
| 
 | |
| 	// Fetch the RoleID of the named role
 | |
| 	req = serverClient.NewRequest("GET", "/v1/auth/approle/role/test-role/role-id")
 | |
| 	body := request(serverClient, req, 200)
 | |
| 	data := body["data"].(map[string]interface{})
 | |
| 	roleID := data["role_id"].(string)
 | |
| 
 | |
| 	// Get a SecretID issued against the named role
 | |
| 	req = serverClient.NewRequest("PUT", "/v1/auth/approle/role/test-role/secret-id")
 | |
| 	body = request(serverClient, req, 200)
 | |
| 	data = body["data"].(map[string]interface{})
 | |
| 	secretID := data["secret_id"].(string)
 | |
| 
 | |
| 	// Write the RoleID and SecretID to temp files
 | |
| 	roleIDPath := makeTempFile("role_id.txt", roleID+"\n")
 | |
| 	secretIDPath := makeTempFile("secret_id.txt", secretID+"\n")
 | |
| 	defer os.Remove(roleIDPath)
 | |
| 	defer os.Remove(secretIDPath)
 | |
| 
 | |
| 	// Get a temp file path we can use for the sink
 | |
| 	sinkPath := makeTempFile("sink.txt", "")
 | |
| 	defer os.Remove(sinkPath)
 | |
| 
 | |
| 	// Create a config file
 | |
| 	config := `
 | |
| auto_auth {
 | |
|     method "approle" {
 | |
|         mount_path = "auth/approle"
 | |
|         config = {
 | |
|             role_id_file_path = "%s"
 | |
|             secret_id_file_path = "%s"
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     sink "file" {
 | |
|         config = {
 | |
|             path = "%s"
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| cache {
 | |
|     use_auto_auth_token = true
 | |
| }
 | |
| 
 | |
| listener "tcp" {
 | |
|     address = "127.0.0.1:8101"
 | |
|     tls_disable = true
 | |
| }
 | |
| listener "tcp" {
 | |
|     address = "127.0.0.1:8102"
 | |
|     tls_disable = true
 | |
|     require_request_header = false
 | |
| }
 | |
| listener "tcp" {
 | |
|     address = "127.0.0.1:8103"
 | |
|     tls_disable = true
 | |
|     require_request_header = true
 | |
| }
 | |
| `
 | |
| 	config = fmt.Sprintf(config, roleIDPath, secretIDPath, sinkPath)
 | |
| 	configPath := makeTempFile("config.hcl", config)
 | |
| 	defer os.Remove(configPath)
 | |
| 
 | |
| 	// Start the agent
 | |
| 	ui, cmd := testAgentCommand(t, logger)
 | |
| 	cmd.client = serverClient
 | |
| 	cmd.startedCh = make(chan struct{})
 | |
| 
 | |
| 	wg := &sync.WaitGroup{}
 | |
| 	wg.Add(1)
 | |
| 	go func() {
 | |
| 		code := cmd.Run([]string{"-config", configPath})
 | |
| 		if code != 0 {
 | |
| 			t.Errorf("non-zero return code when running agent: %d", code)
 | |
| 			t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String())
 | |
| 			t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String())
 | |
| 		}
 | |
| 		wg.Done()
 | |
| 	}()
 | |
| 
 | |
| 	select {
 | |
| 	case <-cmd.startedCh:
 | |
| 	case <-time.After(5 * time.Second):
 | |
| 		t.Errorf("timeout")
 | |
| 	}
 | |
| 
 | |
| 	// defer agent shutdown
 | |
| 	defer func() {
 | |
| 		cmd.ShutdownCh <- struct{}{}
 | |
| 		wg.Wait()
 | |
| 	}()
 | |
| 
 | |
| 	//----------------------------------------------------
 | |
| 	// Perform the tests
 | |
| 	//----------------------------------------------------
 | |
| 
 | |
| 	// Test against a listener configuration that omits
 | |
| 	// 'require_request_header', with the header missing from the request.
 | |
| 	agentClient := newApiClient("http://127.0.0.1:8101", false)
 | |
| 	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | |
| 	request(agentClient, req, 200)
 | |
| 
 | |
| 	// Test against a listener configuration that sets 'require_request_header'
 | |
| 	// to 'false', with the header missing from the request.
 | |
| 	agentClient = newApiClient("http://127.0.0.1:8102", false)
 | |
| 	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | |
| 	request(agentClient, req, 200)
 | |
| 
 | |
| 	// Test against a listener configuration that sets 'require_request_header'
 | |
| 	// to 'true', with the header missing from the request.
 | |
| 	agentClient = newApiClient("http://127.0.0.1:8103", false)
 | |
| 	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | |
| 	resp, err := agentClient.RawRequest(req)
 | |
| 	if err == nil {
 | |
| 		t.Fatalf("expected error")
 | |
| 	}
 | |
| 	if resp.StatusCode != http.StatusPreconditionFailed {
 | |
| 		t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode)
 | |
| 	}
 | |
| 
 | |
| 	// Test against a listener configuration that sets 'require_request_header'
 | |
| 	// to 'true', with an invalid header present in the request.
 | |
| 	agentClient = newApiClient("http://127.0.0.1:8103", false)
 | |
| 	h := agentClient.Headers()
 | |
| 	h[consts.RequestHeaderName] = []string{"bogus"}
 | |
| 	agentClient.SetHeaders(h)
 | |
| 	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | |
| 	resp, err = agentClient.RawRequest(req)
 | |
| 	if err == nil {
 | |
| 		t.Fatalf("expected error")
 | |
| 	}
 | |
| 	if resp.StatusCode != http.StatusPreconditionFailed {
 | |
| 		t.Fatalf("expected status code %d, not %d", http.StatusPreconditionFailed, resp.StatusCode)
 | |
| 	}
 | |
| 
 | |
| 	// Test against a listener configuration that sets 'require_request_header'
 | |
| 	// to 'true', with the proper header present in the request.
 | |
| 	agentClient = newApiClient("http://127.0.0.1:8103", true)
 | |
| 	req = agentClient.NewRequest("GET", "/v1/sys/health")
 | |
| 	request(agentClient, req, 200)
 | |
| }
 |