Add an idle timeout for the server (#4760)

* Add an idle timeout for the server

Because tidy operations can be long-running, this also changes all tidy
operations to behave the same operationally (kick off the process, get a
warning back, log errors to server log) and makes them all run in a
goroutine.

This could mean a sort of hard stop if Vault gets sealed because the
function won't have the read lock. This should generally be okay
(running tidy again should pick back up where it left off), but future
work could use cleanup funcs to trigger the functions to stop.

* Fix up tidy test

* Add deadline to cluster connections and an idle timeout to the cluster server, plus add readheader/read timeout to api server
This commit is contained in:
Jeff Mitchell
2018-06-16 18:21:33 -04:00
committed by GitHub
parent 43e218e5b1
commit f493d2436e
18 changed files with 1212 additions and 1223 deletions

View File

@@ -26,141 +26,153 @@ func pathTidySecretID(b *backend) *framework.Path {
}
// tidySecretID is used to delete entries in the whitelist that are expired.
func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error {
grabbed := atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1)
if grabbed {
defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0)
} else {
return fmt.Errorf("SecretID tidy operation already running")
func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) (*logical.Response, error) {
if !atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1) {
resp := &logical.Response{}
resp.AddWarning("Tidy operation already in progress.")
return resp, nil
}
var result error
go func() {
defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0)
tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error {
roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse)
if err != nil {
return err
}
var result error
// List all the accessors and add them all to a map
accessorHashes, err := s.List(ctx, accessorIDPrefixToUse)
if err != nil {
return err
}
accessorMap := make(map[string]bool, len(accessorHashes))
for _, accessorHash := range accessorHashes {
accessorMap[accessorHash] = true
}
// Don't cancel when the original client request goes away
ctx = context.Background()
secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error {
lock := b.secretIDLock(secretIDHMAC)
lock.Lock()
defer lock.Unlock()
logger := b.Logger().Named("tidy")
entryIndex := fmt.Sprintf("%s%s%s", secretIDPrefixToUse, roleNameHMAC, secretIDHMAC)
secretIDEntry, err := s.Get(ctx, entryIndex)
tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error {
roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error fetching SecretID %q: {{err}}", secretIDHMAC), err)
}
if secretIDEntry == nil {
result = multierror.Append(result, fmt.Errorf("entry for SecretID %q is nil", secretIDHMAC))
return nil
}
if secretIDEntry.Value == nil || len(secretIDEntry.Value) == 0 {
return fmt.Errorf("found entry for SecretID %q but actual SecretID is empty", secretIDHMAC)
}
var result secretIDStorageEntry
if err := secretIDEntry.DecodeJSON(&result); err != nil {
return err
}
// If a secret ID entry does not have a corresponding accessor
// entry, revoke the secret ID immediately
accessorEntry, err := b.secretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
// List all the accessors and add them all to a map
accessorHashes, err := s.List(ctx, accessorIDPrefixToUse)
if err != nil {
return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err)
return err
}
if accessorEntry == nil {
if err := s.Delete(ctx, entryIndex); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err)
}
return nil
accessorMap := make(map[string]bool, len(accessorHashes))
for _, accessorHash := range accessorHashes {
accessorMap[accessorHash] = true
}
// ExpirationTime not being set indicates non-expiring SecretIDs
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
// Clean up the accessor of the secret ID first
err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error {
lock := b.secretIDLock(secretIDHMAC)
lock.Lock()
defer lock.Unlock()
entryIndex := fmt.Sprintf("%s%s%s", secretIDPrefixToUse, roleNameHMAC, secretIDHMAC)
secretIDEntry, err := s.Get(ctx, entryIndex)
if err != nil {
return errwrap.Wrapf("failed to delete secret ID accessor entry: {{err}}", err)
return errwrap.Wrapf(fmt.Sprintf("error fetching SecretID %q: {{err}}", secretIDHMAC), err)
}
if err := s.Delete(ctx, entryIndex); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting SecretID %q from storage: {{err}}", secretIDHMAC), err)
if secretIDEntry == nil {
result = multierror.Append(result, fmt.Errorf("entry for SecretID %q is nil", secretIDHMAC))
return nil
}
if secretIDEntry.Value == nil || len(secretIDEntry.Value) == 0 {
return fmt.Errorf("found entry for SecretID %q but actual SecretID is empty", secretIDHMAC)
}
var result secretIDStorageEntry
if err := secretIDEntry.DecodeJSON(&result); err != nil {
return err
}
// If a secret ID entry does not have a corresponding accessor
// entry, revoke the secret ID immediately
accessorEntry, err := b.secretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
if err != nil {
return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err)
}
if accessorEntry == nil {
if err := s.Delete(ctx, entryIndex); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err)
}
return nil
}
// ExpirationTime not being set indicates non-expiring SecretIDs
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
// Clean up the accessor of the secret ID first
err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse)
if err != nil {
return errwrap.Wrapf("failed to delete secret ID accessor entry: {{err}}", err)
}
if err := s.Delete(ctx, entryIndex); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting SecretID %q from storage: {{err}}", secretIDHMAC), err)
}
return nil
}
// At this point, the secret ID is not expired and is valid. Delete
// the corresponding accessor from the accessorMap. This will leave
// only the dangling accessors in the map which can then be cleaned
// up later.
salt, err := b.Salt(ctx)
if err != nil {
return err
}
delete(accessorMap, salt.SaltID(result.SecretIDAccessor))
return nil
}
// At this point, the secret ID is not expired and is valid. Delete
// the corresponding accessor from the accessorMap. This will leave
// only the dangling accessors in the map which can then be cleaned
// up later.
salt, err := b.Salt(ctx)
if err != nil {
return err
for _, roleNameHMAC := range roleNameHMACs {
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
if err != nil {
return err
}
for _, secretIDHMAC := range secretIDHMACs {
err = secretIDCleanupFunc(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse)
if err != nil {
return err
}
}
}
delete(accessorMap, salt.SaltID(result.SecretIDAccessor))
return nil
}
for _, roleNameHMAC := range roleNameHMACs {
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC))
if err != nil {
return err
}
for _, secretIDHMAC := range secretIDHMACs {
err = secretIDCleanupFunc(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse)
// Accessor indexes were not getting cleaned up until 0.9.3. This is a fix
// to clean up the dangling accessor entries.
for accessorHash, _ := range accessorMap {
// Ideally, locking should be performed here. But for that, accessors
// are required in plaintext, which are not available. Hence performing
// a racy cleanup.
err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash)
if err != nil {
return err
}
}
return nil
}
// Accessor indexes were not getting cleaned up until 0.9.3. This is a fix
// to clean up the dangling accessor entries.
for accessorHash, _ := range accessorMap {
// Ideally, locking should be performed here. But for that, accessors
// are required in plaintext, which are not available. Hence performing
// a racy cleanup.
err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash)
if err != nil {
return err
}
err := tidyFunc(secretIDPrefix, secretIDAccessorPrefix)
if err != nil {
logger.Error("error tidying global secret IDs", "error", err)
return
}
err = tidyFunc(secretIDLocalPrefix, secretIDAccessorLocalPrefix)
if err != nil {
logger.Error("error tidying local secret IDs", "error", err)
return
}
}()
return nil
}
err := tidyFunc(secretIDPrefix, secretIDAccessorPrefix)
if err != nil {
return err
}
err = tidyFunc(secretIDLocalPrefix, secretIDAccessorLocalPrefix)
if err != nil {
return err
}
return result
resp := &logical.Response{}
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
return resp, nil
}
// pathTidySecretIDUpdate is used to delete the expired SecretID entries
func (b *backend) pathTidySecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidySecretID(ctx, req.Storage)
return b.tidySecretID(ctx, req.Storage)
}
const pathTidySecretIDSyn = "Trigger the clean-up of expired SecretID entries."

View File

