VAULT-32804: Add STS Fallback parameters to secrets-aws engine (#29051)

Co-authored-by: Sarah Chavis <62406755+schavis@users.noreply.github.com>

---------

Co-authored-by: Robert <17119716+robmonte@users.noreply.github.com>
Co-authored-by: Sarah Chavis <62406755+schavis@users.noreply.github.com>
This commit is contained in:
kpcraig
2024-12-05 16:22:21 -05:00
committed by GitHub
parent d515cd33b0
commit d8482b008a
5 changed files with 337 additions and 96 deletions

View File

@@ -5,6 +5,7 @@ package aws
import (
"context"
"errors"
"fmt"
"os"
"strconv"
@@ -23,91 +24,139 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)
// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used.
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{}
var endpoint string
var maxRetries int = aws.UseServiceDefaultRetries
func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) ([]*aws.Config, error) {
// set fallback region (we can overwrite later)
fallbackRegion := os.Getenv("AWS_REGION")
if fallbackRegion == "" {
fallbackRegion = os.Getenv("AWS_DEFAULT_REGION")
}
if fallbackRegion == "" {
fallbackRegion = "us-east-1"
}
maxRetries := aws.UseServiceDefaultRetries
entry, err := s.Get(ctx, "config/root")
if err != nil {
return nil, err
}
if entry != nil {
var config rootConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, fmt.Errorf("error reading root configuration: %w", err)
var configs []*aws.Config
// ensure the nil case uses defaults
if entry == nil {
ccfg := awsutil.CredentialsConfig{
HTTPClient: cleanhttp.DefaultClient(),
Logger: logger,
Region: fallbackRegion,
}
credsConfig.AccessKey = config.AccessKey
credsConfig.SecretKey = config.SecretKey
credsConfig.Region = config.Region
maxRetries = config.MaxRetries
switch {
case clientType == "iam" && config.IAMEndpoint != "":
endpoint = *aws.String(config.IAMEndpoint)
case clientType == "sts" && config.STSEndpoint != "":
endpoint = *aws.String(config.STSEndpoint)
if config.STSRegion != "" {
credsConfig.Region = config.STSRegion
}
creds, err := ccfg.GenerateCredentialChain()
if err != nil {
return nil, err
}
configs = append(configs, &aws.Config{
Credentials: creds,
Region: aws.String(fallbackRegion),
Endpoint: aws.String(""),
MaxRetries: aws.Int(maxRetries),
})
if config.IdentityTokenAudience != "" {
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
}
fetcher := &PluginIdentityTokenFetcher{
sys: b.System(),
logger: b.Logger(),
ns: ns,
audience: config.IdentityTokenAudience,
ttl: config.IdentityTokenTTL,
}
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
credsConfig.WebIdentityTokenFetcher = fetcher
credsConfig.RoleARN = config.RoleARN
}
return configs, nil
}
if credsConfig.Region == "" {
credsConfig.Region = os.Getenv("AWS_REGION")
if credsConfig.Region == "" {
credsConfig.Region = os.Getenv("AWS_DEFAULT_REGION")
if credsConfig.Region == "" {
credsConfig.Region = "us-east-1"
}
}
var config rootConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, fmt.Errorf("error reading root configuration: %w", err)
}
var endpoints []string
var regions []string
credsConfig := &awsutil.CredentialsConfig{}
credsConfig.AccessKey = config.AccessKey
credsConfig.SecretKey = config.SecretKey
credsConfig.HTTPClient = cleanhttp.DefaultClient()
credsConfig.Logger = logger
creds, err := credsConfig.GenerateCredentialChain()
if err != nil {
return nil, err
maxRetries = config.MaxRetries
if clientType == "iam" && config.IAMEndpoint != "" {
endpoints = append(endpoints, config.IAMEndpoint)
} else if clientType == "sts" && config.STSEndpoint != "" {
endpoints = append(endpoints, config.STSEndpoint)
if config.STSRegion != "" {
regions = append(regions, config.STSRegion)
}
if len(config.STSFallbackEndpoints) > 0 {
endpoints = append(endpoints, config.STSFallbackEndpoints...)
}
if len(config.STSFallbackRegions) > 0 {
regions = append(regions, config.STSFallbackRegions...)
}
}
return &aws.Config{
Credentials: creds,
Region: aws.String(credsConfig.Region),
Endpoint: &endpoint,
HTTPClient: cleanhttp.DefaultClient(),
MaxRetries: aws.Int(maxRetries),
}, nil
if config.IdentityTokenAudience != "" {
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
}
fetcher := &PluginIdentityTokenFetcher{
sys: b.System(),
logger: b.Logger(),
ns: ns,
audience: config.IdentityTokenAudience,
ttl: config.IdentityTokenTTL,
}
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
credsConfig.WebIdentityTokenFetcher = fetcher
credsConfig.RoleARN = config.RoleARN
}
if len(regions) == 0 {
regions = append(regions, fallbackRegion)
}
if len(regions) != len(endpoints) {
// this probably can't happen, if the input was checked correctly
return nil, errors.New("number of regions does not match number of endpoints")
}
for i := 0; i < len(endpoints); i++ {
if len(regions) > i {
credsConfig.Region = regions[i]
} else {
credsConfig.Region = fallbackRegion
}
creds, err := credsConfig.GenerateCredentialChain()
if err != nil {
return nil, err
}
configs = append(configs, &aws.Config{
Credentials: creds,
Region: aws.String(credsConfig.Region),
Endpoint: aws.String(endpoints[i]),
MaxRetries: aws.Int(maxRetries),
HTTPClient: cleanhttp.DefaultClient(),
})
}
return configs, nil
}
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := b.getRootConfig(ctx, s, "iam", logger)
awsConfig, err := b.getRootConfigs(ctx, s, "iam", logger)
if err != nil {
return nil, err
}
sess, err := session.NewSession(awsConfig)
if len(awsConfig) != 1 {
return nil, errors.New("could not obtain aws config")
}
sess, err := session.NewSession(awsConfig[0])
if err != nil {
return nil, err
}
@@ -119,19 +168,33 @@ func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, log
}
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
awsConfig, err := b.getRootConfig(ctx, s, "sts", logger)
awsConfig, err := b.getRootConfigs(ctx, s, "sts", logger)
if err != nil {
return nil, err
}
sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, err
var client *sts.STS
for _, cfg := range awsConfig {
sess, err := session.NewSession(cfg)
if err != nil {
return nil, err
}
client = sts.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain sts client")
}
// ping the client - we only care about errors
_, err = client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return client, nil
} else {
b.Logger().Debug("couldn't connect with config trying next", "failed endpoint", cfg.Endpoint, "failed region", cfg.Region)
}
}
client := sts.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain sts client")
}
return client, nil
return nil, fmt.Errorf("could not obtain sts client")
}
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided

