mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	Refactor convergent encryption to make specifying a nonce in addition to context possible
This commit is contained in:
		@@ -601,6 +601,7 @@ func TestConvergentEncryption(t *testing.T) {
 | 
				
			|||||||
		Data: map[string]interface{}{
 | 
							Data: map[string]interface{}{
 | 
				
			||||||
			"derived":               false,
 | 
								"derived":               false,
 | 
				
			||||||
			"convergent_encryption": true,
 | 
								"convergent_encryption": true,
 | 
				
			||||||
 | 
								"context_as_nonce":      true,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -619,6 +620,7 @@ func TestConvergentEncryption(t *testing.T) {
 | 
				
			|||||||
	req.Data = map[string]interface{}{
 | 
						req.Data = map[string]interface{}{
 | 
				
			||||||
		"derived":               true,
 | 
							"derived":               true,
 | 
				
			||||||
		"convergent_encryption": true,
 | 
							"convergent_encryption": true,
 | 
				
			||||||
 | 
							"context_as_nonce":      true,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp, err = b.HandleRequest(req)
 | 
						resp, err = b.HandleRequest(req)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -105,42 +105,42 @@ func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
 | 
				
			|||||||
// is needed (for instance, for an upgrade/migration), give up the read lock,
 | 
					// is needed (for instance, for an upgrade/migration), give up the read lock,
 | 
				
			||||||
// call again with an exclusive lock, then swap back out for a read lock.
 | 
					// call again with an exclusive lock, then swap back out for a read lock.
 | 
				
			||||||
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
 | 
					func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
 | 
				
			||||||
	p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, shared)
 | 
						p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, false, shared)
 | 
				
			||||||
	if err == nil ||
 | 
						if err == nil ||
 | 
				
			||||||
		(err != nil && err != errNeedExclusiveLock) {
 | 
							(err != nil && err != errNeedExclusiveLock) {
 | 
				
			||||||
		return p, lock, err
 | 
							return p, lock, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Try again while asking for an exlusive lock
 | 
						// Try again while asking for an exlusive lock
 | 
				
			||||||
	p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, exclusive)
 | 
						p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, false, exclusive)
 | 
				
			||||||
	if err != nil || p == nil || lock == nil {
 | 
						if err != nil || p == nil || lock == nil {
 | 
				
			||||||
		return p, lock, err
 | 
							return p, lock, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	lock.Unlock()
 | 
						lock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, shared)
 | 
						p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, false, shared)
 | 
				
			||||||
	return p, lock, err
 | 
						return p, lock, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Get the policy with an exclusive lock
 | 
					// Get the policy with an exclusive lock
 | 
				
			||||||
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
 | 
					func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
 | 
				
			||||||
	p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, exclusive)
 | 
						p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, false, exclusive)
 | 
				
			||||||
	return p, lock, err
 | 
						return p, lock, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Get the policy with a read lock; if it returns that an exclusive lock is
 | 
					// Get the policy with a read lock; if it returns that an exclusive lock is
 | 
				
			||||||
// needed, retry. If successful, call one more time to get a read lock and
 | 
					// needed, retry. If successful, call one more time to get a read lock and
 | 
				
			||||||
// return the value.
 | 
					// return the value.
 | 
				
			||||||
func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool, convergent bool) (*Policy, *sync.RWMutex, bool, error) {
 | 
					func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived, convergent, contextAsNonce bool) (*Policy, *sync.RWMutex, bool, error) {
 | 
				
			||||||
	p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, convergent, shared)
 | 
						p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, shared)
 | 
				
			||||||
	if err == nil ||
 | 
						if err == nil ||
 | 
				
			||||||
		(err != nil && err != errNeedExclusiveLock) {
 | 
							(err != nil && err != errNeedExclusiveLock) {
 | 
				
			||||||
		return p, lock, false, err
 | 
							return p, lock, false, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Try again while asking for an exlusive lock
 | 
						// Try again while asking for an exlusive lock
 | 
				
			||||||
	p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, convergent, exclusive)
 | 
						p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, exclusive)
 | 
				
			||||||
	if err != nil || p == nil || lock == nil {
 | 
						if err != nil || p == nil || lock == nil {
 | 
				
			||||||
		return p, lock, upserted, err
 | 
							return p, lock, upserted, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -148,14 +148,14 @@ func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, der
 | 
				
			|||||||
	lock.Unlock()
 | 
						lock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Now get a shared lock for the return, but preserve the value of upsert
 | 
						// Now get a shared lock for the return, but preserve the value of upsert
 | 
				
			||||||
	p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, convergent, shared)
 | 
						p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, shared)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return p, lock, upserted, err
 | 
						return p, lock, upserted, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// When the function returns, a lock will be held on the policy if err == nil.
 | 
					// When the function returns, a lock will be held on the policy if err == nil.
 | 
				
			||||||
