mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	 10c0adad72
			
		
	
	10c0adad72
	
	
	
		
			
			Adds debug and warn logging around AWS credential chain generation, specifically to help users debugging auto-unseal problems on AWS, by logging which role is being used in the case of a webidentity token. Adds a deferred call to flush the log output as well, to ensure logs are output in the event of an initialization failure.
		
			
				
	
	
		
			106 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			106 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package aws
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"os"
 | |
| 
 | |
| 	"github.com/aws/aws-sdk-go/aws"
 | |
| 	"github.com/aws/aws-sdk-go/aws/session"
 | |
| 	"github.com/aws/aws-sdk-go/service/iam"
 | |
| 	"github.com/aws/aws-sdk-go/service/sts"
 | |
| 	"github.com/hashicorp/errwrap"
 | |
| 	cleanhttp "github.com/hashicorp/go-cleanhttp"
 | |
| 	"github.com/hashicorp/go-hclog"
 | |
| 	"github.com/hashicorp/vault/sdk/helper/awsutil"
 | |
| 	"github.com/hashicorp/vault/sdk/logical"
 | |
| )
 | |
| 
 | |
| // NOTE: The caller is required to ensure that b.clientMutex is at least read locked
 | |
| func getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) {
 | |
| 	credsConfig := &awsutil.CredentialsConfig{}
 | |
| 	var endpoint string
 | |
| 	var maxRetries int = aws.UseServiceDefaultRetries
 | |
| 
 | |
| 	entry, err := s.Get(ctx, "config/root")
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if entry != nil {
 | |
| 		var config rootConfig
 | |
| 		if err := entry.DecodeJSON(&config); err != nil {
 | |
| 			return nil, errwrap.Wrapf("error reading root configuration: {{err}}", err)
 | |
| 		}
 | |
| 
 | |
| 		credsConfig.AccessKey = config.AccessKey
 | |
| 		credsConfig.SecretKey = config.SecretKey
 | |
| 		credsConfig.Region = config.Region
 | |
| 		maxRetries = config.MaxRetries
 | |
| 		switch {
 | |
| 		case clientType == "iam" && config.IAMEndpoint != "":
 | |
| 			endpoint = *aws.String(config.IAMEndpoint)
 | |
| 		case clientType == "sts" && config.STSEndpoint != "":
 | |
| 			endpoint = *aws.String(config.STSEndpoint)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if credsConfig.Region == "" {
 | |
| 		credsConfig.Region = os.Getenv("AWS_REGION")
 | |
| 		if credsConfig.Region == "" {
 | |
| 			credsConfig.Region = os.Getenv("AWS_DEFAULT_REGION")
 | |
| 			if credsConfig.Region == "" {
 | |
| 				credsConfig.Region = "us-east-1"
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	credsConfig.HTTPClient = cleanhttp.DefaultClient()
 | |
| 
 | |
| 	credsConfig.Logger = logger
 | |
| 
 | |
| 	creds, err := credsConfig.GenerateCredentialChain()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &aws.Config{
 | |
| 		Credentials: creds,
 | |
| 		Region:      aws.String(credsConfig.Region),
 | |
| 		Endpoint:    &endpoint,
 | |
| 		HTTPClient:  cleanhttp.DefaultClient(),
 | |
| 		MaxRetries:  aws.Int(maxRetries),
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
 | |
| 	awsConfig, err := getRootConfig(ctx, s, "iam", logger)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	sess, err := session.NewSession(awsConfig)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	client := iam.New(sess)
 | |
| 	if client == nil {
 | |
| 		return nil, fmt.Errorf("could not obtain iam client")
 | |
| 	}
 | |
| 	return client, nil
 | |
| }
 | |
| 
 | |
| func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
 | |
| 	awsConfig, err := getRootConfig(ctx, s, "sts", logger)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	sess, err := session.NewSession(awsConfig)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	client := sts.New(sess)
 | |
| 	if client == nil {
 | |
| 		return nil, fmt.Errorf("could not obtain sts client")
 | |
| 	}
 | |
| 	return client, nil
 | |
| }
 |