View File

@@ -52,6 +52,14 @@ func pathConfigRoot(b *backend) *framework.Path {
Type: framework.TypeString,
Description: "Specific region for STS API calls.",
},
"sts_fallback_endpoints": {
Type: framework.TypeCommaStringSlice,
Description: "Fallback endpoints if sts_endpoint is unreachable",
},
"sts_fallback_regions": {
Type: framework.TypeCommaStringSlice,
Description: "Fallback regions if sts_region is unreachable",
},
"max_retries": {
Type: framework.TypeInt,
Default: aws.UseServiceDefaultRetries,
@@ -110,14 +118,16 @@ func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request,
}
configData := map[string]interface{}{
"access_key": config.AccessKey,
"region": config.Region,
"iam_endpoint": config.IAMEndpoint,
"sts_endpoint": config.STSEndpoint,
"sts_region": config.STSRegion,
"max_retries": config.MaxRetries,
"username_template": config.UsernameTemplate,
"role_arn": config.RoleARN,
"access_key": config.AccessKey,
"region": config.Region,
"iam_endpoint": config.IAMEndpoint,
"sts_endpoint": config.STSEndpoint,
"sts_region": config.STSRegion,
"sts_fallback_endpoints": config.STSFallbackEndpoints,
"sts_fallback_regions": config.STSFallbackRegions,
"max_retries": config.MaxRetries,
"username_template": config.UsernameTemplate,
"role_arn": config.RoleARN,
}
config.PopulatePluginIdentityTokenData(configData)
@@ -138,19 +148,28 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
usernameTemplate = defaultUserNameTemplate
}
stsFallbackEndpoints := data.Get("sts_fallback_endpoints").([]string)
stsFallbackRegions := data.Get("sts_fallback_regions").([]string)
if len(stsFallbackEndpoints) != len(stsFallbackRegions) {
return logical.ErrorResponse("fallback endpoints and fallback regions must be the same length"), nil
}
b.clientMutex.Lock()
defer b.clientMutex.Unlock()
rc := rootConfig{
AccessKey: data.Get("access_key").(string),
SecretKey: data.Get("secret_key").(string),
IAMEndpoint: iamendpoint,
STSEndpoint: stsendpoint,
STSRegion: stsregion,
Region: region,
MaxRetries: maxretries,
UsernameTemplate: usernameTemplate,
RoleARN: roleARN,
AccessKey: data.Get("access_key").(string),
SecretKey: data.Get("secret_key").(string),
IAMEndpoint: iamendpoint,
STSEndpoint: stsendpoint,
STSRegion: stsregion,
STSFallbackEndpoints: stsFallbackEndpoints,
STSFallbackRegions: stsFallbackRegions,
Region: region,
MaxRetries: maxretries,
UsernameTemplate: usernameTemplate,
RoleARN: roleARN,
}
if err := rc.ParsePluginIdentityTokenFields(data); err != nil {
return logical.ErrorResponse(err.Error()), nil
@@ -196,15 +215,17 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
type rootConfig struct {
pluginidentityutil.PluginIdentityTokenParams
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
Region string `json:"region"`
MaxRetries int `json:"max_retries"`
UsernameTemplate string `json:"username_template"`
RoleARN string `json:"role_arn"`
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
IAMEndpoint string `json:"iam_endpoint"`
STSEndpoint string `json:"sts_endpoint"`
STSRegion string `json:"sts_region"`
STSFallbackEndpoints []string `json:"sts_fallback_endpoints"`
STSFallbackRegions []string `json:"sts_fallback_regions"`
Region string `json:"region"`
MaxRetries int `json:"max_retries"`
UsernameTemplate string `json:"username_template"`
RoleARN string `json:"role_arn"`
}
const pathConfigRootHelpSyn = `

View File

@@ -31,6 +31,8 @@ func TestBackend_PathConfigRoot(t *testing.T) {
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"sts_region": "",
"sts_fallback_endpoints": []string{},
"sts_fallback_regions": []string{},
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
@@ -66,6 +68,152 @@ func TestBackend_PathConfigRoot(t *testing.T) {
}
}
// TestBackend_PathConfigRoot_STSFallback tests valid versions of STS fallback parameters - slice and csv
func TestBackend_PathConfigRoot_STSFallback(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = &testSystemView{}
b := Backend(config)
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"sts_region": "",
"sts_fallback_endpoints": []string{"192.168.1.1", "127.0.0.1"},
"sts_fallback_regions": []string{"my-house-1", "my-house-2"},
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
"identity_token_audience": "",
"identity_token_ttl": int64(0),
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err := b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Storage: config.StorageView,
Path: "config/root",
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err)
}
delete(configData, "secret_key")
require.Equal(t, configData, resp.Data)
if !reflect.DeepEqual(resp.Data, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}
// test we can handle comma separated strings, per CommaStringSlice
configData = map[string]interface{}{
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"sts_region": "",
"sts_fallback_endpoints": "1.1.1.1,8.8.8.8",
"sts_fallback_regions": "zone-1,zone-2",
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
"identity_token_audience": "",
"identity_token_ttl": int64(0),
}
configReq = &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Storage: config.StorageView,
Path: "config/root",
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err)
}
delete(configData, "secret_key")
configData["sts_fallback_endpoints"] = []string{"1.1.1.1", "8.8.8.8"}
configData["sts_fallback_regions"] = []string{"zone-1", "zone-2"}
require.Equal(t, configData, resp.Data)
if !reflect.DeepEqual(resp.Data, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}
}
// TestBackend_PathConfigRoot_STSFallback_mismatchedfallback ensures configuration writing will fail if the
// region/endpoint entries are different lengths
func TestBackend_PathConfigRoot_STSFallback_mismatchedfallback(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = &testSystemView{}
b := Backend(config)
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}
// test we can handle comma separated strings, per CommaStringSlice
configData := map[string]interface{}{
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"sts_region": "",
"sts_fallback_endpoints": "1.1.1.1,8.8.8.8",
"sts_fallback_regions": "zone-1,zone-2",
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
"identity_token_audience": "",
"identity_token_ttl": int64(0),
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err := b.HandleRequest(context.Background(), configReq)
if err != nil {
t.Fatalf("bad: config writing failed: err: %v", err)
}
if resp != nil && !resp.IsError() {
t.Fatalf("expected an error, but it successfully wrote")
}
}
// TestBackend_PathConfigRoot_PluginIdentityToken tests that configuration
// of plugin WIF returns an immediate error.
func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) {

3
changelog/29051.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:improvement
secrets/aws: add fallback endpoint and region parameters to sts configuration
```

View File

@@ -79,6 +79,12 @@ valid AWS credentials with proper permissions.
- `sts_endpoint` `(string: <optional>)`  Specifies a custom HTTP STS endpoint to use.
- `sts_region` `(string: <optional>)` - Specifies a custom STS region to use (should match `sts_endpoint`)
- `sts_fallback_endpoints` `(list: <optional>)` - Specifies an ordered list of fallback STS endpoints to use
- `sts_fallback_regions` `(list: <optional>)` - Specifies an ordered list of fallback STS regions to use (should match fallback endpoints)
- `username_template` `(string: <optional>)` - [Template](/vault/docs/concepts/username-templating) describing how
dynamic usernames are generated. The username template is used to generate both IAM usernames (capped at 64 characters)
and STS usernames (capped at 32 characters). Longer usernames result in a 500 error.