mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-11-04 04:08:16 +00:00 
			
		
		
		
	Merge pull request #115814 from aramase/kms-cryptographic-wearout
[KMSv2] Implement local KEK generation and rotation
This commit is contained in:
		@@ -17,6 +17,7 @@ limitations under the License.
 | 
			
		||||
package encryption
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/aes"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
@@ -24,16 +25,29 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"k8s.io/apimachinery/pkg/util/uuid"
 | 
			
		||||
	"k8s.io/apimachinery/pkg/util/wait"
 | 
			
		||||
	"k8s.io/klog/v2"
 | 
			
		||||
	aestransformer "k8s.io/kms/pkg/encrypt/aes"
 | 
			
		||||
	"k8s.io/kms/pkg/value"
 | 
			
		||||
	"k8s.io/kms/service"
 | 
			
		||||
	"k8s.io/utils/clock"
 | 
			
		||||
	"k8s.io/utils/lru"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// localKEK is a struct that holds the local KEK and the remote KMS response.
 | 
			
		||||
type localKEK struct {
 | 
			
		||||
	encKEK            []byte
 | 
			
		||||
	usage             atomic.Uint64
 | 
			
		||||
	expiry            time.Time
 | 
			
		||||
	transformer       value.Transformer
 | 
			
		||||
	remoteKMSResponse *service.EncryptResponse
 | 
			
		||||
	generatedAt       time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// emptyContext is an empty slice of bytes. This is passed as value.Context to the
 | 
			
		||||
	// GCM transformer. The grpc interface does not provide any additional authenticated data
 | 
			
		||||
@@ -41,10 +55,6 @@ var (
 | 
			
		||||
	emptyContext = value.DefaultContext([]byte{})
 | 
			
		||||
	// errInvalidKMSAnnotationKeySuffix is returned when the annotation key suffix is not allowed.
 | 
			
		||||
	errInvalidKMSAnnotationKeySuffix = fmt.Errorf("annotation keys are not allowed to use %s", referenceSuffix)
 | 
			
		||||
 | 
			
		||||
	// these are var instead of const so that we can set them during tests
 | 
			
		||||
	localKEKGenerationPollInterval = 1 * time.Second
 | 
			
		||||
	localKEKGenerationPollTimeout  = 5 * time.Minute
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -53,121 +63,166 @@ const (
 | 
			
		||||
	referenceKEKAnnotationKey = "encrypted-kek" + referenceSuffix
 | 
			
		||||
	numAnnotations            = 1
 | 
			
		||||
	cacheSize                 = 1_000
 | 
			
		||||
 | 
			
		||||
	// localKEKGenerationPollInterval is the interval at which the local KEK is checked for rotation.
 | 
			
		||||
	localKEKGenerationPollInterval = 1 * time.Minute
 | 
			
		||||
 | 
			
		||||
	// keyLength is the length of the local KEK in bytes.
 | 
			
		||||
	// This is the same length used for the DEKs generated in kube-apiserver.
 | 
			
		||||
	keyLength = 32
 | 
			
		||||
	// keyMaxUsage is the maximum number of times an AES GCM key can be used
 | 
			
		||||
	// with a random nonce: 2^32. The local KEK is a transformer that hold an
 | 
			
		||||
	// AES GCM key. It is based on recommendations from
 | 
			
		||||
	// https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf.
 | 
			
		||||
	// It is reduced by one to be comparable with a atomic.Uint32.
 | 
			
		||||
	// We picked a arbitrary number that is less than the max usage of the local KEK.
 | 
			
		||||
	keyMaxUsage = 1<<22 - 1
 | 
			
		||||
	// keySuggestedUsage is a threshold that triggers the rotation of a new local KEK. It means that half
 | 
			
		||||
	// the number of times a local KEK can be used has been reached.
 | 
			
		||||
	keySuggestedUsage = 1 << 21
 | 
			
		||||
	// keyMaxAge is the maximum age of a local KEK. It is not a cryptographic necessity.
 | 
			
		||||
	keyMaxAge = 7 * 24 * time.Hour
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ service.Service = &LocalKEKService{}
 | 
			
		||||
 | 
			
		||||
// LocalKEKService adds an additional KEK layer to reduce calls to the remote
 | 
			
		||||
// KMS.
 | 
			
		||||
// The local KEK is generated once and stored in the LocalKEKService. This KEK
 | 
			
		||||
// is used for all encryption operations. For the decrypt operation, if the encrypted
 | 
			
		||||
// local KEK is not found in the cache, the remote KMS is used to decrypt the local KEK.
 | 
			
		||||
// LocalKEKService adds an additional KEK layer to reduce calls to the remote KMS.
 | 
			
		||||
// The local KEK is generated at startup in a controller loop and stored in the
 | 
			
		||||
// LocalKEKService. This KEK is used for all encryption operations. For the decrypt
 | 
			
		||||
// operation, if the encrypted local KEK is not found in the cache, the remote KMS
 | 
			
		||||
// is used to decrypt the local KEK.
 | 
			
		||||
type LocalKEKService struct {
 | 
			
		||||
	mu sync.Mutex
 | 
			
		||||
	// remoteKMS is the remote kms that is used to encrypt and decrypt the local KEKs.
 | 
			
		||||
	remoteKMS  service.Service
 | 
			
		||||
	remoteOnce sync.Once
 | 
			
		||||
 | 
			
		||||
	remoteKMS service.Service
 | 
			
		||||
	// localKEKTracker is a atomic pointer to avoid locks. This is used to store the local KEK.
 | 
			
		||||
	localKEKTracker atomic.Pointer[localKEK]
 | 
			
		||||
	// transformers is a thread-safe LRU cache which caches decrypted DEKs indexed by their encrypted form.
 | 
			
		||||
	// The cache is only used for the decrypt operation.
 | 
			
		||||
	transformers *lru.Cache
 | 
			
		||||
	// isReady is an atomic boolean that indicates if the localKEK service is ready for encryption.
 | 
			
		||||
	isReady atomic.Bool
 | 
			
		||||
 | 
			
		||||
	remoteKMSResponse   *service.EncryptResponse
 | 
			
		||||
	localTransformer    value.Transformer
 | 
			
		||||
	localTransformerErr error
 | 
			
		||||
	keyMaxUsage       uint64
 | 
			
		||||
	keySuggestedUsage uint64
 | 
			
		||||
	keyMaxAge         time.Duration
 | 
			
		||||
 | 
			
		||||
	pollInterval time.Duration
 | 
			
		||||
 | 
			
		||||
	clock clock.Clock
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewLocalKEKService is being initialized with a remote KMS service.
 | 
			
		||||
// In the current implementation, the localKEK Service needs to be
 | 
			
		||||
// restarted by the caller after security thresholds are met.
 | 
			
		||||
// TODO(aramase): handle rotation of local KEKs
 | 
			
		||||
//   - when the keyID in Status() no longer matches the keyID used during encryption
 | 
			
		||||
//   - when the local KEK has been used for a certain number of times
 | 
			
		||||
func NewLocalKEKService(remoteService service.Service) *LocalKEKService {
 | 
			
		||||
	return &LocalKEKService{
 | 
			
		||||
		remoteKMS:    remoteService,
 | 
			
		||||
		transformers: lru.New(cacheSize),
 | 
			
		||||
	}
 | 
			
		||||
// The local KEK is generated in a controller loop. The local KEK is used for all
 | 
			
		||||
// encryption operations.
 | 
			
		||||
func NewLocalKEKService(ctx context.Context, remoteService service.Service) *LocalKEKService {
 | 
			
		||||
	return newLocalKEKService(ctx, remoteService, keyMaxUsage, keySuggestedUsage, keyMaxAge, localKEKGenerationPollInterval, clock.RealClock{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newLocalKEKService(ctx context.Context, remoteService service.Service, maxUsage, suggestedUsage uint64, maxAge, pollInterval time.Duration, clock clock.Clock) *LocalKEKService {
 | 
			
		||||
	localKEKService := &LocalKEKService{
 | 
			
		||||
		remoteKMS:         remoteService,
 | 
			
		||||
		transformers:      lru.New(cacheSize),
 | 
			
		||||
		keyMaxUsage:       maxUsage,
 | 
			
		||||
		keySuggestedUsage: suggestedUsage,
 | 
			
		||||
		keyMaxAge:         maxAge,
 | 
			
		||||
		pollInterval:      pollInterval,
 | 
			
		||||
		clock:             clock,
 | 
			
		||||
	}
 | 
			
		||||
	go localKEKService.run(ctx)
 | 
			
		||||
	return localKEKService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Run method creates a new local KEK  when the following thresholds are met:
 | 
			
		||||
//   - the local KEK is used more often than keySuggestedUsage times or
 | 
			
		||||
//   - the local KEK is older than a localExpiry.
 | 
			
		||||
//
 | 
			
		||||
// this method starts the controller and blocks until the context is cancelled.
 | 
			
		||||
func (m *LocalKEKService) run(ctx context.Context) {
 | 
			
		||||
	// same as wait.UntilWithContext but with a custom clock
 | 
			
		||||
	wait.BackoffUntil(func() {
 | 
			
		||||
		lk := m.getLocalKEK()
 | 
			
		||||
		// this is the first time the local KEK is generated
 | 
			
		||||
		localKEKNotGenerated := lk.transformer == nil
 | 
			
		||||
		// the local KEK is used more often than keySuggestedUsage times
 | 
			
		||||
		localKEKUsageThresholdReached := lk.usage.Load() > m.keySuggestedUsage
 | 
			
		||||
		// the local KEK is older than the expiry
 | 
			
		||||
		localKEKExpired := m.clock.Now().After(lk.expiry)
 | 
			
		||||
 | 
			
		||||
		if localKEKNotGenerated || localKEKUsageThresholdReached || localKEKExpired {
 | 
			
		||||
			uid := string(uuid.NewUUID())
 | 
			
		||||
			err := m.generateLocalKEK(ctx, uid, "")
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				m.isReady.Store(true)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			klog.V(2).ErrorS(err, "failed to generate local KEK", "uid", uid)
 | 
			
		||||
			// if the local KEK is expired and we cannot generate a new one, we set
 | 
			
		||||
			// isReady to false because we can no longer encrypt new data.
 | 
			
		||||
			if localKEKExpired {
 | 
			
		||||
				m.isReady.Store(false)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}, wait.NewJitteredBackoffManager(m.pollInterval, 0, m.clock), true, ctx.Done())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getTransformerForEncryption returns the local KEK as localTransformer, the corresponding
 | 
			
		||||
// remoteKMSResponse and a potential error.
 | 
			
		||||
// On every use the localUsage is incremented by one.
 | 
			
		||||
// It is assumed that only one encryption will happen with the returned transformer.
 | 
			
		||||
func (m *LocalKEKService) getTransformerForEncryption(uid string) (value.Transformer, *service.EncryptResponse, error) {
 | 
			
		||||
	// Check if we have a local KEK
 | 
			
		||||
	//	- If exists, use the local KEK for encryption and return
 | 
			
		||||
	//  - Not exists, generate local KEK, encrypt with remote KEK,
 | 
			
		||||
	//	store it in cache encrypt the data and return. This can be
 | 
			
		||||
	// 	expensive but only 1 in N calls will incur this additional latency,
 | 
			
		||||
	// 	N being number of times local KEK is reused)
 | 
			
		||||
	m.remoteOnce.Do(func() {
 | 
			
		||||
		m.localTransformerErr = wait.PollImmediateWithContext(context.Background(), localKEKGenerationPollInterval, localKEKGenerationPollTimeout,
 | 
			
		||||
			func(ctx context.Context) (done bool, err error) {
 | 
			
		||||
				key, err := generateKey(keyLength)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return false, fmt.Errorf("failed to generate local KEK: %w", err)
 | 
			
		||||
				}
 | 
			
		||||
				block, err := aes.NewCipher(key)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return false, fmt.Errorf("failed to create cipher block: %w", err)
 | 
			
		||||
				}
 | 
			
		||||
				transformer := aestransformer.NewGCMTransformer(block)
 | 
			
		||||
 | 
			
		||||
				resp, err := m.remoteKMS.Encrypt(ctx, uid, key)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					klog.ErrorS(err, "failed to encrypt local KEK with remote KMS", "uid", uid)
 | 
			
		||||
					return false, nil
 | 
			
		||||
				}
 | 
			
		||||
				if err = validateRemoteKMSResponse(resp); err != nil {
 | 
			
		||||
					return false, fmt.Errorf("response annotations failed validation: %w", err)
 | 
			
		||||
				}
 | 
			
		||||
				m.remoteKMSResponse = copyResponseAndAddLocalKEKAnnotation(resp)
 | 
			
		||||
				m.localTransformer = transformer
 | 
			
		||||
				m.transformers.Add(base64.StdEncoding.EncodeToString(resp.Ciphertext), transformer)
 | 
			
		||||
				return true, nil
 | 
			
		||||
			})
 | 
			
		||||
	})
 | 
			
		||||
	return m.localTransformer, m.remoteKMSResponse, m.localTransformerErr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func copyResponseAndAddLocalKEKAnnotation(resp *service.EncryptResponse) *service.EncryptResponse {
 | 
			
		||||
	annotations := make(map[string][]byte, len(resp.Annotations)+numAnnotations)
 | 
			
		||||
	for s, bytes := range resp.Annotations {
 | 
			
		||||
		s := s
 | 
			
		||||
		bytes := bytes
 | 
			
		||||
		annotations[s] = bytes
 | 
			
		||||
	lk := m.getLocalKEK()
 | 
			
		||||
	// localKEK is not initialized yet
 | 
			
		||||
	if lk.transformer == nil {
 | 
			
		||||
		return nil, nil, fmt.Errorf("local KEK is not initialized")
 | 
			
		||||
	}
 | 
			
		||||
	annotations[referenceKEKAnnotationKey] = resp.Ciphertext
 | 
			
		||||
 | 
			
		||||
	return &service.EncryptResponse{
 | 
			
		||||
		// Ciphertext is not set on purpose - it is different per Encrypt call
 | 
			
		||||
		KeyID:       resp.KeyID,
 | 
			
		||||
		Annotations: annotations,
 | 
			
		||||
	if m.clock.Now().After(lk.expiry) {
 | 
			
		||||
		return nil, nil, fmt.Errorf("local KEK has expired at %v", lk.expiry)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if counter := lk.usage.Add(1); counter >= m.keyMaxUsage {
 | 
			
		||||
		return nil, nil, fmt.Errorf("local KEK has reached maximum usage of %d", keyMaxUsage)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return lk.transformer, lk.remoteKMSResponse, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Encrypt encrypts the plaintext with the localKEK.
 | 
			
		||||
func (m *LocalKEKService) Encrypt(ctx context.Context, uid string, pt []byte) (*service.EncryptResponse, error) {
 | 
			
		||||
	transformer, resp, err := m.getTransformerForEncryption(uid)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		klog.V(2).InfoS("encrypt plaintext", "uid", uid, "err", err)
 | 
			
		||||
		klog.V(2).ErrorS(err, "failed to get transformer for encryption", "uid", uid)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ct, err := transformer.TransformToStorage(ctx, pt, emptyContext)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		klog.V(2).InfoS("encrypt plaintext", "uid", uid, "err", err)
 | 
			
		||||
		klog.V(2).ErrorS(err, "failed to encrypt data", "uid", uid)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &service.EncryptResponse{
 | 
			
		||||
		Ciphertext:  ct,
 | 
			
		||||
		KeyID:       resp.KeyID, // TODO what about rotation ??
 | 
			
		||||
		KeyID:       resp.KeyID,
 | 
			
		||||
		Annotations: resp.Annotations,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getTransformerForDecryption returns the transformer for the given encryptedKEK.
 | 
			
		||||
// - If the encryptedKEK is the current localKEK, the localKEK is returned.
 | 
			
		||||
// - If the encryptedKEK is not the current localKEK, the cache is checked.
 | 
			
		||||
// - If the encryptedKEK is not found in the cache, the remote KMS is used to decrypt the encryptedKEK.
 | 
			
		||||
func (m *LocalKEKService) getTransformerForDecryption(ctx context.Context, uid string, req *service.DecryptRequest) (value.Transformer, error) {
 | 
			
		||||
	encKEK := req.Annotations[referenceKEKAnnotationKey]
 | 
			
		||||
 | 
			
		||||
	// check if the key required for decryption is the current local KEK
 | 
			
		||||
	// that's being used for encryption
 | 
			
		||||
	lk := m.getLocalKEK()
 | 
			
		||||
	if lk.transformer != nil && bytes.Equal(lk.encKEK, encKEK) {
 | 
			
		||||
		return lk.transformer, nil
 | 
			
		||||
	}
 | 
			
		||||
	// check if the key required for decryption is already in the cache
 | 
			
		||||
	if _transformer, found := m.transformers.Get(base64.StdEncoding.EncodeToString(encKEK)); found {
 | 
			
		||||
		return _transformer.(value.Transformer), nil
 | 
			
		||||
	}
 | 
			
		||||
@@ -190,7 +245,7 @@ func (m *LocalKEKService) getTransformerForDecryption(ctx context.Context, uid s
 | 
			
		||||
	// Overwrite the plain key with 0s.
 | 
			
		||||
	copy(key, make([]byte, len(key)))
 | 
			
		||||
 | 
			
		||||
	m.transformers.Add(encKEK, transformer)
 | 
			
		||||
	m.transformers.Add(base64.StdEncoding.EncodeToString(encKEK), transformer)
 | 
			
		||||
 | 
			
		||||
	return transformer, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -204,13 +259,13 @@ func (m *LocalKEKService) Decrypt(ctx context.Context, uid string, req *service.
 | 
			
		||||
 | 
			
		||||
	transformer, err := m.getTransformerForDecryption(ctx, uid, req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		klog.V(2).InfoS("decrypt ciphertext", "uid", uid, "err", err)
 | 
			
		||||
		klog.V(2).ErrorS(err, "failed to get transformer for decryption", "uid", uid)
 | 
			
		||||
		return nil, fmt.Errorf("failed to get transformer for decryption: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pt, _, err := transformer.TransformFromStorage(ctx, req.Ciphertext, emptyContext)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		klog.V(2).InfoS("decrypt ciphertext with pulled key", "uid", uid, "err", err)
 | 
			
		||||
		klog.V(2).ErrorS(err, "failed to decrypt data", "uid", uid)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -219,12 +274,142 @@ func (m *LocalKEKService) Decrypt(ctx context.Context, uid string, req *service.
 | 
			
		||||
 | 
			
		||||
// Status returns the status of the remote KMS.
 | 
			
		||||
func (m *LocalKEKService) Status(ctx context.Context) (*service.StatusResponse, error) {
 | 
			
		||||
	// TODO(aramase): the response from the remote KMS is funneled through without any validation/action.
 | 
			
		||||
	// This needs to handle the case when remote KEK has changed. The local KEK needs to be rotated and
 | 
			
		||||
	// re-encrypted with the new remote KEK.
 | 
			
		||||
	return m.remoteKMS.Status(ctx)
 | 
			
		||||
	resp, err := m.remoteKMS.Status(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if err := validateRemoteKMSStatusResponse(resp); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r := copyStatusResponse(resp)
 | 
			
		||||
	// if the remote KMS KeyID has changed, we need to rotate the local KEK
 | 
			
		||||
	lk := m.getLocalKEK()
 | 
			
		||||
	if lk.transformer != nil && r.KeyID != lk.remoteKMSResponse.KeyID {
 | 
			
		||||
		if err := m.rotateLocalKEK(ctx, r.KeyID); err != nil {
 | 
			
		||||
			klog.ErrorS(err, "failed to rotate local KEK", "expectedKeyID", r.KeyID, "currentKeyID", lk.remoteKMSResponse.KeyID)
 | 
			
		||||
			// if rotation fails, we will overwrite the keyID to the one we are currently using
 | 
			
		||||
			// for encryption as localKEKService is the authoritative source for the keyID.
 | 
			
		||||
			r.KeyID = lk.remoteKMSResponse.KeyID
 | 
			
		||||
			// TODO(aramase): we are currently not returning the error if rotation fails. We should
 | 
			
		||||
			// allow the failed state for an arbitrary time period and return the error if the state
 | 
			
		||||
			// is not eventually fixed.
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var aggregateHealthz []string
 | 
			
		||||
	if r.Healthz != "ok" {
 | 
			
		||||
		aggregateHealthz = append(aggregateHealthz, r.Healthz)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !m.isReady.Load() {
 | 
			
		||||
		// if the localKEKService is not ready, we will set the healthz status to not ready
 | 
			
		||||
		klog.V(2).InfoS("localKEKService is not ready", "keyID", r.KeyID)
 | 
			
		||||
		aggregateHealthz = append(aggregateHealthz, "localKEKService is not ready")
 | 
			
		||||
	}
 | 
			
		||||
	if len(aggregateHealthz) > 0 {
 | 
			
		||||
		r.Healthz = strings.Join(aggregateHealthz, "; ")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return r, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// rotateLocalKEK rotates the local KEK by generating a new local KEK and encrypting it with the
 | 
			
		||||
// remote KMS.
 | 
			
		||||
func (m *LocalKEKService) rotateLocalKEK(ctx context.Context, expectedKeyID string) error {
 | 
			
		||||
	uid := string(uuid.NewUUID())
 | 
			
		||||
	if err := m.generateLocalKEK(ctx, uid, expectedKeyID); err != nil {
 | 
			
		||||
		klog.V(2).ErrorS(err, "failed to generate local KEK as part of rotation", "uid", uid)
 | 
			
		||||
		return fmt.Errorf("[uid=%s] failed to generate local KEK as part of rotation: %w", uid, err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// generateLocalKEK generates a new local KEK and encrypts it with the remote KMS.
 | 
			
		||||
// if expectedKeyID is not empty, it will check if the keyID returned from the remote KMS matches
 | 
			
		||||
// the expected keyID. If the keyID does not match, it will continue using the existing local KEK
 | 
			
		||||
// and return an error.
 | 
			
		||||
func (m *LocalKEKService) generateLocalKEK(ctx context.Context, uid, expectedKeyID string) error {
 | 
			
		||||
	m.mu.Lock()
 | 
			
		||||
	defer m.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	lk := m.getLocalKEK()
 | 
			
		||||
	// if the localKEK was generated in the last pollInterval duration, we will not generate a new
 | 
			
		||||
	// localKEK. This is to avoid regenerating a new localKEK for queued requests.
 | 
			
		||||
	if lk.transformer != nil && m.clock.Since(lk.generatedAt) < m.pollInterval {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	key, err := generateKey(keyLength)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to generate local KEK: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	block, err := aes.NewCipher(key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to create cipher block: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp, err := m.remoteKMS.Encrypt(ctx, uid, key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to encrypt local KEK: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	if err = validateRemoteKMSEncryptResponse(resp); err != nil {
 | 
			
		||||
		return fmt.Errorf("invalid response from remote KMS: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	if expectedKeyID != "" && resp.KeyID != expectedKeyID {
 | 
			
		||||
		return fmt.Errorf("keyID returned from remote KMS does not match expected keyID")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := m.clock.Now()
 | 
			
		||||
	m.localKEKTracker.Store(&localKEK{
 | 
			
		||||
		encKEK:            resp.Ciphertext,
 | 
			
		||||
		expiry:            now.Add(m.keyMaxAge),
 | 
			
		||||
		usage:             atomic.Uint64{},
 | 
			
		||||
		transformer:       aestransformer.NewGCMTransformer(block),
 | 
			
		||||
		remoteKMSResponse: copyResponseAndAddLocalKEKAnnotation(resp),
 | 
			
		||||
		generatedAt:       now,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *LocalKEKService) getLocalKEK() *localKEK {
 | 
			
		||||
	lk := m.localKEKTracker.Load()
 | 
			
		||||
	if lk == nil {
 | 
			
		||||
		return &localKEK{}
 | 
			
		||||
	}
 | 
			
		||||
	return lk
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// copyResponseAndAddLocalKEKAnnotation returns a copy of the remoteKMSResponse with the
 | 
			
		||||
// referenceKEKAnnotationKey added to the annotations.
 | 
			
		||||
func copyResponseAndAddLocalKEKAnnotation(resp *service.EncryptResponse) *service.EncryptResponse {
 | 
			
		||||
	annotations := make(map[string][]byte, len(resp.Annotations)+numAnnotations)
 | 
			
		||||
	for s, bytes := range resp.Annotations {
 | 
			
		||||
		s := s
 | 
			
		||||
		bytes := bytes
 | 
			
		||||
		annotations[s] = bytes
 | 
			
		||||
	}
 | 
			
		||||
	annotations[referenceKEKAnnotationKey] = resp.Ciphertext
 | 
			
		||||
 | 
			
		||||
	return &service.EncryptResponse{
 | 
			
		||||
		// Ciphertext is not set on purpose - it is different per Encrypt call
 | 
			
		||||
		KeyID:       resp.KeyID,
 | 
			
		||||
		Annotations: annotations,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// copyStatusResponse returns a copy of the remote KMS status response.
 | 
			
		||||
func copyStatusResponse(resp *service.StatusResponse) *service.StatusResponse {
 | 
			
		||||
	return &service.StatusResponse{
 | 
			
		||||
		Healthz: resp.Healthz,
 | 
			
		||||
		Version: resp.Version,
 | 
			
		||||
		KeyID:   resp.KeyID,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// annotationsWithoutReferenceKeys returns a copy of the annotations without the reference implementation
 | 
			
		||||
// annotations.
 | 
			
		||||
func annotationsWithoutReferenceKeys(annotations map[string][]byte) map[string][]byte {
 | 
			
		||||
	if len(annotations) <= numAnnotations {
 | 
			
		||||
		return nil
 | 
			
		||||
@@ -241,7 +426,8 @@ func annotationsWithoutReferenceKeys(annotations map[string][]byte) map[string][
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateRemoteKMSResponse(resp *service.EncryptResponse) error {
 | 
			
		||||
// validateRemoteKMSEncryptResponse validates the EncryptResponse from the remote KMS.
 | 
			
		||||
func validateRemoteKMSEncryptResponse(resp *service.EncryptResponse) error {
 | 
			
		||||
	// validate annotations don't contain the reference implementation annotations
 | 
			
		||||
	for k := range resp.Annotations {
 | 
			
		||||
		if strings.HasSuffix(k, referenceSuffix) {
 | 
			
		||||
@@ -251,6 +437,14 @@ func validateRemoteKMSResponse(resp *service.EncryptResponse) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// validateRemoteKMSStatusResponse validates the StatusResponse from the remote KMS.
 | 
			
		||||
func validateRemoteKMSStatusResponse(resp *service.StatusResponse) error {
 | 
			
		||||
	if len(resp.KeyID) == 0 {
 | 
			
		||||
		return fmt.Errorf("keyID is empty")
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// generateKey generates a random key using system randomness.
 | 
			
		||||
func generateKey(length int) (key []byte, err error) {
 | 
			
		||||
	key = make([]byte, length)
 | 
			
		||||
 
 | 
			
		||||
@@ -17,18 +17,24 @@ limitations under the License.
 | 
			
		||||
package encryption
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"k8s.io/apimachinery/pkg/util/rand"
 | 
			
		||||
	"k8s.io/apimachinery/pkg/util/wait"
 | 
			
		||||
	"k8s.io/kms/service"
 | 
			
		||||
	testingclock "k8s.io/utils/clock/testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) {
 | 
			
		||||
	t.Parallel()
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		name  string
 | 
			
		||||
		input *service.EncryptResponse
 | 
			
		||||
@@ -89,6 +95,7 @@ func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) {
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		tc := tc
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			t.Parallel()
 | 
			
		||||
			got := copyResponseAndAddLocalKEKAnnotation(tc.input)
 | 
			
		||||
			if !reflect.DeepEqual(got, tc.want) {
 | 
			
		||||
				t.Errorf("copyResponseAndAddLocalKEKAnnotation(%v) = %v, want %v", tc.input, got, tc.want)
 | 
			
		||||
@@ -98,6 +105,7 @@ func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestAnnotationsWithoutReferenceKeys(t *testing.T) {
 | 
			
		||||
	t.Parallel()
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		name  string
 | 
			
		||||
		input map[string][]byte
 | 
			
		||||
@@ -135,6 +143,7 @@ func TestAnnotationsWithoutReferenceKeys(t *testing.T) {
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		tc := tc
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			t.Parallel()
 | 
			
		||||
			got := annotationsWithoutReferenceKeys(tc.input)
 | 
			
		||||
			if !reflect.DeepEqual(got, tc.want) {
 | 
			
		||||
				t.Errorf("annotationsWithoutReferenceKeys(%v) = %v, want %v", tc.input, got, tc.want)
 | 
			
		||||
@@ -143,7 +152,8 @@ func TestAnnotationsWithoutReferenceKeys(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValidateRemoteKMSResponse(t *testing.T) {
 | 
			
		||||
func TestValidateRemoteKMSEncryptResponse(t *testing.T) {
 | 
			
		||||
	t.Parallel()
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		name  string
 | 
			
		||||
		input *service.EncryptResponse
 | 
			
		||||
@@ -178,7 +188,8 @@ func TestValidateRemoteKMSResponse(t *testing.T) {
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		tc := tc
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			got := validateRemoteKMSResponse(tc.input)
 | 
			
		||||
			t.Parallel()
 | 
			
		||||
			got := validateRemoteKMSEncryptResponse(tc.input)
 | 
			
		||||
			if got != tc.want {
 | 
			
		||||
				t.Errorf("validateRemoteKMSResponse(%v) = %v, want %v", tc.input, got, tc.want)
 | 
			
		||||
			}
 | 
			
		||||
@@ -186,19 +197,66 @@ func TestValidateRemoteKMSResponse(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestValidateRemoteKMSStatusResponse(t *testing.T) {
 | 
			
		||||
	t.Parallel()
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		name    string
 | 
			
		||||
		input   *service.StatusResponse
 | 
			
		||||
		wantErr string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "keyID is empty",
 | 
			
		||||
			input: &service.StatusResponse{
 | 
			
		||||
				KeyID: "",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "keyID is empty",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "no error",
 | 
			
		||||
			input: &service.StatusResponse{
 | 
			
		||||
				KeyID: "keyID",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range testCases {
 | 
			
		||||
		tc := tc
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			t.Parallel()
 | 
			
		||||
			got := validateRemoteKMSStatusResponse(tc.input)
 | 
			
		||||
			if tc.wantErr != "" {
 | 
			
		||||
				if got == nil {
 | 
			
		||||
					t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
 | 
			
		||||
				}
 | 
			
		||||
				if !strings.Contains(got.Error(), tc.wantErr) {
 | 
			
		||||
					t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				if got != nil {
 | 
			
		||||
					t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ service.Service = &testRemoteService{}
 | 
			
		||||
 | 
			
		||||
type testRemoteService struct {
 | 
			
		||||
	mu sync.Mutex
 | 
			
		||||
 | 
			
		||||
	keyID    string
 | 
			
		||||
	disabled bool
 | 
			
		||||
	keyID            string
 | 
			
		||||
	disabled         bool
 | 
			
		||||
	encryptCallCount int
 | 
			
		||||
	decryptCallCount int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *testRemoteService) Encrypt(ctx context.Context, uid string, plaintext []byte) (*service.EncryptResponse, error) {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	s.encryptCallCount++
 | 
			
		||||
	if s.disabled {
 | 
			
		||||
		return nil, errors.New("failed to encrypt")
 | 
			
		||||
	}
 | 
			
		||||
@@ -215,6 +273,7 @@ func (s *testRemoteService) Decrypt(ctx context.Context, uid string, req *servic
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	s.decryptCallCount++
 | 
			
		||||
	if s.disabled {
 | 
			
		||||
		return nil, errors.New("failed to decrypt")
 | 
			
		||||
	}
 | 
			
		||||
@@ -231,82 +290,88 @@ func (s *testRemoteService) Status(ctx context.Context) (*service.StatusResponse
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if s.disabled {
 | 
			
		||||
		return nil, errors.New("failed to get status")
 | 
			
		||||
	}
 | 
			
		||||
	return &service.StatusResponse{
 | 
			
		||||
	resp := &service.StatusResponse{
 | 
			
		||||
		Version: "v2alpha1",
 | 
			
		||||
		Healthz: "ok",
 | 
			
		||||
		KeyID:   s.keyID,
 | 
			
		||||
	}, nil
 | 
			
		||||
	}
 | 
			
		||||
	if s.disabled {
 | 
			
		||||
		resp.Healthz = "remote KMS is disabled"
 | 
			
		||||
	}
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *testRemoteService) SetDisabledStatus(disabled bool) {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	s.disabled = true
 | 
			
		||||
	s.disabled = disabled
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *testRemoteService) SetKeyID(keyID string) {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	s.keyID = keyID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *testRemoteService) EncryptCallCount() int {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	return s.encryptCallCount
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *testRemoteService) DecryptCallCount() int {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	return s.decryptCallCount
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestEncrypt(t *testing.T) {
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(remoteKMS)
 | 
			
		||||
 | 
			
		||||
	validateResponse := func(got *service.EncryptResponse, t *testing.T) {
 | 
			
		||||
		if len(got.Annotations) != 2 {
 | 
			
		||||
			t.Fatalf("Encrypt() annotations = %v, want 2 annotations", got.Annotations)
 | 
			
		||||
		}
 | 
			
		||||
		if _, ok := got.Annotations[referenceKEKAnnotationKey]; !ok {
 | 
			
		||||
			t.Fatalf("Encrypt() annotations = %v, want %v", got.Annotations, referenceKEKAnnotationKey)
 | 
			
		||||
		}
 | 
			
		||||
		if got.KeyID != remoteKMS.keyID {
 | 
			
		||||
			t.Fatalf("Encrypt() keyID = %v, want %v", got.KeyID, remoteKMS.keyID)
 | 
			
		||||
		}
 | 
			
		||||
		if localKEKService.localTransformer == nil {
 | 
			
		||||
			t.Fatalf("Encrypt() localTransformer = %v, want non-nil", localKEKService.localTransformer)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx := testContext(t)
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(ctx, remoteKMS)
 | 
			
		||||
 | 
			
		||||
	waitUntilReady(t, localKEKService)
 | 
			
		||||
 | 
			
		||||
	// local KEK is generated and encryption is successful
 | 
			
		||||
	got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Encrypt() error = %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	validateResponse(got, t)
 | 
			
		||||
	validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService)
 | 
			
		||||
 | 
			
		||||
	// local KEK is used for encryption even when remote KMS is failing
 | 
			
		||||
	remoteKMS.SetDisabledStatus(true)
 | 
			
		||||
	if got, err = localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")); err != nil {
 | 
			
		||||
		t.Fatalf("Encrypt() error = %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	validateResponse(got, t)
 | 
			
		||||
	validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestEncryptError(t *testing.T) {
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(remoteKMS)
 | 
			
		||||
 | 
			
		||||
	ctx := testContext(t)
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(ctx, remoteKMS)
 | 
			
		||||
 | 
			
		||||
	localKEKGenerationPollTimeout = 5 * time.Second
 | 
			
		||||
	// first time local KEK generation fails because of remote KMS
 | 
			
		||||
	remoteKMS.SetDisabledStatus(true)
 | 
			
		||||
	_, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Fatalf("Encrypt() error = %v, want non-nil", err)
 | 
			
		||||
	}
 | 
			
		||||
	if localKEKService.localTransformer != nil {
 | 
			
		||||
		t.Fatalf("Encrypt() localTransformer = %v, want nil", localKEKService.localTransformer)
 | 
			
		||||
	lk := localKEKService.getLocalKEK()
 | 
			
		||||
	if lk.transformer != nil {
 | 
			
		||||
		t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", lk)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	remoteKMS.SetDisabledStatus(false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDecrypt(t *testing.T) {
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(remoteKMS)
 | 
			
		||||
 | 
			
		||||
	ctx := testContext(t)
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(ctx, remoteKMS)
 | 
			
		||||
 | 
			
		||||
	waitUntilReady(t, localKEKService)
 | 
			
		||||
 | 
			
		||||
	// local KEK is generated and encryption/decryption is successful
 | 
			
		||||
	got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
 | 
			
		||||
@@ -337,10 +402,11 @@ func TestDecrypt(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDecryptError(t *testing.T) {
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(remoteKMS)
 | 
			
		||||
 | 
			
		||||
	ctx := testContext(t)
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(ctx, remoteKMS)
 | 
			
		||||
 | 
			
		||||
	waitUntilReady(t, localKEKService)
 | 
			
		||||
 | 
			
		||||
	got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -353,6 +419,10 @@ func TestDecryptError(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
	// local KEK for decryption not in cache and remote KMS is failing
 | 
			
		||||
	remoteKMS.SetDisabledStatus(true)
 | 
			
		||||
	lk := localKEKService.localKEKTracker.Load()
 | 
			
		||||
	lk.transformer = nil
 | 
			
		||||
	localKEKService.localKEKTracker.Store(lk)
 | 
			
		||||
 | 
			
		||||
	// clear the cache
 | 
			
		||||
	localKEKService.transformers.Clear()
 | 
			
		||||
	if _, err = localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err == nil {
 | 
			
		||||
@@ -361,30 +431,318 @@ func TestDecryptError(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStatus(t *testing.T) {
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := NewLocalKEKService(remoteKMS)
 | 
			
		||||
 | 
			
		||||
	ctx := testContext(t)
 | 
			
		||||
	fakeClock := testingclock.NewFakeClock(time.Now())
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Second, 100*time.Millisecond, fakeClock)
 | 
			
		||||
 | 
			
		||||
	waitUntilReady(t, localKEKService)
 | 
			
		||||
 | 
			
		||||
	got, err := localKEKService.Status(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Status() error = %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if got.Version != "v2alpha1" {
 | 
			
		||||
		t.Fatalf("Status() version = %v, want %v", got.Version, "v2alpha1")
 | 
			
		||||
	}
 | 
			
		||||
	if got.Healthz != "ok" {
 | 
			
		||||
		t.Fatalf("Status() healthz = %v, want %v", got.Healthz, "ok")
 | 
			
		||||
	}
 | 
			
		||||
	if got.KeyID != "test-key-id" {
 | 
			
		||||
		t.Fatalf("Status() keyID = %v, want %v", got.KeyID, "test-key-id")
 | 
			
		||||
	}
 | 
			
		||||
	validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id")
 | 
			
		||||
 | 
			
		||||
	fakeClock.Step(1 * time.Second)
 | 
			
		||||
	// remote KMS is failing
 | 
			
		||||
	remoteKMS.SetDisabledStatus(true)
 | 
			
		||||
	if _, err = localKEKService.Status(ctx); err == nil {
 | 
			
		||||
		t.Fatalf("Status() error = %v, want non-nil", err)
 | 
			
		||||
	// remote KMS keyID changed but local KEK not rotated because of remote KMS failure
 | 
			
		||||
	// the keyID in status should be the old keyID
 | 
			
		||||
	// the error should still be nil
 | 
			
		||||
	remoteKMS.SetKeyID("test-key-id-2")
 | 
			
		||||
 | 
			
		||||
	if got, err = localKEKService.Status(ctx); err != nil {
 | 
			
		||||
		t.Fatalf("Status() error = %v, want nil", err)
 | 
			
		||||
	}
 | 
			
		||||
	validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled", "test-key-id")
 | 
			
		||||
 | 
			
		||||
	fakeClock.Step(1 * time.Second)
 | 
			
		||||
	// wait for local KEK to expire and local KEK service ready to be false
 | 
			
		||||
	wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
 | 
			
		||||
		return !localKEKService.isReady.Load(), nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// status response should include the localKEK unhealthy status
 | 
			
		||||
	if got, err = localKEKService.Status(ctx); err != nil {
 | 
			
		||||
		t.Fatalf("Status() error = %v, want nil", err)
 | 
			
		||||
	}
 | 
			
		||||
	validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled; localKEKService is not ready", "test-key-id")
 | 
			
		||||
 | 
			
		||||
	// remote KMS is functional again, local KEK is rotated
 | 
			
		||||
	remoteKMS.SetDisabledStatus(false)
 | 
			
		||||
	fakeClock.Step(1 * time.Second)
 | 
			
		||||
	waitUntilReady(t, localKEKService)
 | 
			
		||||
	if got, err = localKEKService.Status(ctx); err != nil {
 | 
			
		||||
		t.Fatalf("Status() error = %v, want nil", err)
 | 
			
		||||
	}
 | 
			
		||||
	validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id-2")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRotationKeyUsage(t *testing.T) {
 | 
			
		||||
	ctx := testContext(t)
 | 
			
		||||
 | 
			
		||||
	var record sync.Map
 | 
			
		||||
 | 
			
		||||
	fakeClock := testingclock.NewFakeClock(time.Now())
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock)
 | 
			
		||||
	waitUntilReady(t, localKEKService)
 | 
			
		||||
	lk := localKEKService.localKEKTracker.Load()
 | 
			
		||||
	encLocalKEK := lk.encKEK
 | 
			
		||||
 | 
			
		||||
	// check only single call for Encrypt to remote KMS
 | 
			
		||||
	if remoteKMS.EncryptCallCount() != 1 {
 | 
			
		||||
		t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	for i := 0; i < 6; i++ {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32)))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatalf("Encrypt() error = %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
 | 
			
		||||
				t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
 | 
			
		||||
			}
 | 
			
		||||
			record.Store(resp, nil)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	fakeClock.Step(30 * time.Second)
 | 
			
		||||
	rotated := false
 | 
			
		||||
	// wait for the local KEK to be rotated
 | 
			
		||||
	wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
 | 
			
		||||
		// local KEK must have been rotated after 5 usages
 | 
			
		||||
		lk = localKEKService.localKEKTracker.Load()
 | 
			
		||||
		rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
 | 
			
		||||
		return rotated, nil
 | 
			
		||||
	})
 | 
			
		||||
	if !rotated {
 | 
			
		||||
		t.Fatalf("local KEK must have been rotated")
 | 
			
		||||
	}
 | 
			
		||||
	if remoteKMS.EncryptCallCount() != 2 {
 | 
			
		||||
		t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 2)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// new local KEK must be used for encryption
 | 
			
		||||
	for i := 0; i < 5; i++ {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32)))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatalf("Encrypt() error = %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
 | 
			
		||||
				t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
 | 
			
		||||
			}
 | 
			
		||||
			record.Store(resp, nil)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	// check we can decrypt data encrypted with the old and new local KEKs
 | 
			
		||||
	record.Range(func(key, _ any) bool {
 | 
			
		||||
		k := key.(*service.EncryptResponse)
 | 
			
		||||
		decryptRequest := &service.DecryptRequest{
 | 
			
		||||
			Ciphertext:  k.Ciphertext,
 | 
			
		||||
			Annotations: k.Annotations,
 | 
			
		||||
			KeyID:       k.KeyID,
 | 
			
		||||
		}
 | 
			
		||||
		if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
 | 
			
		||||
			t.Fatalf("Decrypt() error = %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// Out of the 11 calls to Decrypt:
 | 
			
		||||
	// - 5 should be using the current local KEK
 | 
			
		||||
	// - 1 out of the 6 should generate a decrypt call to the remote KMS as the local KEK not in cache
 | 
			
		||||
	// - 5 out of the 6 should use the cached local KEK after 1st decrypt call to the remote KMS
 | 
			
		||||
	assertCallCount(t, remoteKMS, localKEKService)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRotationKeyExpiry(t *testing.T) {
 | 
			
		||||
	ctx := testContext(t)
 | 
			
		||||
 | 
			
		||||
	var record sync.Map
 | 
			
		||||
 | 
			
		||||
	fakeClock := testingclock.NewFakeClock(time.Now())
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 5*time.Second, 100*time.Millisecond, fakeClock)
 | 
			
		||||
	waitUntilReady(t, localKEKService)
 | 
			
		||||
	lk := localKEKService.localKEKTracker.Load()
 | 
			
		||||
	encLocalKEK := lk.encKEK
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	for i := 0; i < 3; i++ {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatalf("Encrypt() error = %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
 | 
			
		||||
				t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
 | 
			
		||||
			}
 | 
			
		||||
			record.Store(resp, nil)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	// check local KEK has only been used 3 times and still under the suggested usage
 | 
			
		||||
	if lk.usage.Load() != 3 {
 | 
			
		||||
		t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// advance the clock to trigger key expiry
 | 
			
		||||
	fakeClock.Step(6 * time.Second)
 | 
			
		||||
 | 
			
		||||
	rotated := false
 | 
			
		||||
	// wait for the local KEK to be rotated due to key expiry
 | 
			
		||||
	wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
 | 
			
		||||
		// local KEK must have been rotated after the key max age
 | 
			
		||||
		t.Log("waiting for local KEK to be rotated")
 | 
			
		||||
		lk = localKEKService.localKEKTracker.Load()
 | 
			
		||||
		rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
 | 
			
		||||
		return rotated, nil
 | 
			
		||||
	})
 | 
			
		||||
	if !rotated {
 | 
			
		||||
		t.Fatalf("local KEK must have been rotated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// new local KEK must be used for encryption
 | 
			
		||||
	for i := 0; i < 5; i++ {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatalf("Encrypt() error = %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
 | 
			
		||||
				t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
 | 
			
		||||
			}
 | 
			
		||||
			record.Store(resp, nil)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	// check we can decrypt data encrypted with the old and new local KEKs
 | 
			
		||||
	record.Range(func(key, _ any) bool {
 | 
			
		||||
		k := key.(*service.EncryptResponse)
 | 
			
		||||
		decryptRequest := &service.DecryptRequest{
 | 
			
		||||
			Ciphertext:  k.Ciphertext,
 | 
			
		||||
			Annotations: k.Annotations,
 | 
			
		||||
			KeyID:       k.KeyID,
 | 
			
		||||
		}
 | 
			
		||||
		if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
 | 
			
		||||
			t.Fatalf("Decrypt() error = %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// Out of the 8 calls to Decrypt:
 | 
			
		||||
	// - 5 should be using the current local KEK
 | 
			
		||||
	// - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache
 | 
			
		||||
	// - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS
 | 
			
		||||
	assertCallCount(t, remoteKMS, localKEKService)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRotationRemoteKeyIDChanged(t *testing.T) {
 | 
			
		||||
	ctx := testContext(t)
 | 
			
		||||
 | 
			
		||||
	var record sync.Map
 | 
			
		||||
 | 
			
		||||
	fakeClock := testingclock.NewFakeClock(time.Now())
 | 
			
		||||
	remoteKMS := &testRemoteService{keyID: "test-key-id"}
 | 
			
		||||
	localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock)
 | 
			
		||||
	waitUntilReady(t, localKEKService)
 | 
			
		||||
	lk := localKEKService.localKEKTracker.Load()
 | 
			
		||||
	encLocalKEK := lk.encKEK
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	for i := 0; i < 3; i++ {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatalf("Encrypt() error = %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
 | 
			
		||||
				t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
 | 
			
		||||
			}
 | 
			
		||||
			record.Store(resp, nil)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	// check local KEK has only been used 3 times and still under the suggested usage
 | 
			
		||||
	if lk.usage.Load() != 3 {
 | 
			
		||||
		t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fakeClock.Step(30 * time.Second)
 | 
			
		||||
	// change the remote key ID
 | 
			
		||||
	remoteKMS.SetKeyID("test-key-id-2")
 | 
			
		||||
	if _, err := localKEKService.Status(ctx); err != nil {
 | 
			
		||||
		t.Fatalf("Status() error = %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rotated := false
 | 
			
		||||
	// wait for the local KEK to be rotated due to remote key ID change
 | 
			
		||||
	wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
 | 
			
		||||
		lk = localKEKService.localKEKTracker.Load()
 | 
			
		||||
		rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
 | 
			
		||||
		return rotated, nil
 | 
			
		||||
	})
 | 
			
		||||
	if !rotated {
 | 
			
		||||
		t.Fatalf("local KEK must have been rotated")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// new local KEK must be used for encryption
 | 
			
		||||
	for i := 0; i < 5; i++ {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatalf("Encrypt() error = %v", err)
 | 
			
		||||
			}
 | 
			
		||||
			if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
 | 
			
		||||
				t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
 | 
			
		||||
			}
 | 
			
		||||
			record.Store(resp, nil)
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	// check we can decrypt data encrypted with the old and new local KEKs
 | 
			
		||||
	record.Range(func(key, _ any) bool {
 | 
			
		||||
		k := key.(*service.EncryptResponse)
 | 
			
		||||
		decryptRequest := &service.DecryptRequest{
 | 
			
		||||
			Ciphertext:  k.Ciphertext,
 | 
			
		||||
			Annotations: k.Annotations,
 | 
			
		||||
			KeyID:       k.KeyID,
 | 
			
		||||
		}
 | 
			
		||||
		if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
 | 
			
		||||
			t.Fatalf("Decrypt() error = %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// Out of the 8 calls to Decrypt:
 | 
			
		||||
	// - 5 should be using the current local KEK
 | 
			
		||||
	// - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache
 | 
			
		||||
	// - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS
 | 
			
		||||
	assertCallCount(t, remoteKMS, localKEKService)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testContext(t *testing.T) context.Context {
 | 
			
		||||
@@ -392,3 +750,49 @@ func testContext(t *testing.T) context.Context {
 | 
			
		||||
	t.Cleanup(cancel)
 | 
			
		||||
	return ctx
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func waitUntilReady(t *testing.T, s *LocalKEKService) {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
 | 
			
		||||
		return s.isReady.Load(), nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateEncryptResponse(t *testing.T, got *service.EncryptResponse, wantKeyID string, localKEKService *LocalKEKService) {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	if len(got.Annotations) != 2 {
 | 
			
		||||
		t.Fatalf("Encrypt() annotations = %v, want 2 annotations", got.Annotations)
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := got.Annotations[referenceKEKAnnotationKey]; !ok {
 | 
			
		||||
		t.Fatalf("Encrypt() annotations = %v, want %v", got.Annotations, referenceKEKAnnotationKey)
 | 
			
		||||
	}
 | 
			
		||||
	if got.KeyID != wantKeyID {
 | 
			
		||||
		t.Fatalf("Encrypt() keyID = %v, want %v", got.KeyID, wantKeyID)
 | 
			
		||||
	}
 | 
			
		||||
	if localKEKService.localKEKTracker.Load() == nil {
 | 
			
		||||
		t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", localKEKService.localKEKTracker.Load())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateStatusResponse(t *testing.T, got *service.StatusResponse, wantVersion, wantHealthz, wantKeyID string) {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	if got.Version != wantVersion {
 | 
			
		||||
		t.Fatalf("Status() version = %v, want %v", got.Version, wantVersion)
 | 
			
		||||
	}
 | 
			
		||||
	if !strings.EqualFold(got.Healthz, wantHealthz) {
 | 
			
		||||
		t.Fatalf("Status() healthz = %v, want %v", got.Healthz, wantHealthz)
 | 
			
		||||
	}
 | 
			
		||||
	if got.KeyID != wantKeyID {
 | 
			
		||||
		t.Fatalf("Status() keyID = %v, want %v", got.KeyID, wantKeyID)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func assertCallCount(t *testing.T, remoteKMS *testRemoteService, localKEKService *LocalKEKService) {
 | 
			
		||||
	t.Helper()
 | 
			
		||||
	if remoteKMS.DecryptCallCount() != 1 {
 | 
			
		||||
		t.Fatalf("Decrypt() remoteKMS.DecryptCallCount() = %v, want %v", remoteKMS.DecryptCallCount(), 1)
 | 
			
		||||
	}
 | 
			
		||||
	if localKEKService.transformers.Len() != 1 {
 | 
			
		||||
		t.Fatalf("Decrypt() localKEKService.transformers.Len() = %v, want %v", localKEKService.transformers.Len(), 1)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,7 @@ require (
 | 
			
		||||
require (
 | 
			
		||||
	github.com/go-logr/logr v1.2.3 // indirect
 | 
			
		||||
	github.com/golang/protobuf v1.5.2 // indirect
 | 
			
		||||
	github.com/google/uuid v1.3.0 // indirect
 | 
			
		||||
	golang.org/x/net v0.7.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.5.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.7.0 // indirect
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								staging/src/k8s.io/kms/go.sum
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								staging/src/k8s.io/kms/go.sum
									
									
									
										generated
									
									
									
								
							@@ -49,6 +49,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
 | 
			
		||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
 | 
			
		||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
 | 
			
		||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
 | 
			
		||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
 | 
			
		||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user