mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-04 04:28:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			311 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			311 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package api
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"github.com/hashicorp/vault/sdk/helper/consts"
 | 
						|
	"io"
 | 
						|
	"net/http"
 | 
						|
	"os"
 | 
						|
	"strings"
 | 
						|
	"testing"
 | 
						|
)
 | 
						|
 | 
						|
func init() {
 | 
						|
	// Ensure our special envvars are not present
 | 
						|
	os.Setenv("VAULT_ADDR", "")
 | 
						|
	os.Setenv("VAULT_TOKEN", "")
 | 
						|
}
 | 
						|
 | 
						|
func TestDefaultConfig_envvar(t *testing.T) {
 | 
						|
	os.Setenv("VAULT_ADDR", "https://vault.mycompany.com")
 | 
						|
	defer os.Setenv("VAULT_ADDR", "")
 | 
						|
 | 
						|
	config := DefaultConfig()
 | 
						|
	if config.Address != "https://vault.mycompany.com" {
 | 
						|
		t.Fatalf("bad: %s", config.Address)
 | 
						|
	}
 | 
						|
 | 
						|
	os.Setenv("VAULT_TOKEN", "testing")
 | 
						|
	defer os.Setenv("VAULT_TOKEN", "")
 | 
						|
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if token := client.Token(); token != "testing" {
 | 
						|
		t.Fatalf("bad: %s", token)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientDefaultHttpClient(t *testing.T) {
 | 
						|
	_, err := NewClient(&Config{
 | 
						|
		HttpClient: http.DefaultClient,
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientNilConfig(t *testing.T) {
 | 
						|
	client, err := NewClient(nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if client == nil {
 | 
						|
		t.Fatal("expected a non-nil client")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientSetAddress(t *testing.T) {
 | 
						|
	client, err := NewClient(nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if client.addr.Host != "172.168.2.1:8300" {
 | 
						|
		t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientToken(t *testing.T) {
 | 
						|
	tokenValue := "foo"
 | 
						|
	handler := func(w http.ResponseWriter, req *http.Request) {}
 | 
						|
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(handler))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	client.SetToken(tokenValue)
 | 
						|
 | 
						|
	// Verify the token is set
 | 
						|
	if v := client.Token(); v != tokenValue {
 | 
						|
		t.Fatalf("bad: %s", v)
 | 
						|
	}
 | 
						|
 | 
						|
	client.ClearToken()
 | 
						|
 | 
						|
	if v := client.Token(); v != "" {
 | 
						|
		t.Fatalf("bad: %s", v)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientBadToken(t *testing.T) {
 | 
						|
	handler := func(w http.ResponseWriter, req *http.Request) {}
 | 
						|
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(handler))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	client.SetToken("foo")
 | 
						|
	_, err = client.RawRequest(client.NewRequest("PUT", "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
 | 
						|
	client.SetToken("foo\u007f")
 | 
						|
	_, err = client.RawRequest(client.NewRequest("PUT", "/"))
 | 
						|
	if err == nil || !strings.Contains(err.Error(), "printable") {
 | 
						|
		t.Fatalf("expected error due to bad token")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientRedirect(t *testing.T) {
 | 
						|
	primary := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		w.Write([]byte("test"))
 | 
						|
	}
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(primary))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	standby := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		w.Header().Set("Location", config.Address)
 | 
						|
		w.WriteHeader(307)
 | 
						|
	}
 | 
						|
	config2, ln2 := testHTTPServer(t, http.HandlerFunc(standby))
 | 
						|
	defer ln2.Close()
 | 
						|
 | 
						|
	client, err := NewClient(config2)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Set the token manually
 | 
						|
	client.SetToken("foo")
 | 
						|
 | 
						|
	// Do a raw "/" request
 | 
						|
	resp, err := client.RawRequest(client.NewRequest("PUT", "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Copy the response
 | 
						|
	var buf bytes.Buffer
 | 
						|
	io.Copy(&buf, resp.Body)
 | 
						|
 | 
						|
	// Verify we got the response from the primary
 | 
						|
	if buf.String() != "test" {
 | 
						|
		t.Fatalf("Bad: %s", buf.String())
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientEnvSettings(t *testing.T) {
 | 
						|
	cwd, _ := os.Getwd()
 | 
						|
	oldCACert := os.Getenv(EnvVaultCACert)
 | 
						|
	oldCAPath := os.Getenv(EnvVaultCAPath)
 | 
						|
	oldClientCert := os.Getenv(EnvVaultClientCert)
 | 
						|
	oldClientKey := os.Getenv(EnvVaultClientKey)
 | 
						|
	oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
 | 
						|
	oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
 | 
						|
	os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
 | 
						|
	os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
 | 
						|
	os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
 | 
						|
	os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
 | 
						|
	os.Setenv(EnvVaultSkipVerify, "true")
 | 
						|
	os.Setenv(EnvVaultMaxRetries, "5")
 | 
						|
	defer os.Setenv(EnvVaultCACert, oldCACert)
 | 
						|
	defer os.Setenv(EnvVaultCAPath, oldCAPath)
 | 
						|
	defer os.Setenv(EnvVaultClientCert, oldClientCert)
 | 
						|
	defer os.Setenv(EnvVaultClientKey, oldClientKey)
 | 
						|
	defer os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
 | 
						|
	defer os.Setenv(EnvVaultMaxRetries, oldMaxRetries)
 | 
						|
 | 
						|
	config := DefaultConfig()
 | 
						|
	if err := config.ReadEnvironment(); err != nil {
 | 
						|
		t.Fatalf("error reading environment: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	tlsConfig := config.HttpClient.Transport.(*http.Transport).TLSClientConfig
 | 
						|
	if len(tlsConfig.RootCAs.Subjects()) == 0 {
 | 
						|
		t.Fatalf("bad: expected a cert pool with at least one subject")
 | 
						|
	}
 | 
						|
	if tlsConfig.GetClientCertificate == nil {
 | 
						|
		t.Fatalf("bad: expected client tls config to have a certificate getter")
 | 
						|
	}
 | 
						|
	if tlsConfig.InsecureSkipVerify != true {
 | 
						|
		t.Fatalf("bad: %v", tlsConfig.InsecureSkipVerify)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientEnvNamespace(t *testing.T) {
 | 
						|
	var seenNamespace string
 | 
						|
	handler := func(w http.ResponseWriter, req *http.Request) {
 | 
						|
		seenNamespace = req.Header.Get(consts.NamespaceHeaderName)
 | 
						|
	}
 | 
						|
	config, ln := testHTTPServer(t, http.HandlerFunc(handler))
 | 
						|
	defer ln.Close()
 | 
						|
 | 
						|
	oldVaultNamespace := os.Getenv(EnvVaultNamespace)
 | 
						|
	defer os.Setenv(EnvVaultNamespace, oldVaultNamespace)
 | 
						|
	os.Setenv(EnvVaultNamespace, "test")
 | 
						|
 | 
						|
	client, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	_, err = client.RawRequest(client.NewRequest("GET", "/"))
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if seenNamespace != "test" {
 | 
						|
		t.Fatalf("Bad: %s", seenNamespace)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParsingRateAndBurst(t *testing.T) {
 | 
						|
	var (
 | 
						|
		correctFormat                    = "400:400"
 | 
						|
		observedRate, observedBurst, err = parseRateLimit(correctFormat)
 | 
						|
		expectedRate, expectedBurst      = float64(400), 400
 | 
						|
	)
 | 
						|
	if err != nil {
 | 
						|
		t.Error(err)
 | 
						|
	}
 | 
						|
	if expectedRate != observedRate {
 | 
						|
		t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
 | 
						|
	}
 | 
						|
	if expectedBurst != observedBurst {
 | 
						|
		t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParsingRateOnly(t *testing.T) {
 | 
						|
	var (
 | 
						|
		correctFormat                    = "400"
 | 
						|
		observedRate, observedBurst, err = parseRateLimit(correctFormat)
 | 
						|
		expectedRate, expectedBurst      = float64(400), 400
 | 
						|
	)
 | 
						|
	if err != nil {
 | 
						|
		t.Error(err)
 | 
						|
	}
 | 
						|
	if expectedRate != observedRate {
 | 
						|
		t.Errorf("Expected rate %v but found %v", expectedRate, observedRate)
 | 
						|
	}
 | 
						|
	if expectedBurst != observedBurst {
 | 
						|
		t.Errorf("Expected burst %v but found %v", expectedBurst, observedBurst)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestParsingErrorCase(t *testing.T) {
 | 
						|
	var incorrectFormat = "foobar"
 | 
						|
	var _, _, err = parseRateLimit(incorrectFormat)
 | 
						|
	if err == nil {
 | 
						|
		t.Error("Expected error, found no error")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClientTimeoutSetting(t *testing.T) {
 | 
						|
	oldClientTimeout := os.Getenv(EnvVaultClientTimeout)
 | 
						|
	os.Setenv(EnvVaultClientTimeout, "10")
 | 
						|
	defer os.Setenv(EnvVaultClientTimeout, oldClientTimeout)
 | 
						|
	config := DefaultConfig()
 | 
						|
	config.ReadEnvironment()
 | 
						|
	_, err := NewClient(config)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type roundTripperFunc func(*http.Request) (*http.Response, error)
 | 
						|
 | 
						|
func (rt roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
 | 
						|
	return rt(r)
 | 
						|
}
 | 
						|
 | 
						|
func TestClientNonTransportRoundTripper(t *testing.T) {
 | 
						|
	client := &http.Client{
 | 
						|
		Transport: roundTripperFunc(http.DefaultTransport.RoundTrip),
 | 
						|
	}
 | 
						|
 | 
						|
	_, err := NewClient(&Config{
 | 
						|
		HttpClient: client,
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestClone(t *testing.T) {
 | 
						|
	client1, err1 := NewClient(nil)
 | 
						|
	if err1 != nil {
 | 
						|
		t.Fatalf("NewClient failed: %v", err1)
 | 
						|
	}
 | 
						|
	client2, err2 := client1.Clone()
 | 
						|
	if err2 != nil {
 | 
						|
		t.Fatalf("Clone failed: %v", err2)
 | 
						|
	}
 | 
						|
 | 
						|
	_ = client2
 | 
						|
}
 |