mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 09:42:25 +00:00
Forward websocket event subscription requests (#22446)
For now, only the leader of a cluster can handle subscription requests, so we forward the connection request otherwise. We forward using a 307 temporary redirect (the fallback way). Forwarding a request over gRPC currently only supports a single request and response, but a websocket connection is long-lived with potentially many messages back and forth. We modified the `vault events subscribe` command to honor those redirects. `wscat` supports them with the `-L` flag. In the future, we may add a gRPC method to handle forwarding WebSocket requests, but doing so adds quite a bit of complexity (even over normal request forwarding) due to the intricate nature of the `http` / `vault.Core` interactions required. (I initially went down this path.) I added tests for the forwarding header, and also tested manually. (Testing with `-dev-three-node` is a little clumsy since it does not properly support experiments, for some reason.) Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b14b0aba25
commit
4a5cde0afb
@@ -102,15 +102,32 @@ func (c *EventsSubscribeCommands) subscribeRequest(client *api.Client, path stri
|
||||
client.AddHeader("X-Vault-Token", client.Token())
|
||||
client.AddHeader("X-Vault-Namesapce", client.Namespace())
|
||||
ctx := context.Background()
|
||||
conn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
|
||||
HTTPClient: client.CloneConfig().HttpClient,
|
||||
HTTPHeader: client.Headers(),
|
||||
})
|
||||
if err != nil {
|
||||
if resp != nil && resp.StatusCode == http.StatusNotFound {
|
||||
return fmt.Errorf("events endpoint not found; check `vault read sys/experiments` to see if an events experiment is available but disabled")
|
||||
|
||||
// Follow redirects in case our request if our request is forwarded to the leader.
|
||||
url := u.String()
|
||||
var conn *websocket.Conn
|
||||
var err error
|
||||
for attempt := 0; attempt < 10; attempt++ {
|
||||
var resp *http.Response
|
||||
conn, resp, err = websocket.Dial(ctx, url, &websocket.DialOptions{
|
||||
HTTPClient: client.CloneConfig().HttpClient,
|
||||
HTTPHeader: client.Headers(),
|
||||
})
|
||||
if err != nil {
|
||||
if resp != nil {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return fmt.Errorf("events endpoint not found; check `vault read sys/experiments` to see if an events experiment is available but disabled")
|
||||
} else if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||
url = resp.Header.Get("Location")
|
||||
continue
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
return err
|
||||
break
|
||||
}
|
||||
if conn == nil {
|
||||
return fmt.Errorf("too many redirects")
|
||||
}
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
|
||||
|
||||
@@ -8,17 +8,24 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/audit"
|
||||
"github.com/hashicorp/vault/helper/experiments"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/hashicorp/vault/vault/cluster"
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
@@ -201,3 +208,97 @@ func TestEventsSubscribeAuth(t *testing.T) {
|
||||
t.Errorf("Expected 403 but got %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanForwardEventConnections(t *testing.T) {
|
||||
// Run again with in-memory network
|
||||
inmemCluster, err := cluster.NewInmemLayerCluster("inmem-cluster", 3, hclog.New(&hclog.LoggerOptions{
|
||||
Mutex: &sync.Mutex{},
|
||||
Level: hclog.Trace,
|
||||
Name: "inmem-cluster",
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testCluster := vault.NewTestCluster(t, &vault.CoreConfig{
|
||||
Experiments: []string{experiments.VaultExperimentEventsAlpha1},
|
||||
AuditBackends: map[string]audit.Factory{
|
||||
"nop": corehelpers.NoopAuditFactory(nil),
|
||||
},
|
||||
}, &vault.TestClusterOptions{
|
||||
ClusterLayers: inmemCluster,
|
||||
})
|
||||
cores := testCluster.Cores
|
||||
testCluster.Start()
|
||||
defer testCluster.Cleanup()
|
||||
|
||||
rootToken := testCluster.RootToken
|
||||
|
||||
// Wait for core to become active
|
||||
vault.TestWaitActiveForwardingReady(t, cores[0].Core)
|
||||
|
||||
// Test forwarding a request. Since we're going directly from core to core
|
||||
// with no fallback we know that if it worked, request handling is working
|
||||
c := cores[1]
|
||||
standby, err := c.Standby()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !standby {
|
||||
t.Fatal("expected core to be standby")
|
||||
}
|
||||
|
||||
// We need to call Leader as that refreshes the connection info
|
||||
isLeader, _, _, err := c.Leader()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if isLeader {
|
||||
t.Fatal("core should not be leader")
|
||||
}
|
||||
corehelpers.RetryUntil(t, 5*time.Second, func() error {
|
||||
state := c.ActiveNodeReplicationState()
|
||||
if state == 0 {
|
||||
return fmt.Errorf("heartbeats have not yet returned a valid active node replication state: %d", state)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("GET", "https://pushit.real.good:9281/v1/sys/events/subscribe/xyz?json=true", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req = req.WithContext(namespace.RootContext(req.Context()))
|
||||
req.Header.Add(consts.AuthHeaderName, rootToken)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
forwardRequest(cores[1].Core, resp, req)
|
||||
|
||||
header := resp.Header()
|
||||
if header == nil {
|
||||
t.Fatal("err: expected at least a Location header")
|
||||
}
|
||||
if !strings.HasPrefix(header.Get("Location"), "wss://") {
|
||||
t.Fatalf("bad location: %s", header.Get("Location"))
|
||||
}
|
||||
|
||||
// test forwarding requests to each core
|
||||
handled := 0
|
||||
forwarded := 0
|
||||
for _, c := range cores {
|
||||
resp := httptest.NewRecorder()
|
||||
fakeHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handled++
|
||||
})
|
||||
handleRequestForwarding(c.Core, fakeHandler).ServeHTTP(resp, req)
|
||||
header := resp.Header()
|
||||
if header == nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(header.Get("Location"), "wss://") {
|
||||
forwarded++
|
||||
}
|
||||
}
|
||||
if handled != 1 && forwarded != 2 {
|
||||
t.Fatalf("Expected 1 core to handle the request and 2 to forward")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,6 +84,7 @@ var (
|
||||
// the always forward list
|
||||
perfStandbyAlwaysForwardPaths = pathmanager.New()
|
||||
alwaysRedirectPaths = pathmanager.New()
|
||||
websocketPaths = pathmanager.New()
|
||||
|
||||
injectDataIntoTopRoutes = []string{
|
||||
"/v1/sys/audit",
|
||||
@@ -109,7 +110,9 @@ var (
|
||||
"/v1/sys/rotate",
|
||||
"/v1/sys/wrapping/wrap",
|
||||
}
|
||||
|
||||
websocketRawPaths = []string{
|
||||
"/v1/sys/events/subscribe",
|
||||
}
|
||||
oidcProtectedPathRegex = regexp.MustCompile(`^identity/oidc/provider/\w(([\w-.]+)?\w)?/userinfo$`)
|
||||
)
|
||||
|
||||
@@ -119,6 +122,10 @@ func init() {
|
||||
"sys/storage/raft/snapshot-force",
|
||||
"!sys/storage/raft/snapshot-auto/config",
|
||||
})
|
||||
websocketPaths.AddPaths(websocketRawPaths)
|
||||
for _, path := range websocketRawPaths {
|
||||
alwaysRedirectPaths.AddPaths([]string{strings.TrimPrefix(path, "/v1/")})
|
||||
}
|
||||
}
|
||||
|
||||
type HandlerAnchor struct{}
|
||||
@@ -1017,6 +1024,15 @@ func respondStandby(core *vault.Core, w http.ResponseWriter, reqURL *url.URL) {
|
||||
RawQuery: reqURL.RawQuery,
|
||||
}
|
||||
|
||||
// WebSockets schemas are ws or wss
|
||||
if websocketPaths.HasPath(reqURL.Path) {
|
||||
if finalURL.Scheme == "http" {
|
||||
finalURL.Scheme = "ws"
|
||||
} else {
|
||||
finalURL.Scheme = "wss"
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure there is a scheme, default to https
|
||||
if finalURL.Scheme == "" {
|
||||
finalURL.Scheme = "https"
|
||||
|
||||
Reference in New Issue
Block a user