mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
521
http/forwarding_test.go
Normal file
521
http/forwarding_test.go
Normal file
@@ -0,0 +1,521 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/vault/api"
|
||||
credCert "github.com/hashicorp/vault/builtin/credential/cert"
|
||||
"github.com/hashicorp/vault/builtin/logical/transit"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func TestHTTP_Fallback_Bad_Address(t *testing.T) {
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"transit": transit.Factory,
|
||||
},
|
||||
ClusterAddr: "https://127.3.4.1:8382",
|
||||
}
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
// make it easy to get access to the active
|
||||
core := cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
addrs := []string{
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
config := api.DefaultConfig()
|
||||
config.Address = addr
|
||||
config.HttpClient = cleanhttp.DefaultClient()
|
||||
config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client.SetToken(root)
|
||||
|
||||
secret, err := client.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret == nil {
|
||||
t.Fatal("secret is nil")
|
||||
}
|
||||
if secret.Data["id"].(string) != root {
|
||||
t.Fatal("token mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP_Fallback_Disabled(t *testing.T) {
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"transit": transit.Factory,
|
||||
},
|
||||
ClusterAddr: "empty",
|
||||
}
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
// make it easy to get access to the active
|
||||
core := cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
addrs := []string{
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
config := api.DefaultConfig()
|
||||
config.Address = addr
|
||||
config.HttpClient = cleanhttp.DefaultClient()
|
||||
config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client.SetToken(root)
|
||||
|
||||
secret, err := client.Auth().Token().LookupSelf()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret == nil {
|
||||
t.Fatal("secret is nil")
|
||||
}
|
||||
if secret.Data["id"].(string) != root {
|
||||
t.Fatal("token mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function recreates the fuzzy testing from transit to pipe a large
|
||||
// number of requests from the standbys to the active node.
|
||||
func TestHTTP_Forwarding_Stress(t *testing.T) {
|
||||
testPlaintext := "the quick brown fox"
|
||||
testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"transit": transit.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
// make it easy to get access to the active
|
||||
core := cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
|
||||
keys := []string{"test1", "test2", "test3"}
|
||||
|
||||
hosts := []string{
|
||||
fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[1].Listeners[0].Address.Port),
|
||||
fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[2].Listeners[0].Address.Port),
|
||||
}
|
||||
|
||||
transport := cleanhttp.DefaultPooledTransport()
|
||||
transport.TLSClientConfig = cores[0].TLSConfig
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return fmt.Errorf("redirects not allowed in this test")
|
||||
},
|
||||
}
|
||||
|
||||
//core.Logger().Printf("[TRACE] mounting transit")
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/mounts/transit", cores[0].Listeners[0].Address.Port),
|
||||
bytes.NewBuffer([]byte("{\"type\": \"transit\"}")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, root)
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//core.Logger().Printf("[TRACE] done mounting transit")
|
||||
|
||||
var totalOps int64
|
||||
var successfulOps int64
|
||||
var key1ver int64 = 1
|
||||
var key2ver int64 = 1
|
||||
var key3ver int64 = 1
|
||||
|
||||
// This is the goroutine loop
|
||||
doFuzzy := func(id int) {
|
||||
// Check for panics, otherwise notify we're done
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
core.Logger().Printf("[ERR] got a panic: %v", err)
|
||||
t.Fail()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Holds the latest encrypted value for each key
|
||||
latestEncryptedText := map[string]string{}
|
||||
|
||||
startTime := time.Now()
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
var chosenFunc, chosenKey, chosenHost string
|
||||
|
||||
doReq := func(method, url string, body io.Reader) (*http.Response, error) {
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, root)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
doResp := func(resp *http.Response) (*api.Secret, error) {
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("nil response")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Make sure we weren't redirected
|
||||
if resp.StatusCode > 300 && resp.StatusCode < 400 {
|
||||
return nil, fmt.Errorf("got status code %d, resp was %#v", resp.StatusCode, *resp)
|
||||
}
|
||||
|
||||
result := &api.Response{Response: resp}
|
||||
err = result.Error()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
secret, err := api.ParseSecret(result.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
for _, chosenHost := range hosts {
|
||||
for _, chosenKey := range keys {
|
||||
// Try to write the key to make sure it exists
|
||||
_, err := doReq("POST", chosenHost+"keys/"+chosenKey, bytes.NewBuffer([]byte("{}")))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//core.Logger().Printf("[TRACE] Starting %d", id)
|
||||
for {
|
||||
// Stop after 10 seconds
|
||||
if time.Now().Sub(startTime) > 10*time.Second {
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddInt64(&totalOps, 1)
|
||||
|
||||
// Pick a function and a key
|
||||
chosenFunc = funcs[rand.Int()%len(funcs)]
|
||||
chosenKey = keys[rand.Int()%len(keys)]
|
||||
chosenHost = hosts[rand.Int()%len(hosts)]
|
||||
|
||||
switch chosenFunc {
|
||||
// Encrypt our plaintext and store the result
|
||||
case "encrypt":
|
||||
//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
||||
resp, err := doReq("POST", chosenHost+"encrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"plaintext\": \"%s\"}", testPlaintextB64))))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
secret, err := doResp(resp)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
latest := secret.Data["ciphertext"].(string)
|
||||
if latest == "" {
|
||||
panic(fmt.Errorf("bad ciphertext"))
|
||||
}
|
||||
latestEncryptedText[chosenKey] = secret.Data["ciphertext"].(string)
|
||||
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
|
||||
// Decrypt the ciphertext and compare the result
|
||||
case "decrypt":
|
||||
ct := latestEncryptedText[chosenKey]
|
||||
if ct == "" {
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
continue
|
||||
}
|
||||
|
||||
//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
||||
resp, err := doReq("POST", chosenHost+"decrypt/"+chosenKey, bytes.NewBuffer([]byte(fmt.Sprintf("{\"ciphertext\": \"%s\"}", ct))))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
secret, err := doResp(resp)
|
||||
if err != nil {
|
||||
// This could well happen since the min version is jumping around
|
||||
if strings.Contains(err.Error(), transit.ErrTooOld) {
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
continue
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ptb64 := secret.Data["plaintext"].(string)
|
||||
pt, err := base64.StdEncoding.DecodeString(ptb64)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("got an error decoding base64 plaintext: %v", err))
|
||||
}
|
||||
if string(pt) != testPlaintext {
|
||||
panic(fmt.Errorf("got bad plaintext back: %s", pt))
|
||||
}
|
||||
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
|
||||
// Rotate to a new key version
|
||||
case "rotate":
|
||||
//core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
|
||||
_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/rotate", bytes.NewBuffer([]byte("{}")))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
switch chosenKey {
|
||||
case "test1":
|
||||
atomic.AddInt64(&key1ver, 1)
|
||||
case "test2":
|
||||
atomic.AddInt64(&key2ver, 1)
|
||||
case "test3":
|
||||
atomic.AddInt64(&key3ver, 1)
|
||||
}
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
|
||||
// Change the min version, which also tests the archive functionality
|
||||
case "change_min_version":
|
||||
var latestVersion int64
|
||||
switch chosenKey {
|
||||
case "test1":
|
||||
latestVersion = atomic.LoadInt64(&key1ver)
|
||||
case "test2":
|
||||
latestVersion = atomic.LoadInt64(&key2ver)
|
||||
case "test3":
|
||||
latestVersion = atomic.LoadInt64(&key3ver)
|
||||
}
|
||||
|
||||
setVersion := (rand.Int63() % latestVersion) + 1
|
||||
|
||||
//core.Logger().Printf("[TRACE] %s, %s, %d, new min version %d", chosenFunc, chosenKey, id, setVersion)
|
||||
|
||||
_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/config", bytes.NewBuffer([]byte(fmt.Sprintf("{\"min_decryption_version\": %d}", setVersion))))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&successfulOps, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn 20 of these workers for 10 seconds
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
//core.Logger().Printf("[TRACE] spawning %d", i)
|
||||
go doFuzzy(i)
|
||||
}
|
||||
|
||||
// Wait for them all to finish
|
||||
wg.Wait()
|
||||
|
||||
core.Logger().Printf("[TRACE] total operations tried: %d, total successful: %d", totalOps, successfulOps)
|
||||
if totalOps != successfulOps {
|
||||
t.Fatalf("total/successful ops mismatch: %d/%d", totalOps, successfulOps)
|
||||
}
|
||||
}
|
||||
|
||||
// This tests TLS connection state forwarding by ensuring that we can use a
|
||||
// client TLS to authenticate against the cert backend
|
||||
func TestHTTP_Forwarding_ClientTLS(t *testing.T) {
|
||||
handler1 := http.NewServeMux()
|
||||
handler2 := http.NewServeMux()
|
||||
handler3 := http.NewServeMux()
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"cert": credCert.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
// Chicken-and-egg: Handler needs a core. So we create handlers first, then
|
||||
// add routes chained to a Handler-created handler.
|
||||
cores := vault.TestCluster(t, []http.Handler{handler1, handler2, handler3}, coreConfig, true)
|
||||
for _, core := range cores {
|
||||
defer core.CloseListeners()
|
||||
}
|
||||
handler1.Handle("/", Handler(cores[0].Core))
|
||||
handler2.Handle("/", Handler(cores[1].Core))
|
||||
handler3.Handle("/", Handler(cores[2].Core))
|
||||
|
||||
// make it easy to get access to the active
|
||||
core := cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
root := cores[0].Root
|
||||
|
||||
transport := cleanhttp.DefaultTransport()
|
||||
transport.TLSClientConfig = cores[0].TLSConfig
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/auth/cert", cores[0].Listeners[0].Address.Port),
|
||||
bytes.NewBuffer([]byte("{\"type\": \"cert\"}")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, root)
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type certConfig struct {
|
||||
Certificate string `json:"certificate"`
|
||||
Policies string `json:"policies"`
|
||||
}
|
||||
encodedCertConfig, err := json.Marshal(&certConfig{
|
||||
Certificate: vault.TestClusterCACert,
|
||||
Policies: "default",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req, err = http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/auth/cert/certs/test", cores[0].Listeners[0].Address.Port),
|
||||
bytes.NewBuffer(encodedCertConfig))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set(AuthHeaderName, root)
|
||||
_, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addrs := []string{
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
|
||||
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
|
||||
}
|
||||
|
||||
// Ensure we can't possibly use lingering connections even though it should be to a different address
|
||||
|
||||
transport = cleanhttp.DefaultTransport()
|
||||
transport.TLSClientConfig = cores[0].TLSConfig
|
||||
|
||||
client = &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return fmt.Errorf("redirects not allowed in this test")
|
||||
},
|
||||
}
|
||||
|
||||
//cores[0].Logger().Printf("root token is %s", root)
|
||||
//time.Sleep(4 * time.Hour)
|
||||
|
||||
for _, addr := range addrs {
|
||||
config := api.DefaultConfig()
|
||||
config.Address = addr
|
||||
config.HttpClient = client
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secret, err := client.Logical().Write("auth/cert/login", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if secret == nil {
|
||||
t.Fatal("secret is nil")
|
||||
}
|
||||
if secret.Auth == nil {
|
||||
t.Fatal("auth is nil")
|
||||
}
|
||||
if secret.Auth.Policies == nil || len(secret.Auth.Policies) == 0 || secret.Auth.Policies[0] != "default" {
|
||||
t.Fatalf("bad policies: %#v", secret.Auth.Policies)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user