diff --git a/changelog/24352.txt b/changelog/24352.txt new file mode 100644 index 0000000000..c6cf651dae --- /dev/null +++ b/changelog/24352.txt @@ -0,0 +1,3 @@ +```release-note:improvement +events: Add support for event subscription plugins, including SQS +``` diff --git a/command/commands.go b/command/commands.go index 2d9047d55d..b56ec2e8f0 100644 --- a/command/commands.go +++ b/command/commands.go @@ -11,6 +11,8 @@ import ( "github.com/hashicorp/cli" "github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/builtin/plugin" + "github.com/hashicorp/vault/plugins/event" + "github.com/hashicorp/vault/plugins/event/sqs" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/version" @@ -185,6 +187,10 @@ var ( "plugin": plugin.Factory, } + eventBackends = map[string]event.Factory{ + "sqs": sqs.New, + } + logicalBackends = map[string]logical.Factory{ "plugin": plugin.Factory, "database": logicalDb.Factory, @@ -742,6 +748,7 @@ func initCommands(ui, serverCmdUi cli.Ui, runOpts *RunOptions) map[string]cli.Co }, AuditBackends: auditBackends, CredentialBackends: credentialBackends, + EventBackends: eventBackends, LogicalBackends: logicalBackends, PhysicalBackends: physicalBackends, diff --git a/command/server.go b/command/server.go index c002749540..614616557f 100644 --- a/command/server.go +++ b/command/server.go @@ -52,6 +52,7 @@ import ( vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/listenerutil" + "github.com/hashicorp/vault/plugins/event" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/strutil" @@ -96,6 +97,7 @@ type ServerCommand struct { CredentialBackends map[string]logical.Factory LogicalBackends map[string]logical.Factory PhysicalBackends map[string]physical.Factory + EventBackends map[string]event.Factory ServiceRegistrations map[string]sr.Factory @@ -3079,6 +3081,7 @@ func createCoreConfig(c *ServerCommand, config *server.Config, backend physical. AuditBackends: c.AuditBackends, CredentialBackends: c.CredentialBackends, LogicalBackends: c.LogicalBackends, + EventBackends: c.EventBackends, LogLevel: config.LogLevel, Logger: c.logger, DetectDeadlocks: config.DetectDeadlocks, diff --git a/go.mod b/go.mod index 13bb22be8b..2bc46cd536 100644 --- a/go.mod +++ b/go.mod @@ -277,13 +277,13 @@ require ( github.com/andybalholm/brotli v1.0.5 // indirect github.com/apache/arrow/go/v14 v14.0.2 // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect - github.com/aws/aws-sdk-go-v2 v1.17.7 // indirect + github.com/aws/aws-sdk-go-v2 v1.23.4 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.13.18 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.1 // indirect github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.59 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.31 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.7 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.7 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.3.32 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.23 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 // indirect @@ -291,10 +291,11 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.25 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.0 // indirect github.com/aws/aws-sdk-go-v2/service/s3 v1.31.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sqs v1.29.1 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.12.6 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.6 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.18.7 // indirect - github.com/aws/smithy-go v1.13.5 // indirect + github.com/aws/smithy-go v1.18.1 // indirect github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect github.com/benbjohnson/immutable v0.4.0 // indirect github.com/beorn7/perks v1.0.1 // indirect diff --git a/go.sum b/go.sum index 8e907ed169..e21c21b1de 100644 --- a/go.sum +++ b/go.sum @@ -1177,6 +1177,8 @@ github.com/aws/aws-sdk-go v1.49.22 h1:r01+cQJ3cORQI1PJxG8af0jzrZpUOL9L+/3kU2x1ge github.com/aws/aws-sdk-go v1.49.22/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/aws/aws-sdk-go-v2 v1.17.7 h1:CLSjnhJSTSogvqUGhIC6LqFKATMRexcxLZ0i/Nzk9Eg= github.com/aws/aws-sdk-go-v2 v1.17.7/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2 v1.23.4 h1:2P20ZjH0ouSAu/6yZep8oCmTReathLuEu6dwoqEgjts= +github.com/aws/aws-sdk-go-v2 v1.23.4/go.mod h1:t3szzKfP0NeRU27uBFczDivYJjsmSnqI8kIvKyWb9ds= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 h1:dK82zF6kkPeCo8J1e+tGx4JdvDIQzj7ygIoLg8WMuGs= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10/go.mod h1:VeTZetY5KRJLuD/7fkQXMU6Mw7H5m/KP2J5Iy9osMno= github.com/aws/aws-sdk-go-v2/config v1.18.19 h1:AqFK6zFNtq4i1EYu+eC7lcKHYnZagMn6SW171la0bGw= @@ -1189,8 +1191,12 @@ github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.59 h1:E3Y+OfzOK1+rmRo/K2G0 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.59/go.mod h1:1M4PLSBUVfBI0aP+C9XI7SM6kZPCGYyI6izWz0TGprE= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.31 h1:sJLYcS+eZn5EeNINGHSCRAwUJMFVqklwkH36Vbyai7M= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.31/go.mod h1:QT0BqUvX1Bh2ABdTGnjqEjvjzrCfIniM9Sc8zn9Yndo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.7 h1:eMqD7ku6WGdmcWWXPYun9m6yk6feSULLhJlAtN6rYG4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.7/go.mod h1:0oBIfcDV6LScxEW0VgOqxT3e4aqKRp+SYhB9wAd5E3Q= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.25 h1:1mnRASEKnkqsntcxHaysxwgVoUUp5dkiB+l3llKnqyg= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.25/go.mod h1:zBHOPwhBc3FlQjQJE/D3IfPWiWaQmT06Vq9aNukDo0k= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.7 h1:+XYhWhgWs5F3Zx8oa49CXzNvfXrItaDjZB/M172fcHQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.7/go.mod h1:L6tcSRyCGxcKfDWUrmv2jv8G1cLDU7d0FUpEFpG9bVE= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.32 h1:p5luUImdIqywn6JpQsW3tq5GNOxKmOnEpybzPx+d1lk= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.32/go.mod h1:XGhIBZDEgfqmFIugclZ6FU7v75nHhBDtzuB4xB/tEi4= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.23 h1:DWYZIsyqagnWL00f8M/SOr9fN063OEQWn9LLTbdYXsk= @@ -1205,6 +1211,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.0 h1:e2ooMhpYGhDnBf github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.0/go.mod h1:bh2E0CXKZsQN+faiKVqC40vfNMAWheoULBCnEgO9K+8= github.com/aws/aws-sdk-go-v2/service/s3 v1.31.0 h1:B1G2pSPvbAtQjilPq+Y7jLIzCOwKzuVEl+aBBaNG0AQ= github.com/aws/aws-sdk-go-v2/service/s3 v1.31.0/go.mod h1:ncltU6n4Nof5uJttDtcNQ537uNuwYqsZZQcpkd2/GUQ= +github.com/aws/aws-sdk-go-v2/service/sqs v1.29.1 h1:OZI2aJxnfOZzB0uhyTaYIW6MeRMb1Qd2eLMjh0bFsRg= +github.com/aws/aws-sdk-go-v2/service/sqs v1.29.1/go.mod h1:GiU88YWgOho2cyEyS2YZo3GYz/j4etRYKWbJdcYgpuQ= github.com/aws/aws-sdk-go-v2/service/sso v1.12.6 h1:5V7DWLBd7wTELVz5bPpwzYy/sikk0gsgZfj40X+l5OI= github.com/aws/aws-sdk-go-v2/service/sso v1.12.6/go.mod h1:Y1VOmit/Fn6Tz1uFAeCO6Q7M2fmfXSCLeL5INVYsLuY= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.6 h1:B8cauxOH1W1v7rd8RdI/MWnoR4Ze0wIHWrb90qczxj4= @@ -1213,6 +1221,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.18.7 h1:bWNgNdRko2x6gqa0blfATqAZKZok github.com/aws/aws-sdk-go-v2/service/sts v1.18.7/go.mod h1:JuTnSoeePXmMVe9G8NcjjwgOKEfZ4cOjMuT2IBT/2eI= github.com/aws/smithy-go v1.13.5 h1:hgz0X/DX0dGqTYpGALqXJoRKRj5oQ7150i5FdTePzO8= github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= +github.com/aws/smithy-go v1.18.1 h1:pOdBTUfXNazOlxLrgeYalVnuTpKreACHtc62xLwIB3c= +github.com/aws/smithy-go v1.18.1/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a h1:eqjiAL3qooftPm8b9C1GsSSRcmlw7iOva8vdBTmV2PY= github.com/axiomhq/hyperloglog v0.0.0-20220105174342-98591331716a/go.mod h1:2stgcRjl6QmW+gU2h5E7BQXg4HU0gzxKWDuT5HviN9s= github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f h1:ZNv7On9kyUzm7fvRZumSyy/IUiSC7AzL0I1jKKtwooA= diff --git a/plugins/event/event_subscription_plugin.go b/plugins/event/event_subscription_plugin.go new file mode 100644 index 0000000000..4309866320 --- /dev/null +++ b/plugins/event/event_subscription_plugin.go @@ -0,0 +1,95 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package event + +import ( + "context" + "time" + + "github.com/hashicorp/vault/sdk/helper/backoff" +) + +type Factory func(context.Context) (SubscriptionPlugin, error) + +// SubscriptionPlugin is the interface implemented by plugins that can subscribe to and receive events. +type SubscriptionPlugin interface { + // Subscribe is used to set up a new connection. + Subscribe(context.Context, *SubscribeRequest) error + // Send is used to send events to a connection. + Send(context.Context, *SendRequest) error + // Unsubscribe is used to teardown a connection. + Unsubscribe(context.Context, *UnsubscribeRequest) error + // PluginMetadata returns the name and version for the particular event subscription plugin. + // The name is usually set as a constant the backend, e.g., "sqs" for the + // AWS SQS backend. + PluginMetadata() *PluginMetadata + // Close closes all connections. + Close(ctx context.Context) error +} + +type Request struct { + Subscribe *SubscribeRequest + Unsubscribe *UnsubscribeRequest + Event *SendRequest +} + +type SubscribeRequest struct { + SubscriptionID string + Config map[string]interface{} + VerifyConnection bool +} + +type UnsubscribeRequest struct { + SubscriptionID string +} + +type SendRequest struct { + SubscriptionID string + EventJSON string +} + +type PluginMetadata struct { + Name string + Version string +} + +// SubscribeConfigDefaults defines configuration map keys for common default options. +// Embed this in your own config struct to pick up these default options. +type SubscribeConfigDefaults struct { + Retries *int `mapstructure:"retries"` + RetryMinBackoff *time.Duration `mapstructure:"retry_min_backoff"` + RetryMaxBackoff *time.Duration `mapstructure:"retry_max_backoff"` +} + +// default values for common configuration keys +const ( + DefaultRetries = 3 + DefaultRetryMinBackoff = 100 * time.Millisecond + DefaultRetryMaxBackoff = 5 * time.Second +) + +func (c *SubscribeConfigDefaults) GetRetries() int { + if c.Retries == nil { + return DefaultRetries + } + return *c.Retries +} + +func (c *SubscribeConfigDefaults) GetRetryMinBackoff() time.Duration { + if c.RetryMinBackoff == nil { + return DefaultRetryMinBackoff + } + return *c.RetryMinBackoff +} + +func (c *SubscribeConfigDefaults) GetRetryMaxBackoff() time.Duration { + if c.RetryMaxBackoff == nil { + return DefaultRetryMaxBackoff + } + return *c.RetryMaxBackoff +} + +func (c *SubscribeConfigDefaults) NewRetryBackoff() *backoff.Backoff { + return backoff.NewBackoff(c.GetRetries(), c.GetRetryMinBackoff(), c.GetRetryMaxBackoff()) +} diff --git a/plugins/event/sqs/sqs.go b/plugins/event/sqs/sqs.go new file mode 100644 index 0000000000..e5537caa3e --- /dev/null +++ b/plugins/event/sqs/sqs.go @@ -0,0 +1,239 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package sqs + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/hashicorp/go-secure-stdlib/awsutil" + "github.com/hashicorp/vault/plugins/event" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/version" + "github.com/mitchellh/mapstructure" +) + +var ( + _ event.Factory = New + _ event.SubscriptionPlugin = (*sqsBackend)(nil) +) + +const pluginName = "sqs" + +// ErrQueueRequired is returned if the required queue parameters are not present. +var ErrQueueRequired = errors.New("queue_name or queue_url must be specified") + +// New returns a new instance of the SQS plugin backend. +func New(_ context.Context) (event.SubscriptionPlugin, error) { + return &sqsBackend{ + connections: map[string]*sqsConnection{}, + }, nil +} + +type sqsBackend struct { + connections map[string]*sqsConnection + clientLock sync.RWMutex +} + +type sqsConnection struct { + client *sqs.SQS + config *sqsConfig + queueURL string +} + +type sqsConfig struct { + event.SubscribeConfigDefaults + CreateQueue bool `mapstructure:"create_queue"` + AccessKeyID string `mapstructure:"access_key_id"` + SecretAccessKey string `mapstructure:"secret_access_key"` + Region string `mapstructure:"region"` + QueueName string `mapstructure:"queue_name"` + QueueURL string `mapstructure:"queue_url"` +} + +func newClient(sconfig *sqsConfig) (*sqs.SQS, error) { + var options []awsutil.Option + if sconfig.AccessKeyID != "" && sconfig.SecretAccessKey != "" { + options = append(options, awsutil.WithAccessKey(sconfig.AccessKeyID)) + options = append(options, awsutil.WithSecretKey(sconfig.SecretAccessKey)) + } + if sconfig.Region != "" { + options = append(options, awsutil.WithRegion(sconfig.Region)) + } + options = append(options, awsutil.WithEnvironmentCredentials(true)) + options = append(options, awsutil.WithSharedCredentials(true)) + credConfig, err := awsutil.NewCredentialsConfig(options...) + if err != nil { + return nil, err + } + session, err := credConfig.GetSession() + if err != nil { + return nil, err + } + return sqs.New(session), nil +} + +func (s *sqsBackend) Subscribe(_ context.Context, request *event.SubscribeRequest) error { + var sconfig sqsConfig + err := mapstructure.Decode(request.Config, &sconfig) + if err != nil { + return err + } + if sconfig.QueueName == "" && sconfig.QueueURL == "" { + return ErrQueueRequired + } + client, err := newClient(&sconfig) + if err != nil { + return err + } + var queueURL string + if sconfig.CreateQueue && sconfig.QueueName != "" { + resp, err := client.CreateQueue(&sqs.CreateQueueInput{ + QueueName: &sconfig.QueueName, + }) + var aerr awserr.Error + if errors.As(err, &aerr) { + if aerr.Code() == sqs.ErrCodeQueueNameExists { + // that's okay + err = nil + } + } + if err != nil { + return err + } + if resp == nil || resp.QueueUrl == nil { + return fmt.Errorf("invalid response from AWS: missing queue URL") + } + queueURL = *resp.QueueUrl + } else if sconfig.QueueURL != "" { + queueURL = sconfig.QueueURL + } else { + resp, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{ + QueueName: &sconfig.QueueName, + }) + if err != nil { + return err + } + if resp == nil || resp.QueueUrl == nil { + return fmt.Errorf("invalid response from AWS: missing queue URL") + } + queueURL = *resp.QueueUrl + } + + conn := &sqsConnection{ + client: client, + config: &sconfig, + queueURL: queueURL, + } + s.clientLock.Lock() + defer s.clientLock.Unlock() + if _, ok := s.connections[request.SubscriptionID]; ok { + s.killConnectionWithLock(request.SubscriptionID) + } + s.connections[request.SubscriptionID] = conn + return nil +} + +func (s *sqsBackend) killConnection(subscriptionID string) { + s.clientLock.Lock() + defer s.clientLock.Unlock() + s.killConnectionWithLock(subscriptionID) +} + +func (s *sqsBackend) killConnectionWithLock(subscriptionID string) { + delete(s.connections, subscriptionID) +} + +func (s *sqsBackend) getConn(subscriptionID string) (*sqsConnection, error) { + s.clientLock.RLock() + defer s.clientLock.RUnlock() + conn, ok := s.connections[subscriptionID] + if !ok { + return nil, fmt.Errorf("invalid subscription_id") + } + return conn, nil +} + +func (s *sqsBackend) Send(_ context.Context, send *event.SendRequest) error { + return s.sendSubscriptionEventInternal(send.SubscriptionID, send.EventJSON, false) +} + +func (s *sqsBackend) refreshClient(subscriptionID string) error { + conn, err := s.getConn(subscriptionID) + if err != nil { + return err + } + client, err := newClient(conn.config) + if err != nil { + return err + } + s.clientLock.Lock() + defer s.clientLock.Unlock() + conn.client = client + // probably not necessary, but just in case + s.connections[subscriptionID] = conn + return nil +} + +func (s *sqsBackend) sendSubscriptionEventInternal(subscriptionID string, eventJson string, isRetry bool) error { + conn, err := s.getConn(subscriptionID) + if err != nil { + return err + } + backoff := conn.config.NewRetryBackoff() + err = backoff.Retry(func() error { + _, err = conn.client.SendMessage(&sqs.SendMessageInput{ + MessageBody: &eventJson, + QueueUrl: &conn.queueURL, + }) + return err + }) + if err != nil && !isRetry { + // refresh client and try again, once + err2 := s.refreshClient(subscriptionID) + if err2 != nil { + return errors.Join(err, err2) + } + return s.sendSubscriptionEventInternal(subscriptionID, eventJson, true) + } else if err != nil && isRetry { + s.killConnection(subscriptionID) + return err + } + return nil +} + +func (s *sqsBackend) Unsubscribe(_ context.Context, request *event.UnsubscribeRequest) error { + s.killConnection(request.SubscriptionID) + return nil +} + +func (s *sqsBackend) PluginMetadata() *event.PluginMetadata { + return &event.PluginMetadata{ + Name: pluginName, + Version: version.GetVersion().Version, + } +} + +func (s *sqsBackend) PluginVersion() logical.PluginVersion { + return logical.PluginVersion{ + Version: version.GetVersion().Version, + } +} + +func (s *sqsBackend) Close(_ context.Context) error { + s.clientLock.Lock() + defer s.clientLock.Unlock() + var subscriptions []string + for k := range s.connections { + subscriptions = append(subscriptions, k) + } + for _, subscription := range subscriptions { + s.killConnectionWithLock(subscription) + } + return nil +} diff --git a/plugins/event/sqs/sqs_test.go b/plugins/event/sqs/sqs_test.go new file mode 100644 index 0000000000..35d91b22f1 --- /dev/null +++ b/plugins/event/sqs/sqs_test.go @@ -0,0 +1,118 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package sqs + +import ( + "context" + "os" + "testing" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sqs" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/plugins/event" + "github.com/stretchr/testify/assert" +) + +func getTestClient(t *testing.T) *sqs.Client { + awsConfig, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(os.Getenv("AWS_REGION"))) + if err != nil { + t.Fatal(err) + } + return sqs.NewFromConfig(awsConfig) +} + +func createQueue(t *testing.T, client *sqs.Client, queueName string) string { + resp, err := client.CreateQueue(context.Background(), &sqs.CreateQueueInput{ + QueueName: &queueName, + }) + if err != nil { + t.Fatal(err) + } + return *resp.QueueUrl +} + +func deleteQueue(t *testing.T, client *sqs.Client, queueURL string) { + _, err := client.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{ + QueueUrl: &queueURL, + }) + if err != nil { + t.Fatal(err) + } +} + +func receiveMessage(t *testing.T, client *sqs.Client, queueURL string) string { + resp, err := client.ReceiveMessage(context.Background(), &sqs.ReceiveMessageInput{ + QueueUrl: &queueURL, + WaitTimeSeconds: 5, + }) + if err != nil { + t.Fatal(err) + } + assert.Len(t, resp.Messages, 1) + msg := resp.Messages[0] + _, err = client.DeleteMessage(context.Background(), &sqs.DeleteMessageInput{ + QueueUrl: &queueURL, + ReceiptHandle: msg.ReceiptHandle, + }) + if err != nil { + t.Fatal(err) + } + return *msg.Body +} + +// TestSQS_SendOneMessage tests that the plugin basic flow of subscribe/sendevent/unsubscribe will send a message to SQS. +func TestSQS_SendOneMessage(t *testing.T) { + region := os.Getenv("AWS_REGION") + if region == "" { + t.Skip("Must set AWS_REGION") + } + sqsClient := getTestClient(t) + temp, err := uuid.GenerateUUID() + assert.Nil(t, err) + tempQueueName := "event-sqs-test-" + temp + tempQueueURL := createQueue(t, sqsClient, tempQueueName) + t.Cleanup(func() { + deleteQueue(t, sqsClient, tempQueueURL) + }) + + backend, _ := New(nil) + subID, err := uuid.GenerateUUID() + assert.Nil(t, err) + + err = backend.Subscribe(nil, &event.SubscribeRequest{ + SubscriptionID: subID, + Config: map[string]interface{}{ + "queue_name": tempQueueName, + "region": os.Getenv("AWS_REGION"), + "create_queue": true, + }, + VerifyConnection: false, + }) + assert.Nil(t, err) + + // create another subscription with the same queue to make sure we are okay with using an existing queue + err = backend.Subscribe(nil, &event.SubscribeRequest{ + SubscriptionID: subID + "2", + Config: map[string]interface{}{ + "queue_name": tempQueueName, + "region": os.Getenv("AWS_REGION"), + "create_queue": true, + }, + VerifyConnection: false, + }) + assert.Nil(t, err) + + err = backend.Send(nil, &event.SendRequest{ + SubscriptionID: subID, + EventJSON: "{}", + }) + assert.Nil(t, err) + + msg := receiveMessage(t, sqsClient, tempQueueURL) + assert.Equal(t, "{}", msg) + + err = backend.Unsubscribe(nil, &event.UnsubscribeRequest{SubscriptionID: subID}) + assert.Nil(t, err) +} diff --git a/sdk/helper/backoff/backoff.go b/sdk/helper/backoff/backoff.go index 35fb059538..ebf9aaa7e7 100644 --- a/sdk/helper/backoff/backoff.go +++ b/sdk/helper/backoff/backoff.go @@ -88,3 +88,20 @@ func jitter(t time.Duration) time.Duration { f := float64(t) * (1.0 - maxJitter*rand.Float64()) return time.Duration(math.Floor(f)) } + +// Retry calls the given function until it does not return an error, at least once and up to max_retries + 1 times. +// If the number of retries is exceeded, Retry() will return the last error seen joined with ErrMaxRetry. +func (b *Backoff) Retry(f func() error) error { + for { + err := f() + if err == nil { + return nil + } + + maxRetryErr := b.NextSleep() + if maxRetryErr != nil { + return errors.Join(maxRetryErr, err) + } + } + return nil // unreachable +} diff --git a/vault/core.go b/vault/core.go index c511a66ba7..20c5430473 100644 --- a/vault/core.go +++ b/vault/core.go @@ -51,6 +51,7 @@ import ( "github.com/hashicorp/vault/helper/osutil" "github.com/hashicorp/vault/limits" "github.com/hashicorp/vault/physical/raft" + "github.com/hashicorp/vault/plugins/event" "github.com/hashicorp/vault/sdk/helper/certutil" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" @@ -319,6 +320,9 @@ type Core struct { // auditBackends is the mapping of backends to use for this core auditBackends map[string]audit.Factory + // eventBackends is the mapping of event plugins to use for this core + eventBackends map[string]event.Factory + // stateLock protects mutable state stateLock locking.RWMutex sealed *uint32 @@ -763,6 +767,8 @@ type CoreConfig struct { AuditBackends map[string]audit.Factory + EventBackends map[string]event.Factory + Physical physical.Backend StorageType string @@ -1282,6 +1288,9 @@ func NewCore(conf *CoreConfig) (*Core, error) { // Audit backends c.configureAuditBackends(conf.AuditBackends) + // Event plugins + c.configureEventBackends(conf.EventBackends) + // UI uiStoragePrefix := systemBarrierPrefix + "ui" c.uiConfig = NewUIConfig(conf.EnableUI, physical.NewView(c.physical, uiStoragePrefix), NewBarrierView(c.barrier, uiStoragePrefix)) @@ -1452,6 +1461,19 @@ func (c *Core) configureLogicalBackends(backends map[string]logical.Factory, log c.addExtraLogicalBackends(adminNamespacePath) } +// configureEventBackends configures the Core with the ability to create +// event backends for various types. +func (c *Core) configureEventBackends(backends map[string]event.Factory) { + eventBackends := make(map[string]event.Factory, len(backends)) + for k, f := range backends { + eventBackends[k] = f + } + + c.eventBackends = eventBackends + + c.addExtraEventBackends() +} + // handleVersionTimeStamps stores the current version at the current time to // storage, and then loads all versions and upgrade timestamps out from storage. func (c *Core) handleVersionTimeStamps(ctx context.Context) error { diff --git a/vault/core_util.go b/vault/core_util.go index b9f858e25b..b2bac92c78 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -89,6 +89,8 @@ func (c *Core) createSecondaries(_ hclog.Logger) {} func (c *Core) addExtraLogicalBackends(_ string) {} +func (c *Core) addExtraEventBackends() {} + func (c *Core) addExtraCredentialBackends() {} func preUnsealInternal(context.Context, *Core) error { return nil } diff --git a/vault/events_test.go b/vault/event_plugins_test.go similarity index 100% rename from vault/events_test.go rename to vault/event_plugins_test.go diff --git a/vault/logical_system_events.go b/vault/logical_system_events.go new file mode 100644 index 0000000000..087be992df --- /dev/null +++ b/vault/logical_system_events.go @@ -0,0 +1,30 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package vault + +import ( + "context" + "net/http" + + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" +) + +// handleEventsSubscribe +func (b *SystemBackend) handleEventsSubscribe(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // TODO + return logical.RespondWithStatusCode(nil, req, http.StatusNoContent) +} + +// handleEventsUnsubscribe +func (b *SystemBackend) handleEventsUnsubscribe(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // TODO + return logical.RespondWithStatusCode(nil, req, http.StatusNoContent) +} + +// handleEventsListSubscriptions +func (b *SystemBackend) handleEventsListSubscriptions(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // TODO + return logical.RespondWithStatusCode(nil, req, http.StatusNoContent) +} diff --git a/vault/logical_system_paths.go b/vault/logical_system_paths.go index e82c4ef6f1..2aee94009b 100644 --- a/vault/logical_system_paths.go +++ b/vault/logical_system_paths.go @@ -5095,3 +5095,108 @@ func (b *SystemBackend) lockedUserPaths() []*framework.Path { }, } } + +func (b *SystemBackend) eventPaths() []*framework.Path { + return []*framework.Path{ + { + Pattern: "events/subscriptions$", + + DisplayAttrs: &framework.DisplayAttributes{ + OperationPrefix: "subscriptions", + OperationVerb: "create", + }, + + Fields: map[string]*framework.FieldSchema{ + "config": { + Type: framework.TypeMap, + Required: true, + // Description: strings.TrimSpace(sysHelp["mount_accessor"][0]), + }, + "plugin": { + Type: framework.TypeString, + Required: true, + }, + //"alias_identifier": { + // Type: framework.TypeString, + // // Description: strings.TrimSpace(sysHelp["alias_identifier"][0]), + //}, + }, + + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.handleEventsSubscribe, + Summary: "", + Responses: map[int][]framework.Response{ + http.StatusOK: {{ + Description: "OK", + Fields: map[string]*framework.FieldSchema{ + "id": { + Type: framework.TypeString, + Required: true, + }, + }, + }}, + }, + }, + }, + // HelpSynopsis: strings.TrimSpace(sysHelp["unlock_user"][0]), + // HelpDescription: strings.TrimSpace(sysHelp["unlock_user"][1]), + }, + { + Pattern: "events/subscriptions/(?P.+)/(?P.+)", + + DisplayAttrs: &framework.DisplayAttributes{ + OperationPrefix: "subscriptions", + OperationVerb: "create", + }, + + Fields: map[string]*framework.FieldSchema{ + "plugin": { + Type: framework.TypeString, + }, + "id": { + Type: framework.TypeString, + // Description: strings.TrimSpace(sysHelp["mount_accessor"][0]), + }, + "list": { + Type: framework.TypeBool, + }, + //"alias_identifier": { + // Type: framework.TypeString, + // // Description: strings.TrimSpace(sysHelp["alias_identifier"][0]), + //}, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ListOperation: &framework.PathOperation{ + Callback: b.handleEventsListSubscriptions, + Summary: "", + Responses: map[int][]framework.Response{ + http.StatusNoContent: {{ + Description: "OK", + }}, + }, + }, + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handleEventsListSubscriptions, + Summary: "", + Responses: map[int][]framework.Response{ + http.StatusNoContent: {{ + Description: "OK", + }}, + }, + }, + logical.DeleteOperation: &framework.PathOperation{ + Callback: b.handleEventsUnsubscribe, + Summary: "", + Responses: map[int][]framework.Response{ + http.StatusNoContent: {{ + Description: "OK", + }}, + }, + }, + }, + // HelpSynopsis: strings.TrimSpace(sysHelp["unlock_user"][0]), + // HelpDescription: strings.TrimSpace(sysHelp["unlock_user"][1]), + }, + } +}