diff --git a/vault/expiration.go b/vault/expiration.go index ecc7639bfc..82088190a0 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "path" + "sync" "time" ) @@ -18,8 +19,11 @@ const ( // If a secret is not renewed in timely manner, it may be expired, and // the ExpirationManager will handle doing automatic revocation. type ExpirationManager struct { - router *Router - view *BarrierView + router *Router + view *BarrierView + doneCh chan struct{} + stopCh chan struct{} + stopLock sync.Mutex } // NewExpirationManager creates a new ExpirationManager that is backed @@ -67,21 +71,55 @@ func (c *Core) stopExpiration() error { // Restore is used to recover the lease states when starting. // This is used after starting the vault. func (m *ExpirationManager) Restore() error { + m.stopLock.Lock() + defer m.stopLock.Unlock() + if m.stopCh != nil { + return fmt.Errorf("cannot restore while running") + } + + // TODO: Restore... return nil } // Start is used to continue automatic revocation. This // should only be called when the Vault is unsealed. func (m *ExpirationManager) Start() error { + m.stopLock.Lock() + defer m.stopLock.Unlock() + if m.stopCh == nil { + m.doneCh = make(chan struct{}) + m.stopCh = make(chan struct{}) + go m.run(m.doneCh, m.stopCh) + } return nil } // Stop is used to prevent further automatic revocations. // This must be called before sealing the view. func (m *ExpirationManager) Stop() error { + m.stopLock.Lock() + defer m.stopLock.Unlock() + if m.stopCh != nil { + doneCh := m.doneCh + close(m.stopCh) + m.stopCh = nil + m.doneCh = nil + <-doneCh // Wait for completion + } return nil } +// run is a long running goroutine that manages background expiration +func (m *ExpirationManager) run(doneCh, stopCh chan struct{}) { + defer close(doneCh) + for { + select { + case <-stopCh: + return + } + } +} + // Revoke is used to revoke a secret named by the given vaultID func (m *ExpirationManager) Revoke(vaultID string) error { return nil diff --git a/vault/expiration_test.go b/vault/expiration_test.go index 6cec022764..d1683fa5d6 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -14,6 +14,24 @@ func mockExpiration(t *testing.T) *ExpirationManager { return NewExpirationManager(router, view) } +func TestExpiration_StartStop(t *testing.T) { + exp := mockExpiration(t) + err := exp.Start() + if err != nil { + t.Fatalf("err: %v", err) + } + + err = exp.Restore() + if err.Error() != "cannot restore while running" { + t.Fatalf("err: %v", err) + } + + err = exp.Stop() + if err != nil { + t.Fatalf("err: %v", err) + } +} + func TestExpiration_Register(t *testing.T) { exp := mockExpiration(t) req := &Request{