WIF support for AWS secrets engine (#24987)

* add new plugin wif fields to AWS Secrets Engine

* add changelog

* go get awsutil v0.3.0

* fix up changelog

* fix test and field parsing helper

* godoc on new test

* require role arn when audience set

* make fmt

---------

Co-authored-by: Austin Gebauer <agebauer@hashicorp.com>
Co-authored-by: Austin Gebauer <34121980+austingebauer@users.noreply.github.com>
This commit is contained in:
vinay-gopalan
2024-01-29 11:34:57 -08:00
committed by GitHub
parent 2acac70160
commit fcf7cf6c22
10 changed files with 224 additions and 35 deletions

View File

@@ -141,7 +141,7 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IA
return b.iamClient, nil
}
iamClient, err := nonCachedClientIAM(ctx, s, b.Logger())
iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger())
if err != nil {
return nil, err
}
@@ -168,7 +168,7 @@ func (b *backend) clientSTS(ctx context.Context, s logical.Storage) (stsiface.ST
return b.stsClient, nil
}
stsClient, err := nonCachedClientSTS(ctx, s, b.Logger())
stsClient, err := b.nonCachedClientSTS(ctx, s, b.Logger())
if err != nil {
return nil, err
}

View File

@@ -7,19 +7,25 @@ import (
"context"
"fmt"
"os"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
// NOTE: The caller is required to ensure that b.clientMutex is at least read locked
func getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) {
func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{}
var endpoint string
var maxRetries int = aws.UseServiceDefaultRetries
@@ -44,6 +50,26 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo
case clientType == "sts" && config.STSEndpoint != "":
endpoint = *aws.String(config.STSEndpoint)
}
if config.IdentityTokenAudience != "" {
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get namespace from context: %w", err)
}
fetcher := &PluginIdentityTokenFetcher{
sys: b.System(),
logger: b.Logger(),
ns: ns,
audience: config.IdentityTokenAudience,
ttl: config.IdentityTokenTTL,
}
sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10)
credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix)
credsConfig.WebIdentityTokenFetcher = fetcher
credsConfig.RoleARN = config.RoleARN
}
}
if credsConfig.Region == "" {
@@ -74,8 +100,8 @@ func getRootConfig(ctx context.Context, s logical.Storage, clientType string, lo
}, nil
}
func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := getRootConfig(ctx, s, "iam", logger)
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := b.getRootConfig(ctx, s, "iam", logger)
if err != nil {
return nil, err
}
@@ -90,8 +116,8 @@ func nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Log
return client, nil
}
func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
awsConfig, err := getRootConfig(ctx, s, "sts", logger)
func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) {
awsConfig, err := b.getRootConfig(ctx, s, "sts", logger)
if err != nil {
return nil, err
}
@@ -105,3 +131,36 @@ func nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Log
}
return client, nil
}
// PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided
// to the AWS SDK client to keep assumed role credentials refreshed through expiration.
// When the client's STS credentials expire, it will use this interface to fetch a new
// plugin identity token and exchange it for new STS credentials.
type PluginIdentityTokenFetcher struct {
sys logical.SystemView
logger hclog.Logger
audience string
ns *namespace.Namespace
ttl time.Duration
}
var _ stscreds.TokenFetcher = (*PluginIdentityTokenFetcher)(nil)
func (f PluginIdentityTokenFetcher) FetchToken(ctx aws.Context) ([]byte, error) {
nsCtx := namespace.ContextWithNamespace(ctx, f.ns)
resp, err := f.sys.GenerateIdentityToken(nsCtx, &pluginutil.IdentityTokenRequest{
Audience: f.audience,
TTL: f.ttl,
})
if err != nil {
return nil, fmt.Errorf("failed to generate plugin identity token: %w", err)
}
f.logger.Info("fetched new plugin identity token")
if resp.TTL < f.ttl {
f.logger.Debug("generated plugin identity token has shorter TTL than requested",
"requested", f.ttl, "actual", resp.TTL)
}
return []byte(resp.Token.Token()), nil
}

View File

