diff --git a/http/events.go b/http/events.go index 384a6719cd..fe64593ce8 100644 --- a/http/events.go +++ b/http/events.go @@ -31,36 +31,92 @@ var webSocketRevalidationTime = 5 * time.Minute type eventSubscriber struct { ctx context.Context + cancelCtx context.CancelFunc clientToken string - capabilitiesFunc func(ctx context.Context, token, path string) ([]string, []string, error) logger hclog.Logger events *eventbus.EventBus namespacePatterns []string pattern string bexprFilter string - conn *websocket.Conn json bool checkCache *cache.Cache isRootToken bool + core *vault.Core + w http.ResponseWriter + r *http.Request + req *logical.Request } -// handleEventsSubscribeWebsocket runs forever serving events to the websocket connection, returning a websocket -// error code and reason only if the connection closes or there was an error. -func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCode, string, error) { +// handleEventsSubscribeWebsocket subscribes to the events, accepts the websocket connection, and then runs forever, +// serving events to the websocket connection. +func (sub *eventSubscriber) handleEventsSubscribeWebsocket() { ctx := sub.ctx logger := sub.logger + // subscribe before accept to avoid race conditions ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern, sub.bexprFilter) if err != nil { logger.Info("Error subscribing", "error", err) - return websocket.StatusUnsupportedData, "Error subscribing", nil + sub.w.WriteHeader(400) + sub.w.Write([]byte("Error subscribing")) + return } defer cancel() + logger.Debug("WebSocket is subscribed to messages", "namespaces", sub.namespacePatterns, "event_types", sub.pattern, "bexpr_filter", sub.bexprFilter) + + conn, err := websocket.Accept(sub.w, sub.r, nil) + if err != nil { + logger.Info("Could not accept as websocket", "error", err) + respondError(sub.w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket")) + return + } + + // continually validate subscribe access while the websocket is running + // this has to be done after accepting the websocket to avoid a race condition + go sub.validateSubscribeAccessLoop() + + // make sure to close the websocket + closeStatus := websocket.StatusNormalClosure + closeReason := "" + var closeErr error = nil + + defer func() { + if closeErr != nil { + closeStatus = websocket.CloseStatus(err) + if closeStatus == -1 { + closeStatus = websocket.StatusInternalError + } + closeReason = fmt.Sprintf("Internal error: %v", err) + logger.Debug("Error from websocket handler", "error", err) + } + // Close() will panic if the reason is greater than this length + if len(closeReason) > 123 { + logger.Debug("Truncated close reason", "closeReason", closeReason) + closeReason = closeReason[:123] + } + err = conn.Close(closeStatus, closeReason) + if err != nil { + logger.Debug("Error closing websocket", "error", err) + } + }() + + // we don't expect any incoming messages + ctx = conn.CloseRead(ctx) + // start the pinger + go func() { + for { + time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive + err := conn.Ping(ctx) + if err != nil { + return + } + } + }() for { select { case <-ctx.Done(): logger.Info("Websocket context is done, closing the connection") - return websocket.StatusNormalClosure, "", nil + return case message := <-ch: // Perform one last check that the message is allowed to be received. // For example, if a new namespace was created that matches the namespace patterns, @@ -78,7 +134,8 @@ func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCo messageBytes, ok = message.Format("cloudevents-json") if !ok { logger.Warn("Could not get cloudevents JSON format") - return 0, "", errors.New("could not get cloudevents JSON format") + closeErr = errors.New("could not get cloudevents JSON format") + return } messageType = websocket.MessageText } else { @@ -87,11 +144,13 @@ func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCo } if err != nil { logger.Warn("Could not serialize websocket event", "error", err) - return 0, "", err + closeErr = err + return } - err = sub.conn.Write(ctx, messageType, messageBytes) + err = conn.Write(ctx, messageType, messageBytes) if err != nil { - return 0, "", err + closeErr = err + return } } } @@ -154,7 +213,7 @@ func (sub *eventSubscriber) allowMessage(eventNs, dataPath, eventType string) bo if eventNs != "" { nsDataPath = path.Join(eventNs, dataPath) } - capabilities, allowedEventTypes, err := sub.capabilitiesFunc(sub.ctx, sub.clientToken, nsDataPath) + capabilities, allowedEventTypes, err := sub.core.CapabilitiesAndSubscribeEventTypes(sub.ctx, sub.clientToken, nsDataPath) if err != nil { sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs) return false @@ -221,64 +280,28 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler bexprFilter := strings.TrimSpace(r.URL.Query().Get("filter")) namespacePatterns := r.URL.Query()["namespaces"] namespacePatterns = prependNamespacePatterns(namespacePatterns, ns) - conn, err := websocket.Accept(w, r, nil) - if err != nil { - logger.Info("Could not accept as websocket", "error", err) - respondError(w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket")) - return - } - - // we don't expect any incoming messages - ctx = conn.CloseRead(ctx) - // start the pinger - go func() { - for { - time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive - err := conn.Ping(ctx) - if err != nil { - return - } - } - }() - - // continually validate subscribe access while the websocket is running + isRoot := entry.IsRoot() ctx, cancelCtx := context.WithCancel(ctx) defer cancelCtx() - isRoot := entry.IsRoot() - go validateSubscribeAccessLoop(core, ctx, cancelCtx, req) sub := &eventSubscriber{ ctx: ctx, - capabilitiesFunc: core.CapabilitiesAndSubscribeEventTypes, + cancelCtx: cancelCtx, logger: logger, events: core.Events(), namespacePatterns: namespacePatterns, pattern: pattern, bexprFilter: bexprFilter, - conn: conn, json: json, checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime), clientToken: auth.ClientToken, isRootToken: isRoot, + core: core, + w: w, + r: r, + req: req, } - closeStatus, closeReason, err := sub.handleEventsSubscribeWebsocket() - if err != nil { - closeStatus = websocket.CloseStatus(err) - if closeStatus == -1 { - closeStatus = websocket.StatusInternalError - } - closeReason = fmt.Sprintf("Internal error: %v", err) - logger.Debug("Error from websocket handler", "error", err) - } - // Close() will panic if the reason is greater than this length - if len(closeReason) > 123 { - logger.Debug("Truncated close reason", "closeReason", closeReason) - closeReason = closeReason[:123] - } - err = conn.Close(closeStatus, closeReason) - if err != nil { - logger.Debug("Error closing websocket", "error", err) - } + sub.handleEventsSubscribeWebsocket() }) } @@ -298,13 +321,13 @@ func prependNamespacePatterns(patterns []string, requestNamespace *namespace.Nam // validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in // its namespace. If the access check ever fails, then the cancel function is called and the function returns. -func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel context.CancelFunc, req *logical.Request) { +func (sub *eventSubscriber) validateSubscribeAccessLoop() { // if something breaks, default to canceling the websocket - defer cancel() + defer sub.cancelCtx() for { - _, _, err := core.CheckTokenWithLock(ctx, req, false) + _, _, err := sub.core.CheckTokenWithLock(sub.ctx, sub.req, false) if err != nil { - core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", req.Path, "error", err) + sub.core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", sub.req.Path, "error", err) return } // wait a while and try again, but quit the loop if the context finishes early @@ -312,7 +335,7 @@ func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel c ticker := time.NewTicker(webSocketRevalidationTime) defer ticker.Stop() select { - case <-ctx.Done(): + case <-sub.ctx.Done(): return true case <-ticker.C: return false diff --git a/http/events_test.go b/http/events_test.go index 4eea2134eb..1ae014364b 100644 --- a/http/events_test.go +++ b/http/events_test.go @@ -12,7 +12,6 @@ import ( "net/url" "strings" "sync" - "sync/atomic" "testing" "time" @@ -46,38 +45,30 @@ func TestEventsSubscribe(t *testing.T) { } } - stop := atomic.Bool{} - const eventType = "abc" ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // send some events - go func() { - for !stop.Load() { - id, err := uuid.GenerateUUID() - if err != nil { - core.Logger().Info("Error generating UUID, exiting sender", "error", err) - } - pluginInfo := &logical.EventPluginInfo{ - MountPath: "secret", - } - err = core.Events().SendEventInternal(namespace.RootContext(ctx), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{ - Id: id, - Metadata: nil, - EntityIds: nil, - Note: "testing", - }) - if err != nil { - core.Logger().Info("Error sending event, exiting sender", "error", err) - } - time.Sleep(100 * time.Millisecond) + sendEvents := func() error { + id, err := uuid.GenerateUUID() + if err != nil { + return err } - }() - - t.Cleanup(func() { - stop.Store(true) - }) + pluginInfo := &logical.EventPluginInfo{ + MountPath: "secret", + } + err = core.Events().SendEventInternal(namespace.RootContext(ctx), namespace.RootNamespace, pluginInfo, logical.EventType(eventType), &logical.EventData{ + Id: id, + Metadata: nil, + EntityIds: nil, + Note: "testing", + }) + if err != nil { + return err + } + return nil + } wsAddr := strings.Replace(addr, "http", "ws", 1) @@ -97,6 +88,10 @@ func TestEventsSubscribe(t *testing.T) { conn.Close(websocket.StatusNormalClosure, "") }) + err = sendEvents() + if err != nil { + t.Fatal(err) + } _, msg, err := conn.Read(ctx) if err != nil { t.Fatal(err) @@ -147,28 +142,24 @@ func TestBexprFilters(t *testing.T) { t.Fatal(err) } } - - // send duplicates to help avoid flaky tests in CI - sendEvents := func(ctx context.Context, eventTypes ...string) { - for i := 0; i < 10; i++ { - time.Sleep(10 * time.Millisecond) - for _, eventType := range eventTypes { - pluginInfo := &logical.EventPluginInfo{ - MountPath: "secret", - } - ns := namespace.RootNamespace - id := eventType - err := core.Events().SendEventInternal(namespace.RootContext(ctx), ns, pluginInfo, logical.EventType(eventType), &logical.EventData{ - Id: id, - Metadata: nil, - EntityIds: nil, - Note: "testing", - }) - if err != nil { - return - } + sendEvents := func(ctx context.Context, eventTypes ...string) error { + for _, eventType := range eventTypes { + pluginInfo := &logical.EventPluginInfo{ + MountPath: "secret", + } + ns := namespace.RootNamespace + id := eventType + err := core.Events().SendEventInternal(namespace.RootContext(ctx), ns, pluginInfo, logical.EventType(eventType), &logical.EventData{ + Id: id, + Metadata: nil, + EntityIds: nil, + Note: "testing", + }) + if err != nil { + return err } } + return nil } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -185,7 +176,10 @@ func TestBexprFilters(t *testing.T) { } defer conn.Close(websocket.StatusNormalClosure, "") - go sendEvents(ctx, "abc", "def", "xyz") + err = sendEvents(ctx, "abc", "def", "xyz") + if err != nil { + t.Fatal(err) + } // read until we time out seen := map[string]bool{} done := false