@@ -3,6 +3,7 @@ package approle
import (
"context"
"testing"
"time"
"github.com/hashicorp/vault/logical"
)
@@ -64,11 +65,14 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) {
t.Fatalf("bad: len(accessorHashes); expect 3, got %d", len(accessorHashes))
}
err = b.tidySecretID(context.Background(), storage)
_, err = b.tidySecretID(context.Background(), storage)
if err != nil {
t.Fatal(err)
}
// It runs async so we give it a bit of time to run
time.Sleep(10 * time.Second)
accessorHashes, err = storage.List(context.Background(), "accessor/")
if err != nil {
t.Fatal(err)

View File

@@ -33,53 +33,72 @@ expiration, before it is removed from the backend storage.`,
}
// tidyWhitelistIdentity is used to delete entries in the whitelist that are expired.
func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) error {
grabbed := atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1)
if grabbed {
func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) (*logical.Response, error) {
if !atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1) {
resp := &logical.Response{}
resp.AddWarning("Tidy operation already in progress.")
return resp, nil
}
go func() {
defer atomic.StoreUint32(b.tidyWhitelistCASGuard, 0)
} else {
return fmt.Errorf("identity whitelist tidy operation already running")
}
bufferDuration := time.Duration(safety_buffer) * time.Second
// Don't cancel when the original client request goes away
ctx = context.Background()
identities, err := s.List(ctx, "whitelist/identity/")
if err != nil {
return err
}
logger := b.Logger().Named("wltidy")
for _, instanceID := range identities {
identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error fetching identity of instanceID %q: {{err}}", instanceID), err)
}
bufferDuration := time.Duration(safety_buffer) * time.Second
if identityEntry == nil {
return fmt.Errorf("identity entry for instanceID %q is nil", instanceID)
}
if identityEntry.Value == nil || len(identityEntry.Value) == 0 {
return fmt.Errorf("found identity entry for instanceID %q but actual identity is empty", instanceID)
}
var result whitelistIdentity
if err := identityEntry.DecodeJSON(&result); err != nil {
return err
}
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting identity of instanceID %q from storage: {{err}}", instanceID), err)
doTidy := func() error {
identities, err := s.List(ctx, "whitelist/identity/")
if err != nil {
return err
}
}
}
return nil
for _, instanceID := range identities {
identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error fetching identity of instanceID %q: {{err}}", instanceID), err)
}
if identityEntry == nil {
return fmt.Errorf("identity entry for instanceID %q is nil", instanceID)
}
if identityEntry.Value == nil || len(identityEntry.Value) == 0 {
return fmt.Errorf("found identity entry for instanceID %q but actual identity is empty", instanceID)
}
var result whitelistIdentity
if err := identityEntry.DecodeJSON(&result); err != nil {
return err
}
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting identity of instanceID %q from storage: {{err}}", instanceID), err)
}
}
}
return nil
}
if err := doTidy(); err != nil {
logger.Error("error running whitelist tidy", "error", err)
return
}
}()
resp := &logical.Response{}
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
return resp, nil
}
// pathTidyIdentityWhitelistUpdate is used to delete entries in the whitelist that are expired.
func (b *backend) pathTidyIdentityWhitelistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int))
return b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int))
}
const pathTidyIdentityWhitelistSyn = `

View File

@@ -33,52 +33,72 @@ expiration, before it is removed from the backend storage.`,
}
// tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist.
func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) error {
grabbed := atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1)
if grabbed {
func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) (*logical.Response, error) {
if !atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1) {
resp := &logical.Response{}
resp.AddWarning("Tidy operation already in progress.")
return resp, nil
}
go func() {
defer atomic.StoreUint32(b.tidyBlacklistCASGuard, 0)
} else {
return fmt.Errorf("roletag blacklist tidy operation already running")
}
bufferDuration := time.Duration(safety_buffer) * time.Second
tags, err := s.List(ctx, "blacklist/roletag/")
if err != nil {
return err
}
// Don't cancel when the original client request goes away
ctx = context.Background()
for _, tag := range tags {
tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error fetching tag %q: {{err}}", tag), err)
}
logger := b.Logger().Named("bltidy")
if tagEntry == nil {
return fmt.Errorf("tag entry for tag %q is nil", tag)
}
bufferDuration := time.Duration(safety_buffer) * time.Second
if tagEntry.Value == nil || len(tagEntry.Value) == 0 {
return fmt.Errorf("found entry for tag %q but actual tag is empty", tag)
}
var result roleTagBlacklistEntry
if err := tagEntry.DecodeJSON(&result); err != nil {
return err
}
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting tag %q from storage: {{err}}", tag), err)
doTidy := func() error {
tags, err := s.List(ctx, "blacklist/roletag/")
if err != nil {
return err
}
}
}
return nil
for _, tag := range tags {
tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error fetching tag %q: {{err}}", tag), err)
}
if tagEntry == nil {
return fmt.Errorf("tag entry for tag %q is nil", tag)
}
if tagEntry.Value == nil || len(tagEntry.Value) == 0 {
return fmt.Errorf("found entry for tag %q but actual tag is empty", tag)
}
var result roleTagBlacklistEntry
if err := tagEntry.DecodeJSON(&result); err != nil {
return err
}
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting tag %q from storage: {{err}}", tag), err)
}
}
}
return nil
}
if err := doTidy(); err != nil {
logger.Error("error running blacklist tidy", "error", err)
return
}
}()
resp := &logical.Response{}
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
return resp, nil
}
// pathTidyRoletagBlacklistUpdate is used to clean-up the entries in the role tag blacklist.
func (b *backend) pathTidyRoletagBlacklistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int))
return b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int))
}
const pathTidyRoletagBlacklistSyn = `

View File

@@ -12,7 +12,7 @@ import (
// Factory creates a new backend implementing the logical.Backend interface
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
b := Backend(conf)
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
@@ -20,7 +20,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend,
}
// Backend returns a new Backend framework struct
func Backend() *backend {
func Backend(conf *logical.BackendConfig) *backend {
var b backend
b.Backend = &framework.Backend{
Help: strings.TrimSpace(backendHelp),
@@ -85,6 +85,8 @@ func Backend() *backend {
}
b.crlLifetime = time.Hour * 72
b.tidyCASGuard = new(uint32)
b.storage = conf.StorageView
return &b
}
@@ -92,8 +94,10 @@ func Backend() *backend {
type backend struct {
*framework.Backend
storage logical.Storage
crlLifetime time.Duration
revokeStorageLock sync.RWMutex
tidyCASGuard *uint32
}
const backendHelp = `

View File

@@ -135,64 +135,6 @@ func TestPKI_RequireCN(t *testing.T) {
}
}
// Performs basic tests on CA functionality
// Uses the RSA CA key
func TestBackend_RSAKey(t *testing.T) {
initTest.Do(setCerts)
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
})
if err != nil {
t.Fatalf("Unable to create backend: %s", err)
}
testCase := logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{},
}
intdata := map[string]interface{}{}
reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, ecCACert, intdata, reqdata)...)
logicaltest.Test(t, testCase)
}
// Performs basic tests on CA functionality
// Uses the EC CA key
func TestBackend_ECKey(t *testing.T) {
initTest.Do(setCerts)
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
})
if err != nil {
t.Fatalf("Unable to create backend: %s", err)
}
testCase := logicaltest.TestCase{
Backend: b,
Steps: []logicaltest.TestStep{},
}
intdata := map[string]interface{}{}
reqdata := map[string]interface{}{}
testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, rsaCACert, intdata, reqdata)...)
logicaltest.Test(t, testCase)
}
func TestBackend_CSRValues(t *testing.T) {
initTest.Do(setCerts)
defaultLeaseTTLVal := time.Hour * 24
@@ -806,685 +748,6 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s
return ret
}
// Generates steps to test out CA configuration -- certificates + CRL expiry,
// and ensure that the certificates are readable after storing them
func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep {
setSerialUnderTest := func(req *logical.Request) error {
req.Path = serialUnderTest
return nil
}
ret := []logicaltest.TestStep{
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/ca",
Data: map[string]interface{}{
"pem_bundle": strings.Join([]string{caKey, caCert}, "\n"),
},
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/crl",
Data: map[string]interface{}{
"expiry": "16h",
},
},
// Ensure we can fetch it back via unauthenticated means, in various formats
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "cert/ca",
Unauthenticated: true,
Check: func(resp *logical.Response) error {
if resp.Data["certificate"].(string) != caCert {
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", resp.Data["certificate"].(string), caCert)
}
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "ca/pem",
Unauthenticated: true,
Check: func(resp *logical.Response) error {
rawBytes := resp.Data["http_raw_body"].([]byte)
if !reflect.DeepEqual(rawBytes, []byte(caCert)) {
return fmt.Errorf("CA certificate:\n%#v\ndoes not match original:\n%#v\n", rawBytes, []byte(caCert))
}
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string))
}
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "ca",
Unauthenticated: true,
Check: func(resp *logical.Response) error {
rawBytes := resp.Data["http_raw_body"].([]byte)
pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: rawBytes,
})))
if pemBytes != caCert {
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", pemBytes, caCert)
}
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string))
}
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "config/crl",
Check: func(resp *logical.Response) error {
if resp.Data["expiry"].(string) != "16h" {
return fmt.Errorf("CRL lifetimes do not match (got %s)", resp.Data["expiry"].(string))
}
return nil
},
},
// Ensure that both parts of the PEM bundle are required
// Here, just the cert
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/ca",
Data: map[string]interface{}{
"pem_bundle": caCert,
},
ErrorOk: true,
},
// Here, just the key
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/ca",
Data: map[string]interface{}{
"pem_bundle": caKey,
},
ErrorOk: true,
},
// Ensure we can fetch it back via unauthenticated means, in various formats
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "cert/ca",
Unauthenticated: true,
Check: func(resp *logical.Response) error {
if resp.Data["certificate"].(string) != caCert {
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", resp.Data["certificate"].(string), caCert)
}
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "ca/pem",
Unauthenticated: true,
Check: func(resp *logical.Response) error {
rawBytes := resp.Data["http_raw_body"].([]byte)
if string(rawBytes) != caCert {
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", string(rawBytes), caCert)
}
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string))
}
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "ca",
Unauthenticated: true,
Check: func(resp *logical.Response) error {
rawBytes := resp.Data["http_raw_body"].([]byte)
pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: rawBytes,
})))
if pemBytes != caCert {
return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", pemBytes, caCert)
}
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string))
}
return nil
},
},
// Test a bunch of generation stuff
logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "root",
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "root/generate/exported",
Data: map[string]interface{}{
"common_name": "Root Cert",
"ttl": "180h",
},
Check: func(resp *logical.Response) error {
intdata["root"] = resp.Data["certificate"].(string)
intdata["rootkey"] = resp.Data["private_key"].(string)
reqdata["pem_bundle"] = strings.Join([]string{intdata["root"].(string), intdata["rootkey"].(string)}, "\n")
return nil
},
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "intermediate/generate/exported",
Data: map[string]interface{}{
"common_name": "intermediate.cert.com",
},
Check: func(resp *logical.Response) error {
intdata["intermediatecsr"] = resp.Data["csr"].(string)
intdata["intermediatekey"] = resp.Data["private_key"].(string)
return nil
},
},
// Re-load the root key in so we can sign it
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/ca",
Data: reqdata,
Check: func(resp *logical.Response) error {
delete(reqdata, "pem_bundle")
delete(reqdata, "ttl")
reqdata["csr"] = intdata["intermediatecsr"].(string)
reqdata["common_name"] = "intermediate.cert.com"
reqdata["ttl"] = "10s"
return nil
},
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "root/sign-intermediate",
Data: reqdata,
Check: func(resp *logical.Response) error {
delete(reqdata, "csr")
delete(reqdata, "common_name")
delete(reqdata, "ttl")
intdata["intermediatecert"] = resp.Data["certificate"].(string)
reqdata["serial_number"] = resp.Data["serial_number"].(string)
reqdata["rsa_int_serial_number"] = resp.Data["serial_number"].(string)
reqdata["certificate"] = resp.Data["certificate"].(string)
reqdata["pem_bundle"] = strings.Join([]string{intdata["intermediatekey"].(string), resp.Data["certificate"].(string)}, "\n")
return nil
},
},
// First load in this way to populate the private key
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/ca",
Data: reqdata,
Check: func(resp *logical.Response) error {
delete(reqdata, "pem_bundle")
return nil
},
},
// Now test setting the intermediate, signed CA cert
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "intermediate/set-signed",
Data: reqdata,
Check: func(resp *logical.Response) error {
delete(reqdata, "certificate")
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
return nil
},
},
// We expect to find a zero revocation time
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
}
if resp.Data["revocation_time"].(int64) != 0 {
return fmt.Errorf("expected a zero revocation time")
}
return nil
},
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "revoke",
Data: reqdata,
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "crl",
Data: reqdata,
Check: func(resp *logical.Response) error {
crlBytes := resp.Data["http_raw_body"].([]byte)
certList, err := x509.ParseCRL(crlBytes)
if err != nil {
t.Fatalf("err: %s", err)
}
revokedList := certList.TBSCertList.RevokedCertificates
if len(revokedList) != 1 {
t.Fatalf("length of revoked list not 1; %d", len(revokedList))
}
revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":")
if revokedString != reqdata["serial_number"].(string) {
t.Fatalf("got serial %s, expecting %s", revokedString, reqdata["serial_number"].(string))
}
delete(reqdata, "serial_number")
return nil
},
},
// Do it all again, with EC keys and DER format
logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "root",
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "root/generate/exported",
Data: map[string]interface{}{
"common_name": "Root Cert",
"ttl": "180h",
"key_type": "ec",
"key_bits": 384,
"format": "der",
},
Check: func(resp *logical.Response) error {
certBytes, _ := base64.StdEncoding.DecodeString(resp.Data["certificate"].(string))
certPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})))
keyBytes, _ := base64.StdEncoding.DecodeString(resp.Data["private_key"].(string))
keyPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyBytes,
})))
intdata["root"] = certPem
intdata["rootkey"] = keyPem
reqdata["pem_bundle"] = strings.Join([]string{certPem, keyPem}, "\n")
return nil
},
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "intermediate/generate/exported",
Data: map[string]interface{}{
"format": "der",
"key_type": "ec",
"key_bits": 384,
"common_name": "intermediate.cert.com",
},
Check: func(resp *logical.Response) error {
csrBytes, _ := base64.StdEncoding.DecodeString(resp.Data["csr"].(string))
csrPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrBytes,
})))
keyBytes, _ := base64.StdEncoding.DecodeString(resp.Data["private_key"].(string))
keyPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyBytes,
})))
intdata["intermediatecsr"] = csrPem
intdata["intermediatekey"] = keyPem
return nil
},
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/ca",
Data: reqdata,
Check: func(resp *logical.Response) error {
delete(reqdata, "pem_bundle")
delete(reqdata, "ttl")
reqdata["csr"] = intdata["intermediatecsr"].(string)
reqdata["common_name"] = "intermediate.cert.com"
reqdata["ttl"] = "10s"
return nil
},
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "root/sign-intermediate",
Data: reqdata,
Check: func(resp *logical.Response) error {
delete(reqdata, "csr")
delete(reqdata, "common_name")
delete(reqdata, "ttl")
intdata["intermediatecert"] = resp.Data["certificate"].(string)
reqdata["serial_number"] = resp.Data["serial_number"].(string)
reqdata["ec_int_serial_number"] = resp.Data["serial_number"].(string)
reqdata["certificate"] = resp.Data["certificate"].(string)
reqdata["pem_bundle"] = strings.Join([]string{intdata["intermediatekey"].(string), resp.Data["certificate"].(string)}, "\n")
return nil
},
},
// First load in this way to populate the private key
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "config/ca",
Data: reqdata,
Check: func(resp *logical.Response) error {
delete(reqdata, "pem_bundle")
return nil
},
},
// Now test setting the intermediate, signed CA cert
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "intermediate/set-signed",
Data: reqdata,
Check: func(resp *logical.Response) error {
delete(reqdata, "certificate")
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
return nil
},
},
// We expect to find a zero revocation time
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
}
if resp.Data["revocation_time"].(int64) != 0 {
return fmt.Errorf("expected a zero revocation time")
}
return nil
},
},
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "revoke",
Data: reqdata,
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "crl",
Data: reqdata,
Check: func(resp *logical.Response) error {
crlBytes := resp.Data["http_raw_body"].([]byte)
certList, err := x509.ParseCRL(crlBytes)
if err != nil {
t.Fatalf("err: %s", err)
}
revokedList := certList.TBSCertList.RevokedCertificates
if len(revokedList) != 2 {
t.Fatalf("length of revoked list not 2; %d", len(revokedList))
}
found := false
for _, revEntry := range revokedList {
revokedString := certutil.GetHexFormatted(revEntry.SerialNumber.Bytes(), ":")
if revokedString == reqdata["serial_number"].(string) {
found = true
}
}
if !found {
t.Fatalf("did not find %s in CRL", reqdata["serial_number"].(string))
}
delete(reqdata, "serial_number")
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
return nil
},
},
// Make sure both serial numbers we expect to find are found
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
}
if resp.Data["revocation_time"].(int64) == 0 {
return fmt.Errorf("expected a non-zero revocation time")
}
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
}
if resp.Data["revocation_time"].(int64) == 0 {
return fmt.Errorf("expected a non-zero revocation time")
}
// Give time for the certificates to pass the safety buffer
t.Logf("Sleeping for 15 seconds to allow safety buffer time to pass before testing tidying")
time.Sleep(15 * time.Second)
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
return nil
},
},
// This shouldn't do anything since the safety buffer is too long
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "tidy",
Data: map[string]interface{}{
"safety_buffer": "3h",
"tidy_cert_store": true,
"tidy_revocation_list": true,
},
},
// We still expect to find these
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
}
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
}
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
return nil
},
},
// Both should appear in the CRL
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "crl",
Data: reqdata,
Check: func(resp *logical.Response) error {
crlBytes := resp.Data["http_raw_body"].([]byte)
certList, err := x509.ParseCRL(crlBytes)
if err != nil {
t.Fatalf("err: %s", err)
}
revokedList := certList.TBSCertList.RevokedCertificates
if len(revokedList) != 2 {
t.Fatalf("length of revoked list not 2; %d", len(revokedList))
}
foundRsa := false
foundEc := false
for _, revEntry := range revokedList {
revokedString := certutil.GetHexFormatted(revEntry.SerialNumber.Bytes(), ":")
if revokedString == reqdata["rsa_int_serial_number"].(string) {
foundRsa = true
}
if revokedString == reqdata["ec_int_serial_number"].(string) {
foundEc = true
}
}
if !foundRsa || !foundEc {
t.Fatalf("did not find an expected entry in CRL")
}
return nil
},
},
// This shouldn't do anything since the boolean values default to false
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "tidy",
Data: map[string]interface{}{
"safety_buffer": "1s",
},
},
// We still expect to find these
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
}
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" {
return fmt.Errorf("got an error: %s", resp.Data["error"].(string))
}
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
return nil
},
},
// This should remove the values since the safety buffer is short
logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "tidy",
Data: map[string]interface{}{
"safety_buffer": "1s",
"tidy_cert_store": true,
"tidy_revocation_list": true,
},
},
// We do *not* expect to find these
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp != nil {
return fmt.Errorf("expected no response")
}
serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string)
return nil
},
},
logicaltest.TestStep{
Operation: logical.ReadOperation,
PreFlight: setSerialUnderTest,
Check: func(resp *logical.Response) error {
if resp != nil {
return fmt.Errorf("expected no response")
}
serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string)
return nil
},
},
// Both should be gone from the CRL
logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "crl",
Data: reqdata,
Check: func(resp *logical.Response) error {
crlBytes := resp.Data["http_raw_body"].([]byte)
certList, err := x509.ParseCRL(crlBytes)
if err != nil {
t.Fatalf("err: %s", err)
}
revokedList := certList.TBSCertList.RevokedCertificates
if len(revokedList) != 0 {
t.Fatalf("length of revoked list not 0; %d", len(revokedList))
}
return nil
},
},
}
return ret
}
// Generates steps to test out various role permutations
func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep {
roleVals := roleEntry{
@@ -2141,7 +1404,7 @@ func TestBackend_PathFetchCertList(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b := Backend()
b := Backend(config)
err := b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
@@ -2268,7 +1531,7 @@ func TestBackend_SignVerbatim(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b := Backend()
b := Backend(config)
err := b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
@@ -2824,7 +2087,7 @@ func TestBackend_SignSelfIssued(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b := Backend()
b := Backend(config)
err := b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)

View File

@@ -0,0 +1,570 @@
package pki
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"math/big"
mathrand "math/rand"
"strings"
"testing"
"time"
"github.com/go-test/deep"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/certutil"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
)
func TestBackend_CA_Steps(t *testing.T) {
var b *backend
factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
be, err := Factory(ctx, conf)
if err == nil {
b = be.(*backend)
}
return be, err
}
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
client := cluster.Cores[0].Client
// Set RSA/EC CA certificates
var rsaCAKey, rsaCACert, ecCAKey, ecCACert string
{
cak, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
marshaledKey, err := x509.MarshalECPrivateKey(cak)
if err != nil {
panic(err)
}
keyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledKey,
}
ecCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
if err != nil {
panic(err)
}
subjKeyID, err := certutil.GetSubjKeyID(cak)
if err != nil {
panic(err)
}
caCertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "root.localhost",
},
SubjectKeyId: subjKeyID,
DNSNames: []string{"root.localhost"},
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
}
caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, cak.Public(), cak)
if err != nil {
panic(err)
}
caCertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
ecCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock)))
rak, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
marshaledKey = x509.MarshalPKCS1PrivateKey(rak)
keyPEMBlock = &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: marshaledKey,
}
rsaCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock)))
if err != nil {
panic(err)
}
subjKeyID, err = certutil.GetSubjKeyID(rak)
if err != nil {
panic(err)
}
caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, rak.Public(), rak)
if err != nil {
panic(err)
}
caCertPEMBlock = &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
rsaCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock)))
}
// Setup backends
var rsaRoot, rsaInt, ecRoot, ecInt *backend
{
if err := client.Sys().Mount("rsaroot", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
}); err != nil {
t.Fatal(err)
}
rsaRoot = b
if err := client.Sys().Mount("rsaint", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
}); err != nil {
t.Fatal(err)
}
rsaInt = b
if err := client.Sys().Mount("ecroot", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
}); err != nil {
t.Fatal(err)
}
ecRoot = b
if err := client.Sys().Mount("ecint", &api.MountInput{
Type: "pki",
Config: api.MountConfigInput{
DefaultLeaseTTL: "16h",
MaxLeaseTTL: "60h",
},
}); err != nil {
t.Fatal(err)
}
ecInt = b
}
t.Run("teststeps", func(t *testing.T) {
t.Run("rsa", func(t *testing.T) {
t.Parallel()
subClient, err := client.Clone()
if err != nil {
t.Fatal(err)
}
subClient.SetToken(client.Token())
runSteps(t, rsaRoot, rsaInt, subClient, "rsaroot/", "rsaint/", rsaCACert, rsaCAKey)
})
t.Run("ec", func(t *testing.T) {
t.Parallel()
subClient, err := client.Clone()
if err != nil {
t.Fatal(err)
}
subClient.SetToken(client.Token())
runSteps(t, ecRoot, ecInt, subClient, "ecroot/", "ecint/", ecCACert, ecCAKey)
})
})
}
func runSteps(t *testing.T, rootB, intB *backend, client *api.Client, rootName, intName, caCert, caKey string) {
// Load CA cert/key in and ensure we can fetch it back in various formats,
// unauthenticated
{
// Attempt import but only provide one the cert
{
_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
"pem_bundle": caCert,
})
if err == nil {
t.Fatal("expected error")
}
}
// Same but with only the key
{
_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
"pem_bundle": caKey,
})
if err == nil {
t.Fatal("expected error")
}
}
// Import CA bundle
{
_, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{
"pem_bundle": strings.Join([]string{caKey, caCert}, "\n"),
})
if err != nil {
t.Fatal(err)
}
}
prevToken := client.Token()
client.SetToken("")
// cert/ca path
{
resp, err := client.Logical().Read(rootName + "cert/ca")
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
if diff := deep.Equal(resp.Data["certificate"].(string), caCert); diff != nil {
t.Fatal(diff)
}
}
// ca/pem path (raw string)
{
req := &logical.Request{
Path: "ca/pem",
Operation: logical.ReadOperation,
Storage: rootB.storage,
}
resp, err := rootB.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
if diff := deep.Equal(resp.Data["http_raw_body"].([]byte), []byte(caCert)); diff != nil {
t.Fatal(diff)
}
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
t.Fatal("wrong content type")
}
}
// ca (raw DER bytes)
{
req := &logical.Request{
Path: "ca",
Operation: logical.ReadOperation,
Storage: rootB.storage,
}
resp, err := rootB.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
rawBytes := resp.Data["http_raw_body"].([]byte)
pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: rawBytes,
})))
if diff := deep.Equal(pemBytes, caCert); diff != nil {
t.Fatal(diff)
}
if resp.Data["http_content_type"].(string) != "application/pkix-cert" {
t.Fatal("wrong content type")
}
}
client.SetToken(prevToken)
}
// Configure an expiry on the CRL and verify what comes back
{
// Set CRL config
{
_, err := client.Logical().Write(rootName+"config/crl", map[string]interface{}{
"expiry": "16h",
})
if err != nil {
t.Fatal(err)
}
}
// Verify it
{
resp, err := client.Logical().Read(rootName + "config/crl")
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
if resp.Data["expiry"].(string) != "16h" {
t.Fatal("expected a 16 hour expiry")
}
}
}
// Test generating a root, an intermediate, signing it, setting signed, and
// revoking it
// We'll need this later
var intSerialNumber string
{
// First, delete the existing CA info
{
_, err := client.Logical().Delete(rootName + "root")
if err != nil {
t.Fatal(err)
}
}
var rootPEM, rootKey, rootPEMBundle string
// Test exported root generation
{
resp, err := client.Logical().Write(rootName+"root/generate/exported", map[string]interface{}{
"common_name": "Root Cert",
"ttl": "180h",
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
rootPEM = resp.Data["certificate"].(string)
rootKey = resp.Data["private_key"].(string)
rootPEMBundle = strings.Join([]string{rootPEM, rootKey}, "\n")
// This is really here to keep the use checker happy
if rootPEMBundle == "" {
t.Fatal("bad root pem bundle")
}
}
var intPEM, intCSR, intKey string
// Test exported intermediate CSR generation
{
resp, err := client.Logical().Write(intName+"intermediate/generate/exported", map[string]interface{}{
"common_name": "intermediate.cert.com",
"ttl": "180h",
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
intCSR = resp.Data["csr"].(string)
intKey = resp.Data["private_key"].(string)
// This is really here to keep the use checker happy
if intCSR == "" || intKey == "" {
t.Fatal("int csr or key empty")
}
}
// Test signing
{
resp, err := client.Logical().Write(rootName+"root/sign-intermediate", map[string]interface{}{
"common_name": "intermediate.cert.com",
"ttl": "10s",
"csr": intCSR,
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
intPEM = resp.Data["certificate"].(string)
intSerialNumber = resp.Data["serial_number"].(string)
}
// Test setting signed
{
resp, err := client.Logical().Write(intName+"intermediate/set-signed", map[string]interface{}{
"certificate": intPEM,
})
if err != nil {
t.Fatal(err)
}
if resp != nil {
t.Fatal("expected nil response")
}
}
// Verify we can find it via the root
{
resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
if resp.Data["revocation_time"].(json.Number).String() != "0" {
t.Fatal("expected a zero revocation time")
}
}
// Revoke the intermediate
{
resp, err := client.Logical().Write(rootName+"revoke", map[string]interface{}{
"serial_number": intSerialNumber,
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
}
}
verifyRevocation := func(t *testing.T, serial string, shouldFind bool) {
// Verify it is now revoked
{
resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber)
if err != nil {
t.Fatal(err)
}
switch shouldFind {
case true:
if resp == nil {
t.Fatal("nil response")
}
if resp.Data["revocation_time"].(json.Number).String() == "0" {
t.Fatal("expected a non-zero revocation time")
}
default:
if resp != nil {
t.Fatalf("expected nil response, got %#v", *resp)
}
}
}
// Fetch the CRL and make sure it shows up
{
req := &logical.Request{
Path: "crl",
Operation: logical.ReadOperation,
Storage: rootB.storage,
}
resp, err := rootB.HandleRequest(context.Background(), req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("nil response")
}
crlBytes := resp.Data["http_raw_body"].([]byte)
certList, err := x509.ParseCRL(crlBytes)
if err != nil {
t.Fatal(err)
}
switch shouldFind {
case true:
revokedList := certList.TBSCertList.RevokedCertificates
if len(revokedList) != 1 {
t.Fatalf("bad length of revoked list: %d", len(revokedList))
}
revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":")
if revokedString != intSerialNumber {
t.Fatalf("bad revoked serial: %s", revokedString)
}
default:
revokedList := certList.TBSCertList.RevokedCertificates
if len(revokedList) != 0 {
t.Fatalf("bad length of revoked list: %d", len(revokedList))
}
}
}
}
// Validate current state of revoked certificates
verifyRevocation(t, intSerialNumber, true)
// Give time for the safety buffer to pass before tidying
time.Sleep(10 * time.Second)
// Test tidying
{
// Run with a high safety buffer, nothing should happen
{
resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
"safety_buffer": "3h",
"tidy_cert_store": true,
"tidy_revocation_list": true,
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected warnings")
}
// Wait a few seconds as it runs in a goroutine
time.Sleep(5 * time.Second)
// Check to make sure we still find the cert and see it on the CRL
verifyRevocation(t, intSerialNumber, true)
}
// Run with both values set false, nothing should happen
{
resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
"safety_buffer": "1s",
"tidy_cert_store": false,
"tidy_revocation_list": false,
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected warnings")
}
// Wait a few seconds as it runs in a goroutine
time.Sleep(5 * time.Second)
// Check to make sure we still find the cert and see it on the CRL
verifyRevocation(t, intSerialNumber, true)
}
// Run with a short safety buffer and both set to true, both should be cleared
{
resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{
"safety_buffer": "1s",
"tidy_cert_store": true,
"tidy_revocation_list": true,
})
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected warnings")
}
// Wait a few seconds as it runs in a goroutine
time.Sleep(5 * time.Second)
// Check to make sure we still find the cert and see it on the CRL
verifyRevocation(t, intSerialNumber, false)
}
}
}

View File

@@ -15,7 +15,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
config.StorageView = &logical.InmemStorage{}
var err error
b := Backend()
b := Backend(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/x509"
"fmt"
"sync/atomic"
"time"
"github.com/hashicorp/errwrap"
@@ -59,116 +60,134 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr
bufferDuration := time.Duration(safetyBuffer) * time.Second
var resp *logical.Response
if tidyCertStore {
serials, err := req.Storage.List(ctx, "certs/")
if err != nil {
return nil, errwrap.Wrapf("error fetching list of certs: {{err}}", err)
}
for _, serial := range serials {
certEntry, err := req.Storage.Get(ctx, "certs/"+serial)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error fetching certificate %q: {{err}}", serial), err)
}
if certEntry == nil {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning(fmt.Sprintf("Certificate entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial))
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting nil entry with serial %s: {{err}}", serial), err)
}
}
if certEntry.Value == nil || len(certEntry.Value) == 0 {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning(fmt.Sprintf("Certificate entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial))
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting entry with nil value with serial %s: {{err}}", serial), err)
}
}
cert, err := x509.ParseCertificate(certEntry.Value)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("unable to parse stored certificate with serial %q: {{err}}", serial), err)
}
if time.Now().After(cert.NotAfter.Add(bufferDuration)) {
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from storage: {{err}}", serial), err)
}
}
}
if !atomic.CompareAndSwapUint32(b.tidyCASGuard, 0, 1) {
resp := &logical.Response{}
resp.AddWarning("Tidy operation already in progress.")
return resp, nil
}
if tidyRevocationList {
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
tidiedRevoked := false
revokedSerials, err := req.Storage.List(ctx, "revoked/")
if err != nil {
return nil, errwrap.Wrapf("error fetching list of revoked certs: {{err}}", err)
}
var revInfo revocationInfo
for _, serial := range revokedSerials {
revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("unable to fetch revoked cert with serial %q: {{err}}", serial), err)
}
if revokedEntry == nil {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning(fmt.Sprintf("Revoked entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial))
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting nil revoked entry with serial %s: {{err}}", serial), err)
}
}
if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 {
if resp == nil {
resp = &logical.Response{}
}
resp.AddWarning(fmt.Sprintf("Revoked entry for serial %s has nil value; tidying up since it is no longer useful for any server operations", serial))
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting revoked entry with nil value with serial %s: {{err}}", serial), err)
}
}
err = revokedEntry.DecodeJSON(&revInfo)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error decoding revocation entry for serial %q: {{err}}", serial), err)
}
revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes)
if err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("unable to parse stored revoked certificate with serial %q: {{err}}", serial), err)
}
if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) {
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
return nil, errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from revoked list: {{err}}", serial), err)
}
tidiedRevoked = true
}
}
if tidiedRevoked {
if err := buildCRL(ctx, b, req); err != nil {
return nil, err
}
}
// Tests using framework will screw up the storage so make a locally
// scoped req to hold a reference
req = &logical.Request{
Storage: req.Storage,
}
go func() {
defer atomic.StoreUint32(b.tidyCASGuard, 0)
// Don't cancel when the original client request goes away
ctx = context.Background()
logger := b.Logger().Named("tidy")
doTidy := func() error {
if tidyCertStore {
serials, err := req.Storage.List(ctx, "certs/")
if err != nil {
return errwrap.Wrapf("error fetching list of certs: {{err}}", err)
}
for _, serial := range serials {
certEntry, err := req.Storage.Get(ctx, "certs/"+serial)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error fetching certificate %q: {{err}}", serial), err)
}
if certEntry == nil {
logger.Warn("certificate entry is nil; tidying up since it is no longer useful for any server operations", "serial", serial)
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting nil entry with serial %s: {{err}}", serial), err)
}
}
if certEntry.Value == nil || len(certEntry.Value) == 0 {
logger.Warn("certificate entry has no value; tidying up since it is no longer useful for any server operations", "serial", serial)
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting entry with nil value with serial %s: {{err}}", serial), err)
}
}
cert, err := x509.ParseCertificate(certEntry.Value)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("unable to parse stored certificate with serial %q: {{err}}", serial), err)
}
if time.Now().After(cert.NotAfter.Add(bufferDuration)) {
if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from storage: {{err}}", serial), err)
}
}
}
}
if tidyRevocationList {
b.revokeStorageLock.Lock()
defer b.revokeStorageLock.Unlock()
tidiedRevoked := false
revokedSerials, err := req.Storage.List(ctx, "revoked/")
if err != nil {
return errwrap.Wrapf("error fetching list of revoked certs: {{err}}", err)
}
var revInfo revocationInfo
for _, serial := range revokedSerials {
revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("unable to fetch revoked cert with serial %q: {{err}}", serial), err)
}
if revokedEntry == nil {
logger.Warn("revoked entry is nil; tidying up since it is no longer useful for any server operations", "serial", serial)
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting nil revoked entry with serial %s: {{err}}", serial), err)
}
}
if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 {
logger.Warn("revoked entry has nil value; tidying up since it is no longer useful for any server operations", "serial", serial)
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting revoked entry with nil value with serial %s: {{err}}", serial), err)
}
}
err = revokedEntry.DecodeJSON(&revInfo)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("error decoding revocation entry for serial %q: {{err}}", serial), err)
}
revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes)
if err != nil {
return errwrap.Wrapf(fmt.Sprintf("unable to parse stored revoked certificate with serial %q: {{err}}", serial), err)
}
if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) {
if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil {
return errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from revoked list: {{err}}", serial), err)
}
tidiedRevoked = true
}
}
if tidiedRevoked {
if err := buildCRL(ctx, b, req); err != nil {
return err
}
}
}
return nil
}
if err := doTidy(); err != nil {
logger.Error("error running tidy", "error", err)
return
}
}()
resp := &logical.Response{}
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
return resp, nil
}

View File

@@ -939,7 +939,10 @@ CLUSTER_SYNTHESIS_COMPLETE:
}
server := &http.Server{
Handler: handler,
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
IdleTimeout: 5 * time.Minute,
}
go server.Serve(ln.Listener)
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"sync"
"time"
proxyproto "github.com/armon/go-proxyproto"
"github.com/hashicorp/errwrap"
@@ -41,12 +42,14 @@ func WrapInProxyProto(listener net.Listener, config *ProxyProtoConfig) (net.List
switch config.Behavior {
case "use_always":
newLn = &proxyproto.Listener{
Listener: listener,
Listener: listener,
ProxyHeaderTimeout: 10 * time.Second,
}
case "allow_authorized", "deny_unauthorized":
newLn = &proxyproto.Listener{
Listener: listener,
Listener: listener,
ProxyHeaderTimeout: 10 * time.Second,
SourceCheck: func(addr net.Addr) (bool, error) {
config.RLock()
defer config.RUnlock()

View File

@@ -6,10 +6,10 @@ import (
"math/rand"
"net/http"
"os"
"reflect"
"testing"
"time"
"github.com/go-test/deep"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/logging"
"github.com/hashicorp/vault/physical"
@@ -106,8 +106,8 @@ func TestDynamoDBBackend(t *testing.T) {
if err != nil {
t.Fatalf("err: %s", err)
}
if !reflect.DeepEqual(inputEntry, entry) {
t.Fatalf("exp: %#v, act: %#v", inputEntry, entry)
if diff := deep.Equal(inputEntry, entry); diff != nil {
t.Fatal(diff)
}
})
}
@@ -285,7 +285,7 @@ func prepareDynamoDBTestContainer(t *testing.T) (cleanup func(), retAddress stri
t.Fatalf("Failed to connect to docker: %s", err)
}
resource, err := pool.Run("deangiberson/aws-dynamodb-local", "latest", []string{})
resource, err := pool.Run("cnadiminti/dynamodb-local", "latest", []string{})
if err != nil {
t.Fatalf("Could not start local DynamoDB: %s", err)
}

View File

@@ -191,15 +191,17 @@ func (m *ExpirationManager) Tidy() error {
var tidyErrors *multierror.Error
logger := m.logger.Named("tidy")
if !atomic.CompareAndSwapInt32(m.tidyLock, 0, 1) {
m.logger.Warn("tidy operation on leases is already in progress")
return fmt.Errorf("tidy operation on leases is already in progress")
logger.Warn("tidy operation on leases is already in progress")
return nil
}
defer atomic.CompareAndSwapInt32(m.tidyLock, 1, 0)
m.logger.Info("beginning tidy operation on leases")
defer m.logger.Info("finished tidy operation on leases")
logger.Info("beginning tidy operation on leases")
defer logger.Info("finished tidy operation on leases")
// Create a cache to keep track of looked up tokens
tokenCache := make(map[string]bool)
@@ -208,7 +210,7 @@ func (m *ExpirationManager) Tidy() error {
tidyFunc := func(leaseID string) {
countLease++
if countLease%500 == 0 {
m.logger.Info("tidying leases", "progress", countLease)
logger.Info("tidying leases", "progress", countLease)
}
le, err := m.loadEntry(leaseID)
@@ -225,7 +227,7 @@ func (m *ExpirationManager) Tidy() error {
var isValid, ok bool
revokeLease := false
if le.ClientToken == "" {
m.logger.Debug("revoking lease which has an empty token", "lease_id", leaseID)
logger.Debug("revoking lease which has an empty token", "lease_id", leaseID)
revokeLease = true
deletedCountEmptyToken++
goto REVOKE_CHECK
@@ -249,7 +251,7 @@ func (m *ExpirationManager) Tidy() error {
}
if te == nil {
m.logger.Debug("revoking lease which holds an invalid token", "lease_id", leaseID)
logger.Debug("revoking lease which holds an invalid token", "lease_id", leaseID)
revokeLease = true
deletedCountInvalidToken++
tokenCache[le.ClientToken] = false
@@ -262,7 +264,7 @@ func (m *ExpirationManager) Tidy() error {
return
}
m.logger.Debug("revoking lease which contains an invalid token", "lease_id", leaseID)
logger.Debug("revoking lease which contains an invalid token", "lease_id", leaseID)
revokeLease = true
deletedCountInvalidToken++
goto REVOKE_CHECK
@@ -285,10 +287,10 @@ func (m *ExpirationManager) Tidy() error {
return err
}
m.logger.Info("number of leases scanned", "count", countLease)
m.logger.Info("number of leases which had empty tokens", "count", deletedCountEmptyToken)
m.logger.Info("number of leases which had invalid tokens", "count", deletedCountInvalidToken)
m.logger.Info("number of leases successfully revoked", "count", revokedCount)
logger.Info("number of leases scanned", "count", countLease)
logger.Info("number of leases which had empty tokens", "count", deletedCountEmptyToken)
logger.Info("number of leases which had invalid tokens", "count", deletedCountInvalidToken)
logger.Info("number of leases successfully revoked", "count", revokedCount)
return tidyErrors.ErrorOrNil()
}

View File

@@ -1,6 +1,7 @@
package vault
import (
"bytes"
"context"
"fmt"
"reflect"
@@ -38,6 +39,14 @@ func TestExpiration_Tidy(t *testing.T) {
var err error
exp := mockExpiration(t)
// We use this later for tidy testing where we need to check the output
logOut := new(bytes.Buffer)
logger := log.New(&log.LoggerOptions{
Output: logOut,
})
exp.logger = logger
if err := exp.Restore(nil); err != nil {
t.Fatal(err)
}
@@ -212,9 +221,11 @@ func TestExpiration_Tidy(t *testing.T) {
}
}
if !(err1 != nil && err1.Error() == "tidy operation on leases is already in progress") &&
!(err2 != nil && err2.Error() == "tidy operation on leases is already in progress") {
t.Fatalf("expected at least one of err1 or err2 to be set; err1: %#v\n err2:%#v\n", err1, err2)
if err1 != nil || err2 != nil {
t.Fatalf("got an error: err1: %v; err2: %v", err1, err2)
}
if !strings.Contains(logOut.String(), "tidy operation on leases is already in progress") {
t.Fatalf("expected to see a warning saying operation in progress, output is %s", logOut.String())
}
root, err := exp.tokenStore.rootToken(context.Background())

View File

@@ -1182,12 +1182,17 @@ func (b *SystemBackend) handleCORSDelete(ctx context.Context, req *logical.Reque
}
func (b *SystemBackend) handleTidyLeases(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := b.Core.expiration.Tidy()
if err != nil {
b.Backend.Logger().Error("failed to tidy leases", "error", err)
return handleErrorNoReadOnlyForward(err)
}
return nil, err
go func() {
err := b.Core.expiration.Tidy()
if err != nil {
b.Backend.Logger().Error("failed to tidy leases", "error", err)
return
}
}()
resp := &logical.Response{}
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
return resp, nil
}
func (b *SystemBackend) invalidate(ctx context.Context, key string) {

View File

@@ -76,7 +76,11 @@ func (c *Core) startForwarding(ctx context.Context) error {
// duties. Doing it this way instead of listening via the server and gRPC
// allows us to re-use the same port via ALPN. We can just tell the server
// to serve a given conn and which handler to use.
fws := &http2.Server{}
fws := &http2.Server{
// Our forwarding connections heartbeat regularly so anything else we
// want to go away/get cleaned up pretty rapidly
IdleTimeout: 5 * HeartbeatInterval,
}
// Shutdown coordination logic
shutdown := new(uint32)
@@ -147,6 +151,20 @@ func (c *Core) startForwarding(ctx context.Context) error {
// Type assert to TLS connection and handshake to populate the
// connection state
tlsConn := conn.(*tls.Conn)
// Set a deadline for the handshake. This will cause clients
// that don't successfully auth to be kicked out quickly.
// Cluster connections should be reliable so being marginally
// aggressive here is fine.
err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second))
if err != nil {
if c.logger.IsDebug() {
c.logger.Debug("error setting deadline for cluster connection", "error", err)
}
tlsConn.Close()
continue
}
err = tlsConn.Handshake()
if err != nil {
if c.logger.IsDebug() {
@@ -156,6 +174,16 @@ func (c *Core) startForwarding(ctx context.Context) error {
continue
}
// Now, set it back to unlimited
err = tlsConn.SetDeadline(time.Time{})
if err != nil {
if c.logger.IsDebug() {
c.logger.Debug("error setting deadline for cluster connection", "error", err)
}
tlsConn.Close()
continue
}
switch tlsConn.ConnectionState().NegotiatedProtocol {
case requestForwardingALPN:
if !ha {

View File

@@ -130,7 +130,7 @@ type TokenStore struct {
saltLock sync.RWMutex
salt *salt.Salt
tidyLock *int32
tidyLock *uint32
identityPoliciesDeriverFunc func(string) (*identity.Entity, []string, error)
}
@@ -150,7 +150,7 @@ func NewTokenStore(ctx context.Context, logger log.Logger, c *Core, config *logi
tokensPendingDeletion: &sync.Map{},
saltLock: sync.RWMutex{},
identityPoliciesDeriverFunc: c.fetchEntityAndDerivedPolicies,
tidyLock: new(int32),
tidyLock: new(uint32),
}
if c.policyStore != nil {
@@ -1290,204 +1290,224 @@ func (ts *TokenStore) lookupBySaltedAccessor(ctx context.Context, saltedAccessor
// handleTidy handles the cleaning up of leaked accessor storage entries and
// cleaning up of leases that are associated to tokens that are expired.
func (ts *TokenStore) handleTidy(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var tidyErrors *multierror.Error
if !atomic.CompareAndSwapInt32(ts.tidyLock, 0, 1) {
ts.logger.Warn("tidy operation on tokens is already in progress")
return nil, fmt.Errorf("tidy operation on tokens is already in progress")
if !atomic.CompareAndSwapUint32(ts.tidyLock, 0, 1) {
resp := &logical.Response{}
resp.AddWarning("Tidy operation already in progress.")
return resp, nil
}
defer atomic.CompareAndSwapInt32(ts.tidyLock, 1, 0)
go func() {
defer atomic.StoreUint32(ts.tidyLock, 0)
ts.logger.Info("beginning tidy operation on tokens")
defer ts.logger.Info("finished tidy operation on tokens")
// Don't cancel when the original client request goes away
ctx = context.Background()
// List out all the accessors
saltedAccessorList, err := ts.view.List(ctx, accessorPrefix)
if err != nil {
return nil, errwrap.Wrapf("failed to fetch accessor index entries: {{err}}", err)
}
logger := ts.logger.Named("tidy")
// First, clean up secondary index entries that are no longer valid
parentList, err := ts.view.List(ctx, parentPrefix)
if err != nil {
return nil, errwrap.Wrapf("failed to fetch secondary index entries: {{err}}", err)
}
var tidyErrors *multierror.Error
var countParentEntries, deletedCountParentEntries, countParentList, deletedCountParentList int64
doTidy := func() error {
// Scan through the secondary index entries; if there is an entry
// with the token's salt ID at the end, remove it
for _, parent := range parentList {
countParentEntries++
ts.logger.Info("beginning tidy operation on tokens")
defer ts.logger.Info("finished tidy operation on tokens")
// Get the children
children, err := ts.view.List(ctx, parentPrefix+parent)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read secondary index: {{err}}", err))
continue
}
// First check if the salt ID of the parent exists, and if not mark this so
// that deletion of children later with this loop below applies to all
// children
originalChildrenCount := int64(len(children))
exists, _ := ts.lookupSalted(ctx, strings.TrimSuffix(parent, "/"), true)
if exists == nil {
ts.logger.Debug("deleting invalid parent prefix entry", "index", parentPrefix+parent)
}
var deletedChildrenCount int64
for _, child := range children {
countParentList++
if countParentList%500 == 0 {
ts.logger.Info("checking validity of tokens in secondary index list", "progress", countParentList)
// List out all the accessors
saltedAccessorList, err := ts.view.List(ctx, accessorPrefix)
if err != nil {
return errwrap.Wrapf("failed to fetch accessor index entries: {{err}}", err)
}
// Look up tainted entries so we can be sure that if this isn't
// found, it doesn't exist. Doing the following without locking
// since appropriate locks cannot be held with salted token IDs.
// Also perform deletion if the parent doesn't exist any more.
te, _ := ts.lookupSalted(ctx, child, true)
// If the child entry is not nil, but the parent doesn't exist, then turn
// that child token into an orphan token. Theres no deletion in this case.
if te != nil && exists == nil {
lock := locksutil.LockForKey(ts.tokenLocks, te.ID)
lock.Lock()
te.Parent = ""
err = ts.store(ctx, te)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to convert child token into an orphan token: {{err}}", err))
}
lock.Unlock()
continue
// First, clean up secondary index entries that are no longer valid
parentList, err := ts.view.List(ctx, parentPrefix)
if err != nil {
return errwrap.Wrapf("failed to fetch secondary index entries: {{err}}", err)
}
// Otherwise, if the entry doesn't exist, or if the parent doesn't exist go
// on with the delete on the secondary index
if te == nil || exists == nil {
index := parentPrefix + parent + child
ts.logger.Debug("deleting invalid secondary index", "index", index)
err = ts.view.Delete(ctx, index)
var countParentEntries, deletedCountParentEntries, countParentList, deletedCountParentList int64
// Scan through the secondary index entries; if there is an entry
// with the token's salt ID at the end, remove it
for _, parent := range parentList {
countParentEntries++
// Get the children
children, err := ts.view.List(ctx, parentPrefix+parent)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete secondary index: {{err}}", err))
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read secondary index: {{err}}", err))
continue
}
deletedChildrenCount++
}
}
// Add current children deleted count to the total count
deletedCountParentList += deletedChildrenCount
// N.B.: We don't call delete on the parent prefix since physical.Backend.Delete
// implementations should be in charge of deleting empty prefixes.
// If we deleted all the children, then add that to our deleted parent entries count.
if originalChildrenCount == deletedChildrenCount {
deletedCountParentEntries++
}
}
var countAccessorList,
deletedCountAccessorEmptyToken,
deletedCountAccessorInvalidToken,
deletedCountInvalidTokenInAccessor int64
// First check if the salt ID of the parent exists, and if not mark this so
// that deletion of children later with this loop below applies to all
// children
originalChildrenCount := int64(len(children))
exists, _ := ts.lookupSalted(ctx, strings.TrimSuffix(parent, "/"), true)
if exists == nil {
ts.logger.Debug("deleting invalid parent prefix entry", "index", parentPrefix+parent)
}
// For each of the accessor, see if the token ID associated with it is
// a valid one. If not, delete the leases associated with that token
// and delete the accessor as well.
for _, saltedAccessor := range saltedAccessorList {
countAccessorList++
if countAccessorList%500 == 0 {
ts.logger.Info("checking if accessors contain valid tokens", "progress", countAccessorList)
}
var deletedChildrenCount int64
for _, child := range children {
countParentList++
if countParentList%500 == 0 {
ts.logger.Info("checking validity of tokens in secondary index list", "progress", countParentList)
}
accessorEntry, err := ts.lookupBySaltedAccessor(ctx, saltedAccessor, true)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read the accessor index: {{err}}", err))
continue
}
// Look up tainted entries so we can be sure that if this isn't
// found, it doesn't exist. Doing the following without locking
// since appropriate locks cannot be held with salted token IDs.
// Also perform deletion if the parent doesn't exist any more.
te, _ := ts.lookupSalted(ctx, child, true)
// If the child entry is not nil, but the parent doesn't exist, then turn
// that child token into an orphan token. Theres no deletion in this case.
if te != nil && exists == nil {
lock := locksutil.LockForKey(ts.tokenLocks, te.ID)
lock.Lock()
// A valid accessor storage entry should always have a token ID
// in it. If not, it is an invalid accessor entry and needs to
// be deleted.
if accessorEntry.TokenID == "" {
index := accessorPrefix + saltedAccessor
// If deletion of accessor fails, move on to the next
// item since this is just a best-effort operation
err = ts.view.Delete(ctx, index)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete the accessor index: {{err}}", err))
continue
}
deletedCountAccessorEmptyToken++
}
lock := locksutil.LockForKey(ts.tokenLocks, accessorEntry.TokenID)
lock.RLock()
// Look up tainted variants so we only find entries that truly don't
// exist
saltedID, err := ts.SaltID(ctx, accessorEntry.TokenID)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read salt id: {{err}}", err))
lock.RUnlock()
continue
}
te, err := ts.lookupSalted(ctx, saltedID, true)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to lookup tainted ID: {{err}}", err))
lock.RUnlock()
continue
}
lock.RUnlock()
// If token entry is not found assume that the token is not valid any
// more and conclude that accessor, leases, and secondary index entries
// for this token should not exist as well.
if te == nil {
ts.logger.Info("deleting token with nil entry", "salted_token", saltedID)
// RevokeByToken expects a '*logical.TokenEntry'. For the
// purposes of tidying, it is sufficient if the token
// entry only has ID set.
tokenEntry := &logical.TokenEntry{
ID: accessorEntry.TokenID,
te.Parent = ""
err = ts.store(ctx, te)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to convert child token into an orphan token: {{err}}", err))
}
lock.Unlock()
continue
}
// Otherwise, if the entry doesn't exist, or if the parent doesn't exist go
// on with the delete on the secondary index
if te == nil || exists == nil {
index := parentPrefix + parent + child
ts.logger.Debug("deleting invalid secondary index", "index", index)
err = ts.view.Delete(ctx, index)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete secondary index: {{err}}", err))
continue
}
deletedChildrenCount++
}
}
// Add current children deleted count to the total count
deletedCountParentList += deletedChildrenCount
// N.B.: We don't call delete on the parent prefix since physical.Backend.Delete
// implementations should be in charge of deleting empty prefixes.
// If we deleted all the children, then add that to our deleted parent entries count.
if originalChildrenCount == deletedChildrenCount {
deletedCountParentEntries++
}
}
// Attempt to revoke the token. This will also revoke
// the leases associated with the token.
err := ts.expiration.RevokeByToken(tokenEntry)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to revoke leases of expired token: {{err}}", err))
continue
}
deletedCountInvalidTokenInAccessor++
var countAccessorList,
deletedCountAccessorEmptyToken,
deletedCountAccessorInvalidToken,
deletedCountInvalidTokenInAccessor int64
index := accessorPrefix + saltedAccessor
// For each of the accessor, see if the token ID associated with it is
// a valid one. If not, delete the leases associated with that token
// and delete the accessor as well.
for _, saltedAccessor := range saltedAccessorList {
countAccessorList++
if countAccessorList%500 == 0 {
ts.logger.Info("checking if accessors contain valid tokens", "progress", countAccessorList)
}
// If deletion of accessor fails, move on to the next item since
// this is just a best-effort operation. We do this last so that on
// next run if something above failed we still have the accessor
// entry to try again.
err = ts.view.Delete(ctx, index)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete accessor entry: {{err}}", err))
continue
accessorEntry, err := ts.lookupBySaltedAccessor(ctx, saltedAccessor, true)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read the accessor index: {{err}}", err))
continue
}
// A valid accessor storage entry should always have a token ID
// in it. If not, it is an invalid accessor entry and needs to
// be deleted.
if accessorEntry.TokenID == "" {
index := accessorPrefix + saltedAccessor
// If deletion of accessor fails, move on to the next
// item since this is just a best-effort operation
err = ts.view.Delete(ctx, index)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete the accessor index: {{err}}", err))
continue
}
deletedCountAccessorEmptyToken++
}
lock := locksutil.LockForKey(ts.tokenLocks, accessorEntry.TokenID)
lock.RLock()
// Look up tainted variants so we only find entries that truly don't
// exist
saltedID, err := ts.SaltID(ctx, accessorEntry.TokenID)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read salt id: {{err}}", err))
lock.RUnlock()
continue
}
te, err := ts.lookupSalted(ctx, saltedID, true)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to lookup tainted ID: {{err}}", err))
lock.RUnlock()
continue
}
lock.RUnlock()
// If token entry is not found assume that the token is not valid any
// more and conclude that accessor, leases, and secondary index entries
// for this token should not exist as well.
if te == nil {
ts.logger.Info("deleting token with nil entry", "salted_token", saltedID)
// RevokeByToken expects a '*logical.TokenEntry'. For the
// purposes of tidying, it is sufficient if the token
// entry only has ID set.
tokenEntry := &logical.TokenEntry{
ID: accessorEntry.TokenID,
}
// Attempt to revoke the token. This will also revoke
// the leases associated with the token.
err := ts.expiration.RevokeByToken(tokenEntry)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to revoke leases of expired token: {{err}}", err))
continue
}
deletedCountInvalidTokenInAccessor++
index := accessorPrefix + saltedAccessor
// If deletion of accessor fails, move on to the next item since
// this is just a best-effort operation. We do this last so that on
// next run if something above failed we still have the accessor
// entry to try again.
err = ts.view.Delete(ctx, index)
if err != nil {
tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete accessor entry: {{err}}", err))
continue
}
deletedCountAccessorInvalidToken++
}
}
deletedCountAccessorInvalidToken++
ts.logger.Info("number of entries scanned in parent prefix", "count", countParentEntries)
ts.logger.Info("number of entries deleted in parent prefix", "count", deletedCountParentEntries)
ts.logger.Info("number of tokens scanned in parent index list", "count", countParentList)
ts.logger.Info("number of tokens revoked in parent index list", "count", deletedCountParentList)
ts.logger.Info("number of accessors scanned", "count", countAccessorList)
ts.logger.Info("number of deleted accessors which had empty tokens", "count", deletedCountAccessorEmptyToken)
ts.logger.Info("number of revoked tokens which were invalid but present in accessors", "count", deletedCountInvalidTokenInAccessor)
ts.logger.Info("number of deleted accessors which had invalid tokens", "count", deletedCountAccessorInvalidToken)
return tidyErrors.ErrorOrNil()
}
}
ts.logger.Info("number of entries scanned in parent prefix", "count", countParentEntries)
ts.logger.Info("number of entries deleted in parent prefix", "count", deletedCountParentEntries)
ts.logger.Info("number of tokens scanned in parent index list", "count", countParentList)
ts.logger.Info("number of tokens revoked in parent index list", "count", deletedCountParentList)
ts.logger.Info("number of accessors scanned", "count", countAccessorList)
ts.logger.Info("number of deleted accessors which had empty tokens", "count", deletedCountAccessorEmptyToken)
ts.logger.Info("number of revoked tokens which were invalid but present in accessors", "count", deletedCountInvalidTokenInAccessor)
ts.logger.Info("number of deleted accessors which had invalid tokens", "count", deletedCountAccessorInvalidToken)
if err := doTidy(); err != nil {
logger.Error("error running tidy", "error", err)
return
}
}()
return nil, tidyErrors.ErrorOrNil()
resp := &logical.Response{}
resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.")
return resp, nil
}
// handleUpdateLookupAccessor handles the auth/token/lookup-accessor path for returning

View File

@@ -3777,6 +3777,9 @@ func TestTokenStore_HandleTidyCase1(t *testing.T) {
t.Fatalf("err:%v resp:%v", err, resp)
}
// Tidy runs async so give it time
time.Sleep(10 * time.Second)
// Tidy should have removed all the dangling accessor entries
resp, err = ts.HandleRequest(context.Background(), accessorListReq)
if err != nil || (resp != nil && resp.IsError()) {
@@ -3909,6 +3912,9 @@ func TestTokenStore_HandleTidy_parentCleanup(t *testing.T) {
t.Fatalf("err:%v resp:%v", err, resp)
}
// Tidy runs async so give it time
time.Sleep(10 * time.Second)
// Tidy should have removed all the dangling accessor entries
resp, err = ts.HandleRequest(context.Background(), accessorListReq)
if err != nil || (resp != nil && resp.IsError()) {