mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 02:02:43 +00:00 
			
		
		
		
	VAULT-6614 Enable role based quotas for lease-count quotas (OSS) (#16157)
* VAULT-6613 add DetermineRoleFromLoginRequest function to Core * Fix body handling * Role resolution for rate limit quotas * VAULT-6613 update precedence test * Add changelog * VAULT-6614 start of changes for roles in LCQs * Expiration changes for leases * Add role information to RequestAuth * VAULT-6614 Test updates * VAULT-6614 Add expiration test with roles * VAULT-6614 fix comment * VAULT-6614 Protobuf on OSS * VAULT-6614 Add rlock to determine role code * VAULT-6614 Try lock instead of rlock * VAULT-6614 back to rlock while I think about this more * VAULT-6614 Additional safety for nil dereference * VAULT-6614 Use %q over %s * VAULT-6614 Add overloading to plugin backends * VAULT-6614 RLocks instead * VAULT-6614 Fix return for backend factory
This commit is contained in:
		| @@ -7,6 +7,8 @@ import ( | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
|  | ||||
| 	log "github.com/hashicorp/go-hclog" | ||||
|  | ||||
| 	uuid "github.com/hashicorp/go-uuid" | ||||
| 	"github.com/hashicorp/vault/sdk/framework" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/consts" | ||||
| @@ -38,7 +40,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, | ||||
|  | ||||
| // Backend returns an instance of the backend, either as a plugin if external | ||||
| // or as a concrete implementation if builtin, casted as logical.Backend. | ||||
| func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { | ||||
| func Backend(ctx context.Context, conf *logical.BackendConfig) (*PluginBackend, error) { | ||||
| 	var b PluginBackend | ||||
|  | ||||
| 	name := conf.Config["plugin_name"] | ||||
| @@ -80,7 +82,7 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, | ||||
|  | ||||
| // PluginBackend is a thin wrapper around plugin.BackendPluginClient | ||||
| type PluginBackend struct { | ||||
| 	logical.Backend | ||||
| 	Backend logical.Backend | ||||
| 	sync.RWMutex | ||||
|  | ||||
| 	config *logical.BackendConfig | ||||
| @@ -118,12 +120,12 @@ func (b *PluginBackend) startBackend(ctx context.Context, storage logical.Storag | ||||
| 	if !b.loaded { | ||||
| 		if b.Backend.Type() != nb.Type() { | ||||
| 			nb.Cleanup(ctx) | ||||
| 			b.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType) | ||||
| 			b.Backend.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType) | ||||
| 			return ErrMismatchType | ||||
| 		} | ||||
| 		if !reflect.DeepEqual(b.Backend.SpecialPaths(), nb.SpecialPaths()) { | ||||
| 			nb.Cleanup(ctx) | ||||
| 			b.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths) | ||||
| 			b.Backend.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths) | ||||
| 			return ErrMismatchPaths | ||||
| 		} | ||||
| 	} | ||||
| @@ -169,7 +171,7 @@ func (b *PluginBackend) lazyLoadBackend(ctx context.Context, storage logical.Sto | ||||
| 		// Reload plugin if it's an rpc.ErrShutdown | ||||
| 		b.Lock() | ||||
| 		if b.canary == canary { | ||||
| 			b.Logger().Debug("reloading plugin backend", "plugin", b.config.Config["plugin_name"]) | ||||
| 			b.Backend.Logger().Debug("reloading plugin backend", "plugin", b.config.Config["plugin_name"]) | ||||
| 			err := b.startBackend(ctx, storage) | ||||
| 			if err != nil { | ||||
| 				b.Unlock() | ||||
| @@ -220,3 +222,52 @@ func (b *PluginBackend) HandleExistenceCheck(ctx context.Context, req *logical.R | ||||
| func (b *PluginBackend) Initialize(ctx context.Context, req *logical.InitializationRequest) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SpecialPaths is a thin wrapper used to ensure we grab the lock for race purposes | ||||
| func (b *PluginBackend) SpecialPaths() *logical.Paths { | ||||
| 	b.RLock() | ||||
| 	defer b.RUnlock() | ||||
| 	return b.Backend.SpecialPaths() | ||||
| } | ||||
|  | ||||
| // System is a thin wrapper used to ensure we grab the lock for race purposes | ||||
| func (b *PluginBackend) System() logical.SystemView { | ||||
| 	b.RLock() | ||||
| 	defer b.RUnlock() | ||||
| 	return b.Backend.System() | ||||
| } | ||||
|  | ||||
| // Logger is a thin wrapper used to ensure we grab the lock for race purposes | ||||
| func (b *PluginBackend) Logger() log.Logger { | ||||
| 	b.RLock() | ||||
| 	defer b.RUnlock() | ||||
| 	return b.Backend.Logger() | ||||
| } | ||||
|  | ||||
| // Cleanup is a thin wrapper used to ensure we grab the lock for race purposes | ||||
| func (b *PluginBackend) Cleanup(ctx context.Context) { | ||||
| 	b.RLock() | ||||
| 	defer b.RUnlock() | ||||
| 	b.Backend.Cleanup(ctx) | ||||
| } | ||||
|  | ||||
| // InvalidateKey is a thin wrapper used to ensure we grab the lock for race purposes | ||||
| func (b *PluginBackend) InvalidateKey(ctx context.Context, key string) { | ||||
| 	b.RLock() | ||||
| 	defer b.RUnlock() | ||||
| 	b.Backend.InvalidateKey(ctx, key) | ||||
| } | ||||
|  | ||||
| // Setup is a thin wrapper used to ensure we grab the lock for race purposes | ||||
| func (b *PluginBackend) Setup(ctx context.Context, config *logical.BackendConfig) error { | ||||
| 	b.RLock() | ||||
| 	defer b.RUnlock() | ||||
| 	return b.Backend.Setup(ctx, config) | ||||
| } | ||||
|  | ||||
| // Type is a thin wrapper used to ensure we grab the lock for race purposes | ||||
| func (b *PluginBackend) Type() logical.BackendType { | ||||
| 	b.RLock() | ||||
| 	defer b.RUnlock() | ||||
| 	return b.Backend.Type() | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: helper/forwarding/types.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: helper/identity/mfa/types.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: helper/identity/types.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: helper/storagepacker/types.proto | ||||
|  | ||||
|   | ||||
| @@ -64,7 +64,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler | ||||
| 			Type:          quotas.TypeRateLimit, | ||||
| 			Path:          path, | ||||
| 			MountPath:     mountPath, | ||||
| 			Role:          core.DetermineRoleFromLoginRequest(mountPath, bodyBytes, r.Context()), | ||||
| 			Role:          core.DetermineRoleFromLoginRequestFromBytes(mountPath, bodyBytes, r.Context()), | ||||
| 			NamespacePath: ns.Path, | ||||
| 			ClientAddress: parseRemoteIPAddress(r), | ||||
| 		}) | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: physical/raft/types.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: sdk/database/dbplugin/database.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: sdk/database/dbplugin/v5/proto/database.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: sdk/helper/pluginutil/multiplexing.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: sdk/logical/identity.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: sdk/logical/plugin.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: sdk/plugin/pb/backend.proto | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: vault/activity/activity_log.proto | ||||
|  | ||||
|   | ||||
| @@ -175,7 +175,7 @@ func (e *ErrInvalidKey) Error() string { | ||||
| 	return fmt.Sprintf("invalid key: %v", e.Reason) | ||||
| } | ||||
|  | ||||
| type RegisterAuthFunc func(context.Context, time.Duration, string, *logical.Auth) error | ||||
| type RegisterAuthFunc func(context.Context, time.Duration, string, *logical.Auth, string) error | ||||
|  | ||||
| type activeAdvertisement struct { | ||||
| 	RedirectAddr     string                     `json:"redirect_addr"` | ||||
| @@ -3324,15 +3324,9 @@ func (c *Core) CheckPluginPerms(pluginName string) (err error) { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given | ||||
| // login request | ||||
| func (c *Core) DetermineRoleFromLoginRequest(mountPoint string, payload []byte, ctx context.Context) string { | ||||
| 	matchingBackend := c.router.MatchingBackend(ctx, mountPoint) | ||||
| 	if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential { | ||||
| 		// Role based quotas do not apply to this request | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| // DetermineRoleFromLoginRequestFromBytes will determine the role that should be applied to a quota for a given | ||||
| // login request, accepting a byte payload | ||||
| func (c *Core) DetermineRoleFromLoginRequestFromBytes(mountPoint string, payload []byte, ctx context.Context) string { | ||||
| 	data := make(map[string]interface{}) | ||||
| 	err := jsonutil.DecodeJSON(payload, &data) | ||||
| 	if err != nil { | ||||
| @@ -3340,6 +3334,20 @@ func (c *Core) DetermineRoleFromLoginRequest(mountPoint string, payload []byte, | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	return c.DetermineRoleFromLoginRequest(mountPoint, data, ctx) | ||||
| } | ||||
|  | ||||
| // DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given | ||||
| // login request | ||||
| func (c *Core) DetermineRoleFromLoginRequest(mountPoint string, data map[string]interface{}, ctx context.Context) string { | ||||
| 	c.authLock.RLock() | ||||
| 	defer c.authLock.RUnlock() | ||||
| 	matchingBackend := c.router.MatchingBackend(ctx, mountPoint) | ||||
| 	if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential { | ||||
| 		// Role based quotas do not apply to this request | ||||
| 		return "" | ||||
| 	} | ||||
|  | ||||
| 	resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{ | ||||
| 		MountPoint: mountPoint, | ||||
| 		Path:       "login", | ||||
|   | ||||
| @@ -166,7 +166,7 @@ func (c *Core) quotaLeaseWalker(ctx context.Context, callback func(request *quot | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leaseIDs []string) error { | ||||
| func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leases []*quotas.QuotaLeaseInformation) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -471,9 +471,15 @@ func (m *ExpirationManager) invalidate(key string) { | ||||
| 				m.pending.Delete(leaseID) | ||||
| 				m.leaseCount-- | ||||
|  | ||||
| 				if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { | ||||
| 					m.logger.Error("failed to update quota on lease invalidation", "error", err) | ||||
| 					return | ||||
| 				// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to | ||||
| 				// accurately update quota lease information. | ||||
| 				// Note that cachedLeaseInfo should never be nil under normal operation. | ||||
| 				if pending.cachedLeaseInfo != nil { | ||||
| 					leaseInfo := "as.QuotaLeaseInformation{LeaseId: leaseID, Role: pending.cachedLeaseInfo.LoginRole} | ||||
| 					if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { | ||||
| 						m.logger.Error("failed to update quota on lease invalidation", "error", err) | ||||
| 						return | ||||
| 					} | ||||
| 				} | ||||
| 			default: | ||||
| 				// Update the lease in memory | ||||
| @@ -486,14 +492,21 @@ func (m *ExpirationManager) invalidate(key string) { | ||||
| 				// other maps, and update metrics/quotas if appropriate. | ||||
| 				m.nonexpiring.Delete(leaseID) | ||||
|  | ||||
| 				if _, ok := m.irrevocable.Load(leaseID); ok { | ||||
| 				if info, ok := m.irrevocable.Load(leaseID); ok { | ||||
| 					irrevocable := info.(pendingInfo) | ||||
| 					m.irrevocable.Delete(leaseID) | ||||
| 					m.irrevocableLeaseCount-- | ||||
|  | ||||
| 					m.leaseCount-- | ||||
| 					if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { | ||||
| 						m.logger.Error("failed to update quota on lease invalidation", "error", err) | ||||
| 						return | ||||
| 					// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to | ||||
| 					// accurately update quota lease information. | ||||
| 					// Note that cachedLeaseInfo should never be nil under normal operation. | ||||
| 					if irrevocable.cachedLeaseInfo != nil { | ||||
| 						leaseInfo := "as.QuotaLeaseInformation{LeaseId: leaseID, Role: irrevocable.cachedLeaseInfo.LoginRole} | ||||
| 						if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { | ||||
| 							m.logger.Error("failed to update quota on lease invalidation", "error", err) | ||||
| 							return | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 				return | ||||
| @@ -1389,7 +1402,7 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request | ||||
| // Register is used to take a request and response with an associated | ||||
| // lease. The secret gets assigned a LeaseID and the management of | ||||
| // of lease is assumed by the expiration manager. | ||||
| func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, resp *logical.Response) (id string, retErr error) { | ||||
| func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, resp *logical.Response, loginRole string) (id string, retErr error) { | ||||
| 	defer metrics.MeasureSince([]string{"expire", "register"}, time.Now()) | ||||
|  | ||||
| 	te := req.TokenEntry() | ||||
| @@ -1431,6 +1444,7 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, | ||||
| 		Path:            req.Path, | ||||
| 		Data:            resp.Data, | ||||
| 		Secret:          resp.Secret, | ||||
| 		LoginRole:       loginRole, | ||||
| 		IssueTime:       time.Now(), | ||||
| 		ExpireTime:      resp.Secret.ExpirationTime(), | ||||
| 		namespace:       ns, | ||||
| @@ -1524,7 +1538,7 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, | ||||
| // RegisterAuth is used to take an Auth response with an associated lease. | ||||
| // The token does not get a LeaseID, but the lease management is handled by | ||||
| // the expiration manager. | ||||
| func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenEntry, auth *logical.Auth) error { | ||||
| func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenEntry, auth *logical.Auth, loginRole string) error { | ||||
| 	defer metrics.MeasureSince([]string{"expire", "register-auth"}, time.Now()) | ||||
|  | ||||
| 	// Triggers failure of RegisterAuth. This should only be set and triggered | ||||
| @@ -1576,6 +1590,7 @@ func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenE | ||||
| 		ClientToken: auth.ClientToken, | ||||
| 		Auth:        auth, | ||||
| 		Path:        te.Path, | ||||
| 		LoginRole:   loginRole, | ||||
| 		IssueTime:   time.Now(), | ||||
| 		ExpireTime:  authExpirationTime, | ||||
| 		namespace:   tokenNS, | ||||
| @@ -1721,6 +1736,7 @@ func (m *ExpirationManager) inMemoryLeaseInfo(le *leaseEntry) *leaseEntry { | ||||
| 	if le.isIrrevocable() { | ||||
| 		ret.RevokeErr = le.RevokeErr | ||||
| 	} | ||||
| 	ret.LoginRole = le.LoginRole | ||||
| 	return ret | ||||
| } | ||||
|  | ||||
| @@ -1795,9 +1811,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry) { | ||||
| 			info.(pendingInfo).timer.Stop() | ||||
| 			m.pending.Delete(le.LeaseID) | ||||
| 			m.leaseCount-- | ||||
| 			if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []string{le.LeaseID}); err != nil { | ||||
| 				m.logger.Error("failed to update quota on lease deletion", "error", err) | ||||
| 				return | ||||
| 			// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to | ||||
| 			// accurately update quota lease information. | ||||
| 			// Note that cachedLeaseInfo should never be nil under normal operation. | ||||
| 			if pending.cachedLeaseInfo != nil { | ||||
| 				leaseInfo := "as.QuotaLeaseInformation{LeaseId: le.LeaseID, Role: le.LoginRole} | ||||
| 				if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { | ||||
| 					m.logger.Error("failed to update quota on lease deletion", "error", err) | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 		return | ||||
| @@ -1849,9 +1871,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry) { | ||||
|  | ||||
| 	if leaseCreated { | ||||
| 		m.leaseCount++ | ||||
| 		if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []string{le.LeaseID}); err != nil { | ||||
| 			m.logger.Error("failed to update quota on lease creation", "error", err) | ||||
| 			return | ||||
| 		// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to | ||||
| 		// accurately update quota lease information. | ||||
| 		// Note that cachedLeaseInfo should never be nil under normal operation. | ||||
| 		if pending.cachedLeaseInfo != nil { | ||||
| 			leaseInfo := "as.QuotaLeaseInformation{LeaseId: le.LeaseID, Role: le.LoginRole} | ||||
| 			if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { | ||||
| 				m.logger.Error("failed to update quota on lease creation", "error", err) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -2450,9 +2478,15 @@ func (m *ExpirationManager) removeFromPending(ctx context.Context, leaseID strin | ||||
| 		m.pending.Delete(leaseID) | ||||
| 		if decrementCounters { | ||||
| 			m.leaseCount-- | ||||
| 			// Log but do not fail; unit tests (and maybe Tidy on production systems) | ||||
| 			if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { | ||||
| 				m.logger.Error("failed to update quota on revocation", "error", err) | ||||
| 			// Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to | ||||
| 			// accurately update quota lease information. | ||||
| 			// Note that cachedLeaseInfo should never be nil under normal operation. | ||||
| 			if pending.cachedLeaseInfo != nil { | ||||
| 				leaseInfo := "as.QuotaLeaseInformation{LeaseId: leaseID, Role: pending.cachedLeaseInfo.LoginRole} | ||||
| 				// Log but do not fail; unit tests (and maybe Tidy on production systems) | ||||
| 				if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { | ||||
| 					m.logger.Error("failed to update quota on revocation", "error", err) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| @@ -2663,6 +2697,11 @@ type leaseEntry struct { | ||||
| 	ExpireTime      time.Time              `json:"expire_time"` | ||||
| 	LastRenewalTime time.Time              `json:"last_renewal_time"` | ||||
|  | ||||
| 	// LoginRole is used to indicate which login role (if applicable) this lease | ||||
| 	// was created with. This is required to decrement lease count quotas | ||||
| 	// based on login roles upon lease expiry. | ||||
| 	LoginRole string `json:"login_role"` | ||||
|  | ||||
| 	// Version is used to track new different versions of leases. V0 (or | ||||
| 	// zero-value) had non-root namespaced secondary indexes live in the root | ||||
| 	// namespace, and V1 has secondary indexes live in the matching namespace. | ||||
|   | ||||
| @@ -324,6 +324,103 @@ func TestExpiration_TotalLeaseCount(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExpiration_TotalLeaseCount_WithRoles(t *testing.T) { | ||||
| 	// Quotas and internal lease count tracker are coupled, so this is a proxy | ||||
| 	// for testing the total lease count quota | ||||
| 	c, _, _ := TestCoreUnsealed(t) | ||||
| 	exp := c.expiration | ||||
|  | ||||
| 	expectedCount := 0 | ||||
| 	otherNS := &namespace.Namespace{ | ||||
| 		ID:   "nsid", | ||||
| 		Path: "foo/bar", | ||||
| 	} | ||||
| 	for i := 0; i < 50; i++ { | ||||
| 		le := &leaseEntry{ | ||||
| 			LeaseID:    "lease" + fmt.Sprintf("%d", i), | ||||
| 			Path:       "foo/bar/" + fmt.Sprintf("%d", i), | ||||
| 			LoginRole:  "loginRole" + fmt.Sprintf("%d", i), | ||||
| 			namespace:  namespace.RootNamespace, | ||||
| 			IssueTime:  time.Now(), | ||||
| 			ExpireTime: time.Now().Add(time.Hour), | ||||
| 		} | ||||
|  | ||||
| 		otherNSle := &leaseEntry{ | ||||
| 			LeaseID:    "lease" + fmt.Sprintf("%d", i) + "/blah.nsid", | ||||
| 			Path:       "foo/bar/" + fmt.Sprintf("%d", i) + "/blah.nsid", | ||||
| 			LoginRole:  "loginRole" + fmt.Sprintf("%d", i), | ||||
| 			namespace:  otherNS, | ||||
| 			IssueTime:  time.Now(), | ||||
| 			ExpireTime: time.Now().Add(time.Hour), | ||||
| 		} | ||||
|  | ||||
| 		exp.pendingLock.Lock() | ||||
| 		if err := exp.persistEntry(namespace.RootContext(nil), le); err != nil { | ||||
| 			exp.pendingLock.Unlock() | ||||
| 			t.Fatalf("error persisting irrevocable entry: %v", err) | ||||
| 		} | ||||
| 		exp.updatePendingInternal(le) | ||||
| 		expectedCount++ | ||||
|  | ||||
| 		if err := exp.persistEntry(namespace.RootContext(nil), otherNSle); err != nil { | ||||
| 			exp.pendingLock.Unlock() | ||||
| 			t.Fatalf("error persisting irrevocable entry: %v", err) | ||||
| 		} | ||||
| 		exp.updatePendingInternal(otherNSle) | ||||
| 		expectedCount++ | ||||
| 		exp.pendingLock.Unlock() | ||||
| 	} | ||||
|  | ||||
| 	// add some irrevocable leases to each count to ensure they are counted too | ||||
| 	// note: irrevocable leases almost certainly have an expire time set in the | ||||
| 	// past, but for this exercise it should be fine to set it to whatever | ||||
| 	for i := 50; i < 60; i++ { | ||||
| 		le := &leaseEntry{ | ||||
| 			LeaseID:    "lease" + fmt.Sprintf("%d", i+1), | ||||
| 			Path:       "foo/bar/" + fmt.Sprintf("%d", i+1), | ||||
| 			LoginRole:  "loginRole" + fmt.Sprintf("%d", i), | ||||
| 			namespace:  namespace.RootNamespace, | ||||
| 			IssueTime:  time.Now(), | ||||
| 			ExpireTime: time.Now(), | ||||
| 			RevokeErr:  "some err message", | ||||
| 		} | ||||
|  | ||||
| 		otherNSle := &leaseEntry{ | ||||
| 			LeaseID:    "lease" + fmt.Sprintf("%d", i+1) + "/blah.nsid", | ||||
| 			Path:       "foo/bar/" + fmt.Sprintf("%d", i+1) + "/blah.nsid", | ||||
| 			LoginRole:  "loginRole" + fmt.Sprintf("%d", i), | ||||
| 			namespace:  otherNS, | ||||
| 			IssueTime:  time.Now(), | ||||
| 			ExpireTime: time.Now(), | ||||
| 			RevokeErr:  "some err message", | ||||
| 		} | ||||
|  | ||||
| 		exp.pendingLock.Lock() | ||||
| 		if err := exp.persistEntry(namespace.RootContext(nil), le); err != nil { | ||||
| 			exp.pendingLock.Unlock() | ||||
| 			t.Fatalf("error persisting irrevocable entry: %v", err) | ||||
| 		} | ||||
| 		exp.updatePendingInternal(le) | ||||
| 		expectedCount++ | ||||
|  | ||||
| 		if err := exp.persistEntry(namespace.RootContext(nil), otherNSle); err != nil { | ||||
| 			exp.pendingLock.Unlock() | ||||
| 			t.Fatalf("error persisting irrevocable entry: %v", err) | ||||
| 		} | ||||
| 		exp.updatePendingInternal(otherNSle) | ||||
| 		expectedCount++ | ||||
| 		exp.pendingLock.Unlock() | ||||
| 	} | ||||
|  | ||||
| 	exp.pendingLock.RLock() | ||||
| 	count := exp.leaseCount | ||||
| 	exp.pendingLock.RUnlock() | ||||
|  | ||||
| 	if count != expectedCount { | ||||
| 		t.Errorf("bad lease count. expected %d, got %d", expectedCount, count) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExpiration_Tidy(t *testing.T) { | ||||
| 	var err error | ||||
|  | ||||
| @@ -477,7 +574,7 @@ func TestExpiration_Tidy(t *testing.T) { | ||||
| 				"test_key": "test_value", | ||||
| 			}, | ||||
| 		} | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("err: %v", err) | ||||
| 		} | ||||
| @@ -636,7 +733,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, | ||||
| 				"secret_key": "abcd", | ||||
| 			}, | ||||
| 		} | ||||
| 		_, err = exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 		_, err = exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 		if err != nil { | ||||
| 			b.Fatalf("err: %v", err) | ||||
| 		} | ||||
| @@ -698,7 +795,7 @@ func BenchmarkExpiration_Create_Leases(b *testing.B) { | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		req.Path = fmt.Sprintf("prod/aws/%d", i) | ||||
| 		_, err = exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 		_, err = exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 		if err != nil { | ||||
| 			b.Fatalf("err: %v", err) | ||||
| 		} | ||||
| @@ -743,7 +840,7 @@ func TestExpiration_Restore(t *testing.T) { | ||||
| 				"secret_key": "abcd", | ||||
| 			}, | ||||
| 		} | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("err: %v", err) | ||||
| 		} | ||||
| @@ -815,7 +912,7 @@ func TestExpiration_Register(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -829,6 +926,49 @@ func TestExpiration_Register(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExpiration_Register_Role(t *testing.T) { | ||||
| 	exp := mockExpiration(t) | ||||
| 	role := "role1" | ||||
| 	req := &logical.Request{ | ||||
| 		Operation:   logical.ReadOperation, | ||||
| 		Path:        "prod/aws/foo", | ||||
| 		ClientToken: "foobar", | ||||
| 	} | ||||
| 	req.SetTokenEntry(&logical.TokenEntry{ID: "foobar", NamespaceID: "root"}) | ||||
| 	resp := &logical.Response{ | ||||
| 		Secret: &logical.Secret{ | ||||
| 			LeaseOptions: logical.LeaseOptions{ | ||||
| 				TTL: time.Hour, | ||||
| 			}, | ||||
| 		}, | ||||
| 		Data: map[string]interface{}{ | ||||
| 			"access_key": "xyz", | ||||
| 			"secret_key": "abcd", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp, role) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if !strings.HasPrefix(id, req.Path) { | ||||
| 		t.Fatalf("bad: %s", id) | ||||
| 	} | ||||
|  | ||||
| 	if len(id) <= len(req.Path) { | ||||
| 		t.Fatalf("bad: %s", id) | ||||
| 	} | ||||
|  | ||||
| 	le, err := exp.loadEntry(exp.quitContext, id) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| 	if le.LoginRole != role { | ||||
| 		t.Fatalf("Login role incorrect. Expected %s, received %s", role, le.LoginRole) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExpiration_Register_BatchToken(t *testing.T) { | ||||
| 	c, _, rootToken := TestCoreUnsealed(t) | ||||
| 	exp := c.expiration | ||||
| @@ -883,7 +1023,7 @@ func TestExpiration_Register_BatchToken(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	leaseID, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 	leaseID, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -952,7 +1092,7 @@ func TestExpiration_RegisterAuth(t *testing.T) { | ||||
| 		Path:        "auth/github/login", | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -961,7 +1101,41 @@ func TestExpiration_RegisterAuth(t *testing.T) { | ||||
| 		Path:        "auth/github/../login", | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err == nil { | ||||
| 		t.Fatal("expected error") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExpiration_RegisterAuth_Role(t *testing.T) { | ||||
| 	exp := mockExpiration(t) | ||||
| 	role := "role1" | ||||
| 	root, err := exp.tokenStore.rootToken(context.Background()) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	auth := &logical.Auth{ | ||||
| 		ClientToken: root.ID, | ||||
| 		LeaseOptions: logical.LeaseOptions{ | ||||
| 			TTL: time.Hour, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	te := &logical.TokenEntry{ | ||||
| 		Path:        "auth/github/login", | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, role) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	te = &logical.TokenEntry{ | ||||
| 		Path:        "auth/github/../login", | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, role) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("expected error") | ||||
| 	} | ||||
| @@ -985,7 +1159,7 @@ func TestExpiration_RegisterAuth_NoLease(t *testing.T) { | ||||
| 		Policies:    []string{"root"}, | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1034,13 +1208,13 @@ func TestExpiration_RegisterAuth_NoTTL(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	// First on core | ||||
| 	err = c.RegisterAuth(ctx, 0, "auth/github/login", auth) | ||||
| 	err = c.RegisterAuth(ctx, 0, "auth/github/login", auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	auth.TokenPolicies[0] = "default" | ||||
| 	err = c.RegisterAuth(ctx, 0, "auth/github/login", auth) | ||||
| 	err = c.RegisterAuth(ctx, 0, "auth/github/login", auth, "") | ||||
| 	if err == nil { | ||||
| 		t.Fatal("expected error") | ||||
| 	} | ||||
| @@ -1053,14 +1227,14 @@ func TestExpiration_RegisterAuth_NoTTL(t *testing.T) { | ||||
| 		Policies:    []string{"root"}, | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(ctx, te, auth) | ||||
| 	err = exp.RegisterAuth(ctx, te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Test non-root token with zero TTL | ||||
| 	te.Policies = []string{"default"} | ||||
| 	err = exp.RegisterAuth(ctx, te, auth) | ||||
| 	err = exp.RegisterAuth(ctx, te, auth, "") | ||||
| 	if err == nil { | ||||
| 		t.Fatal("expected error") | ||||
| 	} | ||||
| @@ -1098,7 +1272,7 @@ func TestExpiration_Revoke(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1145,7 +1319,7 @@ func TestExpiration_RevokeOnExpire(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	_, err = exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 	_, err = exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1208,7 +1382,7 @@ func TestExpiration_RevokePrefix(t *testing.T) { | ||||
| 				"secret_key": "abcd", | ||||
| 			}, | ||||
| 		} | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("err: %v", err) | ||||
| 		} | ||||
| @@ -1277,7 +1451,7 @@ func TestExpiration_RevokeByToken(t *testing.T) { | ||||
| 				"secret_key": "abcd", | ||||
| 			}, | ||||
| 		} | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("err: %v", err) | ||||
| 		} | ||||
| @@ -1376,7 +1550,7 @@ func TestExpiration_RevokeByToken_Blocking(t *testing.T) { | ||||
| 				"secret_key": "abcd", | ||||
| 			}, | ||||
| 		} | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 		_, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("err: %v", err) | ||||
| 		} | ||||
| @@ -1448,7 +1622,7 @@ func TestExpiration_RenewToken(t *testing.T) { | ||||
| 		Path:        "auth/token/login", | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1497,7 +1671,7 @@ func TestExpiration_RenewToken_period(t *testing.T) { | ||||
| 		Path:        "auth/token/login", | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err := exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err := exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1578,7 +1752,7 @@ func TestExpiration_RenewToken_period_backend(t *testing.T) { | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
|  | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1635,7 +1809,7 @@ func TestExpiration_RenewToken_NotRenewable(t *testing.T) { | ||||
| 		Path:        "auth/foo/login", | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1688,7 +1862,7 @@ func TestExpiration_Renew(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1759,7 +1933,7 @@ func TestExpiration_Renew_NotRenewable(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1810,7 +1984,7 @@ func TestExpiration_Renew_RevokeOnExpire(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 	id, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1887,7 +2061,7 @@ func TestExpiration_Renew_FinalSecond(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	ctx := namespace.RootContext(nil) | ||||
| 	id, err := exp.Register(ctx, req, resp) | ||||
| 	id, err := exp.Register(ctx, req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1962,7 +2136,7 @@ func TestExpiration_Renew_FinalSecond_Lease(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	ctx := namespace.RootContext(nil) | ||||
| 	id, err := exp.Register(ctx, req, resp) | ||||
| 	id, err := exp.Register(ctx, req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -2647,7 +2821,7 @@ func sampleToken(t *testing.T, exp *ExpirationManager, path string, expiring boo | ||||
| 		Policies:    auth.Policies, | ||||
| 	} | ||||
|  | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -2822,7 +2996,7 @@ func registerOneLease(t *testing.T, ctx context.Context, exp *ExpirationManager) | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	leaseID, err := exp.Register(ctx, req, resp) | ||||
| 	leaseID, err := exp.Register(ctx, req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|   | ||||
| @@ -211,7 +211,7 @@ func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc { | ||||
| 			} | ||||
| 			authBackend := b.Core.router.MatchingBackend(namespace.ContextWithNamespace(ctx, ns), mountPath) | ||||
| 			if authBackend == nil || authBackend.Type() != logical.TypeCredential { | ||||
| 				return logical.ErrorResponse("Mount path '%s' is not a valid auth method and therefore unsuitable for use with role-based quotas", mountPath), nil | ||||
| 				return logical.ErrorResponse("Mount path %q is not a valid auth method and therefore unsuitable for use with role-based quotas", mountPath), nil | ||||
| 			} | ||||
| 			// We will always error as we aren't supplying real data, but we're looking for "unsupported operation" in particular | ||||
| 			_, err := authBackend.HandleRequest(ctx, &logical.Request{ | ||||
| @@ -219,7 +219,7 @@ func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc { | ||||
| 				Operation: logical.ResolveRoleOperation, | ||||
| 			}) | ||||
| 			if err != nil && (err == logical.ErrUnsupportedOperation || err == logical.ErrUnsupportedPath) { | ||||
| 				return logical.ErrorResponse("Mount path '%s' does not support use with role-based quotas", mountPath), nil | ||||
| 				return logical.ErrorResponse("Mount path %q does not support use with role-based quotas", mountPath), nil | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
|   | ||||
| @@ -1708,7 +1708,7 @@ func TestSystemBackend_revokePrefixAuth_newUrl(t *testing.T) { | ||||
| 			TTL: time.Hour, | ||||
| 		}, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1772,7 +1772,7 @@ func TestSystemBackend_revokePrefixAuth_origUrl(t *testing.T) { | ||||
| 			TTL: time.Hour, | ||||
| 		}, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -3617,7 +3617,7 @@ func TestSystemBackend_PathWildcardPreflight(t *testing.T) { | ||||
| 		ClientToken: te.ID, | ||||
| 		Accessor:    te.Accessor, | ||||
| 		Orphan:      true, | ||||
| 	}); err != nil { | ||||
| 	}, ""); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -716,7 +716,7 @@ func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logic | ||||
| 	} | ||||
|  | ||||
| 	// MFA validation has passed. Let's generate the token | ||||
| 	resp, err := b.Core.LoginMFACreateToken(ctx, cachedResponseAuth.RequestPath, cachedResponseAuth.CachedAuth) | ||||
| 	resp, err := b.Core.LoginMFACreateToken(ctx, cachedResponseAuth.RequestPath, cachedResponseAuth.CachedAuth, req.Data) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create a token. error: %v", err) | ||||
| 	} | ||||
| @@ -742,7 +742,7 @@ func (c *Core) teardownLoginMFA() error { | ||||
|  | ||||
| // LoginMFACreateToken creates a token after the login MFA is validated. | ||||
| // It also applies the lease quotas on the original login request path. | ||||
| func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth) (*logical.Response, error) { | ||||
| func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth, loginRequestData map[string]interface{}) (*logical.Response, error) { | ||||
| 	auth := cachedAuth | ||||
| 	resp := &logical.Response{ | ||||
| 		Auth: auth, | ||||
| @@ -761,6 +761,7 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu | ||||
| 	quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ | ||||
| 		Path:          reqPath, | ||||
| 		MountPath:     strings.TrimPrefix(mountPoint, ns.Path), | ||||
| 		Role:          c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx), | ||||
| 		NamespacePath: ns.Path, | ||||
| 	}) | ||||
|  | ||||
| @@ -780,7 +781,7 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu | ||||
| 	// note that we don't need to handle the error for the following function right away. | ||||
| 	// The function takes the response as in input variable and modify it. So, the returned | ||||
| 	// arguments are resp and err. | ||||
| 	leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, resp) | ||||
| 	leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, resp, loginRequestData) | ||||
|  | ||||
| 	if quotaResp.Access != nil { | ||||
| 		quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated) | ||||
|   | ||||
| @@ -168,6 +168,17 @@ type Manager struct { | ||||
| 	lock       *sync.RWMutex | ||||
| } | ||||
|  | ||||
| // QuotaLeaseInformation contains all of the information lease-count quotas require | ||||
| // from a lease to uniquely identify the lease-count quota to increment/decrement | ||||
| type QuotaLeaseInformation struct { | ||||
| 	// We can determine path and namespace from leaseId | ||||
| 	LeaseId string | ||||
|  | ||||
| 	// We need the role as it's not part of the leaseId, and is required | ||||
| 	// to uniquely identify a lease count quota | ||||
| 	Role string | ||||
| } | ||||
|  | ||||
| // Quota represents the common properties of every quota type | ||||
| type Quota interface { | ||||
| 	// allow checks the if the request is allowed by the quota type implementation. | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: vault/request_forwarding_service.proto | ||||
|  | ||||
|   | ||||
| @@ -969,9 +969,11 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp | ||||
| 	} | ||||
|  | ||||
| 	leaseGenerated := false | ||||
| 	loginRole := c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx) | ||||
| 	quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ | ||||
| 		Path:          req.Path, | ||||
| 		MountPath:     strings.TrimPrefix(req.MountPoint, ns.Path), | ||||
| 		Role:          loginRole, | ||||
| 		NamespacePath: ns.Path, | ||||
| 	}) | ||||
| 	if quotaErr != nil { | ||||
| @@ -1111,7 +1113,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp | ||||
| 				return nil, auth, retErr | ||||
| 			} | ||||
|  | ||||
| 			leaseID, err := registerFunc(ctx, req, resp) | ||||
| 			leaseID, err := registerFunc(ctx, req, resp, loginRole) | ||||
| 			if err != nil { | ||||
| 				c.logger.Error("failed to register lease", "request_path", req.Path, "error", err) | ||||
| 				retErr = multierror.Append(retErr, ErrInternalError) | ||||
| @@ -1191,7 +1193,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp | ||||
| 					Path:        resp.Auth.CreationPath, | ||||
| 					NamespaceID: ns.ID, | ||||
| 				} | ||||
| 				if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth); err != nil { | ||||
| 				if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth, c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)); err != nil { | ||||
| 					// Best-effort clean up on error, so we log the cleanup error as | ||||
| 					// a warning but still return as internal error. | ||||
| 					if err := c.tokenStore.revokeOrphan(ctx, resp.Auth.ClientToken); err != nil { | ||||
| @@ -1390,6 +1392,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re | ||||
| 		quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ | ||||
| 			Path:          req.Path, | ||||
| 			MountPath:     strings.TrimPrefix(req.MountPoint, ns.Path), | ||||
| 			Role:          c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx), | ||||
| 			NamespacePath: ns.Path, | ||||
| 		}) | ||||
|  | ||||
| @@ -1576,7 +1579,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re | ||||
| 		// Attach the display name, might be used by audit backends | ||||
| 		req.DisplayName = auth.DisplayName | ||||
|  | ||||
| 		leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, resp) | ||||
| 		leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, resp, req.Data) | ||||
| 		leaseGenerated = leaseGen | ||||
| 		if errCreateToken != nil { | ||||
| 			return respTokenCreate, nil, errCreateToken | ||||
| @@ -1607,7 +1610,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re | ||||
| // LoginCreateToken creates a token as a result of a login request. | ||||
| // If MFA is enforced, mfa/validate endpoint calls this functions | ||||
| // after successful MFA validation to generate the token. | ||||
| func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint string, resp *logical.Response) (bool, *logical.Response, error) { | ||||
| func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint string, resp *logical.Response, loginRequestData map[string]interface{}) (bool, *logical.Response, error) { | ||||
| 	auth := resp.Auth | ||||
|  | ||||
| 	source := strings.TrimPrefix(mountPoint, credentialRoutePrefix) | ||||
| @@ -1669,7 +1672,7 @@ func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, re | ||||
| 	} | ||||
|  | ||||
| 	leaseGenerated := false | ||||
| 	err = registerFunc(ctx, tokenTTL, reqPath, auth) | ||||
| 	err = registerFunc(ctx, tokenTTL, reqPath, auth, c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx)) | ||||
| 	switch { | ||||
| 	case err == nil: | ||||
| 		if auth.TokenType != logical.TokenTypeBatch { | ||||
| @@ -1736,7 +1739,9 @@ func blockRequestIfErrorImpl(_ *Core, _, _ string) error { return nil } | ||||
|  | ||||
| // RegisterAuth uses a logical.Auth object to create a token entry in the token | ||||
| // store, and registers a corresponding token lease to the expiration manager. | ||||
| func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path string, auth *logical.Auth) error { | ||||
| // role is the login role used as part of the creation of the token entry. If not | ||||
| // relevant, can be omitted (by being provided as ""). | ||||
| func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path string, auth *logical.Auth, role string) error { | ||||
| 	// We first assign token policies to what was returned from the backend | ||||
| 	// via auth.Policies. Then, we get the full set of policies into | ||||
| 	// auth.Policies from the backend + entity information -- this is not | ||||
| @@ -1786,7 +1791,7 @@ func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path st | ||||
| 		auth.Renewable = false | ||||
| 	case logical.TokenTypeService: | ||||
| 		// Register with the expiration manager | ||||
| 		if err := c.expiration.RegisterAuth(ctx, &te, auth); err != nil { | ||||
| 		if err := c.expiration.RegisterAuth(ctx, &te, auth, role); err != nil { | ||||
| 			if err := c.tokenStore.revokeOrphan(ctx, te.ID); err != nil { | ||||
| 				c.logger.Warn("failed to clean up token lease during login request", "request_path", path, "error", err) | ||||
| 			} | ||||
|   | ||||
| @@ -42,7 +42,7 @@ func forward(ctx context.Context, c *Core, req *logical.Request) (*logical.Respo | ||||
| 	panic("forward called in OSS Vault") | ||||
| } | ||||
|  | ||||
| func getLeaseRegisterFunc(c *Core) (func(context.Context, *logical.Request, *logical.Response) (string, error), error) { | ||||
| func getLeaseRegisterFunc(c *Core) (func(context.Context, *logical.Request, *logical.Response, string) (string, error), error) { | ||||
| 	return c.expiration.Register, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -330,7 +330,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { | ||||
| 		NamespaceID:    namespace.RootNamespaceID, | ||||
| 	} | ||||
|  | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), registryEntry, auth); err != nil { | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), registryEntry, auth, ""); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| @@ -375,7 +375,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { | ||||
| 		}, | ||||
| 		ClientToken: ent.ID, | ||||
| 	} | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil { | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| @@ -420,7 +420,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { | ||||
| 		}, | ||||
| 		ClientToken: ent.ID, | ||||
| 	} | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil { | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| @@ -462,7 +462,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { | ||||
| 		}, | ||||
| 		ClientToken: ent.ID, | ||||
| 	} | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil { | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| @@ -496,7 +496,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { | ||||
| 		}, | ||||
| 		ClientToken: ent.ID, | ||||
| 	} | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil { | ||||
| 	if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| @@ -572,7 +572,7 @@ func testMakeTokenViaRequestContext(t testing.TB, ctx context.Context, ts *Token | ||||
| 	} | ||||
|  | ||||
| 	if resp.Auth.TokenType != logical.TokenTypeBatch { | ||||
| 		if err := ts.expiration.RegisterAuth(ctx, te, resp.Auth); err != nil { | ||||
| 		if err := ts.expiration.RegisterAuth(ctx, te, resp.Auth, ""); err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
| @@ -618,7 +618,7 @@ func testMakeTokenDirectly(t testing.TB, ts *TokenStore, te *logical.TokenEntry) | ||||
| 		CreationPath:   te.Path, | ||||
| 		TokenType:      te.Type, | ||||
| 	} | ||||
| 	err := ts.expiration.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err := ts.expiration.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	switch err { | ||||
| 	case nil: | ||||
| 		if te.Type == logical.TokenTypeBatch { | ||||
| @@ -861,7 +861,7 @@ func TestTokenStore_HandleRequest_Renew_Revoke_Accessor(t *testing.T) { | ||||
| 		t.Fatal("token entry was nil") | ||||
| 	} | ||||
|  | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -1322,7 +1322,7 @@ func TestTokenStore_Revoke_Leases(t *testing.T) { | ||||
| 			"secret_key": "abcd", | ||||
| 		}, | ||||
| 	} | ||||
| 	leaseID, err := ts.expiration.Register(namespace.RootContext(nil), req, resp) | ||||
| 	leaseID, err := ts.expiration.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -2208,7 +2208,7 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) { | ||||
| 			Renewable: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -2230,7 +2230,7 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) { | ||||
| 			Renewable: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -2623,7 +2623,7 @@ func TestTokenStore_HandleRequest_Renew(t *testing.T) { | ||||
| 			Renewable: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), root, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), root, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -3113,7 +3113,7 @@ func TestTokenStore_HandleRequest_RenewSelf(t *testing.T) { | ||||
| 			Renewable: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), root, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), root, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -5787,7 +5787,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) { | ||||
| 		NamespaceID: namespace.RootNamespaceID, | ||||
| 	} | ||||
|  | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) | ||||
| 	err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("err: %v", err) | ||||
| 	} | ||||
| @@ -5820,7 +5820,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) { | ||||
| 	leases := []string{} | ||||
|  | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		leaseID, err := exp.Register(namespace.RootContext(nil), req, resp) | ||||
| 		leaseID, err := exp.Register(namespace.RootContext(nil), req, resp, "") | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| // Code generated by protoc-gen-go. DO NOT EDIT. | ||||
| // versions: | ||||
| // 	protoc-gen-go v1.27.1 | ||||
| // 	protoc-gen-go v1.28.0 | ||||
| // 	protoc        v3.19.4 | ||||
| // source: vault/tokens/token.proto | ||||
|  | ||||
|   | ||||
| @@ -325,7 +325,7 @@ DONELISTHANDLING: | ||||
| 	} | ||||
|  | ||||
| 	// Register the wrapped token with the expiration manager | ||||
| 	if err := c.expiration.RegisterAuth(ctx, &te, wAuth); err != nil { | ||||
| 	if err := c.expiration.RegisterAuth(ctx, &te, wAuth, c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)); err != nil { | ||||
| 		// Revoke since it's not yet being tracked for expiration | ||||
| 		c.tokenStore.revokeOrphan(ctx, te.ID) | ||||
| 		c.logger.Error("failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Violet Hynes
					Violet Hynes