@@ -7,7 +7,9 @@ import (
"context"
"github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/pluginidentityutil"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -15,7 +17,7 @@ import (
const defaultUserNameTemplate = `{{ if (eq .Type "STS") }}{{ printf "vault-%s-%s" (unix_time) (random 20) | truncate 32 }}{{ else }}{{ printf "vault-%s-%s-%s" (printf "%s-%s" (.DisplayName) (.PolicyName) | truncate 42) (unix_time) (random 20) | truncate 64 }}{{ end }}`
func pathConfigRoot(b *backend) *framework.Path {
return &framework.Path{
p := &framework.Path{
Pattern: "config/root",
DisplayAttrs: &framework.DisplayAttributes{
@@ -54,6 +56,10 @@ func pathConfigRoot(b *backend) *framework.Path {
Type: framework.TypeString,
Description: "Template to generate custom IAM usernames",
},
"role_arn": {
Type: framework.TypeString,
Description: "Role ARN to assume for plugin identity token federation",
},
},
Operations: map[logical.Operation]framework.OperationHandler{
@@ -75,6 +81,9 @@ func pathConfigRoot(b *backend) *framework.Path {
HelpSynopsis: pathConfigRootHelpSyn,
HelpDescription: pathConfigRootHelpDesc,
}
pluginidentityutil.AddPluginIdentityTokenFields(p.Fields)
return p
}
func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@@ -102,7 +111,10 @@ func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request,
"sts_endpoint": config.STSEndpoint,
"max_retries": config.MaxRetries,
"username_template": config.UsernameTemplate,
"role_arn": config.RoleARN,
}
config.PopulatePluginIdentityTokenData(configData)
return &logical.Response{
Data: configData,
}, nil
@@ -113,6 +125,7 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
iamendpoint := data.Get("iam_endpoint").(string)
stsendpoint := data.Get("sts_endpoint").(string)
maxretries := data.Get("max_retries").(int)
roleARN := data.Get("role_arn").(string)
usernameTemplate := data.Get("username_template").(string)
if usernameTemplate == "" {
usernameTemplate = defaultUserNameTemplate
@@ -121,7 +134,7 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
b.clientMutex.Lock()
defer b.clientMutex.Unlock()
entry, err := logical.StorageEntryJSON("config/root", rootConfig{
rc := rootConfig{
AccessKey: data.Get("access_key").(string),
SecretKey: data.Get("secret_key").(string),
IAMEndpoint: iamendpoint,
@@ -129,7 +142,21 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
Region: region,
MaxRetries: maxretries,
UsernameTemplate: usernameTemplate,
})
RoleARN: roleARN,
}
if err := rc.ParsePluginIdentityTokenFields(data); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
if rc.IdentityTokenAudience != "" && rc.AccessKey != "" {
return logical.ErrorResponse("only one of 'access_key' or 'identity_token_audience' can be set"), nil
}
if rc.IdentityTokenAudience != "" && rc.RoleARN == "" {
return logical.ErrorResponse("missing required 'role_arn' when 'identity_token_audience' is set"), nil
}
entry, err := logical.StorageEntryJSON("config/root", rc)
if err != nil {
return nil, err
}
@@ -147,6 +174,8 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
}
type rootConfig struct {
pluginidentityutil.PluginIdentityTokenParams
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
IAMEndpoint string `json:"iam_endpoint"`
@@ -154,6 +183,7 @@ type rootConfig struct {
Region string `json:"region"`
MaxRetries int `json:"max_retries"`
UsernameTemplate string `json:"username_template"`
RoleARN string `json:"role_arn"`
}
const pathConfigRootHelpSyn = `

View File

@@ -6,9 +6,11 @@ package aws
import (
"context"
"reflect"
"strings"
"testing"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
func TestBackend_PathConfigRoot(t *testing.T) {
@@ -21,13 +23,16 @@ func TestBackend_PathConfigRoot(t *testing.T) {
}
configData := map[string]interface{}{
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"access_key": "AKIAEXAMPLE",
"secret_key": "RandomData",
"region": "us-west-2",
"iam_endpoint": "https://iam.amazonaws.com",
"sts_endpoint": "https://sts.us-west-2.amazonaws.com",
"max_retries": 10,
"username_template": defaultUserNameTemplate,
"role_arn": "",
"identity_token_audience": "",
"identity_token_ttl": int64(0),
}
configReq := &logical.Request{
@@ -52,7 +57,102 @@ func TestBackend_PathConfigRoot(t *testing.T) {
}
delete(configData, "secret_key")
require.Equal(t, configData, resp.Data)
if !reflect.DeepEqual(resp.Data, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}
}
// TestBackend_PathConfigRoot_PluginIdentityToken tests parsing and validation of
// configuration used to set the secret engine up for web identity federation using
// plugin identity tokens.
func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b := Backend(config)
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}
configData := map[string]interface{}{
"identity_token_ttl": int64(10),
"identity_token_audience": "test-aud",
"role_arn": "test-role-arn",
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err := b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Storage: config.StorageView,
Path: "config/root",
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err)
}
// Grab the subset of fields from the response we care to look at for this case
got := map[string]interface{}{
"identity_token_ttl": resp.Data["identity_token_ttl"],
"identity_token_audience": resp.Data["identity_token_audience"],
"role_arn": resp.Data["role_arn"],
}
if !reflect.DeepEqual(got, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}
// mutually exclusive fields must result in an error
configData = map[string]interface{}{
"identity_token_audience": "test-aud",
"access_key": "ASIAIO10230XVB",
}
configReq = &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if !resp.IsError() {
t.Fatalf("expected an error but got nil")
}
expectedError := "only one of 'access_key' or 'identity_token_audience' can be set"
if !strings.Contains(resp.Error().Error(), expectedError) {
t.Fatalf("expected err %s, got %s", expectedError, resp.Error())
}
// missing role arn with audience must result in an error
configData = map[string]interface{}{
"identity_token_audience": "test-aud",
}
configReq = &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if !resp.IsError() {
t.Fatalf("expected an error but got nil")
}
expectedError = "missing required 'role_arn' when 'identity_token_audience' is set"
if !strings.Contains(resp.Error().Error(), expectedError) {
t.Fatalf("expected err %s, got %s", expectedError, resp.Error())
}
}

3
changelog/24987.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:feature
**Plugin Identity Tokens**: Adds secret-less configuration of AWS secret engine using web identity federation.
```

2
go.mod
View File

@@ -101,7 +101,7 @@ require (
github.com/hashicorp/go-raftchunking v0.6.3-0.20191002164813-7e9e8525653a
github.com/hashicorp/go-retryablehttp v0.7.4
github.com/hashicorp/go-rootcerts v1.0.2
github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3
github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-secure-stdlib/gatedwriter v0.1.1
github.com/hashicorp/go-secure-stdlib/kv-builder v0.1.2

4
go.sum
View File

@@ -2164,8 +2164,8 @@ github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5
github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU=
github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc=
github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8=
github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3 h1:AAQ6Vmo/ncfrZYtbpjhO+g0Qt+iNpYtl3UWT1NLmbYY=
github.com/hashicorp/go-secure-stdlib/awsutil v0.2.3/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg=
github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0 h1:I8bynUKMh9I7JdwtW9voJ0xmHvBpxQtLjrMFDYmhOxY=
github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0/go.mod h1:oKHSQs4ivIfZ3fbXGQOop1XuDfdSb8RIsWTGaAanSfg=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=

View File

@@ -4,7 +4,6 @@
package pluginidentityutil
import (
"errors"
"fmt"
"time"
@@ -25,16 +24,10 @@ func (p *PluginIdentityTokenParams) ParsePluginIdentityTokenFields(d *framework.
if tokenTTLRaw, ok := d.GetOk("identity_token_ttl"); ok {
p.IdentityTokenTTL = time.Duration(tokenTTLRaw.(int)) * time.Second
}
if p.IdentityTokenTTL == 0 {
p.IdentityTokenTTL = time.Hour
}
if tokenAudienceRaw, ok := d.GetOk("identity_token_audience"); ok {
p.IdentityTokenAudience = tokenAudienceRaw.(string)
}
if p.IdentityTokenAudience == "" {
return errors.New("missing required identity_token_audience")
}
return nil
}

View File

@@ -39,7 +39,7 @@ func TestParsePluginIdentityTokenFields(t *testing.T) {
want map[string]interface{}
}{
{
name: "basic",
name: "all input",
d: identityTokenFieldData(map[string]interface{}{
fieldIDTokenTTL: 10,
fieldIDTokenAudience: "test-aud",
@@ -50,19 +50,24 @@ func TestParsePluginIdentityTokenFields(t *testing.T) {
},
},
{
name: "empty-ttl",
name: "empty ttl",
d: identityTokenFieldData(map[string]interface{}{
fieldIDTokenAudience: "test-aud",
}),
want: map[string]interface{}{
fieldIDTokenTTL: time.Hour,
fieldIDTokenTTL: time.Duration(0),
fieldIDTokenAudience: "test-aud",
},
},
{
name: "empty-audience",
d: identityTokenFieldData(map[string]interface{}{}),
wantErr: true,
name: "empty audience",
d: identityTokenFieldData(map[string]interface{}{
fieldIDTokenTTL: 10,
}),
want: map[string]interface{}{
fieldIDTokenTTL: time.Duration(10) * time.Second,
fieldIDTokenAudience: "",
},
},
}

View File

@@ -82,7 +82,6 @@ func stabilizeAndPromote(t *testing.T, client *api.Client, nodeID string) {
var err error
for time.Now().Before(deadline) {
state, err = client.Sys().RaftAutopilotState()
// If the state endpoint gets called during a leader election, we'll get an error about
// there not being an active cluster node. Rather than erroring out of this loop, just
// ignore the error and keep trying. It should resolve in a few seconds. There's a