Refactor convergent encryption to make specifying a nonce in addition to context possible

This commit is contained in:
Jeff Mitchell
2016-08-05 17:52:44 -04:00
parent 0a386d48a2
commit c7bf73f924
9 changed files with 152 additions and 48 deletions

View File

@@ -601,6 +601,7 @@ func TestConvergentEncryption(t *testing.T) {
Data: map[string]interface{}{ Data: map[string]interface{}{
"derived": false, "derived": false,
"convergent_encryption": true, "convergent_encryption": true,
"context_as_nonce": true,
}, },
} }
@@ -619,6 +620,7 @@ func TestConvergentEncryption(t *testing.T) {
req.Data = map[string]interface{}{ req.Data = map[string]interface{}{
"derived": true, "derived": true,
"convergent_encryption": true, "convergent_encryption": true,
"context_as_nonce": true,
} }
resp, err = b.HandleRequest(req) resp, err = b.HandleRequest(req)

View File

@@ -105,42 +105,42 @@ func (lm *lockManager) UnlockPolicy(lock *sync.RWMutex, lockType bool) {
// is needed (for instance, for an upgrade/migration), give up the read lock, // is needed (for instance, for an upgrade/migration), give up the read lock,
// call again with an exclusive lock, then swap back out for a read lock. // call again with an exclusive lock, then swap back out for a read lock.
func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) { func (lm *lockManager) GetPolicyShared(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, shared) p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, false, shared)
if err == nil || if err == nil ||
(err != nil && err != errNeedExclusiveLock) { (err != nil && err != errNeedExclusiveLock) {
return p, lock, err return p, lock, err
} }
// Try again while asking for an exlusive lock // Try again while asking for an exlusive lock
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, exclusive) p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, false, exclusive)
if err != nil || p == nil || lock == nil { if err != nil || p == nil || lock == nil {
return p, lock, err return p, lock, err
} }
lock.Unlock() lock.Unlock()
p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, shared) p, lock, _, err = lm.getPolicyCommon(storage, name, false, false, false, false, shared)
return p, lock, err return p, lock, err
} }
// Get the policy with an exclusive lock // Get the policy with an exclusive lock
func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) { func (lm *lockManager) GetPolicyExclusive(storage logical.Storage, name string) (*Policy, *sync.RWMutex, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, exclusive) p, lock, _, err := lm.getPolicyCommon(storage, name, false, false, false, false, exclusive)
return p, lock, err return p, lock, err
} }
// Get the policy with a read lock; if it returns that an exclusive lock is // Get the policy with a read lock; if it returns that an exclusive lock is
// needed, retry. If successful, call one more time to get a read lock and // needed, retry. If successful, call one more time to get a read lock and
// return the value. // return the value.
func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived bool, convergent bool) (*Policy, *sync.RWMutex, bool, error) { func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, derived, convergent, contextAsNonce bool) (*Policy, *sync.RWMutex, bool, error) {
p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, convergent, shared) p, lock, _, err := lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, shared)
if err == nil || if err == nil ||
(err != nil && err != errNeedExclusiveLock) { (err != nil && err != errNeedExclusiveLock) {
return p, lock, false, err return p, lock, false, err
} }
// Try again while asking for an exlusive lock // Try again while asking for an exlusive lock
p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, convergent, exclusive) p, lock, upserted, err := lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, exclusive)
if err != nil || p == nil || lock == nil { if err != nil || p == nil || lock == nil {
return p, lock, upserted, err return p, lock, upserted, err
} }
@@ -148,14 +148,14 @@ func (lm *lockManager) GetPolicyUpsert(storage logical.Storage, name string, der
lock.Unlock() lock.Unlock()
// Now get a shared lock for the return, but preserve the value of upsert // Now get a shared lock for the return, but preserve the value of upsert
p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, convergent, shared) p, lock, _, err = lm.getPolicyCommon(storage, name, true, derived, convergent, contextAsNonce, shared)
return p, lock, upserted, err return p, lock, upserted, err
} }
// When the function returns, a lock will be held on the policy if err == nil. // When the function returns, a lock will be held on the policy if err == nil.
// It is the caller's responsibility to unlock. // It is the caller's responsibility to unlock.
func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, convergent, lockType bool) (*Policy, *sync.RWMutex, bool, error) { func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, upsert, derived, convergent, contextAsNonce, lockType bool) (*Policy, *sync.RWMutex, bool, error) {
lock := lm.policyLock(name, lockType) lock := lm.policyLock(name, lockType)
var p *Policy var p *Policy
@@ -204,6 +204,8 @@ func (lm *lockManager) getPolicyCommon(storage logical.Storage, name string, ups
if derived { if derived {
p.KDFMode = kdfMode p.KDFMode = kdfMode
p.ConvergentEncryption = convergent p.ConvergentEncryption = convergent
p.ContextAsNonce = new(bool)
*p.ContextAsNonce = contextAsNonce
} }
err = p.rotate(storage) err = p.rotate(storage)

View File

@@ -100,7 +100,7 @@ func (b *backend) pathDatakeyWrite(
return nil, err return nil, err
} }
ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) ciphertext, err := p.Encrypt(context, nil, base64.StdEncoding.EncodeToString(newKey))
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -27,6 +27,11 @@ func (b *backend) pathDecrypt() *framework.Path {
Type: framework.TypeString, Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.", Description: "Context for key derivation. Required for derived keys.",
}, },
"nonce": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
},
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
@@ -46,17 +51,28 @@ func (b *backend) pathDecryptWrite(
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
} }
var err error
// Decode the context if any // Decode the context if any
contextRaw := d.Get("context").(string) contextRaw := d.Get("context").(string)
var context []byte var context []byte
if len(contextRaw) != 0 { if len(contextRaw) != 0 {
var err error
context, err = base64.StdEncoding.DecodeString(contextRaw) context, err = base64.StdEncoding.DecodeString(contextRaw)
if err != nil { if err != nil {
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
} }
} }
// Decode the nonce if any
nonceRaw := d.Get("nonce").(string)
var nonce []byte
if len(nonceRaw) != 0 {
nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
if err != nil {
return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
}
}
// Get the policy // Get the policy
p, lock, err := b.lm.GetPolicyShared(req.Storage, name) p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil { if lock != nil {
@@ -69,7 +85,7 @@ func (b *backend) pathDecryptWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
} }
plaintext, err := p.Decrypt(context, ciphertext) plaintext, err := p.Decrypt(context, nonce, ciphertext)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -28,6 +28,11 @@ func (b *backend) pathEncrypt() *framework.Path {
Type: framework.TypeString, Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.", Description: "Context for key derivation. Required for derived keys.",
}, },
"nonce": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
},
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
@@ -63,10 +68,11 @@ func (b *backend) pathEncryptWrite(
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
} }
var err error
// Decode the context if any // Decode the context if any
contextRaw := d.Get("context").(string) contextRaw := d.Get("context").(string)
var context []byte var context []byte
var err error
if len(contextRaw) != 0 { if len(contextRaw) != 0 {
context, err = base64.StdEncoding.DecodeString(contextRaw) context, err = base64.StdEncoding.DecodeString(contextRaw)
if err != nil { if err != nil {
@@ -74,12 +80,22 @@ func (b *backend) pathEncryptWrite(
} }
} }
// Decode the nonce if any
nonceRaw := d.Get("nonce").(string)
var nonce []byte
if len(nonceRaw) != 0 {
nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
if err != nil {
return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
}
}
// Get the policy // Get the policy
var p *Policy var p *Policy
var lock *sync.RWMutex var lock *sync.RWMutex
var upserted bool var upserted bool
if req.Operation == logical.CreateOperation { if req.Operation == logical.CreateOperation {
p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0, false) p, lock, upserted, err = b.lm.GetPolicyUpsert(req.Storage, name, len(context) != 0, false, false)
} else { } else {
p, lock, err = b.lm.GetPolicyShared(req.Storage, name) p, lock, err = b.lm.GetPolicyShared(req.Storage, name)
} }
@@ -93,7 +109,7 @@ func (b *backend) pathEncryptWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
} }
ciphertext, err := p.Encrypt(context, value) ciphertext, err := p.Encrypt(context, nonce, value)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -18,22 +18,41 @@ func (b *backend) pathKeys() *framework.Path {
}, },
"derived": &framework.FieldSchema{ "derived": &framework.FieldSchema{
Type: framework.TypeBool, Type: framework.TypeBool,
Description: "Enables key derivation mode. This allows for per-transaction unique keys", Description: `Enables key derivation mode. This
allows for per-transaction unique keys.`,
}, },
"convergent_encryption": &framework.FieldSchema{ "convergent_encryption": &framework.FieldSchema{
Type: framework.TypeBool, Type: framework.TypeBool,
Description: `Whether to use convergent encryption. Description: `Whether to support convergent encryption.
This is only supported when using a key with This is only supported when using a key with
key derivation enabled and will require all key derivation enabled and will require all
context values to be 96 bits (12 bytes) when requests to carry both a context and 96-bit
base64-decoded. This mode ensures that when (12-byte) nonce, unless the "context_as_nonce"
the same context is supplied, the same feature is also enabled. The given nonce will
ciphertext is emitted from the encryption be used in place of a randomly generated nonce.
function. It is *very important* when using As a result, when the same context and nonce
this mode that you ensure that all contexts (or context, if "context_as_nonce" is enabled)
are *globally unique*. Failing to do so will are supplied, the same ciphertext is emitted
from the encryption function. It is *very
important* when using this mode that you ensure
that all nonces are unique for a given context,
or, when using "context_as_nonce", that all
contexts are unique for a given key. Failing to
do so will severely impact the ciphertext's
security.`,
},
"context_as_nonce": &framework.FieldSchema{
Type: framework.TypeBool,
Description: `Whether to use the context value as the
nonce in the convergent encryption operation
mode. If set true, the user will have to
supply a 96-bit (12-byte) context value.
It is *very important* when using this
mode that you ensure that all contexts are
*globally unique*. Failing to do so will
severely impact the security of the key.`, severely impact the security of the key.`,
}, },
}, },
@@ -54,12 +73,13 @@ func (b *backend) pathPolicyWrite(
name := d.Get("name").(string) name := d.Get("name").(string)
derived := d.Get("derived").(bool) derived := d.Get("derived").(bool)
convergent := d.Get("convergent_encryption").(bool) convergent := d.Get("convergent_encryption").(bool)
contextAsNonce := d.Get("context_as_nonce").(bool)
if !derived && convergent { if !derived && convergent {
return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil return logical.ErrorResponse("convergent encryption requires derivation to be enabled"), nil
} }
p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived, convergent) p, lock, upserted, err := b.lm.GetPolicyUpsert(req.Storage, name, derived, convergent, contextAsNonce)
if lock != nil { if lock != nil {
defer lock.RUnlock() defer lock.RUnlock()
} }
@@ -107,6 +127,9 @@ func (b *backend) pathPolicyRead(
if p.Derived { if p.Derived {
resp.Data["kdf_mode"] = p.KDFMode resp.Data["kdf_mode"] = p.KDFMode
resp.Data["convergent_encryption"] = p.ConvergentEncryption resp.Data["convergent_encryption"] = p.ConvergentEncryption
if p.ContextAsNonce != nil {
resp.Data["context_as_nonce"] = *p.ContextAsNonce
}
} }
retKeys := map[string]int64{} retKeys := map[string]int64{}

View File

@@ -27,6 +27,11 @@ func (b *backend) pathRewrap() *framework.Path {
Type: framework.TypeString, Type: framework.TypeString,
Description: "Context for key derivation. Required for derived keys.", Description: "Context for key derivation. Required for derived keys.",
}, },
"nonce": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Nonce for when convergent encryption is used and the context is not used as the nonce",
},
}, },
Callbacks: map[logical.Operation]framework.OperationFunc{ Callbacks: map[logical.Operation]framework.OperationFunc{
@@ -47,17 +52,28 @@ func (b *backend) pathRewrapWrite(
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
} }
var err error
// Decode the context if any // Decode the context if any
contextRaw := d.Get("context").(string) contextRaw := d.Get("context").(string)
var context []byte var context []byte
if len(contextRaw) != 0 { if len(contextRaw) != 0 {
var err error
context, err = base64.StdEncoding.DecodeString(contextRaw) context, err = base64.StdEncoding.DecodeString(contextRaw)
if err != nil { if err != nil {
return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest
} }
} }
// Decode the nonce if any
nonceRaw := d.Get("nonce").(string)
var nonce []byte
if len(nonceRaw) != 0 {
nonce, err = base64.StdEncoding.DecodeString(nonceRaw)
if err != nil {
return logical.ErrorResponse("failed to decode nonce as base64"), logical.ErrInvalidRequest
}
}
// Get the policy // Get the policy
p, lock, err := b.lm.GetPolicyShared(req.Storage, name) p, lock, err := b.lm.GetPolicyShared(req.Storage, name)
if lock != nil { if lock != nil {
@@ -71,7 +87,7 @@ func (b *backend) pathRewrapWrite(
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
} }
plaintext, err := p.Decrypt(context, value) plaintext, err := p.Decrypt(context, nonce, value)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case errutil.UserError: case errutil.UserError:
@@ -87,7 +103,7 @@ func (b *backend) pathRewrapWrite(
return nil, fmt.Errorf("empty plaintext returned during rewrap") return nil, fmt.Errorf("empty plaintext returned during rewrap")
} }
ciphertext, err := p.Encrypt(context, plaintext) ciphertext, err := p.Encrypt(context, nonce, plaintext)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case errutil.UserError: case errutil.UserError:

View File

@@ -72,6 +72,7 @@ type Policy struct {
Derived bool `json:"derived"` Derived bool `json:"derived"`
KDFMode string `json:"kdf_mode"` KDFMode string `json:"kdf_mode"`
ConvergentEncryption bool `json:"convergent_encryption"` ConvergentEncryption bool `json:"convergent_encryption"`
ContextAsNonce *bool `json:"context_as_nonce"`
// The minimum version of the key allowed to be used // The minimum version of the key allowed to be used
// for decryption // for decryption
@@ -259,6 +260,10 @@ func (p *Policy) needsUpgrade() bool {
return true return true
} }
if p.ConvergentEncryption && p.ContextAsNonce == nil {
return true
}
return false return false
} }
@@ -288,6 +293,14 @@ func (p *Policy) upgrade(storage logical.Storage) error {
persistNeeded = true persistNeeded = true
} }
// Originally the context-as-nonce mode was the only mode, so keep that
// behavior if convergent encryption is already in use
if p.ConvergentEncryption && p.ContextAsNonce == nil {
p.ContextAsNonce = new(bool)
*p.ContextAsNonce = true
persistNeeded = true
}
if persistNeeded { if persistNeeded {
err := p.Persist(storage) err := p.Persist(storage)
if err != nil { if err != nil {
@@ -307,10 +320,6 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"} return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
} }
if p.LatestVersion == 0 {
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
}
if ver <= 0 || ver > p.LatestVersion { if ver <= 0 || ver > p.LatestVersion {
return nil, errutil.UserError{Err: "invalid key version"} return nil, errutil.UserError{Err: "invalid key version"}
} }
@@ -335,7 +344,7 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
} }
} }
func (p *Policy) Encrypt(context []byte, value string) (string, error) { func (p *Policy) Encrypt(context, nonce []byte, value string) (string, error) {
// Decode the plaintext value // Decode the plaintext value
plaintext, err := base64.StdEncoding.DecodeString(value) plaintext, err := base64.StdEncoding.DecodeString(value)
if err != nil { if err != nil {
@@ -367,15 +376,20 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
return "", errutil.InternalError{Err: err.Error()} return "", errutil.InternalError{Err: err.Error()}
} }
if p.ConvergentEncryption && len(context) != gcm.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
}
// Compute random nonce
var nonce []byte
if p.ConvergentEncryption { if p.ConvergentEncryption {
nonce = context
if *p.ContextAsNonce {
if len(context) != gcm.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded context must be %d bytes long when using convergent encryption with context-as-nonce with this key", gcm.NonceSize())}
}
nonce = context
} else if len(nonce) != gcm.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
}
} else { } else {
// Compute random nonce
nonce = make([]byte, gcm.NonceSize()) nonce = make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce) _, err = rand.Read(nonce)
if err != nil { if err != nil {
@@ -387,7 +401,10 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
out := gcm.Seal(nil, nonce, plaintext, nil) out := gcm.Seal(nil, nonce, plaintext, nil)
// Place the encrypted data after the nonce // Place the encrypted data after the nonce
full := append(nonce, out...) full := out
if !p.ConvergentEncryption {
full = append(nonce, out...)
}
// Convert to base64 // Convert to base64
encoded := base64.StdEncoding.EncodeToString(full) encoded := base64.StdEncoding.EncodeToString(full)
@@ -398,12 +415,16 @@ func (p *Policy) Encrypt(context []byte, value string) (string, error) {
return encoded, nil return encoded, nil
} }
func (p *Policy) Decrypt(context []byte, value string) (string, error) { func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
// Verify the prefix // Verify the prefix
if !strings.HasPrefix(value, "vault:v") { if !strings.HasPrefix(value, "vault:v") {
return "", errutil.UserError{Err: "invalid ciphertext: no prefix"} return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
} }
if p.ConvergentEncryption && !*p.ContextAsNonce && (nonce == nil || len(nonce) == 0) {
return "", errutil.UserError{Err: "invalid convergent nonce supplied"}
}
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2) splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2)
if len(splitVerCiphertext) != 2 { if len(splitVerCiphertext) != 2 {
return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"} return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
@@ -460,8 +481,16 @@ func (p *Policy) Decrypt(context []byte, value string) (string, error) {
} }
// Extract the nonce and ciphertext // Extract the nonce and ciphertext
nonce := decoded[:gcm.NonceSize()] var ciphertext []byte
ciphertext := decoded[gcm.NonceSize():] if p.ConvergentEncryption {
if *p.ContextAsNonce {
nonce = context
}
ciphertext = decoded
} else {
nonce = decoded[:gcm.NonceSize()]
ciphertext = decoded[gcm.NonceSize():]
}
// Verify and Decrypt // Verify and Decrypt
plain, err := gcm.Open(nil, nonce, ciphertext, nil) plain, err := gcm.Open(nil, nonce, ciphertext, nil)

View File

@@ -22,7 +22,7 @@ func Test_KeyUpgrade(t *testing.T) {
func testKeyUpgradeCommon(t *testing.T, lm *lockManager) { func testKeyUpgradeCommon(t *testing.T, lm *lockManager) {
storage := &logical.InmemStorage{} storage := &logical.InmemStorage{}
p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false, false) p, lock, upserted, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
if lock != nil { if lock != nil {
defer lock.RUnlock() defer lock.RUnlock()
} }
@@ -68,7 +68,7 @@ func testArchivingUpgradeCommon(t *testing.T, lm *lockManager) {
storage := &logical.InmemStorage{} storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false) p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -198,7 +198,7 @@ func testArchivingCommon(t *testing.T, lm *lockManager) {
storage := &logical.InmemStorage{} storage := &logical.InmemStorage{}
p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false) p, lock, _, err := lm.GetPolicyUpsert(storage, "test", false, false, false)
if lock != nil { if lock != nil {
defer lock.RUnlock() defer lock.RUnlock()
} }