diff --git a/builtin/logical/aws/backend.go b/builtin/logical/aws/backend.go index 521ab5be09..464601aa04 100644 --- a/builtin/logical/aws/backend.go +++ b/builtin/logical/aws/backend.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/service/iam/iamiface" "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/queue" @@ -87,6 +88,10 @@ func Backend(_ *logical.BackendConfig) *backend { type backend struct { *framework.Backend + // Function pointer used to override the IAM client creation for mocked testing + // If set, this function will be called instead of creating real IAM clients + nonCachedClientIAMFunc func(context.Context, logical.Storage, hclog.Logger, *staticRoleEntry) (iamiface.IAMAPI, error) + // Mutex to protect access to reading and writing policies roleMutex sync.RWMutex @@ -131,8 +136,9 @@ func (b *backend) clearClients() { } // clientIAM returns the configured IAM client. If nil, it constructs a new one -// and returns it, setting it the internal variable -func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IAMAPI, error) { +// and returns it, setting it the internal variable. +// entry is only needed when configuring the client to use for role assumption. +func (b *backend) clientIAM(ctx context.Context, s logical.Storage, entry *staticRoleEntry) (iamiface.IAMAPI, error) { b.clientMutex.RLock() if b.iamClient != nil { b.clientMutex.RUnlock() @@ -150,10 +156,11 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IA return b.iamClient, nil } - iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger()) + iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger(), entry) if err != nil { return nil, err } + b.iamClient = iamClient return b.iamClient, nil @@ -248,3 +255,13 @@ func (b *backend) initialize(ctx context.Context, request *logical.Initializatio } return nil } + +// getNonCachedIAMClient returns an IAM client. In a test env, if a mocked client creation +// function is set (nonCachedClientIAMFunc), it will be used instead of the default client creation function. +// This allows us to mock AWS clients in tests. +func (b *backend) getNonCachedIAMClient(ctx context.Context, storage logical.Storage, cfg staticRoleEntry) (iamiface.IAMAPI, error) { + if b.nonCachedClientIAMFunc != nil { + return b.nonCachedClientIAMFunc(ctx, storage, b.Logger(), &cfg) + } + return b.nonCachedClientIAM(ctx, storage, b.Logger(), &cfg) +} diff --git a/builtin/logical/aws/client.go b/builtin/logical/aws/client.go index 4891666eae..d80e294068 100644 --- a/builtin/logical/aws/client.go +++ b/builtin/logical/aws/client.go @@ -148,21 +148,33 @@ func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientT return configs, nil } -func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) { - awsConfig, err := b.getRootConfigs(ctx, s, "iam", logger) - if err != nil { - return nil, err +func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (*iam.IAM, error) { + var awsConfig *aws.Config + var err error + + if entry != nil && entry.AssumeRoleARN != "" { + awsConfig, err = b.assumeRoleStatic(ctx, s, entry) + if err != nil { + return nil, fmt.Errorf("failed to assume role %q: %w", entry.AssumeRoleARN, err) + } + } else { + configs, err := b.getRootConfigs(ctx, s, "iam", logger) + if err != nil { + return nil, err + } + if len(configs) != 1 { + return nil, errors.New("could not obtain aws config") + } + awsConfig = configs[0] } - if len(awsConfig) != 1 { - return nil, errors.New("could not obtain aws config") - } - sess, err := session.NewSession(awsConfig[0]) + + sess, err := session.NewSession(awsConfig) if err != nil { return nil, err } client := iam.New(sess) if client == nil { - return nil, fmt.Errorf("could not obtain iam client") + return nil, fmt.Errorf("could not obtain IAM client") } return client, nil } diff --git a/builtin/logical/aws/client_ce.go b/builtin/logical/aws/client_ce.go new file mode 100644 index 0000000000..0baaeddefb --- /dev/null +++ b/builtin/logical/aws/client_ce.go @@ -0,0 +1,21 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package aws + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go/aws" + "github.com/hashicorp/vault/sdk/logical" +) + +// assumeRoleStatic assumes an AWS role for cross-account static role management. +// It uses the role ARN and session name provided in the staticRoleEntry configuration +// to generate credentials for the assumed role. +func (b *backend) assumeRoleStatic(ctx context.Context, s logical.Storage, entry *staticRoleEntry) (*aws.Config, error) { + return nil, fmt.Errorf("cross-account static roles are only supported in Vault Enterprise") +} diff --git a/builtin/logical/aws/iam_policies.go b/builtin/logical/aws/iam_policies.go index 9735a2af81..c6cbcee99c 100644 --- a/builtin/logical/aws/iam_policies.go +++ b/builtin/logical/aws/iam_policies.go @@ -66,7 +66,7 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr return nil, nil, nil } - iamClient, err = b.clientIAM(ctx, s) + iamClient, err = b.clientIAM(ctx, s, nil) if err != nil { return nil, nil, err } diff --git a/builtin/logical/aws/path_config_rotate_root.go b/builtin/logical/aws/path_config_rotate_root.go index 371b03714b..952b99b47a 100644 --- a/builtin/logical/aws/path_config_rotate_root.go +++ b/builtin/logical/aws/path_config_rotate_root.go @@ -42,7 +42,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R func (b *backend) rotateRoot(ctx context.Context, req *logical.Request) (*logical.Response, error) { // have to get the client config first because that takes out a read lock - client, err := b.clientIAM(ctx, req.Storage) + client, err := b.clientIAM(ctx, req.Storage, nil) if err != nil { return nil, err } diff --git a/builtin/logical/aws/path_static_roles.go b/builtin/logical/aws/path_static_roles.go index 3057fabb3a..6612adc4d9 100644 --- a/builtin/logical/aws/path_static_roles.go +++ b/builtin/logical/aws/path_static_roles.go @@ -21,16 +21,22 @@ import ( const ( pathStaticRole = "static-roles" - paramRoleName = "name" - paramUsername = "username" - paramRotationPeriod = "rotation_period" + paramRoleName = "name" + paramUsername = "username" + paramRotationPeriod = "rotation_period" + paramAssumeRoleARN = "assume_role_arn" + paramRoleSessionName = "assume_role_session_name" + paramExternalID = "external_id" ) type staticRoleEntry struct { - Name string `json:"name" structs:"name" mapstructure:"name"` - ID string `json:"id" structs:"id" mapstructure:"id"` - Username string `json:"username" structs:"username" mapstructure:"username"` - RotationPeriod time.Duration `json:"rotation_period" structs:"rotation_period" mapstructure:"rotation_period"` + Name string `json:"name" structs:"name" mapstructure:"name"` + ID string `json:"id" structs:"id" mapstructure:"id"` + Username string `json:"username" structs:"username" mapstructure:"username"` + RotationPeriod time.Duration `json:"rotation_period" structs:"rotation_period" mapstructure:"rotation_period"` + AssumeRoleARN string `json:"assume_role_arn" structs:"assume_role_arn" mapstructure:"assume_role_arn"` + AssumeRoleSessionName string `json:"assume_role_session_name" structs:"assume_role_session_name" mapstructure:"assume_role_session_name"` + ExternalID string `json:"external_id" structs:"external_id" mapstructure:"external_id"` } func pathStaticRoles(b *backend) *framework.Path { @@ -53,23 +59,12 @@ func pathStaticRoles(b *backend) *framework.Path { }, }}, } + fields := roleResponse[http.StatusOK][0].Fields + AddStaticAssumeRoleFieldsEnt(fields) return &framework.Path{ Pattern: fmt.Sprintf("%s/%s", pathStaticRole, framework.GenericNameWithAtRegex(paramRoleName)), - Fields: map[string]*framework.FieldSchema{ - paramRoleName: { - Type: framework.TypeString, - Description: descRoleName, - }, - paramUsername: { - Type: framework.TypeString, - Description: descUsername, - }, - paramRotationPeriod: { - Type: framework.TypeDurationSecond, - Description: descRotationPeriod, - }, - }, + Fields: fields, Operations: map[logical.Operation]framework.OperationHandler{ logical.ReadOperation: &framework.PathOperation{ @@ -159,6 +154,11 @@ func (b *backend) pathStaticRolesWrite(ctx context.Context, req *logical.Request // other params are optional if we're not Creating + err = validateAssumeRoleFields(data, &config) + if err != nil { + return nil, err + } + if rawUsername, ok := data.GetOk(paramUsername); ok { config.Username = rawUsername.(string) @@ -299,10 +299,11 @@ func (b *backend) validateRoleName(name string) error { // validateIAMUser checks the user information we have for the role against the information on AWS. On a create, it uses the username // to retrieve the user information and _sets_ the userID. On update, it validates the userID and username. func (b *backend) validateIAMUserExists(ctx context.Context, storage logical.Storage, entry *staticRoleEntry, isCreate bool) error { - c, err := b.clientIAM(ctx, storage) + c, err := b.getNonCachedIAMClient(ctx, storage, *entry) if err != nil { - return fmt.Errorf("unable to validate username %q: %w", entry.Username, err) + return fmt.Errorf("unable to get client to validate username %q: %w", entry.Username, err) } + b.iamClient = c // we don't really care about the content of the result, just that it's not an error out, err := c.GetUser(&iam.GetUserInput{ @@ -364,4 +365,7 @@ const ( descUsername = "The IAM user to adopt as a static role." descRotationPeriod = `Period by which to rotate the backing credential of the adopted user. This can be a Go duration (e.g, '1m', 24h'), or an integer number of seconds.` + descAssumeRoleARN = `The AWS ARN for the role to be assumed when interacting with the account specified.` + descRoleSessionName = `An identifier for the assumed role session.` + descExternalID = `An external ID to be passed to the assumed role session.` ) diff --git a/builtin/logical/aws/path_static_roles_ce.go b/builtin/logical/aws/path_static_roles_ce.go new file mode 100644 index 0000000000..7ea873a57b --- /dev/null +++ b/builtin/logical/aws/path_static_roles_ce.go @@ -0,0 +1,29 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package aws + +import ( + "fmt" + + "github.com/hashicorp/vault/sdk/framework" +) + +// AddStaticAssumeRoleFieldsEnt is a no-op for community edition +func AddStaticAssumeRoleFieldsEnt(fields map[string]*framework.FieldSchema) { + // no-op +} + +func validateAssumeRoleFields(data *framework.FieldData, config *staticRoleEntry) error { + _, hasAssumeRoleARN := data.GetOk(paramAssumeRoleARN) + _, hasRoleSessionName := data.GetOk(paramRoleSessionName) + _, hasExternalID := data.GetOk(paramExternalID) + + if hasAssumeRoleARN || hasRoleSessionName || hasExternalID { + return fmt.Errorf("cross-account static roles are only supported in Vault Enterprise") + } + + return nil +} diff --git a/builtin/logical/aws/path_static_roles_test.go b/builtin/logical/aws/path_static_roles_test.go index 0244d6a39c..54a365acb0 100644 --- a/builtin/logical/aws/path_static_roles_test.go +++ b/builtin/logical/aws/path_static_roles_test.go @@ -11,6 +11,8 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/iam/iamiface" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/awsutil" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" @@ -97,7 +99,10 @@ func TestStaticRolesValidation(t *testing.T) { if err != nil { t.Fatal(err) } - b.iamClient = miam + // Used to override the real IAM client creation to return the mocked client + b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + return miam, nil + } if err := b.Setup(bgCTX, config); err != nil { t.Fatal(err) } @@ -241,7 +246,10 @@ func TestStaticRolesWrite(t *testing.T) { } b := Backend(config) - b.iamClient = miam + // Used to override the real IAM client creation to return the mocked client + b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + return miam, nil + } if err := b.Setup(bgCTX, config); err != nil { t.Fatal(err) } @@ -454,7 +462,10 @@ func TestStaticRoleDelete(t *testing.T) { } b := Backend(config) - b.iamClient = miam + // Used to override the real IAM client creation to return the mocked client + b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + return miam, nil + } // put in storage staticRole := staticRoleEntry{ diff --git a/builtin/logical/aws/path_user.go b/builtin/logical/aws/path_user.go index 430f7754ee..fb8c41c5f0 100644 --- a/builtin/logical/aws/path_user.go +++ b/builtin/logical/aws/path_user.go @@ -175,7 +175,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k username := entry.UserName // Get the client - client, err := b.clientIAM(ctx, req.Storage) + client, err := b.clientIAM(ctx, req.Storage, nil) if err != nil { return err } diff --git a/builtin/logical/aws/rotation.go b/builtin/logical/aws/rotation.go index cc81169d6e..61ecdddf79 100644 --- a/builtin/logical/aws/rotation.go +++ b/builtin/logical/aws/rotation.go @@ -59,6 +59,7 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage) return false, nil } + b.Logger().Debug("rotating credential", "role", item.Key) cfg := item.Value.(staticRoleEntry) creds, err := b.createCredential(ctx, storage, cfg, true) @@ -86,9 +87,10 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage) // createCredential will create a new iam credential, deleting the oldest one if necessary. func (b *backend) createCredential(ctx context.Context, storage logical.Storage, cfg staticRoleEntry, shouldLockStorage bool) (*awsCredentials, error) { - iamClient, err := b.clientIAM(ctx, storage) + // Always create a fresh client + iamClient, err := b.getNonCachedIAMClient(ctx, storage, cfg) if err != nil { - return nil, fmt.Errorf("unable to get the AWS IAM client: %w", err) + return nil, fmt.Errorf("failed to get IAM client for role %q: %w", cfg.Name, err) } // IAM users can have a most 2 sets of keys at a time. @@ -190,8 +192,13 @@ func (b *backend) deleteCredential(ctx context.Context, storage logical.Storage, return fmt.Errorf("couldn't delete from storage: %w", err) } + iamClient, err := b.nonCachedClientIAM(ctx, storage, b.Logger(), &cfg) + if err != nil { + return fmt.Errorf("failed to get IAM client for role %q while deleting: %w", cfg.Name, err) + } + // because we have the information, this is the one we created, so it's safe for us to delete. - _, err = b.iamClient.DeleteAccessKey(&iam.DeleteAccessKeyInput{ + _, err = iamClient.DeleteAccessKey(&iam.DeleteAccessKeyInput{ AccessKeyId: aws.String(creds.AccessKeyID), UserName: aws.String(cfg.Username), }) diff --git a/builtin/logical/aws/rotation_test.go b/builtin/logical/aws/rotation_test.go index bff89b2a11..46e73beda4 100644 --- a/builtin/logical/aws/rotation_test.go +++ b/builtin/logical/aws/rotation_test.go @@ -14,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/awsutil" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/testhelpers" @@ -146,7 +147,14 @@ func TestRotation(t *testing.T) { if err != nil { t.Fatalf("couldn't initialze mock IAM handler: %s", err) } - b.iamClient = miam + + // Used to override the IAM client creation to return the mocked client + b.nonCachedClientIAMFunc = func(ctx context.Context, storage logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + if entry.Username == cred.config.Username && entry.ID == cred.config.ID { + return miam, nil + } + return nil, fmt.Errorf("unexpected IAM client creation for user %q", entry.Username) + } c, err := b.createCredential(bgCTX, config.StorageView, cred.config, true) if err != nil { @@ -192,7 +200,11 @@ func TestRotation(t *testing.T) { if err != nil { t.Fatalf("couldn't initialze mock IAM handler: %s", err) } - b.iamClient = miam + + // Set the IAM mock client to be used in the rotation + b.nonCachedClientIAMFunc = func(ctx context.Context, storage logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + return miam, nil + } req := &logical.Request{ Storage: config.StorageView, @@ -340,7 +352,10 @@ func TestCreateCredential(t *testing.T) { } b := Backend(config) - b.iamClient = fiam + + b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + return fiam, nil + } _, err = b.createCredential(context.Background(), config.StorageView, staticRoleEntry{Username: c.username, ID: c.id}, true) if err != nil { @@ -403,7 +418,10 @@ func TestRequeueOnError(t *testing.T) { t.Fail() } - b.iamClient = miam + // Used to override the IAM real client creation to return the mocked client + b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + return miam, nil + } _, err = b.createCredential(bgCTX, config.StorageView, cred, true) if err != nil { @@ -428,7 +446,9 @@ func TestRequeueOnError(t *testing.T) { if err != nil { t.Fatalf("couldn't initialize the mock iam: %s", err) } - b.iamClient = miam + b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + return miam, nil + } // now rotate, but it will fail r, e := b.rotateCredential(bgCTX, config.StorageView) @@ -501,6 +521,12 @@ func Test_RotationQueueInitialized(t *testing.T) { b := Backend(config) b.iamClient = mockClient b.minAllowableRotationPeriod = 1 * time.Second + + // Used to override the IAM real client creation to return the mocked client + b.nonCachedClientIAMFunc = func(ctx context.Context, storage logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) { + return mockClient, nil + } + err := b.Setup(ctx, config) return b, err }, diff --git a/builtin/logical/aws/secret_access_keys.go b/builtin/logical/aws/secret_access_keys.go index a9a9290cc5..93624e1619 100644 --- a/builtin/logical/aws/secret_access_keys.go +++ b/builtin/logical/aws/secret_access_keys.go @@ -361,7 +361,7 @@ func (b *backend) secretAccessKeysCreate( displayName, policyName string, role *awsRoleEntry, ) (*logical.Response, error) { - iamClient, err := b.clientIAM(ctx, s) + iamClient, err := b.clientIAM(ctx, s, nil) if err != nil { return logical.ErrorResponse(err.Error()), nil }