diff --git a/builtin/logical/aws/backend.go b/builtin/logical/aws/backend.go index ed8ac00c9d..b33fb1b4d6 100644 --- a/builtin/logical/aws/backend.go +++ b/builtin/logical/aws/backend.go @@ -141,7 +141,7 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IA return b.iamClient, nil } - iamClient, err := nonCachedClientIAM(ctx, s, b.Logger()) + iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger()) if err != nil { return nil, err } @@ -168,7 +168,7 @@ func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.ST return b.stsClient, nil } - stsClient, err := nonCachedClientSTS(ctx, s, b.Logger()) + stsClient, err := b.nonCachedClientSTS(ctx, s, b.Logger()) if err != nil { return nil, err } diff --git a/builtin/logical/aws/client.go b/builtin/logical/aws/client.go index 33dc86c517..fbc8e80c87 100644 --- a/builtin/logical/aws/client.go +++ b/builtin/logical/aws/client.go @@ -7,19 +7,24 @@ import ( "context" "fmt" "os" + "strconv" + "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/sts" - cleanhttp "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/awsutil" + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" ) // NOTE: The caller is required to ensure that b.clientMutex is at least read locked -func getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) { +func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) { credsConfig := &awsutil.CredentialsConfig{} var endpoint string var maxRetries int = aws.UseServiceDefaultRetries @@ -44,6 +49,27 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo case clientType == "sts" && config.STSEndpoint != "": endpoint = *aws.String(config.STSEndpoint) } + + if config.IdentityTokenAudience != "" { + ns, err := namespace.FromContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get namespace from context: %w", err) + } + + fetcher := &PluginIdentityTokenFetcher{ + sys: b.System(), + logger: b.Logger(), + key: config.IdentityTokenKey, + audience: config.IdentityTokenAudience, + ns: ns, + ttl: time.Duration(config.IdentityTokenTTLSeconds) * time.Second, + } + + sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10) + credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix) + credsConfig.WebIdentityTokenFetcher = fetcher + credsConfig.RoleARN = config.IdentityTokenRoleARN + } } if credsConfig.Region == "" { @@ -74,8 +100,8 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo }, nil } -func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) { - awsConfig, err := getRootConfig(ctx, s, "iam", logger) +func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) { + awsConfig, err := b.getRootConfig(ctx, s, "iam", logger) if err != nil { return nil, err } @@ -90,8 +116,8 @@ func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Log return client, nil } -func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) { - awsConfig, err := getRootConfig(ctx, s, "sts", logger) +func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) { + awsConfig, err := b.getRootConfig(ctx, s, "sts", logger) if err != nil { return nil, err } @@ -105,3 +131,38 @@ func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Log } return client, nil } + +// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided +// to the AWS SDK client to keep assumed role credentials refreshed through expiration. +// When the client's STS credentials expire, it will use this interface to fetch a new +// plugin identity token and exchange it for new STS credentials. +type PluginIdentityTokenFetcher struct { + sys logical.SystemView + logger hclog.Logger + key string + audience string + ns *namespace.Namespace + ttl time.Duration +} + +var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil) + +func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) { + nsCtx := namespace.ContextWithNamespace(ctx, f.ns) + resp, err := f.sys.GenerateIdentityToken(nsCtx, pluginutil.IdentityTokenRequest{ + Key: f.key, + Audience: f.audience, + TTL: f.ttl, + }) + if err != nil { + return nil, fmt.Errorf("failed to generate plugin identity token: %w", err) + } + f.logger.Info("fetched new plugin identity token") + + if resp.TTL < f.ttl { + f.logger.Debug("generated plugin identity token has shorter TTL than requested", + "requested", f.ttl.Seconds(), "actual", resp.TTL) + } + + return []byte(resp.Token), nil +} diff --git a/builtin/logical/aws/path_config_root.go b/builtin/logical/aws/path_config_root.go index 5b5e3f1ce6..947f27d073 100644 --- a/builtin/logical/aws/path_config_root.go +++ b/builtin/logical/aws/path_config_root.go @@ -54,6 +54,38 @@ func pathConfigRoot(b *backend) *framework.Path { Type: framework.TypeString, Description: "Template to generate custom IAM usernames", }, + "identity_token_audience": { + Type: framework.TypeString, + Description: "", + Default: "", + DisplayAttrs: &framework.DisplayAttributes{ + Name: "", + }, + }, + "identity_token_key": { + Type: framework.TypeString, + Description: "", + Default: "", + DisplayAttrs: &framework.DisplayAttributes{ + Name: "", + }, + }, + "identity_token_role_arn": { + Type: framework.TypeString, + Description: "", + Default: "", + DisplayAttrs: &framework.DisplayAttributes{ + Name: "", + }, + }, + "identity_token_ttl": { + Type: framework.TypeDurationSecond, + Description: "", + DisplayAttrs: &framework.DisplayAttributes{ + Name: "", + }, + Default: 3600, + }, }, Operations: map[logical.Operation]framework.OperationHandler{ @@ -118,17 +150,26 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, usernameTemplate = defaultUserNameTemplate } + identityTokenAudience := data.Get("identity_token_audience").(string) + identityTokenKey := data.Get("identity_token_key").(string) + identityTokenTTL := data.Get("identity_token_ttl").(int) + identityTokenRoleARN := data.Get("identity_token_role_arn").(string) + b.clientMutex.Lock() defer b.clientMutex.Unlock() entry, err := logical.StorageEntryJSON("config/root", rootConfig{ - AccessKey: data.Get("access_key").(string), - SecretKey: data.Get("secret_key").(string), - IAMEndpoint: iamendpoint, - STSEndpoint: stsendpoint, - Region: region, - MaxRetries: maxretries, - UsernameTemplate: usernameTemplate, + AccessKey: data.Get("access_key").(string), + SecretKey: data.Get("secret_key").(string), + IAMEndpoint: iamendpoint, + STSEndpoint: stsendpoint, + Region: region, + MaxRetries: maxretries, + UsernameTemplate: usernameTemplate, + IdentityTokenRoleARN: identityTokenRoleARN, + IdentityTokenAudience: identityTokenAudience, + IdentityTokenKey: identityTokenKey, + IdentityTokenTTLSeconds: identityTokenTTL, }) if err != nil { return nil, err @@ -147,13 +188,17 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, } type rootConfig struct { - AccessKey string `json:"access_key"` - SecretKey string `json:"secret_key"` - IAMEndpoint string `json:"iam_endpoint"` - STSEndpoint string `json:"sts_endpoint"` - Region string `json:"region"` - MaxRetries int `json:"max_retries"` - UsernameTemplate string `json:"username_template"` + AccessKey string `json:"access_key"` + SecretKey string `json:"secret_key"` + IAMEndpoint string `json:"iam_endpoint"` + STSEndpoint string `json:"sts_endpoint"` + Region string `json:"region"` + MaxRetries int `json:"max_retries"` + UsernameTemplate string `json:"username_template"` + IdentityTokenKey string `json:"identity_token_key"` + IdentityTokenTTLSeconds int `json:"identity_token_ttl_seconds"` + IdentityTokenAudience string `json:"identity_token_audience"` + IdentityTokenRoleARN string `json:"identity_token_role_arn"` } const pathConfigRootHelpSyn = ` diff --git a/go.mod b/go.mod index 082004c95a..5b3a6e16a7 100644 --- a/go.mod +++ b/go.mod @@ -100,7 +100,7 @@ require ( github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a github.com/hashicorp/go-retryablehttp v0.7.4 github.com/hashicorp/go-rootcerts v1.0.2 - github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3 + github.com/hashicorp/go-secure-stdlib/awsutil v0.2.4-0.20231108055638-37911e265025 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-secure-stdlib/gatedwriter v0.1.1 github.com/hashicorp/go-secure-stdlib/kv-builder v0.1.2 diff --git a/go.sum b/go.sum index 7114846a0f..db4b9fb913 100644 --- a/go.sum +++ b/go.sum @@ -2026,8 +2026,8 @@ github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5 github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= -github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3 h1:AAQ6Vmo/ncfrZYtbpjhO+g0Qt+iNpYtl3UWT1NLmbYY= -github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg= +github.com/hashicorp/go-secure-stdlib/awsutil v0.2.4-0.20231108055638-37911e265025 h1:aBt1QxQZxZ8bfXlvpYSaL96sho9okKcCt8WtooD1ONk= +github.com/hashicorp/go-secure-stdlib/awsutil v0.2.4-0.20231108055638-37911e265025/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg= github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng= github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw= diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index d7a85c9df6..88813f716d 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1818,6 +1818,15 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag } } + // Always add the default key + defaultKeyIDs, err := i.keyIDsByName(ctx, s, defaultKeyName) + if err != nil { + return nil, fmt.Errorf("failed to load default key IDs: %w", err) + } + for _, id := range defaultKeyIDs { + keyIDs[id] = struct{}{} + } + jwks := &jose.JSONWebKeySet{ Keys: make([]jose.JSONWebKey, 0, len(keyIDs)), }