From 46514e01faefa42de4ff0b4a86be1447f2e06b7c Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 27 Jan 2016 16:24:11 -0500 Subject: [PATCH] Implement locking in the transit backend. This ensures that we can safely rotate and modify configuration parameters with multiple requests in flight. As a side effect we also get a cache, which should provide a nice speedup since we don't need to decrypt/deserialize constantly, which would happen even with the physical LRU. --- builtin/logical/transit/backend.go | 42 ++-- builtin/logical/transit/backend_test.go | 24 +- builtin/logical/transit/path_config.go | 36 +-- builtin/logical/transit/path_datakey.go | 20 +- builtin/logical/transit/path_decrypt.go | 20 +- builtin/logical/transit/path_encrypt.go | 38 ++-- builtin/logical/transit/path_keys.go | 74 ++++--- builtin/logical/transit/path_rewrap.go | 22 +- builtin/logical/transit/path_rotate.go | 28 ++- builtin/logical/transit/policy.go | 279 +++++++++++++++++------- builtin/logical/transit/policy_test.go | 37 +++- 11 files changed, 403 insertions(+), 217 deletions(-) diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index 9ff64896c7..0cfe7116f7 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -6,36 +6,46 @@ import ( ) func Factory(conf *logical.BackendConfig) (logical.Backend, error) { - return Backend().Setup(conf) + b := Backend() + be, err := b.Backend.Setup(conf) + if err != nil { + return nil, err + } + + err = b.policies.loadStoredPolicies(conf.StorageView) + if err != nil { + return nil, err + } + + return be, nil } -func Backend() *framework.Backend { +func Backend() *backend { var b backend b.Backend = &framework.Backend{ - PathsSpecial: &logical.Paths{ - Root: []string{ - "keys/*", - }, - }, - Paths: []*framework.Path{ // Rotate/Config needs to come before Keys // as the handler is greedy - pathConfig(), - pathRotate(), - pathRewrap(), - pathKeys(), - pathEncrypt(), - pathDecrypt(), - pathDatakey(), + b.pathConfig(), + b.pathRotate(), + b.pathRewrap(), + b.pathKeys(), + b.pathEncrypt(), + b.pathDecrypt(), + b.pathDatakey(), }, Secrets: []*framework.Secret{}, } - return b.Backend + b.policies = &policyCache{ + cache: map[string]*lockingPolicy{}, + } + + return &b } type backend struct { *framework.Backend + policies *policyCache } diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 3ed79d72bb..6ea5252248 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -19,7 +19,7 @@ const ( func TestBackend_basic(t *testing.T) { decryptData := make(map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Backend: Backend(), + Factory: Factory, Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test", false), testAccStepReadPolicy(t, "test", false, false), @@ -42,7 +42,7 @@ func TestBackend_basic(t *testing.T) { func TestBackend_datakey(t *testing.T) { dataKeyInfo := make(map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Backend: Backend(), + Factory: Factory, Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test", false), testAccStepReadPolicy(t, "test", false, false), @@ -57,7 +57,7 @@ func TestBackend_rotation(t *testing.T) { decryptData := make(map[string]interface{}) encryptHistory := make(map[int]map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Backend: Backend(), + Factory: Factory, Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test", false), testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory), @@ -111,26 +111,10 @@ func TestBackend_rotation(t *testing.T) { }) } -func TestBackend_upsert(t *testing.T) { - decryptData := make(map[string]interface{}) - logicaltest.Test(t, logicaltest.TestCase{ - Backend: Backend(), - Steps: []logicaltest.TestStep{ - testAccStepReadPolicy(t, "test", true, false), - testAccStepEncrypt(t, "test", testPlaintext, decryptData), - testAccStepReadPolicy(t, "test", false, false), - testAccStepDecrypt(t, "test", testPlaintext, decryptData), - testAccStepEnableDeletion(t, "test"), - testAccStepDeletePolicy(t, "test"), - testAccStepReadPolicy(t, "test", true, false), - }, - }) -} - func TestBackend_basic_derived(t *testing.T) { decryptData := make(map[string]interface{}) logicaltest.Test(t, logicaltest.TestCase{ - Backend: Backend(), + Factory: Factory, Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test", true), testAccStepReadPolicy(t, "test", false, true), diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go index 4636181a6f..d5bff3b238 100644 --- a/builtin/logical/transit/path_config.go +++ b/builtin/logical/transit/path_config.go @@ -7,7 +7,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathConfig() *framework.Path { +func (b *backend) pathConfig() *framework.Path { return &framework.Path{ Pattern: "keys/" + framework.GenericNameRegex("name") + "/config", Fields: map[string]*framework.FieldSchema{ @@ -29,7 +29,7 @@ to be decrypted.`, }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: pathConfigWrite, + logical.UpdateOperation: b.pathConfigWrite, }, HelpSynopsis: pathConfigHelpSyn, @@ -37,21 +37,29 @@ to be decrypted.`, } } -func pathConfigWrite( +func (b *backend) pathConfigWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) // Check if the policy already exists - policy, err := getPolicy(req, name) + lp, err := b.policies.getPolicy(req, name) if err != nil { return nil, err } - if policy == nil { + if lp == nil { return logical.ErrorResponse( fmt.Sprintf("no existing role named %s could be found", name)), logical.ErrInvalidRequest } + lp.lock.Lock() + defer lp.lock.Unlock() + + // Verify if wasn't deleted before we grabbed the lock + if lp.policy == nil { + return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name) + } + resp := &logical.Response{} persistNeeded := false @@ -70,12 +78,12 @@ func pathConfigWrite( } if minDecryptionVersion > 0 && - minDecryptionVersion != policy.MinDecryptionVersion { - if minDecryptionVersion > policy.LatestVersion { + minDecryptionVersion != lp.policy.MinDecryptionVersion { + if minDecryptionVersion > lp.policy.LatestVersion { return logical.ErrorResponse( - fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, policy.LatestVersion)), nil + fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.policy.LatestVersion)), nil } - policy.MinDecryptionVersion = minDecryptionVersion + lp.policy.MinDecryptionVersion = minDecryptionVersion persistNeeded = true } } @@ -83,8 +91,8 @@ func pathConfigWrite( allowDeletionInt, ok := d.GetOk("deletion_allowed") if ok { allowDeletion := allowDeletionInt.(bool) - if allowDeletion != policy.DeletionAllowed { - policy.DeletionAllowed = allowDeletion + if allowDeletion != lp.policy.DeletionAllowed { + lp.policy.DeletionAllowed = allowDeletion persistNeeded = true } } @@ -92,8 +100,8 @@ func pathConfigWrite( // Add this as a guard here before persisting since we now require the min // decryption version to start at 1; even if it's not explicitly set here, // force the upgrade - if policy.MinDecryptionVersion == 0 { - policy.MinDecryptionVersion = 1 + if lp.policy.MinDecryptionVersion == 0 { + lp.policy.MinDecryptionVersion = 1 persistNeeded = true } @@ -101,7 +109,7 @@ func pathConfigWrite( return nil, nil } - return resp, policy.Persist(req.Storage) + return resp, lp.policy.Persist(req.Storage) } const pathConfigHelpSyn = `Configure a named encryption key` diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index 2cec9f3e68..78bc80da6d 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -10,7 +10,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathDatakey() *framework.Path { +func (b *backend) pathDatakey() *framework.Path { return &framework.Path{ Pattern: "datakey/" + framework.GenericNameRegex("plaintext") + "/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ @@ -39,7 +39,7 @@ and 512 bits are supported. Defaults to 256.`, }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: pathDatakeyWrite, + logical.UpdateOperation: b.pathDatakeyWrite, }, HelpSynopsis: pathDatakeyHelpSyn, @@ -47,7 +47,7 @@ and 512 bits are supported. Defaults to 256.`, } } -func pathDatakeyWrite( +func (b *backend) pathDatakeyWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) @@ -73,16 +73,24 @@ func pathDatakeyWrite( } // Get the policy - p, err := getPolicy(req, name) + lp, err := b.policies.getPolicy(req, name) if err != nil { return nil, err } // Error if invalid policy - if p == nil { + if lp == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } + lp.lock.RLock() + defer lp.lock.RUnlock() + + // Verify if wasn't deleted before we grabbed the lock + if lp.policy == nil { + return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name) + } + newKey := make([]byte, 32) bits := d.Get("bits").(int) switch bits { @@ -99,7 +107,7 @@ func pathDatakeyWrite( return nil, err } - ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) + ciphertext, err := lp.policy.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index d084d5d936..f357ffb87d 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathDecrypt() *framework.Path { +func (b *backend) pathDecrypt() *framework.Path { return &framework.Path{ Pattern: "decrypt/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ @@ -30,7 +30,7 @@ func pathDecrypt() *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: pathDecryptWrite, + logical.UpdateOperation: b.pathDecryptWrite, }, HelpSynopsis: pathDecryptHelpSyn, @@ -38,7 +38,7 @@ func pathDecrypt() *framework.Path { } } -func pathDecryptWrite( +func (b *backend) pathDecryptWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) ciphertext := d.Get("ciphertext").(string) @@ -58,17 +58,25 @@ func pathDecryptWrite( } // Get the policy - p, err := getPolicy(req, name) + lp, err := b.policies.getPolicy(req, name) if err != nil { return nil, err } // Error if invalid policy - if p == nil { + if lp == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - plaintext, err := p.Decrypt(context, ciphertext) + lp.lock.RLock() + defer lp.lock.RUnlock() + + // Verify if wasn't deleted before we grabbed the lock + if lp.policy == nil { + return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name) + } + + plaintext, err := lp.policy.Decrypt(context, ciphertext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index e9aa322ea9..91564b0945 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathEncrypt() *framework.Path { +func (b *backend) pathEncrypt() *framework.Path { return &framework.Path{ Pattern: "encrypt/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ @@ -30,7 +30,7 @@ func pathEncrypt() *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: pathEncryptWrite, + logical.UpdateOperation: b.pathEncryptWrite, }, HelpSynopsis: pathEncryptHelpSyn, @@ -38,7 +38,7 @@ func pathEncrypt() *framework.Path { } } -func pathEncryptWrite( +func (b *backend) pathEncryptWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) value := d.Get("plaintext").(string) @@ -46,12 +46,6 @@ func pathEncryptWrite( return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest } - // Get the policy - p, err := getPolicy(req, name) - if err != nil { - return nil, err - } - // Decode the context if any contextRaw := d.Get("context").(string) var context []byte @@ -63,16 +57,26 @@ func pathEncryptWrite( } } - // Error if invalid policy - if p == nil { - isDerived := len(context) != 0 - p, err = generatePolicy(req.Storage, name, isDerived) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("failed to upsert policy: %v", err)), logical.ErrInvalidRequest - } + // Get the policy + lp, err := b.policies.getPolicy(req, name) + if err != nil { + return nil, err } - ciphertext, err := p.Encrypt(context, value) + // Error if invalid policy + if lp == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + } + + lp.lock.RLock() + defer lp.lock.RUnlock() + + // Verify if wasn't deleted before we grabbed the lock + if lp.policy == nil { + return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name) + } + + ciphertext, err := lp.policy.Encrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index b3ee4304f2..4b2774287d 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathKeys() *framework.Path { +func (b *backend) pathKeys() *framework.Path { return &framework.Path{ Pattern: "keys/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ @@ -24,9 +24,9 @@ func pathKeys() *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: pathPolicyWrite, - logical.DeleteOperation: pathPolicyDelete, - logical.ReadOperation: pathPolicyRead, + logical.UpdateOperation: b.pathPolicyWrite, + logical.DeleteOperation: b.pathPolicyDelete, + logical.ReadOperation: b.pathPolicyRead, }, HelpSynopsis: pathPolicyHelpSyn, @@ -34,13 +34,13 @@ func pathKeys() *framework.Path { } } -func pathPolicyWrite( +func (b *backend) pathPolicyWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) derived := d.Get("derived").(bool) // Check if the policy already exists - existing, err := getPolicy(req, name) + existing, err := b.policies.getPolicy(req, name) if err != nil { return nil, err } @@ -49,39 +49,47 @@ func pathPolicyWrite( } // Generate the policy - _, err = generatePolicy(req.Storage, name, derived) + _, err = b.policies.generatePolicy(req.Storage, name, derived) return nil, err } -func pathPolicyRead( +func (b *backend) pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - p, err := getPolicy(req, name) + lp, err := b.policies.getPolicy(req, name) if err != nil { return nil, err } - if p == nil { + if lp == nil { return nil, nil } + lp.lock.RLock() + defer lp.lock.RUnlock() + + // Verify if wasn't deleted before we grabbed the lock + if lp.policy == nil { + return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name) + } + // Return the response resp := &logical.Response{ Data: map[string]interface{}{ - "name": p.Name, - "cipher_mode": p.CipherMode, - "derived": p.Derived, - "deletion_allowed": p.DeletionAllowed, - "min_decryption_version": p.MinDecryptionVersion, - "latest_version": p.LatestVersion, + "name": lp.policy.Name, + "cipher_mode": lp.policy.CipherMode, + "derived": lp.policy.Derived, + "deletion_allowed": lp.policy.DeletionAllowed, + "min_decryption_version": lp.policy.MinDecryptionVersion, + "latest_version": lp.policy.LatestVersion, }, } - if p.Derived { - resp.Data["kdf_mode"] = p.KDFMode + if lp.policy.Derived { + resp.Data["kdf_mode"] = lp.policy.KDFMode } retKeys := map[string]int64{} - for k, v := range p.Keys { + for k, v := range lp.policy.Keys { retKeys[strconv.Itoa(k)] = v.CreationTime } resp.Data["keys"] = retKeys @@ -89,32 +97,40 @@ func pathPolicyRead( return resp, nil } -func pathPolicyDelete( +func (b *backend) pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - p, err := getPolicy(req, name) + lp, err := b.policies.getPolicy(req, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err } - if p == nil { + if lp == nil { return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest } - if !p.DeletionAllowed { + // We don't defer here because deletePolicy also needs to grab the lock + lp.lock.RLock() + + // Verify if wasn't deleted before we grabbed the lock + if lp.policy == nil { + lp.lock.RUnlock() + return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name) + } + + if !lp.policy.DeletionAllowed { + lp.lock.RUnlock() return logical.ErrorResponse(fmt.Sprintf("'allow_deletion' config value is not set")), logical.ErrInvalidRequest } - err = req.Storage.Delete("policy/" + name) + // Let deletePolicy grab the lock + lp.lock.RUnlock() + + err = b.policies.deletePolicy(req.Storage, name) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } - err = req.Storage.Delete("archive/" + name) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("error deleting archive %s: %s", name, err)), err - } - return nil, nil } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index adac996b2a..199de11d5c 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathRewrap() *framework.Path { +func (b *backend) pathRewrap() *framework.Path { return &framework.Path{ Pattern: "rewrap/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ @@ -30,7 +30,7 @@ func pathRewrap() *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: pathRewrapWrite, + logical.UpdateOperation: b.pathRewrapWrite, }, HelpSynopsis: pathRewrapHelpSyn, @@ -38,7 +38,7 @@ func pathRewrap() *framework.Path { } } -func pathRewrapWrite( +func (b *backend) pathRewrapWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) @@ -59,17 +59,25 @@ func pathRewrapWrite( } // Get the policy - p, err := getPolicy(req, name) + lp, err := b.policies.getPolicy(req, name) if err != nil { return nil, err } // Error if invalid policy - if p == nil { + if lp == nil { return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - plaintext, err := p.Decrypt(context, value) + lp.lock.RLock() + defer lp.lock.RUnlock() + + // Verify if wasn't deleted before we grabbed the lock + if lp.policy == nil { + return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name) + } + + plaintext, err := lp.policy.Decrypt(context, value) if err != nil { switch err.(type) { case certutil.UserError: @@ -85,7 +93,7 @@ func pathRewrapWrite( return nil, fmt.Errorf("empty plaintext returned during rewrap") } - ciphertext, err := p.Encrypt(context, plaintext) + ciphertext, err := lp.policy.Encrypt(context, plaintext) if err != nil { switch err.(type) { case certutil.UserError: diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go index f7b42dcfb2..a23c10beb6 100644 --- a/builtin/logical/transit/path_rotate.go +++ b/builtin/logical/transit/path_rotate.go @@ -7,7 +7,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathRotate() *framework.Path { +func (b *backend) pathRotate() *framework.Path { return &framework.Path{ Pattern: "keys/" + framework.GenericNameRegex("name") + "/rotate", Fields: map[string]*framework.FieldSchema{ @@ -18,7 +18,7 @@ func pathRotate() *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: pathRotateWrite, + logical.UpdateOperation: b.pathRotateWrite, }, HelpSynopsis: pathRotateHelpSyn, @@ -26,23 +26,31 @@ func pathRotate() *framework.Path { } } -func pathRotateWrite( +func (b *backend) pathRotateWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - // Check if the policy already exists - policy, err := getPolicy(req, name) + // Get the policy + lp, err := b.policies.getPolicy(req, name) if err != nil { return nil, err } - if policy == nil { - return logical.ErrorResponse( - fmt.Sprintf("no existing role named %s could be found", name)), - logical.ErrInvalidRequest + + // Error if invalid policy + if lp == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + } + + lp.lock.RLock() + defer lp.lock.RUnlock() + + // Verify if wasn't deleted before we grabbed the lock + if lp.policy == nil { + return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name) } // Generate the policy - err = policy.rotate(req.Storage) + err = lp.policy.rotate(req.Storage) return nil, err } diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go index 46a4e0cb1c..e4e08f1f03 100644 --- a/builtin/logical/transit/policy.go +++ b/builtin/logical/transit/policy.go @@ -9,6 +9,7 @@ import ( "fmt" "strconv" "strings" + "sync" "time" "github.com/hashicorp/vault/helper/certutil" @@ -21,6 +22,200 @@ const ( kdfMode = "hmac-sha256-counter" ) +// policyCache implements a simple locking cache of policies +type policyCache struct { + cache map[string]*lockingPolicy + lock sync.RWMutex +} + +// loadStoredPolicies loads stored policies into the cache. This should only be +// run at backend initialization time. +func (p *policyCache) loadStoredPolicies(storage logical.Storage) error { + p.lock.Lock() + defer p.lock.Unlock() + + policyNames, err := storage.List("policy/") + if err != nil { + return err + } + + // getPolicy will populate the cache + for _, name := range policyNames { + lp, err := p.getPolicy(&logical.Request{ + Storage: storage, + }, name) + if err != nil { + return err + } + if lp == nil { + return fmt.Errorf("policy %s key was found but value was nil") + } + } + + return nil +} + +// getPolicy loads a policy into the cache or returns one already in the cache +func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPolicy, error) { + // We don't defer this since we may need to give it up and get a write lock + p.lock.RLock() + + // First, see if we're in the cache -- if so, return that + if p.cache[name] != nil { + defer p.lock.RUnlock() + return p.cache[name], nil + } + + // If we find anything, we'll need to write into the cache, plus possibly + // persist the entry, so lock the cache + p.lock.RUnlock() + p.lock.Lock() + defer p.lock.Unlock() + + // Check one more time to ensure that another process did not write during + // our lock switcheroo. + if p.cache[name] != nil { + return p.cache[name], nil + } + + // Note that we don't need to create the locking entry until the end, + // because the policy wasn't in the cache so we don't know about it, and we + // hold the cache lock so nothing else can be writing it in right now + + // Check if the policy already exists + raw, err := req.Storage.Get("policy/" + name) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + // Decode the policy + policy := &Policy{ + Keys: KeyEntryMap{}, + } + err = json.Unmarshal(raw.Value, policy) + if err != nil { + return nil, err + } + + persistNeeded := false + // Ensure we've moved from Key -> Keys + if policy.Key != nil && len(policy.Key) > 0 { + policy.migrateKeyToKeysMap() + persistNeeded = true + } + + // With archiving, past assumptions about the length of the keys map are no longer valid + if policy.LatestVersion == 0 && len(policy.Keys) != 0 { + policy.LatestVersion = len(policy.Keys) + persistNeeded = true + } + + // We disallow setting the version to 0, since they start at 1 since moving + // to rotate-able keys, so update if it's set to 0 + if policy.MinDecryptionVersion == 0 { + policy.MinDecryptionVersion = 1 + persistNeeded = true + } + + // On first load after an upgrade, copy keys to the archive + if policy.ArchiveVersion == 0 { + persistNeeded = true + } + + if persistNeeded { + err = policy.Persist(req.Storage) + if err != nil { + return nil, err + } + } + + lp := &lockingPolicy{ + policy: policy, + } + p.cache[name] = lp + + return lp, nil +} + +// generatePolicy is used to create a new named policy with a randomly +// generated key +func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) { + p.lock.Lock() + defer p.lock.Unlock() + + // Ensure one doesn't already exist + if lp := p.cache[name]; lp != nil { + return nil, fmt.Errorf("policy %s already exists", name) + } + + // Create the policy object + policy := &Policy{ + Name: name, + CipherMode: "aes-gcm", + Derived: derived, + } + if derived { + policy.KDFMode = kdfMode + } + + err := policy.rotate(storage) + if err != nil { + return nil, err + } + + lp := &lockingPolicy{ + policy: policy, + } + p.cache[name] = lp + + // Return the policy + return lp, nil +} + +// deletePolicy deletes a policy +func (p *policyCache) deletePolicy(storage logical.Storage, name string) error { + p.lock.Lock() + defer p.lock.Unlock() + + lp := p.cache[name] + if lp == nil { + return fmt.Errorf("policy %s not found", name) + } + + // We need to ensure all other access has stopped + lp.lock.Lock() + defer lp.lock.Unlock() + + // Verify this hasn't changed + if !lp.policy.DeletionAllowed { + return fmt.Errorf("deletion not allowed for policy %s", name) + } + + err := storage.Delete("policy/" + name) + if err != nil { + return fmt.Errorf("error deleting policy %s: %s", name, err) + } + + err = storage.Delete("archive/" + name) + if err != nil { + return fmt.Errorf("error deleting archive %s: %s", name, err) + } + + lp.policy = nil + delete(p.cache, name) + + return nil +} + +// lockingPolicy holds a Policy guarded by a lock +type lockingPolicy struct { + policy *Policy + lock sync.RWMutex +} + // KeyEntry stores the key and metadata type KeyEntry struct { Key []byte `json:"key"` @@ -435,87 +630,3 @@ func (p *Policy) migrateKeyToKeysMap() { } p.Key = nil } - -func deserializePolicy(buf []byte) (*Policy, error) { - p := &Policy{ - Keys: KeyEntryMap{}, - } - if err := json.Unmarshal(buf, p); err != nil { - return nil, err - } - - return p, nil -} - -func getPolicy(req *logical.Request, name string) (*Policy, error) { - // Check if the policy already exists - raw, err := req.Storage.Get("policy/" + name) - if err != nil { - return nil, err - } - if raw == nil { - return nil, nil - } - - // Decode the policy - p, err := deserializePolicy(raw.Value) - if err != nil { - return nil, err - } - - persistNeeded := false - // Ensure we've moved from Key -> Keys - if p.Key != nil && len(p.Key) > 0 { - p.migrateKeyToKeysMap() - persistNeeded = true - } - - // With archiving, past assumptions about the length of the keys map are no longer valid - if p.LatestVersion == 0 && len(p.Keys) != 0 { - p.LatestVersion = len(p.Keys) - persistNeeded = true - } - - // We disallow setting the version to 0, since they start at 1 since moving - // to rotate-able keys, so update if it's set to 0 - if p.MinDecryptionVersion == 0 { - p.MinDecryptionVersion = 1 - persistNeeded = true - } - - // On first load after an upgrade, copy keys to the archive - if p.ArchiveVersion == 0 { - persistNeeded = true - } - - if persistNeeded { - err = p.Persist(req.Storage) - if err != nil { - return nil, err - } - } - - return p, nil -} - -// generatePolicy is used to create a new named policy with -// a randomly generated key -func generatePolicy(storage logical.Storage, name string, derived bool) (*Policy, error) { - // Create the policy object - p := &Policy{ - Name: name, - CipherMode: "aes-gcm", - Derived: derived, - } - if derived { - p.KDFMode = kdfMode - } - - err := p.rotate(storage) - if err != nil { - return nil, err - } - - // Return the policy - return p, nil -} diff --git a/builtin/logical/transit/policy_test.go b/builtin/logical/transit/policy_test.go index e134607a01..04b3c3bbb1 100644 --- a/builtin/logical/transit/policy_test.go +++ b/builtin/logical/transit/policy_test.go @@ -17,14 +17,19 @@ func resetKeysArchive() { func Test_KeyUpgrade(t *testing.T) { storage := &logical.InmemStorage{} - policy, err := generatePolicy(storage, "test", false) + policies := &policyCache{ + cache: map[string]*lockingPolicy{}, + } + lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } - if policy == nil { + if lp == nil { t.Fatal("nil policy") } + policy := lp.policy + testBytes := make([]byte, len(policy.Keys[1].Key)) copy(testBytes, policy.Keys[1].Key) @@ -51,15 +56,20 @@ func Test_ArchivingUpgrade(t *testing.T) { // zero and latest, respectively storage := &logical.InmemStorage{} + policies := &policyCache{ + cache: map[string]*lockingPolicy{}, + } - policy, err := generatePolicy(storage, "test", false) + lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } - if policy == nil { + if lp == nil { t.Fatal("policy is nil") } + policy := lp.policy + // Store the initial key in the archive keysArchive = append(keysArchive, policy.Keys[1]) checkKeys(t, policy, storage, "initial", 1, 1, 1) @@ -96,17 +106,22 @@ func Test_ArchivingUpgrade(t *testing.T) { t.Fatal(err) } + // Expire from the cache since we modified it under-the-hood + delete(policies.cache, "test") + // Now get the policy again; the upgrade should happen automatically - policy, err = getPolicy(&logical.Request{ + lp, err = policies.getPolicy(&logical.Request{ Storage: storage, }, "test") if err != nil { t.Fatal(err) } - if policy == nil { + if lp == nil { t.Fatal("policy is nil") } + policy = lp.policy + checkKeys(t, policy, storage, "upgrade", 10, 10, 10) } @@ -120,14 +135,20 @@ func Test_Archiving(t *testing.T) { storage := &logical.InmemStorage{} - policy, err := generatePolicy(storage, "test", false) + policies := &policyCache{ + cache: map[string]*lockingPolicy{}, + } + + lp, err := policies.generatePolicy(storage, "test", false) if err != nil { t.Fatal(err) } - if policy == nil { + if lp == nil { t.Fatal("policy is nil") } + policy := lp.policy + // Store the initial key in the archive keysArchive = append(keysArchive, policy.Keys[1]) checkKeys(t, policy, storage, "initial", 1, 1, 1)