mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-30 02:02:43 +00:00
Remove SQS plugin (#26524)
To be moved to Enterprise. The paths and plugin itself were not activated.
This commit is contained in:
committed by
GitHub
parent
dd939d9a7e
commit
1e36019f1c
@@ -1,95 +0,0 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: MPL-2.0
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/helper/backoff"
|
||||
)
|
||||
|
||||
type Factory func(context.Context) (SubscriptionPlugin, error)
|
||||
|
||||
// SubscriptionPlugin is the interface implemented by plugins that can subscribe to and receive events.
|
||||
type SubscriptionPlugin interface {
|
||||
// Subscribe is used to set up a new connection.
|
||||
Subscribe(context.Context, *SubscribeRequest) error
|
||||
// Send is used to send events to a connection.
|
||||
Send(context.Context, *SendRequest) error
|
||||
// Unsubscribe is used to teardown a connection.
|
||||
Unsubscribe(context.Context, *UnsubscribeRequest) error
|
||||
// PluginMetadata returns the name and version for the particular event subscription plugin.
|
||||
// The name is usually set as a constant the backend, e.g., "sqs" for the
|
||||
// AWS SQS backend.
|
||||
PluginMetadata() *PluginMetadata
|
||||
// Close closes all connections.
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Subscribe *SubscribeRequest
|
||||
Unsubscribe *UnsubscribeRequest
|
||||
Event *SendRequest
|
||||
}
|
||||
|
||||
type SubscribeRequest struct {
|
||||
SubscriptionID string
|
||||
Config map[string]interface{}
|
||||
VerifyConnection bool
|
||||
}
|
||||
|
||||
type UnsubscribeRequest struct {
|
||||
SubscriptionID string
|
||||
}
|
||||
|
||||
type SendRequest struct {
|
||||
SubscriptionID string
|
||||
EventJSON string
|
||||
}
|
||||
|
||||
type PluginMetadata struct {
|
||||
Name string
|
||||
Version string
|
||||
}
|
||||
|
||||
// SubscribeConfigDefaults defines configuration map keys for common default options.
|
||||
// Embed this in your own config struct to pick up these default options.
|
||||
type SubscribeConfigDefaults struct {
|
||||
Retries *int `mapstructure:"retries"`
|
||||
RetryMinBackoff *time.Duration `mapstructure:"retry_min_backoff"`
|
||||
RetryMaxBackoff *time.Duration `mapstructure:"retry_max_backoff"`
|
||||
}
|
||||
|
||||
// default values for common configuration keys
|
||||
const (
|
||||
DefaultRetries = 3
|
||||
DefaultRetryMinBackoff = 100 * time.Millisecond
|
||||
DefaultRetryMaxBackoff = 5 * time.Second
|
||||
)
|
||||
|
||||
func (c *SubscribeConfigDefaults) GetRetries() int {
|
||||
if c.Retries == nil {
|
||||
return DefaultRetries
|
||||
}
|
||||
return *c.Retries
|
||||
}
|
||||
|
||||
func (c *SubscribeConfigDefaults) GetRetryMinBackoff() time.Duration {
|
||||
if c.RetryMinBackoff == nil {
|
||||
return DefaultRetryMinBackoff
|
||||
}
|
||||
return *c.RetryMinBackoff
|
||||
}
|
||||
|
||||
func (c *SubscribeConfigDefaults) GetRetryMaxBackoff() time.Duration {
|
||||
if c.RetryMaxBackoff == nil {
|
||||
return DefaultRetryMaxBackoff
|
||||
}
|
||||
return *c.RetryMaxBackoff
|
||||
}
|
||||
|
||||
func (c *SubscribeConfigDefaults) NewRetryBackoff() *backoff.Backoff {
|
||||
return backoff.NewBackoff(c.GetRetries(), c.GetRetryMinBackoff(), c.GetRetryMaxBackoff())
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package sqs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/service/sqs"
|
||||
"github.com/hashicorp/go-secure-stdlib/awsutil"
|
||||
"github.com/hashicorp/vault/plugins/event"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/version"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
var (
|
||||
_ event.Factory = New
|
||||
_ event.SubscriptionPlugin = (*sqsBackend)(nil)
|
||||
)
|
||||
|
||||
const pluginName = "sqs"
|
||||
|
||||
// ErrQueueRequired is returned if the required queue parameters are not present.
|
||||
var ErrQueueRequired = errors.New("queue_name or queue_url must be specified")
|
||||
|
||||
// New returns a new instance of the SQS plugin backend.
|
||||
func New(_ context.Context) (event.SubscriptionPlugin, error) {
|
||||
return &sqsBackend{
|
||||
connections: map[string]*sqsConnection{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type sqsBackend struct {
|
||||
connections map[string]*sqsConnection
|
||||
clientLock sync.RWMutex
|
||||
}
|
||||
|
||||
type sqsConnection struct {
|
||||
client *sqs.SQS
|
||||
config *sqsConfig
|
||||
queueURL string
|
||||
}
|
||||
|
||||
type sqsConfig struct {
|
||||
event.SubscribeConfigDefaults
|
||||
CreateQueue bool `mapstructure:"create_queue"`
|
||||
AccessKeyID string `mapstructure:"access_key_id"`
|
||||
SecretAccessKey string `mapstructure:"secret_access_key"`
|
||||
Region string `mapstructure:"region"`
|
||||
QueueName string `mapstructure:"queue_name"`
|
||||
QueueURL string `mapstructure:"queue_url"`
|
||||
}
|
||||
|
||||
func newClient(sconfig *sqsConfig) (*sqs.SQS, error) {
|
||||
var options []awsutil.Option
|
||||
if sconfig.AccessKeyID != "" && sconfig.SecretAccessKey != "" {
|
||||
options = append(options, awsutil.WithAccessKey(sconfig.AccessKeyID))
|
||||
options = append(options, awsutil.WithSecretKey(sconfig.SecretAccessKey))
|
||||
}
|
||||
if sconfig.Region != "" {
|
||||
options = append(options, awsutil.WithRegion(sconfig.Region))
|
||||
}
|
||||
options = append(options, awsutil.WithEnvironmentCredentials(true))
|
||||
options = append(options, awsutil.WithSharedCredentials(true))
|
||||
credConfig, err := awsutil.NewCredentialsConfig(options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session, err := credConfig.GetSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sqs.New(session), nil
|
||||
}
|
||||
|
||||
func (s *sqsBackend) Subscribe(_ context.Context, request *event.SubscribeRequest) error {
|
||||
var sconfig sqsConfig
|
||||
err := mapstructure.Decode(request.Config, &sconfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sconfig.QueueName == "" && sconfig.QueueURL == "" {
|
||||
return ErrQueueRequired
|
||||
}
|
||||
client, err := newClient(&sconfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var queueURL string
|
||||
if sconfig.CreateQueue && sconfig.QueueName != "" {
|
||||
resp, err := client.CreateQueue(&sqs.CreateQueueInput{
|
||||
QueueName: &sconfig.QueueName,
|
||||
})
|
||||
var aerr awserr.Error
|
||||
if errors.As(err, &aerr) {
|
||||
if aerr.Code() == sqs.ErrCodeQueueNameExists {
|
||||
// that's okay
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp == nil || resp.QueueUrl == nil {
|
||||
return fmt.Errorf("invalid response from AWS: missing queue URL")
|
||||
}
|
||||
queueURL = *resp.QueueUrl
|
||||
} else if sconfig.QueueURL != "" {
|
||||
queueURL = sconfig.QueueURL
|
||||
} else {
|
||||
resp, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{
|
||||
QueueName: &sconfig.QueueName,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp == nil || resp.QueueUrl == nil {
|
||||
return fmt.Errorf("invalid response from AWS: missing queue URL")
|
||||
}
|
||||
queueURL = *resp.QueueUrl
|
||||
}
|
||||
|
||||
conn := &sqsConnection{
|
||||
client: client,
|
||||
config: &sconfig,
|
||||
queueURL: queueURL,
|
||||
}
|
||||
s.clientLock.Lock()
|
||||
defer s.clientLock.Unlock()
|
||||
if _, ok := s.connections[request.SubscriptionID]; ok {
|
||||
s.killConnectionWithLock(request.SubscriptionID)
|
||||
}
|
||||
s.connections[request.SubscriptionID] = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqsBackend) killConnection(subscriptionID string) {
|
||||
s.clientLock.Lock()
|
||||
defer s.clientLock.Unlock()
|
||||
s.killConnectionWithLock(subscriptionID)
|
||||
}
|
||||
|
||||
func (s *sqsBackend) killConnectionWithLock(subscriptionID string) {
|
||||
delete(s.connections, subscriptionID)
|
||||
}
|
||||
|
||||
func (s *sqsBackend) getConn(subscriptionID string) (*sqsConnection, error) {
|
||||
s.clientLock.RLock()
|
||||
defer s.clientLock.RUnlock()
|
||||
conn, ok := s.connections[subscriptionID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid subscription_id")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *sqsBackend) Send(_ context.Context, send *event.SendRequest) error {
|
||||
return s.sendSubscriptionEventInternal(send.SubscriptionID, send.EventJSON, false)
|
||||
}
|
||||
|
||||
func (s *sqsBackend) refreshClient(subscriptionID string) error {
|
||||
conn, err := s.getConn(subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client, err := newClient(conn.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.clientLock.Lock()
|
||||
defer s.clientLock.Unlock()
|
||||
conn.client = client
|
||||
// probably not necessary, but just in case
|
||||
s.connections[subscriptionID] = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqsBackend) sendSubscriptionEventInternal(subscriptionID string, eventJson string, isRetry bool) error {
|
||||
conn, err := s.getConn(subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
backoff := conn.config.NewRetryBackoff()
|
||||
err = backoff.Retry(func() error {
|
||||
_, err = conn.client.SendMessage(&sqs.SendMessageInput{
|
||||
MessageBody: &eventJson,
|
||||
QueueUrl: &conn.queueURL,
|
||||
})
|
||||
return err
|
||||
})
|
||||
if err != nil && !isRetry {
|
||||
// refresh client and try again, once
|
||||
err2 := s.refreshClient(subscriptionID)
|
||||
if err2 != nil {
|
||||
return errors.Join(err, err2)
|
||||
}
|
||||
return s.sendSubscriptionEventInternal(subscriptionID, eventJson, true)
|
||||
} else if err != nil && isRetry {
|
||||
s.killConnection(subscriptionID)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqsBackend) Unsubscribe(_ context.Context, request *event.UnsubscribeRequest) error {
|
||||
s.killConnection(request.SubscriptionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sqsBackend) PluginMetadata() *event.PluginMetadata {
|
||||
return &event.PluginMetadata{
|
||||
Name: pluginName,
|
||||
Version: version.GetVersion().Version,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sqsBackend) PluginVersion() logical.PluginVersion {
|
||||
return logical.PluginVersion{
|
||||
Version: version.GetVersion().Version,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sqsBackend) Close(_ context.Context) error {
|
||||
s.clientLock.Lock()
|
||||
defer s.clientLock.Unlock()
|
||||
var subscriptions []string
|
||||
for k := range s.connections {
|
||||
subscriptions = append(subscriptions, k)
|
||||
}
|
||||
for _, subscription := range subscriptions {
|
||||
s.killConnectionWithLock(subscription)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,118 +0,0 @@
|
||||
// Copyright (c) HashiCorp, Inc.
|
||||
// SPDX-License-Identifier: BUSL-1.1
|
||||
|
||||
package sqs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sqs"
|
||||
"github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/plugins/event"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func getTestClient(t *testing.T) *sqs.Client {
|
||||
awsConfig, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(os.Getenv("AWS_REGION")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return sqs.NewFromConfig(awsConfig)
|
||||
}
|
||||
|
||||
func createQueue(t *testing.T, client *sqs.Client, queueName string) string {
|
||||
resp, err := client.CreateQueue(context.Background(), &sqs.CreateQueueInput{
|
||||
QueueName: &queueName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return *resp.QueueUrl
|
||||
}
|
||||
|
||||
func deleteQueue(t *testing.T, client *sqs.Client, queueURL string) {
|
||||
_, err := client.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{
|
||||
QueueUrl: &queueURL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func receiveMessage(t *testing.T, client *sqs.Client, queueURL string) string {
|
||||
resp, err := client.ReceiveMessage(context.Background(), &sqs.ReceiveMessageInput{
|
||||
QueueUrl: &queueURL,
|
||||
WaitTimeSeconds: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Len(t, resp.Messages, 1)
|
||||
msg := resp.Messages[0]
|
||||
_, err = client.DeleteMessage(context.Background(), &sqs.DeleteMessageInput{
|
||||
QueueUrl: &queueURL,
|
||||
ReceiptHandle: msg.ReceiptHandle,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return *msg.Body
|
||||
}
|
||||
|
||||
// TestSQS_SendOneMessage tests that the plugin basic flow of subscribe/sendevent/unsubscribe will send a message to SQS.
|
||||
func TestSQS_SendOneMessage(t *testing.T) {
|
||||
region := os.Getenv("AWS_REGION")
|
||||
if region == "" {
|
||||
t.Skip("Must set AWS_REGION")
|
||||
}
|
||||
sqsClient := getTestClient(t)
|
||||
temp, err := uuid.GenerateUUID()
|
||||
assert.Nil(t, err)
|
||||
tempQueueName := "event-sqs-test-" + temp
|
||||
tempQueueURL := createQueue(t, sqsClient, tempQueueName)
|
||||
t.Cleanup(func() {
|
||||
deleteQueue(t, sqsClient, tempQueueURL)
|
||||
})
|
||||
|
||||
backend, _ := New(nil)
|
||||
subID, err := uuid.GenerateUUID()
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = backend.Subscribe(nil, &event.SubscribeRequest{
|
||||
SubscriptionID: subID,
|
||||
Config: map[string]interface{}{
|
||||
"queue_name": tempQueueName,
|
||||
"region": os.Getenv("AWS_REGION"),
|
||||
"create_queue": true,
|
||||
},
|
||||
VerifyConnection: false,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// create another subscription with the same queue to make sure we are okay with using an existing queue
|
||||
err = backend.Subscribe(nil, &event.SubscribeRequest{
|
||||
SubscriptionID: subID + "2",
|
||||
Config: map[string]interface{}{
|
||||
"queue_name": tempQueueName,
|
||||
"region": os.Getenv("AWS_REGION"),
|
||||
"create_queue": true,
|
||||
},
|
||||
VerifyConnection: false,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = backend.Send(nil, &event.SendRequest{
|
||||
SubscriptionID: subID,
|
||||
EventJSON: "{}",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
msg := receiveMessage(t, sqsClient, tempQueueURL)
|
||||
assert.Equal(t, "{}", msg)
|
||||
|
||||
err = backend.Unsubscribe(nil, &event.UnsubscribeRequest{SubscriptionID: subID})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
Reference in New Issue
Block a user