mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	adds plugin identity exchange for AWS secrets engine
This commit is contained in:
		@@ -141,7 +141,7 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IA
 | 
			
		||||
		return b.iamClient, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	iamClient, err := nonCachedClientIAM(ctx, s, b.Logger())
 | 
			
		||||
	iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -168,7 +168,7 @@ func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.ST
 | 
			
		||||
		return b.stsClient, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	stsClient, err := nonCachedClientSTS(ctx, s, b.Logger())
 | 
			
		||||
	stsClient, err := b.nonCachedClientSTS(ctx, s, b.Logger())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,19 +7,24 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/aws/aws-sdk-go/aws"
 | 
			
		||||
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
 | 
			
		||||
	"github.com/aws/aws-sdk-go/aws/session"
 | 
			
		||||
	"github.com/aws/aws-sdk-go/service/iam"
 | 
			
		||||
	"github.com/aws/aws-sdk-go/service/sts"
 | 
			
		||||
	cleanhttp "github.com/hashicorp/go-cleanhttp"
 | 
			
		||||
	"github.com/hashicorp/go-cleanhttp"
 | 
			
		||||
	"github.com/hashicorp/go-hclog"
 | 
			
		||||
	"github.com/hashicorp/go-secure-stdlib/awsutil"
 | 
			
		||||
	"github.com/hashicorp/vault/helper/namespace"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/helper/pluginutil"
 | 
			
		||||
	"github.com/hashicorp/vault/sdk/logical"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
 | 
			
		||||
func getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) {
 | 
			
		||||
func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) {
 | 
			
		||||
	credsConfig := &awsutil.CredentialsConfig{}
 | 
			
		||||
	var endpoint string
 | 
			
		||||
	var maxRetries int = aws.UseServiceDefaultRetries
 | 
			
		||||
@@ -44,6 +49,27 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo
 | 
			
		||||
		case clientType == "sts" && config.STSEndpoint != "":
 | 
			
		||||
			endpoint = *aws.String(config.STSEndpoint)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if config.IdentityTokenAudience != "" {
 | 
			
		||||
			ns, err := namespace.FromContext(ctx)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, fmt.Errorf("failed to get namespace from context: %w", err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			fetcher := &PluginIdentityTokenFetcher{
 | 
			
		||||
				sys:      b.System(),
 | 
			
		||||
				logger:   b.Logger(),
 | 
			
		||||
				key:      config.IdentityTokenKey,
 | 
			
		||||
				audience: config.IdentityTokenAudience,
 | 
			
		||||
				ns:       ns,
 | 
			
		||||
				ttl:      time.Duration(config.IdentityTokenTTLSeconds) * time.Second,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
 | 
			
		||||
			credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
 | 
			
		||||
			credsConfig.WebIdentityTokenFetcher = fetcher
 | 
			
		||||
			credsConfig.RoleARN = config.IdentityTokenRoleARN
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if credsConfig.Region == "" {
 | 
			
		||||
@@ -74,8 +100,8 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
 | 
			
		||||
	awsConfig, err := getRootConfig(ctx, s, "iam", logger)
 | 
			
		||||
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
 | 
			
		||||
	awsConfig, err := b.getRootConfig(ctx, s, "iam", logger)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -90,8 +116,8 @@ func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Log
 | 
			
		||||
	return client, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
 | 
			
		||||
	awsConfig, err := getRootConfig(ctx, s, "sts", logger)
 | 
			
		||||
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
 | 
			
		||||
	awsConfig, err := b.getRootConfig(ctx, s, "sts", logger)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -105,3 +131,38 @@ func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Log
 | 
			
		||||
	}
 | 
			
		||||
	return client, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
 | 
			
		||||
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
 | 
			
		||||
// When the client's STS credentials expire, it will use this interface to fetch a new
 | 
			
		||||
// plugin identity token and exchange it for new STS credentials.
 | 
			
		||||
type PluginIdentityTokenFetcher struct {
 | 
			
		||||
	sys      logical.SystemView
 | 
			
		||||
	logger   hclog.Logger
 | 
			
		||||
	key      string
 | 
			
		||||
	audience string
 | 
			
		||||
	ns       *namespace.Namespace
 | 
			
		||||
	ttl      time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil)
 | 
			
		||||
 | 
			
		||||
func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) {
 | 
			
		||||
	nsCtx := namespace.ContextWithNamespace(ctx, f.ns)
 | 
			
		||||
	resp, err := f.sys.GenerateIdentityToken(nsCtx, pluginutil.IdentityTokenRequest{
 | 
			
		||||
		Key:      f.key,
 | 
			
		||||
		Audience: f.audience,
 | 
			
		||||
		TTL:      f.ttl,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to generate plugin identity token: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	f.logger.Info("fetched new plugin identity token")
 | 
			
		||||
 | 
			
		||||
	if resp.TTL < f.ttl {
 | 
			
		||||
		f.logger.Debug("generated plugin identity token has shorter TTL than requested",
 | 
			
		||||
			"requested", f.ttl.Seconds(), "actual", resp.TTL)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return []byte(resp.Token), nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -54,6 +54,38 @@ func pathConfigRoot(b *backend) *framework.Path {
 | 
			
		||||
				Type:        framework.TypeString,
 | 
			
		||||
				Description: "Template to generate custom IAM usernames",
 | 
			
		||||
			},
 | 
			
		||||
			"identity_token_audience": {
 | 
			
		||||
				Type:        framework.TypeString,
 | 
			
		||||
				Description: "",
 | 
			
		||||
				Default:     "",
 | 
			
		||||
				DisplayAttrs: &framework.DisplayAttributes{
 | 
			
		||||
					Name: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"identity_token_key": {
 | 
			
		||||
				Type:        framework.TypeString,
 | 
			
		||||
				Description: "",
 | 
			
		||||
				Default:     "",
 | 
			
		||||
				DisplayAttrs: &framework.DisplayAttributes{
 | 
			
		||||
					Name: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"identity_token_role_arn": {
 | 
			
		||||
				Type:        framework.TypeString,
 | 
			
		||||
				Description: "",
 | 
			
		||||
				Default:     "",
 | 
			
		||||
				DisplayAttrs: &framework.DisplayAttributes{
 | 
			
		||||
					Name: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"identity_token_ttl": {
 | 
			
		||||
				Type:        framework.TypeDurationSecond,
 | 
			
		||||
				Description: "",
 | 
			
		||||
				DisplayAttrs: &framework.DisplayAttributes{
 | 
			
		||||
					Name: "",
 | 
			
		||||
				},
 | 
			
		||||
				Default: 3600,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		Operations: map[logical.Operation]framework.OperationHandler{
 | 
			
		||||
@@ -118,6 +150,11 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
 | 
			
		||||
		usernameTemplate = defaultUserNameTemplate
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	identityTokenAudience := data.Get("identity_token_audience").(string)
 | 
			
		||||
	identityTokenKey := data.Get("identity_token_key").(string)
 | 
			
		||||
	identityTokenTTL := data.Get("identity_token_ttl").(int)
 | 
			
		||||
	identityTokenRoleARN := data.Get("identity_token_role_arn").(string)
 | 
			
		||||
 | 
			
		||||
	b.clientMutex.Lock()
 | 
			
		||||
	defer b.clientMutex.Unlock()
 | 
			
		||||
 | 
			
		||||
@@ -129,6 +166,10 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
 | 
			
		||||
		Region:                  region,
 | 
			
		||||
		MaxRetries:              maxretries,
 | 
			
		||||
		UsernameTemplate:        usernameTemplate,
 | 
			
		||||
		IdentityTokenRoleARN:    identityTokenRoleARN,
 | 
			
		||||
		IdentityTokenAudience:   identityTokenAudience,
 | 
			
		||||
		IdentityTokenKey:        identityTokenKey,
 | 
			
		||||
		IdentityTokenTTLSeconds: identityTokenTTL,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -154,6 +195,10 @@ type rootConfig struct {
 | 
			
		||||
	Region                  string `json:"region"`
 | 
			
		||||
	MaxRetries              int    `json:"max_retries"`
 | 
			
		||||
	UsernameTemplate        string `json:"username_template"`
 | 
			
		||||
	IdentityTokenKey        string `json:"identity_token_key"`
 | 
			
		||||
	IdentityTokenTTLSeconds int    `json:"identity_token_ttl_seconds"`
 | 
			
		||||
	IdentityTokenAudience   string `json:"identity_token_audience"`
 | 
			
		||||
	IdentityTokenRoleARN    string `json:"identity_token_role_arn"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const pathConfigRootHelpSyn = `
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@@ -100,7 +100,7 @@ require (
 | 
			
		||||
	github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a
 | 
			
		||||
	github.com/hashicorp/go-retryablehttp v0.7.4
 | 
			
		||||
	github.com/hashicorp/go-rootcerts v1.0.2
 | 
			
		||||
	github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3
 | 
			
		||||
	github.com/hashicorp/go-secure-stdlib/awsutil v0.2.4-0.20231108055638-37911e265025
 | 
			
		||||
	github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
 | 
			
		||||
	github.com/hashicorp/go-secure-stdlib/gatedwriter v0.1.1
 | 
			
		||||
	github.com/hashicorp/go-secure-stdlib/kv-builder v0.1.2
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@@ -2026,8 +2026,8 @@ github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5
 | 
			
		||||
github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU=
 | 
			
		||||
github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc=
 | 
			
		||||
github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8=
 | 
			
		||||
github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3 h1:AAQ6Vmo/ncfrZYtbpjhO+g0Qt+iNpYtl3UWT1NLmbYY=
 | 
			
		||||
github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg=
 | 
			
		||||
github.com/hashicorp/go-secure-stdlib/awsutil v0.2.4-0.20231108055638-37911e265025 h1:aBt1QxQZxZ8bfXlvpYSaL96sho9okKcCt8WtooD1ONk=
 | 
			
		||||
github.com/hashicorp/go-secure-stdlib/awsutil v0.2.4-0.20231108055638-37911e265025/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg=
 | 
			
		||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
 | 
			
		||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
 | 
			
		||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
 | 
			
		||||
 
 | 
			
		||||
@@ -1818,6 +1818,15 @@ func (i *IdentityStore) generatePublicJWKS(ctx context.Context, s logical.Storag
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Always add the default key
 | 
			
		||||
	defaultKeyIDs, err := i.keyIDsByName(ctx, s, defaultKeyName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to load default key IDs: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	for _, id := range defaultKeyIDs {
 | 
			
		||||
		keyIDs[id] = struct{}{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jwks := &jose.JSONWebKeySet{
 | 
			
		||||
		Keys: make([]jose.JSONWebKey, 0, len(keyIDs)),
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user