mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 01:32:33 +00:00
AWS auth login with multi region STS support (#21960)
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
@@ -61,6 +62,12 @@ func (b *backend) pathConfigClient() *framework.Path {
|
||||
Description: "The region ID for the sts_endpoint, if set.",
|
||||
},
|
||||
|
||||
"use_sts_region_from_client": {
|
||||
Type: framework.TypeBool,
|
||||
Default: false,
|
||||
Description: "Uses the STS region from client requests for making AWS STS API calls.",
|
||||
},
|
||||
|
||||
"iam_server_id_header_value": {
|
||||
Type: framework.TypeString,
|
||||
Default: "",
|
||||
@@ -168,6 +175,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
|
||||
"iam_endpoint": clientConfig.IAMEndpoint,
|
||||
"sts_endpoint": clientConfig.STSEndpoint,
|
||||
"sts_region": clientConfig.STSRegion,
|
||||
"use_sts_region_from_client": clientConfig.UseSTSRegionFromClient,
|
||||
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
|
||||
"max_retries": clientConfig.MaxRetries,
|
||||
"allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues,
|
||||
@@ -281,6 +289,14 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
|
||||
}
|
||||
}
|
||||
|
||||
useSTSRegionFromClientRaw, ok := data.GetOk("use_sts_region_from_client")
|
||||
if ok {
|
||||
if configEntry.UseSTSRegionFromClient != useSTSRegionFromClientRaw.(bool) {
|
||||
changedCreds = true
|
||||
configEntry.UseSTSRegionFromClient = useSTSRegionFromClientRaw.(bool)
|
||||
}
|
||||
}
|
||||
|
||||
headerValStr, ok := data.GetOk("iam_server_id_header_value")
|
||||
if ok {
|
||||
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
|
||||
@@ -363,6 +379,7 @@ type clientConfig struct {
|
||||
IAMEndpoint string `json:"iam_endpoint"`
|
||||
STSEndpoint string `json:"sts_endpoint"`
|
||||
STSRegion string `json:"sts_region"`
|
||||
UseSTSRegionFromClient bool `json:"use_sts_region_from_client"`
|
||||
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
|
||||
AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
|
||||
@@ -21,8 +21,10 @@ import (
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
awsClient "github.com/aws/aws-sdk-go/aws/client"
|
||||
"github.com/aws/aws-sdk-go/aws/endpoints"
|
||||
"github.com/aws/aws-sdk-go/service/ec2"
|
||||
"github.com/aws/aws-sdk-go/service/iam"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
"github.com/hashicorp/errwrap"
|
||||
cleanhttp "github.com/hashicorp/go-cleanhttp"
|
||||
"github.com/hashicorp/go-retryablehttp"
|
||||
@@ -30,6 +32,7 @@ import (
|
||||
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/credential/aws/pkcs7"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/cidrutil"
|
||||
@@ -318,6 +321,24 @@ func (b *backend) pathLoginIamGetRoleNameCallerIdAndEntity(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and use a regional STS endpoint
|
||||
// based on the region set in the Authorization header.
|
||||
if config.UseSTSRegionFromClient {
|
||||
clientSpecifiedRegion, err := awsRegionFromHeader(headers.Get("Authorization"))
|
||||
if err != nil {
|
||||
return "", nil, nil, logical.ErrorResponse("region missing from Authorization header"), nil
|
||||
}
|
||||
|
||||
url, err := stsRegionalEndpoint(clientSpecifiedRegion)
|
||||
if err != nil {
|
||||
return "", nil, nil, logical.ErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
b.Logger().Debug("use_sts_region_from_client set; using region specified from header", "region", clientSpecifiedRegion)
|
||||
endpoint = url
|
||||
}
|
||||
|
||||
b.Logger().Debug("submitting caller identity request", "endpoint", endpoint)
|
||||
callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers)
|
||||
if err != nil {
|
||||
return "", nil, nil, logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil
|
||||
@@ -1884,6 +1905,43 @@ func getMetadataValue(fromAuth *logical.Auth, forKey string) (string, error) {
|
||||
return "", fmt.Errorf("%q not found in auth metadata", forKey)
|
||||
}
|
||||
|
||||
func awsRegionFromHeader(authorizationHeader string) (string, error) {
|
||||
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html
|
||||
// The Authorization header takes the following form.
|
||||
// Authorization: AWS4-HMAC-SHA256
|
||||
// Credential=AKIAIOSFODNN7EXAMPLE/20230719/us-east-1/sts/aws4_request,
|
||||
// SignedHeaders=content-length;content-type;host;x-amz-date,
|
||||
// Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024
|
||||
//
|
||||
// The credential is in the form of "<your-access-key-id>/<date>/<aws-region>/<aws-service>/aws4_request"
|
||||
fields := strings.Split(authorizationHeader, " ")
|
||||
for _, field := range fields {
|
||||
if strings.HasPrefix(field, "Credential=") {
|
||||
fields := strings.Split(field, "/")
|
||||
if len(fields) < 3 {
|
||||
return "", fmt.Errorf("invalid header format")
|
||||
}
|
||||
|
||||
region := fields[2]
|
||||
return region, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid header format")
|
||||
}
|
||||
|
||||
func stsRegionalEndpoint(region string) (string, error) {
|
||||
stsService := sts.EndpointsID
|
||||
resolver := endpoints.DefaultResolver()
|
||||
resolvedEndpoint, err := resolver.EndpointFor(stsService, region,
|
||||
endpoints.STSRegionalEndpointOption,
|
||||
endpoints.StrictMatchingOption)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to get regional STS endpoint for region: %v", region)
|
||||
}
|
||||
return resolvedEndpoint.URL, nil
|
||||
}
|
||||
|
||||
const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID"
|
||||
|
||||
const pathLoginSyn = `
|
||||
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
)
|
||||
|
||||
@@ -625,6 +627,58 @@ func TestBackend_defaultAliasMetadata(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegionFromHeader(t *testing.T) {
|
||||
tcs := map[string]struct {
|
||||
header string
|
||||
expectedRegion string
|
||||
expectedSTSEndpoint string
|
||||
}{
|
||||
"us-east-1": {
|
||||
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
expectedRegion: "us-east-1",
|
||||
expectedSTSEndpoint: "https://sts.us-east-1.amazonaws.com",
|
||||
},
|
||||
"us-west-2": {
|
||||
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-west-2/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
expectedRegion: "us-west-2",
|
||||
expectedSTSEndpoint: "https://sts.us-west-2.amazonaws.com",
|
||||
},
|
||||
"ap-northeast-3": {
|
||||
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/ap-northeast-3/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
expectedRegion: "ap-northeast-3",
|
||||
expectedSTSEndpoint: "https://sts.ap-northeast-3.amazonaws.com",
|
||||
},
|
||||
"us-gov-east-1": {
|
||||
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-gov-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
expectedRegion: "us-gov-east-1",
|
||||
expectedSTSEndpoint: "https://sts.us-gov-east-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
for name, tc := range tcs {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
region, err := awsRegionFromHeader(tc.header)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedRegion, region)
|
||||
|
||||
stsEndpoint, err := stsRegionalEndpoint(region)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedSTSEndpoint, stsEndpoint)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("invalid-header", func(t *testing.T) {
|
||||
region, err := awsRegionFromHeader("this-is-an-invalid-header/foobar")
|
||||
assert.EqualError(t, err, "invalid header format")
|
||||
assert.Empty(t, region)
|
||||
})
|
||||
|
||||
t.Run("invalid-region", func(t *testing.T) {
|
||||
endpoint, err := stsRegionalEndpoint("fake-region-1")
|
||||
assert.EqualError(t, err, "unable to get regional STS endpoint for region: fake-region-1")
|
||||
assert.Empty(t, endpoint)
|
||||
})
|
||||
}
|
||||
|
||||
func defaultLoginData() (map[string]interface{}, error) {
|
||||
awsSession, err := session.NewSession()
|
||||
if err != nil {
|
||||
|
||||
3
changelog/21960.txt
Normal file
3
changelog/21960.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
```release-note:improvement
|
||||
aws/auth: Adds a new config field `use_sts_region_from_client` which allows for using dynamic regional sts endpoints based on Authorization header when using IAM-based authentication.
|
||||
```
|
||||
@@ -65,6 +65,10 @@ capabilities, the credentials are fetched automatically.
|
||||
- `sts_region` `(string: "")` - Region to override the default region for making
|
||||
AWS STS API calls. Should only be set if `sts_endpoint` is set. If so, should
|
||||
be set to the region in which the custom `sts_endpoint` resides.
|
||||
- `use_sts_region_from_client` `(boolean: false)` - If set, overrides both `sts_endpoint`
|
||||
and `sts_region` to instead use the region specified in the client request headers for
|
||||
IAM-based authentication . This can be useful when you have client requests coming from
|
||||
different regions and want flexibility in which regional STS API is used.
|
||||
- `iam_server_id_header_value` `(string: "")` - The value to require in the
|
||||
`X-Vault-AWS-IAM-Server-ID` header as part of GetCallerIdentity requests that
|
||||
are used in the iam auth method. If not set, then no value is required or
|
||||
@@ -123,6 +127,7 @@ $ curl \
|
||||
"iam_endpoint": "",
|
||||
"sts_endpoint": "",
|
||||
"sts_region": "",
|
||||
"use_sts_region_from_client": false,
|
||||
"iam_server_id_header_value": ""
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user