diff --git a/builtin/credential/aws/backend_test.go b/builtin/credential/aws/backend_test.go index 98446029fd..67c4fb7ce4 100644 --- a/builtin/credential/aws/backend_test.go +++ b/builtin/credential/aws/backend_test.go @@ -1519,7 +1519,7 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) { t.Fatal(err) } if resp == nil || resp.Auth == nil || resp.IsError() { - t.Errorf("bad: expected valid login: resp:%#v", resp) + t.Fatalf("bad: expected valid login: resp:%#v", resp) } renewReq := &logical.Request{ diff --git a/builtin/credential/aws/path_login.go b/builtin/credential/aws/path_login.go index 99edadb74f..66effda567 100644 --- a/builtin/credential/aws/path_login.go +++ b/builtin/credential/aws/path_login.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "net/http" "net/url" + "reflect" "regexp" "strings" "time" @@ -1086,14 +1087,12 @@ func (b *backend) pathLoginUpdateIam( if headersB64 == "" { return logical.ErrorResponse("missing iam_request_headers"), nil } - headersJson, err := base64.StdEncoding.DecodeString(headersB64) + headers, err := parseIamRequestHeaders(headersB64) if err != nil { - return logical.ErrorResponse("failed to base64 decode iam_request_headers"), nil + return logical.ErrorResponse(fmt.Sprintf("Error parsing iam_request_headers: %v", err)), nil } - var headers http.Header - err = jsonutil.DecodeJSON(headersJson, &headers) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("failed to JSON decode iam_request_headers %q: %v", headersJson, err)), nil + if headers == nil { + return logical.ErrorResponse("nil response when parsing iam_request_headers"), nil } config, err := b.lockedClientConfigEntry(req.Storage) @@ -1399,6 +1398,37 @@ func parseGetCallerIdentityResponse(response string) (GetCallerIdentityResponse, return result, err } +func parseIamRequestHeaders(headersB64 string) (http.Header, error) { + headersJson, err := base64.StdEncoding.DecodeString(headersB64) + if err != nil { + return nil, fmt.Errorf("failed to base64 decode iam_request_headers") + } + var headersDecoded map[string]interface{} + err = jsonutil.DecodeJSON(headersJson, &headersDecoded) + if err != nil { + return nil, fmt.Errorf("failed to JSON decode iam_request_headers %q: %v", headersJson, err) + } + headers := make(http.Header) + for k, v := range headersDecoded { + switch typedValue := v.(type) { + case string: + headers.Add(k, typedValue) + case []interface{}: + for _, individualVal := range typedValue { + switch possibleStrVal := individualVal.(type) { + case string: + headers.Add(k, possibleStrVal) + default: + return nil, fmt.Errorf("header %q contains value %q that has type %s, not string", k, individualVal, reflect.TypeOf(individualVal)) + } + } + default: + return nil, fmt.Errorf("header %q value %q has type %s, not string or []interface", k, typedValue, reflect.TypeOf(v)) + } + } + return headers, nil +} + func submitCallerIdentityRequest(method, endpoint string, parsedUrl *url.URL, body string, headers http.Header) (*GetCallerIdentityResult, error) { // NOTE: We need to ensure we're calling STS, instead of acting as an unintended network proxy // The protection against this is that this method will only call the endpoint specified in the diff --git a/builtin/credential/aws/path_login_test.go b/builtin/credential/aws/path_login_test.go index 1a3b32429c..58f253bb16 100644 --- a/builtin/credential/aws/path_login_test.go +++ b/builtin/credential/aws/path_login_test.go @@ -1,8 +1,12 @@ package awsauth import ( + "encoding/base64" + "encoding/json" + "fmt" "net/http" "net/url" + "reflect" "testing" ) @@ -143,3 +147,43 @@ func TestBackend_validateVaultHeaderValue(t *testing.T) { t.Errorf("did NOT validate valid POST request with split Authorization header: %v", err) } } + +func TestBackend_pathLogin_parseIamRequestHeaders(t *testing.T) { + testIamParser := func(headers interface{}, expectedHeaders http.Header) error { + headersJson, err := json.Marshal(headers) + if err != nil { + return fmt.Errorf("unable to JSON encode headers: %v", err) + } + headersB64 := base64.StdEncoding.EncodeToString(headersJson) + + parsedHeaders, err := parseIamRequestHeaders(headersB64) + if err != nil { + return fmt.Errorf("error parsing encoded headers: %v", err) + } + if parsedHeaders == nil { + return fmt.Errorf("nil result from parsing headers") + } + if !reflect.DeepEqual(parsedHeaders, expectedHeaders) { + return fmt.Errorf("parsed headers not equal to input headers") + } + return nil + } + + headersGoStyle := http.Header{ + "Header1": []string{"Value1"}, + "Header2": []string{"Value2"}, + } + headersMixedType := map[string]interface{}{ + "Header1": "Value1", + "Header2": []string{"Value2"}, + } + + err := testIamParser(headersGoStyle, headersGoStyle) + if err != nil { + t.Errorf("error parsing go-style headers: %v", err) + } + err = testIamParser(headersMixedType, headersGoStyle) + if err != nil { + t.Errorf("error parsing mixed-style headers: %v", err) + } +} diff --git a/website/source/docs/auth/aws.html.md b/website/source/docs/auth/aws.html.md index 1c7e260009..19c8c640b9 100644 --- a/website/source/docs/auth/aws.html.md +++ b/website/source/docs/auth/aws.html.md @@ -1872,8 +1872,10 @@ The response will be in JSON. For example:
  • iam_request_headers required - Base64-encoded, JSON-serialized representation of the HTTP request - headers. The JSON serialization assumes that each header key maps to an + Base64-encoded, JSON-serialized representation of the + sts:GetCallerIdentity HTTP request + headers. The JSON serialization assumes that each header key maps to + either a string value or an array of string values (though the length of that array will probably only be one). If the `iam_server_id_header_value` is configured in Vault for the aws auth mount, then the headers must include the