mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
events: Don't accept websocket connection until subscription is active (#23024)
The WebSocket tests have been very flaky because we weren't able to tell when a WebSocket was fully connected and subscribed to events. We reworked the websocket subscription code to accept the websocket only after subscribing. This should eliminate all flakiness in these tests. 🤞 (We can follow-up in an enterprise PR to simplify some of the tests after this fix is merged.) I ran this locally a bunch of times and with data race detection enabled, and did not see any failures. Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
8e7c6e80d5
commit
82e9b610df
143
http/events.go
143
http/events.go
@@ -31,36 +31,92 @@ var webSocketRevalidationTime = 5 * time.Minute
|
|||||||
|
|
||||||
type eventSubscriber struct {
|
type eventSubscriber struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
cancelCtx context.CancelFunc
|
||||||
clientToken string
|
clientToken string
|
||||||
capabilitiesFunc func(ctx context.Context, token, path string) ([]string, []string, error)
|
|
||||||
logger hclog.Logger
|
logger hclog.Logger
|
||||||
events *eventbus.EventBus
|
events *eventbus.EventBus
|
||||||
namespacePatterns []string
|
namespacePatterns []string
|
||||||
pattern string
|
pattern string
|
||||||
bexprFilter string
|
bexprFilter string
|
||||||
conn *websocket.Conn
|
|
||||||
json bool
|
json bool
|
||||||
checkCache *cache.Cache
|
checkCache *cache.Cache
|
||||||
isRootToken bool
|
isRootToken bool
|
||||||
|
core *vault.Core
|
||||||
|
w http.ResponseWriter
|
||||||
|
r *http.Request
|
||||||
|
req *logical.Request
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEventsSubscribeWebsocket runs forever serving events to the websocket connection, returning a websocket
|
// handleEventsSubscribeWebsocket subscribes to the events, accepts the websocket connection, and then runs forever,
|
||||||
// error code and reason only if the connection closes or there was an error.
|
// serving events to the websocket connection.
|
||||||
func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCode, string, error) {
|
func (sub *eventSubscriber) handleEventsSubscribeWebsocket() {
|
||||||
ctx := sub.ctx
|
ctx := sub.ctx
|
||||||
logger := sub.logger
|
logger := sub.logger
|
||||||
|
// subscribe before accept to avoid race conditions
|
||||||
ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern, sub.bexprFilter)
|
ch, cancel, err := sub.events.SubscribeMultipleNamespaces(ctx, sub.namespacePatterns, sub.pattern, sub.bexprFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info("Error subscribing", "error", err)
|
logger.Info("Error subscribing", "error", err)
|
||||||
return websocket.StatusUnsupportedData, "Error subscribing", nil
|
sub.w.WriteHeader(400)
|
||||||
|
sub.w.Write([]byte("Error subscribing"))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
logger.Debug("WebSocket is subscribed to messages", "namespaces", sub.namespacePatterns, "event_types", sub.pattern, "bexpr_filter", sub.bexprFilter)
|
||||||
|
|
||||||
|
conn, err := websocket.Accept(sub.w, sub.r, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Could not accept as websocket", "error", err)
|
||||||
|
respondError(sub.w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// continually validate subscribe access while the websocket is running
|
||||||
|
// this has to be done after accepting the websocket to avoid a race condition
|
||||||
|
go sub.validateSubscribeAccessLoop()
|
||||||
|
|
||||||
|
// make sure to close the websocket
|
||||||
|
closeStatus := websocket.StatusNormalClosure
|
||||||
|
closeReason := ""
|
||||||
|
var closeErr error = nil
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if closeErr != nil {
|
||||||
|
closeStatus = websocket.CloseStatus(err)
|
||||||
|
if closeStatus == -1 {
|
||||||
|
closeStatus = websocket.StatusInternalError
|
||||||
|
}
|
||||||
|
closeReason = fmt.Sprintf("Internal error: %v", err)
|
||||||
|
logger.Debug("Error from websocket handler", "error", err)
|
||||||
|
}
|
||||||
|
// Close() will panic if the reason is greater than this length
|
||||||
|
if len(closeReason) > 123 {
|
||||||
|
logger.Debug("Truncated close reason", "closeReason", closeReason)
|
||||||
|
closeReason = closeReason[:123]
|
||||||
|
}
|
||||||
|
err = conn.Close(closeStatus, closeReason)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debug("Error closing websocket", "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// we don't expect any incoming messages
|
||||||
|
ctx = conn.CloseRead(ctx)
|
||||||
|
// start the pinger
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive
|
||||||
|
err := conn.Ping(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.Info("Websocket context is done, closing the connection")
|
logger.Info("Websocket context is done, closing the connection")
|
||||||
return websocket.StatusNormalClosure, "", nil
|
return
|
||||||
case message := <-ch:
|
case message := <-ch:
|
||||||
// Perform one last check that the message is allowed to be received.
|
// Perform one last check that the message is allowed to be received.
|
||||||
// For example, if a new namespace was created that matches the namespace patterns,
|
// For example, if a new namespace was created that matches the namespace patterns,
|
||||||
@@ -78,7 +134,8 @@ func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCo
|
|||||||
messageBytes, ok = message.Format("cloudevents-json")
|
messageBytes, ok = message.Format("cloudevents-json")
|
||||||
if !ok {
|
if !ok {
|
||||||
logger.Warn("Could not get cloudevents JSON format")
|
logger.Warn("Could not get cloudevents JSON format")
|
||||||
return 0, "", errors.New("could not get cloudevents JSON format")
|
closeErr = errors.New("could not get cloudevents JSON format")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
messageType = websocket.MessageText
|
messageType = websocket.MessageText
|
||||||
} else {
|
} else {
|
||||||
@@ -87,11 +144,13 @@ func (sub *eventSubscriber) handleEventsSubscribeWebsocket() (websocket.StatusCo
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Could not serialize websocket event", "error", err)
|
logger.Warn("Could not serialize websocket event", "error", err)
|
||||||
return 0, "", err
|
closeErr = err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
err = sub.conn.Write(ctx, messageType, messageBytes)
|
err = conn.Write(ctx, messageType, messageBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", err
|
closeErr = err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -154,7 +213,7 @@ func (sub *eventSubscriber) allowMessage(eventNs, dataPath, eventType string) bo
|
|||||||
if eventNs != "" {
|
if eventNs != "" {
|
||||||
nsDataPath = path.Join(eventNs, dataPath)
|
nsDataPath = path.Join(eventNs, dataPath)
|
||||||
}
|
}
|
||||||
capabilities, allowedEventTypes, err := sub.capabilitiesFunc(sub.ctx, sub.clientToken, nsDataPath)
|
capabilities, allowedEventTypes, err := sub.core.CapabilitiesAndSubscribeEventTypes(sub.ctx, sub.clientToken, nsDataPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs)
|
sub.logger.Debug("Error checking capabilities and event types for token", "error", err, "namespace", eventNs)
|
||||||
return false
|
return false
|
||||||
@@ -221,64 +280,28 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
|
|||||||
bexprFilter := strings.TrimSpace(r.URL.Query().Get("filter"))
|
bexprFilter := strings.TrimSpace(r.URL.Query().Get("filter"))
|
||||||
namespacePatterns := r.URL.Query()["namespaces"]
|
namespacePatterns := r.URL.Query()["namespaces"]
|
||||||
namespacePatterns = prependNamespacePatterns(namespacePatterns, ns)
|
namespacePatterns = prependNamespacePatterns(namespacePatterns, ns)
|
||||||
conn, err := websocket.Accept(w, r, nil)
|
isRoot := entry.IsRoot()
|
||||||
if err != nil {
|
|
||||||
logger.Info("Could not accept as websocket", "error", err)
|
|
||||||
respondError(w, http.StatusInternalServerError, fmt.Errorf("could not accept as websocket"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// we don't expect any incoming messages
|
|
||||||
ctx = conn.CloseRead(ctx)
|
|
||||||
// start the pinger
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
time.Sleep(30 * time.Second) // not too aggressive, but keep the HTTP connection alive
|
|
||||||
err := conn.Ping(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// continually validate subscribe access while the websocket is running
|
|
||||||
ctx, cancelCtx := context.WithCancel(ctx)
|
ctx, cancelCtx := context.WithCancel(ctx)
|
||||||
defer cancelCtx()
|
defer cancelCtx()
|
||||||
|
|
||||||
isRoot := entry.IsRoot()
|
|
||||||
go validateSubscribeAccessLoop(core, ctx, cancelCtx, req)
|
|
||||||
sub := &eventSubscriber{
|
sub := &eventSubscriber{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
capabilitiesFunc: core.CapabilitiesAndSubscribeEventTypes,
|
cancelCtx: cancelCtx,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
events: core.Events(),
|
events: core.Events(),
|
||||||
namespacePatterns: namespacePatterns,
|
namespacePatterns: namespacePatterns,
|
||||||
pattern: pattern,
|
pattern: pattern,
|
||||||
bexprFilter: bexprFilter,
|
bexprFilter: bexprFilter,
|
||||||
conn: conn,
|
|
||||||
json: json,
|
json: json,
|
||||||
checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime),
|
checkCache: cache.New(webSocketRevalidationTime, webSocketRevalidationTime),
|
||||||
clientToken: auth.ClientToken,
|
clientToken: auth.ClientToken,
|
||||||
isRootToken: isRoot,
|
isRootToken: isRoot,
|
||||||
|
core: core,
|
||||||
|
w: w,
|
||||||
|
r: r,
|
||||||
|
req: req,
|
||||||
}
|
}
|
||||||
closeStatus, closeReason, err := sub.handleEventsSubscribeWebsocket()
|
sub.handleEventsSubscribeWebsocket()
|
||||||
if err != nil {
|
|
||||||
closeStatus = websocket.CloseStatus(err)
|
|
||||||
if closeStatus == -1 {
|
|
||||||
closeStatus = websocket.StatusInternalError
|
|
||||||
}
|
|
||||||
closeReason = fmt.Sprintf("Internal error: %v", err)
|
|
||||||
logger.Debug("Error from websocket handler", "error", err)
|
|
||||||
}
|
|
||||||
// Close() will panic if the reason is greater than this length
|
|
||||||
if len(closeReason) > 123 {
|
|
||||||
logger.Debug("Truncated close reason", "closeReason", closeReason)
|
|
||||||
closeReason = closeReason[:123]
|
|
||||||
}
|
|
||||||
err = conn.Close(closeStatus, closeReason)
|
|
||||||
if err != nil {
|
|
||||||
logger.Debug("Error closing websocket", "error", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,13 +321,13 @@ func prependNamespacePatterns(patterns []string, requestNamespace *namespace.Nam
|
|||||||
|
|
||||||
// validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in
|
// validateSubscribeAccessLoop continually checks if the request has access to the subscribe endpoint in
|
||||||
// its namespace. If the access check ever fails, then the cancel function is called and the function returns.
|
// its namespace. If the access check ever fails, then the cancel function is called and the function returns.
|
||||||
func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel context.CancelFunc, req *logical.Request) {
|
func (sub *eventSubscriber) validateSubscribeAccessLoop() {
|
||||||
// if something breaks, default to canceling the websocket
|
// if something breaks, default to canceling the websocket
|
||||||
defer cancel()
|
defer sub.cancelCtx()
|
||||||
for {
|
for {
|
||||||
_, _, err := core.CheckTokenWithLock(ctx, req, false)
|
_, _, err := sub.core.CheckTokenWithLock(sub.ctx, sub.req, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", req.Path, "error", err)
|
sub.core.Logger().Debug("Token does not have access to subscription path in its own namespace, terminating WebSocket subscription", "path", sub.req.Path, "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// wait a while and try again, but quit the loop if the context finishes early
|
// wait a while and try again, but quit the loop if the context finishes early
|
||||||
@@ -312,7 +335,7 @@ func validateSubscribeAccessLoop(core *vault.Core, ctx context.Context, cancel c
|
|||||||
ticker := time.NewTicker(webSocketRevalidationTime)
|
ticker := time.NewTicker(webSocketRevalidationTime)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-sub.ctx.Done():
|
||||||
return true
|
return true
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -46,18 +45,15 @@ func TestEventsSubscribe(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stop := atomic.Bool{}
|
|
||||||
|
|
||||||
const eventType = "abc"
|
const eventType = "abc"
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// send some events
|
// send some events
|
||||||
go func() {
|
sendEvents := func() error {
|
||||||
for !stop.Load() {
|
|
||||||
id, err := uuid.GenerateUUID()
|
id, err := uuid.GenerateUUID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.Logger().Info("Error generating UUID, exiting sender", "error", err)
|
return err
|
||||||
}
|
}
|
||||||
pluginInfo := &logical.EventPluginInfo{
|
pluginInfo := &logical.EventPluginInfo{
|
||||||
MountPath: "secret",
|
MountPath: "secret",
|
||||||
@@ -69,15 +65,10 @@ func TestEventsSubscribe(t *testing.T) {
|
|||||||
Note: "testing",
|
Note: "testing",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.Logger().Info("Error sending event, exiting sender", "error", err)
|
return err
|
||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond)
|
return nil
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
stop.Store(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
wsAddr := strings.Replace(addr, "http", "ws", 1)
|
wsAddr := strings.Replace(addr, "http", "ws", 1)
|
||||||
|
|
||||||
@@ -97,6 +88,10 @@ func TestEventsSubscribe(t *testing.T) {
|
|||||||
conn.Close(websocket.StatusNormalClosure, "")
|
conn.Close(websocket.StatusNormalClosure, "")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
err = sendEvents()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
_, msg, err := conn.Read(ctx)
|
_, msg, err := conn.Read(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -147,11 +142,7 @@ func TestBexprFilters(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
sendEvents := func(ctx context.Context, eventTypes ...string) error {
|
||||||
// send duplicates to help avoid flaky tests in CI
|
|
||||||
sendEvents := func(ctx context.Context, eventTypes ...string) {
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
for _, eventType := range eventTypes {
|
for _, eventType := range eventTypes {
|
||||||
pluginInfo := &logical.EventPluginInfo{
|
pluginInfo := &logical.EventPluginInfo{
|
||||||
MountPath: "secret",
|
MountPath: "secret",
|
||||||
@@ -165,10 +156,10 @@ func TestBexprFilters(t *testing.T) {
|
|||||||
Note: "testing",
|
Note: "testing",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -185,7 +176,10 @@ func TestBexprFilters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||||
|
|
||||||
go sendEvents(ctx, "abc", "def", "xyz")
|
err = sendEvents(ctx, "abc", "def", "xyz")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
// read until we time out
|
// read until we time out
|
||||||
seen := map[string]bool{}
|
seen := map[string]bool{}
|
||||||
done := false
|
done := false
|
||||||
|
|||||||
Reference in New Issue
Block a user