Implement locking in the transit backend.

This ensures that we can safely rotate and modify configuration
parameters with multiple requests in flight.

As a side effect we also get a cache, which should provide a nice
speedup since we don't need to decrypt/deserialize constantly, which
would happen even with the physical LRU.
This commit is contained in:
Jeff Mitchell
2016-01-27 16:24:11 -05:00
parent ba03981739
commit 46514e01fa
11 changed files with 403 additions and 217 deletions

View File

@@ -6,36 +6,46 @@ import (
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return Backend().Setup(conf)
b := Backend()
be, err := b.Backend.Setup(conf)
if err != nil {
return nil, err
}
err = b.policies.loadStoredPolicies(conf.StorageView)
if err != nil {
return nil, err
}
return be, nil
}
func Backend() *framework.Backend {
func Backend() *backend {
var b backend
b.Backend = &framework.Backend{
PathsSpecial: &logical.Paths{
Root: []string{
"keys/*",
},
},
Paths: []*framework.Path{
// Rotate/Config needs to come before Keys
// as the handler is greedy
pathConfig(),
pathRotate(),
pathRewrap(),
pathKeys(),
pathEncrypt(),
pathDecrypt(),
pathDatakey(),
b.pathConfig(),
b.pathRotate(),
b.pathRewrap(),
b.pathKeys(),
b.pathEncrypt(),
b.pathDecrypt(),
b.pathDatakey(),
},
Secrets: []*framework.Secret{},
}
return b.Backend
b.policies = &policyCache{
cache: map[string]*lockingPolicy{},
}
return &b
}
type backend struct {
*framework.Backend
policies *policyCache
}

View File

@@ -19,7 +19,7 @@ const (
func TestBackend_basic(t *testing.T) {
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
@@ -42,7 +42,7 @@ func TestBackend_basic(t *testing.T) {
func TestBackend_datakey(t *testing.T) {
dataKeyInfo := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test", false),
testAccStepReadPolicy(t, "test", false, false),
@@ -57,7 +57,7 @@ func TestBackend_rotation(t *testing.T) {
decryptData := make(map[string]interface{})
encryptHistory := make(map[int]map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test", false),
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory),
@@ -111,26 +111,10 @@ func TestBackend_rotation(t *testing.T) {
})
}
func TestBackend_upsert(t *testing.T) {
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepReadPolicy(t, "test", true, false),
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
testAccStepReadPolicy(t, "test", false, false),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepEnableDeletion(t, "test"),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true, false),
},
})
}
func TestBackend_basic_derived(t *testing.T) {
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Factory: Factory,
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test", true),
testAccStepReadPolicy(t, "test", false, true),

View File

@@ -7,7 +7,7 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathConfig() *framework.Path {
func (b *backend) pathConfig() *framework.Path {
return &framework.Path{
Pattern: "keys/" + framework.GenericNameRegex("name") + "/config",
Fields: map[string]*framework.FieldSchema{
@@ -29,7 +29,7 @@ to be decrypted.`,
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: pathConfigWrite,
logical.UpdateOperation: b.pathConfigWrite,
},
HelpSynopsis: pathConfigHelpSyn,
@@ -37,21 +37,29 @@ to be decrypted.`,
}
}
func pathConfigWrite(
func (b *backend) pathConfigWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
// Check if the policy already exists
policy, err := getPolicy(req, name)
lp, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
if policy == nil {
if lp == nil {
return logical.ErrorResponse(
fmt.Sprintf("no existing role named %s could be found", name)),
logical.ErrInvalidRequest
}
lp.lock.Lock()
defer lp.lock.Unlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
}
resp := &logical.Response{}
persistNeeded := false
@@ -70,12 +78,12 @@ func pathConfigWrite(
}
if minDecryptionVersion > 0 &&
minDecryptionVersion != policy.MinDecryptionVersion {
if minDecryptionVersion > policy.LatestVersion {
minDecryptionVersion != lp.policy.MinDecryptionVersion {
if minDecryptionVersion > lp.policy.LatestVersion {
return logical.ErrorResponse(
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, policy.LatestVersion)), nil
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.policy.LatestVersion)), nil
}
policy.MinDecryptionVersion = minDecryptionVersion
lp.policy.MinDecryptionVersion = minDecryptionVersion
persistNeeded = true
}
}
@@ -83,8 +91,8 @@ func pathConfigWrite(
allowDeletionInt, ok := d.GetOk("deletion_allowed")
if ok {
allowDeletion := allowDeletionInt.(bool)
if allowDeletion != policy.DeletionAllowed {
policy.DeletionAllowed = allowDeletion
if allowDeletion != lp.policy.DeletionAllowed {
lp.policy.DeletionAllowed = allowDeletion
persistNeeded = true
}
}
@@ -92,8 +100,8 @@ func pathConfigWrite(
// Add this as a guard here before persisting since we now require the min
// decryption version to start at 1; even if it's not explicitly set here,
// force the upgrade
if policy.MinDecryptionVersion == 0 {
policy.MinDecryptionVersion = 1
if lp.policy.MinDecryptionVersion == 0 {
lp.policy.MinDecryptionVersion = 1
persistNeeded = true
}
@@ -101,7 +109,7 @@ func pathConfigWrite(
return nil, nil
}
return resp, policy.Persist(req.Storage)
return resp, lp.policy.Persist(req.Storage)
}
const pathConfigHelpSyn = `Configure a named encryption key`

View File

@@ -10,7 +10,7 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathDatakey() *framework.Path {
func (b *backend) pathDatakey() *framework.Path {
return &framework.Path{
Pattern: "datakey/" + framework.GenericNameRegex("plaintext") + "/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
@@ -39,7 +39,7 @@ and 512 bits are supported. Defaults to 256.`,
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: pathDatakeyWrite,
logical.UpdateOperation: b.pathDatakeyWrite,
},
HelpSynopsis: pathDatakeyHelpSyn,
@@ -47,7 +47,7 @@ and 512 bits are supported. Defaults to 256.`,
}
}
func pathDatakeyWrite(
func (b *backend) pathDatakeyWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
@@ -73,16 +73,24 @@ func pathDatakeyWrite(
}
// Get the policy
p, err := getPolicy(req, name)
lp, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
// Error if invalid policy
if p == nil {
if lp == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
lp.lock.RLock()
defer lp.lock.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
}
newKey := make([]byte, 32)
bits := d.Get("bits").(int)
switch bits {
@@ -99,7 +107,7 @@ func pathDatakeyWrite(
return nil, err
}
ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
ciphertext, err := lp.policy.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
if err != nil {
switch err.(type) {
case certutil.UserError:

View File

@@ -9,7 +9,7 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathDecrypt() *framework.Path {
func (b *backend) pathDecrypt() *framework.Path {
return &framework.Path{
Pattern: "decrypt/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
@@ -30,7 +30,7 @@ func pathDecrypt() *framework.Path {
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: pathDecryptWrite,
logical.UpdateOperation: b.pathDecryptWrite,
},
HelpSynopsis: pathDecryptHelpSyn,
@@ -38,7 +38,7 @@ func pathDecrypt() *framework.Path {
}
}
func pathDecryptWrite(
func (b *backend) pathDecryptWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
ciphertext := d.Get("ciphertext").(string)
@@ -58,17 +58,25 @@ func pathDecryptWrite(
}
// Get the policy
p, err := getPolicy(req, name)
lp, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
// Error if invalid policy
if p == nil {
if lp == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
plaintext, err := p.Decrypt(context, ciphertext)
lp.lock.RLock()
defer lp.lock.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
}
plaintext, err := lp.policy.Decrypt(context, ciphertext)
if err != nil {
switch err.(type) {
case certutil.UserError:

View File

@@ -9,7 +9,7 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathEncrypt() *framework.Path {
func (b *backend) pathEncrypt() *framework.Path {
return &framework.Path{
Pattern: "encrypt/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
@@ -30,7 +30,7 @@ func pathEncrypt() *framework.Path {
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: pathEncryptWrite,
logical.UpdateOperation: b.pathEncryptWrite,
},
HelpSynopsis: pathEncryptHelpSyn,
@@ -38,7 +38,7 @@ func pathEncrypt() *framework.Path {
}
}
func pathEncryptWrite(
func (b *backend) pathEncryptWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
value := d.Get("plaintext").(string)
@@ -46,12 +46,6 @@ func pathEncryptWrite(
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
}
// Get the policy
p, err := getPolicy(req, name)
if err != nil {
return nil, err
}
// Decode the context if any
contextRaw := d.Get("context").(string)
var context []byte
@@ -63,16 +57,26 @@ func pathEncryptWrite(
}
}
// Error if invalid policy
if p == nil {
isDerived := len(context) != 0
p, err = generatePolicy(req.Storage, name, isDerived)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to upsert policy: %v", err)), logical.ErrInvalidRequest
}
// Get the policy
lp, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
ciphertext, err := p.Encrypt(context, value)
// Error if invalid policy
if lp == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
lp.lock.RLock()
defer lp.lock.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
}
ciphertext, err := lp.policy.Encrypt(context, value)
if err != nil {
switch err.(type) {
case certutil.UserError:

View File

@@ -8,7 +8,7 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathKeys() *framework.Path {
func (b *backend) pathKeys() *framework.Path {
return &framework.Path{
Pattern: "keys/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
@@ -24,9 +24,9 @@ func pathKeys() *framework.Path {
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: pathPolicyWrite,
logical.DeleteOperation: pathPolicyDelete,
logical.ReadOperation: pathPolicyRead,
logical.UpdateOperation: b.pathPolicyWrite,
logical.DeleteOperation: b.pathPolicyDelete,
logical.ReadOperation: b.pathPolicyRead,
},
HelpSynopsis: pathPolicyHelpSyn,
@@ -34,13 +34,13 @@ func pathKeys() *framework.Path {
}
}
func pathPolicyWrite(
func (b *backend) pathPolicyWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
derived := d.Get("derived").(bool)
// Check if the policy already exists
existing, err := getPolicy(req, name)
existing, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
@@ -49,39 +49,47 @@ func pathPolicyWrite(
}
// Generate the policy
_, err = generatePolicy(req.Storage, name, derived)
_, err = b.policies.generatePolicy(req.Storage, name, derived)
return nil, err
}
func pathPolicyRead(
func (b *backend) pathPolicyRead(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
p, err := getPolicy(req, name)
lp, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
if p == nil {
if lp == nil {
return nil, nil
}
lp.lock.RLock()
defer lp.lock.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
}
// Return the response
resp := &logical.Response{
Data: map[string]interface{}{
"name": p.Name,
"cipher_mode": p.CipherMode,
"derived": p.Derived,
"deletion_allowed": p.DeletionAllowed,
"min_decryption_version": p.MinDecryptionVersion,
"latest_version": p.LatestVersion,
"name": lp.policy.Name,
"cipher_mode": lp.policy.CipherMode,
"derived": lp.policy.Derived,
"deletion_allowed": lp.policy.DeletionAllowed,
"min_decryption_version": lp.policy.MinDecryptionVersion,
"latest_version": lp.policy.LatestVersion,
},
}
if p.Derived {
resp.Data["kdf_mode"] = p.KDFMode
if lp.policy.Derived {
resp.Data["kdf_mode"] = lp.policy.KDFMode
}
retKeys := map[string]int64{}
for k, v := range p.Keys {
for k, v := range lp.policy.Keys {
retKeys[strconv.Itoa(k)] = v.CreationTime
}
resp.Data["keys"] = retKeys
@@ -89,32 +97,40 @@ func pathPolicyRead(
return resp, nil
}
func pathPolicyDelete(
func (b *backend) pathPolicyDelete(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
p, err := getPolicy(req, name)
lp, err := b.policies.getPolicy(req, name)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err
}
if p == nil {
if lp == nil {
return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest
}
if !p.DeletionAllowed {
// We don't defer here because deletePolicy also needs to grab the lock
lp.lock.RLock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
lp.lock.RUnlock()
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
}
if !lp.policy.DeletionAllowed {
lp.lock.RUnlock()
return logical.ErrorResponse(fmt.Sprintf("'allow_deletion' config value is not set")), logical.ErrInvalidRequest
}
err = req.Storage.Delete("policy/" + name)
// Let deletePolicy grab the lock
lp.lock.RUnlock()
err = b.policies.deletePolicy(req.Storage, name)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
}
err = req.Storage.Delete("archive/" + name)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error deleting archive %s: %s", name, err)), err
}
return nil, nil
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathRewrap() *framework.Path {
func (b *backend) pathRewrap() *framework.Path {
return &framework.Path{
Pattern: "rewrap/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
@@ -30,7 +30,7 @@ func pathRewrap() *framework.Path {
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: pathRewrapWrite,
logical.UpdateOperation: b.pathRewrapWrite,
},
HelpSynopsis: pathRewrapHelpSyn,
@@ -38,7 +38,7 @@ func pathRewrap() *framework.Path {
}
}
func pathRewrapWrite(
func (b *backend) pathRewrapWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
@@ -59,17 +59,25 @@ func pathRewrapWrite(
}
// Get the policy
p, err := getPolicy(req, name)
lp, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
// Error if invalid policy
if p == nil {
if lp == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
plaintext, err := p.Decrypt(context, value)
lp.lock.RLock()
defer lp.lock.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
}
plaintext, err := lp.policy.Decrypt(context, value)
if err != nil {
switch err.(type) {
case certutil.UserError:
@@ -85,7 +93,7 @@ func pathRewrapWrite(
return nil, fmt.Errorf("empty plaintext returned during rewrap")
}
ciphertext, err := p.Encrypt(context, plaintext)
ciphertext, err := lp.policy.Encrypt(context, plaintext)
if err != nil {
switch err.(type) {
case certutil.UserError:

View File

@@ -7,7 +7,7 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func pathRotate() *framework.Path {
func (b *backend) pathRotate() *framework.Path {
return &framework.Path{
Pattern: "keys/" + framework.GenericNameRegex("name") + "/rotate",
Fields: map[string]*framework.FieldSchema{
@@ -18,7 +18,7 @@ func pathRotate() *framework.Path {
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: pathRotateWrite,
logical.UpdateOperation: b.pathRotateWrite,
},
HelpSynopsis: pathRotateHelpSyn,
@@ -26,23 +26,31 @@ func pathRotate() *framework.Path {
}
}
func pathRotateWrite(
func (b *backend) pathRotateWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
// Check if the policy already exists
policy, err := getPolicy(req, name)
// Get the policy
lp, err := b.policies.getPolicy(req, name)
if err != nil {
return nil, err
}
if policy == nil {
return logical.ErrorResponse(
fmt.Sprintf("no existing role named %s could be found", name)),
logical.ErrInvalidRequest
// Error if invalid policy
if lp == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}
lp.lock.RLock()
defer lp.lock.RUnlock()
// Verify if wasn't deleted before we grabbed the lock
if lp.policy == nil {
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
}
// Generate the policy
err = policy.rotate(req.Storage)
err = lp.policy.rotate(req.Storage)
return nil, err
}

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/vault/helper/certutil"
@@ -21,6 +22,200 @@ const (
kdfMode = "hmac-sha256-counter"
)
// policyCache implements a simple locking cache of policies
type policyCache struct {
cache map[string]*lockingPolicy
lock sync.RWMutex
}
// loadStoredPolicies loads stored policies into the cache. This should only be
// run at backend initialization time.
func (p *policyCache) loadStoredPolicies(storage logical.Storage) error {
p.lock.Lock()
defer p.lock.Unlock()
policyNames, err := storage.List("policy/")
if err != nil {
return err
}
// getPolicy will populate the cache
for _, name := range policyNames {
lp, err := p.getPolicy(&logical.Request{
Storage: storage,
}, name)
if err != nil {
return err
}
if lp == nil {
return fmt.Errorf("policy %s key was found but value was nil")
}
}
return nil
}
// getPolicy loads a policy into the cache or returns one already in the cache
func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPolicy, error) {
// We don't defer this since we may need to give it up and get a write lock
p.lock.RLock()
// First, see if we're in the cache -- if so, return that
if p.cache[name] != nil {
defer p.lock.RUnlock()
return p.cache[name], nil
}
// If we find anything, we'll need to write into the cache, plus possibly
// persist the entry, so lock the cache
p.lock.RUnlock()
p.lock.Lock()
defer p.lock.Unlock()
// Check one more time to ensure that another process did not write during
// our lock switcheroo.
if p.cache[name] != nil {
return p.cache[name], nil
}
// Note that we don't need to create the locking entry until the end,
// because the policy wasn't in the cache so we don't know about it, and we
// hold the cache lock so nothing else can be writing it in right now
// Check if the policy already exists
raw, err := req.Storage.Get("policy/" + name)
if err != nil {
return nil, err
}
if raw == nil {
return nil, nil
}
// Decode the policy
policy := &Policy{
Keys: KeyEntryMap{},
}
err = json.Unmarshal(raw.Value, policy)
if err != nil {
return nil, err
}
persistNeeded := false
// Ensure we've moved from Key -> Keys
if policy.Key != nil && len(policy.Key) > 0 {
policy.migrateKeyToKeysMap()
persistNeeded = true
}
// With archiving, past assumptions about the length of the keys map are no longer valid
if policy.LatestVersion == 0 && len(policy.Keys) != 0 {
policy.LatestVersion = len(policy.Keys)
persistNeeded = true
}
// We disallow setting the version to 0, since they start at 1 since moving
// to rotate-able keys, so update if it's set to 0
if policy.MinDecryptionVersion == 0 {
policy.MinDecryptionVersion = 1
persistNeeded = true
}
// On first load after an upgrade, copy keys to the archive
if policy.ArchiveVersion == 0 {
persistNeeded = true
}
if persistNeeded {
err = policy.Persist(req.Storage)
if err != nil {
return nil, err
}
}
lp := &lockingPolicy{
policy: policy,
}
p.cache[name] = lp
return lp, nil
}
// generatePolicy is used to create a new named policy with a randomly
// generated key
func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) {
p.lock.Lock()
defer p.lock.Unlock()
// Ensure one doesn't already exist
if lp := p.cache[name]; lp != nil {
return nil, fmt.Errorf("policy %s already exists", name)
}
// Create the policy object
policy := &Policy{
Name: name,
CipherMode: "aes-gcm",
Derived: derived,
}
if derived {
policy.KDFMode = kdfMode
}
err := policy.rotate(storage)
if err != nil {
return nil, err
}
lp := &lockingPolicy{
policy: policy,
}
p.cache[name] = lp
// Return the policy
return lp, nil
}
// deletePolicy deletes a policy
func (p *policyCache) deletePolicy(storage logical.Storage, name string) error {
p.lock.Lock()
defer p.lock.Unlock()
lp := p.cache[name]
if lp == nil {
return fmt.Errorf("policy %s not found", name)
}
// We need to ensure all other access has stopped
lp.lock.Lock()
defer lp.lock.Unlock()
// Verify this hasn't changed
if !lp.policy.DeletionAllowed {
return fmt.Errorf("deletion not allowed for policy %s", name)
}
err := storage.Delete("policy/" + name)
if err != nil {
return fmt.Errorf("error deleting policy %s: %s", name, err)
}
err = storage.Delete("archive/" + name)
if err != nil {
return fmt.Errorf("error deleting archive %s: %s", name, err)
}
lp.policy = nil
delete(p.cache, name)
return nil
}
// lockingPolicy holds a Policy guarded by a lock
type lockingPolicy struct {
policy *Policy
lock sync.RWMutex
}
// KeyEntry stores the key and metadata
type KeyEntry struct {
Key []byte `json:"key"`
@@ -435,87 +630,3 @@ func (p *Policy) migrateKeyToKeysMap() {
}
p.Key = nil
}
func deserializePolicy(buf []byte) (*Policy, error) {
p := &Policy{
Keys: KeyEntryMap{},
}
if err := json.Unmarshal(buf, p); err != nil {
return nil, err
}
return p, nil
}
func getPolicy(req *logical.Request, name string) (*Policy, error) {
// Check if the policy already exists
raw, err := req.Storage.Get("policy/" + name)
if err != nil {
return nil, err
}
if raw == nil {
return nil, nil
}
// Decode the policy
p, err := deserializePolicy(raw.Value)
if err != nil {
return nil, err
}
persistNeeded := false
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
p.migrateKeyToKeysMap()
persistNeeded = true
}
// With archiving, past assumptions about the length of the keys map are no longer valid
if p.LatestVersion == 0 && len(p.Keys) != 0 {
p.LatestVersion = len(p.Keys)
persistNeeded = true
}
// We disallow setting the version to 0, since they start at 1 since moving
// to rotate-able keys, so update if it's set to 0
if p.MinDecryptionVersion == 0 {
p.MinDecryptionVersion = 1
persistNeeded = true
}
// On first load after an upgrade, copy keys to the archive
if p.ArchiveVersion == 0 {
persistNeeded = true
}
if persistNeeded {
err = p.Persist(req.Storage)
if err != nil {
return nil, err
}
}
return p, nil
}
// generatePolicy is used to create a new named policy with
// a randomly generated key
func generatePolicy(storage logical.Storage, name string, derived bool) (*Policy, error) {
// Create the policy object
p := &Policy{
Name: name,
CipherMode: "aes-gcm",
Derived: derived,
}
if derived {
p.KDFMode = kdfMode
}
err := p.rotate(storage)
if err != nil {
return nil, err
}
// Return the policy
return p, nil
}

View File

@@ -17,14 +17,19 @@ func resetKeysArchive() {
func Test_KeyUpgrade(t *testing.T) {
storage := &logical.InmemStorage{}
policy, err := generatePolicy(storage, "test", false)
policies := &policyCache{
cache: map[string]*lockingPolicy{},
}
lp, err := policies.generatePolicy(storage, "test", false)
if err != nil {
t.Fatal(err)
}
if policy == nil {
if lp == nil {
t.Fatal("nil policy")
}
policy := lp.policy
testBytes := make([]byte, len(policy.Keys[1].Key))
copy(testBytes, policy.Keys[1].Key)
@@ -51,15 +56,20 @@ func Test_ArchivingUpgrade(t *testing.T) {
// zero and latest, respectively
storage := &logical.InmemStorage{}
policies := &policyCache{
cache: map[string]*lockingPolicy{},
}
policy, err := generatePolicy(storage, "test", false)
lp, err := policies.generatePolicy(storage, "test", false)
if err != nil {
t.Fatal(err)
}
if policy == nil {
if lp == nil {
t.Fatal("policy is nil")
}
policy := lp.policy
// Store the initial key in the archive
keysArchive = append(keysArchive, policy.Keys[1])
checkKeys(t, policy, storage, "initial", 1, 1, 1)
@@ -96,17 +106,22 @@ func Test_ArchivingUpgrade(t *testing.T) {
t.Fatal(err)
}
// Expire from the cache since we modified it under-the-hood
delete(policies.cache, "test")
// Now get the policy again; the upgrade should happen automatically
policy, err = getPolicy(&logical.Request{
lp, err = policies.getPolicy(&logical.Request{
Storage: storage,
}, "test")
if err != nil {
t.Fatal(err)
}
if policy == nil {
if lp == nil {
t.Fatal("policy is nil")
}
policy = lp.policy
checkKeys(t, policy, storage, "upgrade", 10, 10, 10)
}
@@ -120,14 +135,20 @@ func Test_Archiving(t *testing.T) {
storage := &logical.InmemStorage{}
policy, err := generatePolicy(storage, "test", false)
policies := &policyCache{
cache: map[string]*lockingPolicy{},
}
lp, err := policies.generatePolicy(storage, "test", false)
if err != nil {
t.Fatal(err)
}
if policy == nil {
if lp == nil {
t.Fatal("policy is nil")
}
policy := lp.policy
// Store the initial key in the archive
keysArchive = append(keysArchive, policy.Keys[1])
checkKeys(t, policy, storage, "initial", 1, 1, 1)