AWS auth login with multi region STS support (#21960)

This commit is contained in:
Raymond Ho
2023-07-28 08:42:22 -07:00
committed by GitHub
parent 194e8cdb02
commit 4f7a8fb494
5 changed files with 137 additions and 0 deletions

View File

@@ -12,6 +12,7 @@ import (
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -61,6 +62,12 @@ func (b *backend) pathConfigClient() *framework.Path {
Description: "The region ID for the sts_endpoint, if set.", Description: "The region ID for the sts_endpoint, if set.",
}, },
"use_sts_region_from_client": {
Type: framework.TypeBool,
Default: false,
Description: "Uses the STS region from client requests for making AWS STS API calls.",
},
"iam_server_id_header_value": { "iam_server_id_header_value": {
Type: framework.TypeString, Type: framework.TypeString,
Default: "", Default: "",
@@ -168,6 +175,7 @@ func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request
"iam_endpoint": clientConfig.IAMEndpoint, "iam_endpoint": clientConfig.IAMEndpoint,
"sts_endpoint": clientConfig.STSEndpoint, "sts_endpoint": clientConfig.STSEndpoint,
"sts_region": clientConfig.STSRegion, "sts_region": clientConfig.STSRegion,
"use_sts_region_from_client": clientConfig.UseSTSRegionFromClient,
"iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue, "iam_server_id_header_value": clientConfig.IAMServerIdHeaderValue,
"max_retries": clientConfig.MaxRetries, "max_retries": clientConfig.MaxRetries,
"allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues, "allowed_sts_header_values": clientConfig.AllowedSTSHeaderValues,
@@ -281,6 +289,14 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
} }
} }
useSTSRegionFromClientRaw, ok := data.GetOk("use_sts_region_from_client")
if ok {
if configEntry.UseSTSRegionFromClient != useSTSRegionFromClientRaw.(bool) {
changedCreds = true
configEntry.UseSTSRegionFromClient = useSTSRegionFromClientRaw.(bool)
}
}
headerValStr, ok := data.GetOk("iam_server_id_header_value") headerValStr, ok := data.GetOk("iam_server_id_header_value")
if ok { if ok {
if configEntry.IAMServerIdHeaderValue != headerValStr.(string) { if configEntry.IAMServerIdHeaderValue != headerValStr.(string) {
@@ -363,6 +379,7 @@ type clientConfig struct {
IAMEndpoint string `json:"iam_endpoint"` IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"` STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"` STSRegion string `json:"sts_region"`
UseSTSRegionFromClient bool `json:"use_sts_region_from_client"`
IAMServerIdHeaderValue string `json:"iam_server_id_header_value"` IAMServerIdHeaderValue string `json:"iam_server_id_header_value"`
AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"` AllowedSTSHeaderValues []string `json:"allowed_sts_header_values"`
MaxRetries int `json:"max_retries"` MaxRetries int `json:"max_retries"`

View File

@@ -21,8 +21,10 @@ import (
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
awsClient "github.com/aws/aws-sdk-go/aws/client" awsClient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
cleanhttp "github.com/hashicorp/go-cleanhttp" cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/go-retryablehttp"
@@ -30,6 +32,7 @@ import (
"github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-secure-stdlib/strutil"
uuid "github.com/hashicorp/go-uuid" uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/credential/aws/pkcs7" "github.com/hashicorp/vault/builtin/credential/aws/pkcs7"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/cidrutil" "github.com/hashicorp/vault/sdk/helper/cidrutil"
@@ -318,6 +321,24 @@ func (b *backend) pathLoginIamGetRoleNameCallerIdAndEntity(ctx context.Context,
} }
} }
// Extract and use a regional STS endpoint
// based on the region set in the Authorization header.
if config.UseSTSRegionFromClient {
clientSpecifiedRegion, err := awsRegionFromHeader(headers.Get("Authorization"))
if err != nil {
return "", nil, nil, logical.ErrorResponse("region missing from Authorization header"), nil
}
url, err := stsRegionalEndpoint(clientSpecifiedRegion)
if err != nil {
return "", nil, nil, logical.ErrorResponse(err.Error()), nil
}
b.Logger().Debug("use_sts_region_from_client set; using region specified from header", "region", clientSpecifiedRegion)
endpoint = url
}
b.Logger().Debug("submitting caller identity request", "endpoint", endpoint)
callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers) callerID, err := submitCallerIdentityRequest(ctx, maxRetries, method, endpoint, parsedUrl, body, headers)
if err != nil { if err != nil {
return "", nil, nil, logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil return "", nil, nil, logical.ErrorResponse(fmt.Sprintf("error making upstream request: %v", err)), nil
@@ -1884,6 +1905,43 @@ func getMetadataValue(fromAuth *logical.Auth, forKey string) (string, error) {
return "", fmt.Errorf("%q not found in auth metadata", forKey) return "", fmt.Errorf("%q not found in auth metadata", forKey)
} }
func awsRegionFromHeader(authorizationHeader string) (string, error) {
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html
// The Authorization header takes the following form.
// Authorization: AWS4-HMAC-SHA256
// Credential=AKIAIOSFODNN7EXAMPLE/20230719/us-east-1/sts/aws4_request,
// SignedHeaders=content-length;content-type;host;x-amz-date,
// Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024
//
// The credential is in the form of "<your-access-key-id>/<date>/<aws-region>/<aws-service>/aws4_request"
fields := strings.Split(authorizationHeader, " ")
for _, field := range fields {
if strings.HasPrefix(field, "Credential=") {
fields := strings.Split(field, "/")
if len(fields) < 3 {
return "", fmt.Errorf("invalid header format")
}
region := fields[2]
return region, nil
}
}
return "", fmt.Errorf("invalid header format")
}
func stsRegionalEndpoint(region string) (string, error) {
stsService := sts.EndpointsID
resolver := endpoints.DefaultResolver()
resolvedEndpoint, err := resolver.EndpointFor(stsService, region,
endpoints.STSRegionalEndpointOption,
endpoints.StrictMatchingOption)
if err != nil {
return "", fmt.Errorf("unable to get regional STS endpoint for region: %v", region)
}
return resolvedEndpoint.URL, nil
}
const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID" const iamServerIdHeader = "X-Vault-AWS-IAM-Server-ID"
const pathLoginSyn = ` const pathLoginSyn = `

View File

@@ -16,6 +16,8 @@ import (
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -625,6 +627,58 @@ func TestBackend_defaultAliasMetadata(t *testing.T) {
} }
} }
func TestRegionFromHeader(t *testing.T) {
tcs := map[string]struct {
header string
expectedRegion string
expectedSTSEndpoint string
}{
"us-east-1": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-east-1",
expectedSTSEndpoint: "https://sts.us-east-1.amazonaws.com",
},
"us-west-2": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-west-2/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-west-2",
expectedSTSEndpoint: "https://sts.us-west-2.amazonaws.com",
},
"ap-northeast-3": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/ap-northeast-3/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "ap-northeast-3",
expectedSTSEndpoint: "https://sts.ap-northeast-3.amazonaws.com",
},
"us-gov-east-1": {
header: "AWS4-HMAC-SHA256 Credential=AAAAAAAAAAAAAAAAAAAA/20230719/us-gov-east-1/sts/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
expectedRegion: "us-gov-east-1",
expectedSTSEndpoint: "https://sts.us-gov-east-1.amazonaws.com",
},
}
for name, tc := range tcs {
t.Run(name, func(t *testing.T) {
region, err := awsRegionFromHeader(tc.header)
assert.NoError(t, err)
assert.Equal(t, tc.expectedRegion, region)
stsEndpoint, err := stsRegionalEndpoint(region)
assert.NoError(t, err)
assert.Equal(t, tc.expectedSTSEndpoint, stsEndpoint)
})
}
t.Run("invalid-header", func(t *testing.T) {
region, err := awsRegionFromHeader("this-is-an-invalid-header/foobar")
assert.EqualError(t, err, "invalid header format")
assert.Empty(t, region)
})
t.Run("invalid-region", func(t *testing.T) {
endpoint, err := stsRegionalEndpoint("fake-region-1")
assert.EqualError(t, err, "unable to get regional STS endpoint for region: fake-region-1")
assert.Empty(t, endpoint)
})
}
func defaultLoginData() (map[string]interface{}, error) { func defaultLoginData() (map[string]interface{}, error) {
awsSession, err := session.NewSession() awsSession, err := session.NewSession()
if err != nil { if err != nil {

3
changelog/21960.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:improvement
aws/auth: Adds a new config field `use_sts_region_from_client` which allows for using dynamic regional sts endpoints based on Authorization header when using IAM-based authentication.
```

View File

@@ -65,6 +65,10 @@ capabilities, the credentials are fetched automatically.
- `sts_region` `(string: "")` - Region to override the default region for making - `sts_region` `(string: "")` - Region to override the default region for making
AWS STS API calls. Should only be set if `sts_endpoint` is set. If so, should AWS STS API calls. Should only be set if `sts_endpoint` is set. If so, should
be set to the region in which the custom `sts_endpoint` resides. be set to the region in which the custom `sts_endpoint` resides.
- `use_sts_region_from_client` `(boolean: false)` - If set, overrides both `sts_endpoint`
and `sts_region` to instead use the region specified in the client request headers for
IAM-based authentication . This can be useful when you have client requests coming from
different regions and want flexibility in which regional STS API is used.
- `iam_server_id_header_value` `(string: "")` - The value to require in the - `iam_server_id_header_value` `(string: "")` - The value to require in the
`X-Vault-AWS-IAM-Server-ID` header as part of GetCallerIdentity requests that `X-Vault-AWS-IAM-Server-ID` header as part of GetCallerIdentity requests that
are used in the iam auth method. If not set, then no value is required or are used in the iam auth method. If not set, then no value is required or
@@ -123,6 +127,7 @@ $ curl \
"iam_endpoint": "", "iam_endpoint": "",
"sts_endpoint": "", "sts_endpoint": "",
"sts_region": "", "sts_region": "",
"use_sts_region_from_client": false,
"iam_server_id_header_value": "" "iam_server_id_header_value": ""
} }
} }