mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 19:17:58 +00:00
Implement locking in the transit backend.
This ensures that we can safely rotate and modify configuration parameters with multiple requests in flight. As a side effect we also get a cache, which should provide a nice speedup since we don't need to decrypt/deserialize constantly, which would happen even with the physical LRU.
This commit is contained in:
@@ -6,36 +6,46 @@ import (
|
||||
)
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
return Backend().Setup(conf)
|
||||
b := Backend()
|
||||
be, err := b.Backend.Setup(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = b.policies.loadStoredPolicies(conf.StorageView)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return be, nil
|
||||
}
|
||||
|
||||
func Backend() *framework.Backend {
|
||||
func Backend() *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
PathsSpecial: &logical.Paths{
|
||||
Root: []string{
|
||||
"keys/*",
|
||||
},
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
// Rotate/Config needs to come before Keys
|
||||
// as the handler is greedy
|
||||
pathConfig(),
|
||||
pathRotate(),
|
||||
pathRewrap(),
|
||||
pathKeys(),
|
||||
pathEncrypt(),
|
||||
pathDecrypt(),
|
||||
pathDatakey(),
|
||||
b.pathConfig(),
|
||||
b.pathRotate(),
|
||||
b.pathRewrap(),
|
||||
b.pathKeys(),
|
||||
b.pathEncrypt(),
|
||||
b.pathDecrypt(),
|
||||
b.pathDatakey(),
|
||||
},
|
||||
|
||||
Secrets: []*framework.Secret{},
|
||||
}
|
||||
|
||||
return b.Backend
|
||||
b.policies = &policyCache{
|
||||
cache: map[string]*lockingPolicy{},
|
||||
}
|
||||
|
||||
return &b
|
||||
}
|
||||
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
policies *policyCache
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ const (
|
||||
func TestBackend_basic(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Factory: Factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
@@ -42,7 +42,7 @@ func TestBackend_basic(t *testing.T) {
|
||||
func TestBackend_datakey(t *testing.T) {
|
||||
dataKeyInfo := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Factory: Factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
@@ -57,7 +57,7 @@ func TestBackend_rotation(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
encryptHistory := make(map[int]map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Factory: Factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test", false),
|
||||
testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory),
|
||||
@@ -111,26 +111,10 @@ func TestBackend_rotation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_upsert(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepReadPolicy(t, "test", true, false),
|
||||
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepReadPolicy(t, "test", false, false),
|
||||
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
|
||||
testAccStepEnableDeletion(t, "test"),
|
||||
testAccStepDeletePolicy(t, "test"),
|
||||
testAccStepReadPolicy(t, "test", true, false),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackend_basic_derived(t *testing.T) {
|
||||
decryptData := make(map[string]interface{})
|
||||
logicaltest.Test(t, logicaltest.TestCase{
|
||||
Backend: Backend(),
|
||||
Factory: Factory,
|
||||
Steps: []logicaltest.TestStep{
|
||||
testAccStepWritePolicy(t, "test", true),
|
||||
testAccStepReadPolicy(t, "test", false, true),
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathConfig() *framework.Path {
|
||||
func (b *backend) pathConfig() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "keys/" + framework.GenericNameRegex("name") + "/config",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
@@ -29,7 +29,7 @@ to be decrypted.`,
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: pathConfigWrite,
|
||||
logical.UpdateOperation: b.pathConfigWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathConfigHelpSyn,
|
||||
@@ -37,21 +37,29 @@ to be decrypted.`,
|
||||
}
|
||||
}
|
||||
|
||||
func pathConfigWrite(
|
||||
func (b *backend) pathConfigWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
// Check if the policy already exists
|
||||
policy, err := getPolicy(req, name)
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy == nil {
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("no existing role named %s could be found", name)),
|
||||
logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.lock.Lock()
|
||||
defer lp.lock.Unlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
|
||||
}
|
||||
|
||||
resp := &logical.Response{}
|
||||
|
||||
persistNeeded := false
|
||||
@@ -70,12 +78,12 @@ func pathConfigWrite(
|
||||
}
|
||||
|
||||
if minDecryptionVersion > 0 &&
|
||||
minDecryptionVersion != policy.MinDecryptionVersion {
|
||||
if minDecryptionVersion > policy.LatestVersion {
|
||||
minDecryptionVersion != lp.policy.MinDecryptionVersion {
|
||||
if minDecryptionVersion > lp.policy.LatestVersion {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, policy.LatestVersion)), nil
|
||||
fmt.Sprintf("cannot set min decryption version of %d, latest key version is %d", minDecryptionVersion, lp.policy.LatestVersion)), nil
|
||||
}
|
||||
policy.MinDecryptionVersion = minDecryptionVersion
|
||||
lp.policy.MinDecryptionVersion = minDecryptionVersion
|
||||
persistNeeded = true
|
||||
}
|
||||
}
|
||||
@@ -83,8 +91,8 @@ func pathConfigWrite(
|
||||
allowDeletionInt, ok := d.GetOk("deletion_allowed")
|
||||
if ok {
|
||||
allowDeletion := allowDeletionInt.(bool)
|
||||
if allowDeletion != policy.DeletionAllowed {
|
||||
policy.DeletionAllowed = allowDeletion
|
||||
if allowDeletion != lp.policy.DeletionAllowed {
|
||||
lp.policy.DeletionAllowed = allowDeletion
|
||||
persistNeeded = true
|
||||
}
|
||||
}
|
||||
@@ -92,8 +100,8 @@ func pathConfigWrite(
|
||||
// Add this as a guard here before persisting since we now require the min
|
||||
// decryption version to start at 1; even if it's not explicitly set here,
|
||||
// force the upgrade
|
||||
if policy.MinDecryptionVersion == 0 {
|
||||
policy.MinDecryptionVersion = 1
|
||||
if lp.policy.MinDecryptionVersion == 0 {
|
||||
lp.policy.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
@@ -101,7 +109,7 @@ func pathConfigWrite(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return resp, policy.Persist(req.Storage)
|
||||
return resp, lp.policy.Persist(req.Storage)
|
||||
}
|
||||
|
||||
const pathConfigHelpSyn = `Configure a named encryption key`
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathDatakey() *framework.Path {
|
||||
func (b *backend) pathDatakey() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "datakey/" + framework.GenericNameRegex("plaintext") + "/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
@@ -39,7 +39,7 @@ and 512 bits are supported. Defaults to 256.`,
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: pathDatakeyWrite,
|
||||
logical.UpdateOperation: b.pathDatakeyWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathDatakeyHelpSyn,
|
||||
@@ -47,7 +47,7 @@ and 512 bits are supported. Defaults to 256.`,
|
||||
}
|
||||
}
|
||||
|
||||
func pathDatakeyWrite(
|
||||
func (b *backend) pathDatakeyWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
@@ -73,16 +73,24 @@ func pathDatakeyWrite(
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.lock.RLock()
|
||||
defer lp.lock.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
|
||||
}
|
||||
|
||||
newKey := make([]byte, 32)
|
||||
bits := d.Get("bits").(int)
|
||||
switch bits {
|
||||
@@ -99,7 +107,7 @@ func pathDatakeyWrite(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
|
||||
ciphertext, err := lp.policy.Encrypt(context, base64.StdEncoding.EncodeToString(newKey))
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathDecrypt() *framework.Path {
|
||||
func (b *backend) pathDecrypt() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "decrypt/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
@@ -30,7 +30,7 @@ func pathDecrypt() *framework.Path {
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: pathDecryptWrite,
|
||||
logical.UpdateOperation: b.pathDecryptWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathDecryptHelpSyn,
|
||||
@@ -38,7 +38,7 @@ func pathDecrypt() *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func pathDecryptWrite(
|
||||
func (b *backend) pathDecryptWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
ciphertext := d.Get("ciphertext").(string)
|
||||
@@ -58,17 +58,25 @@ func pathDecryptWrite(
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
plaintext, err := p.Decrypt(context, ciphertext)
|
||||
lp.lock.RLock()
|
||||
defer lp.lock.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
|
||||
}
|
||||
|
||||
plaintext, err := lp.policy.Decrypt(context, ciphertext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathEncrypt() *framework.Path {
|
||||
func (b *backend) pathEncrypt() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "encrypt/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
@@ -30,7 +30,7 @@ func pathEncrypt() *framework.Path {
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: pathEncryptWrite,
|
||||
logical.UpdateOperation: b.pathEncryptWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathEncryptHelpSyn,
|
||||
@@ -38,7 +38,7 @@ func pathEncrypt() *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func pathEncryptWrite(
|
||||
func (b *backend) pathEncryptWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
value := d.Get("plaintext").(string)
|
||||
@@ -46,12 +46,6 @@ func pathEncryptWrite(
|
||||
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decode the context if any
|
||||
contextRaw := d.Get("context").(string)
|
||||
var context []byte
|
||||
@@ -63,16 +57,26 @@ func pathEncryptWrite(
|
||||
}
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
isDerived := len(context) != 0
|
||||
p, err = generatePolicy(req.Storage, name, isDerived)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("failed to upsert policy: %v", err)), logical.ErrInvalidRequest
|
||||
}
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext, err := p.Encrypt(context, value)
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.lock.RLock()
|
||||
defer lp.lock.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
|
||||
}
|
||||
|
||||
ciphertext, err := lp.policy.Encrypt(context, value)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathKeys() *framework.Path {
|
||||
func (b *backend) pathKeys() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "keys/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
@@ -24,9 +24,9 @@ func pathKeys() *framework.Path {
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: pathPolicyWrite,
|
||||
logical.DeleteOperation: pathPolicyDelete,
|
||||
logical.ReadOperation: pathPolicyRead,
|
||||
logical.UpdateOperation: b.pathPolicyWrite,
|
||||
logical.DeleteOperation: b.pathPolicyDelete,
|
||||
logical.ReadOperation: b.pathPolicyRead,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathPolicyHelpSyn,
|
||||
@@ -34,13 +34,13 @@ func pathKeys() *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func pathPolicyWrite(
|
||||
func (b *backend) pathPolicyWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
derived := d.Get("derived").(bool)
|
||||
|
||||
// Check if the policy already exists
|
||||
existing, err := getPolicy(req, name)
|
||||
existing, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -49,39 +49,47 @@ func pathPolicyWrite(
|
||||
}
|
||||
|
||||
// Generate the policy
|
||||
_, err = generatePolicy(req.Storage, name, derived)
|
||||
_, err = b.policies.generatePolicy(req.Storage, name, derived)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func pathPolicyRead(
|
||||
func (b *backend) pathPolicyRead(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
p, err := getPolicy(req, name)
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p == nil {
|
||||
if lp == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
lp.lock.RLock()
|
||||
defer lp.lock.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
|
||||
}
|
||||
|
||||
// Return the response
|
||||
resp := &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"name": p.Name,
|
||||
"cipher_mode": p.CipherMode,
|
||||
"derived": p.Derived,
|
||||
"deletion_allowed": p.DeletionAllowed,
|
||||
"min_decryption_version": p.MinDecryptionVersion,
|
||||
"latest_version": p.LatestVersion,
|
||||
"name": lp.policy.Name,
|
||||
"cipher_mode": lp.policy.CipherMode,
|
||||
"derived": lp.policy.Derived,
|
||||
"deletion_allowed": lp.policy.DeletionAllowed,
|
||||
"min_decryption_version": lp.policy.MinDecryptionVersion,
|
||||
"latest_version": lp.policy.LatestVersion,
|
||||
},
|
||||
}
|
||||
if p.Derived {
|
||||
resp.Data["kdf_mode"] = p.KDFMode
|
||||
if lp.policy.Derived {
|
||||
resp.Data["kdf_mode"] = lp.policy.KDFMode
|
||||
}
|
||||
|
||||
retKeys := map[string]int64{}
|
||||
for k, v := range p.Keys {
|
||||
for k, v := range lp.policy.Keys {
|
||||
retKeys[strconv.Itoa(k)] = v.CreationTime
|
||||
}
|
||||
resp.Data["keys"] = retKeys
|
||||
@@ -89,32 +97,40 @@ func pathPolicyRead(
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func pathPolicyDelete(
|
||||
func (b *backend) pathPolicyDelete(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
p, err := getPolicy(req, name)
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err
|
||||
}
|
||||
if p == nil {
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
if !p.DeletionAllowed {
|
||||
// We don't defer here because deletePolicy also needs to grab the lock
|
||||
lp.lock.RLock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
lp.lock.RUnlock()
|
||||
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
|
||||
}
|
||||
|
||||
if !lp.policy.DeletionAllowed {
|
||||
lp.lock.RUnlock()
|
||||
return logical.ErrorResponse(fmt.Sprintf("'allow_deletion' config value is not set")), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
err = req.Storage.Delete("policy/" + name)
|
||||
// Let deletePolicy grab the lock
|
||||
lp.lock.RUnlock()
|
||||
|
||||
err = b.policies.deletePolicy(req.Storage, name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err
|
||||
}
|
||||
|
||||
err = req.Storage.Delete("archive/" + name)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("error deleting archive %s: %s", name, err)), err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathRewrap() *framework.Path {
|
||||
func (b *backend) pathRewrap() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "rewrap/" + framework.GenericNameRegex("name"),
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
@@ -30,7 +30,7 @@ func pathRewrap() *framework.Path {
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: pathRewrapWrite,
|
||||
logical.UpdateOperation: b.pathRewrapWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRewrapHelpSyn,
|
||||
@@ -38,7 +38,7 @@ func pathRewrap() *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func pathRewrapWrite(
|
||||
func (b *backend) pathRewrapWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
@@ -59,17 +59,25 @@ func pathRewrapWrite(
|
||||
}
|
||||
|
||||
// Get the policy
|
||||
p, err := getPolicy(req, name)
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Error if invalid policy
|
||||
if p == nil {
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
plaintext, err := p.Decrypt(context, value)
|
||||
lp.lock.RLock()
|
||||
defer lp.lock.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
|
||||
}
|
||||
|
||||
plaintext, err := lp.policy.Decrypt(context, value)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
@@ -85,7 +93,7 @@ func pathRewrapWrite(
|
||||
return nil, fmt.Errorf("empty plaintext returned during rewrap")
|
||||
}
|
||||
|
||||
ciphertext, err := p.Encrypt(context, plaintext)
|
||||
ciphertext, err := lp.policy.Encrypt(context, plaintext)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case certutil.UserError:
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
func pathRotate() *framework.Path {
|
||||
func (b *backend) pathRotate() *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "keys/" + framework.GenericNameRegex("name") + "/rotate",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
@@ -18,7 +18,7 @@ func pathRotate() *framework.Path {
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: pathRotateWrite,
|
||||
logical.UpdateOperation: b.pathRotateWrite,
|
||||
},
|
||||
|
||||
HelpSynopsis: pathRotateHelpSyn,
|
||||
@@ -26,23 +26,31 @@ func pathRotate() *framework.Path {
|
||||
}
|
||||
}
|
||||
|
||||
func pathRotateWrite(
|
||||
func (b *backend) pathRotateWrite(
|
||||
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
name := d.Get("name").(string)
|
||||
|
||||
// Check if the policy already exists
|
||||
policy, err := getPolicy(req, name)
|
||||
// Get the policy
|
||||
lp, err := b.policies.getPolicy(req, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if policy == nil {
|
||||
return logical.ErrorResponse(
|
||||
fmt.Sprintf("no existing role named %s could be found", name)),
|
||||
logical.ErrInvalidRequest
|
||||
|
||||
// Error if invalid policy
|
||||
if lp == nil {
|
||||
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
|
||||
}
|
||||
|
||||
lp.lock.RLock()
|
||||
defer lp.lock.RUnlock()
|
||||
|
||||
// Verify if wasn't deleted before we grabbed the lock
|
||||
if lp.policy == nil {
|
||||
return nil, fmt.Errorf("policy %s found in cache but no longer valid after lock", name)
|
||||
}
|
||||
|
||||
// Generate the policy
|
||||
err = policy.rotate(req.Storage)
|
||||
err = lp.policy.rotate(req.Storage)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
@@ -21,6 +22,200 @@ const (
|
||||
kdfMode = "hmac-sha256-counter"
|
||||
)
|
||||
|
||||
// policyCache implements a simple locking cache of policies
|
||||
type policyCache struct {
|
||||
cache map[string]*lockingPolicy
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// loadStoredPolicies loads stored policies into the cache. This should only be
|
||||
// run at backend initialization time.
|
||||
func (p *policyCache) loadStoredPolicies(storage logical.Storage) error {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
policyNames, err := storage.List("policy/")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// getPolicy will populate the cache
|
||||
for _, name := range policyNames {
|
||||
lp, err := p.getPolicy(&logical.Request{
|
||||
Storage: storage,
|
||||
}, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if lp == nil {
|
||||
return fmt.Errorf("policy %s key was found but value was nil")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPolicy loads a policy into the cache or returns one already in the cache
|
||||
func (p *policyCache) getPolicy(req *logical.Request, name string) (*lockingPolicy, error) {
|
||||
// We don't defer this since we may need to give it up and get a write lock
|
||||
p.lock.RLock()
|
||||
|
||||
// First, see if we're in the cache -- if so, return that
|
||||
if p.cache[name] != nil {
|
||||
defer p.lock.RUnlock()
|
||||
return p.cache[name], nil
|
||||
}
|
||||
|
||||
// If we find anything, we'll need to write into the cache, plus possibly
|
||||
// persist the entry, so lock the cache
|
||||
p.lock.RUnlock()
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
// Check one more time to ensure that another process did not write during
|
||||
// our lock switcheroo.
|
||||
if p.cache[name] != nil {
|
||||
return p.cache[name], nil
|
||||
}
|
||||
|
||||
// Note that we don't need to create the locking entry until the end,
|
||||
// because the policy wasn't in the cache so we don't know about it, and we
|
||||
// hold the cache lock so nothing else can be writing it in right now
|
||||
|
||||
// Check if the policy already exists
|
||||
raw, err := req.Storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
policy := &Policy{
|
||||
Keys: KeyEntryMap{},
|
||||
}
|
||||
err = json.Unmarshal(raw.Value, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
persistNeeded := false
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if policy.Key != nil && len(policy.Key) > 0 {
|
||||
policy.migrateKeyToKeysMap()
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// With archiving, past assumptions about the length of the keys map are no longer valid
|
||||
if policy.LatestVersion == 0 && len(policy.Keys) != 0 {
|
||||
policy.LatestVersion = len(policy.Keys)
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// We disallow setting the version to 0, since they start at 1 since moving
|
||||
// to rotate-able keys, so update if it's set to 0
|
||||
if policy.MinDecryptionVersion == 0 {
|
||||
policy.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// On first load after an upgrade, copy keys to the archive
|
||||
if policy.ArchiveVersion == 0 {
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
if persistNeeded {
|
||||
err = policy.Persist(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
lp := &lockingPolicy{
|
||||
policy: policy,
|
||||
}
|
||||
p.cache[name] = lp
|
||||
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
// generatePolicy is used to create a new named policy with a randomly
|
||||
// generated key
|
||||
func (p *policyCache) generatePolicy(storage logical.Storage, name string, derived bool) (*lockingPolicy, error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
// Ensure one doesn't already exist
|
||||
if lp := p.cache[name]; lp != nil {
|
||||
return nil, fmt.Errorf("policy %s already exists", name)
|
||||
}
|
||||
|
||||
// Create the policy object
|
||||
policy := &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
Derived: derived,
|
||||
}
|
||||
if derived {
|
||||
policy.KDFMode = kdfMode
|
||||
}
|
||||
|
||||
err := policy.rotate(storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lp := &lockingPolicy{
|
||||
policy: policy,
|
||||
}
|
||||
p.cache[name] = lp
|
||||
|
||||
// Return the policy
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
// deletePolicy deletes a policy
|
||||
func (p *policyCache) deletePolicy(storage logical.Storage, name string) error {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
lp := p.cache[name]
|
||||
if lp == nil {
|
||||
return fmt.Errorf("policy %s not found", name)
|
||||
}
|
||||
|
||||
// We need to ensure all other access has stopped
|
||||
lp.lock.Lock()
|
||||
defer lp.lock.Unlock()
|
||||
|
||||
// Verify this hasn't changed
|
||||
if !lp.policy.DeletionAllowed {
|
||||
return fmt.Errorf("deletion not allowed for policy %s", name)
|
||||
}
|
||||
|
||||
err := storage.Delete("policy/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting policy %s: %s", name, err)
|
||||
}
|
||||
|
||||
err = storage.Delete("archive/" + name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting archive %s: %s", name, err)
|
||||
}
|
||||
|
||||
lp.policy = nil
|
||||
delete(p.cache, name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// lockingPolicy holds a Policy guarded by a lock
|
||||
type lockingPolicy struct {
|
||||
policy *Policy
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// KeyEntry stores the key and metadata
|
||||
type KeyEntry struct {
|
||||
Key []byte `json:"key"`
|
||||
@@ -435,87 +630,3 @@ func (p *Policy) migrateKeyToKeysMap() {
|
||||
}
|
||||
p.Key = nil
|
||||
}
|
||||
|
||||
func deserializePolicy(buf []byte) (*Policy, error) {
|
||||
p := &Policy{
|
||||
Keys: KeyEntryMap{},
|
||||
}
|
||||
if err := json.Unmarshal(buf, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getPolicy(req *logical.Request, name string) (*Policy, error) {
|
||||
// Check if the policy already exists
|
||||
raw, err := req.Storage.Get("policy/" + name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Decode the policy
|
||||
p, err := deserializePolicy(raw.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
persistNeeded := false
|
||||
// Ensure we've moved from Key -> Keys
|
||||
if p.Key != nil && len(p.Key) > 0 {
|
||||
p.migrateKeyToKeysMap()
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// With archiving, past assumptions about the length of the keys map are no longer valid
|
||||
if p.LatestVersion == 0 && len(p.Keys) != 0 {
|
||||
p.LatestVersion = len(p.Keys)
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// We disallow setting the version to 0, since they start at 1 since moving
|
||||
// to rotate-able keys, so update if it's set to 0
|
||||
if p.MinDecryptionVersion == 0 {
|
||||
p.MinDecryptionVersion = 1
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
// On first load after an upgrade, copy keys to the archive
|
||||
if p.ArchiveVersion == 0 {
|
||||
persistNeeded = true
|
||||
}
|
||||
|
||||
if persistNeeded {
|
||||
err = p.Persist(req.Storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// generatePolicy is used to create a new named policy with
|
||||
// a randomly generated key
|
||||
func generatePolicy(storage logical.Storage, name string, derived bool) (*Policy, error) {
|
||||
// Create the policy object
|
||||
p := &Policy{
|
||||
Name: name,
|
||||
CipherMode: "aes-gcm",
|
||||
Derived: derived,
|
||||
}
|
||||
if derived {
|
||||
p.KDFMode = kdfMode
|
||||
}
|
||||
|
||||
err := p.rotate(storage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the policy
|
||||
return p, nil
|
||||
}
|
||||
|
||||
@@ -17,14 +17,19 @@ func resetKeysArchive() {
|
||||
|
||||
func Test_KeyUpgrade(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
policy, err := generatePolicy(storage, "test", false)
|
||||
policies := &policyCache{
|
||||
cache: map[string]*lockingPolicy{},
|
||||
}
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if policy == nil {
|
||||
if lp == nil {
|
||||
t.Fatal("nil policy")
|
||||
}
|
||||
|
||||
policy := lp.policy
|
||||
|
||||
testBytes := make([]byte, len(policy.Keys[1].Key))
|
||||
copy(testBytes, policy.Keys[1].Key)
|
||||
|
||||
@@ -51,15 +56,20 @@ func Test_ArchivingUpgrade(t *testing.T) {
|
||||
// zero and latest, respectively
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
policies := &policyCache{
|
||||
cache: map[string]*lockingPolicy{},
|
||||
}
|
||||
|
||||
policy, err := generatePolicy(storage, "test", false)
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if policy == nil {
|
||||
if lp == nil {
|
||||
t.Fatal("policy is nil")
|
||||
}
|
||||
|
||||
policy := lp.policy
|
||||
|
||||
// Store the initial key in the archive
|
||||
keysArchive = append(keysArchive, policy.Keys[1])
|
||||
checkKeys(t, policy, storage, "initial", 1, 1, 1)
|
||||
@@ -96,17 +106,22 @@ func Test_ArchivingUpgrade(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Expire from the cache since we modified it under-the-hood
|
||||
delete(policies.cache, "test")
|
||||
|
||||
// Now get the policy again; the upgrade should happen automatically
|
||||
policy, err = getPolicy(&logical.Request{
|
||||
lp, err = policies.getPolicy(&logical.Request{
|
||||
Storage: storage,
|
||||
}, "test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if policy == nil {
|
||||
if lp == nil {
|
||||
t.Fatal("policy is nil")
|
||||
}
|
||||
|
||||
policy = lp.policy
|
||||
|
||||
checkKeys(t, policy, storage, "upgrade", 10, 10, 10)
|
||||
}
|
||||
|
||||
@@ -120,14 +135,20 @@ func Test_Archiving(t *testing.T) {
|
||||
|
||||
storage := &logical.InmemStorage{}
|
||||
|
||||
policy, err := generatePolicy(storage, "test", false)
|
||||
policies := &policyCache{
|
||||
cache: map[string]*lockingPolicy{},
|
||||
}
|
||||
|
||||
lp, err := policies.generatePolicy(storage, "test", false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if policy == nil {
|
||||
if lp == nil {
|
||||
t.Fatal("policy is nil")
|
||||
}
|
||||
|
||||
policy := lp.policy
|
||||
|
||||
// Store the initial key in the archive
|
||||
keysArchive = append(keysArchive, policy.Keys[1])
|
||||
checkKeys(t, policy, storage, "initial", 1, 1, 1)
|
||||
|
||||
Reference in New Issue
Block a user