diff --git a/command/events.go b/command/events.go index 4aa312b077..e918200913 100644 --- a/command/events.go +++ b/command/events.go @@ -102,15 +102,32 @@ func (c *EventsSubscribeCommands) subscribeRequest(client *api.Client, path stri client.AddHeader("X-Vault-Token", client.Token()) client.AddHeader("X-Vault-Namesapce", client.Namespace()) ctx := context.Background() - conn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{ - HTTPClient: client.CloneConfig().HttpClient, - HTTPHeader: client.Headers(), - }) - if err != nil { - if resp != nil && resp.StatusCode == http.StatusNotFound { - return fmt.Errorf("events endpoint not found; check `vault read sys/experiments` to see if an events experiment is available but disabled") + + // Follow redirects in case our request if our request is forwarded to the leader. + url := u.String() + var conn *websocket.Conn + var err error + for attempt := 0; attempt < 10; attempt++ { + var resp *http.Response + conn, resp, err = websocket.Dial(ctx, url, &websocket.DialOptions{ + HTTPClient: client.CloneConfig().HttpClient, + HTTPHeader: client.Headers(), + }) + if err != nil { + if resp != nil { + if resp.StatusCode == http.StatusNotFound { + return fmt.Errorf("events endpoint not found; check `vault read sys/experiments` to see if an events experiment is available but disabled") + } else if resp.StatusCode == http.StatusTemporaryRedirect { + url = resp.Header.Get("Location") + continue + } + } + return err } - return err + break + } + if conn == nil { + return fmt.Errorf("too many redirects") } defer conn.Close(websocket.StatusNormalClosure, "") diff --git a/http/events_test.go b/http/events_test.go index 850724a57b..7ffed44280 100644 --- a/http/events_test.go +++ b/http/events_test.go @@ -8,17 +8,24 @@ import ( "encoding/json" "fmt" "net/http" + "net/http/httptest" "strings" + "sync" "sync/atomic" "testing" "time" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/helper/experiments" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/helper/testhelpers/corehelpers" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/vault/cluster" "nhooyr.io/websocket" ) @@ -201,3 +208,97 @@ func TestEventsSubscribeAuth(t *testing.T) { t.Errorf("Expected 403 but got %+v", resp) } } + +func TestCanForwardEventConnections(t *testing.T) { + // Run again with in-memory network + inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", 3, hclog.New(&hclog.LoggerOptions{ + Mutex: &sync.Mutex{}, + Level: hclog.Trace, + Name: "inmem-cluster", + })) + if err != nil { + t.Fatal(err) + } + testCluster := vault.NewTestCluster(t, &vault.CoreConfig{ + Experiments: []string{experiments.VaultExperimentEventsAlpha1}, + AuditBackends: map[string]audit.Factory{ + "nop": corehelpers.NoopAuditFactory(nil), + }, + }, &vault.TestClusterOptions{ + ClusterLayers: inmemCluster, + }) + cores := testCluster.Cores + testCluster.Start() + defer testCluster.Cleanup() + + rootToken := testCluster.RootToken + + // Wait for core to become active + vault.TestWaitActiveForwardingReady(t, cores[0].Core) + + // Test forwarding a request. Since we're going directly from core to core + // with no fallback we know that if it worked, request handling is working + c := cores[1] + standby, err := c.Standby() + if err != nil { + t.Fatal(err) + } + if !standby { + t.Fatal("expected core to be standby") + } + + // We need to call Leader as that refreshes the connection info + isLeader, _, _, err := c.Leader() + if err != nil { + t.Fatal(err) + } + if isLeader { + t.Fatal("core should not be leader") + } + corehelpers.RetryUntil(t, 5*time.Second, func() error { + state := c.ActiveNodeReplicationState() + if state == 0 { + return fmt.Errorf("heartbeats have not yet returned a valid active node replication state: %d", state) + } + return nil + }) + + req, err := http.NewRequest("GET", "https://pushit.real.good:9281/v1/sys/events/subscribe/xyz?json=true", nil) + if err != nil { + t.Fatal(err) + } + req = req.WithContext(namespace.RootContext(req.Context())) + req.Header.Add(consts.AuthHeaderName, rootToken) + + resp := httptest.NewRecorder() + forwardRequest(cores[1].Core, resp, req) + + header := resp.Header() + if header == nil { + t.Fatal("err: expected at least a Location header") + } + if !strings.HasPrefix(header.Get("Location"), "wss://") { + t.Fatalf("bad location: %s", header.Get("Location")) + } + + // test forwarding requests to each core + handled := 0 + forwarded := 0 + for _, c := range cores { + resp := httptest.NewRecorder() + fakeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handled++ + }) + handleRequestForwarding(c.Core, fakeHandler).ServeHTTP(resp, req) + header := resp.Header() + if header == nil { + continue + } + if strings.HasPrefix(header.Get("Location"), "wss://") { + forwarded++ + } + } + if handled != 1 && forwarded != 2 { + t.Fatalf("Expected 1 core to handle the request and 2 to forward") + } +} diff --git a/http/handler.go b/http/handler.go index aac805115b..40d57d051b 100644 --- a/http/handler.go +++ b/http/handler.go @@ -84,6 +84,7 @@ var ( // the always forward list perfStandbyAlwaysForwardPaths = pathmanager.New() alwaysRedirectPaths = pathmanager.New() + websocketPaths = pathmanager.New() injectDataIntoTopRoutes = []string{ "/v1/sys/audit", @@ -109,7 +110,9 @@ var ( "/v1/sys/rotate", "/v1/sys/wrapping/wrap", } - + websocketRawPaths = []string{ + "/v1/sys/events/subscribe", + } oidcProtectedPathRegex = regexp.MustCompile(`^identity/oidc/provider/\w(([\w-.]+)?\w)?/userinfo$`) ) @@ -119,6 +122,10 @@ func init() { "sys/storage/raft/snapshot-force", "!sys/storage/raft/snapshot-auto/config", }) + websocketPaths.AddPaths(websocketRawPaths) + for _, path := range websocketRawPaths { + alwaysRedirectPaths.AddPaths([]string{strings.TrimPrefix(path, "/v1/")}) + } } type HandlerAnchor struct{} @@ -1017,6 +1024,15 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) { RawQuery: reqURL.RawQuery, } + // WebSockets schemas are ws or wss + if websocketPaths.HasPath(reqURL.Path) { + if finalURL.Scheme == "http" { + finalURL.Scheme = "ws" + } else { + finalURL.Scheme = "wss" + } + } + // Ensure there is a scheme, default to https if finalURL.Scheme == "" { finalURL.Scheme = "https"