mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-04 04:28:08 +00:00 
			
		
		
		
	* Added flag and env var which will disable client redirection * Added changelog * Docs fix for unsaved file, and test single request made * Updated test for case when redirect is enabled, updated docs based on suggestions
		
			
				
	
	
		
			1436 lines
		
	
	
		
			35 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			1436 lines
		
	
	
		
			35 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package api
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"crypto/x509"
 | 
						|
	"encoding/base64"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"os"
 | 
						|
	"reflect"
 | 
						|
	"sort"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/go-test/deep"
 | 
						|
	"github.com/hashicorp/go-hclog"
 | 
						|
	"github.com/hashicorp/vault/sdk/helper/consts"
 | 
						|
)
 | 
						|
 | 
						|
func init() {
 | 
						|
	// Ensure our special envvars are not present
 | 
						|
	os.Setenv("VAULT_ADDR", "")
 | 
						|
	os.Setenv("VAULT_TOKEN", "")
 | 
						|
}
 | 
						|
 | 
						|
func TestDefaultConfig_envvar(t *testing.T) {
 | 
						|
	os.Setenv("VAULT_ADDR", "https://vault.mycompany.com")
 | 
						|
	defer os.Setenv("VAULT_ADDR", "")
 | 
						|
 | 
						|
	config := DefaultConfig()
 | 
						|
	if config.Address != "https://vault.mycompany.com" {
 | 
						|
		t.Fatalf("bad: %s", config.Address)
 | 
						|
	}
 | 
						|
 | 
						|
	os.Setenv("VAULT_TOKEN", "testing")
 | 
						|
	defer os.Setenv("VAULT_TOKEN", "")
 | 
						|
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if token := client.Token(); token != "testing" {
 | 
						|
		t.Fatalf("bad: %s", token)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientDefaultHttpClient(t *testing.T) {
 | 
						|
	_, err := NewClient(&Config{
 | 
						|
		HttpClient: http.DefaultClient,
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientNilConfig(t *testing.T) {
 | 
						|
	client, err := NewClient(nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if client == nil {
 | 
						|
		t.Fatal("expected a non-nil client")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientDefaultHttpClient_unixSocket(t *testing.T) {
 | 
						|
	os.Setenv("VAULT_AGENT_ADDR", "unix:///var/run/vault.sock")
 | 
						|
	defer os.Setenv("VAULT_AGENT_ADDR", "")
 | 
						|
 | 
						|
	client, err := NewClient(nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if client == nil {
 | 
						|
		t.Fatal("expected a non-nil client")
 | 
						|
	}
 | 
						|
	if client.addr.Scheme != "http" {
 | 
						|
		t.Fatalf("bad: %s", client.addr.Scheme)
 | 
						|
	}
 | 
						|
	if client.addr.Host != "/var/run/vault.sock" {
 | 
						|
		t.Fatalf("bad: %s", client.addr.Host)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientSetAddress(t *testing.T) {
 | 
						|
	client, err := NewClient(nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	// Start with TCP address using HTTP
 | 
						|
	if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if client.addr.Host != "172.168.2.1:8300" {
 | 
						|
		t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
 | 
						|
	}
 | 
						|
	// Test switching to Unix Socket address from TCP address
 | 
						|
	if err := client.SetAddress("unix:///var/run/vault.sock"); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if client.addr.Scheme != "http" {
 | 
						|
		t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
 | 
						|
	}
 | 
						|
	if client.addr.Host != "/var/run/vault.sock" {
 | 
						|
		t.Fatalf("bad: expected: '/var/run/vault.sock' actual: %q", client.addr.Host)
 | 
						|
	}
 | 
						|
	if client.addr.Path != "" {
 | 
						|
		t.Fatalf("bad: expected '' actual: %q", client.addr.Path)
 | 
						|
	}
 | 
						|
	if client.config.HttpClient.Transport.(*http.Transport).DialContext == nil {
 | 
						|
		t.Fatal("bad: expected DialContext to not be nil")
 | 
						|
	}
 | 
						|
	// Test switching to TCP address from Unix Socket address
 | 
						|
	if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if client.addr.Host != "172.168.2.1:8300" {
 | 
						|
		t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
 | 
						|
	}
 | 
						|
	if client.addr.Scheme != "http" {
 | 
						|
		t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientToken(t *testing.T) {
 | 
						|
	tokenValue := "foo"
 | 
						|
	handler := func(w http.ResponseWriter, req *http.Request) {}
 | 
						|
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(handler))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	client.SetToken(tokenValue)
 | 
						|
 | 
						|
	// Verify the token is set
 | 
						|
	if v := client.Token(); v != tokenValue {
 | 
						|
		t.Fatalf("bad: %s", v)
 | 
						|
	}
 | 
						|
 | 
						|
	client.ClearToken()
 | 
						|
 | 
						|
	if v := client.Token(); v != "" {
 | 
						|
		t.Fatalf("bad: %s", v)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientHostHeader(t *testing.T) {
 | 
						|
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		w.Write([]byte(req.Host))
 | 
						|
	}
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(handler))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	config.Address = strings.ReplaceAll(config.Address, "127.0.0.1", "localhost")
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Set the token manually
 | 
						|
	client.SetToken("foo")
 | 
						|
 | 
						|
	resp, err := client.RawRequest(client.NewRequest(http.MethodPut, "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Copy the response
 | 
						|
	var buf bytes.Buffer
 | 
						|
	io.Copy(&buf, resp.Body)
 | 
						|
 | 
						|
	// Verify we got the response from the primary
 | 
						|
	if buf.String() != strings.ReplaceAll(config.Address, "http://", "") {
 | 
						|
		t.Fatalf("Bad address: %s", buf.String())
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientBadToken(t *testing.T) {
 | 
						|
	handler := func(w http.ResponseWriter, req *http.Request) {}
 | 
						|
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(handler))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	client.SetToken("foo")
 | 
						|
	_, err = client.RawRequest(client.NewRequest(http.MethodPut, "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	client.SetToken("foo\u007f")
 | 
						|
	_, err = client.RawRequest(client.NewRequest(http.MethodPut, "/"))
 | 
						|
	if err == nil || !strings.Contains(err.Error(), "printable") {
 | 
						|
		t.Fatalf("expected error due to bad token")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientDisableRedirects(t *testing.T) {
 | 
						|
	tests := map[string]struct {
 | 
						|
		statusCode       int
 | 
						|
		expectedNumReqs  int
 | 
						|
		disableRedirects bool
 | 
						|
	}{
 | 
						|
		"Disabled redirects: Moved permanently":  {statusCode: 301, expectedNumReqs: 1, disableRedirects: true},
 | 
						|
		"Disabled redirects: Found":              {statusCode: 302, expectedNumReqs: 1, disableRedirects: true},
 | 
						|
		"Disabled redirects: Temporary Redirect": {statusCode: 307, expectedNumReqs: 1, disableRedirects: true},
 | 
						|
		"Enable redirects: Moved permanently":    {statusCode: 301, expectedNumReqs: 2, disableRedirects: false},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, tc := range tests {
 | 
						|
		test := tc
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			t.Parallel()
 | 
						|
			numReqs := 0
 | 
						|
			var config *Config
 | 
						|
 | 
						|
			respFunc := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
				// Track how many requests the server has handled
 | 
						|
				numReqs++
 | 
						|
				// Send back the relevant status code and generate a location
 | 
						|
				w.Header().Set("Location", fmt.Sprintf(config.Address+"/reqs/%v", numReqs))
 | 
						|
				w.WriteHeader(test.statusCode)
 | 
						|
			}
 | 
						|
 | 
						|
			config, ln := testHTTPServer(t, http.HandlerFunc(respFunc))
 | 
						|
			config.DisableRedirects = test.disableRedirects
 | 
						|
			defer ln.Close()
 | 
						|
 | 
						|
			client, err := NewClient(config)
 | 
						|
			if err != nil {
 | 
						|
				t.Fatalf("%s: error %v", name, err)
 | 
						|
			}
 | 
						|
 | 
						|
			req := client.NewRequest("GET", "/")
 | 
						|
			resp, err := client.rawRequestWithContext(context.Background(), req)
 | 
						|
			if err != nil {
 | 
						|
				t.Fatalf("%s: error %v", name, err)
 | 
						|
			}
 | 
						|
 | 
						|
			if numReqs != test.expectedNumReqs {
 | 
						|
				t.Fatalf("%s: expected %v request(s) but got %v", name, test.expectedNumReqs, numReqs)
 | 
						|
			}
 | 
						|
 | 
						|
			if resp.StatusCode != test.statusCode {
 | 
						|
				t.Fatalf("%s: expected status code %v got %v", name, test.statusCode, resp.StatusCode)
 | 
						|
			}
 | 
						|
 | 
						|
			location, err := resp.Location()
 | 
						|
			if err != nil {
 | 
						|
				t.Fatalf("%s error %v", name, err)
 | 
						|
			}
 | 
						|
			if req.URL.String() == location.String() {
 | 
						|
				t.Fatalf("%s: expected request URL %v to be different from redirect URL %v", name, req.URL, resp.Request.URL)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientRedirect(t *testing.T) {
 | 
						|
	primary := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		w.Write([]byte("test"))
 | 
						|
	}
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(primary))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	standby := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		w.Header().Set("Location", config.Address)
 | 
						|
		w.WriteHeader(307)
 | 
						|
	}
 | 
						|
	config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby))
 | 
						|
	defer ln2.Close()
 | 
						|
 | 
						|
	client, err := NewClient(config2)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Set the token manually
 | 
						|
	client.SetToken("foo")
 | 
						|
 | 
						|
	// Do a raw "/" request
 | 
						|
	resp, err := client.RawRequest(client.NewRequest(http.MethodPut, "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Copy the response
 | 
						|
	var buf bytes.Buffer
 | 
						|
	io.Copy(&buf, resp.Body)
 | 
						|
 | 
						|
	// Verify we got the response from the primary
 | 
						|
	if buf.String() != "test" {
 | 
						|
		t.Fatalf("Bad: %s", buf.String())
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDefaulRetryPolicy(t *testing.T) {
 | 
						|
	cases := map[string]struct {
 | 
						|
		resp      *http.Response
 | 
						|
		err       error
 | 
						|
		expect    bool
 | 
						|
		expectErr error
 | 
						|
	}{
 | 
						|
		"retry on error": {
 | 
						|
			err:    fmt.Errorf("error"),
 | 
						|
			expect: true,
 | 
						|
		},
 | 
						|
		"don't retry connection failures": {
 | 
						|
			err: &url.Error{
 | 
						|
				Err: x509.UnknownAuthorityError{},
 | 
						|
			},
 | 
						|
		},
 | 
						|
		"don't retry on 200": {
 | 
						|
			resp: &http.Response{
 | 
						|
				StatusCode: http.StatusOK,
 | 
						|
			},
 | 
						|
		},
 | 
						|
		"don't retry on 4xx": {
 | 
						|
			resp: &http.Response{
 | 
						|
				StatusCode: http.StatusBadRequest,
 | 
						|
			},
 | 
						|
		},
 | 
						|
		"don't retry on 501": {
 | 
						|
			resp: &http.Response{
 | 
						|
				StatusCode: http.StatusNotImplemented,
 | 
						|
			},
 | 
						|
		},
 | 
						|
		"retry on 500": {
 | 
						|
			resp: &http.Response{
 | 
						|
				StatusCode: http.StatusInternalServerError,
 | 
						|
			},
 | 
						|
			expect: true,
 | 
						|
		},
 | 
						|
		"retry on 5xx": {
 | 
						|
			resp: &http.Response{
 | 
						|
				StatusCode: http.StatusGatewayTimeout,
 | 
						|
			},
 | 
						|
			expect: true,
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, test := range cases {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			retry, err := DefaultRetryPolicy(context.Background(), test.resp, test.err)
 | 
						|
			if retry != test.expect {
 | 
						|
				t.Fatalf("expected to retry request: '%t', but actual result was: '%t'", test.expect, retry)
 | 
						|
			}
 | 
						|
			if err != test.expectErr {
 | 
						|
				t.Fatalf("expected error from retry policy: %q, but actual result was: %q", err, test.expectErr)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientEnvSettings(t *testing.T) {
 | 
						|
	cwd, _ := os.Getwd()
 | 
						|
 | 
						|
	caCertBytes, err := os.ReadFile(cwd + "/test-fixtures/keys/cert.pem")
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("error reading %q cert file: %v", cwd+"/test-fixtures/keys/cert.pem", err)
 | 
						|
	}
 | 
						|
 | 
						|
	oldCACert := os.Getenv(EnvVaultCACert)
 | 
						|
	oldCACertBytes := os.Getenv(EnvVaultCACertBytes)
 | 
						|
	oldCAPath := os.Getenv(EnvVaultCAPath)
 | 
						|
	oldClientCert := os.Getenv(EnvVaultClientCert)
 | 
						|
	oldClientKey := os.Getenv(EnvVaultClientKey)
 | 
						|
	oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
 | 
						|
	oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
 | 
						|
	oldDisableRedirects := os.Getenv(EnvVaultDisableRedirects)
 | 
						|
 | 
						|
	os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
 | 
						|
	os.Setenv(EnvVaultCACertBytes, string(caCertBytes))
 | 
						|
	os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
 | 
						|
	os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
 | 
						|
	os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
 | 
						|
	os.Setenv(EnvVaultSkipVerify, "true")
 | 
						|
	os.Setenv(EnvVaultMaxRetries, "5")
 | 
						|
	os.Setenv(EnvVaultDisableRedirects, "true")
 | 
						|
 | 
						|
	defer func() {
 | 
						|
		os.Setenv(EnvVaultCACert, oldCACert)
 | 
						|
		os.Setenv(EnvVaultCACertBytes, oldCACertBytes)
 | 
						|
		os.Setenv(EnvVaultCAPath, oldCAPath)
 | 
						|
		os.Setenv(EnvVaultClientCert, oldClientCert)
 | 
						|
		os.Setenv(EnvVaultClientKey, oldClientKey)
 | 
						|
		os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
 | 
						|
		os.Setenv(EnvVaultMaxRetries, oldMaxRetries)
 | 
						|
		os.Setenv(EnvVaultDisableRedirects, oldDisableRedirects)
 | 
						|
	}()
 | 
						|
 | 
						|
	config := DefaultConfig()
 | 
						|
	if err := config.ReadEnvironment(); err != nil {
 | 
						|
		t.Fatalf("error reading environment: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
 | 
						|
	if len(tlsConfig.RootCAs.Subjects()) == 0 {
 | 
						|
		t.Fatalf("bad: expected a cert pool with at least one subject")
 | 
						|
	}
 | 
						|
	if tlsConfig.GetClientCertificate == nil {
 | 
						|
		t.Fatalf("bad: expected client tls config to have a certificate getter")
 | 
						|
	}
 | 
						|
	if tlsConfig.InsecureSkipVerify != true {
 | 
						|
		t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
 | 
						|
	}
 | 
						|
	if config.DisableRedirects != true {
 | 
						|
		t.Fatalf("bad: expected disable redirects to be true: %v", config.DisableRedirects)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientDeprecatedEnvSettings(t *testing.T) {
 | 
						|
	oldInsecure := os.Getenv(EnvVaultInsecure)
 | 
						|
	os.Setenv(EnvVaultInsecure, "true")
 | 
						|
	defer os.Setenv(EnvVaultInsecure, oldInsecure)
 | 
						|
 | 
						|
	config := DefaultConfig()
 | 
						|
	if err := config.ReadEnvironment(); err != nil {
 | 
						|
		t.Fatalf("error reading environment: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
 | 
						|
	if tlsConfig.InsecureSkipVerify != true {
 | 
						|
		t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientEnvNamespace(t *testing.T) {
 | 
						|
	var seenNamespace string
 | 
						|
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		seenNamespace = req.Header.Get(consts.NamespaceHeaderName)
 | 
						|
	}
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(handler))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	oldVaultNamespace := os.Getenv(EnvVaultNamespace)
 | 
						|
	defer os.Setenv(EnvVaultNamespace, oldVaultNamespace)
 | 
						|
	os.Setenv(EnvVaultNamespace, "test")
 | 
						|
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	_, err = client.RawRequest(client.NewRequest(http.MethodGet, "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if seenNamespace != "test" {
 | 
						|
		t.Fatalf("Bad: %s", seenNamespace)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParsingRateAndBurst(t *testing.T) {
 | 
						|
	var (
 | 
						|
		correctFormat                    = "400:400"
 | 
						|
		observedRate, observedBurst, err = parseRateLimit(correctFormat)
 | 
						|
		expectedRate, expectedBurst      = float64(400), 400
 | 
						|
	)
 | 
						|
	if err != nil {
 | 
						|
		t.Error(err)
 | 
						|
	}
 | 
						|
	if expectedRate != observedRate {
 | 
						|
		t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
 | 
						|
	}
 | 
						|
	if expectedBurst != observedBurst {
 | 
						|
		t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParsingRateOnly(t *testing.T) {
 | 
						|
	var (
 | 
						|
		correctFormat                    = "400"
 | 
						|
		observedRate, observedBurst, err = parseRateLimit(correctFormat)
 | 
						|
		expectedRate, expectedBurst      = float64(400), 400
 | 
						|
	)
 | 
						|
	if err != nil {
 | 
						|
		t.Error(err)
 | 
						|
	}
 | 
						|
	if expectedRate != observedRate {
 | 
						|
		t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
 | 
						|
	}
 | 
						|
	if expectedBurst != observedBurst {
 | 
						|
		t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParsingErrorCase(t *testing.T) {
 | 
						|
	incorrectFormat := "foobar"
 | 
						|
	_, _, err := parseRateLimit(incorrectFormat)
 | 
						|
	if err == nil {
 | 
						|
		t.Error("Expected error, found no error")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientTimeoutSetting(t *testing.T) {
 | 
						|
	oldClientTimeout := os.Getenv(EnvVaultClientTimeout)
 | 
						|
	os.Setenv(EnvVaultClientTimeout, "10")
 | 
						|
	defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout)
 | 
						|
	config := DefaultConfig()
 | 
						|
	config.ReadEnvironment()
 | 
						|
	_, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type roundTripperFunc func(*http.Request) (*http.Response, error)
 | 
						|
 | 
						|
func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
 | 
						|
	return rt(r)
 | 
						|
}
 | 
						|
 | 
						|
func TestClientNonTransportRoundTripper(t *testing.T) {
 | 
						|
	client := &http.Client{
 | 
						|
		Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
 | 
						|
	}
 | 
						|
 | 
						|
	_, err := NewClient(&Config{
 | 
						|
		HttpClient: client,
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientNonTransportRoundTripperUnixAddress(t *testing.T) {
 | 
						|
	client := &http.Client{
 | 
						|
		Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
 | 
						|
	}
 | 
						|
 | 
						|
	_, err := NewClient(&Config{
 | 
						|
		HttpClient: client,
 | 
						|
		Address:    "unix:///var/run/vault.sock",
 | 
						|
	})
 | 
						|
	if err == nil {
 | 
						|
		t.Fatal("bad: expected error got nil")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClone(t *testing.T) {
 | 
						|
	type fields struct{}
 | 
						|
	tests := []struct {
 | 
						|
		name    string
 | 
						|
		config  *Config
 | 
						|
		headers *http.Header
 | 
						|
		token   string
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name:   "default",
 | 
						|
			config: DefaultConfig(),
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "cloneHeaders",
 | 
						|
			config: &Config{
 | 
						|
				CloneHeaders: true,
 | 
						|
			},
 | 
						|
			headers: &http.Header{
 | 
						|
				"X-foo": []string{"bar"},
 | 
						|
				"X-baz": []string{"qux"},
 | 
						|
			},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "preventStaleReads",
 | 
						|
			config: &Config{
 | 
						|
				ReadYourWrites: true,
 | 
						|
			},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "cloneToken",
 | 
						|
			config: &Config{
 | 
						|
				CloneToken: true,
 | 
						|
			},
 | 
						|
			token: "cloneToken",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tt := range tests {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			parent, err := NewClient(tt.config)
 | 
						|
			if err != nil {
 | 
						|
				t.Fatalf("NewClient failed: %v", err)
 | 
						|
			}
 | 
						|
 | 
						|
			// Set all of the things that we provide setter methods for, which modify config values
 | 
						|
			err = parent.SetAddress("http://example.com:8080")
 | 
						|
			if err != nil {
 | 
						|
				t.Fatalf("SetAddress failed: %v", err)
 | 
						|
			}
 | 
						|
 | 
						|
			clientTimeout := time.Until(time.Now().AddDate(0, 0, 1))
 | 
						|
			parent.SetClientTimeout(clientTimeout)
 | 
						|
 | 
						|
			checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) {
 | 
						|
				return true, nil
 | 
						|
			}
 | 
						|
			parent.SetCheckRetry(checkRetry)
 | 
						|
 | 
						|
			parent.SetLogger(hclog.NewNullLogger())
 | 
						|
 | 
						|
			parent.SetLimiter(5.0, 10)
 | 
						|
			parent.SetMaxRetries(5)
 | 
						|
			parent.SetOutputCurlString(true)
 | 
						|
			parent.SetOutputPolicy(true)
 | 
						|
			parent.SetSRVLookup(true)
 | 
						|
 | 
						|
			if tt.headers != nil {
 | 
						|
				parent.SetHeaders(*tt.headers)
 | 
						|
			}
 | 
						|
 | 
						|
			if tt.token != "" {
 | 
						|
				parent.SetToken(tt.token)
 | 
						|
			}
 | 
						|
 | 
						|
			clone, err := parent.Clone()
 | 
						|
			if err != nil {
 | 
						|
				t.Fatalf("Clone failed: %v", err)
 | 
						|
			}
 | 
						|
 | 
						|
			if parent.Address() != clone.Address() {
 | 
						|
				t.Fatalf("addresses don't match: %v vs %v", parent.Address(), clone.Address())
 | 
						|
			}
 | 
						|
			if parent.ClientTimeout() != clone.ClientTimeout() {
 | 
						|
				t.Fatalf("timeouts don't match: %v vs %v", parent.ClientTimeout(), clone.ClientTimeout())
 | 
						|
			}
 | 
						|
			if parent.CheckRetry() != nil && clone.CheckRetry() == nil {
 | 
						|
				t.Fatal("checkRetry functions don't match. clone is nil.")
 | 
						|
			}
 | 
						|
			if (parent.Limiter() != nil && clone.Limiter() == nil) || (parent.Limiter() == nil && clone.Limiter() != nil) {
 | 
						|
				t.Fatalf("limiters don't match: %v vs %v", parent.Limiter(), clone.Limiter())
 | 
						|
			}
 | 
						|
			if parent.Limiter().Limit() != clone.Limiter().Limit() {
 | 
						|
				t.Fatalf("limiter limits don't match: %v vs %v", parent.Limiter().Limit(), clone.Limiter().Limit())
 | 
						|
			}
 | 
						|
			if parent.Limiter().Burst() != clone.Limiter().Burst() {
 | 
						|
				t.Fatalf("limiter bursts don't match: %v vs %v", parent.Limiter().Burst(), clone.Limiter().Burst())
 | 
						|
			}
 | 
						|
			if parent.MaxRetries() != clone.MaxRetries() {
 | 
						|
				t.Fatalf("maxRetries don't match: %v vs %v", parent.MaxRetries(), clone.MaxRetries())
 | 
						|
			}
 | 
						|
			if parent.OutputCurlString() == clone.OutputCurlString() {
 | 
						|
				t.Fatalf("outputCurlString was copied over when it shouldn't have been: %v and %v", parent.OutputCurlString(), clone.OutputCurlString())
 | 
						|
			}
 | 
						|
			if parent.SRVLookup() != clone.SRVLookup() {
 | 
						|
				t.Fatalf("SRVLookup doesn't match: %v vs %v", parent.SRVLookup(), clone.SRVLookup())
 | 
						|
			}
 | 
						|
			if tt.config.CloneHeaders {
 | 
						|
				if !reflect.DeepEqual(parent.Headers(), clone.Headers()) {
 | 
						|
					t.Fatalf("Headers() don't match: %v vs %v", parent.Headers(), clone.Headers())
 | 
						|
				}
 | 
						|
				if parent.config.CloneHeaders != clone.config.CloneHeaders {
 | 
						|
					t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", parent.config.CloneHeaders, clone.config.CloneHeaders)
 | 
						|
				}
 | 
						|
				if tt.headers != nil {
 | 
						|
					if !reflect.DeepEqual(*tt.headers, clone.Headers()) {
 | 
						|
						t.Fatalf("expected headers %v, actual %v", *tt.headers, clone.Headers())
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
			if tt.config.ReadYourWrites && parent.replicationStateStore == nil {
 | 
						|
				t.Fatalf("replicationStateStore is nil")
 | 
						|
			}
 | 
						|
			if tt.config.CloneToken {
 | 
						|
				if tt.token == "" {
 | 
						|
					t.Fatalf("test requires a non-empty token")
 | 
						|
				}
 | 
						|
				if parent.config.CloneToken != clone.config.CloneToken {
 | 
						|
					t.Fatalf("config.CloneToken doesn't match: %v vs %v", parent.config.CloneToken, clone.config.CloneToken)
 | 
						|
				}
 | 
						|
				if parent.token != clone.token {
 | 
						|
					t.Fatalf("tokens do not match: %v vs %v", parent.token, clone.token)
 | 
						|
				}
 | 
						|
			} else {
 | 
						|
				// assumes `VAULT_TOKEN` is unset or has an empty value.
 | 
						|
				expected := ""
 | 
						|
				if clone.token != expected {
 | 
						|
					t.Fatalf("expected clone's token %q, actual %q", expected, clone.token)
 | 
						|
				}
 | 
						|
			}
 | 
						|
			if !reflect.DeepEqual(parent.replicationStateStore, clone.replicationStateStore) {
 | 
						|
				t.Fatalf("expected replicationStateStore %v, actual %v", parent.replicationStateStore,
 | 
						|
					clone.replicationStateStore)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestSetHeadersRaceSafe(t *testing.T) {
 | 
						|
	client, err1 := NewClient(nil)
 | 
						|
	if err1 != nil {
 | 
						|
		t.Fatalf("NewClient failed: %v", err1)
 | 
						|
	}
 | 
						|
 | 
						|
	start := make(chan interface{})
 | 
						|
	done := make(chan interface{})
 | 
						|
 | 
						|
	testPairs := map[string]string{
 | 
						|
		"soda":    "rootbeer",
 | 
						|
		"veggie":  "carrots",
 | 
						|
		"fruit":   "apples",
 | 
						|
		"color":   "red",
 | 
						|
		"protein": "egg",
 | 
						|
	}
 | 
						|
 | 
						|
	for key, value := range testPairs {
 | 
						|
		tmpKey := key
 | 
						|
		tmpValue := value
 | 
						|
		go func() {
 | 
						|
			<-start
 | 
						|
			// This test fails if here, you replace client.AddHeader(tmpKey, tmpValue) with:
 | 
						|
			// 	headerCopy := client.Header()
 | 
						|
			// 	headerCopy.AddHeader(tmpKey, tmpValue)
 | 
						|
			// 	client.SetHeader(headerCopy)
 | 
						|
			client.AddHeader(tmpKey, tmpValue)
 | 
						|
			done <- true
 | 
						|
		}()
 | 
						|
	}
 | 
						|
 | 
						|
	// Start everyone at once.
 | 
						|
	close(start)
 | 
						|
 | 
						|
	// Wait until everyone is done.
 | 
						|
	for i := 0; i < len(testPairs); i++ {
 | 
						|
		<-done
 | 
						|
	}
 | 
						|
 | 
						|
	// Check that all the test pairs are in the resulting
 | 
						|
	// headers.
 | 
						|
	resultingHeaders := client.Headers()
 | 
						|
	for key, value := range testPairs {
 | 
						|
		if resultingHeaders.Get(key) != value {
 | 
						|
			t.Fatal("expected " + value + " for " + key)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestMergeReplicationStates(t *testing.T) {
 | 
						|
	type testCase struct {
 | 
						|
		name     string
 | 
						|
		old      []string
 | 
						|
		new      string
 | 
						|
		expected []string
 | 
						|
	}
 | 
						|
 | 
						|
	testCases := []testCase{
 | 
						|
		{
 | 
						|
			name:     "empty-old",
 | 
						|
			old:      nil,
 | 
						|
			new:      "v1:cid:1:0:",
 | 
						|
			expected: []string{"v1:cid:1:0:"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:     "old-smaller",
 | 
						|
			old:      []string{"v1:cid:1:0:"},
 | 
						|
			new:      "v1:cid:2:0:",
 | 
						|
			expected: []string{"v1:cid:2:0:"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:     "old-bigger",
 | 
						|
			old:      []string{"v1:cid:2:0:"},
 | 
						|
			new:      "v1:cid:1:0:",
 | 
						|
			expected: []string{"v1:cid:2:0:"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:     "mixed-single",
 | 
						|
			old:      []string{"v1:cid:1:0:"},
 | 
						|
			new:      "v1:cid:0:1:",
 | 
						|
			expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:     "mixed-single-alt",
 | 
						|
			old:      []string{"v1:cid:0:1:"},
 | 
						|
			new:      "v1:cid:1:0:",
 | 
						|
			expected: []string{"v1:cid:0:1:", "v1:cid:1:0:"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:     "mixed-double",
 | 
						|
			old:      []string{"v1:cid:0:1:", "v1:cid:1:0:"},
 | 
						|
			new:      "v1:cid:2:0:",
 | 
						|
			expected: []string{"v1:cid:0:1:", "v1:cid:2:0:"},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:     "newer-both",
 | 
						|
			old:      []string{"v1:cid:0:1:", "v1:cid:1:0:"},
 | 
						|
			new:      "v1:cid:2:1:",
 | 
						|
			expected: []string{"v1:cid:2:1:"},
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	b64enc := func(ss []string) []string {
 | 
						|
		var ret []string
 | 
						|
		for _, s := range ss {
 | 
						|
			ret = append(ret, base64.StdEncoding.EncodeToString([]byte(s)))
 | 
						|
		}
 | 
						|
		return ret
 | 
						|
	}
 | 
						|
	b64dec := func(ss []string) []string {
 | 
						|
		var ret []string
 | 
						|
		for _, s := range ss {
 | 
						|
			d, err := base64.StdEncoding.DecodeString(s)
 | 
						|
			if err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
			ret = append(ret, string(d))
 | 
						|
		}
 | 
						|
		return ret
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tc := range testCases {
 | 
						|
		t.Run(tc.name, func(t *testing.T) {
 | 
						|
			out := b64dec(MergeReplicationStates(b64enc(tc.old), base64.StdEncoding.EncodeToString([]byte(tc.new))))
 | 
						|
			if diff := deep.Equal(out, tc.expected); len(diff) != 0 {
 | 
						|
				t.Errorf("got=%v, expected=%v, diff=%v", out, tc.expected, diff)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestReplicationStateStore_recordState(t *testing.T) {
 | 
						|
	b64enc := func(s string) string {
 | 
						|
		return base64.StdEncoding.EncodeToString([]byte(s))
 | 
						|
	}
 | 
						|
 | 
						|
	tests := []struct {
 | 
						|
		name     string
 | 
						|
		expected []string
 | 
						|
		resp     []*Response
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name: "single",
 | 
						|
			resp: []*Response{
 | 
						|
				{
 | 
						|
					Response: &http.Response{
 | 
						|
						Header: map[string][]string{
 | 
						|
							HeaderIndex: {
 | 
						|
								b64enc("v1:cid:1:0:"),
 | 
						|
							},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expected: []string{
 | 
						|
				b64enc("v1:cid:1:0:"),
 | 
						|
			},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "empty",
 | 
						|
			resp: []*Response{
 | 
						|
				{
 | 
						|
					Response: &http.Response{
 | 
						|
						Header: map[string][]string{},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expected: nil,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "multiple",
 | 
						|
			resp: []*Response{
 | 
						|
				{
 | 
						|
					Response: &http.Response{
 | 
						|
						Header: map[string][]string{
 | 
						|
							HeaderIndex: {
 | 
						|
								b64enc("v1:cid:0:1:"),
 | 
						|
							},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
				{
 | 
						|
					Response: &http.Response{
 | 
						|
						Header: map[string][]string{
 | 
						|
							HeaderIndex: {
 | 
						|
								b64enc("v1:cid:1:0:"),
 | 
						|
							},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expected: []string{
 | 
						|
				b64enc("v1:cid:0:1:"),
 | 
						|
				b64enc("v1:cid:1:0:"),
 | 
						|
			},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "duplicates",
 | 
						|
			resp: []*Response{
 | 
						|
				{
 | 
						|
					Response: &http.Response{
 | 
						|
						Header: map[string][]string{
 | 
						|
							HeaderIndex: {
 | 
						|
								b64enc("v1:cid:1:0:"),
 | 
						|
							},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
				{
 | 
						|
					Response: &http.Response{
 | 
						|
						Header: map[string][]string{
 | 
						|
							HeaderIndex: {
 | 
						|
								b64enc("v1:cid:1:0:"),
 | 
						|
							},
 | 
						|
						},
 | 
						|
					},
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expected: []string{
 | 
						|
				b64enc("v1:cid:1:0:"),
 | 
						|
			},
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tt := range tests {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			w := &replicationStateStore{}
 | 
						|
 | 
						|
			var wg sync.WaitGroup
 | 
						|
			for _, r := range tt.resp {
 | 
						|
				wg.Add(1)
 | 
						|
				go func(r *Response) {
 | 
						|
					defer wg.Done()
 | 
						|
					w.recordState(r)
 | 
						|
				}(r)
 | 
						|
			}
 | 
						|
			wg.Wait()
 | 
						|
 | 
						|
			if !reflect.DeepEqual(tt.expected, w.store) {
 | 
						|
				t.Errorf("recordState(): expected states %v, actual %v", tt.expected, w.store)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestReplicationStateStore_requireState(t *testing.T) {
 | 
						|
	tests := []struct {
 | 
						|
		name     string
 | 
						|
		states   []string
 | 
						|
		req      []*Request
 | 
						|
		expected []string
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name:   "empty",
 | 
						|
			states: []string{},
 | 
						|
			req: []*Request{
 | 
						|
				{
 | 
						|
					Headers: make(http.Header),
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expected: nil,
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name: "basic",
 | 
						|
			states: []string{
 | 
						|
				"v1:cid:0:1:",
 | 
						|
				"v1:cid:1:0:",
 | 
						|
			},
 | 
						|
			req: []*Request{
 | 
						|
				{
 | 
						|
					Headers: make(http.Header),
 | 
						|
				},
 | 
						|
			},
 | 
						|
			expected: []string{
 | 
						|
				"v1:cid:0:1:",
 | 
						|
				"v1:cid:1:0:",
 | 
						|
			},
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tt := range tests {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			store := &replicationStateStore{
 | 
						|
				store: tt.states,
 | 
						|
			}
 | 
						|
 | 
						|
			var wg sync.WaitGroup
 | 
						|
			for _, r := range tt.req {
 | 
						|
				wg.Add(1)
 | 
						|
				go func(r *Request) {
 | 
						|
					defer wg.Done()
 | 
						|
					store.requireState(r)
 | 
						|
				}(r)
 | 
						|
			}
 | 
						|
 | 
						|
			wg.Wait()
 | 
						|
 | 
						|
			var actual []string
 | 
						|
			for _, r := range tt.req {
 | 
						|
				if values := r.Headers.Values(HeaderIndex); len(values) > 0 {
 | 
						|
					actual = append(actual, values...)
 | 
						|
				}
 | 
						|
			}
 | 
						|
			sort.Strings(actual)
 | 
						|
			if !reflect.DeepEqual(tt.expected, actual) {
 | 
						|
				t.Errorf("requireState(): expected states %v, actual %v", tt.expected, actual)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClient_ReadYourWrites(t *testing.T) {
 | 
						|
	b64enc := func(s string) string {
 | 
						|
		return base64.StdEncoding.EncodeToString([]byte(s))
 | 
						|
	}
 | 
						|
 | 
						|
	handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		w.Header().Set(HeaderIndex, strings.TrimLeft(req.URL.Path, "/"))
 | 
						|
	})
 | 
						|
 | 
						|
	tests := []struct {
 | 
						|
		name       string
 | 
						|
		handler    http.Handler
 | 
						|
		wantStates []string
 | 
						|
		values     [][]string
 | 
						|
		clone      bool
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name:    "multiple_duplicates",
 | 
						|
			clone:   false,
 | 
						|
			handler: handler,
 | 
						|
			wantStates: []string{
 | 
						|
				b64enc("v1:cid:0:4:"),
 | 
						|
			},
 | 
						|
			values: [][]string{
 | 
						|
				{
 | 
						|
					b64enc("v1:cid:0:4:"),
 | 
						|
					b64enc("v1:cid:0:2:"),
 | 
						|
				},
 | 
						|
				{
 | 
						|
					b64enc("v1:cid:0:4:"),
 | 
						|
					b64enc("v1:cid:0:2:"),
 | 
						|
				},
 | 
						|
			},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:    "basic_clone",
 | 
						|
			clone:   true,
 | 
						|
			handler: handler,
 | 
						|
			wantStates: []string{
 | 
						|
				b64enc("v1:cid:0:4:"),
 | 
						|
			},
 | 
						|
			values: [][]string{
 | 
						|
				{
 | 
						|
					b64enc("v1:cid:0:4:"),
 | 
						|
				},
 | 
						|
				{
 | 
						|
					b64enc("v1:cid:0:3:"),
 | 
						|
				},
 | 
						|
			},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:    "multiple_clone",
 | 
						|
			clone:   true,
 | 
						|
			handler: handler,
 | 
						|
			wantStates: []string{
 | 
						|
				b64enc("v1:cid:0:4:"),
 | 
						|
			},
 | 
						|
			values: [][]string{
 | 
						|
				{
 | 
						|
					b64enc("v1:cid:0:4:"),
 | 
						|
					b64enc("v1:cid:0:2:"),
 | 
						|
				},
 | 
						|
				{
 | 
						|
					b64enc("v1:cid:0:3:"),
 | 
						|
					b64enc("v1:cid:0:1:"),
 | 
						|
				},
 | 
						|
			},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:    "multiple_duplicates_clone",
 | 
						|
			clone:   true,
 | 
						|
			handler: handler,
 | 
						|
			wantStates: []string{
 | 
						|
				b64enc("v1:cid:0:4:"),
 | 
						|
			},
 | 
						|
			values: [][]string{
 | 
						|
				{
 | 
						|
					b64enc("v1:cid:0:4:"),
 | 
						|
					b64enc("v1:cid:0:2:"),
 | 
						|
				},
 | 
						|
				{
 | 
						|
					b64enc("v1:cid:0:4:"),
 | 
						|
					b64enc("v1:cid:0:2:"),
 | 
						|
				},
 | 
						|
			},
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for _, tt := range tests {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			testRequest := func(client *Client, val string) {
 | 
						|
				req := client.NewRequest(http.MethodGet, "/"+val)
 | 
						|
				req.Headers.Set(HeaderIndex, val)
 | 
						|
				resp, err := client.RawRequestWithContext(context.Background(), req)
 | 
						|
				if err != nil {
 | 
						|
					t.Fatal(err)
 | 
						|
				}
 | 
						|
 | 
						|
				// validate that the server provided a valid header value in its response
 | 
						|
				actual := resp.Header.Get(HeaderIndex)
 | 
						|
				if actual != val {
 | 
						|
					t.Errorf("expected header value %v, actual %v", val, actual)
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			config, ln := testHTTPServer(t, handler)
 | 
						|
			defer ln.Close()
 | 
						|
 | 
						|
			config.ReadYourWrites = true
 | 
						|
			config.Address = fmt.Sprintf("http://%s", ln.Addr())
 | 
						|
			parent, err := NewClient(config)
 | 
						|
			if err != nil {
 | 
						|
				t.Fatal(err)
 | 
						|
			}
 | 
						|
 | 
						|
			var wg sync.WaitGroup
 | 
						|
			for i := 0; i < len(tt.values); i++ {
 | 
						|
				var c *Client
 | 
						|
				if tt.clone {
 | 
						|
					c, err = parent.Clone()
 | 
						|
					if err != nil {
 | 
						|
						t.Fatal(err)
 | 
						|
					}
 | 
						|
				} else {
 | 
						|
					c = parent
 | 
						|
				}
 | 
						|
 | 
						|
				for _, val := range tt.values[i] {
 | 
						|
					wg.Add(1)
 | 
						|
					go func(val string) {
 | 
						|
						defer wg.Done()
 | 
						|
						testRequest(c, val)
 | 
						|
					}(val)
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			wg.Wait()
 | 
						|
 | 
						|
			if !reflect.DeepEqual(tt.wantStates, parent.replicationStateStore.states()) {
 | 
						|
				t.Errorf("expected states %v, actual %v", tt.wantStates, parent.replicationStateStore.states())
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClient_SetReadYourWrites(t *testing.T) {
 | 
						|
	tests := []struct {
 | 
						|
		name   string
 | 
						|
		config *Config
 | 
						|
		calls  []bool
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name:   "false",
 | 
						|
			config: &Config{},
 | 
						|
			calls:  []bool{false},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:   "true",
 | 
						|
			config: &Config{},
 | 
						|
			calls:  []bool{true},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:   "multi-false",
 | 
						|
			config: &Config{},
 | 
						|
			calls:  []bool{false, false},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:   "multi-true",
 | 
						|
			config: &Config{},
 | 
						|
			calls:  []bool{true, true},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:   "multi-mix",
 | 
						|
			config: &Config{},
 | 
						|
			calls:  []bool{false, true, false, true},
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	assertSetReadYourRights := func(t *testing.T, c *Client, v bool, s *replicationStateStore) {
 | 
						|
		t.Helper()
 | 
						|
		c.SetReadYourWrites(v)
 | 
						|
		if c.config.ReadYourWrites != v {
 | 
						|
			t.Fatalf("expected config.ReadYourWrites %#v, actual %#v", v, c.config.ReadYourWrites)
 | 
						|
		}
 | 
						|
		if !reflect.DeepEqual(s, c.replicationStateStore) {
 | 
						|
			t.Fatalf("expected replicationStateStore %#v, actual %#v", s, c.replicationStateStore)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	for _, tt := range tests {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			c := &Client{
 | 
						|
				config: tt.config,
 | 
						|
			}
 | 
						|
			for i, v := range tt.calls {
 | 
						|
				var expectStateStore *replicationStateStore
 | 
						|
				if v {
 | 
						|
					if c.replicationStateStore == nil {
 | 
						|
						c.replicationStateStore = &replicationStateStore{
 | 
						|
							store: []string{},
 | 
						|
						}
 | 
						|
					}
 | 
						|
					c.replicationStateStore.store = append(c.replicationStateStore.store,
 | 
						|
						fmt.Sprintf("%s-%d", tt.name, i))
 | 
						|
					expectStateStore = c.replicationStateStore
 | 
						|
				}
 | 
						|
				assertSetReadYourRights(t, c, v, expectStateStore)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClient_SetCloneToken(t *testing.T) {
 | 
						|
	tests := []struct {
 | 
						|
		name  string
 | 
						|
		calls []bool
 | 
						|
	}{
 | 
						|
		{
 | 
						|
			name:  "false",
 | 
						|
			calls: []bool{false},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:  "true",
 | 
						|
			calls: []bool{true},
 | 
						|
		},
 | 
						|
		{
 | 
						|
			name:  "multi",
 | 
						|
			calls: []bool{true, false, true},
 | 
						|
		},
 | 
						|
	}
 | 
						|
	for _, tt := range tests {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			c := &Client{
 | 
						|
				config: &Config{},
 | 
						|
			}
 | 
						|
 | 
						|
			var expected bool
 | 
						|
			for _, v := range tt.calls {
 | 
						|
				actual := c.CloneToken()
 | 
						|
				if expected != actual {
 | 
						|
					t.Fatalf("expected %v, actual %v", expected, actual)
 | 
						|
				}
 | 
						|
 | 
						|
				expected = v
 | 
						|
				c.SetCloneToken(expected)
 | 
						|
				actual = c.CloneToken()
 | 
						|
				if actual != expected {
 | 
						|
					t.Fatalf("SetCloneToken(): expected %v, actual %v", expected, actual)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientWithNamespace(t *testing.T) {
 | 
						|
	var ns string
 | 
						|
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		ns = req.Header.Get(consts.NamespaceHeaderName)
 | 
						|
	}
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(handler))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	// set up a client with a namespace
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
	ogNS := "test"
 | 
						|
	client.SetNamespace(ogNS)
 | 
						|
	_, err = client.rawRequestWithContext(
 | 
						|
		context.Background(),
 | 
						|
		client.NewRequest(http.MethodGet, "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
	if ns != ogNS {
 | 
						|
		t.Fatalf("Expected namespace: %q, got %q", ogNS, ns)
 | 
						|
	}
 | 
						|
 | 
						|
	// make a call with a temporary namespace
 | 
						|
	newNS := "new-namespace"
 | 
						|
	_, err = client.WithNamespace(newNS).rawRequestWithContext(
 | 
						|
		context.Background(),
 | 
						|
		client.NewRequest(http.MethodGet, "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
	if ns != newNS {
 | 
						|
		t.Fatalf("Expected new namespace: %q, got %q", newNS, ns)
 | 
						|
	}
 | 
						|
	// ensure client has not been modified
 | 
						|
	_, err = client.rawRequestWithContext(
 | 
						|
		context.Background(),
 | 
						|
		client.NewRequest(http.MethodGet, "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
	if ns != ogNS {
 | 
						|
		t.Fatalf("Expected original namespace: %q, got %q", ogNS, ns)
 | 
						|
	}
 | 
						|
 | 
						|
	// make call with empty ns
 | 
						|
	_, err = client.WithNamespace("").rawRequestWithContext(
 | 
						|
		context.Background(),
 | 
						|
		client.NewRequest(http.MethodGet, "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
	if ns != "" {
 | 
						|
		t.Fatalf("Expected no namespace, got %q", ns)
 | 
						|
	}
 | 
						|
 | 
						|
	// ensure client has not been modified
 | 
						|
	if client.Namespace() != ogNS {
 | 
						|
		t.Fatalf("Expected original namespace: %q, got %q", ogNS, client.Namespace())
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestVaultProxy(t *testing.T) {
 | 
						|
	const NoProxy string = "NO_PROXY"
 | 
						|
 | 
						|
	tests := map[string]struct {
 | 
						|
		name                     string
 | 
						|
		vaultHttpProxy           string
 | 
						|
		vaultProxyAddr           string
 | 
						|
		noProxy                  string
 | 
						|
		requestUrl               string
 | 
						|
		expectedResolvedProxyUrl string
 | 
						|
	}{
 | 
						|
		"VAULT_HTTP_PROXY used when NO_PROXY env var doesn't include request host": {
 | 
						|
			vaultHttpProxy: "https://hashicorp.com",
 | 
						|
			vaultProxyAddr: "",
 | 
						|
			noProxy:        "terraform.io",
 | 
						|
			requestUrl:     "https://vaultproject.io",
 | 
						|
		},
 | 
						|
		"VAULT_HTTP_PROXY used when NO_PROXY env var includes request host": {
 | 
						|
			vaultHttpProxy: "https://hashicorp.com",
 | 
						|
			vaultProxyAddr: "",
 | 
						|
			noProxy:        "terraform.io,vaultproject.io",
 | 
						|
			requestUrl:     "https://vaultproject.io",
 | 
						|
		},
 | 
						|
		"VAULT_PROXY_ADDR used when NO_PROXY env var doesn't include request host": {
 | 
						|
			vaultHttpProxy: "",
 | 
						|
			vaultProxyAddr: "https://hashicorp.com",
 | 
						|
			noProxy:        "terraform.io",
 | 
						|
			requestUrl:     "https://vaultproject.io",
 | 
						|
		},
 | 
						|
		"VAULT_PROXY_ADDR used when NO_PROXY env var includes request host": {
 | 
						|
			vaultHttpProxy: "",
 | 
						|
			vaultProxyAddr: "https://hashicorp.com",
 | 
						|
			noProxy:        "terraform.io,vaultproject.io",
 | 
						|
			requestUrl:     "https://vaultproject.io",
 | 
						|
		},
 | 
						|
		"VAULT_PROXY_ADDR used when VAULT_HTTP_PROXY env var also supplied": {
 | 
						|
			vaultHttpProxy:           "https://hashicorp.com",
 | 
						|
			vaultProxyAddr:           "https://terraform.io",
 | 
						|
			noProxy:                  "",
 | 
						|
			requestUrl:               "https://vaultproject.io",
 | 
						|
			expectedResolvedProxyUrl: "https://terraform.io",
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	for name, tc := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			if tc.vaultHttpProxy != "" {
 | 
						|
				oldVaultHttpProxy := os.Getenv(EnvHTTPProxy)
 | 
						|
				os.Setenv(EnvHTTPProxy, tc.vaultHttpProxy)
 | 
						|
				defer os.Setenv(EnvHTTPProxy, oldVaultHttpProxy)
 | 
						|
			}
 | 
						|
 | 
						|
			if tc.vaultProxyAddr != "" {
 | 
						|
				oldVaultProxyAddr := os.Getenv(EnvVaultProxyAddr)
 | 
						|
				os.Setenv(EnvVaultProxyAddr, tc.vaultProxyAddr)
 | 
						|
				defer os.Setenv(EnvVaultProxyAddr, oldVaultProxyAddr)
 | 
						|
			}
 | 
						|
 | 
						|
			if tc.noProxy != "" {
 | 
						|
				oldNoProxy := os.Getenv(NoProxy)
 | 
						|
				os.Setenv(NoProxy, tc.noProxy)
 | 
						|
				defer os.Setenv(NoProxy, oldNoProxy)
 | 
						|
			}
 | 
						|
 | 
						|
			c := DefaultConfig()
 | 
						|
			if c.Error != nil {
 | 
						|
				t.Fatalf("Expected no error reading config, found error %v", c.Error)
 | 
						|
			}
 | 
						|
 | 
						|
			r, _ := http.NewRequest("GET", tc.requestUrl, nil)
 | 
						|
			proxyUrl, err := c.HttpClient.Transport.(*http.Transport).Proxy(r)
 | 
						|
			if err != nil {
 | 
						|
				t.Fatalf("Expected no error resolving proxy, found error %v", err)
 | 
						|
			}
 | 
						|
			if proxyUrl == nil || proxyUrl.String() == "" {
 | 
						|
				t.Fatalf("Expected proxy to be resolved but no proxy returned")
 | 
						|
			}
 | 
						|
			if tc.expectedResolvedProxyUrl != "" && proxyUrl.String() != tc.expectedResolvedProxyUrl {
 | 
						|
				t.Fatalf("Expected resolved proxy URL to be %v but was %v", tc.expectedResolvedProxyUrl, proxyUrl.String())
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParseAddressWithUnixSocket(t *testing.T) {
 | 
						|
	address := "unix:///var/run/vault.sock"
 | 
						|
	config := DefaultConfig()
 | 
						|
 | 
						|
	u, err := config.ParseAddress(address)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal("Error not expected")
 | 
						|
	}
 | 
						|
	if u.Scheme != "http" {
 | 
						|
		t.Fatal("Scheme not changed to http")
 | 
						|
	}
 | 
						|
	if u.Host != "/var/run/vault.sock" {
 | 
						|
		t.Fatal("Host not changed to socket name")
 | 
						|
	}
 | 
						|
	if u.Path != "" {
 | 
						|
		t.Fatal("Path expected to be blank")
 | 
						|
	}
 | 
						|
	if config.HttpClient.Transport.(*http.Transport).DialContext == nil {
 | 
						|
		t.Fatal("DialContext function not set in config.HttpClient.Transport")
 | 
						|
	}
 | 
						|
}
 |