// It is the caller's responsibility to unlock.
 | 
					// It is the caller's responsibility to unlock.
 | 
				
			||||||
func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, convergent, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
 | 
					func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, convergent, contextAsNonce, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
 | 
				
			||||||
	lock := lm.policyLock(name, lockType)
 | 
						lock := lm.policyLock(name, lockType)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var p *Policy
 | 
						var p *Policy
 | 
				
			||||||
@@ -204,6 +204,8 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups
 | 
				
			|||||||
		if derived {
 | 
							if derived {
 | 
				
			||||||
			p.KDFMode = kdfMode
 | 
								p.KDFMode = kdfMode
 | 
				
			||||||
			p.ConvergentEncryption = convergent
 | 
								p.ConvergentEncryption = convergent
 | 
				
			||||||
 | 
								p.ContextAsNonce = new(bool)
 | 
				
			||||||
 | 
								*p.ContextAsNonce = contextAsNonce
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		err = p.rotate(storage)
 | 
							err = p.rotate(storage)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -100,7 +100,7 @@ func (b *backend) pathDatakeyWrite(
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
 | 
						ciphertext, err := p.Encrypt(context, nil, base64.StdEncoding.EncodeToString(newKey))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		switch err.(type) {
 | 
							switch err.(type) {
 | 
				
			||||||
		case errutil.UserError:
 | 
							case errutil.UserError:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -27,6 +27,11 @@ func (b *backend) pathDecrypt() *framework.Path {
 | 
				
			|||||||
				Type:        framework.TypeString,
 | 
									Type:        framework.TypeString,
 | 
				
			||||||
				Description: "Context for key derivation. Required for derived keys.",
 | 
									Description: "Context for key derivation. Required for derived keys.",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								"nonce": &framework.FieldSchema{
 | 
				
			||||||
 | 
									Type:        framework.TypeString,
 | 
				
			||||||
 | 
									Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Callbacks: map[logical.Operation]framework.OperationFunc{
 | 
							Callbacks: map[logical.Operation]framework.OperationFunc{
 | 
				
			||||||
@@ -46,17 +51,28 @@ func (b *backend) pathDecryptWrite(
 | 
				
			|||||||
		return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
 | 
							return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Decode the context if any
 | 
						// Decode the context if any
 | 
				
			||||||
	contextRaw := d.Get("context").(string)
 | 
						contextRaw := d.Get("context").(string)
 | 
				
			||||||
	var context []byte
 | 
						var context []byte
 | 
				
			||||||
	if len(contextRaw) != 0 {
 | 
						if len(contextRaw) != 0 {
 | 
				
			||||||
		var err error
 | 
					 | 
				
			||||||
		context, err = base64.StdEncoding.DecodeString(contextRaw)
 | 
							context, err = base64.StdEncoding.DecodeString(contextRaw)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
 | 
								return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Decode the nonce if any
 | 
				
			||||||
 | 
						nonceRaw := d.Get("nonce").(string)
 | 
				
			||||||
 | 
						var nonce []byte
 | 
				
			||||||
 | 
						if len(nonceRaw) != 0 {
 | 
				
			||||||
 | 
							nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get the policy
 | 
						// Get the policy
 | 
				
			||||||
	p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
 | 
						p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
 | 
				
			||||||
	if lock != nil {
 | 
						if lock != nil {
 | 
				
			||||||
@@ -69,7 +85,7 @@ func (b *backend) pathDecryptWrite(
 | 
				
			|||||||
		return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
 | 
							return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	plaintext, err := p.Decrypt(context, ciphertext)
 | 
						plaintext, err := p.Decrypt(context, nonce, ciphertext)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		switch err.(type) {
 | 
							switch err.(type) {
 | 
				
			||||||
		case errutil.UserError:
 | 
							case errutil.UserError:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,6 +28,11 @@ func (b *backend) pathEncrypt() *framework.Path {
 | 
				
			|||||||
				Type:        framework.TypeString,
 | 
									Type:        framework.TypeString,
 | 
				
			||||||
				Description: "Context for key derivation. Required for derived keys.",
 | 
									Description: "Context for key derivation. Required for derived keys.",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								"nonce": &framework.FieldSchema{
 | 
				
			||||||
 | 
									Type:        framework.TypeString,
 | 
				
			||||||
 | 
									Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Callbacks: map[logical.Operation]framework.OperationFunc{
 | 
							Callbacks: map[logical.Operation]framework.OperationFunc{
 | 
				
			||||||
@@ -63,10 +68,11 @@ func (b *backend) pathEncryptWrite(
 | 
				
			|||||||
		return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
 | 
							return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Decode the context if any
 | 
						// Decode the context if any
 | 
				
			||||||
	contextRaw := d.Get("context").(string)
 | 
						contextRaw := d.Get("context").(string)
 | 
				
			||||||
	var context []byte
 | 
						var context []byte
 | 
				
			||||||
	var err error
 | 
					 | 
				
			||||||
	if len(contextRaw) != 0 {
 | 
						if len(contextRaw) != 0 {
 | 
				
			||||||
		context, err = base64.StdEncoding.DecodeString(contextRaw)
 | 
							context, err = base64.StdEncoding.DecodeString(contextRaw)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
@@ -74,12 +80,22 @@ func (b *backend) pathEncryptWrite(
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Decode the nonce if any
 | 
				
			||||||
 | 
						nonceRaw := d.Get("nonce").(string)
 | 
				
			||||||
 | 
						var nonce []byte
 | 
				
			||||||
 | 
						if len(nonceRaw) != 0 {
 | 
				
			||||||
 | 
							nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get the policy
 | 
						// Get the policy
 | 
				
			||||||
	var p *Policy
 | 
						var p *Policy
 | 
				
			||||||
	var lock *sync.RWMutex
 | 
						var lock *sync.RWMutex
 | 
				
			||||||
	var upserted bool
 | 
						var upserted bool
 | 
				
			||||||
	if req.Operation == logical.CreateOperation {
 | 
						if req.Operation == logical.CreateOperation {
 | 
				
			||||||
		p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0, false)
 | 
							p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0, false, false)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		p, lock, err = b.lm.GetPolicyShared(req.Storage, name)
 | 
							p, lock, err = b.lm.GetPolicyShared(req.Storage, name)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -93,7 +109,7 @@ func (b *backend) pathEncryptWrite(
 | 
				
			|||||||
		return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
 | 
							return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ciphertext, err := p.Encrypt(context, value)
 | 
						ciphertext, err := p.Encrypt(context, nonce, value)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		switch err.(type) {
 | 
							switch err.(type) {
 | 
				
			||||||
		case errutil.UserError:
 | 
							case errutil.UserError:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,21 +19,40 @@ func (b *backend) pathKeys() *framework.Path {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			"derived": &framework.FieldSchema{
 | 
								"derived": &framework.FieldSchema{
 | 
				
			||||||
				Type: framework.TypeBool,
 | 
									Type: framework.TypeBool,
 | 
				
			||||||
				Description: "Enables key derivation mode. This allows for per-transaction unique keys",
 | 
									Description: `Enables key derivation mode. This
 | 
				
			||||||
 | 
					allows for per-transaction unique keys.`,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			"convergent_encryption": &framework.FieldSchema{
 | 
								"convergent_encryption": &framework.FieldSchema{
 | 
				
			||||||
				Type: framework.TypeBool,
 | 
									Type: framework.TypeBool,
 | 
				
			||||||
				Description: `Whether to use convergent encryption.
 | 
									Description: `Whether to support convergent encryption.
 | 
				
			||||||
This is only supported when using a key with
 | 
					This is only supported when using a key with
 | 
				
			||||||
key derivation enabled and will require all
 | 
					key derivation enabled and will require all
 | 
				
			||||||
context values to be 96 bits (12 bytes) when
 | 
					requests to carry both a context and 96-bit
 | 
				
			||||||
base64-decoded. This mode ensures that when
 | 
					(12-byte) nonce, unless the "context_as_nonce"
 | 
				
			||||||
the same context is supplied, the same
 | 
					feature is also enabled. The given nonce will
 | 
				
			||||||
ciphertext is emitted from the encryption
 | 
					be used in place of a randomly generated nonce.
 | 
				
			||||||
function. It is *very important* when using
 | 
					As a result, when the same context and nonce
 | 
				
			||||||
this mode that you ensure that all contexts
 | 
					(or context, if "context_as_nonce" is enabled)
 | 
				
			||||||
are *globally unique*. Failing to do so will
 | 
					are supplied, the same ciphertext is emitted
 | 
				
			||||||
 | 
					from the encryption function. It is *very
 | 
				
			||||||
 | 
					important* when using this mode that you ensure
 | 
				
			||||||
 | 
					that all nonces are unique for a given context,
 | 
				
			||||||
 | 
					or, when using "context_as_nonce", that all
 | 
				
			||||||
 | 
					contexts are unique for a given key. Failing to
 | 
				
			||||||
 | 
					do so will severely impact the ciphertext's
 | 
				
			||||||
 | 
					security.`,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								"context_as_nonce": &framework.FieldSchema{
 | 
				
			||||||
 | 
									Type: framework.TypeBool,
 | 
				
			||||||
 | 
									Description: `Whether to use the context value as the
 | 
				
			||||||
 | 
					nonce in the convergent encryption operation
 | 
				
			||||||
 | 
					mode. If set true, the user will have to
 | 
				
			||||||
 | 
					supply a 96-bit (12-byte) context value.
 | 
				
			||||||
 | 
					It is *very important* when using this
 | 
				
			||||||
 | 
					mode that you ensure that all contexts are
 | 
				
			||||||
 | 
					*globally unique*. Failing to do so will
 | 
				
			||||||
severely impact the security of the key.`,
 | 
					severely impact the security of the key.`,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
@@ -54,12 +73,13 @@ func (b *backend) pathPolicyWrite(
 | 
				
			|||||||
	name := d.Get("name").(string)
 | 
						name := d.Get("name").(string)
 | 
				
			||||||
	derived := d.Get("derived").(bool)
 | 
						derived := d.Get("derived").(bool)
 | 
				
			||||||
	convergent := d.Get("convergent_encryption").(bool)
 | 
						convergent := d.Get("convergent_encryption").(bool)
 | 
				
			||||||
 | 
						contextAsNonce := d.Get("context_as_nonce").(bool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !derived && convergent {
 | 
						if !derived && convergent {
 | 
				
			||||||
		return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
 | 
							return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived, convergent)
 | 
						p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived, convergent, contextAsNonce)
 | 
				
			||||||
	if lock != nil {
 | 
						if lock != nil {
 | 
				
			||||||
		defer lock.RUnlock()
 | 
							defer lock.RUnlock()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -107,6 +127,9 @@ func (b *backend) pathPolicyRead(
 | 
				
			|||||||
	if p.Derived {
 | 
						if p.Derived {
 | 
				
			||||||
		resp.Data["kdf_mode"] = p.KDFMode
 | 
							resp.Data["kdf_mode"] = p.KDFMode
 | 
				
			||||||
		resp.Data["convergent_encryption"] = p.ConvergentEncryption
 | 
							resp.Data["convergent_encryption"] = p.ConvergentEncryption
 | 
				
			||||||
 | 
							if p.ContextAsNonce != nil {
 | 
				
			||||||
 | 
								resp.Data["context_as_nonce"] = *p.ContextAsNonce
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	retKeys := map[string]int64{}
 | 
						retKeys := map[string]int64{}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -27,6 +27,11 @@ func (b *backend) pathRewrap() *framework.Path {
 | 
				
			|||||||
				Type:        framework.TypeString,
 | 
									Type:        framework.TypeString,
 | 
				
			||||||
				Description: "Context for key derivation. Required for derived keys.",
 | 
									Description: "Context for key derivation. Required for derived keys.",
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								"nonce": &framework.FieldSchema{
 | 
				
			||||||
 | 
									Type:        framework.TypeString,
 | 
				
			||||||
 | 
									Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Callbacks: map[logical.Operation]framework.OperationFunc{
 | 
							Callbacks: map[logical.Operation]framework.OperationFunc{
 | 
				
			||||||
@@ -47,17 +52,28 @@ func (b *backend) pathRewrapWrite(
 | 
				
			|||||||
		return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
 | 
							return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Decode the context if any
 | 
						// Decode the context if any
 | 
				
			||||||
	contextRaw := d.Get("context").(string)
 | 
						contextRaw := d.Get("context").(string)
 | 
				
			||||||
	var context []byte
 | 
						var context []byte
 | 
				
			||||||
	if len(contextRaw) != 0 {
 | 
						if len(contextRaw) != 0 {
 | 
				
			||||||
		var err error
 | 
					 | 
				
			||||||
		context, err = base64.StdEncoding.DecodeString(contextRaw)
 | 
							context, err = base64.StdEncoding.DecodeString(contextRaw)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
 | 
								return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Decode the nonce if any
 | 
				
			||||||
 | 
						nonceRaw := d.Get("nonce").(string)
 | 
				
			||||||
 | 
						var nonce []byte
 | 
				
			||||||
 | 
						if len(nonceRaw) != 0 {
 | 
				
			||||||
 | 
							nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get the policy
 | 
						// Get the policy
 | 
				
			||||||
	p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
 | 
						p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
 | 
				
			||||||
	if lock != nil {
 | 
						if lock != nil {
 | 
				
			||||||
@@ -71,7 +87,7 @@ func (b *backend) pathRewrapWrite(
 | 
				
			|||||||
		return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
 | 
							return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	plaintext, err := p.Decrypt(context, value)
 | 
						plaintext, err := p.Decrypt(context, nonce, value)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		switch err.(type) {
 | 
							switch err.(type) {
 | 
				
			||||||
		case errutil.UserError:
 | 
							case errutil.UserError:
 | 
				
			||||||
@@ -87,7 +103,7 @@ func (b *backend) pathRewrapWrite(
 | 
				
			|||||||
		return nil, fmt.Errorf("empty plaintext returned during rewrap")
 | 
							return nil, fmt.Errorf("empty plaintext returned during rewrap")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ciphertext, err := p.Encrypt(context, plaintext)
 | 
						ciphertext, err := p.Encrypt(context, nonce, plaintext)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		switch err.(type) {
 | 
							switch err.(type) {
 | 
				
			||||||
		case errutil.UserError:
 | 
							case errutil.UserError:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -72,6 +72,7 @@ type Policy struct {
 | 
				
			|||||||
	Derived              bool   `json:"derived"`
 | 
						Derived              bool   `json:"derived"`
 | 
				
			||||||
	KDFMode              string `json:"kdf_mode"`
 | 
						KDFMode              string `json:"kdf_mode"`
 | 
				
			||||||
	ConvergentEncryption bool   `json:"convergent_encryption"`
 | 
						ConvergentEncryption bool   `json:"convergent_encryption"`
 | 
				
			||||||
 | 
						ContextAsNonce       *bool  `json:"context_as_nonce"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// The minimum version of the key allowed to be used
 | 
						// The minimum version of the key allowed to be used
 | 
				
			||||||
	// for decryption
 | 
						// for decryption
 | 
				
			||||||
@@ -259,6 +260,10 @@ func (p *Policy) needsUpgrade() bool {
 | 
				
			|||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if p.ConvergentEncryption && p.ContextAsNonce == nil {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -288,6 +293,14 @@ func (p *Policy) upgrade(storage logical.Storage) error {
 | 
				
			|||||||
		persistNeeded = true
 | 
							persistNeeded = true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Originally the context-as-nonce mode was the only mode, so keep that
 | 
				
			||||||
 | 
						// behavior if convergent encryption is already in use
 | 
				
			||||||
 | 
						if p.ConvergentEncryption && p.ContextAsNonce == nil {
 | 
				
			||||||
 | 
							p.ContextAsNonce = new(bool)
 | 
				
			||||||
 | 
							*p.ContextAsNonce = true
 | 
				
			||||||
 | 
							persistNeeded = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if persistNeeded {
 | 
						if persistNeeded {
 | 
				
			||||||
		err := p.Persist(storage)
 | 
							err := p.Persist(storage)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
@@ -307,10 +320,6 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
 | 
				
			|||||||
		return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
 | 
							return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if p.LatestVersion == 0 {
 | 
					 | 
				
			||||||
		return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if ver <= 0 || ver > p.LatestVersion {
 | 
						if ver <= 0 || ver > p.LatestVersion {
 | 
				
			||||||
		return nil, errutil.UserError{Err: "invalid key version"}
 | 
							return nil, errutil.UserError{Err: "invalid key version"}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -335,7 +344,7 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *Policy) Encrypt(context []byte, value string) (string, error) {
 | 
					func (p *Policy) Encrypt(context, nonce []byte, value string) (string, error) {
 | 
				
			||||||
	// Decode the plaintext value
 | 
						// Decode the plaintext value
 | 
				
			||||||
	plaintext, err := base64.StdEncoding.DecodeString(value)
 | 
						plaintext, err := base64.StdEncoding.DecodeString(value)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -367,15 +376,20 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
 | 
				
			|||||||
		return "", errutil.InternalError{Err: err.Error()}
 | 
							return "", errutil.InternalError{Err: err.Error()}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if p.ConvergentEncryption && len(context) != gcm.NonceSize() {
 | 
						if p.ConvergentEncryption {
 | 
				
			||||||
		return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
 | 
					
 | 
				
			||||||
 | 
							if *p.ContextAsNonce {
 | 
				
			||||||
 | 
								if len(context) != gcm.NonceSize() {
 | 
				
			||||||
 | 
									return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with context-as-nonce with this key", gcm.NonceSize())}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								nonce = context
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							} else if len(nonce) != gcm.NonceSize() {
 | 
				
			||||||
 | 
								return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Compute random nonce
 | 
					 | 
				
			||||||
	var nonce []byte
 | 
					 | 
				
			||||||
	if p.ConvergentEncryption {
 | 
					 | 
				
			||||||
		nonce = context
 | 
					 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
 | 
							// Compute random nonce
 | 
				
			||||||
		nonce = make([]byte, gcm.NonceSize())
 | 
							nonce = make([]byte, gcm.NonceSize())
 | 
				
			||||||
		_, err = rand.Read(nonce)
 | 
							_, err = rand.Read(nonce)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
@@ -387,7 +401,10 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
 | 
				
			|||||||
	out := gcm.Seal(nil, nonce, plaintext, nil)
 | 
						out := gcm.Seal(nil, nonce, plaintext, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Place the encrypted data after the nonce
 | 
						// Place the encrypted data after the nonce
 | 
				
			||||||
	full := append(nonce, out...)
 | 
						full := out
 | 
				
			||||||
 | 
						if !p.ConvergentEncryption {
 | 
				
			||||||
 | 
							full = append(nonce, out...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Convert to base64
 | 
						// Convert to base64
 | 
				
			||||||
	encoded := base64.StdEncoding.EncodeToString(full)
 | 
						encoded := base64.StdEncoding.EncodeToString(full)
 | 
				
			||||||
@@ -398,12 +415,16 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
 | 
				
			|||||||
	return encoded, nil
 | 
						return encoded, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *Policy) Decrypt(context []byte, value string) (string, error) {
 | 
					func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
 | 
				
			||||||
	// Verify the prefix
 | 
						// Verify the prefix
 | 
				
			||||||
	if !strings.HasPrefix(value, "vault:v") {
 | 
						if !strings.HasPrefix(value, "vault:v") {
 | 
				
			||||||
		return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
 | 
							return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if p.ConvergentEncryption && !*p.ContextAsNonce && (nonce == nil || len(nonce) == 0) {
 | 
				
			||||||
 | 
							return "", errutil.UserError{Err: "invalid convergent nonce supplied"}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
 | 
						splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
 | 
				
			||||||
	if len(splitVerCiphertext) != 2 {
 | 
						if len(splitVerCiphertext) != 2 {
 | 
				
			||||||
		return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
 | 
							return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
 | 
				
			||||||
@@ -460,8 +481,16 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Extract the nonce and ciphertext
 | 
						// Extract the nonce and ciphertext
 | 
				
			||||||
	nonce := decoded[:gcm.NonceSize()]
 | 
						var ciphertext []byte
 | 
				
			||||||
	ciphertext := decoded[gcm.NonceSize():]
 | 
						if p.ConvergentEncryption {
 | 
				
			||||||
 | 
							if *p.ContextAsNonce {
 | 
				
			||||||
 | 
								nonce = context
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							ciphertext = decoded
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							nonce = decoded[:gcm.NonceSize()]
 | 
				
			||||||
 | 
							ciphertext = decoded[gcm.NonceSize():]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Verify and Decrypt
 | 
						// Verify and Decrypt
 | 
				
			||||||
	plain, err := gcm.Open(nil, nonce, ciphertext, nil)
 | 
						plain, err := gcm.Open(nil, nonce, ciphertext, nil)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -22,7 +22,7 @@ func Test_KeyUpgrade(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
 | 
					func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
 | 
				
			||||||
	storage := &logical.InmemStorage{}
 | 
						storage := &logical.InmemStorage{}
 | 
				
			||||||
	p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false, false)
 | 
						p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
 | 
				
			||||||
	if lock != nil {
 | 
						if lock != nil {
 | 
				
			||||||
		defer lock.RUnlock()
 | 
							defer lock.RUnlock()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -68,7 +68,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	storage := &logical.InmemStorage{}
 | 
						storage := &logical.InmemStorage{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false)
 | 
						p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -198,7 +198,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	storage := &logical.InmemStorage{}
 | 
						storage := &logical.InmemStorage{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false)
 | 
						p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
 | 
				
			||||||
	if lock != nil {
 | 
						if lock != nil {
 | 
				
			||||||
		defer lock.RUnlock()
 | 
							defer lock.RUnlock()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user