Remove SQS plugin (#26524)

To be moved to Enterprise.

The paths and plugin itself were not activated.
This commit is contained in:
Christopher Swenson
2024-04-18 13:50:11 -07:00
committed by GitHub
parent dd939d9a7e
commit 1e36019f1c
10 changed files with 0 additions and 619 deletions

View File

@@ -1,95 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package event
import (
"context"
"time"
"github.com/hashicorp/vault/sdk/helper/backoff"
)
type Factory func(context.Context) (SubscriptionPlugin, error)
// SubscriptionPlugin is the interface implemented by plugins that can subscribe to and receive events.
type SubscriptionPlugin interface {
// Subscribe is used to set up a new connection.
Subscribe(context.Context, *SubscribeRequest) error
// Send is used to send events to a connection.
Send(context.Context, *SendRequest) error
// Unsubscribe is used to teardown a connection.
Unsubscribe(context.Context, *UnsubscribeRequest) error
// PluginMetadata returns the name and version for the particular event subscription plugin.
// The name is usually set as a constant the backend, e.g., "sqs" for the
// AWS SQS backend.
PluginMetadata() *PluginMetadata
// Close closes all connections.
Close(ctx context.Context) error
}
type Request struct {
Subscribe *SubscribeRequest
Unsubscribe *UnsubscribeRequest
Event *SendRequest
}
type SubscribeRequest struct {
SubscriptionID string
Config map[string]interface{}
VerifyConnection bool
}
type UnsubscribeRequest struct {
SubscriptionID string
}
type SendRequest struct {
SubscriptionID string
EventJSON string
}
type PluginMetadata struct {
Name string
Version string
}
// SubscribeConfigDefaults defines configuration map keys for common default options.
// Embed this in your own config struct to pick up these default options.
type SubscribeConfigDefaults struct {
Retries *int `mapstructure:"retries"`
RetryMinBackoff *time.Duration `mapstructure:"retry_min_backoff"`
RetryMaxBackoff *time.Duration `mapstructure:"retry_max_backoff"`
}
// default values for common configuration keys
const (
DefaultRetries = 3
DefaultRetryMinBackoff = 100 * time.Millisecond
DefaultRetryMaxBackoff = 5 * time.Second
)
func (c *SubscribeConfigDefaults) GetRetries() int {
if c.Retries == nil {
return DefaultRetries
}
return *c.Retries
}
func (c *SubscribeConfigDefaults) GetRetryMinBackoff() time.Duration {
if c.RetryMinBackoff == nil {
return DefaultRetryMinBackoff
}
return *c.RetryMinBackoff
}
func (c *SubscribeConfigDefaults) GetRetryMaxBackoff() time.Duration {
if c.RetryMaxBackoff == nil {
return DefaultRetryMaxBackoff
}
return *c.RetryMaxBackoff
}
func (c *SubscribeConfigDefaults) NewRetryBackoff() *backoff.Backoff {
return backoff.NewBackoff(c.GetRetries(), c.GetRetryMinBackoff(), c.GetRetryMaxBackoff())
}

View File

@@ -1,239 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package sqs
import (
"context"
"errors"
"fmt"
"sync"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/vault/plugins/event"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/version"
"github.com/mitchellh/mapstructure"
)
var (
_ event.Factory = New
_ event.SubscriptionPlugin = (*sqsBackend)(nil)
)
const pluginName = "sqs"
// ErrQueueRequired is returned if the required queue parameters are not present.
var ErrQueueRequired = errors.New("queue_name or queue_url must be specified")
// New returns a new instance of the SQS plugin backend.
func New(_ context.Context) (event.SubscriptionPlugin, error) {
return &sqsBackend{
connections: map[string]*sqsConnection{},
}, nil
}
type sqsBackend struct {
connections map[string]*sqsConnection
clientLock sync.RWMutex
}
type sqsConnection struct {
client *sqs.SQS
config *sqsConfig
queueURL string
}
type sqsConfig struct {
event.SubscribeConfigDefaults
CreateQueue bool `mapstructure:"create_queue"`
AccessKeyID string `mapstructure:"access_key_id"`
SecretAccessKey string `mapstructure:"secret_access_key"`
Region string `mapstructure:"region"`
QueueName string `mapstructure:"queue_name"`
QueueURL string `mapstructure:"queue_url"`
}
func newClient(sconfig *sqsConfig) (*sqs.SQS, error) {
var options []awsutil.Option
if sconfig.AccessKeyID != "" && sconfig.SecretAccessKey != "" {
options = append(options, awsutil.WithAccessKey(sconfig.AccessKeyID))
options = append(options, awsutil.WithSecretKey(sconfig.SecretAccessKey))
}
if sconfig.Region != "" {
options = append(options, awsutil.WithRegion(sconfig.Region))
}
options = append(options, awsutil.WithEnvironmentCredentials(true))
options = append(options, awsutil.WithSharedCredentials(true))
credConfig, err := awsutil.NewCredentialsConfig(options...)
if err != nil {
return nil, err
}
session, err := credConfig.GetSession()
if err != nil {
return nil, err
}
return sqs.New(session), nil
}
func (s *sqsBackend) Subscribe(_ context.Context, request *event.SubscribeRequest) error {
var sconfig sqsConfig
err := mapstructure.Decode(request.Config, &sconfig)
if err != nil {
return err
}
if sconfig.QueueName == "" && sconfig.QueueURL == "" {
return ErrQueueRequired
}
client, err := newClient(&sconfig)
if err != nil {
return err
}
var queueURL string
if sconfig.CreateQueue && sconfig.QueueName != "" {
resp, err := client.CreateQueue(&sqs.CreateQueueInput{
QueueName: &sconfig.QueueName,
})
var aerr awserr.Error
if errors.As(err, &aerr) {
if aerr.Code() == sqs.ErrCodeQueueNameExists {
// that's okay
err = nil
}
}
if err != nil {
return err
}
if resp == nil || resp.QueueUrl == nil {
return fmt.Errorf("invalid response from AWS: missing queue URL")
}
queueURL = *resp.QueueUrl
} else if sconfig.QueueURL != "" {
queueURL = sconfig.QueueURL
} else {
resp, err := client.GetQueueUrl(&sqs.GetQueueUrlInput{
QueueName: &sconfig.QueueName,
})
if err != nil {
return err
}
if resp == nil || resp.QueueUrl == nil {
return fmt.Errorf("invalid response from AWS: missing queue URL")
}
queueURL = *resp.QueueUrl
}
conn := &sqsConnection{
client: client,
config: &sconfig,
queueURL: queueURL,
}
s.clientLock.Lock()
defer s.clientLock.Unlock()
if _, ok := s.connections[request.SubscriptionID]; ok {
s.killConnectionWithLock(request.SubscriptionID)
}
s.connections[request.SubscriptionID] = conn
return nil
}
func (s *sqsBackend) killConnection(subscriptionID string) {
s.clientLock.Lock()
defer s.clientLock.Unlock()
s.killConnectionWithLock(subscriptionID)
}
func (s *sqsBackend) killConnectionWithLock(subscriptionID string) {
delete(s.connections, subscriptionID)
}
func (s *sqsBackend) getConn(subscriptionID string) (*sqsConnection, error) {
s.clientLock.RLock()
defer s.clientLock.RUnlock()
conn, ok := s.connections[subscriptionID]
if !ok {
return nil, fmt.Errorf("invalid subscription_id")
}
return conn, nil
}
func (s *sqsBackend) Send(_ context.Context, send *event.SendRequest) error {
return s.sendSubscriptionEventInternal(send.SubscriptionID, send.EventJSON, false)
}
func (s *sqsBackend) refreshClient(subscriptionID string) error {
conn, err := s.getConn(subscriptionID)
if err != nil {
return err
}
client, err := newClient(conn.config)
if err != nil {
return err
}
s.clientLock.Lock()
defer s.clientLock.Unlock()
conn.client = client
// probably not necessary, but just in case
s.connections[subscriptionID] = conn
return nil
}
func (s *sqsBackend) sendSubscriptionEventInternal(subscriptionID string, eventJson string, isRetry bool) error {
conn, err := s.getConn(subscriptionID)
if err != nil {
return err
}
backoff := conn.config.NewRetryBackoff()
err = backoff.Retry(func() error {
_, err = conn.client.SendMessage(&sqs.SendMessageInput{
MessageBody: &eventJson,
QueueUrl: &conn.queueURL,
})
return err
})
if err != nil && !isRetry {
// refresh client and try again, once
err2 := s.refreshClient(subscriptionID)
if err2 != nil {
return errors.Join(err, err2)
}
return s.sendSubscriptionEventInternal(subscriptionID, eventJson, true)
} else if err != nil && isRetry {
s.killConnection(subscriptionID)
return err
}
return nil
}
func (s *sqsBackend) Unsubscribe(_ context.Context, request *event.UnsubscribeRequest) error {
s.killConnection(request.SubscriptionID)
return nil
}
func (s *sqsBackend) PluginMetadata() *event.PluginMetadata {
return &event.PluginMetadata{
Name: pluginName,
Version: version.GetVersion().Version,
}
}
func (s *sqsBackend) PluginVersion() logical.PluginVersion {
return logical.PluginVersion{
Version: version.GetVersion().Version,
}
}
func (s *sqsBackend) Close(_ context.Context) error {
s.clientLock.Lock()
defer s.clientLock.Unlock()
var subscriptions []string
for k := range s.connections {
subscriptions = append(subscriptions, k)
}
for _, subscription := range subscriptions {
s.killConnectionWithLock(subscription)
}
return nil
}

View File

@@ -1,118 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package sqs
import (
"context"
"os"
"testing"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/plugins/event"
"github.com/stretchr/testify/assert"
)
func getTestClient(t *testing.T) *sqs.Client {
awsConfig, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(os.Getenv("AWS_REGION")))
if err != nil {
t.Fatal(err)
}
return sqs.NewFromConfig(awsConfig)
}
func createQueue(t *testing.T, client *sqs.Client, queueName string) string {
resp, err := client.CreateQueue(context.Background(), &sqs.CreateQueueInput{
QueueName: &queueName,
})
if err != nil {
t.Fatal(err)
}
return *resp.QueueUrl
}
func deleteQueue(t *testing.T, client *sqs.Client, queueURL string) {
_, err := client.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{
QueueUrl: &queueURL,
})
if err != nil {
t.Fatal(err)
}
}
func receiveMessage(t *testing.T, client *sqs.Client, queueURL string) string {
resp, err := client.ReceiveMessage(context.Background(), &sqs.ReceiveMessageInput{
QueueUrl: &queueURL,
WaitTimeSeconds: 5,
})
if err != nil {
t.Fatal(err)
}
assert.Len(t, resp.Messages, 1)
msg := resp.Messages[0]
_, err = client.DeleteMessage(context.Background(), &sqs.DeleteMessageInput{
QueueUrl: &queueURL,
ReceiptHandle: msg.ReceiptHandle,
})
if err != nil {
t.Fatal(err)
}
return *msg.Body
}
// TestSQS_SendOneMessage tests that the plugin basic flow of subscribe/sendevent/unsubscribe will send a message to SQS.
func TestSQS_SendOneMessage(t *testing.T) {
region := os.Getenv("AWS_REGION")
if region == "" {
t.Skip("Must set AWS_REGION")
}
sqsClient := getTestClient(t)
temp, err := uuid.GenerateUUID()
assert.Nil(t, err)
tempQueueName := "event-sqs-test-" + temp
tempQueueURL := createQueue(t, sqsClient, tempQueueName)
t.Cleanup(func() {
deleteQueue(t, sqsClient, tempQueueURL)
})
backend, _ := New(nil)
subID, err := uuid.GenerateUUID()
assert.Nil(t, err)
err = backend.Subscribe(nil, &event.SubscribeRequest{
SubscriptionID: subID,
Config: map[string]interface{}{
"queue_name": tempQueueName,
"region": os.Getenv("AWS_REGION"),
"create_queue": true,
},
VerifyConnection: false,
})
assert.Nil(t, err)
// create another subscription with the same queue to make sure we are okay with using an existing queue
err = backend.Subscribe(nil, &event.SubscribeRequest{
SubscriptionID: subID + "2",
Config: map[string]interface{}{
"queue_name": tempQueueName,
"region": os.Getenv("AWS_REGION"),
"create_queue": true,
},
VerifyConnection: false,
})
assert.Nil(t, err)
err = backend.Send(nil, &event.SendRequest{
SubscriptionID: subID,
EventJSON: "{}",
})
assert.Nil(t, err)
msg := receiveMessage(t, sqsClient, tempQueueURL)
assert.Equal(t, "{}", msg)
err = backend.Unsubscribe(nil, &event.UnsubscribeRequest{SubscriptionID: subID})
assert.Nil(t, err)
}