mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 02:02:43 +00:00 
			
		
		
		
	 cfff8d420e
			
		
	
	cfff8d420e
	
	
	
		
			
			* auth/aws: use cancelable context with aws calls * secrets/aws: use cancelable context with aws calls
		
			
				
	
	
		
			307 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			307 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) HashiCorp, Inc.
 | |
| // SPDX-License-Identifier: MPL-2.0
 | |
| 
 | |
| package awsauth
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 
 | |
| 	"github.com/aws/aws-sdk-go/aws"
 | |
| 	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
 | |
| 	"github.com/aws/aws-sdk-go/aws/session"
 | |
| 	"github.com/aws/aws-sdk-go/service/ec2"
 | |
| 	"github.com/aws/aws-sdk-go/service/iam"
 | |
| 	"github.com/aws/aws-sdk-go/service/sts"
 | |
| 	cleanhttp "github.com/hashicorp/go-cleanhttp"
 | |
| 	"github.com/hashicorp/go-secure-stdlib/awsutil"
 | |
| 	"github.com/hashicorp/vault/sdk/logical"
 | |
| )
 | |
| 
 | |
| // getRawClientConfig creates a aws-sdk-go config, which is used to create client
 | |
| // that can interact with AWS API. This builds credentials in the following
 | |
| // order of preference:
 | |
| //
 | |
| // * Static credentials from 'config/client'
 | |
| // * Environment variables
 | |
| // * Instance metadata role
 | |
