Forward websocket event subscription requests (#22446)

For now, only the leader of a cluster can handle subscription requests,
so we forward the connection request otherwise.

We forward using a 307 temporary redirect (the fallback way).
Forwarding a request over gRPC currently only supports a single request
and response, but a websocket connection is long-lived with potentially
many messages back and forth.

We modified the `vault events subscribe` command to honor those
redirects. `wscat` supports them with the `-L` flag.

In the future, we may add a gRPC method to handle forwarding WebSocket
requests, but doing so adds quite a bit of complexity (even over
normal request forwarding) due to the intricate nature of the `http` /
`vault.Core` interactions required. (I initially went down this path.)

I added tests for the forwarding header, and also tested manually.
(Testing with `-dev-three-node` is a little clumsy since it does not
properly support experiments, for some reason.)

Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
Christopher Swenson
2023-08-22 14:33:31 -07:00
committed by GitHub
parent b14b0aba25
commit 4a5cde0afb
3 changed files with 143 additions and 9 deletions

View File

@@ -102,15 +102,32 @@ func (c *EventsSubscribeCommands) subscribeRequest(client *api.Client, path stri
client.AddHeader("X-Vault-Token", client.Token())
client.AddHeader("X-Vault-Namesapce", client.Namespace())
ctx := context.Background()
conn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPClient: client.CloneConfig().HttpClient,
HTTPHeader: client.Headers(),
})
if err != nil {
if resp != nil && resp.StatusCode == http.StatusNotFound {
return fmt.Errorf("events endpoint not found; check `vault read sys/experiments` to see if an events experiment is available but disabled")
// Follow redirects in case our request if our request is forwarded to the leader.
url := u.String()
var conn *websocket.Conn
var err error
for attempt := 0; attempt < 10; attempt++ {
var resp *http.Response
conn, resp, err = websocket.Dial(ctx, url, &websocket.DialOptions{
HTTPClient: client.CloneConfig().HttpClient,
HTTPHeader: client.Headers(),
})
if err != nil {
if resp != nil {
if resp.StatusCode == http.StatusNotFound {
return fmt.Errorf("events endpoint not found; check `vault read sys/experiments` to see if an events experiment is available but disabled")
} else if resp.StatusCode == http.StatusTemporaryRedirect {
url = resp.Header.Get("Location")
continue
}
}
return err
}
return err
break
}
if conn == nil {
return fmt.Errorf("too many redirects")
}
defer conn.Close(websocket.StatusNormalClosure, "")

View File

@@ -8,17 +8,24 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/experiments"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/cluster"
"nhooyr.io/websocket"
)
@@ -201,3 +208,97 @@ func TestEventsSubscribeAuth(t *testing.T) {
t.Errorf("Expected 403 but got %+v", resp)
}
}
func TestCanForwardEventConnections(t *testing.T) {
// Run again with in-memory network
inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", 3, hclog.New(&hclog.LoggerOptions{
Mutex: &sync.Mutex{},
Level: hclog.Trace,
Name: "inmem-cluster",
}))
if err != nil {
t.Fatal(err)
}
testCluster := vault.NewTestCluster(t, &vault.CoreConfig{
Experiments: []string{experiments.VaultExperimentEventsAlpha1},
AuditBackends: map[string]audit.Factory{
"nop": corehelpers.NoopAuditFactory(nil),
},
}, &vault.TestClusterOptions{
ClusterLayers: inmemCluster,
})
cores := testCluster.Cores
testCluster.Start()
defer testCluster.Cleanup()
rootToken := testCluster.RootToken
// Wait for core to become active
vault.TestWaitActiveForwardingReady(t, cores[0].Core)
// Test forwarding a request. Since we're going directly from core to core
// with no fallback we know that if it worked, request handling is working
c := cores[1]
standby, err := c.Standby()
if err != nil {
t.Fatal(err)
}
if !standby {
t.Fatal("expected core to be standby")
}
// We need to call Leader as that refreshes the connection info
isLeader, _, _, err := c.Leader()
if err != nil {
t.Fatal(err)
}
if isLeader {
t.Fatal("core should not be leader")
}
corehelpers.RetryUntil(t, 5*time.Second, func() error {
state := c.ActiveNodeReplicationState()
if state == 0 {
return fmt.Errorf("heartbeats have not yet returned a valid active node replication state: %d", state)
}
return nil
})
req, err := http.NewRequest("GET", "https://pushit.real.good:9281/v1/sys/events/subscribe/xyz?json=true", nil)
if err != nil {
t.Fatal(err)
}
req = req.WithContext(namespace.RootContext(req.Context()))
req.Header.Add(consts.AuthHeaderName, rootToken)
resp := httptest.NewRecorder()
forwardRequest(cores[1].Core, resp, req)
header := resp.Header()
if header == nil {
t.Fatal("err: expected at least a Location header")
}
if !strings.HasPrefix(header.Get("Location"), "wss://") {
t.Fatalf("bad location: %s", header.Get("Location"))
}
// test forwarding requests to each core
handled := 0
forwarded := 0
for _, c := range cores {
resp := httptest.NewRecorder()
fakeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handled++
})
handleRequestForwarding(c.Core, fakeHandler).ServeHTTP(resp, req)
header := resp.Header()
if header == nil {
continue
}
if strings.HasPrefix(header.Get("Location"), "wss://") {
forwarded++
}
}
if handled != 1 && forwarded != 2 {
t.Fatalf("Expected 1 core to handle the request and 2 to forward")
}
}

View File

@@ -84,6 +84,7 @@ var (
// the always forward list
perfStandbyAlwaysForwardPaths = pathmanager.New()
alwaysRedirectPaths = pathmanager.New()
websocketPaths = pathmanager.New()
injectDataIntoTopRoutes = []string{
"/v1/sys/audit",
@@ -109,7 +110,9 @@ var (
"/v1/sys/rotate",
"/v1/sys/wrapping/wrap",
}
websocketRawPaths = []string{
"/v1/sys/events/subscribe",
}
oidcProtectedPathRegex = regexp.MustCompile(`^identity/oidc/provider/\w(([\w-.]+)?\w)?/userinfo$`)
)
@@ -119,6 +122,10 @@ func init() {
"sys/storage/raft/snapshot-force",
"!sys/storage/raft/snapshot-auto/config",
})
websocketPaths.AddPaths(websocketRawPaths)
for _, path := range websocketRawPaths {
alwaysRedirectPaths.AddPaths([]string{strings.TrimPrefix(path, "/v1/")})
}
}
type HandlerAnchor struct{}
@@ -1017,6 +1024,15 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) {
RawQuery: reqURL.RawQuery,
}
// WebSockets schemas are ws or wss
if websocketPaths.HasPath(reqURL.Path) {
if finalURL.Scheme == "http" {
finalURL.Scheme = "ws"
} else {
finalURL.Scheme = "wss"
}
}
// Ensure there is a scheme, default to https
if finalURL.Scheme == "" {
finalURL.Scheme = "https"