diff --git a/api/sys_auth.go b/api/sys_auth.go index f9f3c8c2cc..ff479a20e1 100644 --- a/api/sys_auth.go +++ b/api/sys_auth.go @@ -90,6 +90,7 @@ type EnableAuthOptions struct { type AuthMount struct { Type string `json:"type" structs:"type" mapstructure:"type"` Description string `json:"description" structs:"description" mapstructure:"description"` + Accessor string `json:"accessor" structs:"accessor" mapstructure:"accessor"` Config AuthConfigOutput `json:"config" structs:"config" mapstructure:"config"` Local bool `json:"local" structs:"local" mapstructure:"local"` } diff --git a/api/sys_mounts.go b/api/sys_mounts.go index 907fddb704..d358f8d871 100644 --- a/api/sys_mounts.go +++ b/api/sys_mounts.go @@ -135,6 +135,7 @@ type MountConfigInput struct { type MountOutput struct { Type string `json:"type" structs:"type"` Description string `json:"description" structs:"description"` + Accessor string `json:"accessor" structs:"accessor"` Config MountConfigOutput `json:"config" structs:"config"` Local bool `json:"local" structs:"local"` } diff --git a/command/auth.go b/command/auth.go index fc614a10b1..e8ef7e2837 100644 --- a/command/auth.go +++ b/command/auth.go @@ -316,7 +316,7 @@ func (c *AuthCommand) listMethods() int { } sort.Strings(paths) - columns := []string{"Path | Type | Default TTL | Max TTL | Replication Behavior | Description"} + columns := []string{"Path | Type | Accessor | Default TTL | Max TTL | Replication Behavior | Description"} for _, path := range paths { auth := auth[path] defTTL := "system" @@ -332,7 +332,7 @@ func (c *AuthCommand) listMethods() int { replicatedBehavior = "local" } columns = append(columns, fmt.Sprintf( - "%s | %s | %s | %s | %s | %s", path, auth.Type, defTTL, maxTTL, replicatedBehavior, auth.Description)) + "%s | %s | %s | %s | %s | %s | %s", path, auth.Type, auth.Accessor, defTTL, maxTTL, replicatedBehavior, auth.Description)) } c.Ui.Output(columnize.SimpleFormat(columns)) diff --git a/command/mounts.go b/command/mounts.go index d918d67124..2ee1665e12 100644 --- a/command/mounts.go +++ b/command/mounts.go @@ -42,7 +42,7 @@ func (c *MountsCommand) Run(args []string) int { } sort.Strings(paths) - columns := []string{"Path | Type | Default TTL | Max TTL | Force No Cache | Replication Behavior | Description"} + columns := []string{"Path | Type | Accessor | Default TTL | Max TTL | Force No Cache | Replication Behavior | Description"} for _, path := range paths { mount := mounts[path] defTTL := "system" @@ -68,7 +68,7 @@ func (c *MountsCommand) Run(args []string) int { replicatedBehavior = "local" } columns = append(columns, fmt.Sprintf( - "%s | %s | %s | %s | %v | %s | %s", path, mount.Type, defTTL, maxTTL, + "%s | %s | %s | %s | %s | %v | %s | %s", path, mount.Type, mount.Accessor, defTTL, maxTTL, mount.Config.ForceNoCache, replicatedBehavior, mount.Description)) } diff --git a/http/handler_test.go b/http/handler_test.go index 8450a8b6be..41a7a69c7a 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -221,6 +221,13 @@ func TestSysMounts_headerAuth(t *testing.T) { testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad:\nExpected: %#v\nActual: %#v\n", expected, actual) diff --git a/http/sys_auth_test.go b/http/sys_auth_test.go index 9e193916f0..fa3c692b3d 100644 --- a/http/sys_auth_test.go +++ b/http/sys_auth_test.go @@ -49,6 +49,13 @@ func TestSysAuth(t *testing.T) { testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: expected:%#v\nactual:%#v", expected, actual) @@ -120,6 +127,13 @@ func TestSysEnableAuth(t *testing.T) { testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: expected:%#v\nactual:%#v", expected, actual) @@ -176,6 +190,13 @@ func TestSysDisableAuth(t *testing.T) { testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: expected:%#v\nactual:%#v", expected, actual) diff --git a/http/sys_mount_test.go b/http/sys_mount_test.go index 2e12f0f798..d36160c50c 100644 --- a/http/sys_mount_test.go +++ b/http/sys_mount_test.go @@ -91,6 +91,14 @@ func TestSysMounts(t *testing.T) { testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } + if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: %#v", actual) } @@ -204,6 +212,14 @@ func TestSysMount(t *testing.T) { testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } + if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: %#v", actual) } @@ -339,6 +355,14 @@ func TestSysRemount(t *testing.T) { testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } + if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: %#v", actual) } @@ -435,6 +459,14 @@ func TestSysUnmount(t *testing.T) { testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } + if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: %#v", actual) } @@ -548,6 +580,14 @@ func TestSysTuneMount(t *testing.T) { testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } + if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad: %#v", actual) } @@ -683,6 +723,14 @@ func TestSysTuneMount(t *testing.T) { testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) expected["request_id"] = actual["request_id"] + for k, v := range actual["data"].(map[string]interface{}) { + if v.(map[string]interface{})["accessor"] == "" { + t.Fatalf("no accessor from %s", k) + } + expected[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + expected["data"].(map[string]interface{})[k].(map[string]interface{})["accessor"] = v.(map[string]interface{})["accessor"] + } + if !reflect.DeepEqual(actual, expected) { t.Fatalf("bad:\nExpected: %#v\nActual:%#v", expected, actual) } diff --git a/vault/audit.go b/vault/audit.go index 71e2f60aef..829ebba14c 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -79,6 +79,13 @@ func (c *Core) enableAudit(entry *MountEntry) error { } entry.UUID = entryUUID } + if entry.Accessor == "" { + accessor, err := c.generateMountAccessor("audit_" + entry.Type) + if err != nil { + return err + } + entry.Accessor = accessor + } viewPath := auditBarrierPrefix + entry.UUID + "/" view := NewBarrierView(c.barrier, viewPath) @@ -201,6 +208,14 @@ func (c *Core) loadAudits() error { entry.Table = c.audit.Type needPersist = true } + if entry.Accessor == "" { + accessor, err := c.generateMountAccessor("audit_" + entry.Type) + if err != nil { + return err + } + entry.Accessor = accessor + needPersist = true + } } if !needPersist { diff --git a/vault/audit_test.go b/vault/audit_test.go index 76eba563ae..a91298d7ae 100644 --- a/vault/audit_test.go +++ b/vault/audit_test.go @@ -220,16 +220,18 @@ func TestCore_EnableAudit_Local(t *testing.T) { Type: auditTableType, Entries: []*MountEntry{ &MountEntry{ - Table: auditTableType, - Path: "noop/", - Type: "noop", - UUID: "abcd", + Table: auditTableType, + Path: "noop/", + Type: "noop", + UUID: "abcd", + Accessor: "noop-abcd", }, &MountEntry{ - Table: auditTableType, - Path: "noop2/", - Type: "noop", - UUID: "bcde", + Table: auditTableType, + Path: "noop2/", + Type: "noop", + UUID: "bcde", + Accessor: "noop-bcde", }, }, } diff --git a/vault/auth.go b/vault/auth.go index 5a5e68b2f9..e6acaae0b7 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -85,7 +85,13 @@ func (c *Core) enableCredential(entry *MountEntry) error { } entry.UUID = entryUUID } - + if entry.Accessor == "" { + accessor, err := c.generateMountAccessor("auth_" + entry.Type) + if err != nil { + return err + } + entry.Accessor = accessor + } viewPath := credentialBarrierPrefix + entry.UUID + "/" view := NewBarrierView(c.barrier, viewPath) sysView := c.mountEntrySysView(entry) @@ -283,13 +289,21 @@ func (c *Core) loadCredentials() error { entry.Table = c.auth.Type needPersist = true } + if entry.Accessor == "" { + accessor, err := c.generateMountAccessor("auth_" + entry.Type) + if err != nil { + return err + } + entry.Accessor = accessor + needPersist = true + } } if !needPersist { return nil } } else { - c.auth = defaultAuthTable() + c.auth = c.defaultAuthTable() } if err := c.persistAuth(c.auth, false); err != nil { @@ -485,7 +499,7 @@ func (c *Core) newCredentialBackend( } // defaultAuthTable creates a default auth table -func defaultAuthTable() *MountTable { +func (c *Core) defaultAuthTable() *MountTable { table := &MountTable{ Type: credentialTableType, } @@ -493,12 +507,17 @@ func defaultAuthTable() *MountTable { if err != nil { panic(fmt.Sprintf("could not generate UUID for default auth table token entry: %v", err)) } + tokenAccessor, err := c.generateMountAccessor("auth_token") + if err != nil { + panic(fmt.Sprintf("could not generate accessor for default auth table token entry: %v", err)) + } tokenAuth := &MountEntry{ Table: credentialTableType, Path: "token/", Type: "token", Description: "token based credentials", UUID: tokenUUID, + Accessor: tokenAccessor, } table.Entries = append(table.Entries, tokenAuth) return table diff --git a/vault/auth_test.go b/vault/auth_test.go index bc150e965c..d45091defd 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -99,16 +99,18 @@ func TestCore_EnableCredential_Local(t *testing.T) { Type: credentialTableType, Entries: []*MountEntry{ &MountEntry{ - Table: credentialTableType, - Path: "noop/", - Type: "noop", - UUID: "abcd", + Table: credentialTableType, + Path: "noop/", + Type: "noop", + UUID: "abcd", + Accessor: "noop-abcd", }, &MountEntry{ - Table: credentialTableType, - Path: "noop2/", - Type: "noop", - UUID: "bcde", + Table: credentialTableType, + Path: "noop2/", + Type: "noop", + UUID: "bcde", + Accessor: "noop-bcde", }, }, } @@ -347,7 +349,8 @@ func TestCore_DisableCredential_Cleanup(t *testing.T) { } func TestDefaultAuthTable(t *testing.T) { - table := defaultAuthTable() + c, _, _ := TestCoreUnsealed(t) + table := c.defaultAuthTable() verifyDefaultAuthTable(t, table) } diff --git a/vault/logical_system.go b/vault/logical_system.go index 39b69c4c35..0d95605cc3 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1160,6 +1160,7 @@ func (b *SystemBackend) handleMountTable( info := map[string]interface{}{ "type": entry.Type, "description": entry.Description, + "accessor": entry.Accessor, "config": map[string]interface{}{ "default_lease_ttl": int64(entry.Config.DefaultLeaseTTL.Seconds()), "max_lease_ttl": int64(entry.Config.MaxLeaseTTL.Seconds()), @@ -1656,6 +1657,7 @@ func (b *SystemBackend) handleAuthTable( info := map[string]interface{}{ "type": entry.Type, "description": entry.Description, + "accessor": entry.Accessor, "config": map[string]interface{}{ "default_lease_ttl": int64(entry.Config.DefaultLeaseTTL.Seconds()), "max_lease_ttl": int64(entry.Config.MaxLeaseTTL.Seconds()), diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 12aca538ab..3301962e11 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -115,6 +115,7 @@ func TestSystemBackend_mounts(t *testing.T) { "secret/": map[string]interface{}{ "type": "generic", "description": "generic secret storage", + "accessor": resp.Data["secret/"].(map[string]interface{})["accessor"], "config": map[string]interface{}{ "default_lease_ttl": resp.Data["secret/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["secret/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), @@ -125,6 +126,7 @@ func TestSystemBackend_mounts(t *testing.T) { "sys/": map[string]interface{}{ "type": "system", "description": "system endpoints used for control, policy and debugging", + "accessor": resp.Data["sys/"].(map[string]interface{})["accessor"], "config": map[string]interface{}{ "default_lease_ttl": resp.Data["sys/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["sys/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), @@ -135,6 +137,7 @@ func TestSystemBackend_mounts(t *testing.T) { "cubbyhole/": map[string]interface{}{ "description": "per-token private secret storage", "type": "cubbyhole", + "accessor": resp.Data["cubbyhole/"].(map[string]interface{})["accessor"], "config": map[string]interface{}{ "default_lease_ttl": resp.Data["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["default_lease_ttl"].(int64), "max_lease_ttl": resp.Data["cubbyhole/"].(map[string]interface{})["config"].(map[string]interface{})["max_lease_ttl"].(int64), @@ -1113,6 +1116,7 @@ func TestSystemBackend_authTable(t *testing.T) { "token/": map[string]interface{}{ "type": "token", "description": "token based credentials", + "accessor": resp.Data["token/"].(map[string]interface{})["accessor"], "config": map[string]interface{}{ "default_lease_ttl": int64(0), "max_lease_ttl": int64(0), diff --git a/vault/mount.go b/vault/mount.go index 1c921001c1..1983b43894 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -66,6 +66,22 @@ var ( } ) +func (c *Core) generateMountAccessor(entryType string) (string, error) { + var accessor string + for { + randBytes, err := uuid.GenerateRandomBytes(4) + if err != nil { + return "", err + } + accessor = fmt.Sprintf("%s_%s", entryType, fmt.Sprintf("%08x", randBytes[0:4])) + if entry := c.router.MatchingMountByAccessor(accessor); entry == nil { + break + } + } + + return accessor, nil +} + // MountTable is used to represent the internal mount table type MountTable struct { Type string `json:"type"` @@ -139,6 +155,7 @@ type MountEntry struct { Type string `json:"type"` // Logical backend Type Description string `json:"description"` // User-provided description UUID string `json:"uuid"` // Barrier view UUID + Accessor string `json:"accessor"` // Unique but more human-friendly ID. Does not change, not used for any sensitive things (like as a salt, which the UUID sometimes is). Config MountConfig `json:"config"` // Configuration related to this mount (but not backend-derived) Options map[string]string `json:"options"` // Backend options Local bool `json:"local"` // Local mounts are not replicated or affected by replication @@ -164,6 +181,7 @@ func (e *MountEntry) Clone() *MountEntry { Type: e.Type, Description: e.Description, UUID: e.UUID, + Accessor: e.Accessor, Config: e.Config, Options: optClone, Local: e.Local, @@ -208,6 +226,13 @@ func (c *Core) mount(entry *MountEntry) error { } entry.UUID = entryUUID } + if entry.Accessor == "" { + accessor, err := c.generateMountAccessor(entry.Type) + if err != nil { + return err + } + entry.Accessor = accessor + } viewPath := backendBarrierPrefix + entry.UUID + "/" view := NewBarrierView(c.barrier, viewPath) sysView := c.mountEntrySysView(entry) @@ -504,7 +529,7 @@ func (c *Core) loadMounts() error { needPersist = true } - for _, requiredMount := range requiredMountTable().Entries { + for _, requiredMount := range c.requiredMountTable().Entries { foundRequired := false for _, coreMount := range c.mounts.Entries { if coreMount.Type == requiredMount.Type { @@ -535,6 +560,14 @@ func (c *Core) loadMounts() error { entry.Table = c.mounts.Type needPersist = true } + if entry.Accessor == "" { + accessor, err := c.generateMountAccessor(entry.Type) + if err != nil { + return err + } + entry.Accessor = accessor + needPersist = true + } } // Done if we have restored the mount table and we don't need @@ -544,7 +577,7 @@ func (c *Core) loadMounts() error { } } else { // Create and persist the default mount table - c.mounts = defaultMountTable() + c.mounts = c.defaultMountTable() } if err := c.persistMounts(c.mounts, false); err != nil { @@ -745,13 +778,17 @@ func (c *Core) mountEntrySysView(entry *MountEntry) logical.SystemView { } // defaultMountTable creates a default mount table -func defaultMountTable() *MountTable { +func (c *Core) defaultMountTable() *MountTable { table := &MountTable{ Type: mountTableType, } mountUUID, err := uuid.GenerateUUID() if err != nil { - panic(fmt.Sprintf("could not create default mount table UUID: %v", err)) + panic(fmt.Sprintf("could not create default secret mount UUID: %v", err)) + } + mountAccessor, err := c.generateMountAccessor("generic") + if err != nil { + panic(fmt.Sprintf("could not generate default secret mount accessor: %v", err)) } genericMount := &MountEntry{ Table: mountTableType, @@ -759,15 +796,16 @@ func defaultMountTable() *MountTable { Type: "generic", Description: "generic secret storage", UUID: mountUUID, + Accessor: mountAccessor, } table.Entries = append(table.Entries, genericMount) - table.Entries = append(table.Entries, requiredMountTable().Entries...) + table.Entries = append(table.Entries, c.requiredMountTable().Entries...) return table } // requiredMountTable() creates a mount table with entries required // to be available -func requiredMountTable() *MountTable { +func (c *Core) requiredMountTable() *MountTable { table := &MountTable{ Type: mountTableType, } @@ -775,12 +813,17 @@ func requiredMountTable() *MountTable { if err != nil { panic(fmt.Sprintf("could not create cubbyhole UUID: %v", err)) } + cubbyholeAccessor, err := c.generateMountAccessor("cubbyhole") + if err != nil { + panic(fmt.Sprintf("could not generate cubbyhole accessor: %v", err)) + } cubbyholeMount := &MountEntry{ Table: mountTableType, Path: "cubbyhole/", Type: "cubbyhole", Description: "per-token private secret storage", UUID: cubbyholeUUID, + Accessor: cubbyholeAccessor, Local: true, } @@ -788,12 +831,17 @@ func requiredMountTable() *MountTable { if err != nil { panic(fmt.Sprintf("could not create sys UUID: %v", err)) } + sysAccessor, err := c.generateMountAccessor("system") + if err != nil { + panic(fmt.Sprintf("could not generate sys accessor: %v", err)) + } sysMount := &MountEntry{ Table: mountTableType, Path: "sys/", Type: "system", Description: "system endpoints used for control, policy and debugging", UUID: sysUUID, + Accessor: sysAccessor, } table.Entries = append(table.Entries, cubbyholeMount) table.Entries = append(table.Entries, sysMount) diff --git a/vault/mount_test.go b/vault/mount_test.go index e0b751954a..acb0ae4b52 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -93,16 +93,18 @@ func TestCore_Mount_Local(t *testing.T) { Type: mountTableType, Entries: []*MountEntry{ &MountEntry{ - Table: mountTableType, - Path: "noop/", - Type: "generic", - UUID: "abcd", + Table: mountTableType, + Path: "noop/", + Type: "generic", + UUID: "abcd", + Accessor: "generic-abcd", }, &MountEntry{ - Table: mountTableType, - Path: "noop2/", - Type: "generic", - UUID: "bcde", + Table: mountTableType, + Path: "noop2/", + Type: "generic", + UUID: "bcde", + Accessor: "generic-bcde", }, }, } @@ -426,7 +428,8 @@ func TestCore_Remount_Protected(t *testing.T) { } func TestDefaultMountTable(t *testing.T) { - table := defaultMountTable() + c, _, _ := TestCoreUnsealed(t) + table := c.defaultMountTable() verifyDefaultTable(t, table) } diff --git a/vault/router.go b/vault/router.go index 6b3c190cce..e6dab3b8a1 100644 --- a/vault/router.go +++ b/vault/router.go @@ -14,10 +14,11 @@ import ( // Router is used to do prefix based routing of a request to a logical backend type Router struct { - l sync.RWMutex - root *radix.Tree - mountUUIDCache *radix.Tree - tokenStoreSalt *salt.Salt + l sync.RWMutex + root *radix.Tree + mountUUIDCache *radix.Tree + mountAccessorCache *radix.Tree + tokenStoreSalt *salt.Salt // storagePrefix maps the prefix used for storage (ala the BarrierView) // to the backend. This is used to map a key back into the backend that owns it. @@ -28,9 +29,10 @@ type Router struct { // NewRouter returns a new router func NewRouter() *Router { r := &Router{ - root: radix.New(), - storagePrefix: radix.New(), - mountUUIDCache: radix.New(), + root: radix.New(), + storagePrefix: radix.New(), + mountUUIDCache: radix.New(), + mountAccessorCache: radix.New(), } return r } @@ -80,6 +82,7 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount r.root.Insert(prefix, re) r.storagePrefix.Insert(storageView.prefix, re) r.mountUUIDCache.Insert(re.mountEntry.UUID, re.mountEntry) + r.mountAccessorCache.Insert(re.mountEntry.Accessor, re.mountEntry) return nil } @@ -103,6 +106,7 @@ func (r *Router) Unmount(prefix string) error { r.root.Delete(prefix) r.storagePrefix.Delete(re.storageView.prefix) r.mountUUIDCache.Delete(re.mountEntry.UUID) + r.mountAccessorCache.Delete(re.mountEntry.Accessor) return nil } @@ -163,6 +167,22 @@ func (r *Router) MatchingMountByUUID(mountID string) *MountEntry { return raw.(*MountEntry) } +func (r *Router) MatchingMountByAccessor(mountAccessor string) *MountEntry { + if mountAccessor == "" { + return nil + } + + r.l.RLock() + defer r.l.RUnlock() + + _, raw, ok := r.mountAccessorCache.LongestPrefix(mountAccessor) + if !ok { + return nil + } + + return raw.(*MountEntry) +} + // MatchingMount returns the mount prefix that would be used for a path func (r *Router) MatchingMount(path string) string { r.l.RLock()