From 2cbce548697f52549178dd475509877b554bc0ac Mon Sep 17 00:00:00 2001 From: Josh Black Date: Fri, 4 Mar 2022 14:16:51 -0800 Subject: [PATCH] Only create new batch tokens if we're on at least 1.10.0 (#14370) --- vault/core_test.go | 106 ++++++++++++++++++++++++++++++++++++ vault/token_store.go | 20 ++++++- vault/version_store.go | 20 ++++++- vault/version_store_test.go | 42 ++++++++++++++ 4 files changed, 184 insertions(+), 4 deletions(-) diff --git a/vault/core_test.go b/vault/core_test.go index b970b4c67f..7f22247509 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -402,6 +402,112 @@ func TestCore_Seal_BadToken(t *testing.T) { } } +func TestCore_PreOneTen_BatchTokens(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + // load up some versions and ensure that 1.9 is the most recent one by timestamp (even though this isn't realistic) + upgradeTimePlusEpsilon := time.Now().UTC() + + versionEntries := []struct { + version string + ts time.Time + }{ + {"1.10.1", upgradeTimePlusEpsilon.Add(-4 * time.Hour)}, + {"1.9.2", upgradeTimePlusEpsilon.Add(2 * time.Hour)}, + } + + for _, entry := range versionEntries { + _, err := c.storeVersionTimestamp(context.Background(), entry.version, entry.ts, false) + if err != nil { + t.Fatalf("failed to write version entry %#v, err: %s", entry, err.Error()) + } + } + + err := c.loadVersionTimestamps(c.activeContext) + if err != nil { + t.Fatalf("failed to populate version history cache, err: %s", err.Error()) + } + + // double check that we're working with 1.9 + v, _, err := c.FindNewestVersionTimestamp() + if err != nil { + t.Fatal(err) + } + if v != "1.9.2" { + t.Fatalf("expected 1.9.2, found: %s", v) + } + + // generate a batch token + te := &logical.TokenEntry{ + NumUses: 1, + Policies: []string{"root"}, + NamespaceID: namespace.RootNamespaceID, + Type: logical.TokenTypeBatch, + } + err = c.tokenStore.create(namespace.RootContext(nil), te) + if err != nil { + t.Fatal(err) + } + + // verify it uses the legacy prefix + if !strings.HasPrefix(te.ID, consts.LegacyBatchTokenPrefix) { + t.Fatalf("expected 1.9 batch token IDs to start with b. but it didn't: %s", te.ID) + } +} + +func TestCore_OneTenPlus_BatchTokens(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + // load up some versions and ensure that 1.10 is the most recent version + upgradeTimePlusEpsilon := time.Now().UTC() + + versionEntries := []struct { + version string + ts time.Time + }{ + {"1.9.2", upgradeTimePlusEpsilon.Add(-4 * time.Hour)}, + {"1.10.1", upgradeTimePlusEpsilon.Add(2 * time.Hour)}, + } + + for _, entry := range versionEntries { + _, err := c.storeVersionTimestamp(context.Background(), entry.version, entry.ts, false) + if err != nil { + t.Fatalf("failed to write version entry %#v, err: %s", entry, err.Error()) + } + } + + err := c.loadVersionTimestamps(c.activeContext) + if err != nil { + t.Fatalf("failed to populate version history cache, err: %s", err.Error()) + } + + // double check that we're working with 1.10 + v, _, err := c.FindNewestVersionTimestamp() + if err != nil { + t.Fatal(err) + } + if v != "1.10.1" { + t.Fatalf("expected 1.10.1, found: %s", v) + } + + // generate a batch token + te := &logical.TokenEntry{ + NumUses: 1, + Policies: []string{"root"}, + NamespaceID: namespace.RootNamespaceID, + Type: logical.TokenTypeBatch, + } + err = c.tokenStore.create(namespace.RootContext(nil), te) + if err != nil { + t.Fatal(err) + } + + // verify it uses the legacy prefix + if !strings.HasPrefix(te.ID, consts.BatchTokenPrefix) { + t.Fatalf("expected 1.10 batch token IDs to start with hvb. but it didn't: %s", te.ID) + } +} + // GH-3497 func TestCore_Seal_SingleUse(t *testing.T) { c, keys, _ := TestCoreUnsealed(t) diff --git a/vault/token_store.go b/vault/token_store.go index b3308d7316..9cfc1d654f 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -24,6 +24,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-sockaddr" + "github.com/hashicorp/go-version" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" @@ -991,10 +992,23 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err } bEntry := base64.RawURLEncoding.EncodeToString(eEntry) - if ts.core.DisableSSCTokens() { - entry.ID = fmt.Sprintf("b.%s", bEntry) + ver, _, err := ts.core.FindNewestVersionTimestamp() + if err != nil { + return err + } + newestVersion, err := version.NewVersion(ver) + if err != nil { + return err + } + oneTen, err := version.NewVersion("1.10.0") + if err != nil { + return err + } + + if ts.core.DisableSSCTokens() || newestVersion.LessThan(oneTen) { + entry.ID = consts.LegacyBatchTokenPrefix + bEntry } else { - entry.ID = fmt.Sprintf("hvb.%s", bEntry) + entry.ID = consts.BatchTokenPrefix + bEntry } if tokenNS.ID != namespace.RootNamespaceID { diff --git a/vault/version_store.go b/vault/version_store.go index 6890bcf187..fdf452585d 100644 --- a/vault/version_store.go +++ b/vault/version_store.go @@ -70,7 +70,7 @@ func (c *Core) storeVersionTimestamp(ctx context.Context, version string, timest // FindOldestVersionTimestamp searches for the vault version with the oldest // upgrade timestamp from storage. The earliest version this can be is 1.9.0. func (c *Core) FindOldestVersionTimestamp() (string, time.Time, error) { - if c.versionTimestamps == nil || len(c.versionTimestamps) == 0 { + if len(c.versionTimestamps) == 0 { return "", time.Time{}, fmt.Errorf("version timestamps are not initialized") } @@ -86,6 +86,24 @@ func (c *Core) FindOldestVersionTimestamp() (string, time.Time, error) { return oldestVersion, oldestUpgradeTime, nil } +func (c *Core) FindNewestVersionTimestamp() (string, time.Time, error) { + if len(c.versionTimestamps) == 0 { + return "", time.Time{}, fmt.Errorf("version timestamps are not initialized") + } + + var newestUpgradeTime time.Time + var newestVersion string + + for version, upgradeTime := range c.versionTimestamps { + if upgradeTime.After(newestUpgradeTime) { + newestVersion = version + newestUpgradeTime = upgradeTime + } + } + + return newestVersion, newestUpgradeTime, nil +} + // loadVersionTimestamps loads all the vault versions and associated upgrade // timestamps from storage. Version timestamps were originally stored in local // time. A timestamp that is not in UTC will be rewritten to storage as UTC. diff --git a/vault/version_store_test.go b/vault/version_store_test.go index 92c010d446..28dc8668d7 100644 --- a/vault/version_store_test.go +++ b/vault/version_store_test.go @@ -68,6 +68,48 @@ func TestVersionStore_GetOldestVersion(t *testing.T) { } } +// TestVersionStore_GetNewestVersion verifies that FindNewestVersionTimestamp finds the newest +// (in time) vault version stored. +func TestVersionStore_GetNewestVersion(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + upgradeTimePlusEpsilon := time.Now().UTC() + + // 1.6.1 is stored after 1.6.2, so even though it is a lower number, it should be returned. + versionEntries := []struct { + version string + ts time.Time + }{ + {"1.6.2", upgradeTimePlusEpsilon.Add(-4 * time.Hour)}, + {"1.6.1", upgradeTimePlusEpsilon.Add(2 * time.Hour)}, + } + + for _, entry := range versionEntries { + _, err := c.storeVersionTimestamp(context.Background(), entry.version, entry.ts, false) + if err != nil { + t.Fatalf("failed to write version entry %#v, err: %s", entry, err.Error()) + } + } + + err := c.loadVersionTimestamps(c.activeContext) + if err != nil { + t.Fatalf("failed to populate version history cache, err: %s", err.Error()) + } + + if len(c.versionTimestamps) != 3 { + t.Fatalf("expected 3 entries in timestamps map after refresh, found: %d", len(c.versionTimestamps)) + } + v, tm, err := c.FindNewestVersionTimestamp() + if err != nil { + t.Fatal(err) + } + if v != "1.6.1" { + t.Fatalf("expected 1.6.1, found: %s", v) + } + if tm.Before(upgradeTimePlusEpsilon.Add(1*time.Hour)) || tm.After(upgradeTimePlusEpsilon.Add(3*time.Hour)) { + t.Fatalf("incorrect upgrade time logged: %v", tm) + } +} + func TestVersionStore_SelfHealUTC(t *testing.T) { c, _, _ := TestCoreUnsealed(t) estLoc, err := time.LoadLocation("EST")