mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	AWS auth login with multi region STS support (#21960)
This commit is contained in:
		| @@ -12,6 +12,7 @@ import ( | ||||
|  | ||||
| 	"github.com/aws/aws-sdk-go/aws" | ||||
| 	"github.com/hashicorp/go-secure-stdlib/strutil" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/sdk/framework" | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
| @@ -61,6 +62,12 @@ func (b *backend) pathConfigClient() *framework.Path { | ||||
| 				Description: "The region ID for the sts_endpoint, if set.", | ||||
| 			}, | ||||
|  | ||||
| 			"use_sts_region_from_client": { | ||||
| 				Type:        framework.TypeBool, | ||||
| 				Default:     false, | ||||
| 				Description: "Uses the STS region from client requests for making AWS STS API calls.", | ||||
| 			}, | ||||
|  | ||||
| 			"iam_server_id_header_value": { | ||||
| 				Type:        framework.TypeString, | ||||
| 				Default:     "", | ||||
| @@ -168,6 +175,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request | ||||
| 			"iam_endpoint":               clientConfig.IAMEndpoint, | ||||
| 			"sts_endpoint":               clientConfig.STSEndpoint, | ||||
| 			"sts_region":                 clientConfig.STSRegion, | ||||
| 			"use_sts_region_from_client": clientConfig.UseSTSRegionFromClient, | ||||
| 			"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue, | ||||
| 			"max_retries":                clientConfig.MaxRetries, | ||||
| 			"allowed_sts_header_values":  clientConfig.AllowedSTSHeaderValues, | ||||
| @@ -281,6 +289,14 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	useSTSRegionFromClientRaw, ok := data.GetOk("use_sts_region_from_client") | ||||
| 	if ok { | ||||
| 		if configEntry.UseSTSRegionFromClient != useSTSRegionFromClientRaw.(bool) { | ||||
| 			changedCreds = true | ||||
| 			configEntry.UseSTSRegionFromClient = useSTSRegionFromClientRaw.(bool) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	headerValStr, ok := data.GetOk("iam_server_id_header_value") | ||||
| 	if ok { | ||||
| 		if configEntry.IAMServerIdHeaderValue != headerValStr.(string) { | ||||
| @@ -363,6 +379,7 @@ type clientConfig struct { | ||||
| 	IAMEndpoint            string   `json:"iam_endpoint"` | ||||
| 	STSEndpoint            string   `json:"sts_endpoint"` | ||||
| 	STSRegion              string   `json:"sts_region"` | ||||
| 	UseSTSRegionFromClient bool     `json:"use_sts_region_from_client"` | ||||
| 	IAMServerIdHeaderValue string   `json:"iam_server_id_header_value"` | ||||
| 	AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"` | ||||
| 	MaxRetries             int      `json:"max_retries"` | ||||
|   | ||||
| @@ -21,8 +21,10 @@ import ( | ||||
|  | ||||
| 	"github.com/aws/aws-sdk-go/aws" | ||||
| 	awsClient "github.com/aws/aws-sdk-go/aws/client" | ||||
| 	"github.com/aws/aws-sdk-go/aws/endpoints" | ||||
| 	"github.com/aws/aws-sdk-go/service/ec2" | ||||
| 	"github.com/aws/aws-sdk-go/service/iam" | ||||
| 	"github.com/aws/aws-sdk-go/service/sts" | ||||
| 	"github.com/hashicorp/errwrap" | ||||
| 	cleanhttp "github.com/hashicorp/go-cleanhttp" | ||||
| 	"github.com/hashicorp/go-retryablehttp" | ||||
| @@ -30,6 +32,7 @@ import ( | ||||
| 	"github.com/hashicorp/go-secure-stdlib/parseutil" | ||||
| 	"github.com/hashicorp/go-secure-stdlib/strutil" | ||||
| 	uuid "github.com/hashicorp/go-uuid" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/builtin/credential/aws/pkcs7" | ||||
| 	"github.com/hashicorp/vault/sdk/framework" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/cidrutil" | ||||
| @@ -318,6 +321,24 @@ func (b *backend) pathLoginIamGetRoleNameCallerIdAndEntity(ctx context.Context, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Extract and use a regional STS endpoint | ||||
| 	// based on the region set in the Authorization header. | ||||
| 	if config.UseSTSRegionFromClient { | ||||
| 		clientSpecifiedRegion, err := awsRegionFromHeader(headers.Get("Authorization")) | ||||
| 		if err != nil { | ||||
| 			return "", nil, nil, logical.ErrorResponse("region missing from Authorization header"), nil | ||||
| 		} | ||||
|  | ||||
| 		url, err := stsRegionalEndpoint(clientSpecifiedRegion) | ||||
| 		if err != nil { | ||||
| 			return "", nil, nil, logical.ErrorResponse(err.Error()), nil | ||||
| 		} | ||||
|  | ||||
| 		b.Logger().Debug("use_sts_region_from_client set; using region specified from header", "region", clientSpecifiedRegion) | ||||
| 		endpoint = url | ||||
| 	} | ||||
|  | ||||
| 	b.Logger().Debug("submitting caller identity request", "endpoint", endpoint) | ||||
| 	callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers) | ||||
| 	if err != nil { | ||||
| 		return "", nil, nil, logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil | ||||
| @@ -1884,6 +1905,43 @@ func getMetadataValue(fromAuth *logical.Auth, forKey string) (string, error) { | ||||
| 	return "", fmt.Errorf("%q not found in auth metadata", forKey) | ||||
| } | ||||
|  | ||||
| func awsRegionFromHeader(authorizationHeader string) (string, error) { | ||||
| 	// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html | ||||
| 	// The Authorization header takes the following form. | ||||
| 	//  Authorization: AWS4-HMAC-SHA256 | ||||
| 	//	Credential=AKIAIOSFODNN7EXAMPLE/20230719/us-east-1/sts/aws4_request, | ||||
| 	// 	SignedHeaders=content-length;content-type;host;x-amz-date, | ||||
| 	//	Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024 | ||||
| 	// | ||||
| 	// The credential is in the form of "<your-access-key-id>/<date>/<aws-region>/<aws-service>/aws4_request" | ||||
| 	fields := strings.Split(authorizationHeader, " ") | ||||
| 	for _, field := range fields { | ||||
| 		if strings.HasPrefix(field, "Credential=") { | ||||
| 			fields := strings.Split(field, "/") | ||||
| 			if len(fields) < 3 { | ||||
| 				return "", fmt.Errorf("invalid header format") | ||||
| 			} | ||||
|  | ||||
| 			region := fields[2] | ||||
| 			return region, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return "", fmt.Errorf("invalid header format") | ||||
| } | ||||
|  | ||||
| func stsRegionalEndpoint(region string) (string, error) { | ||||
| 	stsService := sts.EndpointsID | ||||
| 	resolver := endpoints.DefaultResolver() | ||||
| 	resolvedEndpoint, err := resolver.EndpointFor(stsService, region, | ||||
| 		endpoints.STSRegionalEndpointOption, | ||||
| 		endpoints.StrictMatchingOption) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("unable to get regional STS endpoint for region: %v", region) | ||||
| 	} | ||||
| 	return resolvedEndpoint.URL, nil | ||||
| } | ||||
|  | ||||
| const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID" | ||||
|  | ||||
| const pathLoginSyn = ` | ||||
|   | ||||
| @@ -16,6 +16,8 @@ import ( | ||||
|  | ||||
| 	"github.com/aws/aws-sdk-go/aws/session" | ||||
| 	"github.com/aws/aws-sdk-go/service/sts" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
|  | ||||
| @@ -625,6 +627,58 @@ func TestBackend_defaultAliasMetadata(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRegionFromHeader(t *testing.T) { | ||||
| 	tcs := map[string]struct { | ||||
| 		header              string | ||||
| 		expectedRegion      string | ||||
| 		expectedSTSEndpoint string | ||||
| 	}{ | ||||
| 		"us-east-1": { | ||||
| 			header:              "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", | ||||
| 			expectedRegion:      "us-east-1", | ||||
| 			expectedSTSEndpoint: "https://sts.us-east-1.amazonaws.com", | ||||
| 		}, | ||||
| 		"us-west-2": { | ||||
| 			header:              "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-west-2/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", | ||||
| 			expectedRegion:      "us-west-2", | ||||
| 			expectedSTSEndpoint: "https://sts.us-west-2.amazonaws.com", | ||||
| 		}, | ||||
| 		"ap-northeast-3": { | ||||
| 			header:              "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/ap-northeast-3/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", | ||||
| 			expectedRegion:      "ap-northeast-3", | ||||
| 			expectedSTSEndpoint: "https://sts.ap-northeast-3.amazonaws.com", | ||||
| 		}, | ||||
| 		"us-gov-east-1": { | ||||
| 			header:              "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-gov-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", | ||||
| 			expectedRegion:      "us-gov-east-1", | ||||
| 			expectedSTSEndpoint: "https://sts.us-gov-east-1.amazonaws.com", | ||||
| 		}, | ||||
| 	} | ||||
| 	for name, tc := range tcs { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			region, err := awsRegionFromHeader(tc.header) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, tc.expectedRegion, region) | ||||
|  | ||||
| 			stsEndpoint, err := stsRegionalEndpoint(region) | ||||
| 			assert.NoError(t, err) | ||||
| 			assert.Equal(t, tc.expectedSTSEndpoint, stsEndpoint) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	t.Run("invalid-header", func(t *testing.T) { | ||||
| 		region, err := awsRegionFromHeader("this-is-an-invalid-header/foobar") | ||||
| 		assert.EqualError(t, err, "invalid header format") | ||||
| 		assert.Empty(t, region) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("invalid-region", func(t *testing.T) { | ||||
| 		endpoint, err := stsRegionalEndpoint("fake-region-1") | ||||
| 		assert.EqualError(t, err, "unable to get regional STS endpoint for region: fake-region-1") | ||||
| 		assert.Empty(t, endpoint) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func defaultLoginData() (map[string]interface{}, error) { | ||||
| 	awsSession, err := session.NewSession() | ||||
| 	if err != nil { | ||||
|   | ||||
							
								
								
									
										3
									
								
								changelog/21960.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/21960.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| ```release-note:improvement | ||||
| aws/auth: Adds a new config field `use_sts_region_from_client` which allows for using dynamic regional sts endpoints based on Authorization header when using IAM-based authentication. | ||||
| ``` | ||||
| @@ -65,6 +65,10 @@ capabilities, the credentials are fetched automatically. | ||||
| - `sts_region` `(string: "")` - Region to override the default region for making | ||||
|   AWS STS API calls. Should only be set if `sts_endpoint` is set. If so, should | ||||
|   be set to the region in which the custom `sts_endpoint` resides. | ||||
| - `use_sts_region_from_client` `(boolean: false)` - If set, overrides both `sts_endpoint` | ||||
|   and `sts_region` to instead use the region specified in the client request headers for | ||||
|   IAM-based authentication . This can be useful when you have client requests coming from | ||||
|   different regions and want flexibility in which regional STS API is used. | ||||
| - `iam_server_id_header_value` `(string: "")` - The value to require in the | ||||
|   `X-Vault-AWS-IAM-Server-ID` header as part of GetCallerIdentity requests that | ||||
|   are used in the iam auth method. If not set, then no value is required or | ||||
| @@ -123,6 +127,7 @@ $ curl \ | ||||
|     "iam_endpoint": "", | ||||
|     "sts_endpoint": "", | ||||
|     "sts_region": "", | ||||
|     "use_sts_region_from_client": false, | ||||
|     "iam_server_id_header_value": "" | ||||
|   } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Raymond Ho
					Raymond Ho