VAULT-32804: Add STS Fallback parameters to secrets-aws engine (#29051)

Co-authored-by: Sarah Chavis <62406755+schavis@users.noreply.github.com>

---------

Co-authored-by: Robert <17119716+robmonte@users.noreply.github.com>
Co-authored-by: Sarah Chavis <62406755+schavis@users.noreply.github.com>
This commit is contained in:
kpcraig
2024-12-05 16:22:21 -05:00
committed by GitHub
parent d515cd33b0
commit d8482b008a
5 changed files with 337 additions and 96 deletions

View File

@@ -5,6 +5,7 @@ package aws
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@@ -23,33 +24,76 @@ import (
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used.
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked // NOTE: The caller is required to ensure that b.clientMutex is at least read locked
func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) { func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) ([]*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{} // set fallback region (we can overwrite later)
var endpoint string fallbackRegion := os.Getenv("AWS_REGION")
var maxRetries int = aws.UseServiceDefaultRetries if fallbackRegion == "" {
fallbackRegion = os.Getenv("AWS_DEFAULT_REGION")
}
if fallbackRegion == "" {
fallbackRegion = "us-east-1"
}
maxRetries := aws.UseServiceDefaultRetries
entry, err := s.Get(ctx, "config/root") entry, err := s.Get(ctx, "config/root")
if err != nil { if err != nil {
return nil, err return nil, err
} }
if entry != nil { var configs []*aws.Config
// ensure the nil case uses defaults
if entry == nil {
ccfg := awsutil.CredentialsConfig{
HTTPClient: cleanhttp.DefaultClient(),
Logger: logger,
Region: fallbackRegion,
}
creds, err := ccfg.GenerateCredentialChain()
if err != nil {
return nil, err
}
configs = append(configs, &aws.Config{
Credentials: creds,
Region: aws.String(fallbackRegion),
Endpoint: aws.String(""),
MaxRetries: aws.Int(maxRetries),
})
return configs, nil
}
var config rootConfig var config rootConfig
if err := entry.DecodeJSON(&config); err != nil { if err := entry.DecodeJSON(&config); err != nil {
return nil, fmt.Errorf("error reading root configuration: %w", err) return nil, fmt.Errorf("error reading root configuration: %w", err)
} }
var endpoints []string
var regions []string
credsConfig := &awsutil.CredentialsConfig{}
credsConfig.AccessKey = config.AccessKey credsConfig.AccessKey = config.AccessKey
credsConfig.SecretKey = config.SecretKey credsConfig.SecretKey = config.SecretKey
credsConfig.Region = config.Region credsConfig.HTTPClient = cleanhttp.DefaultClient()
credsConfig.Logger = logger
maxRetries = config.MaxRetries maxRetries = config.MaxRetries
switch { if clientType == "iam" && config.IAMEndpoint != "" {
case clientType == "iam" && config.IAMEndpoint != "": endpoints = append(endpoints, config.IAMEndpoint)
endpoint = *aws.String(config.IAMEndpoint) } else if clientType == "sts" && config.STSEndpoint != "" {
case clientType == "sts" && config.STSEndpoint != "": endpoints = append(endpoints, config.STSEndpoint)
endpoint = *aws.String(config.STSEndpoint)
if config.STSRegion != "" { if config.STSRegion != "" {
credsConfig.Region = config.STSRegion regions = append(regions, config.STSRegion)
}
if len(config.STSFallbackEndpoints) > 0 {
endpoints = append(endpoints, config.STSFallbackEndpoints...)
}
if len(config.STSFallbackRegions) > 0 {
regions = append(regions, config.STSFallbackRegions...)
} }
} }
@@ -72,42 +116,47 @@ func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientTy
credsConfig.WebIdentityTokenFetcher = fetcher credsConfig.WebIdentityTokenFetcher = fetcher
credsConfig.RoleARN = config.RoleARN credsConfig.RoleARN = config.RoleARN
} }
if len(regions) == 0 {
regions = append(regions, fallbackRegion)
} }
if credsConfig.Region == "" { if len(regions) != len(endpoints) {
credsConfig.Region = os.Getenv("AWS_REGION") // this probably can't happen, if the input was checked correctly
if credsConfig.Region == "" { return nil, errors.New("number of regions does not match number of endpoints")
credsConfig.Region = os.Getenv("AWS_DEFAULT_REGION")
if credsConfig.Region == "" {
credsConfig.Region = "us-east-1"
}
}
} }
credsConfig.HTTPClient = cleanhttp.DefaultClient() for i := 0; i < len(endpoints); i++ {
if len(regions) > i {
credsConfig.Logger = logger credsConfig.Region = regions[i]
} else {
credsConfig.Region = fallbackRegion
}
creds, err := credsConfig.GenerateCredentialChain() creds, err := credsConfig.GenerateCredentialChain()
if err != nil { if err != nil {
return nil, err return nil, err
} }
configs = append(configs, &aws.Config{
return &aws.Config{
Credentials: creds, Credentials: creds,
Region: aws.String(credsConfig.Region), Region: aws.String(credsConfig.Region),
Endpoint: &endpoint, Endpoint: aws.String(endpoints[i]),
HTTPClient: cleanhttp.DefaultClient(),
MaxRetries: aws.Int(maxRetries), MaxRetries: aws.Int(maxRetries),
}, nil HTTPClient: cleanhttp.DefaultClient(),
})
}
return configs, nil
} }
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) { func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := b.getRootConfig(ctx, s, "iam", logger) awsConfig, err := b.getRootConfigs(ctx, s, "iam", logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sess, err := session.NewSession(awsConfig) if len(awsConfig) != 1 {
return nil, errors.New("could not obtain aws config")
}
sess, err := session.NewSession(awsConfig[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -119,19 +168,33 @@ func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, log
} }
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) { func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
awsConfig, err := b.getRootConfig(ctx, s, "sts", logger) awsConfig, err := b.getRootConfigs(ctx, s, "sts", logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sess, err := session.NewSession(awsConfig)
var client *sts.STS
for _, cfg := range awsConfig {
sess, err := session.NewSession(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
client := sts.New(sess) client = sts.New(sess)
if client == nil { if client == nil {
return nil, fmt.Errorf("could not obtain sts client") return nil, fmt.Errorf("could not obtain sts client")
} }
// ping the client - we only care about errors
_, err = client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return client, nil return client, nil
} else {
b.Logger().Debug("couldn't connect with config trying next", "failed endpoint", cfg.Endpoint, "failed region", cfg.Region)
}
}
return nil, fmt.Errorf("could not obtain sts client")
} }
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided // PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided

View File

@@ -52,6 +52,14 @@ func pathConfigRoot(b *backend) *framework.Path {
Type: framework.TypeString, Type: framework.TypeString,
Description: "Specific region for STS API calls.", Description: "Specific region for STS API calls.",
}, },
"sts_fallback_endpoints": {
Type: framework.TypeCommaStringSlice,
Description: "Fallback endpoints if sts_endpoint is unreachable",
},
"sts_fallback_regions": {
Type: framework.TypeCommaStringSlice,
Description: "Fallback regions if sts_region is unreachable",
},
"max_retries": { "max_retries": {
Type: framework.TypeInt, Type: framework.TypeInt,
Default: aws.UseServiceDefaultRetries, Default: aws.UseServiceDefaultRetries,
@@ -115,6 +123,8 @@ func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request,
"iam_endpoint": config.IAMEndpoint, "iam_endpoint": config.IAMEndpoint,
"sts_endpoint": config.STSEndpoint, "sts_endpoint": config.STSEndpoint,
"sts_region": config.STSRegion, "sts_region": config.STSRegion,
"sts_fallback_endpoints": config.STSFallbackEndpoints,
"sts_fallback_regions": config.STSFallbackRegions,
"max_retries": config.MaxRetries, "max_retries": config.MaxRetries,
"username_template": config.UsernameTemplate, "username_template": config.UsernameTemplate,
"role_arn": config.RoleARN, "role_arn": config.RoleARN,
@@ -138,6 +148,13 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
usernameTemplate = defaultUserNameTemplate usernameTemplate = defaultUserNameTemplate
} }
stsFallbackEndpoints := data.Get("sts_fallback_endpoints").([]string)
stsFallbackRegions := data.Get("sts_fallback_regions").([]string)
if len(stsFallbackEndpoints) != len(stsFallbackRegions) {
return logical.ErrorResponse("fallback endpoints and fallback regions must be the same length"), nil
}
b.clientMutex.Lock() b.clientMutex.Lock()
defer b.clientMutex.Unlock() defer b.clientMutex.Unlock()
@@ -147,6 +164,8 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
IAMEndpoint: iamendpoint, IAMEndpoint: iamendpoint,
STSEndpoint: stsendpoint, STSEndpoint: stsendpoint,
STSRegion: stsregion, STSRegion: stsregion,
STSFallbackEndpoints: stsFallbackEndpoints,
STSFallbackRegions: stsFallbackRegions,
Region: region, Region: region,
MaxRetries: maxretries, MaxRetries: maxretries,
UsernameTemplate: usernameTemplate, UsernameTemplate: usernameTemplate,
@@ -201,6 +220,8 @@ type rootConfig struct {
IAMEndpoint string `json:"iam_endpoint"` IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"` STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"` STSRegion string `json:"sts_region"`
STSFallbackEndpoints []string `json:"sts_fallback_endpoints"`
STSFallbackRegions []string `json:"sts_fallback_regions"`
Region string `json:"region"` Region string `json:"region"`
MaxRetries int `json:"max_retries"` MaxRetries int `json:"max_retries"`
UsernameTemplate string `json:"username_template"` UsernameTemplate string `json:"username_template"`

View File

@@ -31,6 +31,8 @@ func TestBackend_PathConfigRoot(t *testing.T) {
"iam_endpoint": "https://iam.amazonaws.com", "iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com", "sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"sts_region": "", "sts_region": "",
"sts_fallback_endpoints": []string{},
"sts_fallback_regions": []string{},
"max_retries": 10, "max_retries": 10,
"username_template": defaultUserNameTemplate, "username_template": defaultUserNameTemplate,
"role_arn": "", "role_arn": "",
@@ -66,6 +68,152 @@ func TestBackend_PathConfigRoot(t *testing.T) {
} }
} }
// TestBackend_PathConfigRoot_STSFallback tests valid versions of STS fallback parameters - slice and csv
func TestBackend_PathConfigRoot_STSFallback(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = &testSystemView{}
b := Backend(config)
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"sts_region": "",
"sts_fallback_endpoints": []string{"192.168.1.1", "127.0.0.1"},
"sts_fallback_regions": []string{"my-house-1", "my-house-2"},
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
"identity_token_audience": "",
"identity_token_ttl": int64(0),
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err := b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Storage: config.StorageView,
Path: "config/root",
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err)
}
delete(configData, "secret_key")
require.Equal(t, configData, resp.Data)
if !reflect.DeepEqual(resp.Data, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}
// test we can handle comma separated strings, per CommaStringSlice
configData = map[string]interface{}{
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"sts_region": "",
"sts_fallback_endpoints": "1.1.1.1,8.8.8.8",
"sts_fallback_regions": "zone-1,zone-2",
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
"identity_token_audience": "",
"identity_token_ttl": int64(0),
}
configReq = &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Storage: config.StorageView,
Path: "config/root",
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err)
}
delete(configData, "secret_key")
configData["sts_fallback_endpoints"] = []string{"1.1.1.1", "8.8.8.8"}
configData["sts_fallback_regions"] = []string{"zone-1", "zone-2"}
require.Equal(t, configData, resp.Data)
if !reflect.DeepEqual(resp.Data, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}
}
// TestBackend_PathConfigRoot_STSFallback_mismatchedfallback ensures configuration writing will fail if the
// region/endpoint entries are different lengths
func TestBackend_PathConfigRoot_STSFallback_mismatchedfallback(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = &testSystemView{}
b := Backend(config)
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}
// test we can handle comma separated strings, per CommaStringSlice
configData := map[string]interface{}{
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"sts_region": "",
"sts_fallback_endpoints": "1.1.1.1,8.8.8.8",
"sts_fallback_regions": "zone-1,zone-2",
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
"identity_token_audience": "",
"identity_token_ttl": int64(0),
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err := b.HandleRequest(context.Background(), configReq)
if err != nil {
t.Fatalf("bad: config writing failed: err: %v", err)
}
if resp != nil && !resp.IsError() {
t.Fatalf("expected an error, but it successfully wrote")
}
}
// TestBackend_PathConfigRoot_PluginIdentityToken tests that configuration // TestBackend_PathConfigRoot_PluginIdentityToken tests that configuration
// of plugin WIF returns an immediate error. // of plugin WIF returns an immediate error.
func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) { func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) {

3
changelog/29051.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:improvement
secrets/aws: add fallback endpoint and region parameters to sts configuration
```

View File

@@ -79,6 +79,12 @@ valid AWS credentials with proper permissions.
- `sts_endpoint` `(string: <optional>)`  Specifies a custom HTTP STS endpoint to use. - `sts_endpoint` `(string: <optional>)`  Specifies a custom HTTP STS endpoint to use.
- `sts_region` `(string: <optional>)` - Specifies a custom STS region to use (should match `sts_endpoint`)
- `sts_fallback_endpoints` `(list: <optional>)` - Specifies an ordered list of fallback STS endpoints to use
- `sts_fallback_regions` `(list: <optional>)` - Specifies an ordered list of fallback STS regions to use (should match fallback endpoints)
- `username_template` `(string: <optional>)` - [Template](/vault/docs/concepts/username-templating) describing how - `username_template` `(string: <optional>)` - [Template](/vault/docs/concepts/username-templating) describing how
dynamic usernames are generated. The username template is used to generate both IAM usernames (capped at 64 characters) dynamic usernames are generated. The username template is used to generate both IAM usernames (capped at 64 characters)
and STS usernames (capped at 32 characters). Longer usernames result in a 500 error. and STS usernames (capped at 32 characters). Longer usernames result in a 500 error.