| func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) {
 | |
| 	credsConfig := &awsutil.CredentialsConfig{
 | |
| 		Region: region,
 | |
| 		Logger: b.Logger(),
 | |
| 	}
 | |
| 
 | |
| 	// Read the configured secret key and access key
 | |
| 	config, err := b.nonLockedClientConfigEntry(ctx, s)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	endpoint := aws.String("")
 | |
| 	var maxRetries int = aws.UseServiceDefaultRetries
 | |
| 	if config != nil {
 | |
| 		// Override the defaults with configured values.
 | |
| 		switch {
 | |
| 		case clientType == "ec2" && config.Endpoint != "":
 | |
| 			endpoint = aws.String(config.Endpoint)
 | |
| 		case clientType == "iam" && config.IAMEndpoint != "":
 | |
| 			endpoint = aws.String(config.IAMEndpoint)
 | |
| 		case clientType == "sts":
 | |
| 			if config.STSEndpoint != "" {
 | |
| 				endpoint = aws.String(config.STSEndpoint)
 | |
| 			}
 | |
| 			if config.STSRegion != "" {
 | |
| 				region = config.STSRegion
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		credsConfig.AccessKey = config.AccessKey
 | |
| 		credsConfig.SecretKey = config.SecretKey
 | |
| 		maxRetries = config.MaxRetries
 | |
| 	}
 | |
| 
 | |
| 	credsConfig.HTTPClient = cleanhttp.DefaultClient()
 | |
| 
 | |
| 	creds, err := credsConfig.GenerateCredentialChain()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if creds == nil {
 | |
| 		return nil, fmt.Errorf("could not compile valid credential providers from static config, environment, shared, or instance metadata")
 | |
| 	}
 | |
| 
 | |
| 	// Create a config that can be used to make the API calls.
 | |
| 	return &aws.Config{
 | |
| 		Credentials: creds,
 | |
| 		Region:      aws.String(region),
 | |
| 		HTTPClient:  cleanhttp.DefaultClient(),
 | |
| 		Endpoint:    endpoint,
 | |
| 		MaxRetries:  aws.Int(maxRetries),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // getClientConfig returns an aws-sdk-go config, with optionally assumed credentials
 | |
| // It uses getRawClientConfig to obtain config for the runtime environment, and if
 | |
| // stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
 | |
| // credentials. The credentials will expire after 15 minutes but will auto-refresh.
 | |
| func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
 | |
| 	config, err := b.getRawClientConfig(ctx, s, region, clientType)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if config == nil {
 | |
| 		return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
 | |
| 	}
 | |
| 
 | |
| 	stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts")
 | |
| 	if stsConfig == nil {
 | |
| 		return nil, fmt.Errorf("could not configure STS client")
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if stsRole != "" {
 | |
| 		sess, err := session.NewSession(stsConfig)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		assumedCredentials := stscreds.NewCredentials(sess, stsRole)
 | |
| 		// Test that we actually have permissions to assume the role
 | |
| 		if _, err = assumedCredentials.Get(); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		config.Credentials = assumedCredentials
 | |
| 	} else {
 | |
| 		if b.defaultAWSAccountID == "" {
 | |
| 			sess, err := session.NewSession(stsConfig)
 | |
| 			if err != nil {
 | |
| 				return nil, err
 | |
| 			}
 | |
| 			client := sts.New(sess)
 | |
| 			if client == nil {
 | |
| 				return nil, fmt.Errorf("could not obtain sts client: %w", err)
 | |
| 			}
 | |
| 			inputParams := &sts.GetCallerIdentityInput{}
 | |
| 			identity, err := client.GetCallerIdentityWithContext(ctx, inputParams)
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf("unable to fetch current caller: %w", err)
 | |
| 			}
 | |
| 			if identity == nil {
 | |
| 				return nil, fmt.Errorf("got nil result from GetCallerIdentity")
 | |
| 			}
 | |
| 			b.defaultAWSAccountID = *identity.Account
 | |
| 		}
 | |
| 		if b.defaultAWSAccountID != accountID {
 | |
| 			return nil, fmt.Errorf("unable to fetch client for account ID %q -- default client is for account %q", accountID, b.defaultAWSAccountID)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return config, nil
 | |
| }
 | |
| 
 | |
| // flushCachedEC2Clients deletes all the cached ec2 client objects from the backend.
 | |
| // If the client credentials configuration is deleted or updated in the backend, all
 | |
| // the cached EC2 client objects will be flushed. Config mutex lock should be
 | |
| // acquired for write operation before calling this method.
 | |
| func (b *backend) flushCachedEC2Clients() {
 | |
| 	// deleting items in map during iteration is safe
 | |
| 	for region := range b.EC2ClientsMap {
 | |
| 		delete(b.EC2ClientsMap, region)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // flushCachedIAMClients deletes all the cached iam client objects from the
 | |
| // backend. If the client credentials configuration is deleted or updated in
 | |
| // the backend, all the cached IAM client objects will be flushed. Config mutex
 | |
| // lock should be acquired for write operation before calling this method.
 | |
| func (b *backend) flushCachedIAMClients() {
 | |
| 	// deleting items in map during iteration is safe
 | |
| 	for region := range b.IAMClientsMap {
 | |
| 		delete(b.IAMClientsMap, region)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Gets an entry out of the user ID cache
 | |
| func (b *backend) getCachedUserId(userId string) string {
 | |
| 	if userId == "" {
 | |
| 		return ""
 | |
| 	}
 | |
| 	if entry, ok := b.iamUserIdToArnCache.Get(userId); ok {
 | |
| 		b.iamUserIdToArnCache.SetDefault(userId, entry)
 | |
| 		return entry.(string)
 | |
| 	}
 | |
| 	return ""
 | |
| }
 | |
| 
 | |
| // Sets an entry in the user ID cache
 | |
| func (b *backend) setCachedUserId(userId, arn string) {
 | |
| 	if userId != "" {
 | |
| 		b.iamUserIdToArnCache.SetDefault(userId, arn)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) {
 | |
| 	// Check if an STS configuration exists for the AWS account
 | |
| 	sts, err := b.lockedAwsStsEntry(ctx, s, accountID)
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("error fetching STS config for account ID %q: %w", accountID, err)
 | |
| 	}
 | |
| 	// An empty STS role signifies the master account
 | |
| 	if sts != nil {
 | |
| 		return sts.StsRole, nil
 | |
| 	}
 | |
| 	return "", nil
 | |
| }
 | |
| 
 | |
| // clientEC2 creates a client to interact with AWS EC2 API
 | |
| func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
 | |
| 	stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	b.configMutex.RLock()
 | |
| 	if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
 | |
| 		defer b.configMutex.RUnlock()
 | |
| 		// If the client object was already created, return it
 | |
| 		return b.EC2ClientsMap[region][stsRole], nil
 | |
| 	}
 | |
| 
 | |
| 	// Release the read lock and acquire the write lock
 | |
| 	b.configMutex.RUnlock()
 | |
| 	b.configMutex.Lock()
 | |
| 	defer b.configMutex.Unlock()
 | |
| 
 | |
| 	// If the client gets created while switching the locks, return it
 | |
| 	if b.EC2ClientsMap[region] != nil && b.EC2ClientsMap[region][stsRole] != nil {
 | |
| 		return b.EC2ClientsMap[region][stsRole], nil
 | |
| 	}
 | |
| 
 | |
| 	// Create an AWS config object using a chain of providers
 | |
| 	var awsConfig *aws.Config
 | |
| 	awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2")
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if awsConfig == nil {
 | |
| 		return nil, fmt.Errorf("could not retrieve valid assumed credentials")
 | |
| 	}
 | |
| 
 | |
| 	// Create a new EC2 client object, cache it and return the same
 | |
| 	sess, err := session.NewSession(awsConfig)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	client := ec2.New(sess)
 | |
| 	if client == nil {
 | |
| 		return nil, fmt.Errorf("could not obtain ec2 client")
 | |
| 	}
 | |
| 	if _, ok := b.EC2ClientsMap[region]; !ok {
 | |
| 		b.EC2ClientsMap[region] = map[string]*ec2.EC2{stsRole: client}
 | |
| 	} else {
 | |
| 		b.EC2ClientsMap[region][stsRole] = client
 | |
| 	}
 | |
| 
 | |
| 	return b.EC2ClientsMap[region][stsRole], nil
 | |
| }
 | |
| 
 | |
| // clientIAM creates a client to interact with AWS IAM API
 | |
| func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
 | |
| 	stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if stsRole == "" {
 | |
| 		b.Logger().Debug(fmt.Sprintf("no stsRole found for %s", accountID))
 | |
| 	} else {
 | |
| 		b.Logger().Debug(fmt.Sprintf("found stsRole %s for account %s", stsRole, accountID))
 | |
| 	}
 | |
| 	b.configMutex.RLock()
 | |
| 	if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
 | |
| 		defer b.configMutex.RUnlock()
 | |
| 		// If the client object was already created, return it
 | |
| 		b.Logger().Debug(fmt.Sprintf("returning cached client for region %s and stsRole %s", region, stsRole))
 | |
| 		return b.IAMClientsMap[region][stsRole], nil
 | |
| 	}
 | |
| 	b.Logger().Debug(fmt.Sprintf("no cached client for region %s and stsRole %s", region, stsRole))
 | |
| 
 | |
| 	// Release the read lock and acquire the write lock
 | |
| 	b.configMutex.RUnlock()
 | |
| 	b.configMutex.Lock()
 | |
| 	defer b.configMutex.Unlock()
 | |
| 
 | |
| 	// If the client gets created while switching the locks, return it
 | |
| 	if b.IAMClientsMap[region] != nil && b.IAMClientsMap[region][stsRole] != nil {
 | |
| 		return b.IAMClientsMap[region][stsRole], nil
 | |
| 	}
 | |
| 
 | |
| 	// Create an AWS config object using a chain of providers
 | |
| 	var awsConfig *aws.Config
 | |
| 	awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam")
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if awsConfig == nil {
 | |
| 		return nil, fmt.Errorf("could not retrieve valid assumed credentials")
 | |
| 	}
 | |
| 
 | |
| 	// Create a new IAM client object, cache it and return the same
 | |
| 	sess, err := session.NewSession(awsConfig)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	client := iam.New(sess)
 | |
| 	if client == nil {
 | |
| 		return nil, fmt.Errorf("could not obtain iam client")
 | |
| 	}
 | |
| 	if _, ok := b.IAMClientsMap[region]; !ok {
 | |
| 		b.IAMClientsMap[region] = map[string]*iam.IAM{stsRole: client}
 | |
| 	} else {
 | |
| 		b.IAMClientsMap[region][stsRole] = client
 | |
| 	}
 | |
| 	return b.IAMClientsMap[region][stsRole], nil
 | |
| }
 |