mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	Support event subscriptions with glob wildcards (#19205)
This commit is contained in:
		| @@ -24,7 +24,7 @@ type eventSubscribeArgs struct { | ||||
| 	logger  hclog.Logger | ||||
| 	events  *eventbus.EventBus | ||||
| 	ns      *namespace.Namespace | ||||
| 	eventType logical.EventType | ||||
| 	pattern string | ||||
| 	conn    *websocket.Conn | ||||
| 	json    bool | ||||
| } | ||||
| @@ -34,7 +34,7 @@ type eventSubscribeArgs struct { | ||||
| func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) { | ||||
| 	ctx := args.ctx | ||||
| 	logger := args.logger | ||||
| 	ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.eventType) | ||||
| 	ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.pattern) | ||||
| 	if err != nil { | ||||
| 		logger.Info("Error subscribing", "error", err) | ||||
| 		return websocket.StatusUnsupportedData, "Error subscribing", nil | ||||
| @@ -97,12 +97,11 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler | ||||
| 		if ns.ID != namespace.RootNamespaceID { | ||||
| 			prefix = fmt.Sprintf("/v1/%ssys/events/subscribe/", ns.Path) | ||||
| 		} | ||||
| 		eventTypeStr := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix)) | ||||
| 		if eventTypeStr == "" { | ||||
| 		pattern := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix)) | ||||
| 		if pattern == "" { | ||||
| 			respondError(w, http.StatusBadRequest, fmt.Errorf("did not specify eventType to subscribe to")) | ||||
| 			return | ||||
| 		} | ||||
| 		eventType := logical.EventType(eventTypeStr) | ||||
|  | ||||
| 		json := false | ||||
| 		jsonRaw := r.URL.Query().Get("json") | ||||
| @@ -135,7 +134,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler | ||||
| 			} | ||||
| 		}() | ||||
|  | ||||
| 		closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, eventType, conn, json}) | ||||
| 		closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, pattern, conn, json}) | ||||
| 		if err != nil { | ||||
| 			closeStatus = websocket.CloseStatus(err) | ||||
| 			if closeStatus == -1 { | ||||
|   | ||||
| @@ -16,10 +16,17 @@ import ( | ||||
| 	"github.com/hashicorp/go-uuid" | ||||
| 	"github.com/hashicorp/vault/helper/namespace" | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| 	"github.com/ryanuber/go-glob" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| ) | ||||
|  | ||||
| const defaultTimeout = 60 * time.Second | ||||
| const ( | ||||
| 	// eventTypeAll is purely internal to the event bus. We use it to send all | ||||
| 	// events down one big firehose, and pipelines define their own filtering | ||||
| 	// based on what each subscriber is interested in. | ||||
| 	eventTypeAll   = "*" | ||||
| 	defaultTimeout = 60 * time.Second | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	ErrNotStarted              = errors.New("event broker has not been started") | ||||
| @@ -47,14 +54,12 @@ type asyncChanNode struct { | ||||
| 	// TODO: add bounded deque buffer of *EventReceived | ||||
| 	ctx    context.Context | ||||
| 	ch     chan *logical.EventReceived | ||||
| 	namespace *namespace.Namespace | ||||
| 	logger hclog.Logger | ||||
|  | ||||
| 	// used to close the connection | ||||
| 	closeOnce  sync.Once | ||||
| 	cancelFunc context.CancelFunc | ||||
| 	pipelineID eventlogger.PipelineID | ||||
| 	eventType  eventlogger.EventType | ||||
| 	broker     *eventlogger.Broker | ||||
| } | ||||
|  | ||||
| @@ -97,7 +102,7 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace, | ||||
| 	// We can't easily know when the Send is complete, so we can't call the cancel function. | ||||
| 	// But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long. | ||||
| 	ctx, _ = context.WithTimeout(ctx, bus.timeout) | ||||
| 	_, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived) | ||||
| 	_, err := bus.broker.Send(ctx, eventTypeAll, eventReceived) | ||||
| 	if err != nil { | ||||
| 		// if no listeners for this event type are registered, that's okay, the event | ||||
| 		// will just not be sent anywhere | ||||
| @@ -164,32 +169,42 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) { | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eventType logical.EventType) (<-chan *logical.EventReceived, context.CancelFunc, error) { | ||||
| func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, pattern string) (<-chan *logical.EventReceived, context.CancelFunc, error) { | ||||
| 	// subscriptions are still stored even if the bus has not been started | ||||
| 	pipelineID, err := uuid.GenerateUUID() | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	nodeID, err := uuid.GenerateUUID() | ||||
| 	filterNodeID, err := uuid.GenerateUUID() | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	filterNode := newFilterNode(ns, pattern) | ||||
| 	err = bus.broker.RegisterNode(eventlogger.NodeID(filterNodeID), filterNode) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	sinkNodeID, err := uuid.GenerateUUID() | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	// TODO: should we have just one node per namespace, and handle all the routing ourselves? | ||||
| 	ctx, cancel := context.WithCancel(ctx) | ||||
| 	asyncNode := newAsyncNode(ctx, ns, bus.logger) | ||||
| 	err = bus.broker.RegisterNode(eventlogger.NodeID(nodeID), asyncNode) | ||||
| 	err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode) | ||||
| 	if err != nil { | ||||
| 		defer cancel() | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	nodes := []eventlogger.NodeID{bus.formatterNodeID, eventlogger.NodeID(nodeID)} | ||||
| 	nodes := []eventlogger.NodeID{eventlogger.NodeID(filterNodeID), bus.formatterNodeID, eventlogger.NodeID(sinkNodeID)} | ||||
|  | ||||
| 	pipeline := eventlogger.Pipeline{ | ||||
| 		PipelineID: eventlogger.PipelineID(pipelineID), | ||||
| 		EventType:  eventlogger.EventType(eventType), | ||||
| 		EventType:  eventTypeAll, | ||||
| 		NodeIDs:    nodes, | ||||
| 	} | ||||
| 	err = bus.broker.RegisterPipeline(pipeline) | ||||
| @@ -197,10 +212,10 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eve | ||||
| 		defer cancel() | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
|  | ||||
| 	addSubscriptions(1) | ||||
| 	// add info needed to cancel the subscription | ||||
| 	asyncNode.pipelineID = eventlogger.PipelineID(pipelineID) | ||||
| 	asyncNode.eventType = eventlogger.EventType(eventType) | ||||
| 	asyncNode.cancelFunc = cancel | ||||
| 	return asyncNode.ch, asyncNode.Close, nil | ||||
| } | ||||
| @@ -211,11 +226,31 @@ func (bus *EventBus) SetSendTimeout(timeout time.Duration) { | ||||
| 	bus.timeout = timeout | ||||
| } | ||||
|  | ||||
| func newFilterNode(ns *namespace.Namespace, pattern string) *eventlogger.Filter { | ||||
| 	return &eventlogger.Filter{ | ||||
| 		Predicate: func(e *eventlogger.Event) (bool, error) { | ||||
| 			eventRecv := e.Payload.(*logical.EventReceived) | ||||
|  | ||||
| 			// Drop if event is not in our namespace. | ||||
| 			// TODO: add wildcard/child namespace processing here in some cases? | ||||
| 			if eventRecv.Namespace != ns.Path { | ||||
| 				return false, nil | ||||
| 			} | ||||
|  | ||||
| 			// Filter for correct event type, including wildcards. | ||||
| 			if !glob.Glob(pattern, eventRecv.EventType) { | ||||
| 				return false, nil | ||||
| 			} | ||||
|  | ||||
| 			return true, nil | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode { | ||||
| 	return &asyncChanNode{ | ||||
| 		ctx:    ctx, | ||||
| 		ch:     make(chan *logical.EventReceived), | ||||
| 		namespace: namespace, | ||||
| 		logger: logger, | ||||
| 	} | ||||
| } | ||||
| @@ -225,7 +260,7 @@ func (node *asyncChanNode) Close() { | ||||
| 	node.closeOnce.Do(func() { | ||||
| 		defer node.cancelFunc() | ||||
| 		if node.broker != nil { | ||||
| 			err := node.broker.RemovePipeline(node.eventType, node.pipelineID) | ||||
| 			err := node.broker.RemovePipeline(eventTypeAll, node.pipelineID) | ||||
| 			if err != nil { | ||||
| 				node.logger.Warn("Error removing pipeline for closing node", "error", err) | ||||
| 			} | ||||
| @@ -238,11 +273,6 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (* | ||||
| 	// sends to the channel async in another goroutine | ||||
| 	go func() { | ||||
| 		eventRecv := e.Payload.(*logical.EventReceived) | ||||
| 		// drop if event is not in our namespace | ||||
| 		// TODO: add wildcard processing here in some cases? | ||||
| 		if eventRecv.Namespace != node.namespace.Path { | ||||
| 			return | ||||
| 		} | ||||
| 		var timeout bool | ||||
| 		select { | ||||
| 		case node.ch <- eventRecv: | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/go-secure-stdlib/strutil" | ||||
| 	"github.com/hashicorp/vault/helper/namespace" | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
| @@ -38,7 +39,7 @@ func TestBusBasics(t *testing.T) { | ||||
| 		t.Errorf("Expected no error sending: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) | ||||
| 	ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -81,7 +82,7 @@ func TestNamespaceFiltering(t *testing.T) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) | ||||
| 	ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -137,13 +138,13 @@ func TestBus2Subscriptions(t *testing.T) { | ||||
| 	eventType2 := logical.EventType("someType2") | ||||
| 	bus.Start() | ||||
|  | ||||
| 	ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType1) | ||||
| 	ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType1)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer cancel1() | ||||
|  | ||||
| 	ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType2) | ||||
| 	ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType2)) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -222,7 +223,7 @@ func TestBusSubscriptionsCancel(t *testing.T) { | ||||
| 			received := atomic.Int32{} | ||||
|  | ||||
| 			for i := 0; i < create; i++ { | ||||
| 				ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) | ||||
| 				ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType)) | ||||
| 				if err != nil { | ||||
| 					t.Fatal(err) | ||||
| 				} | ||||
| @@ -297,3 +298,78 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) { | ||||
| 	} | ||||
| 	t.Error("Timeout waiting for condition") | ||||
| } | ||||
|  | ||||
| // TestBusWildcardSubscriptions tests that a single subscription can receive | ||||
| // multiple event types using * for glob patterns. | ||||
| func TestBusWildcardSubscriptions(t *testing.T) { | ||||
| 	bus, err := NewEventBus(nil) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	ctx := context.Background() | ||||
|  | ||||
| 	fooEventType := logical.EventType("kv/foo") | ||||
| 	barEventType := logical.EventType("kv/bar") | ||||
| 	bus.Start() | ||||
|  | ||||
| 	ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, "kv/*") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer cancel1() | ||||
|  | ||||
| 	ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, "*/bar") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer cancel2() | ||||
|  | ||||
| 	event1, err := logical.NewEvent() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	event2, err := logical.NewEvent() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	err = bus.SendInternal(ctx, namespace.RootNamespace, nil, barEventType, event2) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 	err = bus.SendInternal(ctx, namespace.RootNamespace, nil, fooEventType, event1) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
|  | ||||
| 	timeout := time.After(1 * time.Second) | ||||
| 	// Expect to receive both events on ch1, which subscribed to kv/* | ||||
| 	var ch1Seen []string | ||||
| 	for i := 0; i < 2; i++ { | ||||
| 		select { | ||||
| 		case message := <-ch1: | ||||
| 			ch1Seen = append(ch1Seen, message.Event.ID()) | ||||
| 		case <-timeout: | ||||
| 			t.Error("Timeout waiting for event1") | ||||
| 		} | ||||
| 	} | ||||
| 	if len(ch1Seen) != 2 { | ||||
| 		t.Errorf("Expected 2 events but got: %v", ch1Seen) | ||||
| 	} else { | ||||
| 		if !strutil.StrListContains(ch1Seen, event1.ID()) { | ||||
| 			t.Errorf("Did not find %s event1 ID in ch1seen", event1.ID()) | ||||
| 		} | ||||
| 		if !strutil.StrListContains(ch1Seen, event2.ID()) { | ||||
| 			t.Errorf("Did not find %s event2 ID in ch1seen", event2.ID()) | ||||
| 		} | ||||
| 	} | ||||
| 	// Expect to receive just kv/bar on ch2, which subscribed to */bar | ||||
| 	select { | ||||
| 	case message := <-ch2: | ||||
| 		if message.Event.ID() != event2.ID() { | ||||
| 			t.Errorf("Got unexpected message: %v", message) | ||||
| 		} | ||||
| 	case <-timeout: | ||||
| 		t.Error("Timeout waiting for event2") | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -19,7 +19,7 @@ func TestCanSendEventsFromBuiltinPlugin(t *testing.T) { | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, logical.EventType(eventType)) | ||||
| 	ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, eventType) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Tom Proctor
					Tom Proctor