mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			318 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			318 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package dynamodb
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"math/rand"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/go-test/deep"
 | |
| 	log "github.com/hashicorp/go-hclog"
 | |
| 	"github.com/hashicorp/vault/helper/logging"
 | |
| 	"github.com/hashicorp/vault/physical"
 | |
| 	"github.com/ory/dockertest"
 | |
| 
 | |
| 	"github.com/aws/aws-sdk-go/aws"
 | |
| 	"github.com/aws/aws-sdk-go/aws/credentials"
 | |
| 	"github.com/aws/aws-sdk-go/aws/session"
 | |
| 	"github.com/aws/aws-sdk-go/service/dynamodb"
 | |
| 	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
 | |
| )
 | |
| 
 | |
| func TestDynamoDBBackend(t *testing.T) {
 | |
| 	cleanup, endpoint, credsProvider := prepareDynamoDBTestContainer(t)
 | |
| 	defer cleanup()
 | |
| 
 | |
| 	creds, err := credsProvider.Get()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	region := os.Getenv("AWS_DEFAULT_REGION")
 | |
| 	if region == "" {
 | |
| 		region = "us-east-1"
 | |
| 	}
 | |
| 
 | |
| 	awsSession, err := session.NewSession(&aws.Config{
 | |
| 		Credentials: credsProvider,
 | |
| 		Endpoint:    aws.String(endpoint),
 | |
| 		Region:      aws.String(region),
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	conn := dynamodb.New(awsSession)
 | |
| 
 | |
| 	var randInt = rand.New(rand.NewSource(time.Now().UnixNano())).Int()
 | |
| 	table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt)
 | |
| 
 | |
| 	defer func() {
 | |
| 		conn.DeleteTable(&dynamodb.DeleteTableInput{
 | |
| 			TableName: aws.String(table),
 | |
| 		})
 | |
| 	}()
 | |
| 
 | |
| 	logger := logging.NewVaultLogger(log.Debug)
 | |
| 
 | |
| 	b, err := NewDynamoDBBackend(map[string]string{
 | |
| 		"access_key":    creds.AccessKeyID,
 | |
| 		"secret_key":    creds.SecretAccessKey,
 | |
| 		"session_token": creds.SessionToken,
 | |
| 		"table":         table,
 | |
| 		"region":        region,
 | |
| 		"endpoint":      endpoint,
 | |
| 	}, logger)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	physical.ExerciseBackend(t, b)
 | |
| 	physical.ExerciseBackend_ListPrefix(t, b)
 | |
| 
 | |
| 	t.Run("Marshalling upgrade", func(t *testing.T) {
 | |
| 		path := "test_key"
 | |
| 
 | |
| 		// Manually write to DynamoDB using the old ConvertTo function
 | |
| 		// for marshalling data
 | |
| 		inputEntry := &physical.Entry{
 | |
| 			Key:   path,
 | |
| 			Value: []byte{0x0f, 0xcf, 0x4a, 0x0f, 0xba, 0x2b, 0x15, 0xf0, 0xaa, 0x75, 0x09},
 | |
| 		}
 | |
| 
 | |
| 		record := DynamoDBRecord{
 | |
| 			Path:  recordPathForVaultKey(inputEntry.Key),
 | |
| 			Key:   recordKeyForVaultKey(inputEntry.Key),
 | |
| 			Value: inputEntry.Value,
 | |
| 		}
 | |
| 
 | |
| 		item, err := dynamodbattribute.ConvertToMap(record)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("err: %s", err)
 | |
| 		}
 | |
| 
 | |
| 		request := &dynamodb.PutItemInput{
 | |
| 			Item:      item,
 | |
| 			TableName: &table,
 | |
| 		}
 | |
| 		conn.PutItem(request)
 | |
| 
 | |
| 		// Read back the data using the normal interface which should
 | |
| 		// handle the old marshalling format gracefully
 | |
| 		entry, err := b.Get(context.Background(), path)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("err: %s", err)
 | |
| 		}
 | |
| 		if diff := deep.Equal(inputEntry, entry); diff != nil {
 | |
| 			t.Fatal(diff)
 | |
| 		}
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func TestDynamoDBHABackend(t *testing.T) {
 | |
| 	cleanup, endpoint, credsProvider := prepareDynamoDBTestContainer(t)
 | |
| 	defer cleanup()
 | |
| 
 | |
| 	creds, err := credsProvider.Get()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	region := os.Getenv("AWS_DEFAULT_REGION")
 | |
| 	if region == "" {
 | |
| 		region = "us-east-1"
 | |
| 	}
 | |
| 
 | |
| 	awsSession, err := session.NewSession(&aws.Config{
 | |
| 		Credentials: credsProvider,
 | |
| 		Endpoint:    aws.String(endpoint),
 | |
| 		Region:      aws.String(region),
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	conn := dynamodb.New(awsSession)
 | |
| 
 | |
| 	var randInt = rand.New(rand.NewSource(time.Now().UnixNano())).Int()
 | |
| 	table := fmt.Sprintf("vault-dynamodb-testacc-%d", randInt)
 | |
| 
 | |
| 	defer func() {
 | |
| 		conn.DeleteTable(&dynamodb.DeleteTableInput{
 | |
| 			TableName: aws.String(table),
 | |
| 		})
 | |
| 	}()
 | |
| 
 | |
| 	logger := logging.NewVaultLogger(log.Debug)
 | |
| 	b, err := NewDynamoDBBackend(map[string]string{
 | |
| 		"access_key":    creds.AccessKeyID,
 | |
| 		"secret_key":    creds.SecretAccessKey,
 | |
| 		"session_token": creds.SessionToken,
 | |
| 		"table":         table,
 | |
| 		"region":        region,
 | |
| 		"endpoint":      endpoint,
 | |
| 	}, logger)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	ha, ok := b.(physical.HABackend)
 | |
| 	if !ok {
 | |
| 		t.Fatalf("dynamodb does not implement HABackend")
 | |
| 	}
 | |
| 	physical.ExerciseHABackend(t, ha, ha)
 | |
| 	testDynamoDBLockTTL(t, ha)
 | |
| }
 | |
| 
 | |
| // Similar to testHABackend, but using internal implementation details to
 | |
| // trigger the lock failure scenario by setting the lock renew period for one
 | |
| // of the locks to a higher value than the lock TTL.
 | |
| func testDynamoDBLockTTL(t *testing.T, ha physical.HABackend) {
 | |
| 	// Set much smaller lock times to speed up the test.
 | |
| 	lockTTL := time.Second * 3
 | |
| 	renewInterval := time.Second * 1
 | |
| 	watchInterval := time.Second * 1
 | |
| 
 | |
| 	// Get the lock
 | |
| 	origLock, err := ha.LockWith("dynamodbttl", "bar")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	// set the first lock renew period to double the expected TTL.
 | |
| 	lock := origLock.(*DynamoDBLock)
 | |
| 	lock.renewInterval = lockTTL * 2
 | |
| 	lock.ttl = lockTTL
 | |
| 	lock.watchRetryInterval = watchInterval
 | |
| 
 | |
| 	// Attempt to lock
 | |
| 	leaderCh, err := lock.Lock(nil)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if leaderCh == nil {
 | |
| 		t.Fatalf("failed to get leader ch")
 | |
| 	}
 | |
| 
 | |
| 	// Check the value
 | |
| 	held, val, err := lock.Value()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if !held {
 | |
| 		t.Fatalf("should be held")
 | |
| 	}
 | |
| 	if val != "bar" {
 | |
| 		t.Fatalf("bad value: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Second acquisition should succeed because the first lock should
 | |
| 	// not renew within the 3 sec TTL.
 | |
| 	origLock2, err := ha.LockWith("dynamodbttl", "baz")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	lock2 := origLock2.(*DynamoDBLock)
 | |
| 	lock2.renewInterval = renewInterval
 | |
| 	lock2.ttl = lockTTL
 | |
| 	lock2.watchRetryInterval = watchInterval
 | |
| 
 | |
| 	// Cancel attempt in 6 sec so as not to block unit tests forever
 | |
| 	stopCh := make(chan struct{})
 | |
| 	time.AfterFunc(lockTTL*2, func() {
 | |
| 		close(stopCh)
 | |
| 	})
 | |
| 
 | |
| 	// Attempt to lock should work
 | |
| 	leaderCh2, err := lock2.Lock(stopCh)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if leaderCh2 == nil {
 | |
| 		t.Fatalf("should get leader ch")
 | |
| 	}
 | |
| 
 | |
| 	// Check the value
 | |
| 	held, val, err = lock2.Value()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("err: %v", err)
 | |
| 	}
 | |
| 	if !held {
 | |
| 		t.Fatalf("should be held")
 | |
| 	}
 | |
| 	if val != "baz" {
 | |
| 		t.Fatalf("bad value: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// The first lock should have lost the leader channel
 | |
| 	leaderChClosed := false
 | |
| 	blocking := make(chan struct{})
 | |
| 	// Attempt to read from the leader or the blocking channel, which ever one
 | |
| 	// happens first.
 | |
| 	go func() {
 | |
| 		select {
 | |
| 		case <-time.After(watchInterval * 3):
 | |
| 			return
 | |
| 		case <-leaderCh:
 | |
| 			leaderChClosed = true
 | |
| 			close(blocking)
 | |
| 		case <-blocking:
 | |
| 			return
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	<-blocking
 | |
| 	if !leaderChClosed {
 | |
| 		t.Fatalf("original lock did not have its leader channel closed.")
 | |
| 	}
 | |
| 
 | |
| 	// Cleanup
 | |
| 	lock2.Unlock()
 | |
| }
 | |
| 
 | |
| func prepareDynamoDBTestContainer(t *testing.T) (cleanup func(), retAddress string, creds *credentials.Credentials) {
 | |
| 	// If environment variable is set, assume caller wants to target a real
 | |
| 	// DynamoDB.
 | |
| 	if os.Getenv("AWS_DYNAMODB_ENDPOINT") != "" {
 | |
| 		return func() {}, os.Getenv("AWS_DYNAMODB_ENDPOINT"), credentials.NewEnvCredentials()
 | |
| 	}
 | |
| 
 | |
| 	pool, err := dockertest.NewPool("")
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to connect to docker: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	resource, err := pool.Run("cnadiminti/dynamodb-local", "latest", []string{})
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Could not start local DynamoDB: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	retAddress = "http://localhost:" + resource.GetPort("8000/tcp")
 | |
| 	cleanup = func() {
 | |
| 		err := pool.Purge(resource)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("Failed to cleanup local DynamoDB: %s", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// exponential backoff-retry, because the DynamoDB may not be able to accept
 | |
| 	// connections yet
 | |
| 	if err := pool.Retry(func() error {
 | |
| 		var err error
 | |
| 		resp, err := http.Get(retAddress)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		if resp.StatusCode != 400 {
 | |
| 			return fmt.Errorf("expected DynamoDB to return status code 400, got (%s) instead", resp.Status)
 | |
| 		}
 | |
| 		return nil
 | |
| 	}); err != nil {
 | |
| 		t.Fatalf("Could not connect to docker: %s", err)
 | |
| 	}
 | |
| 	return cleanup, retAddress, credentials.NewStaticCredentials("fake", "fake", "")
 | |
| }
 | 
