events: Don't accept websocket connection until subscription is active (#23024)

The WebSocket tests have been very flaky because we weren't able to tell when a WebSocket was fully connected and subscribed to events.

We reworked the websocket subscription code to accept the websocket only after subscribing.

This should eliminate all flakiness in these tests. 🤞 (We can follow-up in an enterprise PR to simplify some of the tests after this fix is merged.)

I ran this locally a bunch of times and with data race detection enabled, and did not see any failures.

Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
Christopher Swenson
2023-09-13 14:28:17 -07:00
committed by GitHub
parent 8e7c6e80d5
commit 82e9b610df
2 changed files with 125 additions and 108 deletions

View File

@@ -31,36 +31,92 @@ var webSocketRevalidationTime = 5 * time.Minute
type eventSubscriber struct {
ctx context.Context
cancelCtx context.CancelFunc
clientToken string
capabilitiesFunc func(ctx context.Context, token, path string) ([]string, []string, error)
logger hclog.Logger
events *eventbus.EventBus
namespacePatterns []string
pattern string
bexprFilter string
conn *websocket.Conn
json bool
checkCache *cache.Cache
isRootToken bool
core *vault.Core
w http.ResponseWriter
r *http.Request
req *logical.Request
}
// handleEventsSubscribeWebsocket runs forever serving events to the websocket connection, returning a websocket
// error code and reason only if the connection closes or there was an error.
func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCode, string, error) {
// handleEventsSubscribeWebsocket subscribes to the events, accepts the websocket connection, and then runs forever,
// serving events to the websocket connection.
func (sub *eventSubscriber) handleEventsSubscribeWebsocket() {
ctx := sub.ctx
logger := sub.logger
// subscribe before accept to avoid race conditions
ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern, sub.bexprFilter)
if err != nil {
logger.Info("Error subscribing", "error", err)
return websocket.StatusUnsupportedData, "Error subscribing", nil
sub.w.WriteHeader(400)
sub.w.Write([]byte("Error subscribing"))
return
}
defer cancel()
logger.Debug("WebSocket is subscribed to messages", "namespaces", sub.namespacePatterns, "event_types", sub.pattern, "bexpr_filter", sub.bexprFilter)
conn, err := websocket.Accept(sub.w, sub.r, nil)
if err != nil {
logger.Info("Could not accept as websocket", "error", err)
respondError(sub.w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket"))
return
}
// continually validate subscribe access while the websocket is running
// this has to be done after accepting the websocket to avoid a race condition
go sub.validateSubscribeAccessLoop()
// make sure to close the websocket
closeStatus := websocket.StatusNormalClosure
closeReason := ""
var closeErr error = nil
defer func() {
if closeErr != nil {
closeStatus = websocket.CloseStatus(err)
if closeStatus == -1 {
closeStatus = websocket.StatusInternalError
}
closeReason = fmt.Sprintf("Internal error: %v", err)
logger.Debug("Error from websocket handler", "error", err)
}
// Close() will panic if the reason is greater than this length
if len(closeReason) > 123 {
logger.Debug("Truncated close reason", "closeReason", closeReason)
closeReason = closeReason[:123]
}
err = conn.Close(closeStatus, closeReason)
if err != nil {
logger.Debug("Error closing websocket", "error", err)
}
}()
// we don't expect any incoming messages
ctx = conn.CloseRead(ctx)
// start the pinger
go func() {
for {
time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive
err := conn.Ping(ctx)
if err != nil {
return
}
}
}()
for {
select {
case <-ctx.Done():
logger.Info("Websocket context is done, closing the connection")
return websocket.StatusNormalClosure, "", nil
return
case message := <-ch:
// Perform one last check that the message is allowed to be received.
// For example, if a new namespace was created that matches the namespace patterns,
@@ -78,7 +134,8 @@ func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCo
messageBytes, ok = message.Format("cloudevents-json")
if !ok {
logger.Warn("Could not get cloudevents JSON format")
return 0, "", errors.New("could not get cloudevents JSON format")
closeErr = errors.New("could not get cloudevents JSON format")
return
}
messageType = websocket.MessageText
} else {
@@ -87,11 +144,13 @@ func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCo
}
if err != nil {
logger.Warn("Could not serialize websocket event", "error", err)
return 0, "", err
closeErr = err
return
}
err = sub.conn.Write(ctx, messageType, messageBytes)
err = conn.Write(ctx, messageType, messageBytes)
if err != nil {
return 0, "", err
closeErr = err
return
}
}
}
@@ -154,7 +213,7 @@ func (sub *eventSubscriber) allowMessage(eventNs, dataPath, eventType string) bo
if eventNs != "" {
nsDataPath = path.Join(eventNs, dataPath)
}
capabilities, allowedEventTypes, err := sub.capabilitiesFunc(sub.ctx, sub.clientToken, nsDataPath)
capabilities, allowedEventTypes, err := sub.core.CapabilitiesAndSubscribeEventTypes(sub.ctx, sub.clientToken, nsDataPath)
if err != nil {
sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs)
return false
@@ -221,64 +280,28 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
bexprFilter := strings.TrimSpace(r.URL.Query().Get("filter"))
namespacePatterns := r.URL.Query()["namespaces"]
namespacePatterns = prependNamespacePatterns(namespacePatterns, ns)
conn, err := websocket.Accept(w, r, nil)
if err != nil {
logger.Info("Could not accept as websocket", "error", err)
respondError(w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket"))
return
}
// we don't expect any incoming messages
ctx = conn.CloseRead(ctx)
// start the pinger
go func() {
for {
time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive
err := conn.Ping(ctx)
if err != nil {
return
}
}
}()
// continually validate subscribe access while the websocket is running
isRoot := entry.IsRoot()
ctx, cancelCtx := context.WithCancel(ctx)
defer cancelCtx()
isRoot := entry.IsRoot()
go validateSubscribeAccessLoop(core, ctx, cancelCtx, req)
sub := &eventSubscriber{
ctx: ctx,
capabilitiesFunc: core.CapabilitiesAndSubscribeEventTypes,
cancelCtx: cancelCtx,
logger: logger,
events: core.Events(),
namespacePatterns: namespacePatterns,
pattern: pattern,
bexprFilter: bexprFilter,
conn: conn,
json: json,
checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime),
clientToken: auth.ClientToken,
isRootToken: isRoot,
core: core,
w: w,
r: r,
req: req,
}
closeStatus, closeReason, err := sub.handleEventsSubscribeWebsocket()
if err != nil {
closeStatus = websocket.CloseStatus(err)
if closeStatus == -1 {
closeStatus = websocket.StatusInternalError
}
closeReason = fmt.Sprintf("Internal error: %v", err)
logger.Debug("Error from websocket handler", "error", err)
}
// Close() will panic if the reason is greater than this length
if len(closeReason) > 123 {
logger.Debug("Truncated close reason", "closeReason", closeReason)
closeReason = closeReason[:123]
}
err = conn.Close(closeStatus, closeReason)
if err != nil {
logger.Debug("Error closing websocket", "error", err)
}
sub.handleEventsSubscribeWebsocket()
})
}
@@ -298,13 +321,13 @@ func prependNamespacePatterns(patterns []string, requestNamespace *namespace.Nam
// validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in
// its namespace. If the access check ever fails, then the cancel function is called and the function returns.
func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel context.CancelFunc, req *logical.Request) {
func (sub *eventSubscriber) validateSubscribeAccessLoop() {
// if something breaks, default to canceling the websocket
defer cancel()
defer sub.cancelCtx()
for {
_, _, err := core.CheckTokenWithLock(ctx, req, false)
_, _, err := sub.core.CheckTokenWithLock(sub.ctx, sub.req, false)
if err != nil {
core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", req.Path, "error", err)
sub.core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", sub.req.Path, "error", err)
return
}
// wait a while and try again, but quit the loop if the context finishes early
@@ -312,7 +335,7 @@ func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel c
ticker := time.NewTicker(webSocketRevalidationTime)
defer ticker.Stop()
select {
case <-ctx.Done():
case <-sub.ctx.Done():
return true
case <-ticker.C:
return false

View File

@@ -12,7 +12,6 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -46,38 +45,30 @@ func TestEventsSubscribe(t *testing.T) {
}
}
stop := atomic.Bool{}
const eventType = "abc"
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// send some events
go func() {
for !stop.Load() {
id, err := uuid.GenerateUUID()
if err != nil {
core.Logger().Info("Error generating UUID, exiting sender", "error", err)
}
pluginInfo := &logical.EventPluginInfo{
MountPath: "secret",
}
err = core.Events().SendEventInternal(namespace.RootContext(ctx), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
Note: "testing",
})
if err != nil {
core.Logger().Info("Error sending event, exiting sender", "error", err)
}
time.Sleep(100 * time.Millisecond)
sendEvents := func() error {
id, err := uuid.GenerateUUID()
if err != nil {
return err
}
}()
t.Cleanup(func() {
stop.Store(true)
})
pluginInfo := &logical.EventPluginInfo{
MountPath: "secret",
}
err = core.Events().SendEventInternal(namespace.RootContext(ctx), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
Note: "testing",
})
if err != nil {
return err
}
return nil
}
wsAddr := strings.Replace(addr, "http", "ws", 1)
@@ -97,6 +88,10 @@ func TestEventsSubscribe(t *testing.T) {
conn.Close(websocket.StatusNormalClosure, "")
})
err = sendEvents()
if err != nil {
t.Fatal(err)
}
_, msg, err := conn.Read(ctx)
if err != nil {
t.Fatal(err)
@@ -147,28 +142,24 @@ func TestBexprFilters(t *testing.T) {
t.Fatal(err)
}
}
// send duplicates to help avoid flaky tests in CI
sendEvents := func(ctx context.Context, eventTypes ...string) {
for i := 0; i < 10; i++ {
time.Sleep(10 * time.Millisecond)
for _, eventType := range eventTypes {
pluginInfo := &logical.EventPluginInfo{
MountPath: "secret",
}
ns := namespace.RootNamespace
id := eventType
err := core.Events().SendEventInternal(namespace.RootContext(ctx), ns, pluginInfo, logical.EventType(eventType), &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
Note: "testing",
})
if err != nil {
return
}
sendEvents := func(ctx context.Context, eventTypes ...string) error {
for _, eventType := range eventTypes {
pluginInfo := &logical.EventPluginInfo{
MountPath: "secret",
}
ns := namespace.RootNamespace
id := eventType
err := core.Events().SendEventInternal(namespace.RootContext(ctx), ns, pluginInfo, logical.EventType(eventType), &logical.EventData{
Id: id,
Metadata: nil,
EntityIds: nil,
Note: "testing",
})
if err != nil {
return err
}
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -185,7 +176,10 @@ func TestBexprFilters(t *testing.T) {
}
defer conn.Close(websocket.StatusNormalClosure, "")
go sendEvents(ctx, "abc", "def", "xyz")
err = sendEvents(ctx, "abc", "def", "xyz")
if err != nil {
t.Fatal(err)
}
// read until we time out
seen := map[string]bool{}
done := false