mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
Modify approle tidy to validate dangling accessors (#4981)
This commit is contained in:
committed by
Brian Kassouf
parent
8d2d9fd8bd
commit
77e61243d0
@@ -3,6 +3,7 @@ package approle
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/consts"
|
||||
"github.com/hashicorp/vault/helper/locksutil"
|
||||
@@ -56,6 +57,8 @@ type backend struct {
|
||||
// secretIDListingLock is a dedicated lock for listing SecretIDAccessors
|
||||
// for all the SecretIDs issued against an approle
|
||||
secretIDListingLock sync.RWMutex
|
||||
|
||||
testTidyDelay time.Duration
|
||||
}
|
||||
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
|
||||
@@ -38,17 +38,29 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
|
||||
go func() {
|
||||
defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0)
|
||||
|
||||
logger := b.Logger().Named("tidy")
|
||||
|
||||
checkCount := 0
|
||||
|
||||
defer func() {
|
||||
if b.testTidyDelay > 0 {
|
||||
logger.Trace("done checking entries", "num_entries", checkCount)
|
||||
}
|
||||
}()
|
||||
|
||||
// Don't cancel when the original client request goes away
|
||||
ctx = context.Background()
|
||||
|
||||
logger := b.Logger().Named("tidy")
|
||||
|
||||
tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error {
|
||||
logger.Trace("listing role HMACs", "prefix", secretIDPrefixToUse)
|
||||
|
||||
roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Trace("listing accessors", "prefix", accessorIDPrefixToUse)
|
||||
|
||||
// List all the accessors and add them all to a map
|
||||
accessorHashes, err := s.List(ctx, accessorIDPrefixToUse)
|
||||
if err != nil {
|
||||
@@ -59,7 +71,10 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
|
||||
accessorMap[accessorHash] = true
|
||||
}
|
||||
|
||||
time.Sleep(b.testTidyDelay)
|
||||
|
||||
secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error {
|
||||
checkCount++
|
||||
lock := b.secretIDLock(secretIDHMAC)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
@@ -91,6 +106,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
|
||||
return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err)
|
||||
}
|
||||
if accessorEntry == nil {
|
||||
logger.Trace("found nil accessor")
|
||||
if err := s.Delete(ctx, entryIndex); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err)
|
||||
}
|
||||
@@ -99,6 +115,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
|
||||
|
||||
// ExpirationTime not being set indicates non-expiring SecretIDs
|
||||
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
|
||||
logger.Trace("found expired secret ID")
|
||||
// Clean up the accessor of the secret ID first
|
||||
err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
|
||||
if err != nil {
|
||||
@@ -126,6 +143,7 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
|
||||
}
|
||||
|
||||
for _, roleNameHMAC := range roleNameHMACs {
|
||||
logger.Trace("listing secret ID HMACs", "role_hmac", roleNameHMAC)
|
||||
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -140,13 +158,60 @@ func (b *backend) tidySecretID(ctx context.Context, req *logical.Request) (*logi
|
||||
|
||||
// Accessor indexes were not getting cleaned up until 0.9.3. This is a fix
|
||||
// to clean up the dangling accessor entries.
|
||||
for accessorHash, _ := range accessorMap {
|
||||
// Ideally, locking should be performed here. But for that, accessors
|
||||
// are required in plaintext, which are not available. Hence performing
|
||||
// a racy cleanup.
|
||||
err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash)
|
||||
if err != nil {
|
||||
return err
|
||||
if len(accessorMap) > 0 {
|
||||
for _, lock := range b.secretIDLocks {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
}
|
||||
for accessorHash, _ := range accessorMap {
|
||||
logger.Trace("found dangling accessor, verifying")
|
||||
// Ideally, locking on accessors should be performed here too
|
||||
// but for that, accessors are required in plaintext, which are
|
||||
// not available. The code above helps but it may still be
|
||||
// racy.
|
||||
// ...
|
||||
// Look up the secret again now that we have all the locks. The
|
||||
// lock is held when writing accessor/secret so if we have the
|
||||
// lock we know we're not in a
|
||||
// wrote-accessor-but-not-yet-secret case, which can be racy.
|
||||
var entry secretIDAccessorStorageEntry
|
||||
entryIndex := accessorIDPrefixToUse + accessorHash
|
||||
se, err := s.Get(ctx, entryIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if se != nil {
|
||||
err = se.DecodeJSON(&entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The storage entry doesn't store the role ID, so we have
|
||||
// to go about this the long way; fortunately we shouldn't
|
||||
// actually hit this very often
|
||||
var found bool
|
||||
searchloop:
|
||||
for _, roleNameHMAC := range roleNameHMACs {
|
||||
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, v := range secretIDHMACs {
|
||||
if v == entry.SecretIDHMAC {
|
||||
found = true
|
||||
logger.Trace("accessor verified, not removing")
|
||||
break searchloop
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
logger.Trace("could not verify dangling accessor, removing")
|
||||
err = s.Delete(ctx, entryIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,15 @@ package approle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
func TestAppRole_TidyDanglingAccessors(t *testing.T) {
|
||||
func TestAppRole_TidyDanglingAccessors_Normal(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
b, storage := createBackendWithStorage(t)
|
||||
@@ -83,3 +85,93 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) {
|
||||
t.Fatalf("bad: len(accessorHashes); expect 1, got %d", len(accessorHashes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppRole_TidyDanglingAccessors_RaceTest(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
b, storage := createBackendWithStorage(t)
|
||||
|
||||
b.testTidyDelay = 300 * time.Millisecond
|
||||
|
||||
// Create a role
|
||||
createRole(t, b, storage, "role1", "a,b,c")
|
||||
|
||||
// Create an initial entry
|
||||
roleSecretIDReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "role/role1/secret-id",
|
||||
Storage: storage,
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), roleSecretIDReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
count := 1
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
now := time.Now()
|
||||
started := false
|
||||
for {
|
||||
if time.Now().Sub(now) > 700*time.Millisecond {
|
||||
break
|
||||
}
|
||||
if time.Now().Sub(now) > 100*time.Millisecond && !started {
|
||||
started = true
|
||||
_, err = b.tidySecretID(context.Background(), &logical.Request{
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
roleSecretIDReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "role/role1/secret-id",
|
||||
Storage: storage,
|
||||
}
|
||||
resp, err = b.HandleRequest(context.Background(), roleSecretIDReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
t.Fatalf("err:%v resp:%#v", err, resp)
|
||||
}
|
||||
}()
|
||||
count++
|
||||
}
|
||||
|
||||
t.Logf("wrote %d entries", count)
|
||||
|
||||
wg.Wait()
|
||||
// Let tidy finish
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Run tidy again
|
||||
_, err = b.tidySecretID(context.Background(), &logical.Request{
|
||||
Storage: storage,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
accessorHashes, err := storage.List(context.Background(), "accessor/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(accessorHashes) != count {
|
||||
t.Fatalf("bad: len(accessorHashes); expect %d, got %d", count, len(accessorHashes))
|
||||
}
|
||||
|
||||
roleHMACs, err := storage.List(context.Background(), secretIDPrefix)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
secretIDs, err := storage.List(context.Background(), fmt.Sprintf("%s%s", secretIDPrefix, roleHMACs[0]))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(secretIDs) != count {
|
||||
t.Fatalf("bad: len(secretIDs); expect %d, got %d", count, len(secretIDs))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user