VAULT-30877: Repopulate AWS static creds queue in initialize (#28775)

* populate rotation queue in initialize

* docs, changelog

* add t.Helper()
This commit is contained in:
miagilepner
2024-11-04 16:32:14 +01:00
committed by GitHub
parent e489631e87
commit 10bd15f956
9 changed files with 293 additions and 30 deletions

View File

@@ -5,6 +5,7 @@ package aws
import (
"context"
"fmt"
"strings"
"sync"
"time"
@@ -33,6 +34,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,
func Backend(_ *logical.BackendConfig) *backend {
var b backend
b.minAllowableRotationPeriod = minAllowableRotationPeriod
b.credRotationQueue = queue.New()
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
@@ -62,6 +64,7 @@ func Backend(_ *logical.BackendConfig) *backend {
secretAccessKeys(&b),
},
InitializeFunc: b.initialize,
Invalidate: b.invalidate,
WALRollback: b.walRollback,
WALRollbackMinAge: minAwsUserRollbackAge,
@@ -94,6 +97,8 @@ type backend struct {
// the age of a static role's credential is tracked by a priority queue and handled
// by the PeriodicFunc
credRotationQueue *queue.PriorityQueue
minAllowableRotationPeriod time.Duration
}
const backendHelp = `
@@ -176,3 +181,66 @@ func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.ST
return b.stsClient, nil
}
func (b *backend) initialize(ctx context.Context, request *logical.InitializationRequest) error {
if !b.WriteSafeReplicationState() {
b.Logger().Info("skipping populating rotation queue")
return nil
}
b.Logger().Info("populating rotation queue")
creds, err := request.Storage.List(ctx, pathStaticCreds+"/")
if err != nil {
return err
}
b.Logger().Debug(fmt.Sprintf("Adding %d items to the rotation queue", len(creds)))
for _, roleName := range creds {
if roleName == "" {
continue
}
credPath := formatCredsStoragePath(roleName)
credsEntry, err := request.Storage.Get(ctx, credPath)
if err != nil {
return fmt.Errorf("could not read credentials: %w", err)
}
if credsEntry == nil {
continue
}
credentials := awsCredentials{}
if err := credsEntry.DecodeJSON(&credentials); err != nil {
return fmt.Errorf("failed to decode credentials: %w", err)
}
configEntry, err := request.Storage.Get(ctx, formatRoleStoragePath(roleName))
if err != nil {
return fmt.Errorf("could not read role: %w", err)
}
if configEntry == nil {
continue
}
config := staticRoleEntry{}
if err := configEntry.DecodeJSON(&config); err != nil {
return fmt.Errorf("failed to decode role config: %w", err)
}
if credentials.Expiration == nil {
expiration := time.Now().UTC().Add(config.RotationPeriod)
credentials.Expiration = &expiration
_, err := logical.StorageEntryJSON(credPath, creds)
if err != nil {
return fmt.Errorf("failed to marshal object to JSON: %w", err)
}
b.Logger().Debug("no known expiration time for credentials so resetting the expiration", "role", roleName, "new expiration", expiration)
}
err = b.credRotationQueue.Push(&queue.Item{
Key: config.Name,
Value: config,
Priority: credentials.priority(config),
})
if err != nil {
return fmt.Errorf("failed to add creds for role %s to queue: %w", roleName, err)
}
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"context"
"fmt"
"net/http"
"time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/sdk/framework"
@@ -21,8 +22,9 @@ const (
)
type awsCredentials struct {
AccessKeyID string `json:"access_key" structs:"access_key" mapstructure:"access_key"`
SecretAccessKey string `json:"secret_key" structs:"secret_key" mapstructure:"secret_key"`
AccessKeyID string `json:"access_key" structs:"access_key" mapstructure:"access_key"`
Expiration *time.Time `json:"expiration,omitempty" structs:"expiration" mapstructure:"expiration"`
SecretAccessKey string `json:"secret_key" structs:"secret_key" mapstructure:"secret_key"`
}
func pathStaticCredentials(b *backend) *framework.Path {
@@ -89,6 +91,13 @@ func formatCredsStoragePath(roleName string) string {
return fmt.Sprintf("%s/%s", pathStaticCreds, roleName)
}
func (a *awsCredentials) priority(role staticRoleEntry) int64 {
if a.Expiration != nil {
return a.Expiration.Unix()
}
return time.Now().Add(role.RotationPeriod).Unix()
}
const pathStaticCredsHelpSyn = `Retrieve static credentials from the named role.`
const pathStaticCredsHelpDesc = `

View File

@@ -7,10 +7,12 @@ import (
"context"
"reflect"
"testing"
"time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestStaticCredsRead verifies that we can correctly read a cred that exists, and correctly _not read_
@@ -91,3 +93,22 @@ func staticCredsFieldData(data map[string]interface{}) *framework.FieldData {
Schema: schema,
}
}
// Test_awsCredentials_priority verifies that the expiration in the credentials
// is returned as the priority value when it is present, but otherwise the
// priority is now + the rotation period
func Test_awsCredentials_priority(t *testing.T) {
expiration := time.Date(2023, 10, 24, 15, 21, 0o0, 0o0, time.UTC)
roleConfig := staticRoleEntry{RotationPeriod: time.Hour}
t.Run("use credential value", func(t *testing.T) {
creds := &awsCredentials{
Expiration: &expiration,
}
require.Equal(t, expiration.Unix(), creds.priority(roleConfig))
})
t.Run("use role value", func(t *testing.T) {
hourUnix := time.Now().Add(time.Hour).Unix()
creds := &awsCredentials{}
require.InDelta(t, hourUnix, creds.priority(roleConfig), float64(time.Minute/time.Second))
})
}

View File

@@ -194,12 +194,13 @@ func (b *backend) pathStaticRolesWrite(ctx context.Context, req *logical.Request
// Bootstrap initial set of keys if they did not exist before. AWS Secret Access Keys can only be obtained on creation,
// so we need to boostrap new roles with a new initial set of keys to be able to serve valid credentials to Vault clients.
existingCreds, err := req.Storage.Get(ctx, formatCredsStoragePath(config.Name))
credsPath := formatCredsStoragePath(config.Name)
existingCredsEntry, err := req.Storage.Get(ctx, credsPath)
if err != nil {
return nil, fmt.Errorf("unable to verify if credentials already exist for role %q: %w", config.Name, err)
}
if existingCreds == nil {
err := b.createCredential(ctx, req.Storage, config, false)
if existingCredsEntry == nil {
creds, err := b.createCredential(ctx, req.Storage, config, false)
if err != nil {
return nil, fmt.Errorf("failed to create new credentials for role %q: %w", config.Name, err)
}
@@ -207,12 +208,17 @@ func (b *backend) pathStaticRolesWrite(ctx context.Context, req *logical.Request
err = b.credRotationQueue.Push(&queue.Item{
Key: config.Name,
Value: config,
Priority: time.Now().Add(config.RotationPeriod).Unix(),
Priority: creds.priority(config),
})
if err != nil {
return nil, fmt.Errorf("failed to add item into the rotation queue for role %q: %w", config.Name, err)
}
} else {
var existingCreds awsCredentials
err := existingCredsEntry.DecodeJSON(&existingCreds)
if err != nil {
return nil, fmt.Errorf("unable to decode existing credentials for role %s: %w", config.Name, err)
}
// creds already exist, so all we need to do is update the rotation
// what here stays the same and what changes? Can we change the name?
i, err := b.credRotationQueue.PopByKey(config.Name)
@@ -227,7 +233,14 @@ func (b *backend) pathStaticRolesWrite(ctx context.Context, req *logical.Request
}
i.Value = config
// update the next rotation to occur at now + the new rotation period
i.Priority = time.Now().Add(config.RotationPeriod).Unix()
newExpiration := time.Now().Add(config.RotationPeriod)
existingCreds.Expiration = &newExpiration
_, err = logical.StorageEntryJSON(credsPath, &existingCreds)
if err != nil {
return nil, fmt.Errorf("error updating credentials for role %s: %w", config.Name, err)
}
i.Priority = existingCreds.priority(config)
err = b.credRotationQueue.Push(i)
if err != nil {
return nil, fmt.Errorf("failed to add updated item into the rotation queue for role %q: %w", config.Name, err)
@@ -318,8 +331,8 @@ const (
)
func (b *backend) validateRotationPeriod(period time.Duration) error {
if period < minAllowableRotationPeriod {
return fmt.Errorf("role rotation period out of range: must be greater than %.2f seconds", minAllowableRotationPeriod.Seconds())
if period < b.minAllowableRotationPeriod {
return fmt.Errorf("role rotation period out of range: must be greater than %.2f seconds", b.minAllowableRotationPeriod.Seconds())
}
return nil
}

View File

@@ -61,8 +61,9 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)
cfg := item.Value.(staticRoleEntry)
err = b.createCredential(ctx, storage, cfg, true)
creds, err := b.createCredential(ctx, storage, cfg, true)
if err != nil {
b.Logger().Error("failed to create credential, re-queueing", "error", err)
// put it back in the queue with a backoff
item.Priority = time.Now().Add(10 * time.Second).Unix()
innerErr := b.credRotationQueue.Push(item)
@@ -74,7 +75,7 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)
}
// set new priority and re-queue
item.Priority = time.Now().Add(cfg.RotationPeriod).Unix()
item.Priority = creds.priority(cfg)
err = b.credRotationQueue.Push(item)
if err != nil {
return true, fmt.Errorf("failed to add item into the rotation queue for role %q: %w", cfg.Name, err)
@@ -84,10 +85,10 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)
}
// createCredential will create a new iam credential, deleting the oldest one if necessary.
func (b *backend) createCredential(ctx context.Context, storage logical.Storage, cfg staticRoleEntry, shouldLockStorage bool) error {
func (b *backend) createCredential(ctx context.Context, storage logical.Storage, cfg staticRoleEntry, shouldLockStorage bool) (*awsCredentials, error) {
iamClient, err := b.clientIAM(ctx, storage)
if err != nil {
return fmt.Errorf("unable to get the AWS IAM client: %w", err)
return nil, fmt.Errorf("unable to get the AWS IAM client: %w", err)
}
// IAM users can have a most 2 sets of keys at a time.
@@ -97,14 +98,14 @@ func (b *backend) createCredential(ctx context.Context, storage logical.Storage,
err = b.validateIAMUserExists(ctx, storage, &cfg, false)
if err != nil {
return fmt.Errorf("iam user didn't exist, or username/userid didn't match: %w", err)
return nil, fmt.Errorf("iam user didn't exist, or username/userid didn't match: %w", err)
}
accessKeys, err := iamClient.ListAccessKeys(&iam.ListAccessKeysInput{
UserName: aws.String(cfg.Username),
})
if err != nil {
return fmt.Errorf("unable to list existing access keys for IAM user %q: %w", cfg.Username, err)
return nil, fmt.Errorf("unable to list existing access keys for IAM user %q: %w", cfg.Username, err)
}
// If we have the maximum number of keys, we have to delete one to make another (so we can get the credentials).
@@ -127,7 +128,7 @@ func (b *backend) createCredential(ctx context.Context, storage logical.Storage,
UserName: oldestKey.UserName,
})
if err != nil {
return fmt.Errorf("unable to delete oldest access keys for user %q: %w", cfg.Username, err)
return nil, fmt.Errorf("unable to delete oldest access keys for user %q: %w", cfg.Username, err)
}
}
@@ -136,16 +137,19 @@ func (b *backend) createCredential(ctx context.Context, storage logical.Storage,
UserName: aws.String(cfg.Username),
})
if err != nil {
return fmt.Errorf("unable to create new access keys for user %q: %w", cfg.Username, err)
return nil, fmt.Errorf("unable to create new access keys for user %q: %w", cfg.Username, err)
}
expiration := time.Now().UTC().Add(cfg.RotationPeriod)
// Persist new keys
entry, err := logical.StorageEntryJSON(formatCredsStoragePath(cfg.Name), &awsCredentials{
creds := &awsCredentials{
AccessKeyID: *out.AccessKey.AccessKeyId,
SecretAccessKey: *out.AccessKey.SecretAccessKey,
})
Expiration: &expiration,
}
// Persist new keys
entry, err := logical.StorageEntryJSON(formatCredsStoragePath(cfg.Name), creds)
if err != nil {
return fmt.Errorf("failed to marshal object to JSON: %w", err)
return nil, fmt.Errorf("failed to marshal object to JSON: %w", err)
}
if shouldLockStorage {
b.roleMutex.Lock()
@@ -153,10 +157,10 @@ func (b *backend) createCredential(ctx context.Context, storage logical.Storage,
}
err = storage.Put(ctx, entry)
if err != nil {
return fmt.Errorf("failed to save object in storage: %w", err)
return nil, fmt.Errorf("failed to save object in storage: %w", err)
}
return nil
return creds, nil
}
// delete credential will remove the credential associated with the role from storage.

View File

@@ -6,6 +6,8 @@ package aws
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
@@ -13,8 +15,13 @@ import (
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/testhelpers"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/queue"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require"
)
// TestRotation verifies that the rotation code and priority queue correctly selects and rotates credentials
@@ -109,6 +116,7 @@ func TestRotation(t *testing.T) {
b := Backend(config)
expirations := make([]*time.Time, len(c.creds))
// insert all our creds
for i, cred := range c.creds {
@@ -140,11 +148,12 @@ func TestRotation(t *testing.T) {
}
b.iamClient = miam
err = b.createCredential(bgCTX, config.StorageView, cred.config, true)
c, err := b.createCredential(bgCTX, config.StorageView, cred.config, true)
if err != nil {
t.Fatalf("couldn't insert credential %d: %s", i, err)
}
expirations[i] = c.Expiration
item := &queue.Item{
Key: cred.config.Name,
Value: cred.config,
@@ -205,10 +214,12 @@ func TestRotation(t *testing.T) {
t.Fatalf("could not unmarshal storage view entry for cred %d to an aws credential: %s", i, err)
}
if cred.changed && out.SecretAccessKey != newSecret {
t.Fatalf("expected the key for cred %d to have changed, but it hasn't", i)
} else if !cred.changed && out.SecretAccessKey != oldSecret {
t.Fatalf("expected the key for cred %d to have stayed the same, but it changed", i)
if cred.changed {
require.Equal(t, out.SecretAccessKey, newSecret, "expected the key for cred %d to have changed, but it hasn't", i)
require.NotEqual(t, out.Expiration.UTC(), expirations[i].UTC(), "expected the expiration for cred %d to have changed, but it hasn't", i)
} else {
require.Equal(t, out.SecretAccessKey, oldSecret, "expected the key for cred %d to have stayed the same, but it changed", i)
require.Equal(t, out.Expiration.UTC(), expirations[i].UTC(), "expected the expiration for cred %d to have changed, but it hasn't", i)
}
}
})
@@ -331,7 +342,7 @@ func TestCreateCredential(t *testing.T) {
b := Backend(config)
b.iamClient = fiam
err = b.createCredential(context.Background(), config.StorageView, staticRoleEntry{Username: c.username, ID: c.id}, true)
_, err = b.createCredential(context.Background(), config.StorageView, staticRoleEntry{Username: c.username, ID: c.id}, true)
if err != nil {
t.Fatalf("got an error we didn't expect: %q", err)
}
@@ -394,7 +405,7 @@ func TestRequeueOnError(t *testing.T) {
b.iamClient = miam
err = b.createCredential(bgCTX, config.StorageView, cred, true)
_, err = b.createCredential(bgCTX, config.StorageView, cred, true)
if err != nil {
t.Fatalf("couldn't insert credential: %s", err)
}
@@ -437,3 +448,135 @@ func TestRequeueOnError(t *testing.T) {
t.Fatalf("priority should be within 5 seconds of our backoff interval")
}
}
type mockIAM struct {
iamiface.IAMAPI
// mapping username -> number of times CreateAccessKey has been queried
// for this user
newKeys map[string]int
l sync.Mutex
}
func (m *mockIAM) GetUser(input *iam.GetUserInput) (*iam.GetUserOutput, error) {
return &iam.GetUserOutput{User: &iam.User{UserId: aws.String(""), UserName: input.UserName}}, nil
}
func (m *mockIAM) ListAccessKeys(input *iam.ListAccessKeysInput) (*iam.ListAccessKeysOutput, error) {
return &iam.ListAccessKeysOutput{
AccessKeyMetadata: []*iam.AccessKeyMetadata{
{
AccessKeyId: aws.String(fmt.Sprintf("%s-key", *input.UserName)),
},
},
}, nil
}
func (m *mockIAM) CreateAccessKey(input *iam.CreateAccessKeyInput) (*iam.CreateAccessKeyOutput, error) {
m.l.Lock()
defer m.l.Unlock()
m.newKeys[*input.UserName]++
count := m.newKeys[*input.UserName]
return &iam.CreateAccessKeyOutput{
AccessKey: &iam.AccessKey{
AccessKeyId: aws.String(fmt.Sprintf("%s-key", *input.UserName)),
SecretAccessKey: aws.String(fmt.Sprintf("%s-%d", *input.UserName, count)),
},
}, nil
}
// Test_RotationQueueInitialized creates a 2 node cluster and sets up the AWS
// credentials backend. The test creates 3 sets of static credentials. Two of
// those have a low rotation period and should get rotated during the test. The
// third has a high rotation period and should not be rotated. The test verifies
// that the correct secrets are rotated, then transfers leadership to the other
// node. The test verifies that credentials are once again rotated on the new
// active node.
func Test_RotationQueueInitialized(t *testing.T) {
mockClient := &mockIAM{
newKeys: make(map[string]int),
}
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"aws": func(ctx context.Context, config *logical.BackendConfig) (logical.Backend, error) {
b := Backend(config)
b.iamClient = mockClient
b.minAllowableRotationPeriod = 1 * time.Second
err := b.Setup(ctx, config)
return b, err
},
},
RollbackPeriod: 1 * time.Second,
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
NumCores: 2,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
err := client.Sys().Mount("aws", &api.MountInput{
Type: "aws",
})
require.NoError(t, err)
// create 3 static roles with different rotation periods
_, err = client.Logical().Write("aws/static-roles/role1", map[string]interface{}{
"username": "user1",
"rotation_period": "2s",
})
require.NoError(t, err)
_, err = client.Logical().Write("aws/static-roles/role2", map[string]interface{}{
"username": "user2",
"rotation_period": "1s",
})
require.NoError(t, err)
_, err = client.Logical().Write("aws/static-roles/role3", map[string]interface{}{
"username": "user3",
"rotation_period": "5m",
})
require.NoError(t, err)
getSecret := func(c *api.Client, role string) string {
r, err := c.Logical().Read("aws/static-creds/" + role)
require.NoError(t, err)
return r.Data["secret_key"].(string)
}
role1Secret := getSecret(client, "role1")
role2Secret := getSecret(client, "role2")
role3Secret := getSecret(client, "role3")
verifySecretsRotated := func(c *api.Client, originalRole1Secret, originalRole2Secret, originalRole3Secret string) (updatedRole1Secret, updatedRole2Secret string) {
testhelpers.RetryUntil(t, 5*time.Second, func() error {
// verify that both secrets with a low rotation period get rotated
updatedRole1Secret = getSecret(c, "role1")
updatedRole2Secret = getSecret(c, "role2")
if originalRole1Secret == updatedRole1Secret && originalRole2Secret == updatedRole2Secret {
return fmt.Errorf("secrets haven't been rotated")
}
// verify that the secret with a high rotation period doesn't get
// rotated
updatedRole3Secret := getSecret(c, "role3")
if updatedRole3Secret != role3Secret {
return fmt.Errorf("secret has been rotated but should not have been")
}
return nil
})
return
}
role1Secret, role2Secret = verifySecretsRotated(client, role1Secret, role2Secret, role3Secret)
// seal to make to core 1 the active node
cores[0].Seal(t)
// verify that the correct secrets get rotated again
vault.TestWaitActive(t, cores[1].Core)
verifySecretsRotated(cores[1].Client, role1Secret, role2Secret, role3Secret)
}

3
changelog/28775.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:bug
secrets/aws: Fix issue with static credentials not rotating after restart or leadership change.
```

View File

@@ -723,6 +723,7 @@ func SetNonRootToken(client *api.Client) error {
func RetryUntilAtCadence(t testing.TB, timeout, sleepTime time.Duration, f func() error) {
t.Helper()
fail := func(err error) {
t.Helper()
t.Fatalf("did not complete before deadline, err: %v", err)
}
RetryUntilAtCadenceWithHandler(t, timeout, sleepTime, fail, f)

View File

@@ -791,6 +791,7 @@ $ curl \
```json
{
"access_key": "AKIA...",
"expiration": "2024-10-25T15:02:10Z",
"secret_key": "..."
}
```