events: Don't accept websocket connection until subscription is active (#23024)

The WebSocket tests have been very flaky because we weren't able to tell when a WebSocket was fully connected and subscribed to events.

We reworked the websocket subscription code to accept the websocket only after subscribing.

This should eliminate all flakiness in these tests. 🤞 (We can follow-up in an enterprise PR to simplify some of the tests after this fix is merged.)

I ran this locally a bunch of times and with data race detection enabled, and did not see any failures.

Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
Christopher Swenson
2023-09-13 14:28:17 -07:00
committed by GitHub
parent 8e7c6e80d5
commit 82e9b610df
2 changed files with 125 additions and 108 deletions

View File

@@ -31,36 +31,92 @@ var webSocketRevalidationTime = 5 * time.Minute
type eventSubscriber struct { type eventSubscriber struct {
ctx context.Context ctx context.Context
cancelCtx context.CancelFunc
clientToken string clientToken string
capabilitiesFunc func(ctx context.Context, token, path string) ([]string, []string, error)
logger hclog.Logger logger hclog.Logger
events *eventbus.EventBus events *eventbus.EventBus
namespacePatterns []string namespacePatterns []string
pattern string pattern string
bexprFilter string bexprFilter string
conn *websocket.Conn
json bool json bool
checkCache *cache.Cache checkCache *cache.Cache
isRootToken bool isRootToken bool
core *vault.Core
w http.ResponseWriter
r *http.Request
req *logical.Request
} }
// handleEventsSubscribeWebsocket runs forever serving events to the websocket connection, returning a websocket // handleEventsSubscribeWebsocket subscribes to the events, accepts the websocket connection, and then runs forever,
// error code and reason only if the connection closes or there was an error. // serving events to the websocket connection.
func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCode, string, error) { func (sub *eventSubscriber) handleEventsSubscribeWebsocket() {
ctx := sub.ctx ctx := sub.ctx
logger := sub.logger logger := sub.logger
// subscribe before accept to avoid race conditions
ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern, sub.bexprFilter) ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern, sub.bexprFilter)
if err != nil { if err != nil {
logger.Info("Error subscribing", "error", err) logger.Info("Error subscribing", "error", err)
return websocket.StatusUnsupportedData, "Error subscribing", nil sub.w.WriteHeader(400)
sub.w.Write([]byte("Error subscribing"))
return
} }
defer cancel() defer cancel()
logger.Debug("WebSocket is subscribed to messages", "namespaces", sub.namespacePatterns, "event_types", sub.pattern, "bexpr_filter", sub.bexprFilter)
conn, err := websocket.Accept(sub.w, sub.r, nil)
if err != nil {
logger.Info("Could not accept as websocket", "error", err)
respondError(sub.w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket"))
return
}
// continually validate subscribe access while the websocket is running
// this has to be done after accepting the websocket to avoid a race condition
go sub.validateSubscribeAccessLoop()
// make sure to close the websocket
closeStatus := websocket.StatusNormalClosure
closeReason := ""
var closeErr error = nil
defer func() {
if closeErr != nil {
closeStatus = websocket.CloseStatus(err)
if closeStatus == -1 {
closeStatus = websocket.StatusInternalError
}
closeReason = fmt.Sprintf("Internal error: %v", err)
logger.Debug("Error from websocket handler", "error", err)
}
// Close() will panic if the reason is greater than this length
if len(closeReason) > 123 {
logger.Debug("Truncated close reason", "closeReason", closeReason)
closeReason = closeReason[:123]
}
err = conn.Close(closeStatus, closeReason)
if err != nil {
logger.Debug("Error closing websocket", "error", err)
}
}()
// we don't expect any incoming messages
ctx = conn.CloseRead(ctx)
// start the pinger
go func() {
for {
time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive
err := conn.Ping(ctx)
if err != nil {
return
}
}
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
logger.Info("Websocket context is done, closing the connection") logger.Info("Websocket context is done, closing the connection")
return websocket.StatusNormalClosure, "", nil return
case message := <-ch: case message := <-ch:
// Perform one last check that the message is allowed to be received. // Perform one last check that the message is allowed to be received.
// For example, if a new namespace was created that matches the namespace patterns, // For example, if a new namespace was created that matches the namespace patterns,
@@ -78,7 +134,8 @@ func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCo
messageBytes, ok = message.Format("cloudevents-json") messageBytes, ok = message.Format("cloudevents-json")
if !ok { if !ok {
logger.Warn("Could not get cloudevents JSON format") logger.Warn("Could not get cloudevents JSON format")
return 0, "", errors.New("could not get cloudevents JSON format") closeErr = errors.New("could not get cloudevents JSON format")
return
} }
messageType = websocket.MessageText messageType = websocket.MessageText
} else { } else {
@@ -87,11 +144,13 @@ func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCo
} }
if err != nil { if err != nil {
logger.Warn("Could not serialize websocket event", "error", err) logger.Warn("Could not serialize websocket event", "error", err)
return 0, "", err closeErr = err
return
} }
err = sub.conn.Write(ctx, messageType, messageBytes) err = conn.Write(ctx, messageType, messageBytes)
if err != nil { if err != nil {
return 0, "", err closeErr = err
return
} }
} }
} }
@@ -154,7 +213,7 @@ func (sub *eventSubscriber) allowMessage(eventNs, dataPath, eventType string) bo
if eventNs != "" { if eventNs != "" {
nsDataPath = path.Join(eventNs, dataPath) nsDataPath = path.Join(eventNs, dataPath)
} }
capabilities, allowedEventTypes, err := sub.capabilitiesFunc(sub.ctx, sub.clientToken, nsDataPath) capabilities, allowedEventTypes, err := sub.core.CapabilitiesAndSubscribeEventTypes(sub.ctx, sub.clientToken, nsDataPath)
if err != nil { if err != nil {
sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs) sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs)
return false return false
@@ -221,64 +280,28 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
bexprFilter := strings.TrimSpace(r.URL.Query().Get("filter")) bexprFilter := strings.TrimSpace(r.URL.Query().Get("filter"))
namespacePatterns := r.URL.Query()["namespaces"] namespacePatterns := r.URL.Query()["namespaces"]
namespacePatterns = prependNamespacePatterns(namespacePatterns, ns) namespacePatterns = prependNamespacePatterns(namespacePatterns, ns)
conn, err := websocket.Accept(w, r, nil) isRoot := entry.IsRoot()
if err != nil {
logger.Info("Could not accept as websocket", "error", err)
respondError(w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket"))
return
}
// we don't expect any incoming messages
ctx = conn.CloseRead(ctx)
// start the pinger
go func() {
for {
time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive
err := conn.Ping(ctx)
if err != nil {
return
}
}
}()
// continually validate subscribe access while the websocket is running
ctx, cancelCtx := context.WithCancel(ctx) ctx, cancelCtx := context.WithCancel(ctx)
defer cancelCtx() defer cancelCtx()
isRoot := entry.IsRoot()
go validateSubscribeAccessLoop(core, ctx, cancelCtx, req)
sub := &eventSubscriber{ sub := &eventSubscriber{
ctx: ctx, ctx: ctx,
capabilitiesFunc: core.CapabilitiesAndSubscribeEventTypes, cancelCtx: cancelCtx,
logger: logger, logger: logger,
events: core.Events(), events: core.Events(),
namespacePatterns: namespacePatterns, namespacePatterns: namespacePatterns,
pattern: pattern, pattern: pattern,
bexprFilter: bexprFilter, bexprFilter: bexprFilter,
conn: conn,
json: json, json: json,
checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime), checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime),
clientToken: auth.ClientToken, clientToken: auth.ClientToken,
isRootToken: isRoot, isRootToken: isRoot,
core: core,
w: w,
r: r,
req: req,
} }
closeStatus, closeReason, err := sub.handleEventsSubscribeWebsocket() sub.handleEventsSubscribeWebsocket()
if err != nil {
closeStatus = websocket.CloseStatus(err)
if closeStatus == -1 {
closeStatus = websocket.StatusInternalError
}
closeReason = fmt.Sprintf("Internal error: %v", err)
logger.Debug("Error from websocket handler", "error", err)
}
// Close() will panic if the reason is greater than this length
if len(closeReason) > 123 {
logger.Debug("Truncated close reason", "closeReason", closeReason)
closeReason = closeReason[:123]
}
err = conn.Close(closeStatus, closeReason)
if err != nil {
logger.Debug("Error closing websocket", "error", err)
}
}) })
} }
@@ -298,13 +321,13 @@ func prependNamespacePatterns(patterns []string, requestNamespace *namespace.Nam
// validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in // validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in
// its namespace. If the access check ever fails, then the cancel function is called and the function returns. // its namespace. If the access check ever fails, then the cancel function is called and the function returns.
func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel context.CancelFunc, req *logical.Request) { func (sub *eventSubscriber) validateSubscribeAccessLoop() {
// if something breaks, default to canceling the websocket // if something breaks, default to canceling the websocket
defer cancel() defer sub.cancelCtx()
for { for {
_, _, err := core.CheckTokenWithLock(ctx, req, false) _, _, err := sub.core.CheckTokenWithLock(sub.ctx, sub.req, false)
if err != nil { if err != nil {
core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", req.Path, "error", err) sub.core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", sub.req.Path, "error", err)
return return
} }
// wait a while and try again, but quit the loop if the context finishes early // wait a while and try again, but quit the loop if the context finishes early
@@ -312,7 +335,7 @@ func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel c
ticker := time.NewTicker(webSocketRevalidationTime) ticker := time.NewTicker(webSocketRevalidationTime)
defer ticker.Stop() defer ticker.Stop()
select { select {
case <-ctx.Done(): case <-sub.ctx.Done():
return true return true
case <-ticker.C: case <-ticker.C:
return false return false

View File

@@ -12,7 +12,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -46,18 +45,15 @@ func TestEventsSubscribe(t *testing.T) {
} }
} }
stop := atomic.Bool{}
const eventType = "abc" const eventType = "abc"
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
// send some events // send some events
go func() { sendEvents := func() error {
for !stop.Load() {
id, err := uuid.GenerateUUID() id, err := uuid.GenerateUUID()
if err != nil { if err != nil {
core.Logger().Info("Error generating UUID, exiting sender", "error", err) return err
} }
pluginInfo := &logical.EventPluginInfo{ pluginInfo := &logical.EventPluginInfo{
MountPath: "secret", MountPath: "secret",
@@ -69,15 +65,10 @@ func TestEventsSubscribe(t *testing.T) {
Note: "testing", Note: "testing",
}) })
if err != nil { if err != nil {
core.Logger().Info("Error sending event, exiting sender", "error", err) return err
} }
time.Sleep(100 * time.Millisecond) return nil
} }
}()
t.Cleanup(func() {
stop.Store(true)
})
wsAddr := strings.Replace(addr, "http", "ws", 1) wsAddr := strings.Replace(addr, "http", "ws", 1)
@@ -97,6 +88,10 @@ func TestEventsSubscribe(t *testing.T) {
conn.Close(websocket.StatusNormalClosure, "") conn.Close(websocket.StatusNormalClosure, "")
}) })
err = sendEvents()
if err != nil {
t.Fatal(err)
}
_, msg, err := conn.Read(ctx) _, msg, err := conn.Read(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -147,11 +142,7 @@ func TestBexprFilters(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
sendEvents := func(ctx context.Context, eventTypes ...string) error {
// send duplicates to help avoid flaky tests in CI
sendEvents := func(ctx context.Context, eventTypes ...string) {
for i := 0; i < 10; i++ {
time.Sleep(10 * time.Millisecond)
for _, eventType := range eventTypes { for _, eventType := range eventTypes {
pluginInfo := &logical.EventPluginInfo{ pluginInfo := &logical.EventPluginInfo{
MountPath: "secret", MountPath: "secret",
@@ -165,10 +156,10 @@ func TestBexprFilters(t *testing.T) {
Note: "testing", Note: "testing",
}) })
if err != nil { if err != nil {
return return err
}
} }
} }
return nil
} }
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
@@ -185,7 +176,10 @@ func TestBexprFilters(t *testing.T) {
} }
defer conn.Close(websocket.StatusNormalClosure, "") defer conn.Close(websocket.StatusNormalClosure, "")
go sendEvents(ctx, "abc", "def", "xyz") err = sendEvents(ctx, "abc", "def", "xyz")
if err != nil {
t.Fatal(err)
}
// read until we time out // read until we time out
seen := map[string]bool{} seen := map[string]bool{}
done := false done := false