backport of commit f97822da31 (#22796)

Co-authored-by: Victor Rodriguez <vrizo@hashicorp.com>
This commit is contained in:
hc-github-team-secure-vault-core
2023-09-06 12:09:00 -04:00
committed by GitHub
parent 9c1b75cc72
commit 9b6d2acb88
3 changed files with 286 additions and 127 deletions

View File

@@ -161,6 +161,16 @@ func (sgi *SealGenerationInfo) UnmarshalJSON(b []byte) error {
return nil
}
// OldKey is used as a return value from Decrypt to indicate that the old
// key was used for decryption and that the value should be re-encrypted
// with the new key and saved. It is not returned as an error by any
// function.
var OldKey = errors.New("decrypted with old key")
func IsOldKeyError(err error) bool {
return errors.Is(err, OldKey)
}
// Access is the embedded implementation of autoSeal that contains logic
// specific to encrypting and decrypting data, or in this case keys.
type Access interface {
@@ -270,33 +280,38 @@ func NewAccessFromWrapper(logger hclog.Logger, wrapper wrapping.Wrapper, sealCon
}
func (a *access) GetAllSealWrappersByPriority() []*SealWrapper {
return copySealWrappers(a.wrappersByPriority, false)
return a.filterSealWrappers(enabledAndDisabled, healthyAndUnhealthy)
}
func (a *access) GetEnabledSealWrappersByPriority() []*SealWrapper {
return copySealWrappers(a.wrappersByPriority, true)
return a.filterSealWrappers(enabledOnly, healthyAndUnhealthy)
}
func (a *access) AllSealWrappersHealthy() bool {
for _, sw := range a.wrappersByPriority {
// Ignore disabled seals
if sw.Disabled {
continue
}
if !sw.IsHealthy() {
return false
}
}
return true
return len(a.wrappersByPriority) == len(a.filterSealWrappers(enabledAndDisabled, healthyOnly))
}
func copySealWrappers(sealWrappers []*SealWrapper, enabledOnly bool) []*SealWrapper {
ret := make([]*SealWrapper, 0, len(sealWrappers))
for _, sw := range sealWrappers {
if enabledOnly && sw.Disabled {
type enabledFilter bool
type healthyFilter bool
const (
enabledOnly = enabledFilter(true)
enabledAndDisabled = !enabledOnly
healthyOnly = healthyFilter(true)
healthyAndUnhealthy = !healthyOnly
)
func (a *access) filterSealWrappers(enabled enabledFilter, healthy healthyFilter) []*SealWrapper {
ret := make([]*SealWrapper, 0, len(a.wrappersByPriority))
for _, sw := range a.wrappersByPriority {
switch {
case enabled == enabledOnly && sw.Disabled:
continue
case healthy == healthyOnly && !sw.IsHealthy():
continue
default:
ret = append(ret, sw)
}
ret = append(ret, sw)
}
return ret
}
@@ -348,11 +363,10 @@ func (a *access) IsUpToDate(ctx context.Context, value *MultiWrapValue, forceKey
a.logger.Error("error refreshing seal key IDs")
return false, JoinSealWrapErrors("cannot determine key IDs of Access wrappers", errs)
}
// TODO(SEALHA): What to do if there are partial failures?
if len(errs) > 0 {
msg := "could not determine key IDs of some Access wrappers"
a.logger.Warn(msg)
a.logger.Trace("partial failure refreshing seal key IDs", "err", JoinSealWrapErrors(msg, errs))
a.logger.Error("partial failure refreshing seal key IDs", "err", JoinSealWrapErrors(msg, errs))
return false, JoinSealWrapErrors(msg, errs)
}
a.keyIdSet.set(test)
}
@@ -360,43 +374,89 @@ func (a *access) IsUpToDate(ctx context.Context, value *MultiWrapValue, forceKey
return a.keyIdSet.equal(value), nil
}
const (
// wrapperEncryptTimeout is the duration we will wait for seal wrappers to return from an encrypt call.
// After the timeout, we return any successful results and errors for the rest of the wrappers, so
// that a partial seal wrap entry can be recorded.
wrapperEncryptTimeout = 10 * time.Second
// wrapperDecryptHighPriorityHeadStart is the duration we wait for the highest priority wrapper
// to return from a decrypt call before we try decrypting with any additional wrappers.
wrapperDecryptHighPriorityHeadStart = 2 * time.Second
)
// Encrypt uses the underlying seal to encrypt the plaintext and returns it.
func (a *access) Encrypt(ctx context.Context, plaintext []byte, options ...wrapping.Option) (*MultiWrapValue, map[string]error) {
// Note that we do not encrypt with disabled wrappers. Disabled wrappers are only used to decrypt.
enabledWrappersByPriority := a.filterSealWrappers(enabledOnly, healthyOnly)
if len(enabledWrappersByPriority) == 0 {
// If all seals are unhealthy, try any way since a seal may have recovered
enabledWrappersByPriority = a.filterSealWrappers(enabledOnly, healthyAndUnhealthy)
}
type result struct {
name string
ciphertext *wrapping.BlobInfo
err error
}
resultCh := make(chan *result)
encryptCtx, cancelEncryptCtx := context.WithTimeout(ctx, wrapperEncryptTimeout)
defer cancelEncryptCtx()
// Start goroutines to encrypt the value using each of the wrappers.
for _, sealWrapper := range enabledWrappersByPriority {
go func(sealWrapper *SealWrapper) {
ciphertext, err := a.tryEncrypt(encryptCtx, sealWrapper, plaintext, options...)
resultCh <- &result{
name: sealWrapper.Name,
ciphertext: ciphertext,
err: err,
}
}(sealWrapper)
}
results := make(map[string]*result)
GATHER_RESULTS:
for {
select {
case result := <-resultCh:
results[result.name] = result
if len(results) == len(enabledWrappersByPriority) {
break GATHER_RESULTS
}
case <-encryptCtx.Done():
break GATHER_RESULTS
case <-ctx.Done():
cancelEncryptCtx()
break GATHER_RESULTS
}
}
// Sort out the successful results from the errors
var slots []*wrapping.BlobInfo
errs := make(map[string]error)
for _, sealWrapper := range a.GetEnabledSealWrappersByPriority() {
now := time.Now()
var encryptErr error
defer func(now time.Time) {
metrics.MeasureSince([]string{"seal", "encrypt", "time"}, now)
metrics.MeasureSince([]string{"seal", sealWrapper.Name, "encrypt", "time"}, now)
if encryptErr != nil {
metrics.IncrCounter([]string{"seal", "encrypt", "error"}, 1)
metrics.IncrCounter([]string{"seal", sealWrapper.Name, "encrypt", "error"}, 1)
for _, sealWrapper := range enabledWrappersByPriority {
if result, ok := results[sealWrapper.Name]; ok {
if result.err != nil {
errs[sealWrapper.Name] = result.err
} else {
slots = append(slots, result.ciphertext)
}
}(now)
metrics.IncrCounter([]string{"seal", "encrypt"}, 1)
metrics.IncrCounter([]string{"seal", sealWrapper.Name, "encrypt"}, 1)
ciphertext, encryptErr := sealWrapper.Wrapper.Encrypt(ctx, plaintext, options...)
if encryptErr != nil {
a.logger.Warn("error encrypting with seal", "seal", sealWrapper.Name)
a.logger.Trace("error encrypting with seal", "seal", sealWrapper.Name, "err", encryptErr)
errs[sealWrapper.Name] = encryptErr
sealWrapper.SetHealthy(false, now)
} else {
a.logger.Trace("encrypted value using seal", "seal", sealWrapper.Name, "keyId", ciphertext.KeyInfo.KeyId)
slots = append(slots, ciphertext)
if encryptCtx.Err() != nil {
errs[sealWrapper.Name] = encryptCtx.Err()
} else {
// Just being paranoid, encryptCtx.Err() should never be nil in this case
errs[sealWrapper.Name] = errors.New("context timeout exceeded")
}
// This failure did not happen on tryDecrypt, so we must log it here
a.logger.Trace("error encrypting with seal", "seal", sealWrapper.Name, "err", errs[sealWrapper.Name])
}
}
if len(slots) == 0 {
a.logger.Error("all seals failed to encrypt value")
a.logger.Error("failed to encrypt value using any seal wrappers")
return nil, errs
}
@@ -407,12 +467,44 @@ func (a *access) Encrypt(ctx context.Context, plaintext []byte, options ...wrapp
Slots: slots,
}
// cache key IDs
a.keyIdSet.set(ret)
if len(errs) == 0 {
// cache key IDs
a.keyIdSet.set(ret)
}
return ret, errs
}
func (a *access) tryEncrypt(ctx context.Context, sealWrapper *SealWrapper, plaintext []byte, options ...wrapping.Option) (*wrapping.BlobInfo, error) {
now := time.Now()
var encryptErr error
defer func(now time.Time) {
metrics.MeasureSince([]string{"seal", "encrypt", "time"}, now)
metrics.MeasureSince([]string{"seal", sealWrapper.Name, "encrypt", "time"}, now)
if encryptErr != nil {
metrics.IncrCounter([]string{"seal", "encrypt", "error"}, 1)
metrics.IncrCounter([]string{"seal", sealWrapper.Name, "encrypt", "error"}, 1)
}
}(now)
metrics.IncrCounter([]string{"seal", "encrypt"}, 1)
metrics.IncrCounter([]string{"seal", sealWrapper.Name, "encrypt"}, 1)
ciphertext, encryptErr := sealWrapper.Wrapper.Encrypt(ctx, plaintext, options...)
if encryptErr != nil {
a.logger.Warn("error encrypting with seal", "seal", sealWrapper.Name)
a.logger.Trace("error encrypting with seal", "seal", sealWrapper.Name, "err", encryptErr)
sealWrapper.SetHealthy(false, now)
return nil, encryptErr
}
a.logger.Trace("encrypted value using seal", "seal", sealWrapper.Name, "keyId", ciphertext.KeyInfo.KeyId)
sealWrapper.SetHealthy(true, now)
return ciphertext, nil
}
// Decrypt uses the underlying seal to decrypt the ciphertext and returns it.
// Note that it is possible depending on the wrapper used that both pt and err
// are populated.
@@ -426,46 +518,85 @@ func (a *access) Decrypt(ctx context.Context, ciphertext *MultiWrapValue, option
return nil, false, err
}
// First, lets try the wrappers in order of priority and look for an exact key ID match
for _, sealWrapper := range a.GetAllSealWrappersByPriority() {
if keyId, err := sealWrapper.Wrapper.KeyId(ctx); err == nil {
if blobInfo, ok := blobInfoMap[keyId]; ok {
pt, oldKey, err := a.tryDecrypt(ctx, sealWrapper, blobInfo, options)
if oldKey {
a.logger.Trace("decrypted using OldKey", "seal", sealWrapper.Name)
return pt, false, err
}
if err == nil {
a.logger.Trace("decrypted value using seal", "seal", sealWrapper.Name)
return pt, isUpToDate, nil
}
// If there is an error, keep trying with the other wrappers
a.logger.Trace("error decrypting with seal, will try other seals", "seal", sealWrapper.Name, "keyId", keyId, "err", err)
}
wrappersByPriority := a.filterSealWrappers(enabledAndDisabled, healthyOnly)
if len(wrappersByPriority) == 0 {
// If all seals are unhealthy, try any way since a seal may have recovered
wrappersByPriority = a.filterSealWrappers(enabledAndDisabled, healthyAndUnhealthy)
}
type result struct {
name string
pt []byte
oldKey bool
err error
}
resultCh := make(chan *result)
decrypt := func(sealWrapper *SealWrapper) {
pt, oldKey, err := a.tryDecrypt(ctx, sealWrapper, blobInfoMap, options)
resultCh <- &result{
name: sealWrapper.Name,
pt: pt,
oldKey: oldKey,
err: err,
}
}
// No key ID match, so try each wrapper with all slots
// Start goroutines to decrypt the value
for i, sealWrapper := range wrappersByPriority {
sealWrapper := sealWrapper
if i == 0 {
// start the highest priority wrapper right away
go decrypt(sealWrapper)
} else {
timer := time.AfterFunc(wrapperDecryptHighPriorityHeadStart, func() {
decrypt(sealWrapper)
})
defer timer.Stop()
}
}
// Gathering failures, but return right away if there is a succesful result
errs := make(map[string]error)
for _, sealWrapper := range a.GetAllSealWrappersByPriority() {
for _, blobInfo := range ciphertext.Slots {
pt, oldKey, err := a.tryDecrypt(ctx, sealWrapper, blobInfo, options)
if oldKey {
a.logger.Trace("decrypted using OldKey", "seal", sealWrapper.Name)
return pt, false, err
GATHER_RESULTS:
for {
select {
case result := <-resultCh:
switch {
case result.err != nil:
errs[result.name] = result.err
if len(errs) == len(wrappersByPriority) {
break GATHER_RESULTS
}
case result.oldKey:
return result.pt, false, OldKey
default:
return result.pt, isUpToDate, nil
}
if err == nil {
a.logger.Trace("decrypted value using seal", "seal", sealWrapper.Name)
return pt, isUpToDate, nil
}
errs[sealWrapper.Name] = err
case <-ctx.Done():
break GATHER_RESULTS
}
}
return nil, false, JoinSealWrapErrors("error decrypting seal wrapped value", errs)
// No wrapper was able to decrypt the value, return an error
if len(errs) > 0 {
return nil, false, JoinSealWrapErrors("error decrypting seal wrapped value", errs)
}
if ctx.Err() != nil {
return nil, false, ctx.Err()
}
// Just being paranoid, ctx.Err() should never be nil in this case
return nil, false, errors.New("context timeout exceeded")
}
func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphertext *wrapping.BlobInfo, options []wrapping.Option) ([]byte, bool, error) {
// tryDecrypt returns the plaintext and a flad indicating whether the decryption was done by the "unwrapSeal" (see
// sealWrapMigration.Decrypt).
func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphertextByKeyId map[string]*wrapping.BlobInfo, options []wrapping.Option) ([]byte, bool, error) {
now := time.Now()
var decryptErr error
defer func(now time.Time) {
metrics.MeasureSince([]string{"seal", "decrypt", "time"}, now)
@@ -475,19 +606,52 @@ func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphe
metrics.IncrCounter([]string{"seal", "decrypt", "error"}, 1)
metrics.IncrCounter([]string{"seal", sealWrapper.Name, "decrypt", "error"}, 1)
}
// TODO (multiseal): log an error?
}(time.Now())
}(now)
metrics.IncrCounter([]string{"seal", "decrypt"}, 1)
metrics.IncrCounter([]string{"seal", sealWrapper.Name, "decrypt"}, 1)
pt, err := sealWrapper.Wrapper.Decrypt(ctx, ciphertext, options...)
isOldKey := false
if err != nil && err.Error() == "decrypted with old key" {
// This is for compatibility with sealWrapMigration
isOldKey = true
var pt []byte
// First, let's look for an exact key ID match
var keyId string
if id, err := sealWrapper.Wrapper.KeyId(ctx); err == nil {
keyId = id
if ciphertext, ok := ciphertextByKeyId[keyId]; ok {
pt, decryptErr = sealWrapper.Wrapper.Decrypt(ctx, ciphertext, options...)
sealWrapper.SetHealthy(decryptErr == nil || IsOldKeyError(decryptErr), now)
}
}
// If we don't get a result, try all the slots
if pt == nil && decryptErr == nil {
for _, ciphertext := range ciphertextByKeyId {
pt, decryptErr = sealWrapper.Wrapper.Decrypt(ctx, ciphertext, options...)
if decryptErr == nil {
// Note that we only update wrapper health for failures on exact key ID match,
// otherwise we would have false negatives.
sealWrapper.SetHealthy(true, now)
break
}
}
}
switch {
case decryptErr != nil && IsOldKeyError(decryptErr):
// an OldKey error is not an actual error, it just means that the decryption was done
// by the "unwrapSeal" of a seal migration (see sealWrapMigration.Decrypt).
a.logger.Trace("decrypted using OldKey", "seal_name", sealWrapper.Name)
return pt, true, nil
case decryptErr != nil:
// Note that if there are more than one ciphertext, the error may be misleading...
a.logger.Trace("error decrypting with seal, this may be a harmless mismatch between wrapper and ciphertext", "seal_name", sealWrapper.Name, "keyId", keyId, "err", decryptErr)
return nil, false, decryptErr
default:
a.logger.Trace("decrypted value using seal", "seal_name", sealWrapper.Name)
return pt, false, nil
}
return pt, isOldKey, err
}
func JoinSealWrapErrors(msg string, errorMap map[string]error) error {

View File

@@ -26,7 +26,9 @@ type SealWrapper struct {
// Disabled indicates, when true indicates that this wrapper should only be used for decryption.
Disabled bool
// hcLock protects lastHealthy, lastSeenHealthy, and healthy. Do not modify those fields directly, use setHealth instead.
// hcLock protects lastHealthy, lastSeenHealthy, and healthy.
// Do not modify those fields directly, use setHealth instead.
// Do not access these fields directly, use getHealth instead.
hcLock sync.RWMutex
lastHealthCheck time.Time
lastSeenHealthy time.Time
@@ -42,53 +44,36 @@ func NewSealWrapper(wrapper wrapping.Wrapper, priority int, name string, sealCon
Disabled: disabled,
}
ret.setHealth(true, time.Now(), ret.lastHealthCheck)
setHealth(ret, true, time.Now(), ret.lastHealthCheck)
return ret
}
func (sw *SealWrapper) rlock() func() {
sw.hcLock.RLock()
return sw.hcLock.RUnlock
}
func (sw *SealWrapper) lock() func() {
sw.hcLock.Lock()
return sw.hcLock.Unlock
}
func (sw *SealWrapper) SetHealthy(healthy bool, checkTime time.Time) {
unlock := sw.lock()
defer unlock()
wasHealthy := sw.healthy
lastHealthy := sw.lastSeenHealthy
if !wasHealthy && healthy {
lastHealthy = checkTime
if healthy {
setHealth(sw, true, checkTime, checkTime)
} else {
// do not update lastSeenHealthy
setHealth(sw, false, sw.lastHealthCheck, checkTime)
}
sw.setHealth(healthy, lastHealthy, checkTime)
}
func (sw *SealWrapper) IsHealthy() bool {
unlock := sw.rlock()
defer unlock()
healthy, _, _ := getHealth(sw)
return sw.healthy
return healthy
}
func (sw *SealWrapper) LastSeenHealthy() time.Time {
unlock := sw.rlock()
defer unlock()
_, lastSeenHealthy, _ := getHealth(sw)
return sw.lastSeenHealthy
return lastSeenHealthy
}
func (sw *SealWrapper) LastHealthCheck() time.Time {
unlock := sw.rlock()
defer unlock()
_, _, lastHealthCheck := getHealth(sw)
return sw.lastHealthCheck
return lastHealthCheck
}
var (
@@ -99,35 +84,43 @@ var (
)
func (sw *SealWrapper) CheckHealth(ctx context.Context, checkTime time.Time) error {
unlock := sw.lock()
defer unlock()
// Assume the wrapper is unhealthy, if we make it to the end we'll set it to true
sw.setHealth(false, sw.lastSeenHealthy, checkTime)
testVal := fmt.Sprintf("Heartbeat %d", mathrand.Intn(1000))
ciphertext, err := sw.Wrapper.Encrypt(ctx, []byte(testVal), nil)
if err != nil {
sw.SetHealthy(false, checkTime)
return fmt.Errorf("failed to encrypt test value, seal wrapper may be unreachable: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, HealthTestTimeout)
defer cancel()
plaintext, err := sw.Wrapper.Decrypt(ctx, ciphertext, nil)
if err != nil {
if err != nil && !IsOldKeyError(err) {
sw.SetHealthy(false, checkTime)
return fmt.Errorf("failed to decrypt test value, seal wrapper may be unreachable: %w", err)
}
if !bytes.Equal([]byte(testVal), plaintext) {
sw.SetHealthy(false, checkTime)
return errors.New("failed to decrypt health test value to expected result")
}
sw.setHealth(true, checkTime, checkTime)
sw.SetHealthy(true, checkTime)
return nil
}
// setHealth sets the fields protected by sw.hcLock, callers *must* hold the write lock.
func (sw *SealWrapper) setHealth(healthy bool, lastSeenHealthy, lastHealthCheck time.Time) {
// getHealth is the only function allowed to inspect the health fields directly
func getHealth(sw *SealWrapper) (healthy bool, lastSeenHealthy time.Time, lastHealthCheck time.Time) {
sw.hcLock.RLock()
defer sw.hcLock.RUnlock()
return sw.healthy, sw.lastSeenHealthy, sw.lastHealthCheck
}
// setHealth is the only function allowed to mutate the health fields
func setHealth(sw *SealWrapper, healthy bool, lastSeenHealthy, lastHealthCheck time.Time) {
sw.hcLock.Lock()
defer sw.hcLock.Unlock()
sw.healthy = healthy
sw.lastSeenHealthy = lastSeenHealthy
sw.lastHealthCheck = lastHealthCheck

View File

@@ -470,6 +470,8 @@ func (d *autoSeal) StartHealthCheck() {
ctx, cancel := context.WithTimeout(ctx, seal.HealthTestTimeout)
defer cancel()
d.logger.Trace("performing a seal health check")
allHealthy := true
allUnhealthy := true
for _, sealWrapper := range d.Access.GetAllSealWrappersByPriority() {