Add support for cross account management of static roles in AWS Secrets (#29645)

* aws-secrets/add-cross-acc-mgmt-static-roles

* refactor

* add function pointer for tests

* delete commented out code

* update

* update comment

* update func name

* add flag

* remove docs
This commit is contained in:
Milena Zlaticanin
2025-02-14 14:13:00 -07:00
committed by GitHub
parent 64e92ba9fd
commit 6e0c771e57
12 changed files with 177 additions and 50 deletions

View File

@@ -12,6 +12,7 @@ import (
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/queue"
@@ -87,6 +88,10 @@ func Backend(_ *logical.BackendConfig) *backend {
type backend struct {
*framework.Backend
// Function pointer used to override the IAM client creation for mocked testing
// If set, this function will be called instead of creating real IAM clients
nonCachedClientIAMFunc func(context.Context, logical.Storage, hclog.Logger, *staticRoleEntry) (iamiface.IAMAPI, error)
// Mutex to protect access to reading and writing policies
roleMutex sync.RWMutex
@@ -131,8 +136,9 @@ func (b *backend) clearClients() {
}
// clientIAM returns the configured IAM client. If nil, it constructs a new one
// and returns it, setting it the internal variable
func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IAMAPI, error) {
// and returns it, setting it the internal variable.
// entry is only needed when configuring the client to use for role assumption.
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
b.clientMutex.RLock()
if b.iamClient != nil {
b.clientMutex.RUnlock()
@@ -150,10 +156,11 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IA
return b.iamClient, nil
}
iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger())
iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger(), entry)
if err != nil {
return nil, err
}
b.iamClient = iamClient
return b.iamClient, nil
@@ -248,3 +255,13 @@ func (b *backend) initialize(ctx context.Context, request *logical.Initializatio
}
return nil
}
// getNonCachedIAMClient returns an IAM client. In a test env, if a mocked client creation
// function is set (nonCachedClientIAMFunc), it will be used instead of the default client creation function.
// This allows us to mock AWS clients in tests.
func (b *backend) getNonCachedIAMClient(ctx context.Context, storage logical.Storage, cfg staticRoleEntry) (iamiface.IAMAPI, error) {
if b.nonCachedClientIAMFunc != nil {
return b.nonCachedClientIAMFunc(ctx, storage, b.Logger(), &cfg)
}
return b.nonCachedClientIAM(ctx, storage, b.Logger(), &cfg)
}

View File

@@ -148,21 +148,33 @@ func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientT
return configs, nil
}
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := b.getRootConfigs(ctx, s, "iam", logger)
if err != nil {
return nil, err
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (*iam.IAM, error) {
var awsConfig *aws.Config
var err error
if entry != nil && entry.AssumeRoleARN != "" {
awsConfig, err = b.assumeRoleStatic(ctx, s, entry)
if err != nil {
return nil, fmt.Errorf("failed to assume role %q: %w", entry.AssumeRoleARN, err)
}
} else {
configs, err := b.getRootConfigs(ctx, s, "iam", logger)
if err != nil {
return nil, err
}
if len(configs) != 1 {
return nil, errors.New("could not obtain aws config")
}
awsConfig = configs[0]
}
if len(awsConfig) != 1 {
return nil, errors.New("could not obtain aws config")
}
sess, err := session.NewSession(awsConfig[0])
sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, err
}
client := iam.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain iam client")
return nil, fmt.Errorf("could not obtain IAM client")
}
return client, nil
}

View File

@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package aws
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/vault/sdk/logical"
)
// assumeRoleStatic assumes an AWS role for cross-account static role management.
// It uses the role ARN and session name provided in the staticRoleEntry configuration
// to generate credentials for the assumed role.
func (b *backend) assumeRoleStatic(ctx context.Context, s logical.Storage, entry *staticRoleEntry) (*aws.Config, error) {
return nil, fmt.Errorf("cross-account static roles are only supported in Vault Enterprise")
}

View File

@@ -66,7 +66,7 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr
return nil, nil, nil
}
iamClient, err = b.clientIAM(ctx, s)
iamClient, err = b.clientIAM(ctx, s, nil)
if err != nil {
return nil, nil, err
}

View File

