aws: pass cancelable context with aws calls (#19365)

* auth/aws: use cancelable context with aws calls

* secrets/aws: use cancelable context with aws calls
This commit is contained in:
Mason Foster
2023-03-23 13:02:24 -04:00
committed by GitHub
parent 85c3eab989
commit cfff8d420e
12 changed files with 68 additions and 44 deletions

View File

@@ -312,7 +312,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
switch entity.Type { switch entity.Type {
case "user": case "user":
userInfo, err := iamClient.GetUser(&iam.GetUserInput{UserName: &entity.FriendlyName}) userInfo, err := iamClient.GetUserWithContext(ctx, &iam.GetUserInput{UserName: &entity.FriendlyName})
if err != nil { if err != nil {
return "", awsutil.AppendAWSError(err) return "", awsutil.AppendAWSError(err)
} }
@@ -321,7 +321,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
} }
return *userInfo.User.UserId, nil return *userInfo.User.UserId, nil
case "role": case "role":
roleInfo, err := iamClient.GetRole(&iam.GetRoleInput{RoleName: &entity.FriendlyName}) roleInfo, err := iamClient.GetRoleWithContext(ctx, &iam.GetRoleInput{RoleName: &entity.FriendlyName})
if err != nil { if err != nil {
return "", awsutil.AppendAWSError(err) return "", awsutil.AppendAWSError(err)
} }
@@ -330,7 +330,7 @@ func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storag
} }
return *roleInfo.Role.RoleId, nil return *roleInfo.Role.RoleId, nil
case "instance-profile": case "instance-profile":
profileInfo, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName}) profileInfo, err := iamClient.GetInstanceProfileWithContext(ctx, &iam.GetInstanceProfileInput{InstanceProfileName: &entity.FriendlyName})
if err != nil { if err != nil {
return "", awsutil.AppendAWSError(err) return "", awsutil.AppendAWSError(err)
} }

View File

@@ -122,7 +122,7 @@ func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region
return nil, fmt.Errorf("could not obtain sts client: %w", err) return nil, fmt.Errorf("could not obtain sts client: %w", err)
} }
inputParams := &sts.GetCallerIdentityInput{} inputParams := &sts.GetCallerIdentityInput{}
identity, err := client.GetCallerIdentity(inputParams) identity, err := client.GetCallerIdentityWithContext(ctx, inputParams)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to fetch current caller: %w", err) return nil, fmt.Errorf("unable to fetch current caller: %w", err)
} }

View File

