diff --git a/changelog/23042.txt b/changelog/23042.txt new file mode 100644 index 0000000000..da73a30753 --- /dev/null +++ b/changelog/23042.txt @@ -0,0 +1,4 @@ +```release-note:bug +events: Ensure subscription resources are cleaned up on close. +``` + diff --git a/vault/eventbus/bus.go b/vault/eventbus/bus.go index 02a512c50d..e04e7ed0b2 100644 --- a/vault/eventbus/bus.go +++ b/vault/eventbus/bus.go @@ -69,10 +69,10 @@ type asyncChanNode struct { logger hclog.Logger // used to close the connection - closeOnce sync.Once - cancelFunc context.CancelFunc - pipelineID eventlogger.PipelineID - broker *eventlogger.Broker + closeOnce sync.Once + cancelFunc context.CancelFunc + pipelineID eventlogger.PipelineID + removePipeline func(ctx context.Context, t eventlogger.EventType, id eventlogger.PipelineID) (bool, error) } var ( @@ -185,10 +185,6 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) { return nil, err } formatterNodeID := eventlogger.NodeID(formatterID) - err = broker.RegisterNode(formatterNodeID, cloudEventsFormatterFilter) - if err != nil { - return nil, err - } if logger == nil { logger = hclog.Default().Named("events") @@ -217,6 +213,11 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP return nil, nil, err } + err = bus.broker.RegisterNode(bus.formatterNodeID, cloudEventsFormatterFilter) + if err != nil { + return nil, nil, err + } + filterNodeID, err := uuid.GenerateUUID() if err != nil { return nil, nil, err @@ -237,7 +238,7 @@ func (bus *EventBus) SubscribeMultipleNamespaces(ctx context.Context, namespaceP } ctx, cancel := context.WithCancel(ctx) - asyncNode := newAsyncNode(ctx, bus.logger) + asyncNode := newAsyncNode(ctx, bus.logger, bus.broker) err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode) if err != nil { defer cancel() @@ -312,11 +313,12 @@ func newFilterNode(namespacePatterns []string, pattern string, bexprFilter strin }, nil } -func newAsyncNode(ctx context.Context, logger hclog.Logger) *asyncChanNode { +func newAsyncNode(ctx context.Context, logger hclog.Logger, broker *eventlogger.Broker) *asyncChanNode { return &asyncChanNode{ - ctx: ctx, - ch: make(chan *eventlogger.Event), - logger: logger, + ctx: ctx, + ch: make(chan *eventlogger.Event), + logger: logger, + removePipeline: broker.RemovePipelineAndNodes, } } @@ -324,17 +326,15 @@ func newAsyncNode(ctx context.Context, logger hclog.Logger) *asyncChanNode { func (node *asyncChanNode) Close(ctx context.Context) { node.closeOnce.Do(func() { defer node.cancelFunc() - if node.broker != nil { - isPipelineRemoved, err := node.broker.RemovePipelineAndNodes(ctx, eventTypeAll, node.pipelineID) + removed, err := node.removePipeline(ctx, eventTypeAll, node.pipelineID) - switch { - case err != nil && isPipelineRemoved: - msg := fmt.Sprintf("Error removing nodes referenced by pipeline %q", node.pipelineID) - node.logger.Warn(msg, err) - case err != nil: - msg := fmt.Sprintf("Error removing pipeline %q", node.pipelineID) - node.logger.Warn(msg, err) - } + switch { + case err != nil && removed: + msg := fmt.Sprintf("Error removing nodes referenced by pipeline %q", node.pipelineID) + node.logger.Warn(msg, err) + case err != nil: + msg := fmt.Sprintf("Error removing pipeline %q", node.pipelineID) + node.logger.Warn(msg, err) } addSubscriptions(-1) }) diff --git a/vault/eventbus/bus_test.go b/vault/eventbus/bus_test.go index e480fecc78..1dbe9170e5 100644 --- a/vault/eventbus/bus_test.go +++ b/vault/eventbus/bus_test.go @@ -626,3 +626,30 @@ func TestBexpr(t *testing.T) { }) } } + +// TestPipelineCleanedUp ensures pipelines are properly cleaned up after +// subscriptions are closed. +func TestPipelineCleanedUp(t *testing.T) { + bus, err := NewEventBus(nil) + if err != nil { + t.Fatal(err) + } + + eventType := logical.EventType("someType") + bus.Start() + + _, cancel, err := bus.Subscribe(context.Background(), namespace.RootNamespace, string(eventType), "") + if err != nil { + t.Fatal(err) + } + if !bus.broker.IsAnyPipelineRegistered(eventTypeAll) { + cancel() + t.Fatal() + } + + cancel() + + if bus.broker.IsAnyPipelineRegistered(eventTypeAll) { + t.Fatal() + } +}