Push a lot of logic into Router to make a bunch of it nicer and enable a

lot of cleanup. Plumb config and calls to framework.Backend.Setup() into
logical_system and elsewhere, including tests.
This commit is contained in:
Jeff Mitchell
2015-09-04 16:58:12 -04:00
parent 76c18762aa
commit 3e713c61ac
20 changed files with 268 additions and 231 deletions

2
.gitignore vendored
View File

@@ -49,6 +49,8 @@ Vagrantfile
dist/* dist/*
tags
# Editor backups # Editor backups
*~ *~
*.sw[a-z] *.sw[a-z]

View File

@@ -8,7 +8,7 @@ import (
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
) )
// RemountCommand is a Command that remounts a mounted secret backend // MountTuneCommand is a Command that remounts a mounted secret backend
// to a new endpoint. // to a new endpoint.
type MountTuneCommand struct { type MountTuneCommand struct {
Meta Meta

View File

@@ -66,7 +66,7 @@ func (c *Core) enableCredential(entry *MountEntry) error {
view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/") view := NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
// Create the new backend // Create the new backend
backend, err := c.newCredentialBackend(entry.Type, view, nil) backend, err := c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -81,7 +81,7 @@ func (c *Core) enableCredential(entry *MountEntry) error {
// Mount the backend // Mount the backend
path := credentialRoutePrefix + entry.Path path := credentialRoutePrefix + entry.Path
if err := c.router.Mount(backend, path, entry.UUID, view); err != nil { if err := c.router.Mount(backend, path, entry, view); err != nil {
return err return err
} }
c.logger.Printf("[INFO] core: enabled credential backend '%s' type: %s", c.logger.Printf("[INFO] core: enabled credential backend '%s' type: %s",
@@ -242,7 +242,7 @@ func (c *Core) setupCredentials() error {
view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/") view = NewBarrierView(c.barrier, credentialBarrierPrefix+entry.UUID+"/")
// Initialize the backend // Initialize the backend
backend, err = c.newCredentialBackend(entry.Type, view, nil) backend, err = c.newCredentialBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil { if err != nil {
c.logger.Printf( c.logger.Printf(
"[ERR] core: failed to create credential entry %#v: %v", "[ERR] core: failed to create credential entry %#v: %v",
@@ -252,7 +252,7 @@ func (c *Core) setupCredentials() error {
// Mount the backend // Mount the backend
path := credentialRoutePrefix + entry.Path path := credentialRoutePrefix + entry.Path
err = c.router.Mount(backend, path, entry.UUID, view) err = c.router.Mount(backend, path, entry, view)
if err != nil { if err != nil {
c.logger.Printf("[ERR] core: failed to mount auth entry %#v: %v", entry, err) c.logger.Printf("[ERR] core: failed to mount auth entry %#v: %v", entry, err)
return loadAuthFailed return loadAuthFailed
@@ -281,7 +281,7 @@ func (c *Core) teardownCredentials() error {
// newCredentialBackend is used to create and configure a new credential backend by name // newCredentialBackend is used to create and configure a new credential backend by name
func (c *Core) newCredentialBackend( func (c *Core) newCredentialBackend(
t string, view logical.Storage, conf map[string]string) (logical.Backend, error) { t string, sysView logical.SystemView, view logical.Storage, conf map[string]string) (logical.Backend, error) {
f, ok := c.credentialBackends[t] f, ok := c.credentialBackends[t]
if !ok { if !ok {
return nil, fmt.Errorf("unknown backend type: %s", t) return nil, fmt.Errorf("unknown backend type: %s", t)
@@ -291,12 +291,14 @@ func (c *Core) newCredentialBackend(
View: view, View: view,
Logger: c.logger, Logger: c.logger,
Config: conf, Config: conf,
System: sysView,
} }
b, err := f(config) b, err := f(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return b, nil return b, nil
} }

View File

@@ -220,8 +220,8 @@ type Core struct {
// out into the configured audit backends // out into the configured audit backends
auditBroker *AuditBroker auditBroker *AuditBroker
// systemView is the barrier view for the system backend // systemBarrierView is the barrier view for the system backend
systemView *BarrierView systemBarrierView *BarrierView
// expiration manager is used for managing LeaseIDs, // expiration manager is used for managing LeaseIDs,
// renewal, expiration and revocation // renewal, expiration and revocation
@@ -351,8 +351,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
logicalBackends[k] = f logicalBackends[k] = f
} }
logicalBackends["generic"] = PassthroughBackendFactory logicalBackends["generic"] = PassthroughBackendFactory
logicalBackends["system"] = func(*logical.BackendConfig) (logical.Backend, error) { logicalBackends["system"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return NewSystemBackend(c), nil return NewSystemBackend(c, config), nil
} }
c.logicalBackends = logicalBackends c.logicalBackends = logicalBackends
@@ -360,8 +360,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
for k, f := range conf.CredentialBackends { for k, f := range conf.CredentialBackends {
credentialBackends[k] = f credentialBackends[k] = f
} }
credentialBackends["token"] = func(*logical.BackendConfig) (logical.Backend, error) { credentialBackends["token"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return NewTokenStore(c) return NewTokenStore(c, config)
} }
c.credentialBackends = credentialBackends c.credentialBackends = credentialBackends
@@ -478,9 +478,9 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r
// We exclude renewal of a lease, since it does not need to be re-registered // We exclude renewal of a lease, since it does not need to be re-registered
if resp != nil && resp.Secret != nil && !strings.HasPrefix(req.Path, "sys/renew/") { if resp != nil && resp.Secret != nil && !strings.HasPrefix(req.Path, "sys/renew/") {
// Get the SystemView for the mount // Get the SystemView for the mount
sysView, err := c.sysViewByPath(req.Path) sysView := c.router.MatchingSystemView(req.Path)
if err != nil { if sysView == nil {
c.logger.Println(err) c.logger.Println("[ERR] core: unable to retrieve system view from router")
return nil, auth, ErrInternalError return nil, auth, ErrInternalError
} }

View File

@@ -0,0 +1,54 @@
package vault
import (
"fmt"
"strings"
"time"
)
type dynamicSystemView struct {
core *Core
path string
}
func (d dynamicSystemView) DefaultLeaseTTL() (time.Duration, error) {
def, _, err := d.fetchTTLs()
if err != nil {
return 0, err
}
return def, nil
}
func (d dynamicSystemView) MaxLeaseTTL() (time.Duration, error) {
_, max, err := d.fetchTTLs()
if err != nil {
return 0, err
}
return max, nil
}
// TTLsByPath returns the default and max TTLs corresponding to a particular
// mount point, or the system default
func (d dynamicSystemView) fetchTTLs() (def, max time.Duration, retErr error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(d.path, "/") {
d.path += "/"
}
me := d.core.router.MatchingMountEntry(d.path)
if me == nil {
return 0, 0, fmt.Errorf("[ERR] core: failed to get mount entry for %s", d.path)
}
def = d.core.defaultLeaseTTL
max = d.core.maxLeaseTTL
if me.Config.DefaultLeaseTTL != nil && *me.Config.DefaultLeaseTTL != 0 {
def = *me.Config.DefaultLeaseTTL
}
if me.Config.MaxLeaseTTL != nil && *me.Config.MaxLeaseTTL != 0 {
max = *me.Config.MaxLeaseTTL
}
return
}

View File

@@ -78,7 +78,7 @@ func NewExpirationManager(router *Router, view *BarrierView, ts *TokenStore, log
// initialize the expiration manager // initialize the expiration manager
func (c *Core) setupExpiration() error { func (c *Core) setupExpiration() error {
// Create a sub-view // Create a sub-view
view := c.systemView.SubView(expirationSubPath) view := c.systemBarrierView.SubView(expirationSubPath)
// Create the manager // Create the manager
mgr := NewExpirationManager(c.router, view, c.tokenStore, c.logger) mgr := NewExpirationManager(c.router, view, c.tokenStore, c.logger)

View File

@@ -22,7 +22,7 @@ func TestExpiration_Restore(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
paths := []string{ paths := []string{
"prod/aws/foo", "prod/aws/foo",
@@ -175,7 +175,7 @@ func TestExpiration_Revoke(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{ req := &logical.Request{
Operation: logical.ReadOperation, Operation: logical.ReadOperation,
@@ -213,7 +213,7 @@ func TestExpiration_RevokeOnExpire(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{ req := &logical.Request{
Operation: logical.ReadOperation, Operation: logical.ReadOperation,
@@ -262,7 +262,7 @@ func TestExpiration_RevokePrefix(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
paths := []string{ paths := []string{
"prod/aws/foo", "prod/aws/foo",
@@ -322,7 +322,7 @@ func TestExpiration_RevokeByToken(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
paths := []string{ paths := []string{
"prod/aws/foo", "prod/aws/foo",
@@ -441,7 +441,7 @@ func TestExpiration_Renew(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{ req := &logical.Request{
Operation: logical.ReadOperation, Operation: logical.ReadOperation,
@@ -503,7 +503,7 @@ func TestExpiration_Renew_NotRenewable(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{ req := &logical.Request{
Operation: logical.ReadOperation, Operation: logical.ReadOperation,
@@ -545,7 +545,7 @@ func TestExpiration_Renew_RevokeOnExpire(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "prod/aws/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
req := &logical.Request{ req := &logical.Request{
Operation: logical.ReadOperation, Operation: logical.ReadOperation,
@@ -613,7 +613,7 @@ func TestExpiration_revokeEntry(t *testing.T) {
noop := &NoopBackend{} noop := &NoopBackend{}
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "", uuid.GenerateUUID(), view) exp.router.Mount(noop, "", &MountEntry{UUID: uuid.GenerateUUID()}, view)
le := &leaseEntry{ le := &leaseEntry{
LeaseID: "foo/bar/1234", LeaseID: "foo/bar/1234",
@@ -702,7 +702,7 @@ func TestExpiration_renewEntry(t *testing.T) {
} }
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
exp.router.Mount(noop, "", uuid.GenerateUUID(), view) exp.router.Mount(noop, "", &MountEntry{UUID: uuid.GenerateUUID()}, view)
le := &leaseEntry{ le := &leaseEntry{
LeaseID: "foo/bar/1234", LeaseID: "foo/bar/1234",
@@ -764,7 +764,7 @@ func TestExpiration_renewAuthEntry(t *testing.T) {
} }
_, barrier, _ := mockBarrier(t) _, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "auth/foo/") view := NewBarrierView(barrier, "auth/foo/")
exp.router.Mount(noop, "auth/foo/", uuid.GenerateUUID(), view) exp.router.Mount(noop, "auth/foo/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
le := &leaseEntry{ le := &leaseEntry{
LeaseID: "auth/foo/1234", LeaseID: "auth/foo/1234",

View File

@@ -11,7 +11,7 @@ import (
) )
// logical.Factory // logical.Factory
func PassthroughBackendFactory(*logical.BackendConfig) (logical.Backend, error) { func PassthroughBackendFactory(conf *logical.BackendConfig) (logical.Backend, error) {
var b PassthroughBackend var b PassthroughBackend
b.Backend = &framework.Backend{ b.Backend = &framework.Backend{
Help: strings.TrimSpace(passthroughHelp), Help: strings.TrimSpace(passthroughHelp),
@@ -53,6 +53,11 @@ func PassthroughBackendFactory(*logical.BackendConfig) (logical.Backend, error)
}, },
} }
if conf == nil {
return nil, fmt.Errorf("Configuation passed into backend is nil")
}
b.Backend.Setup(conf)
return b, nil return b, nil
} }

View File

@@ -176,6 +176,12 @@ func TestPassthroughBackend_List(t *testing.T) {
} }
func testPassthroughBackend() logical.Backend { func testPassthroughBackend() logical.Backend {
b, _ := PassthroughBackendFactory(nil) b, _ := PassthroughBackendFactory(&logical.BackendConfig{
Logger: nil,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 30,
},
})
return b return b
} }

View File

@@ -20,10 +20,11 @@ var (
} }
) )
func NewSystemBackend(core *Core) logical.Backend { func NewSystemBackend(core *Core, config *logical.BackendConfig) logical.Backend {
b := &SystemBackend{ b := &SystemBackend{
Core: core, Core: core,
} }
b.Backend = &framework.Backend{ b.Backend = &framework.Backend{
Help: strings.TrimSpace(sysHelpRoot), Help: strings.TrimSpace(sysHelpRoot),
@@ -346,6 +347,9 @@ func NewSystemBackend(core *Core) logical.Backend {
}, },
}, },
} }
b.Backend.Setup(config)
return b.Backend return b.Backend
} }
@@ -486,9 +490,26 @@ func (b *SystemBackend) handleMountConfig(
logical.ErrInvalidRequest logical.ErrInvalidRequest
} }
def, max, err := b.Core.TTLsByPath(path) if !strings.HasSuffix(path, "/") {
path += "/"
}
sysView := b.Core.router.MatchingSystemView(path)
if sysView == nil {
err := fmt.Errorf("[ERR] sys: cannot fetch sysview for path %s", path)
b.Backend.Logger().Print(err)
return handleError(err)
}
def, err := sysView.DefaultLeaseTTL()
if err != nil { if err != nil {
b.Backend.Logger().Printf("[ERR] sys: fetching config of path '%s' failed: %v", path, err) b.Backend.Logger().Printf("[ERR] sys: fetching config default TTL of path '%s' failed: %v", path, err)
return handleError(err)
}
max, err := sysView.MaxLeaseTTL()
if err != nil {
b.Backend.Logger().Printf("[ERR] sys: fetching config max TTL of path '%s' failed: %v", path, err)
return handleError(err) return handleError(err)
} }
@@ -516,6 +537,10 @@ func (b *SystemBackend) handleMountTune(
logical.ErrInvalidRequest logical.ErrInvalidRequest
} }
if !strings.HasSuffix(path, "/") {
path += "/"
}
var config MountConfig var config MountConfig
configMap := data.Get("config").(map[string]interface{}) configMap := data.Get("config").(map[string]interface{})
if configMap == nil || len(configMap) == 0 { if configMap == nil || len(configMap) == 0 {

View File

@@ -760,10 +760,24 @@ func TestSystemBackend_rotate(t *testing.T) {
func testSystemBackend(t *testing.T) logical.Backend { func testSystemBackend(t *testing.T) logical.Backend {
c, _, _ := TestCoreUnsealed(t) c, _, _ := TestCoreUnsealed(t)
return NewSystemBackend(c) bc := &logical.BackendConfig{
Logger: c.logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 30,
},
}
return NewSystemBackend(c, bc)
} }
func testCoreSystemBackend(t *testing.T) (*Core, logical.Backend, string) { func testCoreSystemBackend(t *testing.T) (*Core, logical.Backend, string) {
c, _, root := TestCoreUnsealed(t) c, _, root := TestCoreUnsealed(t)
return c, NewSystemBackend(c), root bc := &logical.BackendConfig{
Logger: c.logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 30,
},
}
return c, NewSystemBackend(c, bc), root
} }

View File

@@ -40,27 +40,6 @@ var (
} }
) )
type dynamicSystemView struct {
core *Core
path string
}
func (d dynamicSystemView) DefaultLeaseTTL() (time.Duration, error) {
def, _, err := d.core.TTLsByPath(d.path)
if err != nil {
return 0, err
}
return def, nil
}
func (d dynamicSystemView) MaxLeaseTTL() (time.Duration, error) {
_, max, err := d.core.TTLsByPath(d.path)
if err != nil {
return 0, err
}
return max, nil
}
// MountTable is used to represent the internal mount table // MountTable is used to represent the internal mount table
type MountTable struct { type MountTable struct {
// This lock should be held whenever modifying the Entries field. // This lock should be held whenever modifying the Entries field.
@@ -185,12 +164,7 @@ func (c *Core) mount(me *MountEntry) error {
me.UUID = uuid.GenerateUUID() me.UUID = uuid.GenerateUUID()
view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/") view := NewBarrierView(c.barrier, backendBarrierPrefix+me.UUID+"/")
// Create the new backend backend, err := c.newLogicalBackend(me.Type, c.mountEntrySysView(me), view, nil)
sysView, err := c.mountEntrySysView(me)
if err != nil {
return err
}
backend, err := c.newLogicalBackend(me.Type, sysView, view, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -204,7 +178,7 @@ func (c *Core) mount(me *MountEntry) error {
c.mounts = newTable c.mounts = newTable
// Mount the backend // Mount the backend
if err := c.router.Mount(backend, me.Path, me.UUID, view); err != nil { if err := c.router.Mount(backend, me.Path, me, view); err != nil {
return err return err
} }
c.logger.Printf("[INFO] core: mounted '%s' type: %s", me.Path, me.Type) c.logger.Printf("[INFO] core: mounted '%s' type: %s", me.Path, me.Type)
@@ -394,51 +368,44 @@ func (c *Core) tuneMount(path string, config MountConfig) error {
// Prevent protected paths from being changed // Prevent protected paths from being changed
for _, p := range protectedMounts { for _, p := range protectedMounts {
if strings.HasPrefix(path, p) { if strings.HasPrefix(path, p) {
return fmt.Errorf("cannot tune '%s'", path) return fmt.Errorf("[ERR] core: cannot tune '%s'", path)
} }
} }
// Verify exact match of the route me := c.router.MatchingMountEntry(path)
match := c.router.MatchingMount(path) if me == nil {
if match == "" || path != match { return fmt.Errorf("[ERR] core: no matching mount at '%s'", path)
return fmt.Errorf("no matching mount at '%s'", path)
} }
// Find and modify mount
for _, ent := range c.mounts.Entries {
if ent.Path == path {
if config.MaxLeaseTTL != nil { if config.MaxLeaseTTL != nil {
if *ent.Config.DefaultLeaseTTL != 0 { if *me.Config.DefaultLeaseTTL != 0 {
if *config.MaxLeaseTTL < *ent.Config.DefaultLeaseTTL { if *config.MaxLeaseTTL < *me.Config.DefaultLeaseTTL {
return fmt.Errorf("Given backend max lease TTL of %d less than backend default lease TTL of %d", return fmt.Errorf("Given backend max lease TTL of %d less than backend default lease TTL of %d",
*config.MaxLeaseTTL, *ent.Config.DefaultLeaseTTL) *config.MaxLeaseTTL, *me.Config.DefaultLeaseTTL)
} }
} }
if *config.MaxLeaseTTL == 0 { if *config.MaxLeaseTTL == 0 {
*ent.Config.MaxLeaseTTL = 0 *me.Config.MaxLeaseTTL = 0
} else { } else {
ent.Config.MaxLeaseTTL = config.MaxLeaseTTL me.Config.MaxLeaseTTL = config.MaxLeaseTTL
} }
} }
if config.DefaultLeaseTTL != nil { if config.DefaultLeaseTTL != nil {
if *ent.Config.MaxLeaseTTL == 0 { if *me.Config.MaxLeaseTTL == 0 {
if *config.DefaultLeaseTTL > c.maxLeaseTTL { if *config.DefaultLeaseTTL > c.maxLeaseTTL {
return fmt.Errorf("Given default lease TTL of %d greater than system default lease TTL of %d", return fmt.Errorf("Given default lease TTL of %d greater than system default lease TTL of %d",
*config.DefaultLeaseTTL, c.maxLeaseTTL) *config.DefaultLeaseTTL, c.maxLeaseTTL)
} }
} else { } else {
if *ent.Config.MaxLeaseTTL != 0 && *ent.Config.MaxLeaseTTL < *config.DefaultLeaseTTL { if *me.Config.MaxLeaseTTL != 0 && *me.Config.MaxLeaseTTL < *config.DefaultLeaseTTL {
return fmt.Errorf("Given default lease TTL of %d greater than backend max lease TTL of %d", return fmt.Errorf("Given default lease TTL of %d greater than backend max lease TTL of %d",
*config.DefaultLeaseTTL, *ent.Config.MaxLeaseTTL) *config.DefaultLeaseTTL, *me.Config.MaxLeaseTTL)
} }
} }
if *config.DefaultLeaseTTL == 0 { if *config.DefaultLeaseTTL == 0 {
*ent.Config.DefaultLeaseTTL = 0 *me.Config.DefaultLeaseTTL = 0
} else { } else {
ent.Config.DefaultLeaseTTL = config.DefaultLeaseTTL me.Config.DefaultLeaseTTL = config.DefaultLeaseTTL
}
}
break
} }
} }
@@ -508,6 +475,7 @@ func (c *Core) persistMounts(table *MountTable) error {
func (c *Core) setupMounts() error { func (c *Core) setupMounts() error {
var backend logical.Backend var backend logical.Backend
var view *BarrierView var view *BarrierView
var err error
for _, entry := range c.mounts.Entries { for _, entry := range c.mounts.Entries {
// Initialize the backend, special casing for system // Initialize the backend, special casing for system
barrierPath := backendBarrierPrefix + entry.UUID + "/" barrierPath := backendBarrierPrefix + entry.UUID + "/"
@@ -520,11 +488,7 @@ func (c *Core) setupMounts() error {
// Initialize the backend // Initialize the backend
// Create the new backend // Create the new backend
sysView, err := c.mountEntrySysView(entry) backend, err = c.newLogicalBackend(entry.Type, c.mountEntrySysView(entry), view, nil)
if err != nil {
return err
}
backend, err = c.newLogicalBackend(entry.Type, sysView, view, nil)
if err != nil { if err != nil {
c.logger.Printf( c.logger.Printf(
"[ERR] core: failed to create mount entry %#v: %v", "[ERR] core: failed to create mount entry %#v: %v",
@@ -533,11 +497,11 @@ func (c *Core) setupMounts() error {
} }
if entry.Type == "system" { if entry.Type == "system" {
c.systemView = view c.systemBarrierView = view
} }
// Mount the backend // Mount the backend
err = c.router.Mount(backend, entry.Path, entry.UUID, view) err = c.router.Mount(backend, entry.Path, entry, view)
if err != nil { if err != nil {
c.logger.Printf("[ERR] core: failed to mount entry %#v: %v", entry, err) c.logger.Printf("[ERR] core: failed to mount entry %#v: %v", entry, err)
return errLoadMountsFailed return errLoadMountsFailed
@@ -556,7 +520,7 @@ func (c *Core) setupMounts() error {
func (c *Core) unloadMounts() error { func (c *Core) unloadMounts() error {
c.mounts = nil c.mounts = nil
c.router = NewRouter() c.router = NewRouter()
c.systemView = nil c.systemBarrierView = nil
return nil return nil
} }
@@ -582,82 +546,13 @@ func (c *Core) newLogicalBackend(t string, sysView logical.SystemView, view logi
} }
// mountEntrySysView creates a logical.SystemView from global and // mountEntrySysView creates a logical.SystemView from global and
// mount-specific entries // mount-specific entries; because this should be called when setting
func (c *Core) mountEntrySysView(me *MountEntry) (logical.SystemView, error) { // up a mountEntry, it doesn't check to ensure that me is not nil
if me == nil { func (c *Core) mountEntrySysView(me *MountEntry) logical.SystemView {
return nil, fmt.Errorf("[ERR] core: nil MountEntry when generating SystemView") return dynamicSystemView{
}
sysView := dynamicSystemView{
core: c, core: c,
path: me.Path, path: me.Path,
} }
return sysView, nil
}
// sysViewByPath is a simple helper for MountEntrySysView
func (c *Core) sysViewByPath(path string) (logical.SystemView, error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
me, err := c.mountEntryByPath(path)
if err != nil {
return nil, err
}
return c.mountEntrySysView(me)
}
// mountEntryByPath searches across all tables to find the MountEntry
func (c *Core) mountEntryByPath(path string) (*MountEntry, error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
pathSep := strings.IndexRune(path, '/')
if pathSep == -1 {
return nil, fmt.Errorf("[ERR] core: failed to find separator for path %s", path)
}
me := c.mounts.Find(path[0 : pathSep+1])
if me == nil {
me = c.auth.Find(path[0 : pathSep+1])
}
if me == nil {
me = c.audit.Find(path[0 : pathSep+1])
}
if me == nil {
return nil, fmt.Errorf("[ERR] core: failed to find mount entry for path %s", path)
}
return me, nil
}
// TTLsByPath returns the default and max TTLs corresponding to a particular
// mount point, or the system default
func (c *Core) TTLsByPath(path string) (def, max time.Duration, retErr error) {
// Ensure we end the path in a slash
if !strings.HasSuffix(path, "/") {
path += "/"
}
me, err := c.mountEntryByPath(path)
if err != nil {
return 0, 0, err
}
def = c.defaultLeaseTTL
max = c.maxLeaseTTL
if me.Config.DefaultLeaseTTL != nil && *me.Config.DefaultLeaseTTL != 0 {
def = *me.Config.DefaultLeaseTTL
}
if me.Config.MaxLeaseTTL != nil && *me.Config.MaxLeaseTTL != 0 {
max = *me.Config.MaxLeaseTTL
}
return
} }
// defaultMountTable creates a default mount table // defaultMountTable creates a default mount table

View File

@@ -192,7 +192,7 @@ func TestCore_Unmount_Cleanup(t *testing.T) {
func TestCore_Remount(t *testing.T) { func TestCore_Remount(t *testing.T) {
c, key, _ := TestCoreUnsealed(t) c, key, _ := TestCoreUnsealed(t)
err := c.remount("secret", "foo", MountConfig{}) err := c.remount("secret", "foo")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -280,7 +280,7 @@ func TestCore_Remount_Cleanup(t *testing.T) {
} }
// Remount, this should cleanup // Remount, this should cleanup
if err := c.remount("test/", "new/", MountConfig{}); err != nil { if err := c.remount("test/", "new/"); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -309,7 +309,7 @@ func TestCore_Remount_Cleanup(t *testing.T) {
func TestCore_Remount_Protected(t *testing.T) { func TestCore_Remount_Protected(t *testing.T) {
c, _, _ := TestCoreUnsealed(t) c, _, _ := TestCoreUnsealed(t)
err := c.remount("sys", "foo", MountConfig{}) err := c.remount("sys", "foo")
if err.Error() != "cannot remount 'sys/'" { if err.Error() != "cannot remount 'sys/'" {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@@ -46,7 +46,7 @@ func NewPolicyStore(view *BarrierView) *PolicyStore {
// when the vault is being unsealed. // when the vault is being unsealed.
func (c *Core) setupPolicyStore() error { func (c *Core) setupPolicyStore() error {
// Create a sub-view // Create a sub-view
view := c.systemView.SubView(policySubPath) view := c.systemBarrierView.SubView(policySubPath)
// Create the policy store // Create the policy store
c.policy = NewPolicyStore(view) c.policy = NewPolicyStore(view)

View File

@@ -21,7 +21,7 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
Path: "foo", Path: "foo",
}, },
} }
if err := router.Mount(backend, "foo", uuid.GenerateUUID(), nil); err != nil { if err := router.Mount(backend, "foo", &MountEntry{UUID: uuid.GenerateUUID()}, nil); err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }

View File

@@ -26,24 +26,24 @@ func NewRouter() *Router {
return r return r
} }
// mountEntry is used to represent a mount point // routeEntry is used to represent a mount point in the router
type mountEntry struct { type routeEntry struct {
tainted bool tainted bool
salt string
backend logical.Backend backend logical.Backend
mountEntry *MountEntry
view *BarrierView view *BarrierView
rootPaths *radix.Tree rootPaths *radix.Tree
loginPaths *radix.Tree loginPaths *radix.Tree
} }
// SaltID is used to apply a salt and hash to an ID to make sure its not reversable // SaltID is used to apply a salt and hash to an ID to make sure its not reversable
func (me *mountEntry) SaltID(id string) string { func (re *routeEntry) SaltID(id string) string {
return salt.SaltID(me.salt, id, salt.SHA1Hash) return salt.SaltID(re.mountEntry.UUID, id, salt.SHA1Hash)
} }
// Mount is used to expose a logical backend at a given prefix, using a unique salt, // Mount is used to expose a logical backend at a given prefix, using a unique salt,
// and the barrier view for that path. // and the barrier view for that path.
func (r *Router) Mount(backend logical.Backend, prefix, salt string, view *BarrierView) error { func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *MountEntry, view *BarrierView) error {
r.l.Lock() r.l.Lock()
defer r.l.Unlock() defer r.l.Unlock()
@@ -59,14 +59,15 @@ func (r *Router) Mount(backend logical.Backend, prefix, salt string, view *Barri
} }
// Create a mount entry // Create a mount entry
me := &mountEntry{ re := &routeEntry{
tainted: false, tainted: false,
backend: backend, backend: backend,
mountEntry: mountEntry,
view: view, view: view,
rootPaths: pathsToRadix(paths.Root), rootPaths: pathsToRadix(paths.Root),
loginPaths: pathsToRadix(paths.Unauthenticated), loginPaths: pathsToRadix(paths.Unauthenticated),
} }
r.root.Insert(prefix, me) r.root.Insert(prefix, re)
return nil return nil
} }
@@ -91,12 +92,8 @@ func (r *Router) Remount(src, dst string) error {
// Update the mount point // Update the mount point
r.root.Delete(src) r.root.Delete(src)
mountEntry, ok := raw.(*mountEntry) routeEntry := raw.(*routeEntry)
if !ok { dynSysView, ok := routeEntry.backend.System().(dynamicSystemView)
return fmt.Errorf("Unable to retrieve mount entry at '%s'", src)
}
sysView := mountEntry.backend.System()
dynSysView, ok := sysView.(dynamicSystemView)
if ok { if ok {
dynSysView.path = dst dynSysView.path = dst
} }
@@ -111,7 +108,7 @@ func (r *Router) Taint(path string) error {
defer r.l.Unlock() defer r.l.Unlock()
_, raw, ok := r.root.LongestPrefix(path) _, raw, ok := r.root.LongestPrefix(path)
if ok { if ok {
raw.(*mountEntry).tainted = true raw.(*routeEntry).tainted = true
} }
return nil return nil
} }
@@ -122,7 +119,7 @@ func (r *Router) Untaint(path string) error {
defer r.l.Unlock() defer r.l.Unlock()
_, raw, ok := r.root.LongestPrefix(path) _, raw, ok := r.root.LongestPrefix(path)
if ok { if ok {
raw.(*mountEntry).tainted = false raw.(*routeEntry).tainted = false
} }
return nil return nil
} }
@@ -146,7 +143,29 @@ func (r *Router) MatchingView(path string) *BarrierView {
if !ok { if !ok {
return nil return nil
} }
return raw.(*mountEntry).view return raw.(*routeEntry).view
}
// MatchingMountEntry returns the MountEntry used for a path
func (r *Router) MatchingMountEntry(path string) *MountEntry {
r.l.RLock()
_, raw, ok := r.root.LongestPrefix(path)
r.l.RUnlock()
if !ok {
return nil
}
return raw.(*routeEntry).mountEntry
}
// MatchingSystemView returns the SystemView used for a path
func (r *Router) MatchingSystemView(path string) logical.SystemView {
r.l.RLock()
_, raw, ok := r.root.LongestPrefix(path)
r.l.RUnlock()
if !ok {
return nil
}
return raw.(*routeEntry).backend.System()
} }
// Route is used to route a given request // Route is used to route a given request
@@ -166,11 +185,11 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
} }
defer metrics.MeasureSince([]string{"route", string(req.Operation), defer metrics.MeasureSince([]string{"route", string(req.Operation),
strings.Replace(mount, "/", "-", -1)}, time.Now()) strings.Replace(mount, "/", "-", -1)}, time.Now())
me := raw.(*mountEntry) re := raw.(*routeEntry)
// If the path is tainted, we reject any operation except for // If the path is tainted, we reject any operation except for
// Rollback and Revoke // Rollback and Revoke
if me.tainted { if re.tainted {
switch req.Operation { switch req.Operation {
case logical.RevokeOperation, logical.RollbackOperation: case logical.RevokeOperation, logical.RollbackOperation:
default: default:
@@ -190,12 +209,12 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
} }
// Attach the storage view for the request // Attach the storage view for the request
req.Storage = me.view req.Storage = re.view
// Hash the request token unless this is the token backend // Hash the request token unless this is the token backend
clientToken := req.ClientToken clientToken := req.ClientToken
if !strings.HasPrefix(original, "auth/token/") { if !strings.HasPrefix(original, "auth/token/") {
req.ClientToken = me.SaltID(req.ClientToken) req.ClientToken = re.SaltID(req.ClientToken)
} }
// If the request is not a login path, then clear the connection // If the request is not a login path, then clear the connection
@@ -214,7 +233,7 @@ func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
}() }()
// Invoke the backend // Invoke the backend
return me.backend.HandleRequest(req) return re.backend.HandleRequest(req)
} }
// RootPath checks if the given path requires root privileges // RootPath checks if the given path requires root privileges
@@ -225,13 +244,13 @@ func (r *Router) RootPath(path string) bool {
if !ok { if !ok {
return false return false
} }
me := raw.(*mountEntry) re := raw.(*routeEntry)
// Trim to get remaining path // Trim to get remaining path
remain := strings.TrimPrefix(path, mount) remain := strings.TrimPrefix(path, mount)
// Check the rootPaths of this backend // Check the rootPaths of this backend
match, raw, ok := me.rootPaths.LongestPrefix(remain) match, raw, ok := re.rootPaths.LongestPrefix(remain)
if !ok { if !ok {
return false return false
} }
@@ -254,13 +273,13 @@ func (r *Router) LoginPath(path string) bool {
if !ok { if !ok {
return false return false
} }
me := raw.(*mountEntry) re := raw.(*routeEntry)
// Trim to get remaining path // Trim to get remaining path
remain := strings.TrimPrefix(path, mount) remain := strings.TrimPrefix(path, mount)
// Check the loginPaths of this backend // Check the loginPaths of this backend
match, raw, ok := me.loginPaths.LongestPrefix(remain) match, raw, ok := re.loginPaths.LongestPrefix(remain)
if !ok { if !ok {
return false return false
} }

View File

@@ -55,12 +55,12 @@ func TestRouter_Mount(t *testing.T) {
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{} n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
err = r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) err = r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if !strings.Contains(err.Error(), "cannot mount under existing mount") { if !strings.Contains(err.Error(), "cannot mount under existing mount") {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -104,7 +104,7 @@ func TestRouter_Unmount(t *testing.T) {
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{} n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -129,7 +129,7 @@ func TestRouter_Remount(t *testing.T) {
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{} n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -177,7 +177,7 @@ func TestRouter_RootPath(t *testing.T) {
"policy/*", "policy/*",
}, },
} }
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -215,7 +215,7 @@ func TestRouter_LoginPath(t *testing.T) {
"oauth/*", "oauth/*",
}, },
} }
err := r.Mount(n, "auth/foo/", uuid.GenerateUUID(), view) err := r.Mount(n, "auth/foo/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -246,7 +246,7 @@ func TestRouter_Taint(t *testing.T) {
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{} n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -285,7 +285,7 @@ func TestRouter_Untaint(t *testing.T) {
view := NewBarrierView(barrier, "logical/") view := NewBarrierView(barrier, "logical/")
n := &NoopBackend{} n := &NoopBackend{}
err := r.Mount(n, "prod/aws/", uuid.GenerateUUID(), view) err := r.Mount(n, "prod/aws/", &MountEntry{UUID: uuid.GenerateUUID()}, view)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@@ -62,10 +62,12 @@ func TestCore(t *testing.T) *Core {
}, },
} }
noopBackends := make(map[string]logical.Factory) noopBackends := make(map[string]logical.Factory)
noopBackends["noop"] = func(*logical.BackendConfig) (logical.Backend, error) { noopBackends["noop"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return new(framework.Backend), nil b := new(framework.Backend)
b.Setup(config)
return b, nil
} }
noopBackends["http"] = func(*logical.BackendConfig) (logical.Backend, error) { noopBackends["http"] = func(config *logical.BackendConfig) (logical.Backend, error) {
return new(rawHTTP), nil return new(rawHTTP), nil
} }
logicalBackends := make(map[string]logical.Factory) logicalBackends := make(map[string]logical.Factory)

View File

@@ -48,9 +48,9 @@ type TokenStore struct {
// NewTokenStore is used to construct a token store that is // NewTokenStore is used to construct a token store that is
// backed by the given barrier view. // backed by the given barrier view.
func NewTokenStore(c *Core) (*TokenStore, error) { func NewTokenStore(c *Core, config *logical.BackendConfig) (*TokenStore, error) {
// Create a sub-view // Create a sub-view
view := c.systemView.SubView(tokenSubPath) view := c.systemBarrierView.SubView(tokenSubPath)
// Initialize the store // Initialize the store
t := &TokenStore{ t := &TokenStore{
@@ -203,6 +203,8 @@ func NewTokenStore(c *Core) (*TokenStore, error) {
}, },
} }
t.Backend.Setup(config)
return t, nil return t, nil
} }

View File

@@ -10,19 +10,30 @@ import (
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
func getBackendConfig(c *Core) *logical.BackendConfig {
return &logical.BackendConfig{
Logger: c.logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 30,
},
}
}
func mockTokenStore(t *testing.T) (*Core, *TokenStore, string) { func mockTokenStore(t *testing.T) (*Core, *TokenStore, string) {
logger := log.New(os.Stderr, "", log.LstdFlags) logger := log.New(os.Stderr, "", log.LstdFlags)
c, _, root := TestCoreUnsealed(t) c, _, root := TestCoreUnsealed(t)
ts, err := NewTokenStore(c)
ts, err := NewTokenStore(c, getBackendConfig(c))
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
router := NewRouter() router := NewRouter()
router.Mount(ts, "auth/token/", "", ts.view) router.Mount(ts, "auth/token/", &MountEntry{UUID: ""}, ts.view)
view := c.systemView.SubView(expirationSubPath) view := c.systemBarrierView.SubView(expirationSubPath)
exp := NewExpirationManager(router, view, ts, logger) exp := NewExpirationManager(router, view, ts, logger)
ts.SetExpirationManager(exp) ts.SetExpirationManager(exp)
return c, ts, root return c, ts, root
@@ -68,7 +79,7 @@ func TestTokenStore_CreateLookup(t *testing.T) {
} }
// New store should share the salt // New store should share the salt
ts2, err := NewTokenStore(c) ts2, err := NewTokenStore(c, getBackendConfig(c))
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -107,7 +118,7 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) {
} }
// New store should share the salt // New store should share the salt
ts2, err := NewTokenStore(c) ts2, err := NewTokenStore(c, getBackendConfig(c))
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@@ -219,7 +230,7 @@ func TestTokenStore_Revoke_Leases(t *testing.T) {
// Mount a noop backend // Mount a noop backend
noop := &NoopBackend{} noop := &NoopBackend{}
ts.expiration.router.Mount(noop, "", "", nil) ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, nil)
ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}} ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}}
if err := ts.Create(ent); err != nil { if err := ts.Create(ent); err != nil {