Support event subscriptions with glob wildcards (#19205)

This commit is contained in:
Tom Proctor
2023-02-16 17:22:56 +00:00
committed by GitHub
parent add3659f39
commit 184939e90a
4 changed files with 147 additions and 42 deletions

View File

@@ -20,13 +20,13 @@ import (
) )
type eventSubscribeArgs struct { type eventSubscribeArgs struct {
ctx context.Context ctx context.Context
logger hclog.Logger logger hclog.Logger
events *eventbus.EventBus events *eventbus.EventBus
ns *namespace.Namespace ns *namespace.Namespace
eventType logical.EventType pattern string
conn *websocket.Conn conn *websocket.Conn
json bool json bool
} }
// handleEventsSubscribeWebsocket runs forever, returning a websocket error code and reason // handleEventsSubscribeWebsocket runs forever, returning a websocket error code and reason
@@ -34,7 +34,7 @@ type eventSubscribeArgs struct {
func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) { func handleEventsSubscribeWebsocket(args eventSubscribeArgs) (websocket.StatusCode, string, error) {
ctx := args.ctx ctx := args.ctx
logger := args.logger logger := args.logger
ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.eventType) ch, cancel, err := args.events.Subscribe(ctx, args.ns, args.pattern)
if err != nil { if err != nil {
logger.Info("Error subscribing", "error", err) logger.Info("Error subscribing", "error", err)
return websocket.StatusUnsupportedData, "Error subscribing", nil return websocket.StatusUnsupportedData, "Error subscribing", nil
@@ -97,12 +97,11 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
if ns.ID != namespace.RootNamespaceID { if ns.ID != namespace.RootNamespaceID {
prefix = fmt.Sprintf("/v1/%ssys/events/subscribe/", ns.Path) prefix = fmt.Sprintf("/v1/%ssys/events/subscribe/", ns.Path)
} }
eventTypeStr := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix)) pattern := strings.TrimSpace(strings.TrimPrefix(r.URL.Path, prefix))
if eventTypeStr == "" { if pattern == "" {
respondError(w, http.StatusBadRequest, fmt.Errorf("did not specify eventType to subscribe to")) respondError(w, http.StatusBadRequest, fmt.Errorf("did not specify eventType to subscribe to"))
return return
} }
eventType := logical.EventType(eventTypeStr)
json := false json := false
jsonRaw := r.URL.Query().Get("json") jsonRaw := r.URL.Query().Get("json")
@@ -135,7 +134,7 @@ func handleEventsSubscribe(core *vault.Core, req *logical.Request) http.Handler
} }
}() }()
closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, eventType, conn, json}) closeStatus, closeReason, err := handleEventsSubscribeWebsocket(eventSubscribeArgs{ctx, logger, core.Events(), ns, pattern, conn, json})
if err != nil { if err != nil {
closeStatus = websocket.CloseStatus(err) closeStatus = websocket.CloseStatus(err)
if closeStatus == -1 { if closeStatus == -1 {

View File

@@ -16,10 +16,17 @@ import (
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/ryanuber/go-glob"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
) )
const defaultTimeout = 60 * time.Second const (
// eventTypeAll is purely internal to the event bus. We use it to send all
// events down one big firehose, and pipelines define their own filtering
// based on what each subscriber is interested in.
eventTypeAll = "*"
defaultTimeout = 60 * time.Second
)
var ( var (
ErrNotStarted = errors.New("event broker has not been started") ErrNotStarted = errors.New("event broker has not been started")
@@ -45,16 +52,14 @@ type pluginEventBus struct {
type asyncChanNode struct { type asyncChanNode struct {
// TODO: add bounded deque buffer of *EventReceived // TODO: add bounded deque buffer of *EventReceived
ctx context.Context ctx context.Context
ch chan *logical.EventReceived ch chan *logical.EventReceived
namespace *namespace.Namespace logger hclog.Logger
logger hclog.Logger
// used to close the connection // used to close the connection
closeOnce sync.Once closeOnce sync.Once
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
pipelineID eventlogger.PipelineID pipelineID eventlogger.PipelineID
eventType eventlogger.EventType
broker *eventlogger.Broker broker *eventlogger.Broker
} }
@@ -97,7 +102,7 @@ func (bus *EventBus) SendInternal(ctx context.Context, ns *namespace.Namespace,
// We can't easily know when the Send is complete, so we can't call the cancel function. // We can't easily know when the Send is complete, so we can't call the cancel function.
// But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long. // But, it is called automatically after bus.timeout, so there won't be any leak as long as bus.timeout is not too long.
ctx, _ = context.WithTimeout(ctx, bus.timeout) ctx, _ = context.WithTimeout(ctx, bus.timeout)
_, err := bus.broker.Send(ctx, eventlogger.EventType(eventType), eventReceived) _, err := bus.broker.Send(ctx, eventTypeAll, eventReceived)
if err != nil { if err != nil {
// if no listeners for this event type are registered, that's okay, the event // if no listeners for this event type are registered, that's okay, the event
// will just not be sent anywhere // will just not be sent anywhere
@@ -164,32 +169,42 @@ func NewEventBus(logger hclog.Logger) (*EventBus, error) {
}, nil }, nil
} }
func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eventType logical.EventType) (<-chan *logical.EventReceived, context.CancelFunc, error) { func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, pattern string) (<-chan *logical.EventReceived, context.CancelFunc, error) {
// subscriptions are still stored even if the bus has not been started // subscriptions are still stored even if the bus has not been started
pipelineID, err := uuid.GenerateUUID() pipelineID, err := uuid.GenerateUUID()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
nodeID, err := uuid.GenerateUUID() filterNodeID, err := uuid.GenerateUUID()
if err != nil {
return nil, nil, err
}
filterNode := newFilterNode(ns, pattern)
err = bus.broker.RegisterNode(eventlogger.NodeID(filterNodeID), filterNode)
if err != nil {
return nil, nil, err
}
sinkNodeID, err := uuid.GenerateUUID()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// TODO: should we have just one node per namespace, and handle all the routing ourselves?
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
asyncNode := newAsyncNode(ctx, ns, bus.logger) asyncNode := newAsyncNode(ctx, ns, bus.logger)
err = bus.broker.RegisterNode(eventlogger.NodeID(nodeID), asyncNode) err = bus.broker.RegisterNode(eventlogger.NodeID(sinkNodeID), asyncNode)
if err != nil { if err != nil {
defer cancel() defer cancel()
return nil, nil, err return nil, nil, err
} }
nodes := []eventlogger.NodeID{bus.formatterNodeID, eventlogger.NodeID(nodeID)} nodes := []eventlogger.NodeID{eventlogger.NodeID(filterNodeID), bus.formatterNodeID, eventlogger.NodeID(sinkNodeID)}
pipeline := eventlogger.Pipeline{ pipeline := eventlogger.Pipeline{
PipelineID: eventlogger.PipelineID(pipelineID), PipelineID: eventlogger.PipelineID(pipelineID),
EventType: eventlogger.EventType(eventType), EventType: eventTypeAll,
NodeIDs: nodes, NodeIDs: nodes,
} }
err = bus.broker.RegisterPipeline(pipeline) err = bus.broker.RegisterPipeline(pipeline)
@@ -197,10 +212,10 @@ func (bus *EventBus) Subscribe(ctx context.Context, ns *namespace.Namespace, eve
defer cancel() defer cancel()
return nil, nil, err return nil, nil, err
} }
addSubscriptions(1) addSubscriptions(1)
// add info needed to cancel the subscription // add info needed to cancel the subscription
asyncNode.pipelineID = eventlogger.PipelineID(pipelineID) asyncNode.pipelineID = eventlogger.PipelineID(pipelineID)
asyncNode.eventType = eventlogger.EventType(eventType)
asyncNode.cancelFunc = cancel asyncNode.cancelFunc = cancel
return asyncNode.ch, asyncNode.Close, nil return asyncNode.ch, asyncNode.Close, nil
} }
@@ -211,12 +226,32 @@ func (bus *EventBus) SetSendTimeout(timeout time.Duration) {
bus.timeout = timeout bus.timeout = timeout
} }
func newFilterNode(ns *namespace.Namespace, pattern string) *eventlogger.Filter {
return &eventlogger.Filter{
Predicate: func(e *eventlogger.Event) (bool, error) {
eventRecv := e.Payload.(*logical.EventReceived)
// Drop if event is not in our namespace.
// TODO: add wildcard/child namespace processing here in some cases?
if eventRecv.Namespace != ns.Path {
return false, nil
}
// Filter for correct event type, including wildcards.
if !glob.Glob(pattern, eventRecv.EventType) {
return false, nil
}
return true, nil
},
}
}
func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode { func newAsyncNode(ctx context.Context, namespace *namespace.Namespace, logger hclog.Logger) *asyncChanNode {
return &asyncChanNode{ return &asyncChanNode{
ctx: ctx, ctx: ctx,
ch: make(chan *logical.EventReceived), ch: make(chan *logical.EventReceived),
namespace: namespace, logger: logger,
logger: logger,
} }
} }
@@ -225,7 +260,7 @@ func (node *asyncChanNode) Close() {
node.closeOnce.Do(func() { node.closeOnce.Do(func() {
defer node.cancelFunc() defer node.cancelFunc()
if node.broker != nil { if node.broker != nil {
err := node.broker.RemovePipeline(node.eventType, node.pipelineID) err := node.broker.RemovePipeline(eventTypeAll, node.pipelineID)
if err != nil { if err != nil {
node.logger.Warn("Error removing pipeline for closing node", "error", err) node.logger.Warn("Error removing pipeline for closing node", "error", err)
} }
@@ -238,11 +273,6 @@ func (node *asyncChanNode) Process(ctx context.Context, e *eventlogger.Event) (*
// sends to the channel async in another goroutine // sends to the channel async in another goroutine
go func() { go func() {
eventRecv := e.Payload.(*logical.EventReceived) eventRecv := e.Payload.(*logical.EventReceived)
// drop if event is not in our namespace
// TODO: add wildcard processing here in some cases?
if eventRecv.Namespace != node.namespace.Path {
return
}
var timeout bool var timeout bool
select { select {
case node.ch <- eventRecv: case node.ch <- eventRecv:

View File

@@ -7,6 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -38,7 +39,7 @@ func TestBusBasics(t *testing.T) {
t.Errorf("Expected no error sending: %v", err) t.Errorf("Expected no error sending: %v", err)
} }
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -81,7 +82,7 @@ func TestNamespaceFiltering(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) ch, cancel, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -137,13 +138,13 @@ func TestBus2Subscriptions(t *testing.T) {
eventType2 := logical.EventType("someType2") eventType2 := logical.EventType("someType2")
bus.Start() bus.Start()
ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType1) ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType1))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer cancel1() defer cancel1()
ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType2) ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType2))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -222,7 +223,7 @@ func TestBusSubscriptionsCancel(t *testing.T) {
received := atomic.Int32{} received := atomic.Int32{}
for i := 0; i < create; i++ { for i := 0; i < create; i++ {
ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, eventType) ch, cancelFunc, err := bus.Subscribe(ctx, namespace.RootNamespace, string(eventType))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -297,3 +298,78 @@ func waitFor(t *testing.T, maxWait time.Duration, f func() bool) {
} }
t.Error("Timeout waiting for condition") t.Error("Timeout waiting for condition")
} }
// TestBusWildcardSubscriptions tests that a single subscription can receive
// multiple event types using * for glob patterns.
func TestBusWildcardSubscriptions(t *testing.T) {
bus, err := NewEventBus(nil)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
fooEventType := logical.EventType("kv/foo")
barEventType := logical.EventType("kv/bar")
bus.Start()
ch1, cancel1, err := bus.Subscribe(ctx, namespace.RootNamespace, "kv/*")
if err != nil {
t.Fatal(err)
}
defer cancel1()
ch2, cancel2, err := bus.Subscribe(ctx, namespace.RootNamespace, "*/bar")
if err != nil {
t.Fatal(err)
}
defer cancel2()
event1, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
event2, err := logical.NewEvent()
if err != nil {
t.Fatal(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, barEventType, event2)
if err != nil {
t.Error(err)
}
err = bus.SendInternal(ctx, namespace.RootNamespace, nil, fooEventType, event1)
if err != nil {
t.Error(err)
}
timeout := time.After(1 * time.Second)
// Expect to receive both events on ch1, which subscribed to kv/*
var ch1Seen []string
for i := 0; i < 2; i++ {
select {
case message := <-ch1:
ch1Seen = append(ch1Seen, message.Event.ID())
case <-timeout:
t.Error("Timeout waiting for event1")
}
}
if len(ch1Seen) != 2 {
t.Errorf("Expected 2 events but got: %v", ch1Seen)
} else {
if !strutil.StrListContains(ch1Seen, event1.ID()) {
t.Errorf("Did not find %s event1 ID in ch1seen", event1.ID())
}
if !strutil.StrListContains(ch1Seen, event2.ID()) {
t.Errorf("Did not find %s event2 ID in ch1seen", event2.ID())
}
}
// Expect to receive just kv/bar on ch2, which subscribed to */bar
select {
case message := <-ch2:
if message.Event.ID() != event2.ID() {
t.Errorf("Got unexpected message: %v", message)
}
case <-timeout:
t.Error("Timeout waiting for event2")
}
}

View File

@@ -19,7 +19,7 @@ func TestCanSendEventsFromBuiltinPlugin(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, logical.EventType(eventType)) ch, cancel, err := c.events.Subscribe(ctx, namespace.RootNamespace, eventType)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }