diff --git a/Makefile b/Makefile index 173ffb3647..4c5dd1eeb2 100644 --- a/Makefile +++ b/Makefile @@ -194,6 +194,7 @@ proto: bootstrap protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/*.proto protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/database/dbplugin/v5/proto/*.proto protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/plugin/pb/*.proto + protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative vault/tokens/token.proto protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative sdk/helper/pluginutil/*.proto # No additional sed expressions should be added to this list. Going forward diff --git a/changelog/14109.txt b/changelog/14109.txt new file mode 100644 index 0000000000..f4a3d5039a --- /dev/null +++ b/changelog/14109.txt @@ -0,0 +1,3 @@ +```release-note:feature +Server Side Consistent Tokens: Service tokens now use SSC token format and token prefixes are updated." +``` \ No newline at end of file diff --git a/command/commands.go b/command/commands.go index cbf1b22a69..a0d82b6872 100644 --- a/command/commands.go +++ b/command/commands.go @@ -81,6 +81,10 @@ const ( // path to a license file on disk EnvVaultLicensePath = "VAULT_LICENSE_PATH" + // DisableSSCTokens is an env var used to disable index bearing + // token functionality + DisableSSCTokens = "VAULT_DISABLE_SERVER_SIDE_CONSISTENT_TOKENS" + // flagNameAddress is the flag used in the base command to read in the // address of the Vault server. flagNameAddress = "address" diff --git a/command/login_test.go b/command/login_test.go index cf8325dca9..fc3e18dbf7 100644 --- a/command/login_test.go +++ b/command/login_test.go @@ -13,6 +13,11 @@ import ( "github.com/hashicorp/vault/vault" ) +// minTokenLengthExternal is the minimum size of SSC +// tokens we are currently handing out to end users, without any +// namespace information +const minTokenLengthExternal = 91 + func testLoginCommand(tb testing.TB) (*cli.MockUi, *LoginCommand) { tb.Helper() @@ -82,7 +87,7 @@ func TestLoginCommand_Run(t *testing.T) { t.Fatal(err) } - if l, exp := len(storedToken), vault.TokenLength+2; l != exp { + if l, exp := len(storedToken), minTokenLengthExternal+vault.TokenPrefixLength; l != exp { t.Errorf("expected token to be %d characters, was %d: %q", exp, l, storedToken) } }) @@ -209,7 +214,7 @@ func TestLoginCommand_Run(t *testing.T) { // Verify only the token was printed token := ui.OutputWriter.String() - if l, exp := len(token), vault.TokenLength+2; l != exp { + if l, exp := len(token), minTokenLengthExternal+vault.TokenPrefixLength; l != exp { t.Errorf("expected token to be %d characters, was %d: %q", exp, l, token) } diff --git a/command/operator_generate_root_test.go b/command/operator_generate_root_test.go index f12ed5a4e0..b4489718ef 100644 --- a/command/operator_generate_root_test.go +++ b/command/operator_generate_root_test.go @@ -435,7 +435,7 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) { t.Fatal(err) } - if l, exp := len(token), vault.TokenLength+2; l != exp { + if l, exp := len(token), vault.TokenLength+vault.TokenPrefixLength; l != exp { t.Errorf("expected %d to be %d: %s", l, exp, token) } }) @@ -521,7 +521,7 @@ func TestOperatorGenerateRootCommand_Run(t *testing.T) { t.Fatal(err) } - if l, exp := len(token), vault.TokenLength+2; l != exp { + if l, exp := len(token), vault.TokenLength+vault.TokenPrefixLength; l != exp { t.Errorf("expected %d to be %d: %s", l, exp, token) } }) diff --git a/command/operator_init_test.go b/command/operator_init_test.go index 423edad763..491d623a14 100644 --- a/command/operator_init_test.go +++ b/command/operator_init_test.go @@ -333,7 +333,7 @@ func TestOperatorInitCommand_Run(t *testing.T) { root := match[0][1] decryptedRoot := testPGPDecrypt(t, pgpkeys.TestPrivKey1, root) - if l, exp := len(decryptedRoot), vault.TokenLength+2; l != exp { + if l, exp := len(decryptedRoot), vault.TokenLength+vault.TokenPrefixLength; l != exp { t.Errorf("expected %d to be %d", l, exp) } }) diff --git a/command/server.go b/command/server.go index 97f7654c10..ed6df3f35f 100644 --- a/command/server.go +++ b/command/server.go @@ -1163,6 +1163,15 @@ func (c *ServerCommand) Run(args []string) int { if envLicense := os.Getenv(EnvVaultLicense); envLicense != "" { config.License = envLicense } + if disableSSC := os.Getenv(DisableSSCTokens); disableSSC != "" { + var err error + config.DisableSSCTokens, err = strconv.ParseBool(disableSSC) + if err != nil { + c.UI.Warn(wrapAtLength("WARNING! failed to parse " + + "VAULT_DISABLE_SERVER_SIDE_CONSISTENT_TOKENS env var: " + + "setting to default value false")) + } + } // If mlockall(2) isn't supported, show a warning. We disable this in dev // because it is quite scary to see when first using Vault. We also disable @@ -2502,6 +2511,7 @@ func createCoreConfig(c *ServerCommand, config *server.Config, backend physical. EnableResponseHeaderRaftNodeID: config.EnableResponseHeaderRaftNodeID, License: config.License, LicensePath: config.LicensePath, + DisableSSCTokens: config.DisableSSCTokens, } if c.flagDev { coreConfig.EnableRaw = true diff --git a/command/server/config.go b/command/server/config.go index 15e1446610..45637d2e58 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -83,8 +83,9 @@ type Config struct { EnableResponseHeaderRaftNodeID bool `hcl:"-"` EnableResponseHeaderRaftNodeIDRaw interface{} `hcl:"enable_response_header_raft_node_id"` - License string `hcl:"-"` - LicensePath string `hcl:"license_path"` + License string `hcl:"-"` + LicensePath string `hcl:"license_path"` + DisableSSCTokens bool `hcl:"-"` } const ( diff --git a/helper/namespace/namespace.go b/helper/namespace/namespace.go index 1b59495cab..b6aba2bc3f 100644 --- a/helper/namespace/namespace.go +++ b/helper/namespace/namespace.go @@ -4,6 +4,8 @@ import ( "context" "errors" "strings" + + "github.com/hashicorp/vault/sdk/helper/consts" ) type contextValues struct{} @@ -105,6 +107,12 @@ func SplitIDFromString(input string) (string, string) { case strings.HasPrefix(input, "s."): prefix = "s." input = input[2:] + case strings.HasPrefix(input, consts.BatchTokenPrefix): + prefix = consts.BatchTokenPrefix + input = input[4:] + case strings.HasPrefix(input, consts.ServiceTokenPrefix): + prefix = consts.ServiceTokenPrefix + input = input[4:] case slashIdx > 0: // Leases will never have a b./s. to start diff --git a/http/sys_generate_root.go b/http/sys_generate_root.go index 4ac3015077..db2da6f7f3 100644 --- a/http/sys_generate_root.go +++ b/http/sys_generate_root.go @@ -64,6 +64,12 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r respondError(w, http.StatusInternalServerError, err) return } + var otpLength int + if core.DisableSSCTokens() { + otpLength = vault.TokenLength + vault.OldTokenPrefixLength + } else { + otpLength = vault.TokenLength + vault.TokenPrefixLength + } // Format the status status := &GenerateRootStatusResponse{ @@ -71,7 +77,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r Progress: progress, Required: sealConfig.SecretThreshold, Complete: false, - OTPLength: vault.TokenLength + 2, + OTPLength: otpLength, OTP: otp, } if generationConfig != nil { @@ -98,7 +104,11 @@ func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r case len(req.PGPKey) > 0, len(req.OTP) > 0: default: genned = true - req.OTP, err = base62.Random(vault.TokenLength + 2) + if core.DisableSSCTokens() { + req.OTP, err = base62.Random(vault.TokenLength + vault.OldTokenPrefixLength) + } else { + req.OTP, err = base62.Random(vault.TokenLength + vault.TokenPrefixLength) + } if err != nil { respondError(w, http.StatusInternalServerError, err) return diff --git a/http/sys_generate_root_test.go b/http/sys_generate_root_test.go index 60e951436d..f226d0042b 100644 --- a/http/sys_generate_root_test.go +++ b/http/sys_generate_root_test.go @@ -19,6 +19,8 @@ import ( "github.com/hashicorp/vault/vault" ) +var tokenLength string = fmt.Sprintf("%d", vault.TokenLength+vault.TokenPrefixLength) + func TestSysGenerateRootAttempt_Status(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) @@ -40,7 +42,7 @@ func TestSysGenerateRootAttempt_Status(t *testing.T) { "encoded_root_token": "", "pgp_fingerprint": "", "nonce": "", - "otp_length": json.Number("26"), + "otp_length": json.Number(tokenLength), } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -68,7 +70,7 @@ func TestSysGenerateRootAttempt_Setup_OTP(t *testing.T) { "encoded_token": "", "encoded_root_token": "", "pgp_fingerprint": "", - "otp_length": json.Number("26"), + "otp_length": json.Number(tokenLength), } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -93,7 +95,7 @@ func TestSysGenerateRootAttempt_Setup_OTP(t *testing.T) { "encoded_root_token": "", "pgp_fingerprint": "", "otp": "", - "otp_length": json.Number("26"), + "otp_length": json.Number(tokenLength), } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -129,7 +131,7 @@ func TestSysGenerateRootAttempt_Setup_PGP(t *testing.T) { "encoded_root_token": "", "pgp_fingerprint": "816938b8a29146fbe245dd29e7cbaf8e011db793", "otp": "", - "otp_length": json.Number("26"), + "otp_length": json.Number(tokenLength), } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -159,7 +161,7 @@ func TestSysGenerateRootAttempt_Cancel(t *testing.T) { "encoded_token": "", "encoded_root_token": "", "pgp_fingerprint": "", - "otp_length": json.Number("26"), + "otp_length": json.Number(tokenLength), } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -191,7 +193,7 @@ func TestSysGenerateRootAttempt_Cancel(t *testing.T) { "pgp_fingerprint": "", "nonce": "", "otp": "", - "otp_length": json.Number("26"), + "otp_length": json.Number(tokenLength), } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) diff --git a/sdk/helper/consts/token_consts.go b/sdk/helper/consts/token_consts.go new file mode 100644 index 0000000000..2b4e0278bf --- /dev/null +++ b/sdk/helper/consts/token_consts.go @@ -0,0 +1,10 @@ +package consts + +const ( + ServiceTokenPrefix = "hvs." + BatchTokenPrefix = "hvb." + RecoveryTokenPrefix = "hvr." + LegacyServiceTokenPrefix = "s." + LegacyBatchTokenPrefix = "b." + LegacyRecoveryTokenPrefix = "r." +) diff --git a/sdk/logical/request.go b/sdk/logical/request.go index f2196aca73..1c400a4cb7 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -220,6 +220,10 @@ type Request struct { // this will be the sha256(sorted policies + namespace) associated with the // client token. ClientID string `json:"client_id" structs:"client_id" mapstructure:"client_id" sentinel:""` + + // InboundSSCToken is the token that arrives on an inbound request, supplied + // by the vault user. + InboundSSCToken string } // Clone returns a deep copy of the request by using copystructure diff --git a/sdk/logical/token.go b/sdk/logical/token.go index b204a4a6c8..ebebd4ad9c 100644 --- a/sdk/logical/token.go +++ b/sdk/logical/token.go @@ -93,6 +93,10 @@ type TokenEntry struct { // ID of this entry, generally a random UUID ID string `json:"id" mapstructure:"id" structs:"id" sentinel:""` + // ExternalID is the ID of a newly created service + // token that will be returned to a user + ExternalID string `json:"-"` + // Accessor for this token, a random UUID Accessor string `json:"accessor" mapstructure:"accessor" structs:"accessor" sentinel:""` diff --git a/vault/audit_broker.go b/vault/audit_broker.go index 9440ec3f61..7389cb5689 100644 --- a/vault/audit_broker.go +++ b/vault/audit_broker.go @@ -91,6 +91,15 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput, head defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now()) a.RLock() defer a.RUnlock() + if in.Request.InboundSSCToken != "" { + if in.Auth != nil { + reqAuthToken := in.Auth.ClientToken + in.Auth.ClientToken = in.Request.InboundSSCToken + defer func() { + in.Auth.ClientToken = reqAuthToken + }() + } + } var retErr *multierror.Error @@ -153,6 +162,15 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput, hea defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now()) a.RLock() defer a.RUnlock() + if in.Request.InboundSSCToken != "" { + if in.Auth != nil { + reqAuthToken := in.Auth.ClientToken + in.Auth.ClientToken = in.Request.InboundSSCToken + defer func() { + in.Auth.ClientToken = reqAuthToken + }() + } + } var retErr *multierror.Error diff --git a/vault/core.go b/vault/core.go index c36fc55f13..cc00c901ab 100644 --- a/vault/core.go +++ b/vault/core.go @@ -74,6 +74,11 @@ const ( coreKeyringCanaryPath = "core/canary-keyring" indexHeaderHMACKeyPath = "core/index-header-hmac-key" + + // ForwardSSCTokenToActive is the value that must be set in the + // forwardToActive to trigger forwarding if a perf standby encounters + // an SSC Token that it does not have the WAL state for. + ForwardSSCTokenToActive = "new_token" ) var ( @@ -576,6 +581,9 @@ type Core struct { enableResponseHeaderHostname bool enableResponseHeaderRaftNodeID bool + // disableSSCTokens is used to disable server side consistent token creation/usage + disableSSCTokens bool + // versionTimestamps is a map of vault versions to timestamps when the version // was first run. Note that because perf standbys should be upgraded first, and // only the active node will actually write the new version timestamp, a perf @@ -702,6 +710,9 @@ type CoreConfig struct { // Whether to send headers in the HTTP response showing hostname or raft node ID EnableResponseHeaderHostname bool EnableResponseHeaderRaftNodeID bool + + // DisableSSCTokens is used to disable the use of server side consistent tokens + DisableSSCTokens bool } // GetServiceRegistration returns the config's ServiceRegistration, or nil if it does @@ -844,6 +855,7 @@ func CreateCore(conf *CoreConfig) (*Core, error) { disableAutopilot: conf.DisableAutopilot, enableResponseHeaderHostname: conf.EnableResponseHeaderHostname, enableResponseHeaderRaftNodeID: conf.EnableResponseHeaderRaftNodeID, + disableSSCTokens: conf.DisableSSCTokens, } c.standbyStopCh.Store(make(chan struct{})) atomic.StoreUint32(c.sealed, 1) @@ -1098,6 +1110,11 @@ func (c *Core) RaftNodeIDHeaderEnabled() bool { return c.enableResponseHeaderRaftNodeID } +// DisableSSCTokens determines whether to use server side consistent tokens or not. +func (c *Core) DisableSSCTokens() bool { + return c.disableSSCTokens +} + // Shutdown is invoked when the Vault instance is about to be terminated. It // should not be accessible as part of an API call as it will cause an availability // problem. It is only used to gracefully quit in the case of HA so that failover @@ -2064,6 +2081,9 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c if err := c.setupQuotas(ctx, false); err != nil { return err } + if err := c.setupHeaderHMACKey(ctx, false); err != nil { + return err + } if !c.IsDRSecondary() { if err := c.startRollback(); err != nil { return err diff --git a/vault/core_test.go b/vault/core_test.go index 157d6db1e3..eec131047a 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "strings" "sync" "testing" "time" @@ -148,6 +149,103 @@ func TestCore_Unseal_MultiShare(t *testing.T) { } } +// TestCore_UseSSCTokenToggleOn will check that the root SSC +// token can be used even when disableSSCTokens is toggled on +func TestCore_UseSSCTokenToggleOn(t *testing.T) { + c, _, root := TestCoreUnsealed(t) + c.disableSSCTokens = true + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "secret/test", + Data: map[string]interface{}{ + "foo": "bar", + "lease": "1h", + }, + ClientToken: root, + } + ctx := namespace.RootContext(nil) + resp, err := c.HandleRequest(ctx, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + + // Read the key + req.Operation = logical.ReadOperation + req.Data = nil + req.SetTokenEntry(&logical.TokenEntry{ID: root, NamespaceID: "root", Policies: []string{"root"}}) + resp, err = c.HandleRequest(ctx, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil || resp.Secret == nil || resp.Data == nil { + t.Fatalf("bad: %#v", resp) + } + if resp.Secret.TTL != time.Hour { + t.Fatalf("bad: %#v", resp.Secret) + } + if resp.Secret.LeaseID == "" { + t.Fatalf("bad: %#v", resp.Secret) + } + if resp.Data["foo"] != "bar" { + t.Fatalf("bad: %#v", resp.Data) + } +} + +// TestCore_UseNonSSCTokenToggleOff will check that the root +// non-SSC token can be used even when disableSSCTokens is toggled +// off. +func TestCore_UseNonSSCTokenToggleOff(t *testing.T) { + coreConfig := &CoreConfig{ + DisableSSCTokens: true, + } + c, _, root := TestCoreUnsealedWithConfig(t, coreConfig) + if len(root) > TokenLength+OldTokenPrefixLength || !strings.HasPrefix(root, "s.") { + t.Fatalf("token is not an old token type: %s, %d", root, len(root)) + } + c.disableSSCTokens = false + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "secret/test", + Data: map[string]interface{}{ + "foo": "bar", + "lease": "1h", + }, + ClientToken: root, + } + ctx := namespace.RootContext(nil) + resp, err := c.HandleRequest(ctx, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %#v", resp) + } + + // Read the key + req.Operation = logical.ReadOperation + req.Data = nil + req.SetTokenEntry(&logical.TokenEntry{ID: root, NamespaceID: "root", Policies: []string{"root"}}) + resp, err = c.HandleRequest(ctx, req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil || resp.Secret == nil || resp.Data == nil { + t.Fatalf("bad: %#v", resp) + } + if resp.Secret.TTL != time.Hour { + t.Fatalf("bad: %#v", resp.Secret) + } + if resp.Secret.LeaseID == "" { + t.Fatalf("bad: %#v", resp.Secret) + } + if resp.Data["foo"] != "bar" { + t.Fatalf("bad: %#v", resp.Data) + } +} + func TestCore_Unseal_Single(t *testing.T) { c := TestCore(t) @@ -745,12 +843,15 @@ func TestCore_HandleLogin_Token(t *testing.T) { } // Check the policy and metadata - te, err := c.tokenStore.Lookup(namespace.RootContext(nil), clientToken) - if err != nil { - t.Fatalf("err: %v", err) + innerToken, _ := c.DecodeSSCToken(clientToken) + te, err := c.tokenStore.Lookup(namespace.RootContext(nil), innerToken) + if err != nil || te == nil { + t.Fatalf("tok: %s, err: %v", clientToken, err) } + + expectedID, _ := c.DecodeSSCToken(clientToken) expect := &logical.TokenEntry{ - ID: clientToken, + ID: expectedID, Accessor: te.Accessor, Parent: "", Policies: []string{"bar", "default", "foo"}, @@ -830,7 +931,7 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) { t.Fatalf("bad: %#v", noop) } if !reflect.DeepEqual(noop.RespAuth[1], auth) { - t.Fatalf("bad: %#v", auth) + t.Fatalf("bad: %#v, vs %#v", auth, noop.RespAuth) } if len(noop.RespReq) != 2 || !reflect.DeepEqual(noop.RespReq[1], req) { t.Fatalf("Bad: %#v", noop.RespReq[1]) @@ -1056,10 +1157,14 @@ func TestCore_HandleRequest_CreateToken_Lease(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } + + expectedID, _ := c.DecodeSSCToken(clientToken) + expectedRootID, _ := c.DecodeSSCToken(root) + expect := &logical.TokenEntry{ - ID: clientToken, + ID: expectedID, Accessor: te.Accessor, - Parent: root, + Parent: expectedRootID, Policies: []string{"default", "foo"}, Path: "auth/token/create", DisplayName: "token", @@ -1104,10 +1209,14 @@ func TestCore_HandleRequest_CreateToken_NoDefaultPolicy(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } + + expectedID, _ := c.DecodeSSCToken(clientToken) + expectedRootID, _ := c.DecodeSSCToken(root) + expect := &logical.TokenEntry{ - ID: clientToken, + ID: expectedID, Accessor: te.Accessor, - Parent: root, + Parent: expectedRootID, Policies: []string{"foo"}, Path: "auth/token/create", DisplayName: "token", @@ -2061,8 +2170,9 @@ func TestCore_RenewToken_SingleRegister(t *testing.T) { } // Verify the token exists - if resp.Data["id"] != newClient { - t.Fatalf("bad: %#v", resp.Data) + if newClient != resp.Data["id"].(string) { + t.Fatalf("bad: return IDs: expected %v, got %v", + resp.Data["id"], newClient) } } diff --git a/vault/core_util.go b/vault/core_util.go index e30cf46ba5..99a0f8d987 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -127,6 +127,10 @@ func (c *Core) namepaceByPath(string) *namespace.Namespace { return namespace.RootNamespace } +func (c *Core) HasWALState(required *logical.WALState, perfStandby bool) bool { + return true +} + func (c *Core) setupReplicatedClusterPrimary(*replication.Cluster) error { return nil } func (c *Core) perfStandbyCount() int { return 0 } @@ -177,6 +181,10 @@ func (c *Core) AllowForwardingViaHeader() bool { return false } +func (c *Core) ForwardToActive() string { + return "" +} + func (c *Core) MissingRequiredState(raw []string, perfStandby bool) bool { return false } diff --git a/vault/core_util_common.go b/vault/core_util_common.go new file mode 100644 index 0000000000..6e05947740 --- /dev/null +++ b/vault/core_util_common.go @@ -0,0 +1,57 @@ +package vault + +import ( + "context" + + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/logical" +) + +func (c *Core) loadHeaderHMACKey(ctx context.Context) error { + ent, err := c.barrier.Get(ctx, indexHeaderHMACKeyPath) + if err != nil { + return err + } + + if ent != nil { + c.IndexHeaderHMACKey.Store(ent.Value) + } + return nil +} + +func (c *Core) headerHMACKey() []byte { + key := c.IndexHeaderHMACKey.Load() + if key == nil { + return nil + } + return key.([]byte) +} + +func (c *Core) setupHeaderHMACKey(ctx context.Context, isPerfStandby bool) error { + if c.IsPerfSecondary() || c.IsDRSecondary() || isPerfStandby { + return c.loadHeaderHMACKey(ctx) + } + ent, err := c.barrier.Get(ctx, indexHeaderHMACKeyPath) + if err != nil { + return err + } + + if ent != nil { + c.IndexHeaderHMACKey.Store(ent.Value) + return nil + } + + key, err := uuid.GenerateUUID() + if err != nil { + return err + } + err = c.barrier.Put(ctx, &logical.StorageEntry{ + Key: indexHeaderHMACKeyPath, + Value: []byte(key), + }) + if err != nil { + return err + } + c.IndexHeaderHMACKey.Store([]byte(key)) + return nil +} diff --git a/vault/expiration.go b/vault/expiration.go index d5574a101e..f6ea50fc67 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -1581,6 +1581,11 @@ func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenE // Setup revocation timer m.updatePending(&le) + if strings.HasPrefix(auth.ClientToken, consts.ServiceTokenPrefix) { + generatedTokenEntry := logical.TokenEntry{Policies: auth.Policies} + tok := m.tokenStore.GenerateSSCTokenID(auth.ClientToken, logical.IndexStateFromContext(ctx), &generatedTokenEntry) + te.ExternalID = tok + } return nil } diff --git a/vault/external_tests/token/batch_token_test.go b/vault/external_tests/token/batch_token_test.go index 5056d16a08..a344b5eed9 100644 --- a/vault/external_tests/token/batch_token_test.go +++ b/vault/external_tests/token/batch_token_test.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/credential/approle" vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" ) @@ -133,10 +134,10 @@ path "kv/*" { if resp.Auth.ClientToken == "" { t.Fatal("expected a client token") } - if batch && !strings.HasPrefix(resp.Auth.ClientToken, "b.") { + if batch && !strings.HasPrefix(resp.Auth.ClientToken, consts.BatchTokenPrefix) { t.Fatal("expected a batch token") } - if !batch && strings.HasPrefix(resp.Auth.ClientToken, "b.") { + if !batch && strings.HasPrefix(resp.Auth.ClientToken, consts.BatchTokenPrefix) { t.Fatal("expected a non-batch token") } return resp.Auth.ClientToken @@ -268,7 +269,7 @@ path "kv/*" { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "b." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.BatchTokenPrefix { t.Fatal(secret.Auth.ClientToken) } @@ -354,7 +355,7 @@ func TestTokenStore_Roles_Batch(t *testing.T) { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "s." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.ServiceTokenPrefix { t.Fatal(secret.Auth.ClientToken) } } @@ -397,7 +398,7 @@ func TestTokenStore_Roles_Batch(t *testing.T) { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "b." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.BatchTokenPrefix { t.Fatal(secret.Auth.ClientToken) } } @@ -424,7 +425,7 @@ func TestTokenStore_Roles_Batch(t *testing.T) { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "b." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.BatchTokenPrefix { t.Fatal(secret.Auth.ClientToken) } // Client specifies service @@ -441,7 +442,7 @@ func TestTokenStore_Roles_Batch(t *testing.T) { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "s." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.ServiceTokenPrefix { t.Fatal(secret.Auth.ClientToken) } // Client doesn't specify @@ -457,7 +458,7 @@ func TestTokenStore_Roles_Batch(t *testing.T) { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "s." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.ServiceTokenPrefix { t.Fatal(secret.Auth.ClientToken) } } @@ -484,7 +485,7 @@ func TestTokenStore_Roles_Batch(t *testing.T) { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "b." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.BatchTokenPrefix { t.Fatal(secret.Auth.ClientToken) } // Client specifies service @@ -501,7 +502,7 @@ func TestTokenStore_Roles_Batch(t *testing.T) { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "s." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.ServiceTokenPrefix { t.Fatal(secret.Auth.ClientToken) } // Client doesn't specify @@ -517,7 +518,7 @@ func TestTokenStore_Roles_Batch(t *testing.T) { if err != nil { t.Fatal(err) } - if secret.Auth.ClientToken[0:2] != "b." { + if secret.Auth.ClientToken[0:vault.TokenPrefixLength] != consts.BatchTokenPrefix { t.Fatal(secret.Auth.ClientToken) } } diff --git a/vault/generate_root.go b/vault/generate_root.go index 231b41016d..ec530ff96a 100644 --- a/vault/generate_root.go +++ b/vault/generate_root.go @@ -64,7 +64,7 @@ func (g generateStandardRootToken) generate(ctx context.Context, c *Core) (strin c.tokenStore.revokeOrphan(ctx, te.ID) } - return te.ID, cleanupFunc, nil + return te.ExternalID, cleanupFunc, nil } // GenerateRootConfig holds the configuration for a root generation @@ -134,7 +134,8 @@ func (c *Core) GenerateRootInit(otp, pgpKey string, strategy GenerateRootStrateg var fingerprint string switch { case len(otp) > 0: - if len(otp) != TokenLength+2 { + if (len(otp) != TokenLength+TokenPrefixLength && !c.DisableSSCTokens()) || + (len(otp) != TokenLength+OldTokenPrefixLength && c.DisableSSCTokens()) { return fmt.Errorf("OTP string is wrong length") } diff --git a/vault/generate_root_recovery.go b/vault/generate_root_recovery.go index afafb5c599..9757c42e5f 100644 --- a/vault/generate_root_recovery.go +++ b/vault/generate_root_recovery.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/hashicorp/vault/sdk/helper/consts" "go.uber.org/atomic" ) @@ -40,11 +41,18 @@ func (g *generateRecoveryToken) authenticate(ctx context.Context, c *Core, combi } func (g *generateRecoveryToken) generate(ctx context.Context, c *Core) (string, func(), error) { - id, err := base62.Random(TokenLength) + var id string + var err error + id, err = base62.Random(TokenLength) if err != nil { return "", nil, err } - token := "r." + id + var token string + if c.DisableSSCTokens() { + token = consts.LegacyRecoveryTokenPrefix + id + } else { + token = consts.RecoveryTokenPrefix + id + } g.token.Store(token) return token, func() { g.token.Store("") }, nil diff --git a/vault/generate_root_test.go b/vault/generate_root_test.go index 9401be7cb1..7b80d3a447 100644 --- a/vault/generate_root_test.go +++ b/vault/generate_root_test.go @@ -45,7 +45,7 @@ func testCore_GenerateRoot_Lifecycle_Common(t *testing.T, c *Core, keys [][]byte t.Fatalf("err: %v", err) } - otp, err := base62.Random(26) + otp, err := base62.Random(TokenPrefixLength + TokenLength) if err != nil { t.Fatal(err) } @@ -89,7 +89,7 @@ func TestCore_GenerateRoot_Init(t *testing.T) { } func testCore_GenerateRoot_Init_Common(t *testing.T, c *Core) { - otp, err := base62.Random(26) + otp, err := base62.Random(TokenPrefixLength + TokenLength) if err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func TestCore_GenerateRoot_InvalidMasterNonce(t *testing.T) { } func testCore_GenerateRoot_InvalidMasterNonce_Common(t *testing.T, c *Core, keys [][]byte) { - otp, err := base62.Random(26) + otp, err := base62.Random(TokenPrefixLength + TokenLength) if err != nil { t.Fatal(err) } @@ -154,7 +154,7 @@ func TestCore_GenerateRoot_Update_OTP(t *testing.T) { } func testCore_GenerateRoot_Update_OTP_Common(t *testing.T, c *Core, keys [][]byte) { - otp, err := base62.Random(26) + otp, err := base62.Random(TokenPrefixLength + TokenLength) if err != nil { t.Fatal(err) } diff --git a/vault/ha.go b/vault/ha.go index 83ca3604f2..12f9d6d745 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -62,6 +62,8 @@ func (c *Core) Standby() (bool, error) { } // PerfStandby checks if the vault is a performance standby +// This function cannot be used during request handling +// because this causes a deadlock with the statelock. func (c *Core) PerfStandby() bool { c.stateLock.RLock() perfStandby := c.perfStandby diff --git a/vault/init.go b/vault/init.go index 6efbcd8267..f092da2941 100644 --- a/vault/init.go +++ b/vault/init.go @@ -394,7 +394,7 @@ func (c *Core) Initialize(ctx context.Context, initParams *InitParams) (*InitRes c.logger.Error("root token generation failed", "error", err) return nil, err } - results.RootToken = rootToken.ID + results.RootToken = rootToken.ExternalID c.logger.Info("root token generated") if initParams.RootTokenPGPKey != "" { diff --git a/vault/request_handling.go b/vault/request_handling.go index e1a4915f1e..61f7a3bf08 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -2,12 +2,15 @@ package vault import ( "context" + "crypto/hmac" + "encoding/base64" "errors" "fmt" "strings" "time" metrics "github.com/armon/go-metrics" + "github.com/golang/protobuf/proto" "github.com/hashicorp/errwrap" multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/go-secure-stdlib/strutil" @@ -24,6 +27,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault/quotas" + "github.com/hashicorp/vault/vault/tokens" uberAtomic "go.uber.org/atomic" ) @@ -449,7 +453,6 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R ctx = context.WithValue(ctx, logical.CtxKeyInFlightRequestID{}, inFlightReqID) } resp, err = c.handleCancelableRequest(ctx, req) - req.SetTokenEntry(nil) cancel() return resp, err @@ -498,6 +501,8 @@ func (c *Core) handleCancelableRequest(ctx context.Context, req *logical.Request if err != nil { return nil, fmt.Errorf("could not parse namespace from http context: %w", err) } + var requestBodyToken string + var returnRequestAuthToken bool // req.Path will be relative by this point. The prefix check is first // to fail faster if we're not in this situation since it's a hot path @@ -536,6 +541,7 @@ func (c *Core) handleCancelableRequest(ctx context.Context, req *logical.Request // requests for these paths always go to the token NS case "auth/token/lookup-self", "auth/token/renew-self", "auth/token/revoke-self": ctx = newCtx + returnRequestAuthToken = true // For the following operations, we can set the proper namespace context // using the token's embedded nsID if a relative path was provided. @@ -556,6 +562,16 @@ func (c *Core) handleCancelableRequest(ctx context.Context, req *logical.Request if token == nil { return logical.ErrorResponse("invalid token"), logical.ErrPermissionDenied } + // We don't care if the token is an server side consistent token or not. Either way, we're going + // to be returning it for these paths instead of the short token stored in vault. + requestBodyToken = token.(string) + if IsSSCToken(token.(string)) { + token, err = c.CheckSSCToken(ctx, token.(string), c.isLoginRequest(ctx, req), c.perfStandby) + if err != nil { + return nil, fmt.Errorf("server side consistent token check failed: %w", err) + } + req.Data["token"] = token + } _, nsID := namespace.SplitIDFromString(token.(string)) if nsID != "" { ns, err := NamespaceByID(ctx, nsID, c) @@ -616,12 +632,32 @@ func (c *Core) handleCancelableRequest(ctx context.Context, req *logical.Request walState := &logical.WALState{} ctx = logical.IndexStateContext(ctx, walState) var auth *logical.Auth - if c.router.LoginPath(ctx, req.Path) { + if c.isLoginRequest(ctx, req) { resp, auth, err = c.handleLoginRequest(ctx, req) } else { resp, auth, err = c.handleRequest(ctx, req) } + // If we saved the token in the request, we should return it in the response + // data. + if resp != nil && resp.Data != nil { + if _, ok := resp.Data["error"]; !ok { + if requestBodyToken != "" { + resp.Data["id"] = requestBodyToken + } else if returnRequestAuthToken && req.InboundSSCToken != "" { + resp.Data["id"] = req.InboundSSCToken + } + } + } + if resp != nil && resp.Auth != nil && requestBodyToken != "" { + // if a client token has already been set and the request body token's internal token + // is equal to that value, then we can return the original request body token + tok, _ := c.DecodeSSCToken(requestBodyToken) + if resp.Auth.ClientToken == tok { + resp.Auth.ClientToken = requestBodyToken + } + } + // Ensure we don't leak internal data if resp != nil { if resp.Secret != nil { @@ -758,6 +794,10 @@ func (c *Core) doRouting(ctx context.Context, req *logical.Request) (*logical.Re return resp, err } +func (c *Core) isLoginRequest(ctx context.Context, req *logical.Request) bool { + return c.router.LoginPath(ctx, req.Path) +} + func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp *logical.Response, retAuth *logical.Auth, retErr error) { defer metrics.MeasureSince([]string{"core", "handle_request"}, time.Now()) @@ -1108,6 +1148,10 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp // We build the "policies" list to be returned by starting with // token policies, and add identity policies right after this // conditional + tok, _ := c.DecodeSSCToken(req.InboundSSCToken) + if resp.Auth.ClientToken == tok { + resp.Auth.ClientToken = req.InboundSSCToken + } resp.Auth.Policies = policyutil.SanitizePolicies(resp.Auth.TokenPolicies, policyutil.DoNotAddDefaultPolicy) } else { resp.Auth.TokenPolicies = policyutil.SanitizePolicies(resp.Auth.Policies, policyutil.DoNotAddDefaultPolicy) @@ -1115,12 +1159,13 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp switch resp.Auth.TokenType { case logical.TokenTypeBatch: case logical.TokenTypeService: - if err := c.expiration.RegisterAuth(ctx, &logical.TokenEntry{ + registeredTokenEntry := &logical.TokenEntry{ TTL: auth.TTL, Policies: auth.TokenPolicies, Path: resp.Auth.CreationPath, NamespaceID: ns.ID, - }, resp.Auth); err != nil { + } + if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth); err != nil { // Best-effort clean up on error, so we log the cleanup error as // a warning but still return as internal error. if err := c.tokenStore.revokeOrphan(ctx, resp.Auth.ClientToken); err != nil { @@ -1130,6 +1175,9 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp retErr = multierror.Append(retErr, ErrInternalError) return nil, auth, retErr } + if registeredTokenEntry.ExternalID != "" { + resp.Auth.ClientToken = registeredTokenEntry.ExternalID + } leaseGenerated = true } } @@ -1570,6 +1618,9 @@ func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path st c.logger.Error("failed to register token lease during login request", "request_path", path, "error", err) return ErrInternalError } + if te.ExternalID != "" { + auth.ClientToken = te.ExternalID + } } return nil @@ -1590,12 +1641,20 @@ func (c *Core) PopulateTokenEntry(ctx context.Context, req *logical.Request) err // a token from a Vault version pre-accessors. We ignore errors for // JWTs. token := req.ClientToken + var err error + req.InboundSSCToken = token + if IsSSCToken(token) { + token, err = c.CheckSSCToken(ctx, token, c.isLoginRequest(ctx, req), c.perfStandby) + if err != nil { + return err + } + } + req.ClientToken = token te, err := c.LookupToken(ctx, token) if err != nil { - dotCount := strings.Count(token, ".") // If we have two dots but the second char is a dot it's a vault // token of the form s.SOMETHING.nsid, not a JWT - if dotCount != 2 || (dotCount == 2 && token[1] == '.') { + if !IsJWT(token) { return fmt.Errorf("error performing token check: %w", err) } } @@ -1606,3 +1665,129 @@ func (c *Core) PopulateTokenEntry(ctx context.Context, req *logical.Request) err } return nil } + +func (c *Core) CheckSSCToken(ctx context.Context, token string, unauth bool, isPerfStandby bool) (string, error) { + if unauth && token != "" { + // This token shouldn't really be here, but alas it was sent along with the request + // Since we're already knee deep in the token checking code pre-existing token checking + // code, we have to deal with this token whether we like it or not. So, we'll just try + // to get the inner token, and if that fails, return the token as-is. We intentionally + // will skip any token checks, because this is an unauthenticated paths and the token + // is just a nuisance rather than a means of auth. + + // We cannot return whatever we like here, because if we do then CheckToken, which looks up + // the corresponding lease, will not find the token entry and lease. There are unauth'ed + // endpoints that use the token entry (such as sys/ui/mounts/internal) to do custom token + // checks, which would then fail. Therefore, we must try to get whatever thing is tied to + // token entries, but we must explicitly not do any SSC Token checks. + tok, err := c.DecodeSSCToken(token) + if err != nil || tok == "" { + return token, nil + } + return tok, nil + } + return c.checkSSCTokenInternal(ctx, token, isPerfStandby) +} + +// DecodeSSCToken returns the random part of an SSCToken without +// performing any signature or WAL checks. +func (c *Core) DecodeSSCToken(token string) (string, error) { + // Skip batch and old style service tokens. These can have the prefix "b.", + // "s." (for old tokens) or "hvb." + if !IsSSCToken(token) { + return token, nil + } + tok, err := c.DecodeSSCTokenInternal(token) + if err != nil { + return "", err + } + return tok.Random, nil +} + +// DecodeSSCTokenInternal is a helper used to get the inner part of a SSC token without +// checking the token signature or the WAL index. +func (c *Core) DecodeSSCTokenInternal(token string) (*tokens.Token, error) { + signedToken := &tokens.SignedToken{} + + // Skip batch and old style service tokens. These can have the prefix "b.", + // "s." (for old tokens) or "hvb." + if !strings.HasPrefix(token, consts.ServiceTokenPrefix) { + return nil, fmt.Errorf("not service token") + } + + // Consider the suffix of the token only when unmarshalling + suffixToken := token[4:] + + tokenBytes, err := base64.RawURLEncoding.DecodeString(suffixToken) + if err != nil { + return nil, fmt.Errorf("can't decode token") + } + + err = proto.Unmarshal(tokenBytes, signedToken) + if err != nil { + return nil, err + } + plainToken := &tokens.Token{} + err2 := proto.Unmarshal([]byte(signedToken.Token), plainToken) + if err2 != nil { + return nil, err2 + } + return plainToken, nil +} + +func (c *Core) checkSSCTokenInternal(ctx context.Context, token string, isPerfStandby bool) (string, error) { + signedToken := &tokens.SignedToken{} + + // Skip batch and old style service tokens. These can have the prefix "b.", + // "s." (for old tokens) or "hvb." + if !strings.HasPrefix(token, consts.ServiceTokenPrefix) { + return token, nil + } + + // Check token length to guess if this is an server side consistent token or not. + // Note that even when the DisableSSCTokens flag is set, index + // bearing tokens that have already been given out may still be used. + if !IsSSCToken(token) { + return token, nil + } + + // Consider the suffix of the token only when unmarshalling + suffixToken := token[4:] + + tokenBytes, err := base64.RawURLEncoding.DecodeString(suffixToken) + if err != nil { + c.logger.Warn("cannot decode token", "error", err) + return token, nil + } + + err = proto.Unmarshal(tokenBytes, signedToken) + if err != nil { + return "", fmt.Errorf("error occurred when unmarshalling ssc token: %w", err) + } + hm, err := c.tokenStore.CalculateSignedTokenHMAC(signedToken.Token) + if !hmac.Equal(hm, signedToken.Hmac) { + return "", fmt.Errorf("token mac for %+v is incorrect: err %w", signedToken, err) + } + plainToken := &tokens.Token{} + err = proto.Unmarshal([]byte(signedToken.Token), plainToken) + if err != nil { + return "", err + } + ep := int(plainToken.IndexEpoch) + if ep < c.tokenStore.GetSSCTokensGenerationCounter() { + return plainToken.Random, nil + } + + requiredWalState := &logical.WALState{ClusterID: c.clusterID.Load(), LocalIndex: plainToken.LocalIndex, ReplicatedIndex: 0} + if c.HasWALState(requiredWalState, isPerfStandby) { + return plainToken.Random, nil + } + // Make sure to forward the request instead of checking the token if the flag + // is set and we're on a perf standby + if c.ForwardToActive() == ForwardSSCTokenToActive && isPerfStandby { + return "", logical.ErrPerfStandbyPleaseForward + } + // In this case, the server side consistent token cannot be used on this node. We return the appropriate + // status code. + return "", logical.ErrMissingRequiredState +} diff --git a/vault/router.go b/vault/router.go index 6426e4eb8f..98c805d680 100644 --- a/vault/router.go +++ b/vault/router.go @@ -603,7 +603,8 @@ func (r *Router) routeCommon(ctx context.Context, req *logical.Request, existenc } switch { - case te.NamespaceID == namespace.RootNamespaceID && !strings.HasPrefix(req.ClientToken, "s."): + case te.NamespaceID == namespace.RootNamespaceID && !strings.HasPrefix(req.ClientToken, "s.") && + !strings.HasPrefix(req.ClientToken, consts.ServiceTokenPrefix): // In order for the token store to revoke later, we need to have the same // salted ID, so we double-salt what's going to the cubbyhole backend salt, err := r.tokenStoreSaltFunc(ctx) @@ -643,6 +644,13 @@ func (r *Router) routeCommon(ctx context.Context, req *logical.Request, existenc headers := req.Headers req.Headers = nil + // Cache the saved request SSC token + inboundToken := req.InboundSSCToken + + // Ensure that the inbound token we cache in the + // request during token creation isn't sent to backends + req.InboundSSCToken = "" + // Filter and add passthrough headers to the backend var passthroughRequestHeaders []string if rawVal, ok := re.mountEntry.synthesizedConfigCache.Load("passthrough_request_headers"); ok { @@ -696,6 +704,13 @@ func (r *Router) routeCommon(ctx context.Context, req *logical.Request, existenc req.MFACreds = originalMFACreds + req.InboundSSCToken = inboundToken + + // Before resetting the tokenEntry, see if an ExternalID was added + if req.TokenEntry() != nil && req.TokenEntry().ExternalID != "" { + reqTokenEntry.ExternalID = req.TokenEntry().ExternalID + } + req.SetTokenEntry(reqTokenEntry) req.ControlGroup = originalControlGroup }() diff --git a/vault/testing.go b/vault/testing.go index 87cc68c76c..4e12363ab4 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -183,6 +183,7 @@ func TestCoreWithSealAndUI(t testing.T, opts *CoreConfig) *Core { conf.NumExpirationWorkers = numExpirationWorkersTest conf.RawConfig = opts.RawConfig conf.EnableResponseHeaderHostname = opts.EnableResponseHeaderHostname + conf.DisableSSCTokens = opts.DisableSSCTokens if opts.Logger != nil { conf.Logger = opts.Logger @@ -327,7 +328,11 @@ func TestCoreInitClusterWrapperSetup(t testing.T, core *Core, handler http.Handl if err != nil { t.Fatalf("err: %s", err) } - return result.SecretShares, result.RecoveryShares, result.RootToken + innerToken, err := core.DecodeSSCToken(result.RootToken) + if err != nil { + t.Fatalf("err: %s", err) + } + return result.SecretShares, result.RecoveryShares, innerToken } func TestCoreUnseal(core *Core, key []byte) (bool, error) { diff --git a/vault/token_store.go b/vault/token_store.go index 346efec3cf..d3924afe38 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -2,6 +2,8 @@ package vault import ( "context" + "crypto/hmac" + "crypto/sha256" "encoding/base64" "encoding/json" "errors" @@ -34,6 +36,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/plugin/pb" + "github.com/hashicorp/vault/vault/tokens" "github.com/mitchellh/mapstructure" ) @@ -62,13 +65,31 @@ const ( // that the token is but is currently fulfilling its final use; after this // request it will not be able to be looked up as being valid. tokenRevocationPending = -1 -) -var ( // TokenLength is the size of tokens we are currently generating, without // any namespace information TokenLength = 24 + // MaxNsIdLength is the maximum namespace ID length (4 characters prepended by a ".") + MaxNsIdLength = 5 + + // TokenPrefixLength is the length of the new token prefixes ("hvs.", "hvb.", + // and "hvr.") + TokenPrefixLength = 4 + + // OldTokenPrefixLength is the length of the old token prefixes ("s.", "b.". "r.") + OldTokenPrefixLength = 2 + + // GenerationCounterBuffer is a buffer for the generation counter estimation in the + // case where a counter cannot be retrieved from storage + GenerationCounterBuffer = 5 + + // MaxRetrySSCTokensGenerationCounter is the maximum number of retries the TokenStore + // will make when attempting to get the SSCTokensGenerationCounter + MaxRetrySSCTokensGenerationCounter = 3 +) + +var ( // displayNameSanitize is used to sanitize a display name given to a token. displayNameSanitize = regexp.MustCompile("[^a-zA-Z0-9-]") @@ -92,7 +113,7 @@ var ( view := storage.(*BarrierView) switch { - case te.NamespaceID == namespace.RootNamespaceID && !strings.HasPrefix(te.ID, "s."): + case te.NamespaceID == namespace.RootNamespaceID && !IsServiceToken(te.ID): saltedID, err := ts.SaltID(ctx, te.ID) if err != nil { return err @@ -532,6 +553,12 @@ type TokenStore struct { identityPoliciesDeriverFunc func(string) (*identity.Entity, []string, error) quitContext context.Context + + // sscTokensGenerationCounter is a per-cluster version that counts how many + // "sync points" the cluster has encountered in its lifecycle. "Sync points" are the + // number of times all nodes in the cluster have stepped down. Currently the only sync + // point is a DR cluster promoting to the primary. + sscTokensGenerationCounter SSCTokenGenerationCounter } // NewTokenStore is used to construct a token store that is @@ -586,6 +613,10 @@ func NewTokenStore(ctx context.Context, logger log.Logger, core *Core, config *l t.Backend.Setup(ctx, config) + if err := t.loadSSCTokensGenerationCounter(ctx); err != nil { + return t, err + } + return t, nil } @@ -865,6 +896,8 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err if userSelectedID { switch { + case strings.HasPrefix(entry.ID, consts.ServiceTokenPrefix): + return fmt.Errorf("custom token ID cannot have the 'hvs.' prefix") case strings.HasPrefix(entry.ID, "s."): return fmt.Errorf("custom token ID cannot have the 's.' prefix") case strings.Contains(entry.ID, "."): @@ -873,7 +906,11 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err } if !userSelectedID { - entry.ID = fmt.Sprintf("s.%s", entry.ID) + if !ts.core.DisableSSCTokens() { + entry.ID = fmt.Sprintf("hvs.%s", entry.ID) + } else { + entry.ID = fmt.Sprintf("s.%s", entry.ID) + } } // Attach namespace ID for tokens that are not belonging to the root @@ -882,7 +919,7 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err entry.ID = fmt.Sprintf("%s.%s", entry.ID, tokenNS.ID) } - if tokenNS.ID != namespace.RootNamespaceID || strings.HasPrefix(entry.ID, "s.") { + if tokenNS.ID != namespace.RootNamespaceID || strings.HasPrefix(entry.ID, consts.ServiceTokenPrefix) { if entry.CubbyholeID == "" { cubbyholeID, err := base62.Random(TokenLength) if err != nil { @@ -906,7 +943,15 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err return err } - return ts.storeCommon(ctx, entry, true) + err = ts.storeCommon(ctx, entry, true) + if err != nil { + return err + } + entry.ExternalID = entry.ID + if !userSelectedID && !ts.core.DisableSSCTokens() { + entry.ExternalID = ts.GenerateSSCTokenID(entry.ID, logical.IndexStateFromContext(ctx), entry) + } + return nil case logical.TokenTypeBatch: // Ensure fields we don't support/care about are nilled, proto marshal, @@ -946,7 +991,11 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err } bEntry := base64.RawURLEncoding.EncodeToString(eEntry) - entry.ID = fmt.Sprintf("b.%s", bEntry) + if ts.core.DisableSSCTokens() { + entry.ID = fmt.Sprintf("b.%s", bEntry) + } else { + entry.ID = fmt.Sprintf("hvb.%s", bEntry) + } if tokenNS.ID != namespace.RootNamespaceID { entry.ID = fmt.Sprintf("%s.%s", entry.ID, tokenNS.ID) @@ -959,6 +1008,84 @@ func (ts *TokenStore) create(ctx context.Context, entry *logical.TokenEntry) err } } +// GenerateSSCTokenID generates the ID field of the TokenEntry struct for newly +// minted service tokens. This function is meant to be robust so as to allow vault +// to continue operating even in the case where IDs can't be generated. Thus it logs +// errors as opposed to throwing them. +func (ts *TokenStore) GenerateSSCTokenID(innerToken string, walState *logical.WALState, te *logical.TokenEntry) string { + // Set up the prefix prepending function. This should really only be used in + // the token ID generation code itself. + prependServicePrefix := func(externalToken string) string { + if strings.HasPrefix(externalToken, consts.ServiceTokenPrefix) { + // We didn't generate a SSC token and furthermore are attempting + // to regenerate a token that already has passed through + // GenerateSSCTokenID, as it has a prefix. + return externalToken + } + return consts.ServiceTokenPrefix + externalToken + } + + // If we are not using server side consistent tokens, log it and return here + if ts.core.DisableSSCTokens() { + ts.logger.Trace("server side consistent tokens are disabled") + return prependServicePrefix(innerToken) + } + + // If there is no WAL state, do not throw an error as it may be a single + // node cluster, or an OSS core. Instead, log that this has happened and + // create a walState with nil values to signify that these values should + // be ignored + if walState == nil { + ts.logger.Debug("no wal state found when generating token") + walState = &logical.WALState{} + } + if te.IsRoot() { + return prependServicePrefix(innerToken) + } + + // If the token is a root token, we will always set the index and epoch to 0 so as to ensure + // that root tokens are always fixed size. This is required because during root token + // generation, the size needs to be known to create the OTP. + + localIndex := walState.LocalIndex + tokenGenerationCounter := uint32(ts.GetSSCTokensGenerationCounter()) + + t := tokens.Token{Random: innerToken, LocalIndex: localIndex, IndexEpoch: tokenGenerationCounter} + marshalledToken, err := proto.Marshal(&t) + if err != nil { + ts.logger.Error("unable to marshal token", "error", err) + return prependServicePrefix(innerToken) + } + + hmac, err := ts.CalculateSignedTokenHMAC(marshalledToken) + if err != nil { + // If we can't calculate the HMAC for any reason, we should log an error + // but still allow vault to function, using the old token instead. + ts.logger.Error("unable to calculate token signature", "error", err) + return prependServicePrefix(innerToken) + } + st := tokens.SignedToken{TokenVersion: 1, Token: marshalledToken, Hmac: hmac} + + marshalledSignedToken, err := proto.Marshal(&st) + if err != nil { + ts.logger.Error("unable to marshal signed token", "error", err) + return prependServicePrefix(innerToken) + } + generatedSSCToken := base64.RawURLEncoding.EncodeToString(marshalledSignedToken) + return prependServicePrefix(generatedSSCToken) +} + +func (ts *TokenStore) CalculateSignedTokenHMAC(marshalledToken []byte) ([]byte, error) { + key := ts.core.headerHMACKey() + if key == nil { + return nil, errors.New("token hmac key has not been initialized or has not been replicated yet to the active node") + } + + hm := hmac.New(sha256.New, key) + hm.Write([]byte(marshalledToken)) + return hm.Sum(nil), nil +} + // Store is used to store an updated token entry without writing the // secondary index. func (ts *TokenStore) store(ctx context.Context, entry *logical.TokenEntry) error { @@ -1121,7 +1248,7 @@ func (ts *TokenStore) Lookup(ctx context.Context, id string) (*logical.TokenEntr } // If it starts with "b." it's a batch token - if len(id) > 2 && strings.HasPrefix(id, "b.") { + if IsBatchToken(id) { return ts.lookupBatchToken(ctx, id) } @@ -1132,6 +1259,16 @@ func (ts *TokenStore) Lookup(ctx context.Context, id string) (*logical.TokenEntr return ts.lookupInternal(ctx, id, false, false) } +func (ts *TokenStore) stripBatchPrefix(id string) string { + if strings.HasPrefix(id, "b.") { + return id[2:] + } + if strings.HasPrefix(id, consts.BatchTokenPrefix) { + return id[4:] + } + return "" +} + // lookupTainted is used to find a token that may or may not be tainted given // its ID. It acquires a read lock, then calls lookupInternal. func (ts *TokenStore) lookupTainted(ctx context.Context, id string) (*logical.TokenEntry, error) { @@ -1149,7 +1286,7 @@ func (ts *TokenStore) lookupTainted(ctx context.Context, id string) (*logical.To func (ts *TokenStore) lookupBatchTokenInternal(ctx context.Context, id string) (*logical.TokenEntry, error) { // Strip the b. from the front and namespace ID from the back - bEntry, _ := namespace.SplitIDFromString(id[2:]) + bEntry, _ := namespace.SplitIDFromString(ts.stripBatchPrefix(id)) eEntry, err := base64.RawURLEncoding.DecodeString(bEntry) if err != nil { @@ -1210,11 +1347,25 @@ func (ts *TokenStore) lookupInternal(ctx context.Context, id string, salted, tai return nil, fmt.Errorf("failed to find namespace in context: %w", err) } - // If it starts with "b." it's a batch token - if len(id) > 2 && strings.HasPrefix(id, "b.") { + // If it starts with "b." or consts.BatchTokenPrefix it's a batch token + if IsBatchToken(id) { return ts.lookupBatchToken(ctx, id) } + // lookupInternal is called internally with tokens that oftentimes come from request + // parameters that we cannot really guess. Most notably, these calls come from either + // validateWrappedToken and/or lookupTokenTainted, used in the wrapping token logic. + // We can't really catch all these instances of lookup token, so we have to check the + // SSC token in this function itself. + if IsSSCToken(id) { + internalID, err := ts.core.DecodeSSCToken(id) + if err == nil && internalID != "" { + // A malformed token was passed in, is our best guess here. Just use id going + // forward. + id = internalID + } + } + var raw *logical.StorageEntry lookupID := id @@ -2033,7 +2184,7 @@ func (ts *TokenStore) handleTidy(ctx context.Context, req *logical.Request, data default: // Cache the cubbyhole storage key when the token is valid switch { - case te.NamespaceID == namespace.RootNamespaceID && !strings.HasPrefix(te.ID, "s."): + case te.NamespaceID == namespace.RootNamespaceID && !IsServiceToken(te.ID): saltedID, err := ts.SaltID(quitCtx, te.ID) if err != nil { tidyErrors = multierror.Append(tidyErrors, fmt.Errorf("failed to create salted token id: %w", err)) diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 8f0d36a0fd..011ac641b0 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -14,6 +14,8 @@ import ( "time" "github.com/go-test/deep" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/parseutil" @@ -22,6 +24,7 @@ import ( "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/locksutil" "github.com/hashicorp/vault/sdk/helper/tokenutil" "github.com/hashicorp/vault/sdk/logical" @@ -223,7 +226,7 @@ func TestTokenStore_Salting(t *testing.T) { t.Fatalf("expected sha1 hash; got sha2-256 hmac") } - saltedID, err = ts.SaltID(namespace.RootContext(nil), "s.foo") + saltedID, err = ts.SaltID(namespace.RootContext(nil), "hvs.foo") if err != nil { t.Fatal(err) } @@ -240,7 +243,7 @@ func TestTokenStore_Salting(t *testing.T) { t.Fatalf("expected sha2-256 hmac; got sha1 hash") } - saltedID, err = ts.SaltID(nsCtx, "s.foo") + saltedID, err = ts.SaltID(nsCtx, "hvs.foo") if err != nil { t.Fatal(err) } @@ -753,7 +756,8 @@ func TestTokenStore_HandleRequest_ListAccessors(t *testing.T) { } // Revoke root to make the number of accessors match - salted, err := ts.SaltID(namespace.RootContext(nil), root) + internalRoot, _ := c.DecodeSSCToken(root) + salted, err := ts.SaltID(namespace.RootContext(nil), internalRoot) if err != nil { t.Fatal(err) } @@ -1000,9 +1004,7 @@ func TestTokenStore_RootToken(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } - if !reflect.DeepEqual(out, te) { - t.Fatalf("bad: expected:%#v\nactual:%#v", te, out) - } + deepEqualTokenEntries(t, out, te) } func TestTokenStore_NoRootBatch(t *testing.T) { @@ -1045,9 +1047,7 @@ func TestTokenStore_CreateLookup(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } - if !reflect.DeepEqual(out, ent) { - t.Fatalf("bad: expected:%#v\nactual:%#v", ent, out) - } + deepEqualTokenEntries(t, out, ent) // New store should share the salt ts2, err := NewTokenStore(namespace.RootContext(nil), hclog.New(&hclog.LoggerOptions{}), c, getBackendConfig(c)) @@ -1061,9 +1061,7 @@ func TestTokenStore_CreateLookup(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } - if !reflect.DeepEqual(out, ent) { - t.Fatalf("bad: expected:%#v\nactual:%#v", ent, out) - } + deepEqualTokenEntries(t, out, ent) } func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) { @@ -1089,9 +1087,7 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } - if !reflect.DeepEqual(out, ent) { - t.Fatalf("bad: expected:%#v\nactual:%#v", ent, out) - } + deepEqualTokenEntries(t, out, ent) // New store should share the salt ts2, err := NewTokenStore(namespace.RootContext(nil), hclog.New(&hclog.LoggerOptions{}), c, getBackendConfig(c)) @@ -1105,9 +1101,7 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } - if !reflect.DeepEqual(out, ent) { - t.Fatalf("bad: expected:%#v\nactual:%#v", ent, out) - } + deepEqualTokenEntries(t, out, ent) } func TestTokenStore_CreateLookup_ExpirationInRestoreMode(t *testing.T) { @@ -1150,9 +1144,7 @@ func TestTokenStore_CreateLookup_ExpirationInRestoreMode(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } - if !reflect.DeepEqual(out, ent) { - t.Fatalf("bad: expected:%#v\nactual:%#v", ent, out) - } + deepEqualTokenEntries(t, out, ent) // Set to expired lease time le.ExpireTime = time.Now().Add(-1 * time.Hour) @@ -1384,9 +1376,7 @@ func TestTokenStore_Revoke_Orphan(t *testing.T) { // Unset the expected token parent's ID ent2.Parent = "" - if !reflect.DeepEqual(out, ent2) { - t.Fatalf("bad:\nexpected:%#v\nactual:%#v", ent2, out) - } + deepEqualTokenEntries(t, out, ent2) } // This was the original function name, and now it just calls @@ -1664,6 +1654,12 @@ func TestTokenStore_HandleRequest_CreateToken_DisplayName(t *testing.T) { } } +func deepEqualTokenEntries(t *testing.T, a *logical.TokenEntry, b *logical.TokenEntry) { + if diff := cmp.Diff(a, b, cmpopts.IgnoreFields(logical.TokenEntry{}, "ExternalID")); diff != "" { + t.Fatalf("bad diff in token entries: %s", diff) + } +} + func TestTokenStore_HandleRequest_CreateToken_NumUses(t *testing.T) { c, _, root := TestCoreUnsealed(t) ts := c.tokenStore @@ -2421,8 +2417,10 @@ func testTokenStoreHandleRequestLookup(t *testing.T, batch, periodic bool) { t.Fatalf("bad: %#v", resp) } + internalRoot, _ := c.DecodeSSCToken(root) + exp := map[string]interface{}{ - "id": root, + "id": internalRoot, "accessor": resp.Data["accessor"].(string), "policies": []string{"root"}, "path": "auth/token/root", @@ -6000,7 +5998,7 @@ func TestTokenStore_TokenID(t *testing.T) { c, _, initToken := TestCoreUnsealed(t) ts := c.tokenStore - // Ensure that a regular service token has a "s." prefix + // Ensure that a regular service token has a consts.ServiceTokenPrefix prefix resp, err := ts.HandleRequest(namespace.RootContext(nil), &logical.Request{ ClientToken: initToken, Path: "create", @@ -6009,8 +6007,8 @@ func TestTokenStore_TokenID(t *testing.T) { if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("bad: resp: %#v\nerr: %v", resp, err) } - if !strings.HasPrefix(resp.Auth.ClientToken, "s.") { - t.Fatalf("token %q does not have a 's.' prefix", resp.Auth.ClientToken) + if !strings.HasPrefix(resp.Auth.ClientToken, consts.ServiceTokenPrefix) { + t.Fatalf("token %q does not have a 'hvs.' prefix", resp.Auth.ClientToken) } }) @@ -6041,19 +6039,19 @@ func TestTokenStore_TokenID(t *testing.T) { c, _, initToken := TestCoreUnsealed(t) ts := c.tokenStore - // Ensure that custom token ID having a "s." prefix fails + // Ensure that custom token ID having a consts.ServiceTokenPrefix prefix fails resp, err := ts.HandleRequest(namespace.RootContext(nil), &logical.Request{ ClientToken: initToken, Path: "create", Operation: logical.UpdateOperation, Data: map[string]interface{}{ - "id": "s.foobar", + "id": "hvs.foobar", }, }) if err == nil { t.Fatalf("expected an error") } - if resp.Error().Error() != "custom token ID cannot have the 's.' prefix" { + if resp.Error().Error() != "custom token ID cannot have the 'hvs.' prefix" { t.Fatalf("expected input error not present in error response") } }) diff --git a/vault/token_store_util_common.go b/vault/token_store_util_common.go new file mode 100644 index 0000000000..235e2d2167 --- /dev/null +++ b/vault/token_store_util_common.go @@ -0,0 +1,59 @@ +package vault + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/hashicorp/vault/sdk/logical" +) + +const sscGenCounterPath string = "core/sscGenCounter/" + +type SSCTokenGenerationCounter struct { + Counter int +} + +func (ts *TokenStore) GetSSCTokensGenerationCounter() int { + return ts.sscTokensGenerationCounter.Counter +} + +func (ts *TokenStore) loadSSCTokensGenerationCounter(ctx context.Context) error { + sscTokensGenerationCounterStorageVal, err := ts.core.barrier.Get(ctx, sscGenCounterPath) + if err != nil { + return fmt.Errorf("unable to retrieve SSCTokenGenerationCounter from storage: err %w", err) + } + if sscTokensGenerationCounterStorageVal == nil { + ts.logger.Trace("no token generation counter found in storage") + ts.sscTokensGenerationCounter = SSCTokenGenerationCounter{Counter: 0} + return nil + } + var sscTokensGenerationCounter SSCTokenGenerationCounter + err = json.Unmarshal(sscTokensGenerationCounterStorageVal.Value, &sscTokensGenerationCounter) + if err != nil { + return fmt.Errorf("malformed token generation counter found in storage: err %w", err) + } + ts.sscTokensGenerationCounter = sscTokensGenerationCounter + return nil +} + +func (ts *TokenStore) UpdateSSCTokensGenerationCounter(ctx context.Context) error { + ts.sscTokensGenerationCounter.Counter += 1 + if ts.sscTokensGenerationCounter.Counter <= 0 { + // Don't store the 0 value + ts.logger.Warn("attempt to store non-positive token generation counter was ignored", + "sscTokensGenerationCounter", ts.sscTokensGenerationCounter.Counter) + } + marshalledCtr, err := json.Marshal(ts.sscTokensGenerationCounter) + if err != nil { + return err + } + err = ts.core.barrier.Put(ctx, &logical.StorageEntry{ + Key: sscGenCounterPath, + Value: marshalledCtr, + }) + if err != nil { + return err + } + return nil +} diff --git a/vault/tokens/token.pb.go b/vault/tokens/token.pb.go new file mode 100644 index 0000000000..4604ab3c59 --- /dev/null +++ b/vault/tokens/token.pb.go @@ -0,0 +1,247 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.17.3 +// source: vault/tokens/token.proto + +package tokens + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// SignedToken +type SignedToken struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TokenVersion uint64 `protobuf:"varint,1,opt,name=token_version,json=tokenVersion,proto3" json:"token_version,omitempty"` // always 1 for now + Hmac []byte `protobuf:"bytes,2,opt,name=hmac,proto3" json:"hmac,omitempty"` // HMAC of token + Token []byte `protobuf:"bytes,3,opt,name=token,proto3" json:"token,omitempty"` // protobuf-marshalled Token message +} + +func (x *SignedToken) Reset() { + *x = SignedToken{} + if protoimpl.UnsafeEnabled { + mi := &file_vault_tokens_token_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SignedToken) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SignedToken) ProtoMessage() {} + +func (x *SignedToken) ProtoReflect() protoreflect.Message { + mi := &file_vault_tokens_token_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SignedToken.ProtoReflect.Descriptor instead. +func (*SignedToken) Descriptor() ([]byte, []int) { + return file_vault_tokens_token_proto_rawDescGZIP(), []int{0} +} + +func (x *SignedToken) GetTokenVersion() uint64 { + if x != nil { + return x.TokenVersion + } + return 0 +} + +func (x *SignedToken) GetHmac() []byte { + if x != nil { + return x.Hmac + } + return nil +} + +func (x *SignedToken) GetToken() []byte { + if x != nil { + return x.Token + } + return nil +} + +type Token struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Random string `protobuf:"bytes,1,opt,name=random,proto3" json:"random,omitempty"` // unencoded equiv of former $randbase62 + LocalIndex uint64 `protobuf:"varint,2,opt,name=local_index,json=localIndex,proto3" json:"local_index,omitempty"` // required storage state to have this token + IndexEpoch uint32 `protobuf:"varint,3,opt,name=index_epoch,json=indexEpoch,proto3" json:"index_epoch,omitempty"` +} + +func (x *Token) Reset() { + *x = Token{} + if protoimpl.UnsafeEnabled { + mi := &file_vault_tokens_token_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Token) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Token) ProtoMessage() {} + +func (x *Token) ProtoReflect() protoreflect.Message { + mi := &file_vault_tokens_token_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Token.ProtoReflect.Descriptor instead. +func (*Token) Descriptor() ([]byte, []int) { + return file_vault_tokens_token_proto_rawDescGZIP(), []int{1} +} + +func (x *Token) GetRandom() string { + if x != nil { + return x.Random + } + return "" +} + +func (x *Token) GetLocalIndex() uint64 { + if x != nil { + return x.LocalIndex + } + return 0 +} + +func (x *Token) GetIndexEpoch() uint32 { + if x != nil { + return x.IndexEpoch + } + return 0 +} + +var File_vault_tokens_token_proto protoreflect.FileDescriptor + +var file_vault_tokens_token_proto_rawDesc = []byte{ + 0x0a, 0x18, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x2f, 0x74, + 0x6f, 0x6b, 0x65, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x73, 0x22, 0x5c, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x65, 0x64, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0c, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x56, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x6d, 0x61, 0x63, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x68, 0x6d, 0x61, 0x63, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x22, 0x61, 0x0a, 0x05, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x61, 0x6e, + 0x64, 0x6f, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x72, 0x61, 0x6e, 0x64, 0x6f, + 0x6d, 0x12, 0x1f, 0x0a, 0x0b, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0a, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x6e, 0x64, + 0x65, 0x78, 0x12, 0x1f, 0x0a, 0x0b, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x5f, 0x65, 0x70, 0x6f, 0x63, + 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x45, 0x70, + 0x6f, 0x63, 0x68, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x76, 0x61, 0x75, 0x6c, + 0x74, 0x2f, 0x76, 0x61, 0x75, 0x6c, 0x74, 0x2f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_vault_tokens_token_proto_rawDescOnce sync.Once + file_vault_tokens_token_proto_rawDescData = file_vault_tokens_token_proto_rawDesc +) + +func file_vault_tokens_token_proto_rawDescGZIP() []byte { + file_vault_tokens_token_proto_rawDescOnce.Do(func() { + file_vault_tokens_token_proto_rawDescData = protoimpl.X.CompressGZIP(file_vault_tokens_token_proto_rawDescData) + }) + return file_vault_tokens_token_proto_rawDescData +} + +var file_vault_tokens_token_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_vault_tokens_token_proto_goTypes = []interface{}{ + (*SignedToken)(nil), // 0: tokens.SignedToken + (*Token)(nil), // 1: tokens.Token +} +var file_vault_tokens_token_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_vault_tokens_token_proto_init() } +func file_vault_tokens_token_proto_init() { + if File_vault_tokens_token_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_vault_tokens_token_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SignedToken); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_vault_tokens_token_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Token); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_vault_tokens_token_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_vault_tokens_token_proto_goTypes, + DependencyIndexes: file_vault_tokens_token_proto_depIdxs, + MessageInfos: file_vault_tokens_token_proto_msgTypes, + }.Build() + File_vault_tokens_token_proto = out.File + file_vault_tokens_token_proto_rawDesc = nil + file_vault_tokens_token_proto_goTypes = nil + file_vault_tokens_token_proto_depIdxs = nil +} diff --git a/vault/tokens/token.proto b/vault/tokens/token.proto new file mode 100644 index 0000000000..ae4364a3af --- /dev/null +++ b/vault/tokens/token.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +option go_package = "github.com/hashicorp/vault/vault/tokens"; + +package tokens; + +// SignedToken +message SignedToken { + uint64 token_version = 1; // always 1 for now + bytes hmac = 2; // HMAC of token + bytes token = 3; // protobuf-marshalled Token message +} + +message Token { + string random = 1; // unencoded equiv of former $randbase62 + uint64 local_index = 2; // required storage state to have this token + uint32 index_epoch = 3; +} \ No newline at end of file diff --git a/vault/version_store.go b/vault/version_store.go index 7e1d446904..6890bcf187 100644 --- a/vault/version_store.go +++ b/vault/version_store.go @@ -4,12 +4,16 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" + "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" ) -const vaultVersionPath string = "core/versions/" +const ( + vaultVersionPath string = "core/versions/" +) // storeVersionTimestamp will store the version and timestamp pair to storage // only if no entry for that version already exists in storage. Version @@ -128,3 +132,23 @@ func (c *Core) loadVersionTimestamps(ctx context.Context) error { } return nil } + +func IsJWT(token string) bool { + return len(token) > 3 && strings.Count(token, ".") == 2 && + (token[3] != '.' && token[1] != '.') +} + +func IsSSCToken(token string) bool { + return len(token) > MaxNsIdLength+TokenLength+TokenPrefixLength && + strings.HasPrefix(token, consts.ServiceTokenPrefix) +} + +func IsServiceToken(token string) bool { + return strings.HasPrefix(token, consts.ServiceTokenPrefix) || + strings.HasPrefix(token, consts.LegacyServiceTokenPrefix) +} + +func IsBatchToken(token string) bool { + return strings.HasPrefix(token, consts.LegacyBatchTokenPrefix) || + strings.HasPrefix(token, consts.BatchTokenPrefix) +} diff --git a/vault/wrapping.go b/vault/wrapping.go index b56b8aef1a..5b1d320806 100644 --- a/vault/wrapping.go +++ b/vault/wrapping.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" "fmt" - "strings" "time" "github.com/armon/go-metrics" @@ -166,7 +165,7 @@ DONELISTHANDLING: }, ) - resp.WrapInfo.Token = te.ID + resp.WrapInfo.Token = te.ExternalID resp.WrapInfo.Accessor = te.Accessor resp.WrapInfo.CreationTime = creationTime // If this is not a rewrap, store the request path as creation_path @@ -403,7 +402,7 @@ func (c *Core) validateWrappingToken(ctx context.Context, req *logical.Request) // token to be a JWT -- namespaced tokens have two dots too, but Vault // token types (for now at least) begin with a letter representing a type // and then a dot. - if strings.Count(token, ".") == 2 && token[1] != '.' { + if IsJWT(token) { // Implement the jose library way parsedJWT, err := squarejwt.ParseSigned(token) if err != nil {