@@ -42,7 +42,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
func (b *backend) rotateRoot(ctx context.Context, req *logical.Request) (*logical.Response, error) {
// have to get the client config first because that takes out a read lock
client, err := b.clientIAM(ctx, req.Storage)
client, err := b.clientIAM(ctx, req.Storage, nil)
if err != nil {
return nil, err
}

View File

@@ -21,16 +21,22 @@ import (
const (
pathStaticRole = "static-roles"
paramRoleName = "name"
paramUsername = "username"
paramRotationPeriod = "rotation_period"
paramRoleName = "name"
paramUsername = "username"
paramRotationPeriod = "rotation_period"
paramAssumeRoleARN = "assume_role_arn"
paramRoleSessionName = "assume_role_session_name"
paramExternalID = "external_id"
)
type staticRoleEntry struct {
Name string `json:"name" structs:"name" mapstructure:"name"`
ID string `json:"id" structs:"id" mapstructure:"id"`
Username string `json:"username" structs:"username" mapstructure:"username"`
RotationPeriod time.Duration `json:"rotation_period" structs:"rotation_period" mapstructure:"rotation_period"`
Name string `json:"name" structs:"name" mapstructure:"name"`
ID string `json:"id" structs:"id" mapstructure:"id"`
Username string `json:"username" structs:"username" mapstructure:"username"`
RotationPeriod time.Duration `json:"rotation_period" structs:"rotation_period" mapstructure:"rotation_period"`
AssumeRoleARN string `json:"assume_role_arn" structs:"assume_role_arn" mapstructure:"assume_role_arn"`
AssumeRoleSessionName string `json:"assume_role_session_name" structs:"assume_role_session_name" mapstructure:"assume_role_session_name"`
ExternalID string `json:"external_id" structs:"external_id" mapstructure:"external_id"`
}
func pathStaticRoles(b *backend) *framework.Path {
@@ -53,23 +59,12 @@ func pathStaticRoles(b *backend) *framework.Path {
},
}},
}
fields := roleResponse[http.StatusOK][0].Fields
AddStaticAssumeRoleFieldsEnt(fields)
return &framework.Path{
Pattern: fmt.Sprintf("%s/%s", pathStaticRole, framework.GenericNameWithAtRegex(paramRoleName)),
Fields: map[string]*framework.FieldSchema{
paramRoleName: {
Type: framework.TypeString,
Description: descRoleName,
},
paramUsername: {
Type: framework.TypeString,
Description: descUsername,
},
paramRotationPeriod: {
Type: framework.TypeDurationSecond,
Description: descRotationPeriod,
},
},
Fields: fields,
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
@@ -159,6 +154,11 @@ func (b *backend) pathStaticRolesWrite(ctx context.Context, req *logical.Request
// other params are optional if we're not Creating
err = validateAssumeRoleFields(data, &config)
if err != nil {
return nil, err
}
if rawUsername, ok := data.GetOk(paramUsername); ok {
config.Username = rawUsername.(string)
@@ -299,10 +299,11 @@ func (b *backend) validateRoleName(name string) error {
// validateIAMUser checks the user information we have for the role against the information on AWS. On a create, it uses the username
// to retrieve the user information and _sets_ the userID. On update, it validates the userID and username.
func (b *backend) validateIAMUserExists(ctx context.Context, storage logical.Storage, entry *staticRoleEntry, isCreate bool) error {
c, err := b.clientIAM(ctx, storage)
c, err := b.getNonCachedIAMClient(ctx, storage, *entry)
if err != nil {
return fmt.Errorf("unable to validate username %q: %w", entry.Username, err)
return fmt.Errorf("unable to get client to validate username %q: %w", entry.Username, err)
}
b.iamClient = c
// we don't really care about the content of the result, just that it's not an error
out, err := c.GetUser(&iam.GetUserInput{
@@ -364,4 +365,7 @@ const (
descUsername = "The IAM user to adopt as a static role."
descRotationPeriod = `Period by which to rotate the backing credential of the adopted user.
This can be a Go duration (e.g, '1m', 24h'), or an integer number of seconds.`
descAssumeRoleARN = `The AWS ARN for the role to be assumed when interacting with the account specified.`
descRoleSessionName = `An identifier for the assumed role session.`
descExternalID = `An external ID to be passed to the assumed role session.`
)

View File

@@ -0,0 +1,29 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package aws
import (
"fmt"
"github.com/hashicorp/vault/sdk/framework"
)
// AddStaticAssumeRoleFieldsEnt is a no-op for community edition
func AddStaticAssumeRoleFieldsEnt(fields map[string]*framework.FieldSchema) {
// no-op
}
func validateAssumeRoleFields(data *framework.FieldData, config *staticRoleEntry) error {
_, hasAssumeRoleARN := data.GetOk(paramAssumeRoleARN)
_, hasRoleSessionName := data.GetOk(paramRoleSessionName)
_, hasExternalID := data.GetOk(paramExternalID)
if hasAssumeRoleARN || hasRoleSessionName || hasExternalID {
return fmt.Errorf("cross-account static roles are only supported in Vault Enterprise")
}
return nil
}

View File

@@ -11,6 +11,8 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
@@ -97,7 +99,10 @@ func TestStaticRolesValidation(t *testing.T) {
if err != nil {
t.Fatal(err)
}
b.iamClient = miam
// Used to override the real IAM client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}
if err := b.Setup(bgCTX, config); err != nil {
t.Fatal(err)
}
@@ -241,7 +246,10 @@ func TestStaticRolesWrite(t *testing.T) {
}
b := Backend(config)
b.iamClient = miam
// Used to override the real IAM client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}
if err := b.Setup(bgCTX, config); err != nil {
t.Fatal(err)
}
@@ -454,7 +462,10 @@ func TestStaticRoleDelete(t *testing.T) {
}
b := Backend(config)
b.iamClient = miam
// Used to override the real IAM client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}
// put in storage
staticRole := staticRoleEntry{

View File

@@ -175,7 +175,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
username := entry.UserName
// Get the client
client, err := b.clientIAM(ctx, req.Storage)
client, err := b.clientIAM(ctx, req.Storage, nil)
if err != nil {
return err
}

View File

@@ -59,6 +59,7 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)
return false, nil
}
b.Logger().Debug("rotating credential", "role", item.Key)
cfg := item.Value.(staticRoleEntry)
creds, err := b.createCredential(ctx, storage, cfg, true)
@@ -86,9 +87,10 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)
// createCredential will create a new iam credential, deleting the oldest one if necessary.
func (b *backend) createCredential(ctx context.Context, storage logical.Storage, cfg staticRoleEntry, shouldLockStorage bool) (*awsCredentials, error) {
iamClient, err := b.clientIAM(ctx, storage)
// Always create a fresh client
iamClient, err := b.getNonCachedIAMClient(ctx, storage, cfg)
if err != nil {
return nil, fmt.Errorf("unable to get the AWS IAM client: %w", err)
return nil, fmt.Errorf("failed to get IAM client for role %q: %w", cfg.Name, err)
}
// IAM users can have a most 2 sets of keys at a time.
@@ -190,8 +192,13 @@ func (b *backend) deleteCredential(ctx context.Context, storage logical.Storage,
return fmt.Errorf("couldn't delete from storage: %w", err)
}
iamClient, err := b.nonCachedClientIAM(ctx, storage, b.Logger(), &cfg)
if err != nil {
return fmt.Errorf("failed to get IAM client for role %q while deleting: %w", cfg.Name, err)
}
// because we have the information, this is the one we created, so it's safe for us to delete.
_, err = b.iamClient.DeleteAccessKey(&iam.DeleteAccessKeyInput{
_, err = iamClient.DeleteAccessKey(&iam.DeleteAccessKeyInput{
AccessKeyId: aws.String(creds.AccessKeyID),
UserName: aws.String(cfg.Username),
})

View File

@@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/testhelpers"
@@ -146,7 +147,14 @@ func TestRotation(t *testing.T) {
if err != nil {
t.Fatalf("couldn't initialze mock IAM handler: %s", err)
}
b.iamClient = miam
// Used to override the IAM client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, storage logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
if entry.Username == cred.config.Username && entry.ID == cred.config.ID {
return miam, nil
}
return nil, fmt.Errorf("unexpected IAM client creation for user %q", entry.Username)
}
c, err := b.createCredential(bgCTX, config.StorageView, cred.config, true)
if err != nil {
@@ -192,7 +200,11 @@ func TestRotation(t *testing.T) {
if err != nil {
t.Fatalf("couldn't initialze mock IAM handler: %s", err)
}
b.iamClient = miam
// Set the IAM mock client to be used in the rotation
b.nonCachedClientIAMFunc = func(ctx context.Context, storage logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}
req := &logical.Request{
Storage: config.StorageView,
@@ -340,7 +352,10 @@ func TestCreateCredential(t *testing.T) {
}
b := Backend(config)
b.iamClient = fiam
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return fiam, nil
}
_, err = b.createCredential(context.Background(), config.StorageView, staticRoleEntry{Username: c.username, ID: c.id}, true)
if err != nil {
@@ -403,7 +418,10 @@ func TestRequeueOnError(t *testing.T) {
t.Fail()
}
b.iamClient = miam
// Used to override the IAM real client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}
_, err = b.createCredential(bgCTX, config.StorageView, cred, true)
if err != nil {
@@ -428,7 +446,9 @@ func TestRequeueOnError(t *testing.T) {
if err != nil {
t.Fatalf("couldn't initialize the mock iam: %s", err)
}
b.iamClient = miam
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}
// now rotate, but it will fail
r, e := b.rotateCredential(bgCTX, config.StorageView)
@@ -501,6 +521,12 @@ func Test_RotationQueueInitialized(t *testing.T) {
b := Backend(config)
b.iamClient = mockClient
b.minAllowableRotationPeriod = 1 * time.Second
// Used to override the IAM real client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, storage logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return mockClient, nil
}
err := b.Setup(ctx, config)
return b, err
},

View File

@@ -361,7 +361,7 @@ func (b *backend) secretAccessKeysCreate(
displayName, policyName string,
role *awsRoleEntry,
) (*logical.Response, error) {
iamClient, err := b.clientIAM(ctx, s)
iamClient, err := b.clientIAM(ctx, s, nil)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}