Add stubs for plugin WIF (#25657)

* Add stubs for plugin wif

* add header to sdk file

* drop changelog to move it

* fix test
This commit is contained in:
Austin Gebauer
2024-02-27 12:10:43 -08:00
committed by GitHub
parent 1fd5c34c1c
commit df57ff46ff
7 changed files with 75 additions and 359 deletions

View File

@@ -5,10 +5,12 @@ package aws
import (
"context"
"errors"
"github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/pluginidentityutil"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)
@@ -155,6 +157,18 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("missing required 'role_arn' when 'identity_token_audience' is set"), nil
}
if rc.IdentityTokenAudience != "" {
_, err := b.System().GenerateIdentityToken(ctx, &pluginutil.IdentityTokenRequest{
Audience: rc.IdentityTokenAudience,
})
if err != nil {
if errors.Is(err, pluginidentityutil.ErrPluginWorkloadIdentityUnsupported) {
return logical.ErrorResponse(err.Error()), nil
}
return nil, err
}
}
entry, err := logical.StorageEntryJSON("config/root", rc)
if err != nil {
return nil, err

View File

@@ -6,10 +6,12 @@ package aws
import (
"context"
"reflect"
"strings"
"testing"
"github.com/hashicorp/vault/sdk/helper/pluginidentityutil"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -63,12 +65,12 @@ func TestBackend_PathConfigRoot(t *testing.T) {
}
}
// TestBackend_PathConfigRoot_PluginIdentityToken tests parsing and validation of
// configuration used to set the secret engine up for web identity federation using
// plugin identity tokens.
// TestBackend_PathConfigRoot_PluginIdentityToken tests that configuration
// of plugin WIF returns an immediate error.
func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) {
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = &testSystemView{}
b := Backend(config)
if err := b.Setup(context.Background(), config); err != nil {
@@ -89,70 +91,15 @@ func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) {
}
resp, err := b.HandleRequest(context.Background(), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err)
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Storage: config.StorageView,
Path: "config/root",
})
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err)
}
// Grab the subset of fields from the response we care to look at for this case
got := map[string]interface{}{
"identity_token_ttl": resp.Data["identity_token_ttl"],
"identity_token_audience": resp.Data["identity_token_audience"],
"role_arn": resp.Data["role_arn"],
}
if !reflect.DeepEqual(got, configData) {
t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data)
}
// mutually exclusive fields must result in an error
configData = map[string]interface{}{
"identity_token_audience": "test-aud",
"access_key": "ASIAIO10230XVB",
}
configReq = &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if !resp.IsError() {
t.Fatalf("expected an error but got nil")
}
expectedError := "only one of 'access_key' or 'identity_token_audience' can be set"
if !strings.Contains(resp.Error().Error(), expectedError) {
t.Fatalf("expected err %s, got %s", expectedError, resp.Error())
}
// missing role arn with audience must result in an error
configData = map[string]interface{}{
"identity_token_audience": "test-aud",
}
configReq = &logical.Request{
Operation: logical.UpdateOperation,
Storage: config.StorageView,
Path: "config/root",
Data: configData,
}
resp, err = b.HandleRequest(context.Background(), configReq)
if !resp.IsError() {
t.Fatalf("expected an error but got nil")
}
expectedError = "missing required 'role_arn' when 'identity_token_audience' is set"
if !strings.Contains(resp.Error().Error(), expectedError) {
t.Fatalf("expected err %s, got %s", expectedError, resp.Error())
}
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.ErrorContains(t, resp.Error(), pluginidentityutil.ErrPluginWorkloadIdentityUnsupported.Error())
}
type testSystemView struct {
logical.StaticSystemView
}
func (d testSystemView) GenerateIdentityToken(_ context.Context, _ *pluginutil.IdentityTokenRequest) (*pluginutil.IdentityTokenResponse, error) {
return nil, pluginidentityutil.ErrPluginWorkloadIdentityUnsupported
}

View File

@@ -1,3 +0,0 @@
```release-note:feature
**Plugin Workload Identity**: Vault can generate identity tokens for plugins to use in workload identity federation auth flows.
```

View File

@@ -0,0 +1,8 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package pluginidentityutil
import "errors"
var ErrPluginWorkloadIdentityUnsupported = errors.New("plugin workload identity not supported in Vault community edition")

View File

