oidc: check for nil signing key on rotation (#13716)

* check for nil signing key on rotation

* add changelog

* Update nil signing key handling

- bypass setting ExpireAt if signing key is nil in rotate
- return err if singing key is nil in signPayload

* add comment; update error msg on signPayload; refactor UT
This commit is contained in:
John-Michael Faircloth
2022-01-24 12:05:49 -06:00
committed by GitHub
parent f7a25fcf4c
commit be80ddedf1
3 changed files with 222 additions and 70 deletions

3
changelog/13716.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:bug
identity/oidc: Check for a nil signing key on rotation to prevent panics.
```

View File

@@ -548,19 +548,11 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
// generate current and next keys if creating a new key or changing algorithms // generate current and next keys if creating a new key or changing algorithms
if key.Algorithm != prevAlgorithm { if key.Algorithm != prevAlgorithm {
signingKey, err := generateKeys(key.Algorithm) err = key.generateAndSetKey(ctx, i.Logger(), req.Storage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key.SigningKey = signingKey
key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID})
if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)
err = key.generateAndSetNextKey(ctx, i.Logger(), req.Storage) err = key.generateAndSetNextKey(ctx, i.Logger(), req.Storage)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1013,6 +1005,24 @@ func mergeJSONTemplates(logger hclog.Logger, output map[string]interface{}, temp
return nil return nil
} }
// generateAndSetKey will generate new signing and public key pairs and set
// them as the SigningKey.
func (k *namedKey) generateAndSetKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error {
signingKey, err := generateKeys(k.Algorithm)
if err != nil {
return err
}
k.SigningKey = signingKey
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID})
if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil {
return err
}
logger.Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)
return nil
}
// generateAndSetNextKey will generate new signing and public key pairs and set // generateAndSetNextKey will generate new signing and public key pairs and set
// them as the NextSigningKey. // them as the NextSigningKey.
func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error { func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error {
@@ -1032,6 +1042,9 @@ func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logge
} }
func (k *namedKey) signPayload(payload []byte) (string, error) { func (k *namedKey) signPayload(payload []byte) (string, error) {
if k.SigningKey == nil {
return "", fmt.Errorf("signing key is nil; rotate the key and try again")
}
signingKey := jose.SigningKey{Key: k.SigningKey, Algorithm: jose.SignatureAlgorithm(k.Algorithm)} signingKey := jose.SigningKey{Key: k.SigningKey, Algorithm: jose.SignatureAlgorithm(k.Algorithm)}
signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{})
if err != nil { if err != nil {
@@ -1482,21 +1495,27 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0 // verification_ttl can be overridden with an overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error { func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
verificationTTL := k.VerificationTTL verificationTTL := k.VerificationTTL
if overrideVerificationTTL >= 0 { if overrideVerificationTTL >= 0 {
verificationTTL = overrideVerificationTTL verificationTTL = overrideVerificationTTL
} }
now := time.Now() now := time.Now()
// set the previous public key's expiry time if k.SigningKey != nil {
for _, key := range k.KeyRing { // set the previous public key's expiry time
if key.KeyID == k.SigningKey.KeyID { for _, key := range k.KeyRing {
key.ExpireAt = now.Add(verificationTTL) if key.KeyID == k.SigningKey.KeyID {
break key.ExpireAt = now.Add(verificationTTL)
break
}
} }
} else {
// this can occur for keys generated before vault 1.9.0 but rotated on
// vault 1.9.0
logger.Debug("nil signing key detected on rotation")
} }
if k.NextSigningKey == nil { if k.NextSigningKey == nil {
logger.Debug("nil next signing key detected on rotation")
// keys will not have a NextSigningKey if they were generated before // keys will not have a NextSigningKey if they were generated before
// vault 1.9 // vault 1.9
err := k.generateAndSetNextKey(ctx, logger, s) err := k.generateAndSetNextKey(ctx, logger, s)
@@ -1504,6 +1523,7 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St
return err return err
} }
} }
// do the rotation // do the rotation
k.SigningKey = k.NextSigningKey k.SigningKey = k.NextSigningKey
k.NextRotation = now.Add(k.RotationPeriod) k.NextRotation = now.Add(k.RotationPeriod)
@@ -1695,21 +1715,21 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
return now, err return now, err
} }
namedKeys, err := s.List(ctx, namedKeyConfigPath) keyNames, err := s.List(ctx, namedKeyConfigPath)
if err != nil { if err != nil {
return now, err return now, err
} }
usedKeys := make([]string, 0) usedKeys := make([]string, 0)
for _, k := range namedKeys { for _, keyName := range keyNames {
entry, err := s.Get(ctx, namedKeyConfigPath+k) entry, err := s.Get(ctx, namedKeyConfigPath+keyName)
if err != nil { if err != nil {
return now, err return now, err
} }
if entry == nil { if entry == nil {
i.Logger().Warn("could not find key to update", "key", k) i.Logger().Warn("could not find key to update", "key", keyName)
continue continue
} }
@@ -1722,14 +1742,14 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
keyRing := key.KeyRing keyRing := key.KeyRing
var keyringUpdated bool var keyringUpdated bool
for i := 0; i < len(keyRing); i++ { for j := 0; j < len(keyRing); j++ {
k := keyRing[i] k := keyRing[j]
if !k.ExpireAt.IsZero() && k.ExpireAt.Before(now) { if !k.ExpireAt.IsZero() && k.ExpireAt.Before(now) {
keyRing[i] = keyRing[len(keyRing)-1] keyRing[j] = keyRing[len(keyRing)-1]
keyRing = keyRing[:len(keyRing)-1] keyRing = keyRing[:len(keyRing)-1]
keyringUpdated = true keyringUpdated = true
i-- j--
continue continue
} }

View File

@@ -2,8 +2,6 @@ package vault
import ( import (
"context" "context"
"crypto/rand"
"crypto/rsa"
"encoding/json" "encoding/json"
"strconv" "strconv"
"strings" "strings"
@@ -11,7 +9,7 @@ import (
"time" "time"
"github.com/go-test/deep" "github.com/go-test/deep"
uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
@@ -893,6 +891,79 @@ func TestOIDC_SignIDToken(t *testing.T) {
} }
} }
// TestOIDC_SignIDToken_NilSigningKey tests that an error is returned when
// attempting to sign an ID token with a nil signing key
func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
ctx := namespace.RootContext(nil)
// Create and load an entity, an entity is required to generate an ID token
testEntity := &identity.Entity{
Name: "test-entity-name",
ID: "test-entity-id",
BucketKey: "test-entity-bucket-key",
}
txn := c.identityStore.db.Txn(true)
defer txn.Abort()
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
if err != nil {
t.Fatal(err)
}
txn.Commit()
// Create a test key "test-key" with a nil SigningKey
namedKey := &namedKey{
name: "test-key",
AllowedClientIDs: []string{"*"},
Algorithm: "RS256",
VerificationTTL: 60 * time.Second,
RotationPeriod: 60 * time.Second,
KeyRing: nil,
SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(),
}
s := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc")
if err := namedKey.generateAndSetNextKey(ctx, hclog.NewNullLogger(), s); err != nil {
t.Fatalf("failed to set next signing key")
}
// Store namedKey
entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+namedKey.name, namedKey)
if err := s.Put(ctx, entry); err != nil {
t.Fatalf("writing to in mem storage failed")
}
// Create a test role "test-role" -- expect no warning
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/role/test-role",
Operation: logical.CreateOperation,
Data: map[string]interface{}{
"key": "test-key",
"ttl": "1m",
},
Storage: s,
})
expectSuccess(t, resp, err)
if resp != nil {
t.Fatalf("was expecting a nil response but instead got: %#v", resp)
}
// Generate a token against the role "test-role" -- should fail
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/token/test-role",
Operation: logical.ReadOperation,
Storage: s,
EntityID: "test-entity-id",
})
expectError(t, resp, err)
// validate error message
expectedStrings := map[string]interface{}{
"error signing OIDC token: signing key is nil; rotate the key and try again": true,
}
expectStrings(t, []string{err.Error()}, expectedStrings)
}
// TestOIDC_PeriodicFunc tests timing logic for running key // TestOIDC_PeriodicFunc tests timing logic for running key
// rotations and expiration actions. // rotations and expiration actions.
func TestOIDC_PeriodicFunc(t *testing.T) { func TestOIDC_PeriodicFunc(t *testing.T) {
@@ -900,72 +971,111 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
c, _, _ := TestCoreUnsealed(t) c, _, _ := TestCoreUnsealed(t)
ctx := namespace.RootContext(nil) ctx := namespace.RootContext(nil)
// Prepare a dummy signing key
key, _ := rsa.GenerateKey(rand.Reader, 2048)
id, _ := uuid.GenerateUUID()
jwk := &jose.JSONWebKey{
Key: key,
KeyID: id,
Algorithm: "RS256",
Use: "sig",
}
cyclePeriod := 2 * time.Second cyclePeriod := 2 * time.Second
testSets := []struct { testSets := []struct {
namedKey *namedKey namedKey *namedKey
testCases []struct { expectedKeyCount int
cycle int setSigningKey bool
numKeys int setNextSigningKey bool
numPublicKeys int cycles int
}
}{ }{
{ {
// don't set NextSigningKey to ensure its non-existence can be handled namedKey: &namedKey{
&namedKey{
name: "test-key", name: "test-key",
Algorithm: "RS256", Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod, VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod, RotationPeriod: 1 * cyclePeriod,
KeyRing: nil, KeyRing: nil,
SigningKey: jwk, SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(), NextRotation: time.Now(),
}, },
[]struct { expectedKeyCount: 3,
cycle int setSigningKey: true,
numKeys int setNextSigningKey: true,
numPublicKeys int cycles: 4,
}{ },
{1, 2, 2}, {
{2, 3, 3}, // don't set SigningKey to ensure its non-existence can be handled
{3, 3, 3}, namedKey: &namedKey{
{4, 3, 3}, name: "test-key-nil-signing-key",
{5, 3, 3}, Algorithm: "RS256",
{6, 3, 3}, VerificationTTL: 1 * cyclePeriod,
{7, 3, 3}, RotationPeriod: 1 * cyclePeriod,
KeyRing: nil,
SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(),
}, },
expectedKeyCount: 2,
setSigningKey: false,
setNextSigningKey: true,
cycles: 2,
},
{
// don't set NextSigningKey to ensure its non-existence can be handled
namedKey: &namedKey{
name: "test-key-nil-next-signing-key",
Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod,
KeyRing: nil,
SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(),
},
expectedKeyCount: 2,
setSigningKey: true,
setNextSigningKey: false,
cycles: 2,
},
{
// don't set keys to ensure non-existence can be handled
namedKey: &namedKey{
name: "test-key-nil-signing-and-next-signing-key",
Algorithm: "RS256",
VerificationTTL: 1 * cyclePeriod,
RotationPeriod: 1 * cyclePeriod,
KeyRing: nil,
SigningKey: nil,
NextSigningKey: nil,
NextRotation: time.Now(),
},
expectedKeyCount: 2,
setSigningKey: false,
setNextSigningKey: false,
cycles: 2,
}, },
} }
for _, testSet := range testSets { for _, testSet := range testSets {
// Store namedKey
storage := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc") storage := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc")
if testSet.setSigningKey {
if err := testSet.namedKey.generateAndSetKey(ctx, hclog.NewNullLogger(), storage); err != nil {
t.Fatalf("failed to set signing key")
}
}
if testSet.setNextSigningKey {
if err := testSet.namedKey.generateAndSetNextKey(ctx, hclog.NewNullLogger(), storage); err != nil {
t.Fatalf("failed to set next signing key")
}
}
// Store namedKey
entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+testSet.namedKey.name, testSet.namedKey) entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+testSet.namedKey.name, testSet.namedKey)
if err := storage.Put(ctx, entry); err != nil { if err := storage.Put(ctx, entry); err != nil {
t.Fatalf("writing to in mem storage failed") t.Fatalf("writing to in mem storage failed")
} }
currentCycle := 1 currentCycle := 0
numCases := len(testSet.testCases) lastCycle := testSet.cycles - 1
lastCycle := testSet.testCases[numCases-1].cycle namedKeySamples := make([]*logical.StorageEntry, testSet.cycles)
namedKeySamples := make([]*logical.StorageEntry, numCases) publicKeysSamples := make([][]string, testSet.cycles)
publicKeysSamples := make([][]string, numCases)
i := 0 i := 0
// var start time.Time
for currentCycle <= lastCycle { for currentCycle <= lastCycle {
c.identityStore.oidcPeriodicFunc(ctx) c.identityStore.oidcPeriodicFunc(ctx)
if currentCycle == testSet.testCases[i].cycle { if currentCycle == i {
namedKeyEntry, _ := storage.Get(ctx, namedKeyConfigPath+testSet.namedKey.name) namedKeyEntry, _ := storage.Get(ctx, namedKeyConfigPath+testSet.namedKey.name)
publicKeysEntry, _ := storage.List(ctx, publicKeysConfigPath) publicKeysEntry, _ := storage.List(ctx, publicKeysConfigPath)
namedKeySamples[i] = namedKeyEntry namedKeySamples[i] = namedKeyEntry
@@ -985,15 +1095,34 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
} }
// measure collected samples // measure collected samples
for i := range testSet.testCases { for i := 0; i < testSet.cycles; i++ {
cycle := i + 1
namedKeySamples[i].DecodeJSON(&testSet.namedKey) namedKeySamples[i].DecodeJSON(&testSet.namedKey)
if len(testSet.namedKey.KeyRing) != testSet.testCases[i].numKeys { actualKeyRingLen := len(testSet.namedKey.KeyRing)
t.Fatalf("At cycle: %d expected namedKey's KeyRing to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numKeys, len(testSet.namedKey.KeyRing)) if actualKeyRingLen < testSet.expectedKeyCount {
t.Errorf(
"For key: %s at cycle: %d expected namedKey's KeyRing to be at least of length %d but was: %d",
testSet.namedKey.name,
cycle,
testSet.expectedKeyCount,
actualKeyRingLen,
)
} }
if len(publicKeysSamples[i]) != testSet.testCases[i].numPublicKeys { actualPubKeysLen := len(publicKeysSamples[i])
t.Fatalf("At cycle: %d expected public keys to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numPublicKeys, len(publicKeysSamples[i])) if actualPubKeysLen < testSet.expectedKeyCount {
t.Errorf(
"For key: %s at cycle: %d expected public keys to be at least of length %d but was: %d",
testSet.namedKey.name,
cycle,
testSet.expectedKeyCount,
actualPubKeysLen,
)
} }
} }
if err := storage.Delete(ctx, namedKeyConfigPath+testSet.namedKey.name); err != nil {
t.Fatalf("deleting from in mem storage failed")
}
} }
} }