From c68e270863cd5518a3380b73c721fe822592b6b8 Mon Sep 17 00:00:00 2001 From: Vishal Nayak Date: Fri, 26 Jun 2020 17:13:16 -0400 Subject: [PATCH] Resource Quotas: Rate Limiting (#9330) --- api/response.go | 4 +- go.mod | 1 + http/handler.go | 49 +- http/logical.go | 30 +- http/logical_test.go | 20 +- http/sys_seal.go | 4 +- http/util.go | 61 +- sdk/logical/error.go | 8 + sdk/logical/response_util.go | 6 +- vault/auth.go | 5 + vault/core.go | 46 +- vault/core_util.go | 23 +- vault/expiration.go | 90 +- vault/external_tests/quotas/quotas_test.go | 384 ++++++++ vault/logical_system.go | 25 +- vault/logical_system_quotas.go | 272 ++++++ vault/logical_system_test.go | 2 +- vault/mount.go | 5 + vault/quotas/quotas.go | 860 ++++++++++++++++++ vault/quotas/quotas_rate_limit.go | 282 ++++++ vault/quotas/quotas_rate_limit_test.go | 173 ++++ vault/quotas/quotas_test.go | 67 ++ vault/quotas/quotas_util.go | 65 ++ vault/request_handling.go | 81 +- vault/router.go | 8 + .../hashicorp/vault/api/response.go | 4 +- .../hashicorp/vault/sdk/logical/error.go | 8 + .../vault/sdk/logical/response_util.go | 6 +- website/pages/docs/internals/telemetry.mdx | 11 + 29 files changed, 2516 insertions(+), 84 deletions(-) create mode 100644 vault/external_tests/quotas/quotas_test.go create mode 100644 vault/logical_system_quotas.go create mode 100644 vault/quotas/quotas.go create mode 100644 vault/quotas/quotas_rate_limit.go create mode 100644 vault/quotas/quotas_rate_limit_test.go create mode 100644 vault/quotas/quotas_test.go create mode 100644 vault/quotas/quotas_util.go diff --git a/api/response.go b/api/response.go index aed2a52e08..ae350c9791 100644 --- a/api/response.go +++ b/api/response.go @@ -27,8 +27,8 @@ func (r *Response) DecodeJSON(out interface{}) error { // body must still be closed manually. func (r *Response) Error() error { // 200 to 399 are okay status codes. 429 is the code for health status of - // standby nodes. - if (r.StatusCode >= 200 && r.StatusCode < 400) || r.StatusCode == 429 { + // standby nodes, otherwise, 429 is treated as quota limit reached. + if (r.StatusCode >= 200 && r.StatusCode < 400) || (r.StatusCode == 429 && r.Request.URL.Path == "/v1/sys/health") { return nil } diff --git a/go.mod b/go.mod index 04f65c2b14..167aece7de 100644 --- a/go.mod +++ b/go.mod @@ -146,6 +146,7 @@ require ( golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 golang.org/x/net v0.0.0-20200602114024-627f9648deb9 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d + golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 golang.org/x/tools v0.0.0-20200416214402-fc959738d646 google.golang.org/api v0.24.0 google.golang.org/grpc v1.29.1 diff --git a/http/handler.go b/http/handler.go index 31047958a9..40c63bd625 100644 --- a/http/handler.go +++ b/http/handler.go @@ -176,8 +176,8 @@ func Handler(props *vault.HandlerProperties) http.Handler { // Wrap the handler in another handler to trigger all help paths. helpWrappedHandler := wrapHelpHandler(mux, core) corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core) - - genericWrappedHandler := genericWrapping(core, corsWrappedHandler, props) + quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core) + genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props) // Wrap the handler with PrintablePathCheckHandler to check for non-printable // characters in the request path. @@ -221,26 +221,14 @@ func (w *copyResponseWriter) WriteHeader(code int) { func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origBody := new(bytes.Buffer) - reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody)) - r.Body = reader - req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) - if err != nil || status != 0 { - respondError(w, status, err) - return - } - if origBody != nil { - r.Body = ioutil.NopCloser(origBody) - } input := &logical.LogInput{ - Request: req, + Request: w.(*LogicalResponseWriter).request, } - core.AuditLogger().AuditRequest(r.Context(), input) cw := newCopyResponseWriter(w) h.ServeHTTP(cw, r) data := make(map[string]interface{}) - err = jsonutil.DecodeJSON(cw.body.Bytes(), &data) + err := jsonutil.DecodeJSON(cw.body.Bytes(), &data) if err != nil { // best effort, ignore } @@ -249,7 +237,13 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { core.AuditLogger().AuditResponse(r.Context(), input) return }) +} +// LogicalResponseWriter is used to carry the logical request from generic +// handler down to all the middleware http handlers. +type LogicalResponseWriter struct { + http.ResponseWriter + request *logical.Request } // wrapGenericHandler wraps the handler with an extra layer of handler where @@ -288,6 +282,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr } ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) r = r.WithContext(ctx) + r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) switch { case strings.HasPrefix(r.URL.Path, "/v1/"): @@ -306,7 +301,27 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr return } - h.ServeHTTP(w, r) + origBody := new(bytes.Buffer) + reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody)) + r.Body = reader + req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) + if err != nil || status != 0 { + respondError(w, status, err) + return + } + // Reset the body since logical request creation already read the + // request body. + r.Body = ioutil.NopCloser(origBody) + + // Set the mount path in the request + req.MountPoint = core.MatchingMount(r.Context(), req.Path) + + // Pass the logical request down through the response writer + h.ServeHTTP(&LogicalResponseWriter{ + ResponseWriter: w, + request: req, + }, r) + cancelFunc() return }) diff --git a/http/logical.go b/http/logical.go index 2a7e0c9a49..2a5c202436 100644 --- a/http/logical.go +++ b/http/logical.go @@ -141,6 +141,7 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. } case "OPTIONS": + case "HEAD": default: return nil, nil, http.StatusMethodNotAllowed, nil } @@ -169,36 +170,32 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. return req, origBody, 0, nil } -func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { - req, origBody, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) - if err != nil || status != 0 { - return nil, nil, status, err - } - +func setupLogicalRequest(core *vault.Core, req *logical.Request, r *http.Request) (*logical.Request, int, error) { + var err error req, err = requestAuth(core, r, req) if err != nil { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { - return nil, nil, http.StatusForbidden, nil + return nil, http.StatusForbidden, nil } - return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err) + return nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err) } req, err = requestWrapInfo(r, req) if err != nil { - return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err) + return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err) } err = parseMFAHeader(req) if err != nil { - return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err) + return nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err) } err = requestPolicyOverride(r, req) if err != nil { - return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err) + return nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err) } - return req, origBody, 0, nil + return req, 0, nil } // handleLogical returns a handler for processing logical requests. These requests @@ -257,7 +254,7 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han // toggles. Refer to usage on functions for possible behaviors. func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForward bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - req, origBody, statusCode, err := buildLogicalRequest(core, w, r) + req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r) if err != nil || statusCode != 0 { respondError(w, statusCode, err) return @@ -270,10 +267,6 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) return } - - if origBody != nil { - r.Body = origBody - } forwardRequest(core, w, r) return } @@ -398,9 +391,6 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) return case needsForward && !noForward: - if origBody != nil { - r.Body = origBody - } forwardRequest(core, w, r) return case !ok: diff --git a/http/logical_test.go b/http/logical_test.go index 550f5f782b..8765993c4d 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -281,7 +281,13 @@ func TestLogical_ListSuffix(t *testing.T) { req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil) req = req.WithContext(namespace.RootContext(nil)) req.Header.Add(consts.AuthHeaderName, rootToken) - lreq, _, status, err := buildLogicalRequest(core, nil, req) + + lreq, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), nil, req) + if err != nil || status != 0 { + t.Fatal(err) + } + + lreq, status, err = setupLogicalRequest(core, lreq, req) if err != nil { t.Fatal(err) } @@ -295,7 +301,11 @@ func TestLogical_ListSuffix(t *testing.T) { req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil) req = req.WithContext(namespace.RootContext(nil)) req.Header.Add(consts.AuthHeaderName, rootToken) - lreq, _, status, err = buildLogicalRequest(core, nil, req) + lreq, _, status, err = buildLogicalRequestNoAuth(core.PerfStandby(), nil, req) + if err != nil || status != 0 { + t.Fatal(err) + } + lreq, status, err = setupLogicalRequest(core, lreq, req) if err != nil { t.Fatal(err) } @@ -309,7 +319,11 @@ func TestLogical_ListSuffix(t *testing.T) { req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil) req = req.WithContext(namespace.RootContext(nil)) req.Header.Add(consts.AuthHeaderName, rootToken) - lreq, _, status, err = buildLogicalRequest(core, nil, req) + lreq, _, status, err = buildLogicalRequestNoAuth(core.PerfStandby(), nil, req) + if err != nil || status != 0 { + t.Fatal(err) + } + lreq, status, err = setupLogicalRequest(core, lreq, req) if err != nil { t.Fatal(err) } diff --git a/http/sys_seal.go b/http/sys_seal.go index a13573addd..17d35c9b0d 100644 --- a/http/sys_seal.go +++ b/http/sys_seal.go @@ -17,7 +17,7 @@ import ( func handleSysSeal(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - req, _, statusCode, err := buildLogicalRequest(core, w, r) + req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r) if err != nil || statusCode != 0 { respondError(w, statusCode, err) return @@ -47,7 +47,7 @@ func handleSysSeal(core *vault.Core) http.Handler { func handleSysStepDown(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - req, _, statusCode, err := buildLogicalRequest(core, w, r) + req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r) if err != nil || statusCode != 0 { respondError(w, statusCode, err) return diff --git a/http/util.go b/http/util.go index c4bd282d84..484186b11e 100644 --- a/http/util.go +++ b/http/util.go @@ -1,15 +1,21 @@ package http import ( + "fmt" + "net" "net/http" + "strings" + "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/vault/quotas" ) var ( adjustRequest = func(c *vault.Core, r *http.Request) (*http.Request, int) { - return r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)), 0 + return r, 0 } genericWrapping = func(core *vault.Core, in http.Handler, props *vault.HandlerProperties) http.Handler { @@ -22,3 +28,56 @@ var ( nonVotersAllowed = false ) + +func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ns, err := namespace.FromContext(r.Context()) + if err != nil { + respondError(w, http.StatusInternalServerError, err) + return + } + req := w.(*LogicalResponseWriter).request + quotaResp, err := core.ApplyRateLimitQuota("as.Request{ + Type: quotas.TypeRateLimit, + Path: req.Path, + MountPath: strings.TrimPrefix(req.MountPoint, ns.Path), + NamespacePath: ns.Path, + ClientAddress: parseRemoteIPAddress(r), + }) + if err != nil { + core.Logger().Error("failed to apply quota", "path", req.Path, "error", err) + respondError(w, http.StatusUnprocessableEntity, err) + return + } + + if !quotaResp.Allowed { + quotaErr := errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrRateLimitQuotaExceeded) + respondError(w, http.StatusTooManyRequests, quotaErr) + + if core.Logger().IsTrace() { + core.Logger().Trace("request rejected due to lease count quota violation", "request_path", req.Path) + } + + if core.RateLimitAuditLoggingEnabled() { + _ = core.AuditLogger().AuditRequest(r.Context(), &logical.LogInput{ + Request: req, + OuterErr: quotaErr, + }) + } + + return + } + + handler.ServeHTTP(w, r) + return + }) +} + +func parseRemoteIPAddress(r *http.Request) string { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return "" + } + + return ip +} diff --git a/sdk/logical/error.go b/sdk/logical/error.go index fd896a6ce3..aab73cc066 100644 --- a/sdk/logical/error.go +++ b/sdk/logical/error.go @@ -28,6 +28,14 @@ var ( // ErrPerfStandbyForward is returned when Vault is in a state such that a // perf standby cannot satisfy a request ErrPerfStandbyPleaseForward = errors.New("please forward to the active node") + + // ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease + // count quota being exceeded. + ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded") + + // ErrRateLimitQuotaExceeded is returned when a request is rejected due to a + // rate limit quota being exceeded. + ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded") ) type HTTPCodedError interface { diff --git a/sdk/logical/response_util.go b/sdk/logical/response_util.go index ee57f8e05a..ce743507fb 100644 --- a/sdk/logical/response_util.go +++ b/sdk/logical/response_util.go @@ -81,7 +81,7 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { } }) if allErrors != nil { - return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors) + return codedErr.Code, multierror.Append(fmt.Errorf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg), allErrors) } return codedErr.Code, errors.New(codedErr.Msg) } @@ -110,6 +110,10 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { statusCode = http.StatusBadRequest case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): statusCode = http.StatusBadGateway + case errwrap.Contains(err, ErrRateLimitQuotaExceeded.Error()): + statusCode = http.StatusTooManyRequests + case errwrap.Contains(err, ErrLeaseCountQuotaExceeded.Error()): + statusCode = http.StatusTooManyRequests } } diff --git a/vault/auth.go b/vault/auth.go index 3d4e169d03..f97ccff7b7 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -339,6 +339,11 @@ func (c *Core) disableCredentialInternal(ctx context.Context, path string, updat removePathCheckers(c, entry, viewPath) + if err := c.quotaManager.HandleBackendDisabling(ctx, ns.Path, path); err != nil { + c.logger.Error("failed to update quotas after disabling auth", "path", path, "error", err) + return err + } + if c.logger.IsInfo() { c.logger.Info("disabled credential backend", "path", path) } diff --git a/vault/core.go b/vault/core.go index 375912d438..31fea7adf2 100644 --- a/vault/core.go +++ b/vault/core.go @@ -43,6 +43,7 @@ import ( sr "github.com/hashicorp/vault/serviceregistration" "github.com/hashicorp/vault/shamir" "github.com/hashicorp/vault/vault/cluster" + "github.com/hashicorp/vault/vault/quotas" vaultseal "github.com/hashicorp/vault/vault/seal" "github.com/patrickmn/go-cache" "google.golang.org/grpc" @@ -97,6 +98,7 @@ var ( enterprisePostUnseal = enterprisePostUnsealImpl enterprisePreSeal = enterprisePreSealImpl enterpriseSetupFilteredPaths = enterpriseSetupFilteredPathsImpl + enterpriseSetupQuotas = enterpriseSetupQuotasImpl startReplication = startReplicationImpl stopReplication = stopReplicationImpl LastWAL = lastWALImpl @@ -520,6 +522,8 @@ type Core struct { // can test an upgrade to a version that includes the fixes from // https://github.com/hashicorp/vault-enterprise/pull/1103 PR1103disabled bool + + quotaManager *quotas.Manager } // CoreConfig is used to parameterize a core @@ -944,7 +948,9 @@ func NewCore(conf *CoreConfig) (*Core, error) { c.clusterListener.Store((*cluster.Listener)(nil)) - err = c.adjustForSealMigration(conf.UnwrapSeal) + quotasLogger := conf.Logger.Named("quotas") + c.allLoggers = append(c.allLoggers, quotasLogger) + c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink) if err != nil { return nil, err } @@ -1822,7 +1828,10 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock, performCleanup } } - postSealInternal(c) + if err := postSealInternal(c); err != nil { + c.logger.Error("post seal error", "error", err) + return err + } c.logger.Info("vault is sealed") @@ -1892,6 +1901,9 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c if err := c.setupCredentials(ctx); err != nil { return err } + if err := c.setupQuotas(ctx, false); err != nil { + return err + } if !c.IsDRSecondary() { if err := c.startRollback(); err != nil { return err @@ -2078,6 +2090,10 @@ func enterpriseSetupFilteredPathsImpl(c *Core) error { return nil } +func enterpriseSetupQuotasImpl(ctx context.Context, c *Core) error { + return nil +} + func startReplicationImpl(c *Core) error { return nil } @@ -2474,3 +2490,29 @@ func (c *Core) readFeatureFlags(ctx context.Context) (*FeatureFlags, error) { } return &flags, nil } + +// MatchingMount returns the path of the mount that will be responsible for +// handling the given request path. +func (c *Core) MatchingMount(ctx context.Context, reqPath string) string { + return c.router.MatchingMount(ctx, reqPath) +} + +func (c *Core) setupQuotas(ctx context.Context, isPerfStandby bool) error { + if c.quotaManager == nil { + return nil + } + + return c.quotaManager.Setup(ctx, c.systemBarrierView, isPerfStandby) +} + +// ApplyRateLimitQuota checks the request against all the applicable quota rules +func (c *Core) ApplyRateLimitQuota(req *quotas.Request) (quotas.Response, error) { + req.Type = quotas.TypeRateLimit + return c.quotaManager.ApplyQuota(req) +} + +// RateLimitAuditLoggingEnabled returns if the quota configuration allows audit +// logging of request rejections due to rate limiting quota rule violations. +func (c *Core) RateLimitAuditLoggingEnabled() bool { + return c.quotaManager.RateLimitAuditLoggingEnabled() +} diff --git a/vault/core_util.go b/vault/core_util.go index 226f5554e3..3cfb0c615f 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/license" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" + "github.com/hashicorp/vault/vault/quotas" "github.com/hashicorp/vault/vault/replication" ) @@ -58,7 +59,7 @@ func addExtraCredentialBackends(*Core, map[string]logical.Factory) {} func preUnsealInternal(context.Context, *Core) error { return nil } -func postSealInternal(*Core) {} +func postSealInternal(*Core) error { return nil } func preSealPhysical(c *Core) { switch c.sealUnwrapper.(type) { @@ -132,3 +133,23 @@ func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, chan struct{}, func (c *Core) initSealsForMigration() {} func (c *Core) postSealMigration(ctx context.Context) error { return nil } + +func (c *Core) applyLeaseCountQuota(in *quotas.Request) (*quotas.Response, error) { + return "as.Response{Allowed: true}, nil +} + +func (c *Core) ackLeaseQuota(access quotas.Access, leaseGenerated bool) error { + return nil +} + +func (c *Core) quotaLeaseWalker(ctx context.Context, callback func(request *quotas.Request) bool) error { + return nil +} + +func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leaseIDs []string) error { + return nil +} + +func (c *Core) namespaceByPath(path string) *namespace.Namespace { + return namespace.RootNamespace +} diff --git a/vault/expiration.go b/vault/expiration.go index 8d0067f45b..143700526c 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -12,17 +12,18 @@ import ( "sync/atomic" "time" - metrics "github.com/armon/go-metrics" - "github.com/hashicorp/errwrap" - log "github.com/hashicorp/go-hclog" - multierror "github.com/hashicorp/go-multierror" - "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/base62" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/locksutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/quotas" + metrics "github.com/armon/go-metrics" + "github.com/hashicorp/errwrap" + log "github.com/hashicorp/go-hclog" + multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/vault/helper/namespace" uberAtomic "go.uber.org/atomic" ) @@ -256,23 +257,60 @@ func (m *ExpirationManager) inRestoreMode() bool { } func (m *ExpirationManager) invalidate(key string) { - switch { case strings.HasPrefix(key, leaseViewPrefix): - // Clear from the pending expiration leaseID := strings.TrimPrefix(key, leaseViewPrefix) - m.pendingLock.Lock() - if info, ok := m.pending.Load(leaseID); ok { - pending := info.(pendingInfo) - pending.timer.Stop() - m.pending.Delete(leaseID) - m.leaseCount-- + ctx := m.quitContext + _, nsID := namespace.SplitIDFromString(leaseID) + leaseNS := namespace.RootNamespace + var err error + if nsID != "" { + leaseNS, err = NamespaceByID(ctx, nsID, m.core) + if err != nil { + m.logger.Error("failed to invalidate lease entry", "error", err) + return + } } - // If in the nonexpiring map, remove there. - m.nonexpiring.Delete(leaseID) + le, err := m.loadEntryInternal(namespace.ContextWithNamespace(ctx, leaseNS), leaseID, false, false) + if err != nil { + m.logger.Error("failed to invalidate lease entry", "error", err) + return + } - m.pendingLock.Unlock() + m.pendingLock.Lock() + defer m.pendingLock.Unlock() + info, ok := m.pending.Load(leaseID) + switch { + case ok: + switch { + case le == nil: + // Handle lease deletion + pending := info.(pendingInfo) + pending.timer.Stop() + m.pending.Delete(leaseID) + m.leaseCount-- + + // If in the nonexpiring map, remove there. + m.nonexpiring.Delete(leaseID) + + if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { + m.logger.Error("failed to handle lease delete invalidation", "error", err) + return + } + default: + // Handle lease update + m.updatePendingInternal(le, le.ExpireTime.Sub(time.Now())) + } + default: + // There is no entry in the pending map and the invalidation + // resulted in a nil entry. This should ideally never happen. + if le == nil { + return + } + // Handle lease creation + m.updatePendingInternal(le, le.ExpireTime.Sub(time.Now())) + } } } @@ -692,13 +730,18 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo } } - // Clear the expiration handler (or remove from the list of non-expiring tokens.) + // Clear the expiration handler m.pendingLock.Lock() if info, ok := m.pending.Load(leaseID); ok { pending := info.(pendingInfo) pending.timer.Stop() m.pending.Delete(leaseID) m.leaseCount-- + if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { + m.pendingLock.Unlock() + m.logger.Error("failed to handle lease path deletion", "error", err) + return err + } } m.nonexpiring.Delete(leaseID) m.pendingLock.Unlock() @@ -1420,10 +1463,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry, leaseTotal tim info.(pendingInfo).timer.Stop() m.pending.Delete(le.LeaseID) m.leaseCount-- + if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []string{le.LeaseID}); err != nil { + m.logger.Error("failed to handle lease path deletion", "error", err) + return + } } return } + leaseCreated := false // Create entry if it does not exist or reset if it does if ok { pending = info.(pendingInfo) @@ -1439,12 +1487,20 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry, leaseTotal tim } // new lease m.leaseCount++ + leaseCreated = true } // Retain some information in-memory pending.cachedLeaseInfo = m.inMemoryLeaseInfo(le) m.pending.Store(le.LeaseID, pending) + + if leaseCreated { + if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []string{le.LeaseID}); err != nil { + m.logger.Error("failed to handle lease creation", "error", err) + return + } + } } // revokeEntry is used to attempt revocation of an internal entry diff --git a/vault/external_tests/quotas/quotas_test.go b/vault/external_tests/quotas/quotas_test.go new file mode 100644 index 0000000000..4a151fb800 --- /dev/null +++ b/vault/external_tests/quotas/quotas_test.go @@ -0,0 +1,384 @@ +package quotas + +import ( + "fmt" + "testing" + "time" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/sdk/logical" + + "github.com/hashicorp/vault/builtin/credential/userpass" + "github.com/hashicorp/vault/builtin/logical/pki" + "github.com/hashicorp/vault/helper/testhelpers/teststorage" + "github.com/hashicorp/vault/vault" + "go.uber.org/atomic" +) + +const ( + testLookupOnlyPolicy = ` +path "/auth/token/lookup" { + capabilities = [ "create", "update"] +} +` +) + +var ( + coreConfig = &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "pki": pki.Factory, + }, + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + } +) + +func setupMounts(t *testing.T, client *api.Client) { + t.Helper() + + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{ + "password": "bar", + }) + if err != nil { + t.Fatal(err) + } + + err = client.Sys().Mount("pki", &api.MountInput{ + Type: "pki", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{ + "common_name": "testvault.com", + "ttl": "200h", + "ip_sans": "127.0.0.1", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("pki/roles/test", map[string]interface{}{ + "require_cn": false, + "allowed_domains": "testvault.com", + "allow_subdomains": true, + "max_ttl": "2h", + "generate_lease": true, + }) + if err != nil { + t.Fatal(err) + } + +} + +func teardownMounts(t *testing.T, client *api.Client) { + t.Helper() + if err := client.Sys().Unmount("pki"); err != nil { + t.Fatal(err) + } + if err := client.Sys().DisableAuth("userpass"); err != nil { + t.Fatal(err) + } +} + +func testRPS(reqFunc func(numSuccess, numFail *atomic.Int32), d time.Duration) (int32, int32, time.Duration) { + numSuccess := atomic.NewInt32(0) + numFail := atomic.NewInt32(0) + + start := time.Now() + end := start.Add(d) + for time.Now().Before(end) { + reqFunc(numSuccess, numFail) + } + + return numSuccess.Load(), numFail.Load(), time.Since(start) +} + +func waitForRemovalOrTimeout(c *api.Client, path string, tick, to time.Duration) error { + ticker := time.Tick(tick) + timeout := time.After(to) + + // wait for the resource to be removed + for { + select { + case <-timeout: + return fmt.Errorf("timeout exceeding waiting for resource to be deleted: %s", path) + + case <-ticker: + resp, err := c.Logical().Read(path) + if err != nil { + return err + } + + if resp == nil { + return nil + } + } + } +} + +func TestQuotas_RateLimitQuota_Mount(t *testing.T) { + conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil) + cluster := vault.NewTestCluster(t, conf, opts) + cluster.Start() + defer cluster.Cleanup() + + core := cluster.Cores[0].Core + client := cluster.Cores[0].Client + vault.TestWaitActive(t, core) + + err := client.Sys().Mount("pki", &api.MountInput{ + Type: "pki", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{ + "common_name": "testvault.com", + "ttl": "200h", + "ip_sans": "127.0.0.1", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("pki/roles/test", map[string]interface{}{ + "require_cn": false, + "allowed_domains": "testvault.com", + "allow_subdomains": true, + "max_ttl": "2h", + "generate_lease": true, + }) + if err != nil { + t.Fatal(err) + } + + reqFunc := func(numSuccess, numFail *atomic.Int32) { + _, err := client.Logical().Read("pki/cert/ca_chain") + + if err != nil { + numFail.Add(1) + } else { + numSuccess.Add(1) + } + } + + // Create a rate limit quota with a low RPS of 7.7, which means we can process + // ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed + // by a refill rate of 7.7 per-second. + _, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{ + "rate": 7.7, + "burst": 8, + "path": "pki/", + }) + if err != nil { + t.Fatal(err) + } + + numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second) + + // evaluate the ideal RPS as (burst + (RPS * totalSeconds)) + ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second)) + + // ensure there were some failed requests + if numFail == 0 { + t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed) + } + + // ensure that we should never get more requests than allowed + if want := int32(ideal + 1); numSuccess > want { + t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed) + } + + // update the rate limit quota with a high RPS such that no requests should fail + _, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{ + "rate": 1000.0, + "burst": 3000, + "path": "pki/", + }) + if err != nil { + t.Fatal(err) + } + + _, numFail, _ = testRPS(reqFunc, 5*time.Second) + if numFail > 0 { + t.Fatalf("unexpected number of failed requests: %d", numFail) + } +} + +func TestQuotas_RateLimitQuota_MountPrecedence(t *testing.T) { + conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil) + cluster := vault.NewTestCluster(t, conf, opts) + cluster.Start() + defer cluster.Cleanup() + + core := cluster.Cores[0].Core + client := cluster.Cores[0].Client + + vault.TestWaitActive(t, core) + + // create PKI mount + err := client.Sys().Mount("pki", &api.MountInput{ + Type: "pki", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{ + "common_name": "testvault.com", + "ttl": "200h", + "ip_sans": "127.0.0.1", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("pki/roles/test", map[string]interface{}{ + "require_cn": false, + "allowed_domains": "testvault.com", + "allow_subdomains": true, + "max_ttl": "2h", + "generate_lease": true, + }) + if err != nil { + t.Fatal(err) + } + + // create a root rate limit quota + _, err = client.Logical().Write("sys/quotas/rate-limit/root-rlq", map[string]interface{}{ + "name": "root-rlq", + "rate": 14.7, + "burst": 15, + }) + if err != nil { + t.Fatal(err) + } + + // create a mount rate limit quota with a lower RPS than the root rate limit quota + _, err = client.Logical().Write("sys/quotas/rate-limit/mount-rlq", map[string]interface{}{ + "name": "mount-rlq", + "rate": 7.7, + "burst": 8, + "path": "pki/", + }) + if err != nil { + t.Fatal(err) + } + + // ensure mount rate limit quota takes precedence over root rate limit quota + reqFunc := func(numSuccess, numFail *atomic.Int32) { + _, err := client.Logical().Read("pki/cert/ca_chain") + + if err != nil { + numFail.Add(1) + } else { + numSuccess.Add(1) + } + } + + // ensure mount rate limit quota takes precedence over root rate limit quota + numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second) + + // evaluate the ideal RPS as (burst + (RPS * totalSeconds)) + ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second)) + + // ensure there were some failed requests + if numFail == 0 { + t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed) + } + + // ensure that we should never get more requests than allowed + if want := int32(ideal + 1); numSuccess > want { + t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed) + } +} + +func TestQuotas_RateLimitQuota(t *testing.T) { + conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil) + cluster := vault.NewTestCluster(t, conf, opts) + cluster.Start() + defer cluster.Cleanup() + + core := cluster.Cores[0].Core + client := cluster.Cores[0].Client + + vault.TestWaitActive(t, core) + + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + _, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{ + "password": "bar", + }) + if err != nil { + t.Fatal(err) + } + + // Create a rate limit quota with a low RPS of 7.7, which means we can process + // ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed + // by a refill rate of 7.7 per-second. + _, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{ + "rate": 7.7, + "burst": 8, + }) + if err != nil { + t.Fatal(err) + } + + reqFunc := func(numSuccess, numFail *atomic.Int32) { + _, err := client.Logical().Read("sys/quotas/rate-limit/rlq") + + if err != nil { + numFail.Add(1) + } else { + numSuccess.Add(1) + } + } + + numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second) + + // evaluate the ideal RPS as (burst + (RPS * totalSeconds)) + ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second)) + + // ensure there were some failed requests + if numFail == 0 { + t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed) + } + + // ensure that we should never get more requests than allowed + if want := int32(ideal + 1); numSuccess > want { + t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed) + } + + // allow time (1s) for rate limit to refill before updating the quota + time.Sleep(time.Second) + + // update the rate limit quota with a high RPS such that no requests should fail + _, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{ + "rate": 1000.0, + "burst": 3000, + }) + if err != nil { + t.Fatal(err) + } + + _, numFail, _ = testRPS(reqFunc, 5*time.Second) + if numFail > 0 { + t.Fatalf("unexpected number of failed requests: %d", numFail) + } +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 81e209885b..da56f3e6da 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -160,6 +160,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend { b.Backend.Paths = append(b.Backend.Paths, b.metricsPath()) b.Backend.Paths = append(b.Backend.Paths, b.monitorPath()) b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath()) + b.Backend.Paths = append(b.Backend.Paths, b.quotasPaths()...) if core.rawEnabled { b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) @@ -751,7 +752,7 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d // Get all the options path := data.Get("path").(string) - path = sanitizeMountPath(path) + path = sanitizePath(path) logicalType := data.Get("type").(string) description := data.Get("description").(string) @@ -934,7 +935,7 @@ func handleErrorNoReadOnlyForward( // handleUnmount is used to unmount a path func (b *SystemBackend) handleUnmount(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { path := data.Get("path").(string) - path = sanitizeMountPath(path) + path = sanitizePath(path) ns, err := namespace.FromContext(ctx) if err != nil { @@ -1029,6 +1030,12 @@ func (b *SystemBackend) handleRemount(ctx context.Context, req *logical.Request, return handleError(err) } + // Update quotas with the new path + if err := b.Core.quotaManager.HandleRemount(ctx, ns.Path, sanitizePath(fromPath), sanitizePath(toPath)); err != nil { + b.Core.logger.Error("failed to update quotas after remount", "ns_path", ns.Path, "from_path", fromPath, "to_path", toPath, "error", err) + return handleError(err) + } + return nil, nil } @@ -1060,7 +1067,7 @@ func (b *SystemBackend) handleMountTuneRead(ctx context.Context, req *logical.Re // handleTuneReadCommon returns the config settings of a path func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) (*logical.Response, error) { - path = sanitizeMountPath(path) + path = sanitizePath(path) sysView := b.Core.router.MatchingSystemView(ctx, path) if sysView == nil { @@ -1146,7 +1153,7 @@ func (b *SystemBackend) handleMountTuneWrite(ctx context.Context, req *logical.R func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, data *framework.FieldData) (*logical.Response, error) { repState := b.Core.ReplicationState() - path = sanitizeMountPath(path) + path = sanitizePath(path) // Prevent protected paths from being changed for _, p := range untunableMounts { @@ -1716,7 +1723,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque // Get all the options path := data.Get("path").(string) - path = sanitizeMountPath(path) + path = sanitizePath(path) logicalType := data.Get("type").(string) description := data.Get("description").(string) pluginName := data.Get("plugin_name").(string) @@ -1857,7 +1864,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque // handleDisableAuth is used to disable a credential backend func (b *SystemBackend) handleDisableAuth(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { path := data.Get("path").(string) - path = sanitizeMountPath(path) + path = sanitizePath(path) ns, err := namespace.FromContext(ctx) if err != nil { @@ -2272,7 +2279,7 @@ func (b *SystemBackend) handleAuditHash(ctx context.Context, req *logical.Reques return logical.ErrorResponse("the \"input\" parameter is empty"), nil } - path = sanitizeMountPath(path) + path = sanitizePath(path) hash, err := b.Core.auditBroker.GetHash(ctx, path, input) if err != nil { @@ -3258,7 +3265,7 @@ func (b *SystemBackend) pathInternalUIMountRead(ctx context.Context, req *logica if path == "" { return logical.ErrorResponse("path not set"), logical.ErrInvalidRequest } - path = sanitizeMountPath(path) + path = sanitizePath(path) errResp := logical.ErrorResponse(fmt.Sprintf("preflight capability check returned 403, please ensure client's policies grant access to path %q", path)) @@ -3576,7 +3583,7 @@ func (b *SystemBackend) pathInternalOpenAPI(ctx context.Context, req *logical.Re return resp, nil } -func sanitizeMountPath(path string) string { +func sanitizePath(path string) string { if !strings.HasSuffix(path, "/") { path += "/" } diff --git a/vault/logical_system_quotas.go b/vault/logical_system_quotas.go new file mode 100644 index 0000000000..12047a5168 --- /dev/null +++ b/vault/logical_system_quotas.go @@ -0,0 +1,272 @@ +package vault + +import ( + "context" + "strings" + + "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/quotas" +) + +// quotasPaths returns paths that enable quota management +func (b *SystemBackend) quotasPaths() []*framework.Path { + return []*framework.Path{ + { + Pattern: "quotas/config$", + Fields: map[string]*framework.FieldSchema{ + "enable_rate_limit_audit_logging": { + Type: framework.TypeBool, + Description: "If set, starts audit logging of requests that get rejected due to rate limit quota rule violations.", + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.handleQuotasConfigUpdate(), + }, + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handleQuotasConfigRead(), + }, + }, + HelpSynopsis: strings.TrimSpace(quotasHelp["quotas-config"][0]), + HelpDescription: strings.TrimSpace(quotasHelp["quotas-config"][1]), + }, + { + Pattern: "quotas/rate-limit/?$", + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ListOperation: &framework.PathOperation{ + Callback: b.handleRateLimitQuotasList(), + }, + }, + HelpSynopsis: strings.TrimSpace(quotasHelp["rate-limit-list"][0]), + HelpDescription: strings.TrimSpace(quotasHelp["rate-limit-list"][1]), + }, + { + Pattern: "quotas/rate-limit/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "type": { + Type: framework.TypeString, + Description: "Type of the quota rule.", + }, + "name": { + Type: framework.TypeString, + Description: "Name of the quota rule.", + }, + "path": { + Type: framework.TypeString, + Description: `Path of the mount or namespace to apply the quota. A blank path configures a +global quota. For example namespace1/ adds a quota to a full namespace, +namespace1/auth/userpass adds a quota to userpass in namespace1.`, + }, + "rate": { + Type: framework.TypeFloat, + Description: `The rate at which allowed requests are refilled per second by the quota rule. +Internally, a token-bucket algorithm is used which has a size of 'burst', initially full. The quota +limits requests to 'rate' per-second, with a maximum burst size of 'burst'. Each request takes a single +token from this bucket. The 'rate' must be positive.`, + }, + "burst": { + Type: framework.TypeInt, + Description: `The maximum number of requests at any given second to be allowed by the quota +rule. There is a one-to-one mapping between requests and tokens in the rate limit quota. A client +may perform up to 'burst' requests at once, at which they they may invoke additional requests at +'rate' per-second.`, + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.handleRateLimitQuotasUpdate(), + }, + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handleRateLimitQuotasRead(), + }, + logical.DeleteOperation: &framework.PathOperation{ + Callback: b.handleRateLimitQuotasDelete(), + }, + }, + HelpSynopsis: strings.TrimSpace(quotasHelp["rate-limit"][0]), + HelpDescription: strings.TrimSpace(quotasHelp["rate-limit"][1]), + }, + } +} + +func (b *SystemBackend) handleQuotasConfigUpdate() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + config, err := quotas.LoadConfig(ctx, b.Core.systemBarrierView) + if err != nil { + return nil, err + } + config.EnableRateLimitAuditLogging = d.Get("enable_rate_limit_audit_logging").(bool) + + entry, err := logical.StorageEntryJSON(quotas.ConfigPath, config) + if err != nil { + return nil, err + } + if err := req.Storage.Put(ctx, entry); err != nil { + return nil, err + } + + b.Core.quotaManager.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging) + return nil, nil + } +} + +func (b *SystemBackend) handleQuotasConfigRead() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + config := b.Core.quotaManager.Config() + return &logical.Response{ + Data: map[string]interface{}{ + "enable_rate_limit_audit_logging": config.EnableRateLimitAuditLogging, + }, + }, nil + } +} + +func (b *SystemBackend) handleRateLimitQuotasList() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + names, err := b.Core.quotaManager.QuotaNames(quotas.TypeRateLimit) + if err != nil { + return nil, err + } + + return logical.ListResponse(names), nil + } +} + +func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + qType := quotas.TypeRateLimit.String() + rate := d.Get("rate").(float64) + if rate <= 0 { + return logical.ErrorResponse("'rate' is invalid"), nil + } + + burst := d.Get("burst").(int) + if burst < int(rate) { + return logical.ErrorResponse("'burst' must be greater than or equal to 'rate' as an integer value"), nil + } + + mountPath := sanitizePath(d.Get("path").(string)) + ns := b.Core.namespaceByPath(mountPath) + if ns.ID != namespace.RootNamespaceID { + mountPath = strings.TrimPrefix(mountPath, ns.Path) + } + + if mountPath != "" { + match := b.Core.router.MatchingMount(namespace.ContextWithNamespace(ctx, ns), mountPath) + if match == "" { + return logical.ErrorResponse("invalid mount path %q", mountPath), nil + } + } + + // Disallow duplicate quotas with same precedence and similar + // properties. + quota, err := b.Core.quotaManager.QuotaByFactors(ctx, qType, ns.Path, mountPath) + if err != nil { + return nil, err + } + if quota != nil && quota.QuotaName() != name { + return logical.ErrorResponse("quota rule with similar properties exists under the name %q", quota.QuotaName()), nil + } + + switch { + case quota == nil: + quota = quotas.NewRateLimitQuota(name, ns.Path, mountPath, rate, burst) + default: + rlq := quota.(*quotas.RateLimitQuota) + rlq.NamespacePath = ns.Path + rlq.MountPath = mountPath + rlq.Rate = rate + rlq.Burst = burst + } + + entry, err := logical.StorageEntryJSON(quotas.QuotaStoragePath(qType, name), quota) + if err != nil { + return nil, err + } + + if err := req.Storage.Put(ctx, entry); err != nil { + return nil, err + } + + if err := b.Core.quotaManager.SetQuota(ctx, qType, quota, false); err != nil { + return nil, err + } + + return nil, nil + } +} + +func (b *SystemBackend) handleRateLimitQuotasRead() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + qType := quotas.TypeRateLimit.String() + + quota, err := b.Core.quotaManager.QuotaByName(qType, name) + if err != nil { + return nil, err + } + if quota == nil { + return nil, nil + } + + rlq := quota.(*quotas.RateLimitQuota) + + nsPath := rlq.NamespacePath + if rlq.NamespacePath == "root" { + nsPath = "" + } + + data := map[string]interface{}{ + "type": qType, + "name": rlq.Name, + "path": nsPath + rlq.MountPath, + "rate": rlq.Rate, + "burst": rlq.Burst, + } + + return &logical.Response{ + Data: data, + }, nil + } +} + +func (b *SystemBackend) handleRateLimitQuotasDelete() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + qType := quotas.TypeRateLimit.String() + + if err := req.Storage.Delete(ctx, quotas.QuotaStoragePath(qType, name)); err != nil { + return nil, err + } + + if err := b.Core.quotaManager.DeleteQuota(ctx, qType, name); err != nil { + return nil, err + } + + return nil, nil + } +} + +var quotasHelp = map[string][2]string{ + "quotas-config": { + "Create, update and read the quota configuration.", + "", + }, + "rate-limit": { + `Get, create or update rate limit resource quota for an optional namespace or +mount.`, + `A rate limit quota will enforce rate limiting using a token bucket algorithm. A +rate limit quota can be created at the root level or defined on a namespace or +mount by specifying a 'path'. The rate limiter is applied to each unique client +IP address. A client may invoke 'burst' requests at any given second, after +which they may invoke additional requests at 'rate' per-second.`, + }, + "rate-limit-list": { + "Lists the names of all the rate limit quotas.", + "This list contains quota definitions from all the namespaces.", + }, +} diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 14a725a2f3..296719e137 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -2654,7 +2654,7 @@ func TestSystemBackend_PathWildcardPreflight(t *testing.T) { // Add another mount me := &MountEntry{ Table: mountTableType, - Path: sanitizeMountPath("kv-v1"), + Path: sanitizePath("kv-v1"), Type: "kv", Options: map[string]string{"version": "1"}, } diff --git a/vault/mount.go b/vault/mount.go index 1a8c2c0e8b..a78c228da5 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -664,6 +664,11 @@ func (c *Core) unmountInternal(ctx context.Context, path string, updateStorage b removePathCheckers(c, entry, viewPath) + if err := c.quotaManager.HandleBackendDisabling(ctx, ns.Path, path); err != nil { + c.logger.Error("failed to update quotas after disabling mount", "path", path, "error", err) + return err + } + if c.logger.IsInfo() { c.logger.Info("successfully unmounted", "path", path, "namespace", ns.Path) } diff --git a/vault/quotas/quotas.go b/vault/quotas/quotas.go new file mode 100644 index 0000000000..e8f07e2452 --- /dev/null +++ b/vault/quotas/quotas.go @@ -0,0 +1,860 @@ +package quotas + +import ( + "context" + "errors" + "fmt" + "path" + "strings" + "sync" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/vault/helper/metricsutil" + "github.com/hashicorp/vault/sdk/logical" +) + +// Type represents the quota kind +type Type string + +const ( + // TypeRateLimit represents the rate limiting quota type + TypeRateLimit Type = "rate-limit" + + // TypeLeaseCount represents the lease count limiting quota type + TypeLeaseCount Type = "lease-count" +) + +// LeaseAction is the action taken by the expiration manager on the lease. The +// quota manager will use this information to update the lease path cache and +// updating counters for relevant quota rules. +type LeaseAction uint32 + +// String converts each lease action into its string equivalent value +func (la LeaseAction) String() string { + switch la { + case LeaseActionLoaded: + return "loaded" + case LeaseActionCreated: + return "created" + case LeaseActionDeleted: + return "deleted" + case LeaseActionAllow: + return "allow" + } + return "unknown" +} + +const ( + _ LeaseAction = iota + + // LeaseActionLoaded indicates loading of lease in the expiration manager after + // unseal. + LeaseActionLoaded + + // LeaseActionCreated indicates that a lease is created in the expiration manager. + LeaseActionCreated + + // LeaseActionDeleted indicates that is lease is expired and deleted in the + // expiration manager. + LeaseActionDeleted + + // LeaseActionAllow will be used to indicate the lease count checker that + // incCounter is called from Allow(). All the rest of the actions indicate the + // action took place on the lease in the expiration manager. + LeaseActionAllow +) + +type leaseWalkFunc func(context.Context, func(request *Request) bool) error + +// String converts each quota type into its string equivalent value +func (q Type) String() string { + switch q { + case TypeLeaseCount: + return "lease-count" + case TypeRateLimit: + return "rate-limit" + } + return "unknown" +} + +const ( + indexID = "id" + indexName = "name" + indexNamespace = "ns" + indexNamespaceMount = "ns_mount" +) + +const ( + // StoragePrefix is the prefix for the physical location where quota rules are + // persisted. + StoragePrefix = "quotas/" + + // ConfigPath is the physical location where the quota configuration is + // persisted. + ConfigPath = StoragePrefix + "config" +) + +var ( + // ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease + // count quota being exceeded. + ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded") + + // ErrRateLimitQuotaExceeded is returned when a request is rejected due to a + // rate limit quota being exceeded. + ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded") +) + +// Access provides information to reach back to the quota checker. +type Access interface { + // QuotaID is the identifier of the quota that issued this access. + QuotaID() string +} + +// Ensure that access implements the Access interface. +var _ Access = (*access)(nil) + +// access implements the Access interface +type access struct { + quotaID string +} + +// QuotaID returns the identifier of the quota rule to which this access refers +// to. +func (a *access) QuotaID() string { + return a.quotaID +} + +// Manager holds all the existing quota rules. For any given input. the manager +// checks them against any applicable quota rules. +type Manager struct { + entManager + + // db holds the in memory instances of all active quota rules indexed by + // some of the quota properties. + db *memdb.MemDB + + // config containing operator preferences and quota behaviors + config *Config + + storage logical.Storage + ctx context.Context + + logger log.Logger + metricSink *metricsutil.ClusterMetricSink + lock *sync.RWMutex +} + +// Quota represents the common properties of every quota type +type Quota interface { + // allow checks the if the request is allowed by the quota type implementation. + allow(*Request) (Response, error) + + // quotaID is the identifier of the quota rule + quotaID() string + + // QuotaName is the name of the quota rule + QuotaName() string + + // initialize sets up the fields in the quota type to begin operating + initialize(log.Logger, *metricsutil.ClusterMetricSink) error + + // close defines any cleanup behavior that needs to be executed when a quota + // rule is deleted. + close() error + + // handleRemount takes in the new mount path in the quota + handleRemount(string) +} + +// Response holds information about the result of the Allow() call. The response +// can optionally have the Access field set, which is used to reach back into +// the quota rule that sent this response. +type Response struct { + // Allowed is set if the quota allows the request + Allowed bool + + // Access is the handle to reach back into the quota rule that processed the + // quota request. This may not be set all the time. + Access Access +} + +// Config holds operator preferences around quota behaviors +type Config struct { + // EnableRateLimitAuditLogging, if set, starts audit logging of the + // request rejections that arise due to rate limit quota violations. + EnableRateLimitAuditLogging bool `json:"enable_rate_limit_audit_logging"` +} + +// Request contains information required by the quota manager to query and +// apply the quota rules. +type Request struct { + // Type is the quota type + Type Type + + // Path is the request path to which quota rules are being queried for + Path string + + // NamespacePath is the namespace path to which the request belongs + NamespacePath string + + // MountPath is the mount path to which the request is made + MountPath string + + // ClientAddress is client unique addressable string (e.g. IP address). It can + // be empty if the quota type does not need it. + ClientAddress string +} + +// NewManager creates and initializes a new quota manager to hold all the quota +// rules and to process incoming requests. +func NewManager(logger log.Logger, walkFunc leaseWalkFunc, ms *metricsutil.ClusterMetricSink) (*Manager, error) { + db, err := memdb.NewMemDB(dbSchema()) + if err != nil { + return nil, err + } + + manager := &Manager{ + db: db, + logger: logger, + metricSink: ms, + config: new(Config), + lock: new(sync.RWMutex), + } + + manager.init(walkFunc) + + return manager, nil +} + +// SetQuota adds a new quota rule to the db. +func (m *Manager) SetQuota(ctx context.Context, qType string, quota Quota, loading bool) error { + m.lock.Lock() + defer m.lock.Unlock() + return m.setQuotaLocked(ctx, qType, quota, loading) +} + +// setQuotaLocked should be called with the manager's lock held +func (m *Manager) setQuotaLocked(ctx context.Context, qType string, quota Quota, loading bool) error { + if qType == TypeLeaseCount.String() { + m.setIsPerfStandby(quota) + } + + txn := m.db.Txn(true) + defer txn.Abort() + + raw, err := txn.First(qType, "id", quota.quotaID()) + if err != nil { + return err + } + + // If there already exists an entry in the db, remove that first. + if raw != nil { + err = txn.Delete(qType, raw) + if err != nil { + return err + } + } + + // Initialize the quota type implementation + if err := quota.initialize(m.logger, m.metricSink); err != nil { + return err + } + + // Add the initialized quota type implementation to the db + if err := txn.Insert(qType, quota); err != nil { + return err + } + + if loading { + txn.Commit() + return nil + } + + // For the lease count type, recompute the counters + if !loading && qType == TypeLeaseCount.String() { + if err := m.recomputeLeaseCounts(ctx, txn); err != nil { + return err + } + } + + txn.Commit() + return nil +} + +// QuotaNames returns the names of all the quota rules for a given type +func (m *Manager) QuotaNames(qType Type) ([]string, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + txn := m.db.Txn(false) + iter, err := txn.Get(qType.String(), indexID) + if err != nil { + return nil, err + } + var names []string + for raw := iter.Next(); raw != nil; raw = iter.Next() { + names = append(names, raw.(Quota).QuotaName()) + } + return names, nil +} + +// QuotaByID queries for a quota rule in the db for a given quota ID +func (m *Manager) QuotaByID(qType string, id string) (Quota, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + txn := m.db.Txn(false) + + quotaRaw, err := txn.First(qType, indexID, id) + if err != nil { + return nil, err + } + if quotaRaw == nil { + return nil, nil + } + + return quotaRaw.(Quota), nil +} + +// QuotaByName queries for a quota rule in the db for a given quota name +func (m *Manager) QuotaByName(qType string, name string) (Quota, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + txn := m.db.Txn(false) + + quotaRaw, err := txn.First(qType, indexName, name) + if err != nil { + return nil, err + } + if quotaRaw == nil { + return nil, nil + } + + return quotaRaw.(Quota), nil +} + +// QuotaByFactors returns the quota rule that matches the provided factors +func (m *Manager) QuotaByFactors(ctx context.Context, qType, nsPath, mountPath string) (Quota, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + // nsPath would have been made non-empty during insertion. Use non-empty value + // during query as well. + if nsPath == "" { + nsPath = "root" + } + + idx := indexNamespace + args := []interface{}{nsPath, false} + if mountPath != "" { + idx = indexNamespaceMount + args = []interface{}{nsPath, mountPath} + } + + txn := m.db.Txn(false) + iter, err := txn.Get(qType, idx, args...) + if err != nil { + return nil, err + } + var quotas []Quota + for raw := iter.Next(); raw != nil; raw = iter.Next() { + quotas = append(quotas, raw.(Quota)) + } + if len(quotas) > 1 { + return nil, fmt.Errorf("conflicting quota definitions detected") + } + if len(quotas) == 0 { + return nil, nil + } + + return quotas[0], nil +} + +// queryQuota returns the quota rule that is applicable for the given request. It +// queries all the quota rules that are defined against request values and finds +// the quota rule that takes priority. +// +// Priority rules are as follows: +// - namespace specific quota takes precedence over global quota +// - mount specific quota takes precedence over namespace specific quota +func (m *Manager) queryQuota(txn *memdb.Txn, req *Request) (Quota, error) { + if txn == nil { + txn = m.db.Txn(false) + } + + // ns would have been made non-empty during insertion. Use non-empty + // value during query as well. + if req.NamespacePath == "" { + req.NamespacePath = "root" + } + + // + // Find a match from most specific applicable quota rule to less specific one. + // + quotaFetchFunc := func(idx string, args ...interface{}) (Quota, error) { + iter, err := txn.Get(req.Type.String(), idx, args...) + if err != nil { + return nil, err + } + var quotas []Quota + for raw := iter.Next(); raw != nil; raw = iter.Next() { + quota := raw.(Quota) + quotas = append(quotas, quota) + } + if len(quotas) > 1 { + return nil, fmt.Errorf("conflicting quota definitions detected") + } + if len(quotas) == 0 { + return nil, nil + } + + return quotas[0], nil + } + + // Fetch mount quota + quota, err := quotaFetchFunc(indexNamespaceMount, req.NamespacePath, req.MountPath) + if err != nil { + return nil, err + } + if quota != nil { + return quota, nil + } + + // Fetch ns quota. If NamespacePath is root, this will return the global quota. + quota, err = quotaFetchFunc(indexNamespace, req.NamespacePath, false) + if err != nil { + return nil, err + } + if quota != nil { + return quota, nil + } + + // If the request belongs to "root" namespace, then we have already looked at + // global quotas when fetching namespace specific quota rule. When the request + // belongs to a non-root namespace, and when there are no namespace specific + // quota rules present, we fallback on the global quotas. + if req.NamespacePath == "root" { + return nil, nil + } + + // Fetch global quota + quota, err = quotaFetchFunc(indexNamespace, "root", false) + if err != nil { + return nil, err + } + if quota != nil { + return quota, nil + } + + return nil, nil +} + +// DeleteQuota removes a quota rule from the db for a given name +func (m *Manager) DeleteQuota(ctx context.Context, qType string, name string) error { + m.lock.Lock() + defer m.lock.Unlock() + + txn := m.db.Txn(true) + defer txn.Abort() + + raw, err := txn.First(qType, indexName, name) + if err != nil { + return err + } + if raw == nil { + return nil + } + + quota := raw.(Quota) + if err := quota.close(); err != nil { + return err + } + + err = txn.Delete(qType, raw) + if err != nil { + return err + } + + // For the lease count type, recompute the counters + if qType == TypeLeaseCount.String() { + if err := m.recomputeLeaseCounts(ctx, txn); err != nil { + return err + } + } + + txn.Commit() + return nil +} + +// ApplyQuota runs the request against any quota rule that is applicable to it. If +// there are multiple quota rule that matches the request parameters, rule that +// takes precedence will be used to allow/reject the request. +func (m *Manager) ApplyQuota(req *Request) (Response, error) { + var resp Response + + quota, err := m.queryQuota(nil, req) + if err != nil { + return resp, err + } + + // If there is no quota defined, allow the request. + if quota == nil { + resp.Allowed = true + return resp, nil + } + + // If the quota type is lease count, and if the path is not known to + // generate leases, allow the request. + if req.Type == TypeLeaseCount && !m.inLeasePathCache(req.Path) { + resp.Allowed = true + return resp, nil + } + + return quota.allow(req) +} + +// SetEnableRateLimitAuditLogging updates the operator preference regarding the +// audit logging behavior. +func (m *Manager) SetEnableRateLimitAuditLogging(val bool) { + m.config.EnableRateLimitAuditLogging = val +} + +// RateLimitAuditLoggingEnabled returns if the quota configuration allows audit +// logging of request rejections due to rate limiting quota rule violations. +func (m *Manager) RateLimitAuditLoggingEnabled() bool { + return m.config.EnableRateLimitAuditLogging +} + +// Config returns the operator preferences in the quota manager +func (m *Manager) Config() *Config { + return m.config +} + +// Reset will clear all the quotas from the db and clear the lease path cache. +func (m *Manager) Reset() error { + m.lock.Lock() + defer m.lock.Unlock() + + var err error + m.db, err = memdb.NewMemDB(dbSchema()) + if err != nil { + return err + } + + m.storage = nil + m.ctx = nil + + m.entManager.Reset() + return nil +} + +// dbSchema creates a DB schema for holding all the quota rules. It creates a +// table for each supported type of quota. +func dbSchema() *memdb.DBSchema { + schema := &memdb.DBSchema{ + Tables: make(map[string]*memdb.TableSchema), + } + + commonSchema := func(name string) *memdb.TableSchema { + return &memdb.TableSchema{ + Name: name, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ID", + }, + }, + indexName: { + Name: indexName, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "Name", + }, + }, + indexNamespace: { + Name: indexNamespace, + Indexer: &memdb.CompoundMultiIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "NamespacePath", + }, + // By sending false as the query parameter, we can + // query just the namespace specific quota. + &memdb.FieldSetIndex{ + Field: "MountPath", + }, + }, + }, + }, + indexNamespaceMount: { + Name: indexNamespaceMount, + AllowMissing: true, + Indexer: &memdb.CompoundMultiIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{ + Field: "NamespacePath", + }, + &memdb.StringFieldIndex{ + Field: "MountPath", + }, + }, + }, + }, + }, + } + } + + // Create a table per quota type. This allows names to be reused between + // different quota types and querying a bit easier. + for _, name := range quotaTypes() { + schema.Tables[name] = commonSchema(name) + } + + return schema +} + +// Invalidate receives notifications from the replication sub-system when a key +// is updated in the storage. This function will read the key from storage and +// updates the caches and data structures to reflect those updates. +func (m *Manager) Invalidate(key string) { + switch key { + case "config": + config, err := LoadConfig(m.ctx, m.storage) + if err != nil { + m.logger.Error("failed to invalidate quota config", "error", err) + return + } + m.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging) + default: + splitKeys := strings.Split(key, "/") + if len(splitKeys) != 2 { + m.logger.Error("incorrect key while invalidating quota rule") + return + } + qType := splitKeys[0] + name := splitKeys[1] + + // Read quota rule from storage + quota, err := Load(m.ctx, m.storage, qType, name) + if err != nil { + m.logger.Error("failed to read invalidated quota rule", "error", err) + return + } + + switch { + case quota == nil: + // Handle quota deletion + if err := m.DeleteQuota(m.ctx, qType, name); err != nil { + m.logger.Error("failed to delete invalidated quota rule", "error", err) + return + } + default: + // Handle quota update + if err := m.SetQuota(m.ctx, qType, quota, false); err != nil { + m.logger.Error("failed to update invalidated quota rule", "error", err) + return + } + } + } +} + +// LoadConfig reads the quota configuration from the underlying storage +func LoadConfig(ctx context.Context, storage logical.Storage) (*Config, error) { + var config Config + entry, err := storage.Get(ctx, ConfigPath) + if err != nil { + return nil, err + } + if entry == nil { + return &config, nil + } + + err = entry.DecodeJSON(&config) + if err != nil { + return nil, err + } + + return &config, nil +} + +// Load reads the quota rule from the underlying storage +func Load(ctx context.Context, storage logical.Storage, qType, name string) (Quota, error) { + var quota Quota + entry, err := storage.Get(ctx, QuotaStoragePath(qType, name)) + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + switch qType { + case TypeRateLimit.String(): + quota = &RateLimitQuota{} + case TypeLeaseCount.String(): + quota = &LeaseCountQuota{} + default: + return nil, fmt.Errorf("unsupported type: %v", qType) + } + + err = entry.DecodeJSON(quota) + if err != nil { + return nil, err + } + + return quota, nil +} + +// Setup loads the quota configuration and all the quota rules into the +// quota manager. +func (m *Manager) Setup(ctx context.Context, storage logical.Storage, isPerfStandby bool) error { + m.lock.Lock() + defer m.lock.Unlock() + + m.storage = storage + m.ctx = ctx + m.isPerfStandby = isPerfStandby + + // Load the quota configuration from storage and load it into the quota + // manager. + config, err := LoadConfig(ctx, storage) + if err != nil { + return err + } + m.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging) + + // Load the quota rules for all supported types from storage and load it in + // the quota manager. + for _, qType := range quotaTypes() { + names, err := logical.CollectKeys(ctx, logical.NewStorageView(storage, StoragePrefix+qType+"/")) + if err != nil { + return nil + } + for _, name := range names { + quota, err := Load(ctx, m.storage, qType, name) + if err != nil { + return err + } + + if quota == nil { + continue + } + + err = m.setQuotaLocked(ctx, qType, quota, true) + if err != nil { + return err + } + } + } + + return nil +} + +// QuotaStoragePath returns the storage path suffix for persisting the quota +// rule. +func QuotaStoragePath(quotaType, name string) string { + return path.Join(StoragePrefix+quotaType, name) +} + +// HandleRemount updates the quota subsystem about the remount operation that +// took place. Quota manager will trigger the quota specific updates including +// the mount path update.. +func (m *Manager) HandleRemount(ctx context.Context, nsPath, fromPath, toPath string) error { + m.lock.Lock() + defer m.lock.Unlock() + + txn := m.db.Txn(true) + defer txn.Abort() + + // nsPath would have been made non-empty during insertion. Use non-empty value + // during query as well. + if nsPath == "" { + nsPath = "root" + } + + idx := indexNamespaceMount + leaseQuotaUpdated := false + args := []interface{}{nsPath, fromPath} + for _, quotaType := range quotaTypes() { + iter, err := txn.Get(quotaType, idx, args...) + if err != nil { + return err + } + for raw := iter.Next(); raw != nil; raw = iter.Next() { + quota := raw.(Quota) + quota.handleRemount(toPath) + entry, err := logical.StorageEntryJSON(QuotaStoragePath(quotaType, quota.QuotaName()), quota) + if err != nil { + return err + } + if err := m.storage.Put(ctx, entry); err != nil { + return err + } + if quotaType == TypeLeaseCount.String() { + leaseQuotaUpdated = true + } + } + } + + if leaseQuotaUpdated { + if err := m.recomputeLeaseCounts(ctx, txn); err != nil { + return err + } + } + + txn.Commit() + + return nil +} + +// HandleBackendDisabling updates the quota subsystem with the disabling of auth +// or secret engine disabling. +func (m *Manager) HandleBackendDisabling(ctx context.Context, nsPath, mountPath string) error { + m.lock.Lock() + defer m.lock.Unlock() + + txn := m.db.Txn(true) + defer txn.Abort() + + // nsPath would have been made non-empty during insertion. Use non-empty value + // during query as well. + if nsPath == "" { + nsPath = "root" + } + + idx := indexNamespaceMount + leaseQuotaDeleted := false + args := []interface{}{nsPath, mountPath} + for _, quotaType := range quotaTypes() { + iter, err := txn.Get(quotaType, idx, args...) + if err != nil { + return err + } + for raw := iter.Next(); raw != nil; raw = iter.Next() { + if err := txn.Delete(quotaType, raw); err != nil { + return fmt.Errorf("failed to delete quota from db after mount disabling; namespace %q, err %v", nsPath, err) + } + quota := raw.(Quota) + if err := m.storage.Delete(ctx, QuotaStoragePath(quotaType, quota.QuotaName())); err != nil { + return fmt.Errorf("failed to delete quota from storage after mount disabling; namespace %q, err %v", nsPath, err) + } + if quotaType == TypeLeaseCount.String() { + leaseQuotaDeleted = true + } + } + } + + if leaseQuotaDeleted { + if err := m.recomputeLeaseCounts(ctx, txn); err != nil { + return err + } + } + + txn.Commit() + + return nil +} diff --git a/vault/quotas/quotas_rate_limit.go b/vault/quotas/quotas_rate_limit.go new file mode 100644 index 0000000000..57f240f38e --- /dev/null +++ b/vault/quotas/quotas_rate_limit.go @@ -0,0 +1,282 @@ +package quotas + +import ( + "fmt" + "math" + "sync" + "time" + + "github.com/armon/go-metrics" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/helper/metricsutil" + "github.com/hashicorp/vault/sdk/helper/pathmanager" + "golang.org/x/time/rate" +) + +var rateLimitExemptPaths = pathmanager.New() + +const ( + // DefaultRateLimitPurgeInterval defines the default purge interval used by a + // RateLimitQuota to remove stale client rate limiters. + DefaultRateLimitPurgeInterval = time.Minute + + // DefaultRateLimitStaleAge defines the default stale age of a client limiter. + DefaultRateLimitStaleAge = 3 * time.Minute + + // EnvVaultEnableRateLimitAuditLogging is used to enable audit logging of + // requests that get rejected due to rate limit quota violations. + EnvVaultEnableRateLimitAuditLogging = "VAULT_ENABLE_RATE_LIMIT_AUDIT_LOGGING" +) + +func init() { + rateLimitExemptPaths.AddPaths([]string{ + "/v1/sys/generate-recovery-token/attempt", + "/v1/sys/generate-recovery-token/update", + "/v1/sys/generate-root/attempt", + "/v1/sys/generate-root/update", + "/v1/sys/health", + "/v1/sys/seal-status", + "/v1/sys/unseal", + }) +} + +// ClientRateLimiter defines a token bucket based rate limiter for a unique +// addressable client (e.g. IP address). Whenever this client attempts to make +// a request, the lastSeen value will be updated. +type ClientRateLimiter struct { + // lastSeen defines the UNIX timestamp the client last made a request. + lastSeen time.Time + + // limiter represents an instance of a token bucket based rate limiter. + limiter *rate.Limiter +} + +// newClientRateLimiter returns a token bucket based rate limiter for a client +// that is uniquely addressable, where maxRequests defines the requests-per-second +// and burstSize defines the maximum burst allowed. A caller may provide -1 for +// burstSize to allow the burst value to be roughly equivalent to the RPS. Note, +// the underlying rate limiter is already thread-safe. +func newClientRateLimiter(maxRequests float64, burstSize int) *ClientRateLimiter { + if burstSize < 0 { + burstSize = int(math.Ceil(maxRequests)) + } + + return &ClientRateLimiter{ + lastSeen: time.Now().UTC(), + limiter: rate.NewLimiter(rate.Limit(maxRequests), burstSize), + } +} + +// Ensure that RateLimitQuota implements the Quota interface +var _ Quota = (*RateLimitQuota)(nil) + +// RateLimitQuota represents the quota rule properties that is used to limit the +// number of requests per second for a namespace or mount. +type RateLimitQuota struct { + // ID is the identifier of the quota + ID string `json:"id"` + + // Type of quota this represents + Type Type `json:"type"` + + // Name of the quota rule + Name string `json:"name"` + + // NamespacePath is the path of the namespace to which this quota is + // applicable. + NamespacePath string `json:"namespace_path"` + + // MountPath is the path of the mount to which this quota is applicable + MountPath string `json:"mount_path"` + + // Rate defines the rate of which allowed requests are refilled per second. + Rate float64 `json:"rate"` + + // Burst defines maximum number of requests at any given moment to be allowed. + Burst int `json:"burst"` + + lock *sync.Mutex + logger log.Logger + metricSink *metricsutil.ClusterMetricSink + purgeEnabled bool + + // purgeInterval defines the interval in seconds in which the RateLimitQuota + // attempts to remove stale entries from the rateQuotas mapping. + purgeInterval time.Duration + closeCh chan struct{} + + // staleAge defines the age in seconds in which a clientRateLimiter is + // considered stale. A clientRateLimiter is considered stale if the delta + // between the current purge time and its lastSeen timestamp is greater than + // this value. + staleAge time.Duration + + // rateQuotas contains a mapping from a unique addressable client (e.g. IP address) + // to a clientRateLimiter reference. Every purgeInterval seconds, the RateLimitQuota + // will attempt to remove stale entries from the mapping. + rateQuotas map[string]*ClientRateLimiter +} + +// NewRateLimitQuota creates a quota checker for imposing limits on the number +// of requests per second. +func NewRateLimitQuota(name, nsPath, mountPath string, rate float64, burst int) *RateLimitQuota { + return &RateLimitQuota{ + Name: name, + Type: TypeRateLimit, + NamespacePath: nsPath, + MountPath: mountPath, + Rate: rate, + Burst: burst, + } +} + +// jnitialize ensures the namespace and max requests are initialized, sets the ID +// if it's currently empty, sets the purge interval and stale age to default +// values, and finally starts the client purge go routine if it has been started +// already. Note, initialize will reset the internal rateQuotas mapping. +func (rlq *RateLimitQuota) initialize(logger log.Logger, ms *metricsutil.ClusterMetricSink) error { + if rlq.lock == nil { + rlq.lock = new(sync.Mutex) + } + + rlq.lock.Lock() + defer rlq.lock.Unlock() + + // Memdb requires a non-empty value for indexing + if rlq.NamespacePath == "" { + rlq.NamespacePath = "root" + } + + if rlq.Rate <= 0 { + return fmt.Errorf("invalid avg rps: %v", rlq.Rate) + } + + if rlq.Burst < int(rlq.Rate) { + return fmt.Errorf("burst size (%v) must be greater than or equal to average rps (%v)", rlq.Burst, rlq.Rate) + } + + if logger != nil { + rlq.logger = logger + } + + if rlq.metricSink == nil { + rlq.metricSink = ms + } + + if rlq.ID == "" { + id, err := uuid.GenerateUUID() + if err != nil { + return err + } + + rlq.ID = id + } + + rlq.purgeInterval = DefaultRateLimitPurgeInterval + rlq.staleAge = DefaultRateLimitStaleAge + rlq.rateQuotas = make(map[string]*ClientRateLimiter) + + if !rlq.purgeEnabled { + rlq.purgeEnabled = true + rlq.closeCh = make(chan struct{}) + go rlq.purgeClientsLoop() + } + + return nil +} + +// quotaID returns the identifier of the quota rule +func (rlq *RateLimitQuota) quotaID() string { + return rlq.ID +} + +// QuotaName returns the name of the quota rule +func (rlq *RateLimitQuota) QuotaName() string { + return rlq.Name +} + +// purgeClientsLoop performs a blocking process where every purgeInterval +// duration, we look for stale clients to remove from the rateQuotas map. +// A ClientRateLimiter is considered stale if its lastSeen timestamp exceeds the +// current time. The loop will continue to run indefinitely until a value is +// sent on the closeCh in which we stop the ticker and exit. +func (rlq *RateLimitQuota) purgeClientsLoop() { + ticker := time.NewTicker(rlq.purgeInterval) + + for { + select { + case t := <-ticker.C: + rlq.lock.Lock() + + for client, crl := range rlq.rateQuotas { + if t.UTC().Sub(crl.lastSeen) >= rlq.staleAge { + delete(rlq.rateQuotas, client) + } + } + + rlq.lock.Unlock() + + case <-rlq.closeCh: + ticker.Stop() + rlq.purgeEnabled = false + return + } + } +} + +// clientRateLimiter returns a reference to a ClientRateLimiter based on a +// provided client address (e.g. IP address). If the ClientRateLimiter does not +// exist in the RateLimitQuota's mapping, one will be created and set. The +// created RateLimitQuota will have its requests-per-second set to +// RateLimitQuota.AverageRps. If the ClientRateLimiter already exists, the +// lastSeen timestamp will be updated. +func (rlq *RateLimitQuota) clientRateLimiter(addr string) *ClientRateLimiter { + rlq.lock.Lock() + defer rlq.lock.Unlock() + + crl, ok := rlq.rateQuotas[addr] + if !ok { + limiter := newClientRateLimiter(rlq.Rate, rlq.Burst) + rlq.rateQuotas[addr] = limiter + return limiter + } + + crl.lastSeen = time.Now().UTC() + return crl +} + +// allow decides if the request is allowed by the quota. An error will be +// returned if the request ID or address is empty. If the path is exempt, the +// quota will not be evaluated. Otherwise, the client rate limiter is retrieved +// by address and the rate limit quota is checked against that limiter. +func (rlq *RateLimitQuota) allow(req *Request) (Response, error) { + var resp Response + + // Skip rate limit checks for paths that are exempt from rate limiting. + if rateLimitExemptPaths.HasPath(req.Path) { + resp.Allowed = true + return resp, nil + } + + if req.ClientAddress == "" { + return resp, fmt.Errorf("missing request client address in quota request") + } + + resp.Allowed = rlq.clientRateLimiter(req.ClientAddress).limiter.Allow() + if !resp.Allowed { + rlq.metricSink.IncrCounterWithLabels([]string{"quota", "rate_limit", "violation"}, 1, []metrics.Label{{"name", rlq.Name}}) + } + + return resp, nil +} + +// close stops the current running client purge loop. +func (rlq *RateLimitQuota) close() error { + close(rlq.closeCh) + return nil +} + +func (rlq *RateLimitQuota) handleRemount(toPath string) { + rlq.MountPath = toPath +} diff --git a/vault/quotas/quotas_rate_limit_test.go b/vault/quotas/quotas_rate_limit_test.go new file mode 100644 index 0000000000..8736cbd5d1 --- /dev/null +++ b/vault/quotas/quotas_rate_limit_test.go @@ -0,0 +1,173 @@ +package quotas + +import ( + "fmt" + "sync" + "testing" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/metricsutil" + "github.com/hashicorp/vault/sdk/helper/logging" + "go.uber.org/atomic" +) + +func TestNewClientRateLimiter(t *testing.T) { + testCases := []struct { + maxRequests float64 + burstSize int + expectedBurst int + }{ + {1000, -1, 1000}, + {1000, 5000, 5000}, + {16.1, -1, 17}, + {16.7, -1, 17}, + {16.7, 100, 100}, + } + + for _, tc := range testCases { + crl := newClientRateLimiter(tc.maxRequests, tc.burstSize) + b := crl.limiter.Burst() + if b != tc.expectedBurst { + t.Fatalf("unexpected burst size; expected: %d, got: %d", tc.expectedBurst, b) + } + } +} + +func TestNewRateLimitQuota(t *testing.T) { + rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, 50) + if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil { + t.Fatal(err) + } + + if !rlq.purgeEnabled { + t.Fatal("expected rate limit quota to start purge loop") + } + + if rlq.purgeInterval != DefaultRateLimitPurgeInterval { + t.Fatalf("unexpected purgeInterval; expected: %d, got: %d", DefaultRateLimitPurgeInterval, rlq.purgeInterval) + } + if rlq.staleAge != DefaultRateLimitStaleAge { + t.Fatalf("unexpected staleAge; expected: %d, got: %d", DefaultRateLimitStaleAge, rlq.staleAge) + } +} + +func TestRateLimitQuota_Close(t *testing.T) { + rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, 50) + + if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil { + t.Fatal(err) + } + + if err := rlq.close(); err != nil { + t.Fatalf("unexpected error when closing: %v", err) + } + + time.Sleep(time.Second) // allow enough time for purgeClientsLoop to receive on closeCh + + if rlq.purgeEnabled { + t.Fatal("expected client purging to be disabled after close") + } +} + +func TestRateLimitQuota_Allow(t *testing.T) { + rlq := &RateLimitQuota{ + Name: "test-rate-limiter", + Type: TypeRateLimit, + NamespacePath: "qa", + MountPath: "/foo/bar", + Rate: 16.7, + Burst: 83, + purgeEnabled: true, // to allow manual setting of purgeInterval and staleAge + } + + if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil { + t.Fatal(err) + } + + // override value and manually start purgeClientsLoop for testing purposes + rlq.purgeInterval = 10 * time.Second + rlq.staleAge = 10 * time.Second + go rlq.purgeClientsLoop() + + var wg sync.WaitGroup + + type clientResult struct { + atomicNumAllow *atomic.Int32 + atomicNumFail *atomic.Int32 + } + + reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) { + defer wg.Done() + + resp, err := rlq.allow(&Request{ClientAddress: addr}) + if err != nil { + return + } + + if resp.Allowed { + atomicNumAllow.Add(1) + } else { + atomicNumFail.Add(1) + } + } + + results := make(map[string]*clientResult) + + start := time.Now() + end := start.Add(5 * time.Second) + for time.Now().Before(end) { + + for i := 0; i < 5; i++ { + wg.Add(1) + + addr := fmt.Sprintf("127.0.0.%d", i) + cr, ok := results[addr] + if !ok { + results[addr] = &clientResult{atomicNumAllow: atomic.NewInt32(0), atomicNumFail: atomic.NewInt32(0)} + cr = results[addr] + } + + go reqFunc(addr, cr.atomicNumAllow, cr.atomicNumFail) + + time.Sleep(2 * time.Millisecond) + } + + } + + wg.Wait() + + if got, expected := len(results), len(rlq.rateQuotas); got != expected { + t.Fatalf("unexpected number of tracked client rate limit quotas; got %d, expected; %d", got, expected) + } + + elapsed := time.Since(start) + + // evaluate the ideal RPS as (burst + (RPS * totalSeconds)) + ideal := float64(rlq.Burst) + (rlq.Rate * float64(elapsed) / float64(time.Second)) + + for addr, cr := range results { + numAllow := cr.atomicNumAllow.Load() + numFail := cr.atomicNumFail.Load() + + // ensure there were some failed requests for the namespace + if numFail == 0 { + t.Fatalf("expected some requests to fail; addr: %s, numSuccess: %d, numFail: %d, elapsed: %d", addr, numAllow, numFail, elapsed) + } + + // ensure that we should never get more requests than allowed for the namespace + if want := int32(ideal + 1); numAllow > want { + t.Fatalf("too many successful requests; addr: %s, want: %d, numSuccess: %d, numFail: %d, elapsed: %d", addr, want, numAllow, numFail, elapsed) + } + } + + // allow enough time for the client to be purged + time.Sleep(rlq.purgeInterval * 2) + + for addr := range results { + rlc, ok := rlq.rateQuotas[addr] + if ok || rlc != nil { + t.Fatalf("expected stale client to be purged: %s", addr) + } + } +} diff --git a/vault/quotas/quotas_test.go b/vault/quotas/quotas_test.go new file mode 100644 index 0000000000..79b87169a6 --- /dev/null +++ b/vault/quotas/quotas_test.go @@ -0,0 +1,67 @@ +package quotas + +import ( + "context" + "testing" + + "github.com/go-test/deep" + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/metricsutil" + "github.com/hashicorp/vault/sdk/helper/logging" +) + +func TestQuotas_Precedence(t *testing.T) { + qm, err := NewManager(logging.NewVaultLogger(log.Trace), nil, metricsutil.BlackholeSink()) + if err != nil { + t.Fatal(err) + } + + setQuotaFunc := func(t *testing.T, name, nsPath, mountPath string) Quota { + t.Helper() + quota := NewRateLimitQuota(name, nsPath, mountPath, 10, 20) + err := qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true) + if err != nil { + t.Fatal(err) + } + return quota + } + + checkQuotaFunc := func(t *testing.T, nsPath, mountPath string, expected Quota) { + t.Helper() + quota, err := qm.queryQuota(nil, &Request{ + Type: TypeRateLimit, + NamespacePath: nsPath, + MountPath: mountPath, + }) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(expected, quota); len(diff) > 0 { + t.Fatal(diff) + } + } + + // No quota present. Expect nil. + checkQuotaFunc(t, "", "", nil) + + // Define global quota and expect that to be returned. + rateLimitGlobalQuota := setQuotaFunc(t, "rateLimitGlobalQuota", "", "") + checkQuotaFunc(t, "", "", rateLimitGlobalQuota) + + // Define a global mount specific quota and expect that to be returned. + rateLimitGlobalMountQuota := setQuotaFunc(t, "rateLimitGlobalMountQuota", "", "testmount") + checkQuotaFunc(t, "", "testmount", rateLimitGlobalMountQuota) + + // Define a namespace quota and expect that to be returned. + rateLimitNSQuota := setQuotaFunc(t, "rateLimitNSQuota", "testns", "") + checkQuotaFunc(t, "testns", "", rateLimitNSQuota) + + // Define a namespace mount specific quota and expect that to be returned. + rateLimitNSMountQuota := setQuotaFunc(t, "rateLimitNSMountQuota", "testns", "testmount") + checkQuotaFunc(t, "testns", "testmount", rateLimitNSMountQuota) + + // Now that many quota types are defined, verify that the most specific + // matches are returned per namespace. + checkQuotaFunc(t, "", "", rateLimitGlobalQuota) + checkQuotaFunc(t, "testns", "", rateLimitNSQuota) +} diff --git a/vault/quotas/quotas_util.go b/vault/quotas/quotas_util.go new file mode 100644 index 0000000000..983417476c --- /dev/null +++ b/vault/quotas/quotas_util.go @@ -0,0 +1,65 @@ +// +build !enterprise + +package quotas + +import ( + "context" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/helper/metricsutil" + + memdb "github.com/hashicorp/go-memdb" +) + +func quotaTypes() []string { + return []string{ + TypeRateLimit.String(), + } +} + +func (m *Manager) init(walkFunc leaseWalkFunc) {} + +func (m *Manager) recomputeLeaseCounts(ctx context.Context, txn *memdb.Txn) error { + return nil +} + +func (m *Manager) setIsPerfStandby(quota Quota) {} + +func (m *Manager) inLeasePathCache(path string) bool { + return false +} + +type entManager struct { + isPerfStandby bool +} + +func (*entManager) Reset() error { + return nil +} + +type LeaseCountQuota struct { +} + +func (l LeaseCountQuota) allow(request *Request) (Response, error) { + panic("implement me") +} + +func (l LeaseCountQuota) quotaID() string { + panic("implement me") +} + +func (l LeaseCountQuota) QuotaName() string { + panic("implement me") +} + +func (l LeaseCountQuota) initialize(logger log.Logger, sink *metricsutil.ClusterMetricSink) error { + panic("implement me") +} + +func (l LeaseCountQuota) close() error { + panic("implement me") +} + +func (l LeaseCountQuota) handleRemount(s string) { + panic("implement me") +} diff --git a/vault/request_handling.go b/vault/request_handling.go index e0086bb169..b082a1cb5a 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -24,6 +24,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/quotas" uberAtomic "go.uber.org/atomic" ) @@ -539,7 +540,6 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp } // Create an audit trail of the response - if !isControlGroupRun(req) { switch req.Path { case "sys/replication/dr/status", "sys/replication/performance/status", "sys/replication/status": @@ -708,6 +708,36 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp } } + leaseGenerated := false + quotaResp, quotaErr := c.applyLeaseCountQuota("as.Request{ + Path: req.Path, + MountPath: strings.TrimPrefix(req.MountPoint, ns.Path), + NamespacePath: ns.Path, + }) + if quotaErr != nil { + c.logger.Error("failed to apply quota", "path", req.Path, "error", err) + retErr = multierror.Append(retErr, quotaErr) + return nil, auth, retErr + } + + if !quotaResp.Allowed { + if c.logger.IsTrace() { + c.logger.Trace("request rejected due to lease count quota violation", "request_path", req.Path) + } + + retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrLeaseCountQuotaExceeded)) + return nil, auth, retErr + } + + defer func() { + if quotaResp.Access != nil { + quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated) + if quotaAckErr != nil { + retErr = multierror.Append(retErr, quotaAckErr) + } + } + }() + // Route the request resp, routeErr := c.doRouting(ctx, req) if resp != nil { @@ -827,6 +857,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp retErr = multierror.Append(retErr, ErrInternalError) return nil, auth, retErr } + leaseGenerated = true resp.Secret.LeaseID = leaseID // Get the actual time of the lease @@ -917,6 +948,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp retErr = multierror.Append(retErr, ErrInternalError) return nil, auth, retErr } + leaseGenerated = true } } @@ -1073,6 +1105,46 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re // If the response generated an authentication, then generate the token if resp != nil && resp.Auth != nil { + ns, err := namespace.FromContext(ctx) + if err != nil { + c.logger.Error("failed to get namespace from context", "error", err) + retErr = multierror.Append(retErr, ErrInternalError) + return + } + + leaseGenerated := false + + // The request successfully authenticated itself. Run the quota checks + // before creating lease. + quotaResp, quotaErr := c.applyLeaseCountQuota("as.Request{ + Path: req.Path, + MountPath: strings.TrimPrefix(req.MountPoint, ns.Path), + NamespacePath: ns.Path, + }) + + if quotaErr != nil { + c.logger.Error("failed to apply quota", "path", req.Path, "error", err) + retErr = multierror.Append(retErr, quotaErr) + return + } + + if !quotaResp.Allowed { + if c.logger.IsTrace() { + c.logger.Trace("request rejected due to lease count quota violation", "request_path", req.Path) + } + + retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrLeaseCountQuotaExceeded)) + return + } + + defer func() { + if quotaResp.Access != nil { + quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated) + if quotaAckErr != nil { + retErr = multierror.Append(retErr, quotaAckErr) + } + } + }() var entity *identity.Entity auth = resp.Auth @@ -1141,10 +1213,6 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re resp.AddWarning(warning) } - ns, err := namespace.FromContext(ctx) - if err != nil { - return nil, nil, err - } _, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, ns, auth.EntityID) if err != nil { return nil, nil, ErrInternalError @@ -1181,6 +1249,9 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re err = registerFunc(ctx, tokenTTL, req.Path, auth) switch { case err == nil: + if auth.TokenType != logical.TokenTypeBatch { + leaseGenerated = true + } case err == ErrInternalError: return nil, auth, err default: diff --git a/vault/router.go b/vault/router.go index 2666c0071d..9624b09b59 100644 --- a/vault/router.go +++ b/vault/router.go @@ -422,6 +422,14 @@ func (r *Router) MatchingSystemView(ctx context.Context, path string) logical.Sy return raw.(*routeEntry).backend.System() } +func (r *Router) MatchingMountByAPIPath(ctx context.Context, path string) string { + me, _, _ := r.matchingMountEntryByPath(ctx, path, true) + if me == nil { + return "" + } + return me.Path +} + // MatchingStoragePrefixByAPIPath the storage prefix for the given api path func (r *Router) MatchingStoragePrefixByAPIPath(ctx context.Context, path string) (string, bool) { ns, err := namespace.FromContext(ctx) diff --git a/vendor/github.com/hashicorp/vault/api/response.go b/vendor/github.com/hashicorp/vault/api/response.go index aed2a52e08..ae350c9791 100644 --- a/vendor/github.com/hashicorp/vault/api/response.go +++ b/vendor/github.com/hashicorp/vault/api/response.go @@ -27,8 +27,8 @@ func (r *Response) DecodeJSON(out interface{}) error { // body must still be closed manually. func (r *Response) Error() error { // 200 to 399 are okay status codes. 429 is the code for health status of - // standby nodes. - if (r.StatusCode >= 200 && r.StatusCode < 400) || r.StatusCode == 429 { + // standby nodes, otherwise, 429 is treated as quota limit reached. + if (r.StatusCode >= 200 && r.StatusCode < 400) || (r.StatusCode == 429 && r.Request.URL.Path == "/v1/sys/health") { return nil } diff --git a/vendor/github.com/hashicorp/vault/sdk/logical/error.go b/vendor/github.com/hashicorp/vault/sdk/logical/error.go index fd896a6ce3..aab73cc066 100644 --- a/vendor/github.com/hashicorp/vault/sdk/logical/error.go +++ b/vendor/github.com/hashicorp/vault/sdk/logical/error.go @@ -28,6 +28,14 @@ var ( // ErrPerfStandbyForward is returned when Vault is in a state such that a // perf standby cannot satisfy a request ErrPerfStandbyPleaseForward = errors.New("please forward to the active node") + + // ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease + // count quota being exceeded. + ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded") + + // ErrRateLimitQuotaExceeded is returned when a request is rejected due to a + // rate limit quota being exceeded. + ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded") ) type HTTPCodedError interface { diff --git a/vendor/github.com/hashicorp/vault/sdk/logical/response_util.go b/vendor/github.com/hashicorp/vault/sdk/logical/response_util.go index ee57f8e05a..ce743507fb 100644 --- a/vendor/github.com/hashicorp/vault/sdk/logical/response_util.go +++ b/vendor/github.com/hashicorp/vault/sdk/logical/response_util.go @@ -81,7 +81,7 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { } }) if allErrors != nil { - return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors) + return codedErr.Code, multierror.Append(fmt.Errorf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg), allErrors) } return codedErr.Code, errors.New(codedErr.Msg) } @@ -110,6 +110,10 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { statusCode = http.StatusBadRequest case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): statusCode = http.StatusBadGateway + case errwrap.Contains(err, ErrRateLimitQuotaExceeded.Error()): + statusCode = http.StatusTooManyRequests + case errwrap.Contains(err, ErrLeaseCountQuotaExceeded.Error()): + statusCode = http.StatusTooManyRequests } } diff --git a/website/pages/docs/internals/telemetry.mdx b/website/pages/docs/internals/telemetry.mdx index e8d0dc1dd4..6d71dbd206 100644 --- a/website/pages/docs/internals/telemetry.mdx +++ b/website/pages/docs/internals/telemetry.mdx @@ -163,6 +163,17 @@ These metrics cover measurement of token, identity, and lease operations, and co | `vault.token.revoke-tree` | Time taken to revoke a token tree | ms | summary | | `vault.token.store` | Time taken to store an updated token entry without writing to the secondary index | ms | summary | +## Resource Quota Metrics + +These metrics relate to rate limit and lease count quotas. Each metric comes with a label "name" identifying the specific quota. + +| Metric | Description | Unit | Type | +| :---------------------------- | :---------------------------------------------------------------- | :---- | :------ | +| `quota.rate_limit.violation` | Total number of rate limit quota violations | quota | counter | +| `quota.lease_count.violation` | Total number of lease count quota violations | quota | counter | +| `quota.lease_count.max` | Total maximum amount of leases allowed by the lease count quota | lease | gauge | +| `quota.lease_count.counter` | Total current amount of leases generated by the lease count quota | lease | gauge | + ## Merkle Tree and Write Ahead Log Metrics These metrics relate to internal operations on Merkle Trees and Write Ahead Logs (WAL)