@@ -64,10 +64,6 @@ func (c *oidcConfig) fullIssuer(child string) (string, error) {
return issuer, nil
}
func validChildIssuer(child string) bool {
return child == baseIdentityTokenIssuer || child == pluginIdentityTokenIssuer
}
type expireableKey struct {
KeyID string `json:"key_id"`
ExpireAt time.Time `json:"expire_at"`
@@ -150,22 +146,15 @@ var (
)
const (
issuerPath = "identity/oidc"
oidcTokensPrefix = "oidc_tokens/"
namedKeyCachePrefix = "namedKeys/"
oidcConfigStorageKey = oidcTokensPrefix + "config/"
namedKeyConfigPath = oidcTokensPrefix + "named_keys/"
publicKeysConfigPath = oidcTokensPrefix + "public_keys/"
roleConfigPath = oidcTokensPrefix + "roles/"
// Identity tokens have a base issuer and plugin issuer
baseIdentityTokenIssuer = ""
pluginIdentityTokenIssuer = "plugins"
pluginTokenSubjectPrefix = "plugin-identity"
pluginTokenPrivateClaimKey = "vaultproject.io"
secretTableValue = "secret"
deleteKeyErrorFmt = "unable to delete key %q because it is currently referenced by these %s: %s"
issuerPath = "identity/oidc"
oidcTokensPrefix = "oidc_tokens/"
namedKeyCachePrefix = "namedKeys/"
oidcConfigStorageKey = oidcTokensPrefix + "config/"
namedKeyConfigPath = oidcTokensPrefix + "named_keys/"
publicKeysConfigPath = oidcTokensPrefix + "public_keys/"
roleConfigPath = oidcTokensPrefix + "roles/"
baseIdentityTokenIssuer = ""
deleteKeyErrorFmt = "unable to delete key %q because it is currently referenced by these %s: %s"
)
// optionalChildIssuerRegex is a regex for optionally accepting a field in an
@@ -1096,99 +1085,6 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
return retResp, nil
}
func (i *IdentityStore) generatePluginIdentityToken(ctx context.Context, storage logical.Storage, me *MountEntry, audience string, ttl time.Duration) (string, time.Duration, error) {
ns, err := namespace.FromContext(ctx)
if err != nil {
return "", 0, err
}
if me == nil {
i.Logger().Error("unexpected nil mount entry when generating plugin identity token")
return "", 0, errors.New("mount entry must not be nil")
}
key := defaultKeyName
if me.Config.IdentityTokenKey != "" {
key = me.Config.IdentityTokenKey
}
if ttl == 0 {
ttl = time.Hour
}
namedKey, err := i.getNamedKey(ctx, storage, key)
if err != nil {
return "", 0, err
}
if namedKey == nil {
return "", 0, fmt.Errorf("key %q not found", key)
}
// Validate that the role is allowed to sign with its key (the key could have been updated)
if !strutil.StrListContains(namedKey.AllowedClientIDs, "*") && !strutil.StrListContains(namedKey.AllowedClientIDs, audience) {
return "", 0, fmt.Errorf("the key %q does not list %q as an allowed audience", key, audience)
}
config, err := i.getOIDCConfig(ctx, storage)
if err != nil {
return "", 0, err
}
// Cap the TTL to the key's verification TTL. This is the maximum amount of
// time the key will remain in the JWKS after it's been rotated.
if ttl > namedKey.VerificationTTL {
ttl = namedKey.VerificationTTL
}
// Tokens for plugins have a distinct issuer from Vault's identity token issuer
issuer, err := config.fullIssuer(pluginIdentityTokenIssuer)
if err != nil {
return "", 0, err
}
// The subject uniquely identifies the plugin
subject := fmt.Sprintf("%s:%s:%s:%s", pluginTokenSubjectPrefix, ns.ID,
translateTableClaim(me.Table), me.Accessor)
now := time.Now()
claims := map[string]any{
"iss": issuer,
"sub": subject,
"aud": []string{audience},
"nbf": now.Unix(),
"iat": now.Unix(),
"exp": now.Add(ttl).Unix(),
pluginTokenPrivateClaimKey: map[string]any{
"namespace_id": ns.ID,
"namespace_path": ns.Path,
"class": translateTableClaim(me.Table),
"plugin": me.Type,
"version": me.RunningVersion,
"path": me.Path,
"accessor": me.Accessor,
"local": me.Local,
},
}
payload, err := json.Marshal(claims)
if err != nil {
return "", 0, err
}
signedToken, err := namedKey.signPayload(payload)
if err != nil {
return "", 0, fmt.Errorf("error signing plugin identity token: %w", err)
}
return signedToken, ttl, nil
}
func translateTableClaim(table string) string {
switch table {
case mountTableType:
return secretTableValue
default:
return table
}
}
func (i *IdentityStore) getNamedKey(ctx context.Context, s logical.Storage, name string) (*namedKey, error) {
ns, err := namespace.FromContext(ctx)
if err != nil {

View File

@@ -0,0 +1,24 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package vault
import (
"context"
"time"
"github.com/hashicorp/vault/sdk/helper/pluginidentityutil"
"github.com/hashicorp/vault/sdk/logical"
)
//go:generate go run github.com/hashicorp/vault/tools/stubmaker
func (i *IdentityStore) generatePluginIdentityToken(_ context.Context, _ logical.Storage, _ *MountEntry, _ string, _ time.Duration) (string, time.Duration, error) {
return "", 0, pluginidentityutil.ErrPluginWorkloadIdentityUnsupported
}
func validChildIssuer(child string) bool {
return child == baseIdentityTokenIssuer
}

View File

@@ -5,7 +5,6 @@ package vault
import (
"context"
"crypto"
"encoding/json"
"fmt"
"regexp"
@@ -17,7 +16,6 @@ import (
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/go-test/deep"
capjwt "github.com/hashicorp/cap/jwt"
"github.com/hashicorp/go-hclog"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/helper/identity"
@@ -1787,12 +1785,6 @@ func Test_oidcConfig_fullIssuer(t *testing.T) {
child: baseIdentityTokenIssuer,
want: fmt.Sprintf("https://vault.dev/v1/%s", issuerPath),
},
{
name: "issuer with valid plugin child",
issuer: "http://127.0.0.1:8200",
child: pluginIdentityTokenIssuer,
want: fmt.Sprintf("http://127.0.0.1:8200/v1/%s/%s", issuerPath, pluginIdentityTokenIssuer),
},
{
name: "issuer with invalid child",
issuer: "http://127.0.0.1:8200",
@@ -1838,11 +1830,6 @@ func Test_validChildIssuer(t *testing.T) {
child: baseIdentityTokenIssuer,
want: true,
},
{
name: "valid child issuer",
child: pluginIdentityTokenIssuer,
want: true,
},
{
name: "invalid child issuer",
child: "test",
@@ -1869,8 +1856,8 @@ func Test_optionalChildIssuerRegex(t *testing.T) {
{
name: "valid match with capture",
pattern: "oidc" + optionalChildIssuerRegex("child") + "/.well-known/keys",
path: "oidc/plugins/.well-known/keys",
captures: map[string]string{"child": "plugins"},
path: "oidc/test/.well-known/keys",
captures: map[string]string{"child": "test"},
},
{
name: "valid match with capture name, segment, and path change",
@@ -1887,7 +1874,7 @@ func Test_optionalChildIssuerRegex(t *testing.T) {
{
name: "invalid match with multiple path segments",
pattern: "oidc" + optionalChildIssuerRegex("child") + "/.well-known/keys",
path: "oidc/plugins/invalid/.well-known/keys",
path: "oidc/test/invalid/.well-known/keys",
captures: map[string]string{},
},
}
@@ -1906,132 +1893,6 @@ func Test_optionalChildIssuerRegex(t *testing.T) {
}
}
// TestIdentityStore_generatePluginIdentityToken tests generation of plugin identity
// tokens by verifying signatures and validating claims.
func TestIdentityStore_generatePluginIdentityToken(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
core.credentialBackends["userpass"] = credUserpass.Factory
identityStore := core.IdentityStore()
identityStore.redirectAddr = "http://localhost:8200"
ctx := namespace.RootContext(nil)
storage := core.router.MatchingStorageByAPIPath(ctx, mountPathIdentity)
require.NotNil(t, storage)
// Create a key
testKey := "test-key"
testAudience := "allowed-audience"
resp, err := core.identityStore.HandleRequest(ctx, testKeyReq(storage, testKey,
[]string{testAudience}, "RS256"))
expectSuccess(t, resp, err)
// Enable a secret mount using the test key
createMountEntryWithKey(t, ctx, core.systemBackend, "mounts/", "kv/", testKey)
expectSuccess(t, resp, err)
secretMountEntry := core.router.MatchingMountEntry(ctx, "kv/")
require.NotNil(t, secretMountEntry)
// Enable an auth mount using the default key
createMountEntryWithKey(t, ctx, core.systemBackend, "auth/", "userpass/", defaultKeyName)
expectSuccess(t, resp, err)
authMountEntry := core.router.MatchingMountEntry(ctx, "auth/userpass/")
require.NotNil(t, authMountEntry)
tests := []struct {
name string
ctx context.Context
mountEntry *MountEntry
audience string
ttl time.Duration
wantErr bool
}{
{
name: "expect error with nil context",
ctx: nil,
wantErr: true,
},
{
name: "expect error with nil mount entry",
ctx: ctx,
mountEntry: nil,
wantErr: true,
},
{
name: "expect error with key that doesn't exist",
ctx: ctx,
mountEntry: &MountEntry{
Config: MountConfig{
IdentityTokenKey: "does-not-exist",
},
},
wantErr: true,
},
{
name: "expect error with audience that's not allowed by the key",
ctx: ctx,
mountEntry: secretMountEntry,
audience: "not-allowed-audience",
wantErr: true,
},
{
name: "expect valid identity token with secret mount using test key",
ctx: ctx,
mountEntry: secretMountEntry,
audience: testAudience,
},
{
name: "expect valid identity token with auth mount using default key",
ctx: ctx,
mountEntry: authMountEntry,
audience: testAudience,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, _, err := identityStore.generatePluginIdentityToken(tt.ctx, storage, tt.mountEntry,
tt.audience, tt.ttl)
if tt.wantErr {
require.Error(t, err)
require.Empty(t, token)
return
}
require.NoError(t, err)
require.NotEmpty(t, token)
// Verify the signature and claims of the token
key, err := identityStore.getNamedKey(ctx, storage, tt.mountEntry.Config.IdentityTokenKey)
require.NoError(t, err)
keySet, err := capjwt.NewStaticKeySet([]crypto.PublicKey{key.SigningKey.Public()})
require.NoError(t, err)
validator, err := capjwt.NewValidator(keySet)
require.NoError(t, err)
expected := capjwt.Expected{
Issuer: fmt.Sprintf("%s/v1/identity/oidc/plugins", identityStore.redirectAddr),
Subject: fmt.Sprintf("%s:%s:%s:%s", pluginTokenSubjectPrefix, namespace.RootNamespace.ID,
translateTableClaim(tt.mountEntry.Table), tt.mountEntry.Accessor),
Audiences: []string{tt.audience},
SigningAlgorithms: []capjwt.Alg{capjwt.RS256},
}
claims, err := validator.Validate(ctx, token, expected)
require.NoError(t, err)
require.Contains(t, claims, pluginTokenPrivateClaimKey)
require.IsType(t, map[string]interface{}{}, claims[pluginTokenPrivateClaimKey])
vaultSubClaims := claims[pluginTokenPrivateClaimKey].(map[string]interface{})
require.Equal(t, namespace.RootNamespace.ID, vaultSubClaims["namespace_id"])
require.Equal(t, namespace.RootNamespace.Path, vaultSubClaims["namespace_path"])
require.Equal(t, translateTableClaim(tt.mountEntry.Table), vaultSubClaims["class"])
require.Equal(t, tt.mountEntry.Type, vaultSubClaims["plugin"])
require.Equal(t, tt.mountEntry.RunningVersion, vaultSubClaims["version"])
require.Equal(t, tt.mountEntry.Path, vaultSubClaims["path"])
require.Equal(t, tt.mountEntry.Accessor, vaultSubClaims["accessor"])
require.Equal(t, tt.mountEntry.Local, vaultSubClaims["local"])
})
}
}
func createMountEntryWithKey(t *testing.T, ctx context.Context, sys *SystemBackend, mountPrefix, mountType, key string) {
t.Helper()
@@ -2048,34 +1909,3 @@ func createMountEntryWithKey(t *testing.T, ctx context.Context, sys *SystemBacke
})
expectSuccess(t, resp, err)
}
// Test_translateTableClaim tests that we convert mount entry table
// values to expected claim values.
func Test_translateTableClaim(t *testing.T) {
tests := []struct {
name string
table string
want string
}{
{
name: "given mounts table returns secret",
table: mountTableType,
want: secretTableValue,
},
{
name: "given auth table returns auth",
table: "auth",
want: "auth",
},
{
name: "given any value returns itself",
table: "other",
want: "other",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, translateTableClaim(tt.table), "translateTableClaim(%v)", tt.table)
})
}
}