mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 17:52:32 +00:00
Add an idle timeout for the server (#4760)
* Add an idle timeout for the server Because tidy operations can be long-running, this also changes all tidy operations to behave the same operationally (kick off the process, get a warning back, log errors to server log) and makes them all run in a goroutine. This could mean a sort of hard stop if Vault gets sealed because the function won't have the read lock. This should generally be okay (running tidy again should pick back up where it left off), but future work could use cleanup funcs to trigger the functions to stop. * Fix up tidy test * Add deadline to cluster connections and an idle timeout to the cluster server, plus add readheader/read timeout to api server
This commit is contained in:
@@ -26,141 +26,153 @@ func pathTidySecretID(b *backend) *framework.Path {
|
||||
}
|
||||
|
||||
// tidySecretID is used to delete entries in the whitelist that are expired.
|
||||
func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error {
|
||||
grabbed := atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1)
|
||||
if grabbed {
|
||||
defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0)
|
||||
} else {
|
||||
return fmt.Errorf("SecretID tidy operation already running")
|
||||
func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) (*logical.Response, error) {
|
||||
if !atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1) {
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation already in progress.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
var result error
|
||||
go func() {
|
||||
defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0)
|
||||
|
||||
tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error {
|
||||
roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var result error
|
||||
|
||||
// List all the accessors and add them all to a map
|
||||
accessorHashes, err := s.List(ctx, accessorIDPrefixToUse)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
accessorMap := make(map[string]bool, len(accessorHashes))
|
||||
for _, accessorHash := range accessorHashes {
|
||||
accessorMap[accessorHash] = true
|
||||
}
|
||||
// Don't cancel when the original client request goes away
|
||||
ctx = context.Background()
|
||||
|
||||
secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error {
|
||||
lock := b.secretIDLock(secretIDHMAC)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
logger := b.Logger().Named("tidy")
|
||||
|
||||
entryIndex := fmt.Sprintf("%s%s%s", secretIDPrefixToUse, roleNameHMAC, secretIDHMAC)
|
||||
secretIDEntry, err := s.Get(ctx, entryIndex)
|
||||
tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error {
|
||||
roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error fetching SecretID %q: {{err}}", secretIDHMAC), err)
|
||||
}
|
||||
|
||||
if secretIDEntry == nil {
|
||||
result = multierror.Append(result, fmt.Errorf("entry for SecretID %q is nil", secretIDHMAC))
|
||||
return nil
|
||||
}
|
||||
|
||||
if secretIDEntry.Value == nil || len(secretIDEntry.Value) == 0 {
|
||||
return fmt.Errorf("found entry for SecretID %q but actual SecretID is empty", secretIDHMAC)
|
||||
}
|
||||
|
||||
var result secretIDStorageEntry
|
||||
if err := secretIDEntry.DecodeJSON(&result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If a secret ID entry does not have a corresponding accessor
|
||||
// entry, revoke the secret ID immediately
|
||||
accessorEntry, err := b.secretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
|
||||
// List all the accessors and add them all to a map
|
||||
accessorHashes, err := s.List(ctx, accessorIDPrefixToUse)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err)
|
||||
return err
|
||||
}
|
||||
if accessorEntry == nil {
|
||||
if err := s.Delete(ctx, entryIndex); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err)
|
||||
}
|
||||
return nil
|
||||
accessorMap := make(map[string]bool, len(accessorHashes))
|
||||
for _, accessorHash := range accessorHashes {
|
||||
accessorMap[accessorHash] = true
|
||||
}
|
||||
|
||||
// ExpirationTime not being set indicates non-expiring SecretIDs
|
||||
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
|
||||
// Clean up the accessor of the secret ID first
|
||||
err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
|
||||
secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error {
|
||||
lock := b.secretIDLock(secretIDHMAC)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
entryIndex := fmt.Sprintf("%s%s%s", secretIDPrefixToUse, roleNameHMAC, secretIDHMAC)
|
||||
secretIDEntry, err := s.Get(ctx, entryIndex)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to delete secret ID accessor entry: {{err}}", err)
|
||||
return errwrap.Wrapf(fmt.Sprintf("error fetching SecretID %q: {{err}}", secretIDHMAC), err)
|
||||
}
|
||||
|
||||
if err := s.Delete(ctx, entryIndex); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting SecretID %q from storage: {{err}}", secretIDHMAC), err)
|
||||
if secretIDEntry == nil {
|
||||
result = multierror.Append(result, fmt.Errorf("entry for SecretID %q is nil", secretIDHMAC))
|
||||
return nil
|
||||
}
|
||||
|
||||
if secretIDEntry.Value == nil || len(secretIDEntry.Value) == 0 {
|
||||
return fmt.Errorf("found entry for SecretID %q but actual SecretID is empty", secretIDHMAC)
|
||||
}
|
||||
|
||||
var result secretIDStorageEntry
|
||||
if err := secretIDEntry.DecodeJSON(&result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If a secret ID entry does not have a corresponding accessor
|
||||
// entry, revoke the secret ID immediately
|
||||
accessorEntry, err := b.secretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err)
|
||||
}
|
||||
if accessorEntry == nil {
|
||||
if err := s.Delete(ctx, entryIndex); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExpirationTime not being set indicates non-expiring SecretIDs
|
||||
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
|
||||
// Clean up the accessor of the secret ID first
|
||||
err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to delete secret ID accessor entry: {{err}}", err)
|
||||
}
|
||||
|
||||
if err := s.Delete(ctx, entryIndex); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting SecretID %q from storage: {{err}}", secretIDHMAC), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// At this point, the secret ID is not expired and is valid. Delete
|
||||
// the corresponding accessor from the accessorMap. This will leave
|
||||
// only the dangling accessors in the map which can then be cleaned
|
||||
// up later.
|
||||
salt, err := b.Salt(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(accessorMap, salt.SaltID(result.SecretIDAccessor))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// At this point, the secret ID is not expired and is valid. Delete
|
||||
// the corresponding accessor from the accessorMap. This will leave
|
||||
// only the dangling accessors in the map which can then be cleaned
|
||||
// up later.
|
||||
salt, err := b.Salt(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
for _, roleNameHMAC := range roleNameHMACs {
|
||||
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, secretIDHMAC := range secretIDHMACs {
|
||||
err = secretIDCleanupFunc(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(accessorMap, salt.SaltID(result.SecretIDAccessor))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, roleNameHMAC := range roleNameHMACs {
|
||||
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, secretIDHMAC := range secretIDHMACs {
|
||||
err = secretIDCleanupFunc(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse)
|
||||
// Accessor indexes were not getting cleaned up until 0.9.3. This is a fix
|
||||
// to clean up the dangling accessor entries.
|
||||
for accessorHash, _ := range accessorMap {
|
||||
// Ideally, locking should be performed here. But for that, accessors
|
||||
// are required in plaintext, which are not available. Hence performing
|
||||
// a racy cleanup.
|
||||
err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Accessor indexes were not getting cleaned up until 0.9.3. This is a fix
|
||||
// to clean up the dangling accessor entries.
|
||||
for accessorHash, _ := range accessorMap {
|
||||
// Ideally, locking should be performed here. But for that, accessors
|
||||
// are required in plaintext, which are not available. Hence performing
|
||||
// a racy cleanup.
|
||||
err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := tidyFunc(secretIDPrefix, secretIDAccessorPrefix)
|
||||
if err != nil {
|
||||
logger.Error("error tidying global secret IDs", "error", err)
|
||||
return
|
||||
}
|
||||
err = tidyFunc(secretIDLocalPrefix, secretIDAccessorLocalPrefix)
|
||||
if err != nil {
|
||||
logger.Error("error tidying local secret IDs", "error", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := tidyFunc(secretIDPrefix, secretIDAccessorPrefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tidyFunc(secretIDLocalPrefix, secretIDAccessorLocalPrefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return result
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// pathTidySecretIDUpdate is used to delete the expired SecretID entries
|
||||
func (b *backend) pathTidySecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
return nil, b.tidySecretID(ctx, req.Storage)
|
||||
return b.tidySecretID(ctx, req.Storage)
|
||||
}
|
||||
|
||||
const pathTidySecretIDSyn = "Trigger the clean-up of expired SecretID entries."
|
||||
|
||||
@@ -3,6 +3,7 @@ package approle
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
@@ -64,11 +65,14 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) {
|
||||
t.Fatalf("bad: len(accessorHashes); expect 3, got %d", len(accessorHashes))
|
||||
}
|
||||
|
||||
err = b.tidySecretID(context.Background(), storage)
|
||||
_, err = b.tidySecretID(context.Background(), storage)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// It runs async so we give it a bit of time to run
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
accessorHashes, err = storage.List(context.Background(), "accessor/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -33,53 +33,72 @@ expiration, before it is removed from the backend storage.`,
|
||||
}
|
||||
|
||||
// tidyWhitelistIdentity is used to delete entries in the whitelist that are expired.
|
||||
func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) error {
|
||||
grabbed := atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1)
|
||||
if grabbed {
|
||||
func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) (*logical.Response, error) {
|
||||
if !atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1) {
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation already in progress.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer atomic.StoreUint32(b.tidyWhitelistCASGuard, 0)
|
||||
} else {
|
||||
return fmt.Errorf("identity whitelist tidy operation already running")
|
||||
}
|
||||
|
||||
bufferDuration := time.Duration(safety_buffer) * time.Second
|
||||
// Don't cancel when the original client request goes away
|
||||
ctx = context.Background()
|
||||
|
||||
identities, err := s.List(ctx, "whitelist/identity/")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger := b.Logger().Named("wltidy")
|
||||
|
||||
for _, instanceID := range identities {
|
||||
identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error fetching identity of instanceID %q: {{err}}", instanceID), err)
|
||||
}
|
||||
bufferDuration := time.Duration(safety_buffer) * time.Second
|
||||
|
||||
if identityEntry == nil {
|
||||
return fmt.Errorf("identity entry for instanceID %q is nil", instanceID)
|
||||
}
|
||||
|
||||
if identityEntry.Value == nil || len(identityEntry.Value) == 0 {
|
||||
return fmt.Errorf("found identity entry for instanceID %q but actual identity is empty", instanceID)
|
||||
}
|
||||
|
||||
var result whitelistIdentity
|
||||
if err := identityEntry.DecodeJSON(&result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
|
||||
if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting identity of instanceID %q from storage: {{err}}", instanceID), err)
|
||||
doTidy := func() error {
|
||||
identities, err := s.List(ctx, "whitelist/identity/")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
for _, instanceID := range identities {
|
||||
identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error fetching identity of instanceID %q: {{err}}", instanceID), err)
|
||||
}
|
||||
|
||||
if identityEntry == nil {
|
||||
return fmt.Errorf("identity entry for instanceID %q is nil", instanceID)
|
||||
}
|
||||
|
||||
if identityEntry.Value == nil || len(identityEntry.Value) == 0 {
|
||||
return fmt.Errorf("found identity entry for instanceID %q but actual identity is empty", instanceID)
|
||||
}
|
||||
|
||||
var result whitelistIdentity
|
||||
if err := identityEntry.DecodeJSON(&result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
|
||||
if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting identity of instanceID %q from storage: {{err}}", instanceID), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := doTidy(); err != nil {
|
||||
logger.Error("error running whitelist tidy", "error", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// pathTidyIdentityWhitelistUpdate is used to delete entries in the whitelist that are expired.
|
||||
func (b *backend) pathTidyIdentityWhitelistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
return nil, b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int))
|
||||
return b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int))
|
||||
}
|
||||
|
||||
const pathTidyIdentityWhitelistSyn = `
|
||||
|
||||
@@ -33,52 +33,72 @@ expiration, before it is removed from the backend storage.`,
|
||||
}
|
||||
|
||||
// tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist.
|
||||
func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) error {
|
||||
grabbed := atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1)
|
||||
if grabbed {
|
||||
func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) (*logical.Response, error) {
|
||||
if !atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1) {
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation already in progress.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer atomic.StoreUint32(b.tidyBlacklistCASGuard, 0)
|
||||
} else {
|
||||
return fmt.Errorf("roletag blacklist tidy operation already running")
|
||||
}
|
||||
|
||||
bufferDuration := time.Duration(safety_buffer) * time.Second
|
||||
tags, err := s.List(ctx, "blacklist/roletag/")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Don't cancel when the original client request goes away
|
||||
ctx = context.Background()
|
||||
|
||||
for _, tag := range tags {
|
||||
tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error fetching tag %q: {{err}}", tag), err)
|
||||
}
|
||||
logger := b.Logger().Named("bltidy")
|
||||
|
||||
if tagEntry == nil {
|
||||
return fmt.Errorf("tag entry for tag %q is nil", tag)
|
||||
}
|
||||
bufferDuration := time.Duration(safety_buffer) * time.Second
|
||||
|
||||
if tagEntry.Value == nil || len(tagEntry.Value) == 0 {
|
||||
return fmt.Errorf("found entry for tag %q but actual tag is empty", tag)
|
||||
}
|
||||
|
||||
var result roleTagBlacklistEntry
|
||||
if err := tagEntry.DecodeJSON(&result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
|
||||
if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting tag %q from storage: {{err}}", tag), err)
|
||||
doTidy := func() error {
|
||||
tags, err := s.List(ctx, "blacklist/roletag/")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
for _, tag := range tags {
|
||||
tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error fetching tag %q: {{err}}", tag), err)
|
||||
}
|
||||
|
||||
if tagEntry == nil {
|
||||
return fmt.Errorf("tag entry for tag %q is nil", tag)
|
||||
}
|
||||
|
||||
if tagEntry.Value == nil || len(tagEntry.Value) == 0 {
|
||||
return fmt.Errorf("found entry for tag %q but actual tag is empty", tag)
|
||||
}
|
||||
|
||||
var result roleTagBlacklistEntry
|
||||
if err := tagEntry.DecodeJSON(&result); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
|
||||
if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting tag %q from storage: {{err}}", tag), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := doTidy(); err != nil {
|
||||
logger.Error("error running blacklist tidy", "error", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// pathTidyRoletagBlacklistUpdate is used to clean-up the entries in the role tag blacklist.
|
||||
func (b *backend) pathTidyRoletagBlacklistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
return nil, b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int))
|
||||
return b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int))
|
||||
}
|
||||
|
||||
const pathTidyRoletagBlacklistSyn = `
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
// Factory creates a new backend implementing the logical.Backend interface
|
||||
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
b := Backend()
|
||||
b := Backend(conf)
|
||||
if err := b.Setup(ctx, conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -20,7 +20,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,
|
||||
}
|
||||
|
||||
// Backend returns a new Backend framework struct
|
||||
func Backend() *backend {
|
||||
func Backend(conf *logical.BackendConfig) *backend {
|
||||
var b backend
|
||||
b.Backend = &framework.Backend{
|
||||
Help: strings.TrimSpace(backendHelp),
|
||||
@@ -85,6 +85,8 @@ func Backend() *backend {
|
||||
}
|
||||
|
||||
b.crlLifetime = time.Hour * 72
|
||||
b.tidyCASGuard = new(uint32)
|
||||
b.storage = conf.StorageView
|
||||
|
||||
return &b
|
||||
}
|
||||
@@ -92,8 +94,10 @@ func Backend() *backend {
|
||||
type backend struct {
|
||||
*framework.Backend
|
||||
|
||||
storage logical.Storage
|
||||
crlLifetime time.Duration
|
||||
revokeStorageLock sync.RWMutex
|
||||
tidyCASGuard *uint32
|
||||
}
|
||||
|
||||
const backendHelp = `
|
||||
|
||||
@@ -135,64 +135,6 @@ func TestPKI_RequireCN(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Performs basic tests on CA functionality
|
||||
// Uses the RSA CA key
|
||||
func TestBackend_RSAKey(t *testing.T) {
|
||||
initTest.Do(setCerts)
|
||||
defaultLeaseTTLVal := time.Hour * 24
|
||||
maxLeaseTTLVal := time.Hour * 24 * 32
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: defaultLeaseTTLVal,
|
||||
MaxLeaseTTLVal: maxLeaseTTLVal,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create backend: %s", err)
|
||||
}
|
||||
|
||||
testCase := logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{},
|
||||
}
|
||||
|
||||
intdata := map[string]interface{}{}
|
||||
reqdata := map[string]interface{}{}
|
||||
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, ecCACert, intdata, reqdata)...)
|
||||
|
||||
logicaltest.Test(t, testCase)
|
||||
}
|
||||
|
||||
// Performs basic tests on CA functionality
|
||||
// Uses the EC CA key
|
||||
func TestBackend_ECKey(t *testing.T) {
|
||||
initTest.Do(setCerts)
|
||||
defaultLeaseTTLVal := time.Hour * 24
|
||||
maxLeaseTTLVal := time.Hour * 24 * 32
|
||||
b, err := Factory(context.Background(), &logical.BackendConfig{
|
||||
Logger: nil,
|
||||
System: &logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: defaultLeaseTTLVal,
|
||||
MaxLeaseTTLVal: maxLeaseTTLVal,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create backend: %s", err)
|
||||
}
|
||||
|
||||
testCase := logicaltest.TestCase{
|
||||
Backend: b,
|
||||
Steps: []logicaltest.TestStep{},
|
||||
}
|
||||
|
||||
intdata := map[string]interface{}{}
|
||||
reqdata := map[string]interface{}{}
|
||||
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, rsaCACert, intdata, reqdata)...)
|
||||
|
||||
logicaltest.Test(t, testCase)
|
||||
}
|
||||
|
||||
func TestBackend_CSRValues(t *testing.T) {
|
||||
initTest.Do(setCerts)
|
||||
defaultLeaseTTLVal := time.Hour * 24
|
||||
@@ -806,685 +748,6 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s
|
||||
return ret
|
||||
}
|
||||
|
||||
// Generates steps to test out CA configuration -- certificates + CRL expiry,
|
||||
// and ensure that the certificates are readable after storing them
|
||||
func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
|
||||
setSerialUnderTest := func(req *logical.Request) error {
|
||||
req.Path = serialUnderTest
|
||||
return nil
|
||||
}
|
||||
|
||||
ret := []logicaltest.TestStep{
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/ca",
|
||||
Data: map[string]interface{}{
|
||||
"pem_bundle": strings.Join([]string{caKey, caCert}, "\n"),
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/crl",
|
||||
Data: map[string]interface{}{
|
||||
"expiry": "16h",
|
||||
},
|
||||
},
|
||||
|
||||
// Ensure we can fetch it back via unauthenticated means, in various formats
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "cert/ca",
|
||||
Unauthenticated: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["certificate"].(string) != caCert {
|
||||
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", resp.Data["certificate"].(string), caCert)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "ca/pem",
|
||||
Unauthenticated: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
rawBytes := resp.Data["http_raw_body"].([]byte)
|
||||
if !reflect.DeepEqual(rawBytes, []byte(caCert)) {
|
||||
return fmt.Errorf("CA certificate:\n%#v\ndoes not match original:\n%#v\n", rawBytes, []byte(caCert))
|
||||
}
|
||||
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
|
||||
return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "ca",
|
||||
Unauthenticated: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
rawBytes := resp.Data["http_raw_body"].([]byte)
|
||||
pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: rawBytes,
|
||||
})))
|
||||
if pemBytes != caCert {
|
||||
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", pemBytes, caCert)
|
||||
}
|
||||
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
|
||||
return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "config/crl",
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["expiry"].(string) != "16h" {
|
||||
return fmt.Errorf("CRL lifetimes do not match (got %s)", resp.Data["expiry"].(string))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Ensure that both parts of the PEM bundle are required
|
||||
// Here, just the cert
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/ca",
|
||||
Data: map[string]interface{}{
|
||||
"pem_bundle": caCert,
|
||||
},
|
||||
ErrorOk: true,
|
||||
},
|
||||
|
||||
// Here, just the key
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/ca",
|
||||
Data: map[string]interface{}{
|
||||
"pem_bundle": caKey,
|
||||
},
|
||||
ErrorOk: true,
|
||||
},
|
||||
|
||||
// Ensure we can fetch it back via unauthenticated means, in various formats
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "cert/ca",
|
||||
Unauthenticated: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["certificate"].(string) != caCert {
|
||||
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", resp.Data["certificate"].(string), caCert)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "ca/pem",
|
||||
Unauthenticated: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
rawBytes := resp.Data["http_raw_body"].([]byte)
|
||||
if string(rawBytes) != caCert {
|
||||
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", string(rawBytes), caCert)
|
||||
}
|
||||
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
|
||||
return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "ca",
|
||||
Unauthenticated: true,
|
||||
Check: func(resp *logical.Response) error {
|
||||
rawBytes := resp.Data["http_raw_body"].([]byte)
|
||||
pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: rawBytes,
|
||||
})))
|
||||
if pemBytes != caCert {
|
||||
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", pemBytes, caCert)
|
||||
}
|
||||
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
|
||||
return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Test a bunch of generation stuff
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "root",
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "root/generate/exported",
|
||||
Data: map[string]interface{}{
|
||||
"common_name": "Root Cert",
|
||||
"ttl": "180h",
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
intdata["root"] = resp.Data["certificate"].(string)
|
||||
intdata["rootkey"] = resp.Data["private_key"].(string)
|
||||
reqdata["pem_bundle"] = strings.Join([]string{intdata["root"].(string), intdata["rootkey"].(string)}, "\n")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "intermediate/generate/exported",
|
||||
Data: map[string]interface{}{
|
||||
"common_name": "intermediate.cert.com",
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
intdata["intermediatecsr"] = resp.Data["csr"].(string)
|
||||
intdata["intermediatekey"] = resp.Data["private_key"].(string)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Re-load the root key in so we can sign it
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/ca",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
delete(reqdata, "pem_bundle")
|
||||
delete(reqdata, "ttl")
|
||||
reqdata["csr"] = intdata["intermediatecsr"].(string)
|
||||
reqdata["common_name"] = "intermediate.cert.com"
|
||||
reqdata["ttl"] = "10s"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "root/sign-intermediate",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
delete(reqdata, "csr")
|
||||
delete(reqdata, "common_name")
|
||||
delete(reqdata, "ttl")
|
||||
intdata["intermediatecert"] = resp.Data["certificate"].(string)
|
||||
reqdata["serial_number"] = resp.Data["serial_number"].(string)
|
||||
reqdata["rsa_int_serial_number"] = resp.Data["serial_number"].(string)
|
||||
reqdata["certificate"] = resp.Data["certificate"].(string)
|
||||
reqdata["pem_bundle"] = strings.Join([]string{intdata["intermediatekey"].(string), resp.Data["certificate"].(string)}, "\n")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// First load in this way to populate the private key
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/ca",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
delete(reqdata, "pem_bundle")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Now test setting the intermediate, signed CA cert
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "intermediate/set-signed",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
delete(reqdata, "certificate")
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// We expect to find a zero revocation time
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
|
||||
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
if resp.Data["revocation_time"].(int64) != 0 {
|
||||
return fmt.Errorf("expected a zero revocation time")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "revoke",
|
||||
Data: reqdata,
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "crl",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
crlBytes := resp.Data["http_raw_body"].([]byte)
|
||||
certList, err := x509.ParseCRL(crlBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
revokedList := certList.TBSCertList.RevokedCertificates
|
||||
if len(revokedList) != 1 {
|
||||
t.Fatalf("length of revoked list not 1; %d", len(revokedList))
|
||||
}
|
||||
revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":")
|
||||
if revokedString != reqdata["serial_number"].(string) {
|
||||
t.Fatalf("got serial %s, expecting %s", revokedString, reqdata["serial_number"].(string))
|
||||
}
|
||||
delete(reqdata, "serial_number")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Do it all again, with EC keys and DER format
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.DeleteOperation,
|
||||
Path: "root",
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "root/generate/exported",
|
||||
Data: map[string]interface{}{
|
||||
"common_name": "Root Cert",
|
||||
"ttl": "180h",
|
||||
"key_type": "ec",
|
||||
"key_bits": 384,
|
||||
"format": "der",
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
certBytes, _ := base64.StdEncoding.DecodeString(resp.Data["certificate"].(string))
|
||||
certPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certBytes,
|
||||
})))
|
||||
keyBytes, _ := base64.StdEncoding.DecodeString(resp.Data["private_key"].(string))
|
||||
keyPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: keyBytes,
|
||||
})))
|
||||
intdata["root"] = certPem
|
||||
intdata["rootkey"] = keyPem
|
||||
reqdata["pem_bundle"] = strings.Join([]string{certPem, keyPem}, "\n")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "intermediate/generate/exported",
|
||||
Data: map[string]interface{}{
|
||||
"format": "der",
|
||||
"key_type": "ec",
|
||||
"key_bits": 384,
|
||||
"common_name": "intermediate.cert.com",
|
||||
},
|
||||
Check: func(resp *logical.Response) error {
|
||||
csrBytes, _ := base64.StdEncoding.DecodeString(resp.Data["csr"].(string))
|
||||
csrPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE REQUEST",
|
||||
Bytes: csrBytes,
|
||||
})))
|
||||
keyBytes, _ := base64.StdEncoding.DecodeString(resp.Data["private_key"].(string))
|
||||
keyPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: keyBytes,
|
||||
})))
|
||||
intdata["intermediatecsr"] = csrPem
|
||||
intdata["intermediatekey"] = keyPem
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/ca",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
delete(reqdata, "pem_bundle")
|
||||
delete(reqdata, "ttl")
|
||||
reqdata["csr"] = intdata["intermediatecsr"].(string)
|
||||
reqdata["common_name"] = "intermediate.cert.com"
|
||||
reqdata["ttl"] = "10s"
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "root/sign-intermediate",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
delete(reqdata, "csr")
|
||||
delete(reqdata, "common_name")
|
||||
delete(reqdata, "ttl")
|
||||
intdata["intermediatecert"] = resp.Data["certificate"].(string)
|
||||
reqdata["serial_number"] = resp.Data["serial_number"].(string)
|
||||
reqdata["ec_int_serial_number"] = resp.Data["serial_number"].(string)
|
||||
reqdata["certificate"] = resp.Data["certificate"].(string)
|
||||
reqdata["pem_bundle"] = strings.Join([]string{intdata["intermediatekey"].(string), resp.Data["certificate"].(string)}, "\n")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// First load in this way to populate the private key
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "config/ca",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
delete(reqdata, "pem_bundle")
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Now test setting the intermediate, signed CA cert
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "intermediate/set-signed",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
delete(reqdata, "certificate")
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// We expect to find a zero revocation time
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
|
||||
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
if resp.Data["revocation_time"].(int64) != 0 {
|
||||
return fmt.Errorf("expected a zero revocation time")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "revoke",
|
||||
Data: reqdata,
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "crl",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
crlBytes := resp.Data["http_raw_body"].([]byte)
|
||||
certList, err := x509.ParseCRL(crlBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
revokedList := certList.TBSCertList.RevokedCertificates
|
||||
if len(revokedList) != 2 {
|
||||
t.Fatalf("length of revoked list not 2; %d", len(revokedList))
|
||||
}
|
||||
found := false
|
||||
for _, revEntry := range revokedList {
|
||||
revokedString := certutil.GetHexFormatted(revEntry.SerialNumber.Bytes(), ":")
|
||||
if revokedString == reqdata["serial_number"].(string) {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("did not find %s in CRL", reqdata["serial_number"].(string))
|
||||
}
|
||||
delete(reqdata, "serial_number")
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Make sure both serial numbers we expect to find are found
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
|
||||
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
if resp.Data["revocation_time"].(int64) == 0 {
|
||||
return fmt.Errorf("expected a non-zero revocation time")
|
||||
}
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
|
||||
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
if resp.Data["revocation_time"].(int64) == 0 {
|
||||
return fmt.Errorf("expected a non-zero revocation time")
|
||||
}
|
||||
|
||||
// Give time for the certificates to pass the safety buffer
|
||||
t.Logf("Sleeping for 15 seconds to allow safety buffer time to pass before testing tidying")
|
||||
time.Sleep(15 * time.Second)
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// This shouldn't do anything since the safety buffer is too long
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "tidy",
|
||||
Data: map[string]interface{}{
|
||||
"safety_buffer": "3h",
|
||||
"tidy_cert_store": true,
|
||||
"tidy_revocation_list": true,
|
||||
},
|
||||
},
|
||||
|
||||
// We still expect to find these
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
|
||||
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
|
||||
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Both should appear in the CRL
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "crl",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
crlBytes := resp.Data["http_raw_body"].([]byte)
|
||||
certList, err := x509.ParseCRL(crlBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
revokedList := certList.TBSCertList.RevokedCertificates
|
||||
if len(revokedList) != 2 {
|
||||
t.Fatalf("length of revoked list not 2; %d", len(revokedList))
|
||||
}
|
||||
foundRsa := false
|
||||
foundEc := false
|
||||
for _, revEntry := range revokedList {
|
||||
revokedString := certutil.GetHexFormatted(revEntry.SerialNumber.Bytes(), ":")
|
||||
if revokedString == reqdata["rsa_int_serial_number"].(string) {
|
||||
foundRsa = true
|
||||
}
|
||||
if revokedString == reqdata["ec_int_serial_number"].(string) {
|
||||
foundEc = true
|
||||
}
|
||||
}
|
||||
if !foundRsa || !foundEc {
|
||||
t.Fatalf("did not find an expected entry in CRL")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// This shouldn't do anything since the boolean values default to false
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "tidy",
|
||||
Data: map[string]interface{}{
|
||||
"safety_buffer": "1s",
|
||||
},
|
||||
},
|
||||
|
||||
// We still expect to find these
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
|
||||
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
|
||||
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
|
||||
}
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// This should remove the values since the safety buffer is short
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "tidy",
|
||||
Data: map[string]interface{}{
|
||||
"safety_buffer": "1s",
|
||||
"tidy_cert_store": true,
|
||||
"tidy_revocation_list": true,
|
||||
},
|
||||
},
|
||||
|
||||
// We do *not* expect to find these
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp != nil {
|
||||
return fmt.Errorf("expected no response")
|
||||
}
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
PreFlight: setSerialUnderTest,
|
||||
Check: func(resp *logical.Response) error {
|
||||
if resp != nil {
|
||||
return fmt.Errorf("expected no response")
|
||||
}
|
||||
|
||||
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
||||
// Both should be gone from the CRL
|
||||
logicaltest.TestStep{
|
||||
Operation: logical.ReadOperation,
|
||||
Path: "crl",
|
||||
Data: reqdata,
|
||||
Check: func(resp *logical.Response) error {
|
||||
crlBytes := resp.Data["http_raw_body"].([]byte)
|
||||
certList, err := x509.ParseCRL(crlBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
revokedList := certList.TBSCertList.RevokedCertificates
|
||||
if len(revokedList) != 0 {
|
||||
t.Fatalf("length of revoked list not 0; %d", len(revokedList))
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// Generates steps to test out various role permutations
|
||||
func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
|
||||
roleVals := roleEntry{
|
||||
@@ -2141,7 +1404,7 @@ func TestBackend_PathFetchCertList(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b := Backend()
|
||||
b := Backend(config)
|
||||
err := b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -2268,7 +1531,7 @@ func TestBackend_SignVerbatim(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b := Backend()
|
||||
b := Backend(config)
|
||||
err := b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -2824,7 +2087,7 @@ func TestBackend_SignSelfIssued(t *testing.T) {
|
||||
storage := &logical.InmemStorage{}
|
||||
config.StorageView = storage
|
||||
|
||||
b := Backend()
|
||||
b := Backend(config)
|
||||
err := b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
570
builtin/logical/pki/ca_test.go
Normal file
570
builtin/logical/pki/ca_test.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package pki
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func TestBackend_CA_Steps(t *testing.T) {
|
||||
var b *backend
|
||||
|
||||
factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
be, err := Factory(ctx, conf)
|
||||
if err == nil {
|
||||
b = be.(*backend)
|
||||
}
|
||||
return be, err
|
||||
}
|
||||
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"pki": factory,
|
||||
},
|
||||
}
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
// Set RSA/EC CA certificates
|
||||
var rsaCAKey, rsaCACert, ecCAKey, ecCACert string
|
||||
{
|
||||
cak, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
marshaledKey, err := x509.MarshalECPrivateKey(cak)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
keyPEMBlock := &pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: marshaledKey,
|
||||
}
|
||||
ecCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
subjKeyID, err := certutil.GetSubjKeyID(cak)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caCertTemplate := &x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "root.localhost",
|
||||
},
|
||||
SubjectKeyId: subjKeyID,
|
||||
DNSNames: []string{"root.localhost"},
|
||||
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
|
||||
SerialNumber: big.NewInt(mathrand.Int63()),
|
||||
NotBefore: time.Now().Add(-30 * time.Second),
|
||||
NotAfter: time.Now().Add(262980 * time.Hour),
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, cak.Public(), cak)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caCertPEMBlock := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: caBytes,
|
||||
}
|
||||
ecCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock)))
|
||||
|
||||
rak, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
marshaledKey = x509.MarshalPKCS1PrivateKey(rak)
|
||||
keyPEMBlock = &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: marshaledKey,
|
||||
}
|
||||
rsaCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
subjKeyID, err = certutil.GetSubjKeyID(rak)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, rak.Public(), rak)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
caCertPEMBlock = &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: caBytes,
|
||||
}
|
||||
rsaCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock)))
|
||||
}
|
||||
|
||||
// Setup backends
|
||||
var rsaRoot, rsaInt, ecRoot, ecInt *backend
|
||||
{
|
||||
if err := client.Sys().Mount("rsaroot", &api.MountInput{
|
||||
Type: "pki",
|
||||
Config: api.MountConfigInput{
|
||||
DefaultLeaseTTL: "16h",
|
||||
MaxLeaseTTL: "60h",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rsaRoot = b
|
||||
|
||||
if err := client.Sys().Mount("rsaint", &api.MountInput{
|
||||
Type: "pki",
|
||||
Config: api.MountConfigInput{
|
||||
DefaultLeaseTTL: "16h",
|
||||
MaxLeaseTTL: "60h",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rsaInt = b
|
||||
|
||||
if err := client.Sys().Mount("ecroot", &api.MountInput{
|
||||
Type: "pki",
|
||||
Config: api.MountConfigInput{
|
||||
DefaultLeaseTTL: "16h",
|
||||
MaxLeaseTTL: "60h",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ecRoot = b
|
||||
|
||||
if err := client.Sys().Mount("ecint", &api.MountInput{
|
||||
Type: "pki",
|
||||
Config: api.MountConfigInput{
|
||||
DefaultLeaseTTL: "16h",
|
||||
MaxLeaseTTL: "60h",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ecInt = b
|
||||
}
|
||||
|
||||
t.Run("teststeps", func(t *testing.T) {
|
||||
t.Run("rsa", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
subClient, err := client.Clone()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
subClient.SetToken(client.Token())
|
||||
runSteps(t, rsaRoot, rsaInt, subClient, "rsaroot/", "rsaint/", rsaCACert, rsaCAKey)
|
||||
})
|
||||
t.Run("ec", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
subClient, err := client.Clone()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
subClient.SetToken(client.Token())
|
||||
runSteps(t, ecRoot, ecInt, subClient, "ecroot/", "ecint/", ecCACert, ecCAKey)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func runSteps(t *testing.T, rootB, intB *backend, client *api.Client, rootName, intName, caCert, caKey string) {
|
||||
// Load CA cert/key in and ensure we can fetch it back in various formats,
|
||||
// unauthenticated
|
||||
{
|
||||
// Attempt import but only provide one the cert
|
||||
{
|
||||
_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
|
||||
"pem_bundle": caCert,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
// Same but with only the key
|
||||
{
|
||||
_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
|
||||
"pem_bundle": caKey,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
// Import CA bundle
|
||||
{
|
||||
_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
|
||||
"pem_bundle": strings.Join([]string{caKey, caCert}, "\n"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
prevToken := client.Token()
|
||||
client.SetToken("")
|
||||
|
||||
// cert/ca path
|
||||
{
|
||||
resp, err := client.Logical().Read(rootName + "cert/ca")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if diff := deep.Equal(resp.Data["certificate"].(string), caCert); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
// ca/pem path (raw string)
|
||||
{
|
||||
req := &logical.Request{
|
||||
Path: "ca/pem",
|
||||
Operation: logical.ReadOperation,
|
||||
Storage: rootB.storage,
|
||||
}
|
||||
resp, err := rootB.HandleRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if diff := deep.Equal(resp.Data["http_raw_body"].([]byte), []byte(caCert)); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
|
||||
t.Fatal("wrong content type")
|
||||
}
|
||||
}
|
||||
|
||||
// ca (raw DER bytes)
|
||||
{
|
||||
req := &logical.Request{
|
||||
Path: "ca",
|
||||
Operation: logical.ReadOperation,
|
||||
Storage: rootB.storage,
|
||||
}
|
||||
resp, err := rootB.HandleRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
rawBytes := resp.Data["http_raw_body"].([]byte)
|
||||
pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: rawBytes,
|
||||
})))
|
||||
if diff := deep.Equal(pemBytes, caCert); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
|
||||
t.Fatal("wrong content type")
|
||||
}
|
||||
}
|
||||
|
||||
client.SetToken(prevToken)
|
||||
}
|
||||
|
||||
// Configure an expiry on the CRL and verify what comes back
|
||||
{
|
||||
// Set CRL config
|
||||
{
|
||||
_, err := client.Logical().Write(rootName+"config/crl", map[string]interface{}{
|
||||
"expiry": "16h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify it
|
||||
{
|
||||
resp, err := client.Logical().Read(rootName + "config/crl")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if resp.Data["expiry"].(string) != "16h" {
|
||||
t.Fatal("expected a 16 hour expiry")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test generating a root, an intermediate, signing it, setting signed, and
|
||||
// revoking it
|
||||
|
||||
// We'll need this later
|
||||
var intSerialNumber string
|
||||
{
|
||||
// First, delete the existing CA info
|
||||
{
|
||||
_, err := client.Logical().Delete(rootName + "root")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var rootPEM, rootKey, rootPEMBundle string
|
||||
// Test exported root generation
|
||||
{
|
||||
resp, err := client.Logical().Write(rootName+"root/generate/exported", map[string]interface{}{
|
||||
"common_name": "Root Cert",
|
||||
"ttl": "180h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
rootPEM = resp.Data["certificate"].(string)
|
||||
rootKey = resp.Data["private_key"].(string)
|
||||
rootPEMBundle = strings.Join([]string{rootPEM, rootKey}, "\n")
|
||||
// This is really here to keep the use checker happy
|
||||
if rootPEMBundle == "" {
|
||||
t.Fatal("bad root pem bundle")
|
||||
}
|
||||
}
|
||||
|
||||
var intPEM, intCSR, intKey string
|
||||
// Test exported intermediate CSR generation
|
||||
{
|
||||
resp, err := client.Logical().Write(intName+"intermediate/generate/exported", map[string]interface{}{
|
||||
"common_name": "intermediate.cert.com",
|
||||
"ttl": "180h",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
intCSR = resp.Data["csr"].(string)
|
||||
intKey = resp.Data["private_key"].(string)
|
||||
// This is really here to keep the use checker happy
|
||||
if intCSR == "" || intKey == "" {
|
||||
t.Fatal("int csr or key empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Test signing
|
||||
{
|
||||
resp, err := client.Logical().Write(rootName+"root/sign-intermediate", map[string]interface{}{
|
||||
"common_name": "intermediate.cert.com",
|
||||
"ttl": "10s",
|
||||
"csr": intCSR,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
intPEM = resp.Data["certificate"].(string)
|
||||
intSerialNumber = resp.Data["serial_number"].(string)
|
||||
}
|
||||
|
||||
// Test setting signed
|
||||
{
|
||||
resp, err := client.Logical().Write(intName+"intermediate/set-signed", map[string]interface{}{
|
||||
"certificate": intPEM,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we can find it via the root
|
||||
{
|
||||
resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if resp.Data["revocation_time"].(json.Number).String() != "0" {
|
||||
t.Fatal("expected a zero revocation time")
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke the intermediate
|
||||
{
|
||||
resp, err := client.Logical().Write(rootName+"revoke", map[string]interface{}{
|
||||
"serial_number": intSerialNumber,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verifyRevocation := func(t *testing.T, serial string, shouldFind bool) {
|
||||
// Verify it is now revoked
|
||||
{
|
||||
resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
switch shouldFind {
|
||||
case true:
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
if resp.Data["revocation_time"].(json.Number).String() == "0" {
|
||||
t.Fatal("expected a non-zero revocation time")
|
||||
}
|
||||
default:
|
||||
if resp != nil {
|
||||
t.Fatalf("expected nil response, got %#v", *resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the CRL and make sure it shows up
|
||||
{
|
||||
req := &logical.Request{
|
||||
Path: "crl",
|
||||
Operation: logical.ReadOperation,
|
||||
Storage: rootB.storage,
|
||||
}
|
||||
resp, err := rootB.HandleRequest(context.Background(), req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("nil response")
|
||||
}
|
||||
crlBytes := resp.Data["http_raw_body"].([]byte)
|
||||
certList, err := x509.ParseCRL(crlBytes)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
switch shouldFind {
|
||||
case true:
|
||||
revokedList := certList.TBSCertList.RevokedCertificates
|
||||
if len(revokedList) != 1 {
|
||||
t.Fatalf("bad length of revoked list: %d", len(revokedList))
|
||||
}
|
||||
revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":")
|
||||
if revokedString != intSerialNumber {
|
||||
t.Fatalf("bad revoked serial: %s", revokedString)
|
||||
}
|
||||
default:
|
||||
revokedList := certList.TBSCertList.RevokedCertificates
|
||||
if len(revokedList) != 0 {
|
||||
t.Fatalf("bad length of revoked list: %d", len(revokedList))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate current state of revoked certificates
|
||||
verifyRevocation(t, intSerialNumber, true)
|
||||
|
||||
// Give time for the safety buffer to pass before tidying
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Test tidying
|
||||
{
|
||||
// Run with a high safety buffer, nothing should happen
|
||||
{
|
||||
resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
|
||||
"safety_buffer": "3h",
|
||||
"tidy_cert_store": true,
|
||||
"tidy_revocation_list": true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected warnings")
|
||||
}
|
||||
|
||||
// Wait a few seconds as it runs in a goroutine
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Check to make sure we still find the cert and see it on the CRL
|
||||
verifyRevocation(t, intSerialNumber, true)
|
||||
}
|
||||
|
||||
// Run with both values set false, nothing should happen
|
||||
{
|
||||
resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
|
||||
"safety_buffer": "1s",
|
||||
"tidy_cert_store": false,
|
||||
"tidy_revocation_list": false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected warnings")
|
||||
}
|
||||
|
||||
// Wait a few seconds as it runs in a goroutine
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Check to make sure we still find the cert and see it on the CRL
|
||||
verifyRevocation(t, intSerialNumber, true)
|
||||
}
|
||||
|
||||
// Run with a short safety buffer and both set to true, both should be cleared
|
||||
{
|
||||
resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
|
||||
"safety_buffer": "1s",
|
||||
"tidy_cert_store": true,
|
||||
"tidy_revocation_list": true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected warnings")
|
||||
}
|
||||
|
||||
// Wait a few seconds as it runs in a goroutine
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
// Check to make sure we still find the cert and see it on the CRL
|
||||
verifyRevocation(t, intSerialNumber, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
|
||||
config.StorageView = &logical.InmemStorage{}
|
||||
|
||||
var err error
|
||||
b := Backend()
|
||||
b := Backend(config)
|
||||
err = b.Setup(context.Background(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/errwrap"
|
||||
@@ -59,116 +60,134 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr
|
||||
|
||||
bufferDuration := time.Duration(safetyBuffer) * time.Second
|
||||
|
||||
var resp *logical.Response
|
||||
|
||||
if tidyCertStore {
|
||||
serials, err := req.Storage.List(ctx, "certs/")
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error fetching list of certs: {{err}}", err)
|
||||
}
|
||||
|
||||
for _, serial := range serials {
|
||||
certEntry, err := req.Storage.Get(ctx, "certs/"+serial)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error fetching certificate %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
if certEntry == nil {
|
||||
if resp == nil {
|
||||
resp = &logical.Response{}
|
||||
}
|
||||
resp.AddWarning(fmt.Sprintf("Certificate entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial))
|
||||
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting nil entry with serial %s: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
|
||||
if certEntry.Value == nil || len(certEntry.Value) == 0 {
|
||||
if resp == nil {
|
||||
resp = &logical.Response{}
|
||||
}
|
||||
resp.AddWarning(fmt.Sprintf("Certificate entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial))
|
||||
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting entry with nil value with serial %s: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certEntry.Value)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("unable to parse stored certificate with serial %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
if time.Now().After(cert.NotAfter.Add(bufferDuration)) {
|
||||
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from storage: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !atomic.CompareAndSwapUint32(b.tidyCASGuard, 0, 1) {
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation already in progress.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if tidyRevocationList {
|
||||
b.revokeStorageLock.Lock()
|
||||
defer b.revokeStorageLock.Unlock()
|
||||
|
||||
tidiedRevoked := false
|
||||
|
||||
revokedSerials, err := req.Storage.List(ctx, "revoked/")
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("error fetching list of revoked certs: {{err}}", err)
|
||||
}
|
||||
|
||||
var revInfo revocationInfo
|
||||
for _, serial := range revokedSerials {
|
||||
revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("unable to fetch revoked cert with serial %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
if revokedEntry == nil {
|
||||
if resp == nil {
|
||||
resp = &logical.Response{}
|
||||
}
|
||||
resp.AddWarning(fmt.Sprintf("Revoked entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial))
|
||||
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting nil revoked entry with serial %s: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
|
||||
if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 {
|
||||
if resp == nil {
|
||||
resp = &logical.Response{}
|
||||
}
|
||||
resp.AddWarning(fmt.Sprintf("Revoked entry for serial %s has nil value; tidying up since it is no longer useful for any server operations", serial))
|
||||
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting revoked entry with nil value with serial %s: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
|
||||
err = revokedEntry.DecodeJSON(&revInfo)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error decoding revocation entry for serial %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("unable to parse stored revoked certificate with serial %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) {
|
||||
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
|
||||
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from revoked list: {{err}}", serial), err)
|
||||
}
|
||||
tidiedRevoked = true
|
||||
}
|
||||
}
|
||||
|
||||
if tidiedRevoked {
|
||||
if err := buildCRL(ctx, b, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Tests using framework will screw up the storage so make a locally
|
||||
// scoped req to hold a reference
|
||||
req = &logical.Request{
|
||||
Storage: req.Storage,
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer atomic.StoreUint32(b.tidyCASGuard, 0)
|
||||
|
||||
// Don't cancel when the original client request goes away
|
||||
ctx = context.Background()
|
||||
|
||||
logger := b.Logger().Named("tidy")
|
||||
|
||||
doTidy := func() error {
|
||||
if tidyCertStore {
|
||||
serials, err := req.Storage.List(ctx, "certs/")
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error fetching list of certs: {{err}}", err)
|
||||
}
|
||||
|
||||
for _, serial := range serials {
|
||||
certEntry, err := req.Storage.Get(ctx, "certs/"+serial)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error fetching certificate %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
if certEntry == nil {
|
||||
logger.Warn("certificate entry is nil; tidying up since it is no longer useful for any server operations", "serial", serial)
|
||||
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting nil entry with serial %s: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
|
||||
if certEntry.Value == nil || len(certEntry.Value) == 0 {
|
||||
logger.Warn("certificate entry has no value; tidying up since it is no longer useful for any server operations", "serial", serial)
|
||||
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting entry with nil value with serial %s: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certEntry.Value)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("unable to parse stored certificate with serial %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
if time.Now().After(cert.NotAfter.Add(bufferDuration)) {
|
||||
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from storage: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tidyRevocationList {
|
||||
b.revokeStorageLock.Lock()
|
||||
defer b.revokeStorageLock.Unlock()
|
||||
|
||||
tidiedRevoked := false
|
||||
|
||||
revokedSerials, err := req.Storage.List(ctx, "revoked/")
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("error fetching list of revoked certs: {{err}}", err)
|
||||
}
|
||||
|
||||
var revInfo revocationInfo
|
||||
for _, serial := range revokedSerials {
|
||||
revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("unable to fetch revoked cert with serial %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
if revokedEntry == nil {
|
||||
logger.Warn("revoked entry is nil; tidying up since it is no longer useful for any server operations", "serial", serial)
|
||||
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting nil revoked entry with serial %s: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
|
||||
if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 {
|
||||
logger.Warn("revoked entry has nil value; tidying up since it is no longer useful for any server operations", "serial", serial)
|
||||
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting revoked entry with nil value with serial %s: {{err}}", serial), err)
|
||||
}
|
||||
}
|
||||
|
||||
err = revokedEntry.DecodeJSON(&revInfo)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error decoding revocation entry for serial %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("unable to parse stored revoked certificate with serial %q: {{err}}", serial), err)
|
||||
}
|
||||
|
||||
if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) {
|
||||
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
|
||||
return errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from revoked list: {{err}}", serial), err)
|
||||
}
|
||||
tidiedRevoked = true
|
||||
}
|
||||
}
|
||||
|
||||
if tidiedRevoked {
|
||||
if err := buildCRL(ctx, b, req); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := doTidy(); err != nil {
|
||||
logger.Error("error running tidy", "error", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -939,7 +939,10 @@ CLUSTER_SYNTHESIS_COMPLETE:
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
}
|
||||
go server.Serve(ln.Listener)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
proxyproto "github.com/armon/go-proxyproto"
|
||||
"github.com/hashicorp/errwrap"
|
||||
@@ -41,12 +42,14 @@ func WrapInProxyProto(listener net.Listener, config *ProxyProtoConfig) (net.List
|
||||
switch config.Behavior {
|
||||
case "use_always":
|
||||
newLn = &proxyproto.Listener{
|
||||
Listener: listener,
|
||||
Listener: listener,
|
||||
ProxyHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
case "allow_authorized", "deny_unauthorized":
|
||||
newLn = &proxyproto.Listener{
|
||||
Listener: listener,
|
||||
Listener: listener,
|
||||
ProxyHeaderTimeout: 10 * time.Second,
|
||||
SourceCheck: func(addr net.Addr) (bool, error) {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-test/deep"
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/vault/helper/logging"
|
||||
"github.com/hashicorp/vault/physical"
|
||||
@@ -106,8 +106,8 @@ func TestDynamoDBBackend(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
if !reflect.DeepEqual(inputEntry, entry) {
|
||||
t.Fatalf("exp: %#v, act: %#v", inputEntry, entry)
|
||||
if diff := deep.Equal(inputEntry, entry); diff != nil {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -285,7 +285,7 @@ func prepareDynamoDBTestContainer(t *testing.T) (cleanup func(), retAddress stri
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("deangiberson/aws-dynamodb-local", "latest", []string{})
|
||||
resource, err := pool.Run("cnadiminti/dynamodb-local", "latest", []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local DynamoDB: %s", err)
|
||||
}
|
||||
|
||||
@@ -191,15 +191,17 @@ func (m *ExpirationManager) Tidy() error {
|
||||
|
||||
var tidyErrors *multierror.Error
|
||||
|
||||
logger := m.logger.Named("tidy")
|
||||
|
||||
if !atomic.CompareAndSwapInt32(m.tidyLock, 0, 1) {
|
||||
m.logger.Warn("tidy operation on leases is already in progress")
|
||||
return fmt.Errorf("tidy operation on leases is already in progress")
|
||||
logger.Warn("tidy operation on leases is already in progress")
|
||||
return nil
|
||||
}
|
||||
|
||||
defer atomic.CompareAndSwapInt32(m.tidyLock, 1, 0)
|
||||
|
||||
m.logger.Info("beginning tidy operation on leases")
|
||||
defer m.logger.Info("finished tidy operation on leases")
|
||||
logger.Info("beginning tidy operation on leases")
|
||||
defer logger.Info("finished tidy operation on leases")
|
||||
|
||||
// Create a cache to keep track of looked up tokens
|
||||
tokenCache := make(map[string]bool)
|
||||
@@ -208,7 +210,7 @@ func (m *ExpirationManager) Tidy() error {
|
||||
tidyFunc := func(leaseID string) {
|
||||
countLease++
|
||||
if countLease%500 == 0 {
|
||||
m.logger.Info("tidying leases", "progress", countLease)
|
||||
logger.Info("tidying leases", "progress", countLease)
|
||||
}
|
||||
|
||||
le, err := m.loadEntry(leaseID)
|
||||
@@ -225,7 +227,7 @@ func (m *ExpirationManager) Tidy() error {
|
||||
var isValid, ok bool
|
||||
revokeLease := false
|
||||
if le.ClientToken == "" {
|
||||
m.logger.Debug("revoking lease which has an empty token", "lease_id", leaseID)
|
||||
logger.Debug("revoking lease which has an empty token", "lease_id", leaseID)
|
||||
revokeLease = true
|
||||
deletedCountEmptyToken++
|
||||
goto REVOKE_CHECK
|
||||
@@ -249,7 +251,7 @@ func (m *ExpirationManager) Tidy() error {
|
||||
}
|
||||
|
||||
if te == nil {
|
||||
m.logger.Debug("revoking lease which holds an invalid token", "lease_id", leaseID)
|
||||
logger.Debug("revoking lease which holds an invalid token", "lease_id", leaseID)
|
||||
revokeLease = true
|
||||
deletedCountInvalidToken++
|
||||
tokenCache[le.ClientToken] = false
|
||||
@@ -262,7 +264,7 @@ func (m *ExpirationManager) Tidy() error {
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Debug("revoking lease which contains an invalid token", "lease_id", leaseID)
|
||||
logger.Debug("revoking lease which contains an invalid token", "lease_id", leaseID)
|
||||
revokeLease = true
|
||||
deletedCountInvalidToken++
|
||||
goto REVOKE_CHECK
|
||||
@@ -285,10 +287,10 @@ func (m *ExpirationManager) Tidy() error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("number of leases scanned", "count", countLease)
|
||||
m.logger.Info("number of leases which had empty tokens", "count", deletedCountEmptyToken)
|
||||
m.logger.Info("number of leases which had invalid tokens", "count", deletedCountInvalidToken)
|
||||
m.logger.Info("number of leases successfully revoked", "count", revokedCount)
|
||||
logger.Info("number of leases scanned", "count", countLease)
|
||||
logger.Info("number of leases which had empty tokens", "count", deletedCountEmptyToken)
|
||||
logger.Info("number of leases which had invalid tokens", "count", deletedCountInvalidToken)
|
||||
logger.Info("number of leases successfully revoked", "count", revokedCount)
|
||||
|
||||
return tidyErrors.ErrorOrNil()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
@@ -38,6 +39,14 @@ func TestExpiration_Tidy(t *testing.T) {
|
||||
var err error
|
||||
|
||||
exp := mockExpiration(t)
|
||||
|
||||
// We use this later for tidy testing where we need to check the output
|
||||
logOut := new(bytes.Buffer)
|
||||
logger := log.New(&log.LoggerOptions{
|
||||
Output: logOut,
|
||||
})
|
||||
exp.logger = logger
|
||||
|
||||
if err := exp.Restore(nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -212,9 +221,11 @@ func TestExpiration_Tidy(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if !(err1 != nil && err1.Error() == "tidy operation on leases is already in progress") &&
|
||||
!(err2 != nil && err2.Error() == "tidy operation on leases is already in progress") {
|
||||
t.Fatalf("expected at least one of err1 or err2 to be set; err1: %#v\n err2:%#v\n", err1, err2)
|
||||
if err1 != nil || err2 != nil {
|
||||
t.Fatalf("got an error: err1: %v; err2: %v", err1, err2)
|
||||
}
|
||||
if !strings.Contains(logOut.String(), "tidy operation on leases is already in progress") {
|
||||
t.Fatalf("expected to see a warning saying operation in progress, output is %s", logOut.String())
|
||||
}
|
||||
|
||||
root, err := exp.tokenStore.rootToken(context.Background())
|
||||
|
||||
@@ -1182,12 +1182,17 @@ func (b *SystemBackend) handleCORSDelete(ctx context.Context, req *logical.Reque
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handleTidyLeases(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
err := b.Core.expiration.Tidy()
|
||||
if err != nil {
|
||||
b.Backend.Logger().Error("failed to tidy leases", "error", err)
|
||||
return handleErrorNoReadOnlyForward(err)
|
||||
}
|
||||
return nil, err
|
||||
go func() {
|
||||
err := b.Core.expiration.Tidy()
|
||||
if err != nil {
|
||||
b.Backend.Logger().Error("failed to tidy leases", "error", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) invalidate(ctx context.Context, key string) {
|
||||
|
||||
@@ -76,7 +76,11 @@ func (c *Core) startForwarding(ctx context.Context) error {
|
||||
// duties. Doing it this way instead of listening via the server and gRPC
|
||||
// allows us to re-use the same port via ALPN. We can just tell the server
|
||||
// to serve a given conn and which handler to use.
|
||||
fws := &http2.Server{}
|
||||
fws := &http2.Server{
|
||||
// Our forwarding connections heartbeat regularly so anything else we
|
||||
// want to go away/get cleaned up pretty rapidly
|
||||
IdleTimeout: 5 * HeartbeatInterval,
|
||||
}
|
||||
|
||||
// Shutdown coordination logic
|
||||
shutdown := new(uint32)
|
||||
@@ -147,6 +151,20 @@ func (c *Core) startForwarding(ctx context.Context) error {
|
||||
// Type assert to TLS connection and handshake to populate the
|
||||
// connection state
|
||||
tlsConn := conn.(*tls.Conn)
|
||||
|
||||
// Set a deadline for the handshake. This will cause clients
|
||||
// that don't successfully auth to be kicked out quickly.
|
||||
// Cluster connections should be reliable so being marginally
|
||||
// aggressive here is fine.
|
||||
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
|
||||
if err != nil {
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debug("error setting deadline for cluster connection", "error", err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
err = tlsConn.Handshake()
|
||||
if err != nil {
|
||||
if c.logger.IsDebug() {
|
||||
@@ -156,6 +174,16 @@ func (c *Core) startForwarding(ctx context.Context) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Now, set it back to unlimited
|
||||
err = tlsConn.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debug("error setting deadline for cluster connection", "error", err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
switch tlsConn.ConnectionState().NegotiatedProtocol {
|
||||
case requestForwardingALPN:
|
||||
if !ha {
|
||||
|
||||
@@ -130,7 +130,7 @@ type TokenStore struct {
|
||||
saltLock sync.RWMutex
|
||||
salt *salt.Salt
|
||||
|
||||
tidyLock *int32
|
||||
tidyLock *uint32
|
||||
|
||||
identityPoliciesDeriverFunc func(string) (*identity.Entity, []string, error)
|
||||
}
|
||||
@@ -150,7 +150,7 @@ func NewTokenStore(ctx context.Context, logger log.Logger, c *Core, config *logi
|
||||
tokensPendingDeletion: &sync.Map{},
|
||||
saltLock: sync.RWMutex{},
|
||||
identityPoliciesDeriverFunc: c.fetchEntityAndDerivedPolicies,
|
||||
tidyLock: new(int32),
|
||||
tidyLock: new(uint32),
|
||||
}
|
||||
|
||||
if c.policyStore != nil {
|
||||
@@ -1290,204 +1290,224 @@ func (ts *TokenStore) lookupBySaltedAccessor(ctx context.Context, saltedAccessor
|
||||
// handleTidy handles the cleaning up of leaked accessor storage entries and
|
||||
// cleaning up of leases that are associated to tokens that are expired.
|
||||
func (ts *TokenStore) handleTidy(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
var tidyErrors *multierror.Error
|
||||
|
||||
if !atomic.CompareAndSwapInt32(ts.tidyLock, 0, 1) {
|
||||
ts.logger.Warn("tidy operation on tokens is already in progress")
|
||||
return nil, fmt.Errorf("tidy operation on tokens is already in progress")
|
||||
if !atomic.CompareAndSwapUint32(ts.tidyLock, 0, 1) {
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation already in progress.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
defer atomic.CompareAndSwapInt32(ts.tidyLock, 1, 0)
|
||||
go func() {
|
||||
defer atomic.StoreUint32(ts.tidyLock, 0)
|
||||
|
||||
ts.logger.Info("beginning tidy operation on tokens")
|
||||
defer ts.logger.Info("finished tidy operation on tokens")
|
||||
// Don't cancel when the original client request goes away
|
||||
ctx = context.Background()
|
||||
|
||||
// List out all the accessors
|
||||
saltedAccessorList, err := ts.view.List(ctx, accessorPrefix)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to fetch accessor index entries: {{err}}", err)
|
||||
}
|
||||
logger := ts.logger.Named("tidy")
|
||||
|
||||
// First, clean up secondary index entries that are no longer valid
|
||||
parentList, err := ts.view.List(ctx, parentPrefix)
|
||||
if err != nil {
|
||||
return nil, errwrap.Wrapf("failed to fetch secondary index entries: {{err}}", err)
|
||||
}
|
||||
var tidyErrors *multierror.Error
|
||||
|
||||
var countParentEntries, deletedCountParentEntries, countParentList, deletedCountParentList int64
|
||||
doTidy := func() error {
|
||||
|
||||
// Scan through the secondary index entries; if there is an entry
|
||||
// with the token's salt ID at the end, remove it
|
||||
for _, parent := range parentList {
|
||||
countParentEntries++
|
||||
ts.logger.Info("beginning tidy operation on tokens")
|
||||
defer ts.logger.Info("finished tidy operation on tokens")
|
||||
|
||||
// Get the children
|
||||
children, err := ts.view.List(ctx, parentPrefix+parent)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read secondary index: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
|
||||
// First check if the salt ID of the parent exists, and if not mark this so
|
||||
// that deletion of children later with this loop below applies to all
|
||||
// children
|
||||
originalChildrenCount := int64(len(children))
|
||||
exists, _ := ts.lookupSalted(ctx, strings.TrimSuffix(parent, "/"), true)
|
||||
if exists == nil {
|
||||
ts.logger.Debug("deleting invalid parent prefix entry", "index", parentPrefix+parent)
|
||||
}
|
||||
|
||||
var deletedChildrenCount int64
|
||||
for _, child := range children {
|
||||
countParentList++
|
||||
if countParentList%500 == 0 {
|
||||
ts.logger.Info("checking validity of tokens in secondary index list", "progress", countParentList)
|
||||
// List out all the accessors
|
||||
saltedAccessorList, err := ts.view.List(ctx, accessorPrefix)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to fetch accessor index entries: {{err}}", err)
|
||||
}
|
||||
|
||||
// Look up tainted entries so we can be sure that if this isn't
|
||||
// found, it doesn't exist. Doing the following without locking
|
||||
// since appropriate locks cannot be held with salted token IDs.
|
||||
// Also perform deletion if the parent doesn't exist any more.
|
||||
te, _ := ts.lookupSalted(ctx, child, true)
|
||||
// If the child entry is not nil, but the parent doesn't exist, then turn
|
||||
// that child token into an orphan token. Theres no deletion in this case.
|
||||
if te != nil && exists == nil {
|
||||
lock := locksutil.LockForKey(ts.tokenLocks, te.ID)
|
||||
lock.Lock()
|
||||
|
||||
te.Parent = ""
|
||||
err = ts.store(ctx, te)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to convert child token into an orphan token: {{err}}", err))
|
||||
}
|
||||
lock.Unlock()
|
||||
continue
|
||||
// First, clean up secondary index entries that are no longer valid
|
||||
parentList, err := ts.view.List(ctx, parentPrefix)
|
||||
if err != nil {
|
||||
return errwrap.Wrapf("failed to fetch secondary index entries: {{err}}", err)
|
||||
}
|
||||
// Otherwise, if the entry doesn't exist, or if the parent doesn't exist go
|
||||
// on with the delete on the secondary index
|
||||
if te == nil || exists == nil {
|
||||
index := parentPrefix + parent + child
|
||||
ts.logger.Debug("deleting invalid secondary index", "index", index)
|
||||
err = ts.view.Delete(ctx, index)
|
||||
|
||||
var countParentEntries, deletedCountParentEntries, countParentList, deletedCountParentList int64
|
||||
|
||||
// Scan through the secondary index entries; if there is an entry
|
||||
// with the token's salt ID at the end, remove it
|
||||
for _, parent := range parentList {
|
||||
countParentEntries++
|
||||
|
||||
// Get the children
|
||||
children, err := ts.view.List(ctx, parentPrefix+parent)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete secondary index: {{err}}", err))
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read secondary index: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
deletedChildrenCount++
|
||||
}
|
||||
}
|
||||
// Add current children deleted count to the total count
|
||||
deletedCountParentList += deletedChildrenCount
|
||||
// N.B.: We don't call delete on the parent prefix since physical.Backend.Delete
|
||||
// implementations should be in charge of deleting empty prefixes.
|
||||
// If we deleted all the children, then add that to our deleted parent entries count.
|
||||
if originalChildrenCount == deletedChildrenCount {
|
||||
deletedCountParentEntries++
|
||||
}
|
||||
}
|
||||
|
||||
var countAccessorList,
|
||||
deletedCountAccessorEmptyToken,
|
||||
deletedCountAccessorInvalidToken,
|
||||
deletedCountInvalidTokenInAccessor int64
|
||||
// First check if the salt ID of the parent exists, and if not mark this so
|
||||
// that deletion of children later with this loop below applies to all
|
||||
// children
|
||||
originalChildrenCount := int64(len(children))
|
||||
exists, _ := ts.lookupSalted(ctx, strings.TrimSuffix(parent, "/"), true)
|
||||
if exists == nil {
|
||||
ts.logger.Debug("deleting invalid parent prefix entry", "index", parentPrefix+parent)
|
||||
}
|
||||
|
||||
// For each of the accessor, see if the token ID associated with it is
|
||||
// a valid one. If not, delete the leases associated with that token
|
||||
// and delete the accessor as well.
|
||||
for _, saltedAccessor := range saltedAccessorList {
|
||||
countAccessorList++
|
||||
if countAccessorList%500 == 0 {
|
||||
ts.logger.Info("checking if accessors contain valid tokens", "progress", countAccessorList)
|
||||
}
|
||||
var deletedChildrenCount int64
|
||||
for _, child := range children {
|
||||
countParentList++
|
||||
if countParentList%500 == 0 {
|
||||
ts.logger.Info("checking validity of tokens in secondary index list", "progress", countParentList)
|
||||
}
|
||||
|
||||
accessorEntry, err := ts.lookupBySaltedAccessor(ctx, saltedAccessor, true)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read the accessor index: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
// Look up tainted entries so we can be sure that if this isn't
|
||||
// found, it doesn't exist. Doing the following without locking
|
||||
// since appropriate locks cannot be held with salted token IDs.
|
||||
// Also perform deletion if the parent doesn't exist any more.
|
||||
te, _ := ts.lookupSalted(ctx, child, true)
|
||||
// If the child entry is not nil, but the parent doesn't exist, then turn
|
||||
// that child token into an orphan token. Theres no deletion in this case.
|
||||
if te != nil && exists == nil {
|
||||
lock := locksutil.LockForKey(ts.tokenLocks, te.ID)
|
||||
lock.Lock()
|
||||
|
||||
// A valid accessor storage entry should always have a token ID
|
||||
// in it. If not, it is an invalid accessor entry and needs to
|
||||
// be deleted.
|
||||
if accessorEntry.TokenID == "" {
|
||||
index := accessorPrefix + saltedAccessor
|
||||
// If deletion of accessor fails, move on to the next
|
||||
// item since this is just a best-effort operation
|
||||
err = ts.view.Delete(ctx, index)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete the accessor index: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
deletedCountAccessorEmptyToken++
|
||||
}
|
||||
|
||||
lock := locksutil.LockForKey(ts.tokenLocks, accessorEntry.TokenID)
|
||||
lock.RLock()
|
||||
|
||||
// Look up tainted variants so we only find entries that truly don't
|
||||
// exist
|
||||
saltedID, err := ts.SaltID(ctx, accessorEntry.TokenID)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read salt id: {{err}}", err))
|
||||
lock.RUnlock()
|
||||
continue
|
||||
}
|
||||
te, err := ts.lookupSalted(ctx, saltedID, true)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to lookup tainted ID: {{err}}", err))
|
||||
lock.RUnlock()
|
||||
continue
|
||||
}
|
||||
|
||||
lock.RUnlock()
|
||||
|
||||
// If token entry is not found assume that the token is not valid any
|
||||
// more and conclude that accessor, leases, and secondary index entries
|
||||
// for this token should not exist as well.
|
||||
if te == nil {
|
||||
ts.logger.Info("deleting token with nil entry", "salted_token", saltedID)
|
||||
|
||||
// RevokeByToken expects a '*logical.TokenEntry'. For the
|
||||
// purposes of tidying, it is sufficient if the token
|
||||
// entry only has ID set.
|
||||
tokenEntry := &logical.TokenEntry{
|
||||
ID: accessorEntry.TokenID,
|
||||
te.Parent = ""
|
||||
err = ts.store(ctx, te)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to convert child token into an orphan token: {{err}}", err))
|
||||
}
|
||||
lock.Unlock()
|
||||
continue
|
||||
}
|
||||
// Otherwise, if the entry doesn't exist, or if the parent doesn't exist go
|
||||
// on with the delete on the secondary index
|
||||
if te == nil || exists == nil {
|
||||
index := parentPrefix + parent + child
|
||||
ts.logger.Debug("deleting invalid secondary index", "index", index)
|
||||
err = ts.view.Delete(ctx, index)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete secondary index: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
deletedChildrenCount++
|
||||
}
|
||||
}
|
||||
// Add current children deleted count to the total count
|
||||
deletedCountParentList += deletedChildrenCount
|
||||
// N.B.: We don't call delete on the parent prefix since physical.Backend.Delete
|
||||
// implementations should be in charge of deleting empty prefixes.
|
||||
// If we deleted all the children, then add that to our deleted parent entries count.
|
||||
if originalChildrenCount == deletedChildrenCount {
|
||||
deletedCountParentEntries++
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to revoke the token. This will also revoke
|
||||
// the leases associated with the token.
|
||||
err := ts.expiration.RevokeByToken(tokenEntry)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to revoke leases of expired token: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
deletedCountInvalidTokenInAccessor++
|
||||
var countAccessorList,
|
||||
deletedCountAccessorEmptyToken,
|
||||
deletedCountAccessorInvalidToken,
|
||||
deletedCountInvalidTokenInAccessor int64
|
||||
|
||||
index := accessorPrefix + saltedAccessor
|
||||
// For each of the accessor, see if the token ID associated with it is
|
||||
// a valid one. If not, delete the leases associated with that token
|
||||
// and delete the accessor as well.
|
||||
for _, saltedAccessor := range saltedAccessorList {
|
||||
countAccessorList++
|
||||
if countAccessorList%500 == 0 {
|
||||
ts.logger.Info("checking if accessors contain valid tokens", "progress", countAccessorList)
|
||||
}
|
||||
|
||||
// If deletion of accessor fails, move on to the next item since
|
||||
// this is just a best-effort operation. We do this last so that on
|
||||
// next run if something above failed we still have the accessor
|
||||
// entry to try again.
|
||||
err = ts.view.Delete(ctx, index)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete accessor entry: {{err}}", err))
|
||||
continue
|
||||
accessorEntry, err := ts.lookupBySaltedAccessor(ctx, saltedAccessor, true)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read the accessor index: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
|
||||
// A valid accessor storage entry should always have a token ID
|
||||
// in it. If not, it is an invalid accessor entry and needs to
|
||||
// be deleted.
|
||||
if accessorEntry.TokenID == "" {
|
||||
index := accessorPrefix + saltedAccessor
|
||||
// If deletion of accessor fails, move on to the next
|
||||
// item since this is just a best-effort operation
|
||||
err = ts.view.Delete(ctx, index)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete the accessor index: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
deletedCountAccessorEmptyToken++
|
||||
}
|
||||
|
||||
lock := locksutil.LockForKey(ts.tokenLocks, accessorEntry.TokenID)
|
||||
lock.RLock()
|
||||
|
||||
// Look up tainted variants so we only find entries that truly don't
|
||||
// exist
|
||||
saltedID, err := ts.SaltID(ctx, accessorEntry.TokenID)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read salt id: {{err}}", err))
|
||||
lock.RUnlock()
|
||||
continue
|
||||
}
|
||||
te, err := ts.lookupSalted(ctx, saltedID, true)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to lookup tainted ID: {{err}}", err))
|
||||
lock.RUnlock()
|
||||
continue
|
||||
}
|
||||
|
||||
lock.RUnlock()
|
||||
|
||||
// If token entry is not found assume that the token is not valid any
|
||||
// more and conclude that accessor, leases, and secondary index entries
|
||||
// for this token should not exist as well.
|
||||
if te == nil {
|
||||
ts.logger.Info("deleting token with nil entry", "salted_token", saltedID)
|
||||
|
||||
// RevokeByToken expects a '*logical.TokenEntry'. For the
|
||||
// purposes of tidying, it is sufficient if the token
|
||||
// entry only has ID set.
|
||||
tokenEntry := &logical.TokenEntry{
|
||||
ID: accessorEntry.TokenID,
|
||||
}
|
||||
|
||||
// Attempt to revoke the token. This will also revoke
|
||||
// the leases associated with the token.
|
||||
err := ts.expiration.RevokeByToken(tokenEntry)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to revoke leases of expired token: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
deletedCountInvalidTokenInAccessor++
|
||||
|
||||
index := accessorPrefix + saltedAccessor
|
||||
|
||||
// If deletion of accessor fails, move on to the next item since
|
||||
// this is just a best-effort operation. We do this last so that on
|
||||
// next run if something above failed we still have the accessor
|
||||
// entry to try again.
|
||||
err = ts.view.Delete(ctx, index)
|
||||
if err != nil {
|
||||
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete accessor entry: {{err}}", err))
|
||||
continue
|
||||
}
|
||||
deletedCountAccessorInvalidToken++
|
||||
}
|
||||
}
|
||||
deletedCountAccessorInvalidToken++
|
||||
|
||||
ts.logger.Info("number of entries scanned in parent prefix", "count", countParentEntries)
|
||||
ts.logger.Info("number of entries deleted in parent prefix", "count", deletedCountParentEntries)
|
||||
ts.logger.Info("number of tokens scanned in parent index list", "count", countParentList)
|
||||
ts.logger.Info("number of tokens revoked in parent index list", "count", deletedCountParentList)
|
||||
ts.logger.Info("number of accessors scanned", "count", countAccessorList)
|
||||
ts.logger.Info("number of deleted accessors which had empty tokens", "count", deletedCountAccessorEmptyToken)
|
||||
ts.logger.Info("number of revoked tokens which were invalid but present in accessors", "count", deletedCountInvalidTokenInAccessor)
|
||||
ts.logger.Info("number of deleted accessors which had invalid tokens", "count", deletedCountAccessorInvalidToken)
|
||||
|
||||
return tidyErrors.ErrorOrNil()
|
||||
}
|
||||
}
|
||||
|
||||
ts.logger.Info("number of entries scanned in parent prefix", "count", countParentEntries)
|
||||
ts.logger.Info("number of entries deleted in parent prefix", "count", deletedCountParentEntries)
|
||||
ts.logger.Info("number of tokens scanned in parent index list", "count", countParentList)
|
||||
ts.logger.Info("number of tokens revoked in parent index list", "count", deletedCountParentList)
|
||||
ts.logger.Info("number of accessors scanned", "count", countAccessorList)
|
||||
ts.logger.Info("number of deleted accessors which had empty tokens", "count", deletedCountAccessorEmptyToken)
|
||||
ts.logger.Info("number of revoked tokens which were invalid but present in accessors", "count", deletedCountInvalidTokenInAccessor)
|
||||
ts.logger.Info("number of deleted accessors which had invalid tokens", "count", deletedCountAccessorInvalidToken)
|
||||
if err := doTidy(); err != nil {
|
||||
logger.Error("error running tidy", "error", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return nil, tidyErrors.ErrorOrNil()
|
||||
resp := &logical.Response{}
|
||||
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// handleUpdateLookupAccessor handles the auth/token/lookup-accessor path for returning
|
||||
|
||||
@@ -3777,6 +3777,9 @@ func TestTokenStore_HandleTidyCase1(t *testing.T) {
|
||||
t.Fatalf("err:%v resp:%v", err, resp)
|
||||
}
|
||||
|
||||
// Tidy runs async so give it time
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Tidy should have removed all the dangling accessor entries
|
||||
resp, err = ts.HandleRequest(context.Background(), accessorListReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
@@ -3909,6 +3912,9 @@ func TestTokenStore_HandleTidy_parentCleanup(t *testing.T) {
|
||||
t.Fatalf("err:%v resp:%v", err, resp)
|
||||
}
|
||||
|
||||
// Tidy runs async so give it time
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Tidy should have removed all the dangling accessor entries
|
||||
resp, err = ts.HandleRequest(context.Background(), accessorListReq)
|
||||
if err != nil || (resp != nil && resp.IsError()) {
|
||||
|
||||
Reference in New Issue
Block a user