@@ -100,7 +100,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
// Get the current user's name since it's required to create an access key. // Get the current user's name since it's required to create an access key.
// Empty input means get the current user. // Empty input means get the current user.
var getUserInput iam.GetUserInput var getUserInput iam.GetUserInput
getUserRes, err := iamClient.GetUser(&getUserInput) getUserRes, err := iamClient.GetUserWithContext(ctx, &getUserInput)
if err != nil { if err != nil {
return nil, fmt.Errorf("error calling GetUser: %w", err) return nil, fmt.Errorf("error calling GetUser: %w", err)
} }
@@ -118,7 +118,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
createAccessKeyInput := iam.CreateAccessKeyInput{ createAccessKeyInput := iam.CreateAccessKeyInput{
UserName: getUserRes.User.UserName, UserName: getUserRes.User.UserName,
} }
createAccessKeyRes, err := iamClient.CreateAccessKey(&createAccessKeyInput) createAccessKeyRes, err := iamClient.CreateAccessKeyWithContext(ctx, &createAccessKeyInput)
if err != nil { if err != nil {
return nil, fmt.Errorf("error calling CreateAccessKey: %w", err) return nil, fmt.Errorf("error calling CreateAccessKey: %w", err)
} }
@@ -142,7 +142,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
AccessKeyId: createAccessKeyRes.AccessKey.AccessKeyId, AccessKeyId: createAccessKeyRes.AccessKey.AccessKeyId,
UserName: getUserRes.User.UserName, UserName: getUserRes.User.UserName,
} }
if _, err := iamClient.DeleteAccessKey(&deleteAccessKeyInput); err != nil { if _, err := iamClient.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput); err != nil {
// Include this error in the errs returned by this method. // Include this error in the errs returned by this method.
errs = multierror.Append(errs, fmt.Errorf("error deleting newly created but unstored access key ID %s: %s", *createAccessKeyRes.AccessKey.AccessKeyId, err)) errs = multierror.Append(errs, fmt.Errorf("error deleting newly created but unstored access key ID %s: %s", *createAccessKeyRes.AccessKey.AccessKeyId, err))
} }
@@ -179,7 +179,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
AccessKeyId: aws.String(oldAccessKey), AccessKeyId: aws.String(oldAccessKey),
UserName: getUserRes.User.UserName, UserName: getUserRes.User.UserName,
} }
if _, err = iamClient.DeleteAccessKey(&deleteAccessKeyInput); err != nil { if _, err = iamClient.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput); err != nil {
errs = multierror.Append(errs, fmt.Errorf("error deleting old access key ID %s: %w", oldAccessKey, err)) errs = multierror.Append(errs, fmt.Errorf("error deleting old access key ID %s: %w", oldAccessKey, err))
return nil, errs return nil, errs
} }

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface" "github.com/aws/aws-sdk-go/service/iam/iamiface"
@@ -15,9 +16,23 @@ import (
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
type mockIAMClient awsutil.MockIAM
func (m *mockIAMClient) GetUserWithContext(_ aws.Context, input *iam.GetUserInput, _ ...request.Option) (*iam.GetUserOutput, error) {
return (*awsutil.MockIAM)(m).GetUser(input)
}
func (m *mockIAMClient) CreateAccessKeyWithContext(_ aws.Context, input *iam.CreateAccessKeyInput, _ ...request.Option) (*iam.CreateAccessKeyOutput, error) {
return (*awsutil.MockIAM)(m).CreateAccessKey(input)
}
func (m *mockIAMClient) DeleteAccessKeyWithContext(_ aws.Context, input *iam.DeleteAccessKeyInput, _ ...request.Option) (*iam.DeleteAccessKeyOutput, error) {
return (*awsutil.MockIAM)(m).DeleteAccessKey(input)
}
func TestPathConfigRotateRoot(t *testing.T) { func TestPathConfigRotateRoot(t *testing.T) {
getIAMClient = func(sess *session.Session) iamiface.IAMAPI { getIAMClient = func(sess *session.Session) iamiface.IAMAPI {
return &awsutil.MockIAM{ return &mockIAMClient{
CreateAccessKeyOutput: &iam.CreateAccessKeyOutput{ CreateAccessKeyOutput: &iam.CreateAccessKeyOutput{
AccessKey: &iam.AccessKey{ AccessKey: &iam.AccessKey{
AccessKeyId: aws.String("fizz2"), AccessKeyId: aws.String("fizz2"),

View File

@@ -106,8 +106,8 @@ This must match the request body included in the signature.`,
"iam_request_headers": { "iam_request_headers": {
Type: framework.TypeHeader, Type: framework.TypeHeader,
Description: `Key/value pairs of headers for use in the Description: `Key/value pairs of headers for use in the
sts:GetCallerIdentity HTTP requests headers when auth_type is iam. Can be either sts:GetCallerIdentity HTTP requests headers when auth_type is iam. Can be either
a Base64-encoded, JSON-serialized string, or a JSON object of key/value pairs. a Base64-encoded, JSON-serialized string, or a JSON object of key/value pairs.
This must at a minimum include the headers over which AWS has included a signature.`, This must at a minimum include the headers over which AWS has included a signature.`,
}, },
"identity": { "identity": {
@@ -340,7 +340,7 @@ func (b *backend) pathLoginResolveRoleIam(ctx context.Context, req *logical.Requ
// instanceIamRoleARN fetches the IAM role ARN associated with the given // instanceIamRoleARN fetches the IAM role ARN associated with the given
// instance profile name // instance profile name
func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName string) (string, error) { func (b *backend) instanceIamRoleARN(ctx context.Context, iamClient *iam.IAM, instanceProfileName string) (string, error) {
if iamClient == nil { if iamClient == nil {
return "", fmt.Errorf("nil iamClient") return "", fmt.Errorf("nil iamClient")
} }
@@ -348,7 +348,7 @@ func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName str
return "", fmt.Errorf("missing instance profile name") return "", fmt.Errorf("missing instance profile name")
} }
profile, err := iamClient.GetInstanceProfile(&iam.GetInstanceProfileInput{ profile, err := iamClient.GetInstanceProfileWithContext(ctx, &iam.GetInstanceProfileInput{
InstanceProfileName: aws.String(instanceProfileName), InstanceProfileName: aws.String(instanceProfileName),
}) })
if err != nil { if err != nil {
@@ -382,7 +382,7 @@ func (b *backend) validateInstance(ctx context.Context, s logical.Storage, insta
return nil, err return nil, err
} }
status, err := ec2Client.DescribeInstances(&ec2.DescribeInstancesInput{ status, err := ec2Client.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{
InstanceIds: []*string{ InstanceIds: []*string{
aws.String(instanceID), aws.String(instanceID),
}, },
@@ -724,7 +724,7 @@ func (b *backend) verifyInstanceMeetsRoleRequirements(ctx context.Context,
} else if iamClient == nil { } else if iamClient == nil {
return nil, fmt.Errorf("received a nil iamClient") return nil, fmt.Errorf("received a nil iamClient")
} }
iamRoleARN, err := b.instanceIamRoleARN(iamClient, iamInstanceProfileEntity.FriendlyName) iamRoleARN, err := b.instanceIamRoleARN(ctx, iamClient, iamInstanceProfileEntity.FriendlyName)
if err != nil { if err != nil {
return nil, fmt.Errorf("IAM role ARN could not be fetched: %w", err) return nil, fmt.Errorf("IAM role ARN could not be fetched: %w", err)
} }
@@ -1835,7 +1835,7 @@ func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage)
input := iam.GetUserInput{ input := iam.GetUserInput{
UserName: aws.String(e.FriendlyName), UserName: aws.String(e.FriendlyName),
} }
resp, err := client.GetUser(&input) resp, err := client.GetUserWithContext(ctx, &input)
if err != nil { if err != nil {
return "", fmt.Errorf("error fetching user %q: %w", e.FriendlyName, err) return "", fmt.Errorf("error fetching user %q: %w", e.FriendlyName, err)
} }
@@ -1849,7 +1849,7 @@ func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage)
input := iam.GetRoleInput{ input := iam.GetRoleInput{
RoleName: aws.String(e.FriendlyName), RoleName: aws.String(e.FriendlyName),
} }
resp, err := client.GetRole(&input) resp, err := client.GetRoleWithContext(ctx, &input)
if err != nil { if err != nil {
return "", fmt.Errorf("error fetching role %q: %w", e.FriendlyName, err) return "", fmt.Errorf("error fetching role %q: %w", e.FriendlyName, err)
} }

View File

@@ -19,6 +19,7 @@ import (
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2"
@@ -39,7 +40,7 @@ type mockIAMClient struct {
iamiface.IAMAPI iamiface.IAMAPI
} }
func (m *mockIAMClient) CreateUser(input *iam.CreateUserInput) (*iam.CreateUserOutput, error) { func (m *mockIAMClient) CreateUserWithContext(_ aws.Context, input *iam.CreateUserInput, _ ...request.Option) (*iam.CreateUserOutput, error) {
return nil, awserr.New("Throttling", "", nil) return nil, awserr.New("Throttling", "", nil)
} }

View File

@@ -73,7 +73,7 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr
for _, g := range iamGroups { for _, g := range iamGroups {
// Collect managed policy ARNs from the IAM Group // Collect managed policy ARNs from the IAM Group
agp, err = iamClient.ListAttachedGroupPolicies(&iam.ListAttachedGroupPoliciesInput{ agp, err = iamClient.ListAttachedGroupPoliciesWithContext(ctx, &iam.ListAttachedGroupPoliciesInput{
GroupName: aws.String(g), GroupName: aws.String(g),
}) })
if err != nil { if err != nil {
@@ -84,14 +84,14 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr
} }
// Collect inline policy names from the IAM Group // Collect inline policy names from the IAM Group
inlinePolicies, err = iamClient.ListGroupPolicies(&iam.ListGroupPoliciesInput{ inlinePolicies, err = iamClient.ListGroupPoliciesWithContext(ctx, &iam.ListGroupPoliciesInput{
GroupName: aws.String(g), GroupName: aws.String(g),
}) })
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
for _, iP := range inlinePolicies.PolicyNames { for _, iP := range inlinePolicies.PolicyNames {
inlinePolicyDoc, err = iamClient.GetGroupPolicy(&iam.GetGroupPolicyInput{ inlinePolicyDoc, err = iamClient.GetGroupPolicyWithContext(ctx, &iam.GetGroupPolicyInput{
GroupName: &g, GroupName: &g,
PolicyName: iP, PolicyName: iP,
}) })

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface" "github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -29,15 +30,15 @@ type mockGroupIAMClient struct {
GetGroupPolicyResp iam.GetGroupPolicyOutput GetGroupPolicyResp iam.GetGroupPolicyOutput
} }
func (m mockGroupIAMClient) ListAttachedGroupPolicies(in *iam.ListAttachedGroupPoliciesInput) (*iam.ListAttachedGroupPoliciesOutput, error) { func (m mockGroupIAMClient) ListAttachedGroupPoliciesWithContext(_ aws.Context, in *iam.ListAttachedGroupPoliciesInput, _ ...request.Option) (*iam.ListAttachedGroupPoliciesOutput, error) {
return &m.ListAttachedGroupPoliciesResp, nil return &m.ListAttachedGroupPoliciesResp, nil
} }
func (m mockGroupIAMClient) ListGroupPolicies(in *iam.ListGroupPoliciesInput) (*iam.ListGroupPoliciesOutput, error) { func (m mockGroupIAMClient) ListGroupPoliciesWithContext(_ aws.Context, in *iam.ListGroupPoliciesInput, _ ...request.Option) (*iam.ListGroupPoliciesOutput, error) {
return &m.ListGroupPoliciesResp, nil return &m.ListGroupPoliciesResp, nil
} }
func (m mockGroupIAMClient) GetGroupPolicy(in *iam.GetGroupPolicyInput) (*iam.GetGroupPolicyOutput, error) { func (m mockGroupIAMClient) GetGroupPolicyWithContext(_ aws.Context, in *iam.GetGroupPolicyInput, _ ...request.Option) (*iam.GetGroupPolicyOutput, error) {
return &m.GetGroupPolicyResp, nil return &m.GetGroupPolicyResp, nil
} }

View File

@@ -59,7 +59,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
} }
var getUserInput iam.GetUserInput // empty input means get current user var getUserInput iam.GetUserInput // empty input means get current user
getUserRes, err := client.GetUser(&getUserInput) getUserRes, err := client.GetUserWithContext(ctx, &getUserInput)
if err != nil { if err != nil {
return nil, fmt.Errorf("error calling GetUser: %w", err) return nil, fmt.Errorf("error calling GetUser: %w", err)
} }
@@ -76,7 +76,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
createAccessKeyInput := iam.CreateAccessKeyInput{ createAccessKeyInput := iam.CreateAccessKeyInput{
UserName: getUserRes.User.UserName, UserName: getUserRes.User.UserName,
} }
createAccessKeyRes, err := client.CreateAccessKey(&createAccessKeyInput) createAccessKeyRes, err := client.CreateAccessKeyWithContext(ctx, &createAccessKeyInput)
if err != nil { if err != nil {
return nil, fmt.Errorf("error calling CreateAccessKey: %w", err) return nil, fmt.Errorf("error calling CreateAccessKey: %w", err)
} }
@@ -107,7 +107,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R
AccessKeyId: aws.String(oldAccessKey), AccessKeyId: aws.String(oldAccessKey),
UserName: getUserRes.User.UserName, UserName: getUserRes.User.UserName,
} }
_, err = client.DeleteAccessKey(&deleteAccessKeyInput) _, err = client.DeleteAccessKeyWithContext(ctx, &deleteAccessKeyInput)
if err != nil { if err != nil {
return nil, fmt.Errorf("error deleting old access key: %w", err) return nil, fmt.Errorf("error deleting old access key: %w", err)
} }

View File

@@ -155,7 +155,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
} }
// Get information about this user // Get information about this user
groupsResp, err := client.ListGroupsForUser(&iam.ListGroupsForUserInput{ groupsResp, err := client.ListGroupsForUserWithContext(ctx, &iam.ListGroupsForUserInput{
UserName: aws.String(username), UserName: aws.String(username),
MaxItems: aws.Int64(1000), MaxItems: aws.Int64(1000),
}) })
@@ -194,7 +194,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
groups := groupsResp.Groups groups := groupsResp.Groups
// Inline (user) policies // Inline (user) policies
policiesResp, err := client.ListUserPolicies(&iam.ListUserPoliciesInput{ policiesResp, err := client.ListUserPoliciesWithContext(ctx, &iam.ListUserPoliciesInput{
UserName: aws.String(username), UserName: aws.String(username),
MaxItems: aws.Int64(1000), MaxItems: aws.Int64(1000),
}) })
@@ -204,7 +204,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
policies := policiesResp.PolicyNames policies := policiesResp.PolicyNames
// Attached managed policies // Attached managed policies
manPoliciesResp, err := client.ListAttachedUserPolicies(&iam.ListAttachedUserPoliciesInput{ manPoliciesResp, err := client.ListAttachedUserPoliciesWithContext(ctx, &iam.ListAttachedUserPoliciesInput{
UserName: aws.String(username), UserName: aws.String(username),
MaxItems: aws.Int64(1000), MaxItems: aws.Int64(1000),
}) })
@@ -213,7 +213,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
} }
manPolicies := manPoliciesResp.AttachedPolicies manPolicies := manPoliciesResp.AttachedPolicies
keysResp, err := client.ListAccessKeys(&iam.ListAccessKeysInput{ keysResp, err := client.ListAccessKeysWithContext(ctx, &iam.ListAccessKeysInput{
UserName: aws.String(username), UserName: aws.String(username),
MaxItems: aws.Int64(1000), MaxItems: aws.Int64(1000),
}) })
@@ -224,7 +224,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
// Revoke all keys // Revoke all keys
for _, k := range keys { for _, k := range keys {
_, err = client.DeleteAccessKey(&iam.DeleteAccessKeyInput{ _, err = client.DeleteAccessKeyWithContext(ctx, &iam.DeleteAccessKeyInput{
AccessKeyId: k.AccessKeyId, AccessKeyId: k.AccessKeyId,
UserName: aws.String(username), UserName: aws.String(username),
}) })
@@ -235,7 +235,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
// Detach managed policies // Detach managed policies
for _, p := range manPolicies { for _, p := range manPolicies {
_, err = client.DetachUserPolicy(&iam.DetachUserPolicyInput{ _, err = client.DetachUserPolicyWithContext(ctx, &iam.DetachUserPolicyInput{
UserName: aws.String(username), UserName: aws.String(username),
PolicyArn: p.PolicyArn, PolicyArn: p.PolicyArn,
}) })
@@ -246,7 +246,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
// Delete any inline (user) policies // Delete any inline (user) policies
for _, p := range policies { for _, p := range policies {
_, err = client.DeleteUserPolicy(&iam.DeleteUserPolicyInput{ _, err = client.DeleteUserPolicyWithContext(ctx, &iam.DeleteUserPolicyInput{
UserName: aws.String(username), UserName: aws.String(username),
PolicyName: p, PolicyName: p,
}) })
@@ -257,7 +257,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
// Remove the user from all their groups // Remove the user from all their groups
for _, g := range groups { for _, g := range groups {
_, err = client.RemoveUserFromGroup(&iam.RemoveUserFromGroupInput{ _, err = client.RemoveUserFromGroupWithContext(ctx, &iam.RemoveUserFromGroupInput{
GroupName: g.GroupName, GroupName: g.GroupName,
UserName: aws.String(username), UserName: aws.String(username),
}) })
@@ -267,7 +267,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
} }
// Delete the user // Delete the user
_, err = client.DeleteUser(&iam.DeleteUserInput{ _, err = client.DeleteUserWithContext(ctx, &iam.DeleteUserInput{
UserName: aws.String(username), UserName: aws.String(username),
}) })
if err != nil { if err != nil {

View File

@@ -153,7 +153,7 @@ func (b *backend) getFederationToken(ctx context.Context, s logical.Storage,
return logical.ErrorResponse("must specify at least one of policy_arns or policy_document with %s credential_type", federationTokenCred), nil return logical.ErrorResponse("must specify at least one of policy_arns or policy_document with %s credential_type", federationTokenCred), nil
} }
tokenResp, err := stsClient.GetFederationToken(getTokenInput) tokenResp, err := stsClient.GetFederationTokenWithContext(ctx, getTokenInput)
if err != nil { if err != nil {
return logical.ErrorResponse("Error generating STS keys: %s", err), awsutil.CheckAWSError(err) return logical.ErrorResponse("Error generating STS keys: %s", err), awsutil.CheckAWSError(err)
} }
@@ -228,7 +228,7 @@ func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
if len(policyARNs) > 0 { if len(policyARNs) > 0 {
assumeRoleInput.SetPolicyArns(convertPolicyARNs(policyARNs)) assumeRoleInput.SetPolicyArns(convertPolicyARNs(policyARNs))
} }
tokenResp, err := stsClient.AssumeRole(assumeRoleInput) tokenResp, err := stsClient.AssumeRoleWithContext(ctx, assumeRoleInput)
if err != nil { if err != nil {
return logical.ErrorResponse("Error assuming role: %s", err), awsutil.CheckAWSError(err) return logical.ErrorResponse("Error assuming role: %s", err), awsutil.CheckAWSError(err)
} }
@@ -314,7 +314,7 @@ func (b *backend) secretAccessKeysCreate(
} }
// Create the user // Create the user
_, err = iamClient.CreateUser(createUserRequest) _, err = iamClient.CreateUserWithContext(ctx, createUserRequest)
if err != nil { if err != nil {
if walErr := framework.DeleteWAL(ctx, s, walID); walErr != nil { if walErr := framework.DeleteWAL(ctx, s, walID); walErr != nil {
iamErr := fmt.Errorf("error creating IAM user: %w", err) iamErr := fmt.Errorf("error creating IAM user: %w", err)
@@ -325,7 +325,7 @@ func (b *backend) secretAccessKeysCreate(
for _, arn := range role.PolicyArns { for _, arn := range role.PolicyArns {
// Attach existing policy against user // Attach existing policy against user
_, err = iamClient.AttachUserPolicy(&iam.AttachUserPolicyInput{ _, err = iamClient.AttachUserPolicyWithContext(ctx, &iam.AttachUserPolicyInput{
UserName: aws.String(username), UserName: aws.String(username),
PolicyArn: aws.String(arn), PolicyArn: aws.String(arn),
}) })
@@ -336,7 +336,7 @@ func (b *backend) secretAccessKeysCreate(
} }
if role.PolicyDocument != "" { if role.PolicyDocument != "" {
// Add new inline user policy against user // Add new inline user policy against user
_, err = iamClient.PutUserPolicy(&iam.PutUserPolicyInput{ _, err = iamClient.PutUserPolicyWithContext(ctx, &iam.PutUserPolicyInput{
UserName: aws.String(username), UserName: aws.String(username),
PolicyName: aws.String(policyName), PolicyName: aws.String(policyName),
PolicyDocument: aws.String(role.PolicyDocument), PolicyDocument: aws.String(role.PolicyDocument),
@@ -348,7 +348,7 @@ func (b *backend) secretAccessKeysCreate(
for _, group := range role.IAMGroups { for _, group := range role.IAMGroups {
// Add user to IAM groups // Add user to IAM groups
_, err = iamClient.AddUserToGroup(&iam.AddUserToGroupInput{ _, err = iamClient.AddUserToGroupWithContext(ctx, &iam.AddUserToGroupInput{
UserName: aws.String(username), UserName: aws.String(username),
GroupName: aws.String(group), GroupName: aws.String(group),
}) })
@@ -367,7 +367,7 @@ func (b *backend) secretAccessKeysCreate(
} }
if len(tags) > 0 { if len(tags) > 0 {
_, err = iamClient.TagUser(&iam.TagUserInput{ _, err = iamClient.TagUserWithContext(ctx, &iam.TagUserInput{
Tags: tags, Tags: tags,
UserName: &username, UserName: &username,
}) })
@@ -378,7 +378,7 @@ func (b *backend) secretAccessKeysCreate(
} }
// Create the keys // Create the keys
keyResp, err := iamClient.CreateAccessKey(&iam.CreateAccessKeyInput{ keyResp, err := iamClient.CreateAccessKeyWithContext(ctx, &iam.CreateAccessKeyInput{
UserName: aws.String(username), UserName: aws.String(username),
}) })
if err != nil { if err != nil {

7
changelog/19365.txt Normal file
View File

@@ -0,0 +1,7 @@
```release-note: enhancement
auth/aws: Support request cancellation with AWS requests
```
```release-note: enhancement
secrets/aws: Support request cancellation with AWS requests
```