Resource Quotas: Rate Limiting (#9330)

This commit is contained in:
Vishal Nayak
2020-06-26 17:13:16 -04:00
committed by GitHub
parent ab08ff4c47
commit c68e270863
29 changed files with 2516 additions and 84 deletions

View File

@@ -27,8 +27,8 @@ func (r *Response) DecodeJSON(out interface{}) error {
// body must still be closed manually. // body must still be closed manually.
func (r *Response) Error() error { func (r *Response) Error() error {
// 200 to 399 are okay status codes. 429 is the code for health status of // 200 to 399 are okay status codes. 429 is the code for health status of
// standby nodes. // standby nodes, otherwise, 429 is treated as quota limit reached.
if (r.StatusCode >= 200 && r.StatusCode < 400) || r.StatusCode == 429 { if (r.StatusCode >= 200 && r.StatusCode < 400) || (r.StatusCode == 429 && r.Request.URL.Path == "/v1/sys/health") {
return nil return nil
} }

1
go.mod
View File

@@ -146,6 +146,7 @@ require (
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9
golang.org/x/net v0.0.0-20200602114024-627f9648deb9 golang.org/x/net v0.0.0-20200602114024-627f9648deb9
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1
golang.org/x/tools v0.0.0-20200416214402-fc959738d646 golang.org/x/tools v0.0.0-20200416214402-fc959738d646
google.golang.org/api v0.24.0 google.golang.org/api v0.24.0
google.golang.org/grpc v1.29.1 google.golang.org/grpc v1.29.1

View File

@@ -176,8 +176,8 @@ func Handler(props *vault.HandlerProperties) http.Handler {
// Wrap the handler in another handler to trigger all help paths. // Wrap the handler in another handler to trigger all help paths.
helpWrappedHandler := wrapHelpHandler(mux, core) helpWrappedHandler := wrapHelpHandler(mux, core)
corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core) corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core)
quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core)
genericWrappedHandler := genericWrapping(core, corsWrappedHandler, props) genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props)
// Wrap the handler with PrintablePathCheckHandler to check for non-printable // Wrap the handler with PrintablePathCheckHandler to check for non-printable
// characters in the request path. // characters in the request path.
@@ -221,26 +221,14 @@ func (w *copyResponseWriter) WriteHeader(code int) {
func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origBody := new(bytes.Buffer)
reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody))
r.Body = reader
req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r)
if err != nil || status != 0 {
respondError(w, status, err)
return
}
if origBody != nil {
r.Body = ioutil.NopCloser(origBody)
}
input := &logical.LogInput{ input := &logical.LogInput{
Request: req, Request: w.(*LogicalResponseWriter).request,
} }
core.AuditLogger().AuditRequest(r.Context(), input) core.AuditLogger().AuditRequest(r.Context(), input)
cw := newCopyResponseWriter(w) cw := newCopyResponseWriter(w)
h.ServeHTTP(cw, r) h.ServeHTTP(cw, r)
data := make(map[string]interface{}) data := make(map[string]interface{})
err = jsonutil.DecodeJSON(cw.body.Bytes(), &data) err := jsonutil.DecodeJSON(cw.body.Bytes(), &data)
if err != nil { if err != nil {
// best effort, ignore // best effort, ignore
} }
@@ -249,7 +237,13 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
core.AuditLogger().AuditResponse(r.Context(), input) core.AuditLogger().AuditResponse(r.Context(), input)
return return
}) })
}
// LogicalResponseWriter is used to carry the logical request from generic
// handler down to all the middleware http handlers.
type LogicalResponseWriter struct {
http.ResponseWriter
request *logical.Request
} }
// wrapGenericHandler wraps the handler with an extra layer of handler where // wrapGenericHandler wraps the handler with an extra layer of handler where
@@ -288,6 +282,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
} }
ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) ctx = context.WithValue(ctx, "original_request_path", r.URL.Path)
r = r.WithContext(ctx) r = r.WithContext(ctx)
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
switch { switch {
case strings.HasPrefix(r.URL.Path, "/v1/"): case strings.HasPrefix(r.URL.Path, "/v1/"):
@@ -306,7 +301,27 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
return return
} }
h.ServeHTTP(w, r) origBody := new(bytes.Buffer)
reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody))
r.Body = reader
req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r)
if err != nil || status != 0 {
respondError(w, status, err)
return
}
// Reset the body since logical request creation already read the
// request body.
r.Body = ioutil.NopCloser(origBody)
// Set the mount path in the request
req.MountPoint = core.MatchingMount(r.Context(), req.Path)
// Pass the logical request down through the response writer
h.ServeHTTP(&LogicalResponseWriter{
ResponseWriter: w,
request: req,
}, r)
cancelFunc() cancelFunc()
return return
}) })

View File

@@ -141,6 +141,7 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.
} }
case "OPTIONS": case "OPTIONS":
case "HEAD":
default: default:
return nil, nil, http.StatusMethodNotAllowed, nil return nil, nil, http.StatusMethodNotAllowed, nil
} }
@@ -169,36 +170,32 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http.
return req, origBody, 0, nil return req, origBody, 0, nil
} }
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { func setupLogicalRequest(core *vault.Core, req *logical.Request, r *http.Request) (*logical.Request, int, error) {
req, origBody, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) var err error
if err != nil || status != 0 {
return nil, nil, status, err
}
req, err = requestAuth(core, r, req) req, err = requestAuth(core, r, req)
if err != nil { if err != nil {
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
return nil, nil, http.StatusForbidden, nil return nil, http.StatusForbidden, nil
} }
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err) return nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err)
} }
req, err = requestWrapInfo(r, req) req, err = requestWrapInfo(r, req)
if err != nil { if err != nil {
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err) return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
} }
err = parseMFAHeader(req) err = parseMFAHeader(req)
if err != nil { if err != nil {
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err) return nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err)
} }
err = requestPolicyOverride(r, req) err = requestPolicyOverride(r, req)
if err != nil { if err != nil {
return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err) return nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
} }
return req, origBody, 0, nil return req, 0, nil
} }
// handleLogical returns a handler for processing logical requests. These requests // handleLogical returns a handler for processing logical requests. These requests
@@ -257,7 +254,7 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han
// toggles. Refer to usage on functions for possible behaviors. // toggles. Refer to usage on functions for possible behaviors.
func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForward bool) http.Handler { func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForward bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, origBody, statusCode, err := buildLogicalRequest(core, w, r) req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r)
if err != nil || statusCode != 0 { if err != nil || statusCode != 0 {
respondError(w, statusCode, err) respondError(w, statusCode, err)
return return
@@ -270,10 +267,6 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly)
return return
} }
if origBody != nil {
r.Body = origBody
}
forwardRequest(core, w, r) forwardRequest(core, w, r)
return return
} }
@@ -398,9 +391,6 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly)
return return
case needsForward && !noForward: case needsForward && !noForward:
if origBody != nil {
r.Body = origBody
}
forwardRequest(core, w, r) forwardRequest(core, w, r)
return return
case !ok: case !ok:

View File

@@ -281,7 +281,13 @@ func TestLogical_ListSuffix(t *testing.T) {
req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil) req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil)
req = req.WithContext(namespace.RootContext(nil)) req = req.WithContext(namespace.RootContext(nil))
req.Header.Add(consts.AuthHeaderName, rootToken) req.Header.Add(consts.AuthHeaderName, rootToken)
lreq, _, status, err := buildLogicalRequest(core, nil, req)
lreq, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), nil, req)
if err != nil || status != 0 {
t.Fatal(err)
}
lreq, status, err = setupLogicalRequest(core, lreq, req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -295,7 +301,11 @@ func TestLogical_ListSuffix(t *testing.T) {
req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil) req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil)
req = req.WithContext(namespace.RootContext(nil)) req = req.WithContext(namespace.RootContext(nil))
req.Header.Add(consts.AuthHeaderName, rootToken) req.Header.Add(consts.AuthHeaderName, rootToken)
lreq, _, status, err = buildLogicalRequest(core, nil, req) lreq, _, status, err = buildLogicalRequestNoAuth(core.PerfStandby(), nil, req)
if err != nil || status != 0 {
t.Fatal(err)
}
lreq, status, err = setupLogicalRequest(core, lreq, req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -309,7 +319,11 @@ func TestLogical_ListSuffix(t *testing.T) {
req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil) req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil)
req = req.WithContext(namespace.RootContext(nil)) req = req.WithContext(namespace.RootContext(nil))
req.Header.Add(consts.AuthHeaderName, rootToken) req.Header.Add(consts.AuthHeaderName, rootToken)
lreq, _, status, err = buildLogicalRequest(core, nil, req) lreq, _, status, err = buildLogicalRequestNoAuth(core.PerfStandby(), nil, req)
if err != nil || status != 0 {
t.Fatal(err)
}
lreq, status, err = setupLogicalRequest(core, lreq, req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -17,7 +17,7 @@ import (
func handleSysSeal(core *vault.Core) http.Handler { func handleSysSeal(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, _, statusCode, err := buildLogicalRequest(core, w, r) req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r)
if err != nil || statusCode != 0 { if err != nil || statusCode != 0 {
respondError(w, statusCode, err) respondError(w, statusCode, err)
return return
@@ -47,7 +47,7 @@ func handleSysSeal(core *vault.Core) http.Handler {
func handleSysStepDown(core *vault.Core) http.Handler { func handleSysStepDown(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, _, statusCode, err := buildLogicalRequest(core, w, r) req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r)
if err != nil || statusCode != 0 { if err != nil || statusCode != 0 {
respondError(w, statusCode, err) respondError(w, statusCode, err)
return return

View File

@@ -1,15 +1,21 @@
package http package http
import ( import (
"fmt"
"net"
"net/http" "net/http"
"strings"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/quotas"
) )
var ( var (
adjustRequest = func(c *vault.Core, r *http.Request) (*http.Request, int) { adjustRequest = func(c *vault.Core, r *http.Request) (*http.Request, int) {
return r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)), 0 return r, 0
} }
genericWrapping = func(core *vault.Core, in http.Handler, props *vault.HandlerProperties) http.Handler { genericWrapping = func(core *vault.Core, in http.Handler, props *vault.HandlerProperties) http.Handler {
@@ -22,3 +28,56 @@ var (
nonVotersAllowed = false nonVotersAllowed = false
) )
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ns, err := namespace.FromContext(r.Context())
if err != nil {
respondError(w, http.StatusInternalServerError, err)
return
}
req := w.(*LogicalResponseWriter).request
quotaResp, err := core.ApplyRateLimitQuota(&quotas.Request{
Type: quotas.TypeRateLimit,
Path: req.Path,
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
NamespacePath: ns.Path,
ClientAddress: parseRemoteIPAddress(r),
})
if err != nil {
core.Logger().Error("failed to apply quota", "path", req.Path, "error", err)
respondError(w, http.StatusUnprocessableEntity, err)
return
}
if !quotaResp.Allowed {
quotaErr := errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrRateLimitQuotaExceeded)
respondError(w, http.StatusTooManyRequests, quotaErr)
if core.Logger().IsTrace() {
core.Logger().Trace("request rejected due to lease count quota violation", "request_path", req.Path)
}
if core.RateLimitAuditLoggingEnabled() {
_ = core.AuditLogger().AuditRequest(r.Context(), &logical.LogInput{
Request: req,
OuterErr: quotaErr,
})
}
return
}
handler.ServeHTTP(w, r)
return
})
}
func parseRemoteIPAddress(r *http.Request) string {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return ""
}
return ip
}

View File

@@ -28,6 +28,14 @@ var (
// ErrPerfStandbyForward is returned when Vault is in a state such that a // ErrPerfStandbyForward is returned when Vault is in a state such that a
// perf standby cannot satisfy a request // perf standby cannot satisfy a request
ErrPerfStandbyPleaseForward = errors.New("please forward to the active node") ErrPerfStandbyPleaseForward = errors.New("please forward to the active node")
// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease
// count quota being exceeded.
ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded")
// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a
// rate limit quota being exceeded.
ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded")
) )
type HTTPCodedError interface { type HTTPCodedError interface {

View File

@@ -81,7 +81,7 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
} }
}) })
if allErrors != nil { if allErrors != nil {
return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors) return codedErr.Code, multierror.Append(fmt.Errorf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg), allErrors)
} }
return codedErr.Code, errors.New(codedErr.Msg) return codedErr.Code, errors.New(codedErr.Msg)
} }
@@ -110,6 +110,10 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
statusCode = http.StatusBadRequest statusCode = http.StatusBadRequest
case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): case errwrap.Contains(err, ErrUpstreamRateLimited.Error()):
statusCode = http.StatusBadGateway statusCode = http.StatusBadGateway
case errwrap.Contains(err, ErrRateLimitQuotaExceeded.Error()):
statusCode = http.StatusTooManyRequests
case errwrap.Contains(err, ErrLeaseCountQuotaExceeded.Error()):
statusCode = http.StatusTooManyRequests
} }
} }

View File

@@ -339,6 +339,11 @@ func (c *Core) disableCredentialInternal(ctx context.Context, path string, updat
removePathCheckers(c, entry, viewPath) removePathCheckers(c, entry, viewPath)
if err := c.quotaManager.HandleBackendDisabling(ctx, ns.Path, path); err != nil {
c.logger.Error("failed to update quotas after disabling auth", "path", path, "error", err)
return err
}
if c.logger.IsInfo() { if c.logger.IsInfo() {
c.logger.Info("disabled credential backend", "path", path) c.logger.Info("disabled credential backend", "path", path)
} }

View File

@@ -43,6 +43,7 @@ import (
sr "github.com/hashicorp/vault/serviceregistration" sr "github.com/hashicorp/vault/serviceregistration"
"github.com/hashicorp/vault/shamir" "github.com/hashicorp/vault/shamir"
"github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/cluster"
"github.com/hashicorp/vault/vault/quotas"
vaultseal "github.com/hashicorp/vault/vault/seal" vaultseal "github.com/hashicorp/vault/vault/seal"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -97,6 +98,7 @@ var (
enterprisePostUnseal = enterprisePostUnsealImpl enterprisePostUnseal = enterprisePostUnsealImpl
enterprisePreSeal = enterprisePreSealImpl enterprisePreSeal = enterprisePreSealImpl
enterpriseSetupFilteredPaths = enterpriseSetupFilteredPathsImpl enterpriseSetupFilteredPaths = enterpriseSetupFilteredPathsImpl
enterpriseSetupQuotas = enterpriseSetupQuotasImpl
startReplication = startReplicationImpl startReplication = startReplicationImpl
stopReplication = stopReplicationImpl stopReplication = stopReplicationImpl
LastWAL = lastWALImpl LastWAL = lastWALImpl
@@ -520,6 +522,8 @@ type Core struct {
// can test an upgrade to a version that includes the fixes from // can test an upgrade to a version that includes the fixes from
// https://github.com/hashicorp/vault-enterprise/pull/1103 // https://github.com/hashicorp/vault-enterprise/pull/1103
PR1103disabled bool PR1103disabled bool
quotaManager *quotas.Manager
} }
// CoreConfig is used to parameterize a core // CoreConfig is used to parameterize a core
@@ -944,7 +948,9 @@ func NewCore(conf *CoreConfig) (*Core, error) {
c.clusterListener.Store((*cluster.Listener)(nil)) c.clusterListener.Store((*cluster.Listener)(nil))
err = c.adjustForSealMigration(conf.UnwrapSeal) quotasLogger := conf.Logger.Named("quotas")
c.allLoggers = append(c.allLoggers, quotasLogger)
c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1822,7 +1828,10 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock, performCleanup
} }
} }
postSealInternal(c) if err := postSealInternal(c); err != nil {
c.logger.Error("post seal error", "error", err)
return err
}
c.logger.Info("vault is sealed") c.logger.Info("vault is sealed")
@@ -1892,6 +1901,9 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
if err := c.setupCredentials(ctx); err != nil { if err := c.setupCredentials(ctx); err != nil {
return err return err
} }
if err := c.setupQuotas(ctx, false); err != nil {
return err
}
if !c.IsDRSecondary() { if !c.IsDRSecondary() {
if err := c.startRollback(); err != nil { if err := c.startRollback(); err != nil {
return err return err
@@ -2078,6 +2090,10 @@ func enterpriseSetupFilteredPathsImpl(c *Core) error {
return nil return nil
} }
func enterpriseSetupQuotasImpl(ctx context.Context, c *Core) error {
return nil
}
func startReplicationImpl(c *Core) error { func startReplicationImpl(c *Core) error {
return nil return nil
} }
@@ -2474,3 +2490,29 @@ func (c *Core) readFeatureFlags(ctx context.Context) (*FeatureFlags, error) {
} }
return &flags, nil return &flags, nil
} }
// MatchingMount returns the path of the mount that will be responsible for
// handling the given request path.
func (c *Core) MatchingMount(ctx context.Context, reqPath string) string {
return c.router.MatchingMount(ctx, reqPath)
}
func (c *Core) setupQuotas(ctx context.Context, isPerfStandby bool) error {
if c.quotaManager == nil {
return nil
}
return c.quotaManager.Setup(ctx, c.systemBarrierView, isPerfStandby)
}
// ApplyRateLimitQuota checks the request against all the applicable quota rules
func (c *Core) ApplyRateLimitQuota(req *quotas.Request) (quotas.Response, error) {
req.Type = quotas.TypeRateLimit
return c.quotaManager.ApplyQuota(req)
}
// RateLimitAuditLoggingEnabled returns if the quota configuration allows audit
// logging of request rejections due to rate limiting quota rule violations.
func (c *Core) RateLimitAuditLoggingEnabled() bool {
return c.quotaManager.RateLimitAuditLoggingEnabled()
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/license" "github.com/hashicorp/vault/sdk/helper/license"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/vault/quotas"
"github.com/hashicorp/vault/vault/replication" "github.com/hashicorp/vault/vault/replication"
) )
@@ -58,7 +59,7 @@ func addExtraCredentialBackends(*Core, map[string]logical.Factory) {}
func preUnsealInternal(context.Context, *Core) error { return nil } func preUnsealInternal(context.Context, *Core) error { return nil }
func postSealInternal(*Core) {} func postSealInternal(*Core) error { return nil }
func preSealPhysical(c *Core) { func preSealPhysical(c *Core) {
switch c.sealUnwrapper.(type) { switch c.sealUnwrapper.(type) {
@@ -132,3 +133,23 @@ func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, chan struct{},
func (c *Core) initSealsForMigration() {} func (c *Core) initSealsForMigration() {}
func (c *Core) postSealMigration(ctx context.Context) error { return nil } func (c *Core) postSealMigration(ctx context.Context) error { return nil }
func (c *Core) applyLeaseCountQuota(in *quotas.Request) (*quotas.Response, error) {
return &quotas.Response{Allowed: true}, nil
}
func (c *Core) ackLeaseQuota(access quotas.Access, leaseGenerated bool) error {
return nil
}
func (c *Core) quotaLeaseWalker(ctx context.Context, callback func(request *quotas.Request) bool) error {
return nil
}
func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leaseIDs []string) error {
return nil
}
func (c *Core) namespaceByPath(path string) *namespace.Namespace {
return namespace.RootNamespace
}

View File

@@ -12,17 +12,18 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
metrics "github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/base62" "github.com/hashicorp/vault/sdk/helper/base62"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/locksutil" "github.com/hashicorp/vault/sdk/helper/locksutil"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/quotas"
metrics "github.com/armon/go-metrics"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/namespace"
uberAtomic "go.uber.org/atomic" uberAtomic "go.uber.org/atomic"
) )
@@ -256,23 +257,60 @@ func (m *ExpirationManager) inRestoreMode() bool {
} }
func (m *ExpirationManager) invalidate(key string) { func (m *ExpirationManager) invalidate(key string) {
switch { switch {
case strings.HasPrefix(key, leaseViewPrefix): case strings.HasPrefix(key, leaseViewPrefix):
// Clear from the pending expiration
leaseID := strings.TrimPrefix(key, leaseViewPrefix) leaseID := strings.TrimPrefix(key, leaseViewPrefix)
m.pendingLock.Lock() ctx := m.quitContext
if info, ok := m.pending.Load(leaseID); ok { _, nsID := namespace.SplitIDFromString(leaseID)
pending := info.(pendingInfo) leaseNS := namespace.RootNamespace
pending.timer.Stop() var err error
m.pending.Delete(leaseID) if nsID != "" {
m.leaseCount-- leaseNS, err = NamespaceByID(ctx, nsID, m.core)
if err != nil {
m.logger.Error("failed to invalidate lease entry", "error", err)
return
}
} }
// If in the nonexpiring map, remove there. le, err := m.loadEntryInternal(namespace.ContextWithNamespace(ctx, leaseNS), leaseID, false, false)
m.nonexpiring.Delete(leaseID) if err != nil {
m.logger.Error("failed to invalidate lease entry", "error", err)
return
}
m.pendingLock.Unlock() m.pendingLock.Lock()
defer m.pendingLock.Unlock()
info, ok := m.pending.Load(leaseID)
switch {
case ok:
switch {
case le == nil:
// Handle lease deletion
pending := info.(pendingInfo)
pending.timer.Stop()
m.pending.Delete(leaseID)
m.leaseCount--
// If in the nonexpiring map, remove there.
m.nonexpiring.Delete(leaseID)
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil {
m.logger.Error("failed to handle lease delete invalidation", "error", err)
return
}
default:
// Handle lease update
m.updatePendingInternal(le, le.ExpireTime.Sub(time.Now()))
}
default:
// There is no entry in the pending map and the invalidation
// resulted in a nil entry. This should ideally never happen.
if le == nil {
return
}
// Handle lease creation
m.updatePendingInternal(le, le.ExpireTime.Sub(time.Now()))
}
} }
} }
@@ -692,13 +730,18 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo
} }
} }
// Clear the expiration handler (or remove from the list of non-expiring tokens.) // Clear the expiration handler
m.pendingLock.Lock() m.pendingLock.Lock()
if info, ok := m.pending.Load(leaseID); ok { if info, ok := m.pending.Load(leaseID); ok {
pending := info.(pendingInfo) pending := info.(pendingInfo)
pending.timer.Stop() pending.timer.Stop()
m.pending.Delete(leaseID) m.pending.Delete(leaseID)
m.leaseCount-- m.leaseCount--
if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil {
m.pendingLock.Unlock()
m.logger.Error("failed to handle lease path deletion", "error", err)
return err
}
} }
m.nonexpiring.Delete(leaseID) m.nonexpiring.Delete(leaseID)
m.pendingLock.Unlock() m.pendingLock.Unlock()
@@ -1420,10 +1463,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry, leaseTotal tim
info.(pendingInfo).timer.Stop() info.(pendingInfo).timer.Stop()
m.pending.Delete(le.LeaseID) m.pending.Delete(le.LeaseID)
m.leaseCount-- m.leaseCount--
if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []string{le.LeaseID}); err != nil {
m.logger.Error("failed to handle lease path deletion", "error", err)
return
}
} }
return return
} }
leaseCreated := false
// Create entry if it does not exist or reset if it does // Create entry if it does not exist or reset if it does
if ok { if ok {
pending = info.(pendingInfo) pending = info.(pendingInfo)
@@ -1439,12 +1487,20 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry, leaseTotal tim
} }
// new lease // new lease
m.leaseCount++ m.leaseCount++
leaseCreated = true
} }
// Retain some information in-memory // Retain some information in-memory
pending.cachedLeaseInfo = m.inMemoryLeaseInfo(le) pending.cachedLeaseInfo = m.inMemoryLeaseInfo(le)
m.pending.Store(le.LeaseID, pending) m.pending.Store(le.LeaseID, pending)
if leaseCreated {
if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []string{le.LeaseID}); err != nil {
m.logger.Error("failed to handle lease creation", "error", err)
return
}
}
} }
// revokeEntry is used to attempt revocation of an internal entry // revokeEntry is used to attempt revocation of an internal entry

View File

@@ -0,0 +1,384 @@
package quotas
import (
"fmt"
"testing"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/builtin/logical/pki"
"github.com/hashicorp/vault/helper/testhelpers/teststorage"
"github.com/hashicorp/vault/vault"
"go.uber.org/atomic"
)
const (
testLookupOnlyPolicy = `
path "/auth/token/lookup" {
capabilities = [ "create", "update"]
}
`
)
var (
coreConfig = &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"pki": pki.Factory,
},
CredentialBackends: map[string]logical.Factory{
"userpass": userpass.Factory,
},
}
)
func setupMounts(t *testing.T, client *api.Client) {
t.Helper()
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
"password": "bar",
})
if err != nil {
t.Fatal(err)
}
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "testvault.com",
"ttl": "200h",
"ip_sans": "127.0.0.1",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"require_cn": false,
"allowed_domains": "testvault.com",
"allow_subdomains": true,
"max_ttl": "2h",
"generate_lease": true,
})
if err != nil {
t.Fatal(err)
}
}
func teardownMounts(t *testing.T, client *api.Client) {
t.Helper()
if err := client.Sys().Unmount("pki"); err != nil {
t.Fatal(err)
}
if err := client.Sys().DisableAuth("userpass"); err != nil {
t.Fatal(err)
}
}
func testRPS(reqFunc func(numSuccess, numFail *atomic.Int32), d time.Duration) (int32, int32, time.Duration) {
numSuccess := atomic.NewInt32(0)
numFail := atomic.NewInt32(0)
start := time.Now()
end := start.Add(d)
for time.Now().Before(end) {
reqFunc(numSuccess, numFail)
}
return numSuccess.Load(), numFail.Load(), time.Since(start)
}
func waitForRemovalOrTimeout(c *api.Client, path string, tick, to time.Duration) error {
ticker := time.Tick(tick)
timeout := time.After(to)
// wait for the resource to be removed
for {
select {
case <-timeout:
return fmt.Errorf("timeout exceeding waiting for resource to be deleted: %s", path)
case <-ticker:
resp, err := c.Logical().Read(path)
if err != nil {
return err
}
if resp == nil {
return nil
}
}
}
}
func TestQuotas_RateLimitQuota_Mount(t *testing.T) {
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
cluster := vault.NewTestCluster(t, conf, opts)
cluster.Start()
defer cluster.Cleanup()
core := cluster.Cores[0].Core
client := cluster.Cores[0].Client
vault.TestWaitActive(t, core)
err := client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "testvault.com",
"ttl": "200h",
"ip_sans": "127.0.0.1",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"require_cn": false,
"allowed_domains": "testvault.com",
"allow_subdomains": true,
"max_ttl": "2h",
"generate_lease": true,
})
if err != nil {
t.Fatal(err)
}
reqFunc := func(numSuccess, numFail *atomic.Int32) {
_, err := client.Logical().Read("pki/cert/ca_chain")
if err != nil {
numFail.Add(1)
} else {
numSuccess.Add(1)
}
}
// Create a rate limit quota with a low RPS of 7.7, which means we can process
// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed
// by a refill rate of 7.7 per-second.
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
"rate": 7.7,
"burst": 8,
"path": "pki/",
})
if err != nil {
t.Fatal(err)
}
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
// ensure there were some failed requests
if numFail == 0 {
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
}
// ensure that we should never get more requests than allowed
if want := int32(ideal + 1); numSuccess > want {
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
}
// update the rate limit quota with a high RPS such that no requests should fail
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
"rate": 1000.0,
"burst": 3000,
"path": "pki/",
})
if err != nil {
t.Fatal(err)
}
_, numFail, _ = testRPS(reqFunc, 5*time.Second)
if numFail > 0 {
t.Fatalf("unexpected number of failed requests: %d", numFail)
}
}
func TestQuotas_RateLimitQuota_MountPrecedence(t *testing.T) {
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
cluster := vault.NewTestCluster(t, conf, opts)
cluster.Start()
defer cluster.Cleanup()
core := cluster.Cores[0].Core
client := cluster.Cores[0].Client
vault.TestWaitActive(t, core)
// create PKI mount
err := client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"common_name": "testvault.com",
"ttl": "200h",
"ip_sans": "127.0.0.1",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{
"require_cn": false,
"allowed_domains": "testvault.com",
"allow_subdomains": true,
"max_ttl": "2h",
"generate_lease": true,
})
if err != nil {
t.Fatal(err)
}
// create a root rate limit quota
_, err = client.Logical().Write("sys/quotas/rate-limit/root-rlq", map[string]interface{}{
"name": "root-rlq",
"rate": 14.7,
"burst": 15,
})
if err != nil {
t.Fatal(err)
}
// create a mount rate limit quota with a lower RPS than the root rate limit quota
_, err = client.Logical().Write("sys/quotas/rate-limit/mount-rlq", map[string]interface{}{
"name": "mount-rlq",
"rate": 7.7,
"burst": 8,
"path": "pki/",
})
if err != nil {
t.Fatal(err)
}
// ensure mount rate limit quota takes precedence over root rate limit quota
reqFunc := func(numSuccess, numFail *atomic.Int32) {
_, err := client.Logical().Read("pki/cert/ca_chain")
if err != nil {
numFail.Add(1)
} else {
numSuccess.Add(1)
}
}
// ensure mount rate limit quota takes precedence over root rate limit quota
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
// ensure there were some failed requests
if numFail == 0 {
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
}
// ensure that we should never get more requests than allowed
if want := int32(ideal + 1); numSuccess > want {
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
}
}
func TestQuotas_RateLimitQuota(t *testing.T) {
conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil)
cluster := vault.NewTestCluster(t, conf, opts)
cluster.Start()
defer cluster.Cleanup()
core := cluster.Cores[0].Core
client := cluster.Cores[0].Client
vault.TestWaitActive(t, core)
err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
Type: "userpass",
})
if err != nil {
t.Fatal(err)
}
_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{
"password": "bar",
})
if err != nil {
t.Fatal(err)
}
// Create a rate limit quota with a low RPS of 7.7, which means we can process
// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed
// by a refill rate of 7.7 per-second.
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
"rate": 7.7,
"burst": 8,
})
if err != nil {
t.Fatal(err)
}
reqFunc := func(numSuccess, numFail *atomic.Int32) {
_, err := client.Logical().Read("sys/quotas/rate-limit/rlq")
if err != nil {
numFail.Add(1)
} else {
numSuccess.Add(1)
}
}
numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second)
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second))
// ensure there were some failed requests
if numFail == 0 {
t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed)
}
// ensure that we should never get more requests than allowed
if want := int32(ideal + 1); numSuccess > want {
t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed)
}
// allow time (1s) for rate limit to refill before updating the quota
time.Sleep(time.Second)
// update the rate limit quota with a high RPS such that no requests should fail
_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{
"rate": 1000.0,
"burst": 3000,
})
if err != nil {
t.Fatal(err)
}
_, numFail, _ = testRPS(reqFunc, 5*time.Second)
if numFail > 0 {
t.Fatalf("unexpected number of failed requests: %d", numFail)
}
}

View File

@@ -160,6 +160,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend {
b.Backend.Paths = append(b.Backend.Paths, b.metricsPath()) b.Backend.Paths = append(b.Backend.Paths, b.metricsPath())
b.Backend.Paths = append(b.Backend.Paths, b.monitorPath()) b.Backend.Paths = append(b.Backend.Paths, b.monitorPath())
b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath()) b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath())
b.Backend.Paths = append(b.Backend.Paths, b.quotasPaths()...)
if core.rawEnabled { if core.rawEnabled {
b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...)
@@ -751,7 +752,7 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d
// Get all the options // Get all the options
path := data.Get("path").(string) path := data.Get("path").(string)
path = sanitizeMountPath(path) path = sanitizePath(path)
logicalType := data.Get("type").(string) logicalType := data.Get("type").(string)
description := data.Get("description").(string) description := data.Get("description").(string)
@@ -934,7 +935,7 @@ func handleErrorNoReadOnlyForward(
// handleUnmount is used to unmount a path // handleUnmount is used to unmount a path
func (b *SystemBackend) handleUnmount(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *SystemBackend) handleUnmount(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string) path := data.Get("path").(string)
path = sanitizeMountPath(path) path = sanitizePath(path)
ns, err := namespace.FromContext(ctx) ns, err := namespace.FromContext(ctx)
if err != nil { if err != nil {
@@ -1029,6 +1030,12 @@ func (b *SystemBackend) handleRemount(ctx context.Context, req *logical.Request,
return handleError(err) return handleError(err)
} }
// Update quotas with the new path
if err := b.Core.quotaManager.HandleRemount(ctx, ns.Path, sanitizePath(fromPath), sanitizePath(toPath)); err != nil {
b.Core.logger.Error("failed to update quotas after remount", "ns_path", ns.Path, "from_path", fromPath, "to_path", toPath, "error", err)
return handleError(err)
}
return nil, nil return nil, nil
} }
@@ -1060,7 +1067,7 @@ func (b *SystemBackend) handleMountTuneRead(ctx context.Context, req *logical.Re
// handleTuneReadCommon returns the config settings of a path // handleTuneReadCommon returns the config settings of a path
func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) (*logical.Response, error) { func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) (*logical.Response, error) {
path = sanitizeMountPath(path) path = sanitizePath(path)
sysView := b.Core.router.MatchingSystemView(ctx, path) sysView := b.Core.router.MatchingSystemView(ctx, path)
if sysView == nil { if sysView == nil {
@@ -1146,7 +1153,7 @@ func (b *SystemBackend) handleMountTuneWrite(ctx context.Context, req *logical.R
func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, data *framework.FieldData) (*logical.Response, error) { func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, data *framework.FieldData) (*logical.Response, error) {
repState := b.Core.ReplicationState() repState := b.Core.ReplicationState()
path = sanitizeMountPath(path) path = sanitizePath(path)
// Prevent protected paths from being changed // Prevent protected paths from being changed
for _, p := range untunableMounts { for _, p := range untunableMounts {
@@ -1716,7 +1723,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque
// Get all the options // Get all the options
path := data.Get("path").(string) path := data.Get("path").(string)
path = sanitizeMountPath(path) path = sanitizePath(path)
logicalType := data.Get("type").(string) logicalType := data.Get("type").(string)
description := data.Get("description").(string) description := data.Get("description").(string)
pluginName := data.Get("plugin_name").(string) pluginName := data.Get("plugin_name").(string)
@@ -1857,7 +1864,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque
// handleDisableAuth is used to disable a credential backend // handleDisableAuth is used to disable a credential backend
func (b *SystemBackend) handleDisableAuth(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (b *SystemBackend) handleDisableAuth(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
path := data.Get("path").(string) path := data.Get("path").(string)
path = sanitizeMountPath(path) path = sanitizePath(path)
ns, err := namespace.FromContext(ctx) ns, err := namespace.FromContext(ctx)
if err != nil { if err != nil {
@@ -2272,7 +2279,7 @@ func (b *SystemBackend) handleAuditHash(ctx context.Context, req *logical.Reques
return logical.ErrorResponse("the \"input\" parameter is empty"), nil return logical.ErrorResponse("the \"input\" parameter is empty"), nil
} }
path = sanitizeMountPath(path) path = sanitizePath(path)
hash, err := b.Core.auditBroker.GetHash(ctx, path, input) hash, err := b.Core.auditBroker.GetHash(ctx, path, input)
if err != nil { if err != nil {
@@ -3258,7 +3265,7 @@ func (b *SystemBackend) pathInternalUIMountRead(ctx context.Context, req *logica
if path == "" { if path == "" {
return logical.ErrorResponse("path not set"), logical.ErrInvalidRequest return logical.ErrorResponse("path not set"), logical.ErrInvalidRequest
} }
path = sanitizeMountPath(path) path = sanitizePath(path)
errResp := logical.ErrorResponse(fmt.Sprintf("preflight capability check returned 403, please ensure client's policies grant access to path %q", path)) errResp := logical.ErrorResponse(fmt.Sprintf("preflight capability check returned 403, please ensure client's policies grant access to path %q", path))
@@ -3576,7 +3583,7 @@ func (b *SystemBackend) pathInternalOpenAPI(ctx context.Context, req *logical.Re
return resp, nil return resp, nil
} }
func sanitizeMountPath(path string) string { func sanitizePath(path string) string {
if !strings.HasSuffix(path, "/") { if !strings.HasSuffix(path, "/") {
path += "/" path += "/"
} }

View File

@@ -0,0 +1,272 @@
package vault
import (
"context"
"strings"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/quotas"
)
// quotasPaths returns paths that enable quota management
func (b *SystemBackend) quotasPaths() []*framework.Path {
return []*framework.Path{
{
Pattern: "quotas/config$",
Fields: map[string]*framework.FieldSchema{
"enable_rate_limit_audit_logging": {
Type: framework.TypeBool,
Description: "If set, starts audit logging of requests that get rejected due to rate limit quota rule violations.",
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.UpdateOperation: &framework.PathOperation{
Callback: b.handleQuotasConfigUpdate(),
},
logical.ReadOperation: &framework.PathOperation{
Callback: b.handleQuotasConfigRead(),
},
},
HelpSynopsis: strings.TrimSpace(quotasHelp["quotas-config"][0]),
HelpDescription: strings.TrimSpace(quotasHelp["quotas-config"][1]),
},
{
Pattern: "quotas/rate-limit/?$",
Operations: map[logical.Operation]framework.OperationHandler{
logical.ListOperation: &framework.PathOperation{
Callback: b.handleRateLimitQuotasList(),
},
},
HelpSynopsis: strings.TrimSpace(quotasHelp["rate-limit-list"][0]),
HelpDescription: strings.TrimSpace(quotasHelp["rate-limit-list"][1]),
},
{
Pattern: "quotas/rate-limit/" + framework.GenericNameRegex("name"),
Fields: map[string]*framework.FieldSchema{
"type": {
Type: framework.TypeString,
Description: "Type of the quota rule.",
},
"name": {
Type: framework.TypeString,
Description: "Name of the quota rule.",
},
"path": {
Type: framework.TypeString,
Description: `Path of the mount or namespace to apply the quota. A blank path configures a
global quota. For example namespace1/ adds a quota to a full namespace,
namespace1/auth/userpass adds a quota to userpass in namespace1.`,
},
"rate": {
Type: framework.TypeFloat,
Description: `The rate at which allowed requests are refilled per second by the quota rule.
Internally, a token-bucket algorithm is used which has a size of 'burst', initially full. The quota
limits requests to 'rate' per-second, with a maximum burst size of 'burst'. Each request takes a single
token from this bucket. The 'rate' must be positive.`,
},
"burst": {
Type: framework.TypeInt,
Description: `The maximum number of requests at any given second to be allowed by the quota
rule. There is a one-to-one mapping between requests and tokens in the rate limit quota. A client
may perform up to 'burst' requests at once, at which they they may invoke additional requests at
'rate' per-second.`,
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.UpdateOperation: &framework.PathOperation{
Callback: b.handleRateLimitQuotasUpdate(),
},
logical.ReadOperation: &framework.PathOperation{
Callback: b.handleRateLimitQuotasRead(),
},
logical.DeleteOperation: &framework.PathOperation{
Callback: b.handleRateLimitQuotasDelete(),
},
},
HelpSynopsis: strings.TrimSpace(quotasHelp["rate-limit"][0]),
HelpDescription: strings.TrimSpace(quotasHelp["rate-limit"][1]),
},
}
}
func (b *SystemBackend) handleQuotasConfigUpdate() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
config, err := quotas.LoadConfig(ctx, b.Core.systemBarrierView)
if err != nil {
return nil, err
}
config.EnableRateLimitAuditLogging = d.Get("enable_rate_limit_audit_logging").(bool)
entry, err := logical.StorageEntryJSON(quotas.ConfigPath, config)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
b.Core.quotaManager.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging)
return nil, nil
}
}
func (b *SystemBackend) handleQuotasConfigRead() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
config := b.Core.quotaManager.Config()
return &logical.Response{
Data: map[string]interface{}{
"enable_rate_limit_audit_logging": config.EnableRateLimitAuditLogging,
},
}, nil
}
}
func (b *SystemBackend) handleRateLimitQuotasList() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
names, err := b.Core.quotaManager.QuotaNames(quotas.TypeRateLimit)
if err != nil {
return nil, err
}
return logical.ListResponse(names), nil
}
}
func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
qType := quotas.TypeRateLimit.String()
rate := d.Get("rate").(float64)
if rate <= 0 {
return logical.ErrorResponse("'rate' is invalid"), nil
}
burst := d.Get("burst").(int)
if burst < int(rate) {
return logical.ErrorResponse("'burst' must be greater than or equal to 'rate' as an integer value"), nil
}
mountPath := sanitizePath(d.Get("path").(string))
ns := b.Core.namespaceByPath(mountPath)
if ns.ID != namespace.RootNamespaceID {
mountPath = strings.TrimPrefix(mountPath, ns.Path)
}
if mountPath != "" {
match := b.Core.router.MatchingMount(namespace.ContextWithNamespace(ctx, ns), mountPath)
if match == "" {
return logical.ErrorResponse("invalid mount path %q", mountPath), nil
}
}
// Disallow duplicate quotas with same precedence and similar
// properties.
quota, err := b.Core.quotaManager.QuotaByFactors(ctx, qType, ns.Path, mountPath)
if err != nil {
return nil, err
}
if quota != nil && quota.QuotaName() != name {
return logical.ErrorResponse("quota rule with similar properties exists under the name %q", quota.QuotaName()), nil
}
switch {
case quota == nil:
quota = quotas.NewRateLimitQuota(name, ns.Path, mountPath, rate, burst)
default:
rlq := quota.(*quotas.RateLimitQuota)
rlq.NamespacePath = ns.Path
rlq.MountPath = mountPath
rlq.Rate = rate
rlq.Burst = burst
}
entry, err := logical.StorageEntryJSON(quotas.QuotaStoragePath(qType, name), quota)
if err != nil {
return nil, err
}
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
if err := b.Core.quotaManager.SetQuota(ctx, qType, quota, false); err != nil {
return nil, err
}
return nil, nil
}
}
func (b *SystemBackend) handleRateLimitQuotasRead() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
qType := quotas.TypeRateLimit.String()
quota, err := b.Core.quotaManager.QuotaByName(qType, name)
if err != nil {
return nil, err
}
if quota == nil {
return nil, nil
}
rlq := quota.(*quotas.RateLimitQuota)
nsPath := rlq.NamespacePath
if rlq.NamespacePath == "root" {
nsPath = ""
}
data := map[string]interface{}{
"type": qType,
"name": rlq.Name,
"path": nsPath + rlq.MountPath,
"rate": rlq.Rate,
"burst": rlq.Burst,
}
return &logical.Response{
Data: data,
}, nil
}
}
func (b *SystemBackend) handleRateLimitQuotasDelete() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
qType := quotas.TypeRateLimit.String()
if err := req.Storage.Delete(ctx, quotas.QuotaStoragePath(qType, name)); err != nil {
return nil, err
}
if err := b.Core.quotaManager.DeleteQuota(ctx, qType, name); err != nil {
return nil, err
}
return nil, nil
}
}
var quotasHelp = map[string][2]string{
"quotas-config": {
"Create, update and read the quota configuration.",
"",
},
"rate-limit": {
`Get, create or update rate limit resource quota for an optional namespace or
mount.`,
`A rate limit quota will enforce rate limiting using a token bucket algorithm. A
rate limit quota can be created at the root level or defined on a namespace or
mount by specifying a 'path'. The rate limiter is applied to each unique client
IP address. A client may invoke 'burst' requests at any given second, after
which they may invoke additional requests at 'rate' per-second.`,
},
"rate-limit-list": {
"Lists the names of all the rate limit quotas.",
"This list contains quota definitions from all the namespaces.",
},
}

View File

@@ -2654,7 +2654,7 @@ func TestSystemBackend_PathWildcardPreflight(t *testing.T) {
// Add another mount // Add another mount
me := &MountEntry{ me := &MountEntry{
Table: mountTableType, Table: mountTableType,
Path: sanitizeMountPath("kv-v1"), Path: sanitizePath("kv-v1"),
Type: "kv", Type: "kv",
Options: map[string]string{"version": "1"}, Options: map[string]string{"version": "1"},
} }

View File

@@ -664,6 +664,11 @@ func (c *Core) unmountInternal(ctx context.Context, path string, updateStorage b
removePathCheckers(c, entry, viewPath) removePathCheckers(c, entry, viewPath)
if err := c.quotaManager.HandleBackendDisabling(ctx, ns.Path, path); err != nil {
c.logger.Error("failed to update quotas after disabling mount", "path", path, "error", err)
return err
}
if c.logger.IsInfo() { if c.logger.IsInfo() {
c.logger.Info("successfully unmounted", "path", path, "namespace", ns.Path) c.logger.Info("successfully unmounted", "path", path, "namespace", ns.Path)
} }

860
vault/quotas/quotas.go Normal file
View File

@@ -0,0 +1,860 @@
package quotas
import (
"context"
"errors"
"fmt"
"path"
"strings"
"sync"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/sdk/logical"
)
// Type represents the quota kind
type Type string
const (
// TypeRateLimit represents the rate limiting quota type
TypeRateLimit Type = "rate-limit"
// TypeLeaseCount represents the lease count limiting quota type
TypeLeaseCount Type = "lease-count"
)
// LeaseAction is the action taken by the expiration manager on the lease. The
// quota manager will use this information to update the lease path cache and
// updating counters for relevant quota rules.
type LeaseAction uint32
// String converts each lease action into its string equivalent value
func (la LeaseAction) String() string {
switch la {
case LeaseActionLoaded:
return "loaded"
case LeaseActionCreated:
return "created"
case LeaseActionDeleted:
return "deleted"
case LeaseActionAllow:
return "allow"
}
return "unknown"
}
const (
_ LeaseAction = iota
// LeaseActionLoaded indicates loading of lease in the expiration manager after
// unseal.
LeaseActionLoaded
// LeaseActionCreated indicates that a lease is created in the expiration manager.
LeaseActionCreated
// LeaseActionDeleted indicates that is lease is expired and deleted in the
// expiration manager.
LeaseActionDeleted
// LeaseActionAllow will be used to indicate the lease count checker that
// incCounter is called from Allow(). All the rest of the actions indicate the
// action took place on the lease in the expiration manager.
LeaseActionAllow
)
type leaseWalkFunc func(context.Context, func(request *Request) bool) error
// String converts each quota type into its string equivalent value
func (q Type) String() string {
switch q {
case TypeLeaseCount:
return "lease-count"
case TypeRateLimit:
return "rate-limit"
}
return "unknown"
}
const (
indexID = "id"
indexName = "name"
indexNamespace = "ns"
indexNamespaceMount = "ns_mount"
)
const (
// StoragePrefix is the prefix for the physical location where quota rules are
// persisted.
StoragePrefix = "quotas/"
// ConfigPath is the physical location where the quota configuration is
// persisted.
ConfigPath = StoragePrefix + "config"
)
var (
// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease
// count quota being exceeded.
ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded")
// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a
// rate limit quota being exceeded.
ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded")
)
// Access provides information to reach back to the quota checker.
type Access interface {
// QuotaID is the identifier of the quota that issued this access.
QuotaID() string
}
// Ensure that access implements the Access interface.
var _ Access = (*access)(nil)
// access implements the Access interface
type access struct {
quotaID string
}
// QuotaID returns the identifier of the quota rule to which this access refers
// to.
func (a *access) QuotaID() string {
return a.quotaID
}
// Manager holds all the existing quota rules. For any given input. the manager
// checks them against any applicable quota rules.
type Manager struct {
entManager
// db holds the in memory instances of all active quota rules indexed by
// some of the quota properties.
db *memdb.MemDB
// config containing operator preferences and quota behaviors
config *Config
storage logical.Storage
ctx context.Context
logger log.Logger
metricSink *metricsutil.ClusterMetricSink
lock *sync.RWMutex
}
// Quota represents the common properties of every quota type
type Quota interface {
// allow checks the if the request is allowed by the quota type implementation.
allow(*Request) (Response, error)
// quotaID is the identifier of the quota rule
quotaID() string
// QuotaName is the name of the quota rule
QuotaName() string
// initialize sets up the fields in the quota type to begin operating
initialize(log.Logger, *metricsutil.ClusterMetricSink) error
// close defines any cleanup behavior that needs to be executed when a quota
// rule is deleted.
close() error
// handleRemount takes in the new mount path in the quota
handleRemount(string)
}
// Response holds information about the result of the Allow() call. The response
// can optionally have the Access field set, which is used to reach back into
// the quota rule that sent this response.
type Response struct {
// Allowed is set if the quota allows the request
Allowed bool
// Access is the handle to reach back into the quota rule that processed the
// quota request. This may not be set all the time.
Access Access
}
// Config holds operator preferences around quota behaviors
type Config struct {
// EnableRateLimitAuditLogging, if set, starts audit logging of the
// request rejections that arise due to rate limit quota violations.
EnableRateLimitAuditLogging bool `json:"enable_rate_limit_audit_logging"`
}
// Request contains information required by the quota manager to query and
// apply the quota rules.
type Request struct {
// Type is the quota type
Type Type
// Path is the request path to which quota rules are being queried for
Path string
// NamespacePath is the namespace path to which the request belongs
NamespacePath string
// MountPath is the mount path to which the request is made
MountPath string
// ClientAddress is client unique addressable string (e.g. IP address). It can
// be empty if the quota type does not need it.
ClientAddress string
}
// NewManager creates and initializes a new quota manager to hold all the quota
// rules and to process incoming requests.
func NewManager(logger log.Logger, walkFunc leaseWalkFunc, ms *metricsutil.ClusterMetricSink) (*Manager, error) {
db, err := memdb.NewMemDB(dbSchema())
if err != nil {
return nil, err
}
manager := &Manager{
db: db,
logger: logger,
metricSink: ms,
config: new(Config),
lock: new(sync.RWMutex),
}
manager.init(walkFunc)
return manager, nil
}
// SetQuota adds a new quota rule to the db.
func (m *Manager) SetQuota(ctx context.Context, qType string, quota Quota, loading bool) error {
m.lock.Lock()
defer m.lock.Unlock()
return m.setQuotaLocked(ctx, qType, quota, loading)
}
// setQuotaLocked should be called with the manager's lock held
func (m *Manager) setQuotaLocked(ctx context.Context, qType string, quota Quota, loading bool) error {
if qType == TypeLeaseCount.String() {
m.setIsPerfStandby(quota)
}
txn := m.db.Txn(true)
defer txn.Abort()
raw, err := txn.First(qType, "id", quota.quotaID())
if err != nil {
return err
}
// If there already exists an entry in the db, remove that first.
if raw != nil {
err = txn.Delete(qType, raw)
if err != nil {
return err
}
}
// Initialize the quota type implementation
if err := quota.initialize(m.logger, m.metricSink); err != nil {
return err
}
// Add the initialized quota type implementation to the db
if err := txn.Insert(qType, quota); err != nil {
return err
}
if loading {
txn.Commit()
return nil
}
// For the lease count type, recompute the counters
if !loading && qType == TypeLeaseCount.String() {
if err := m.recomputeLeaseCounts(ctx, txn); err != nil {
return err
}
}
txn.Commit()
return nil
}
// QuotaNames returns the names of all the quota rules for a given type
func (m *Manager) QuotaNames(qType Type) ([]string, error) {
m.lock.RLock()
defer m.lock.RUnlock()
txn := m.db.Txn(false)
iter, err := txn.Get(qType.String(), indexID)
if err != nil {
return nil, err
}
var names []string
for raw := iter.Next(); raw != nil; raw = iter.Next() {
names = append(names, raw.(Quota).QuotaName())
}
return names, nil
}
// QuotaByID queries for a quota rule in the db for a given quota ID
func (m *Manager) QuotaByID(qType string, id string) (Quota, error) {
m.lock.RLock()
defer m.lock.RUnlock()
txn := m.db.Txn(false)
quotaRaw, err := txn.First(qType, indexID, id)
if err != nil {
return nil, err
}
if quotaRaw == nil {
return nil, nil
}
return quotaRaw.(Quota), nil
}
// QuotaByName queries for a quota rule in the db for a given quota name
func (m *Manager) QuotaByName(qType string, name string) (Quota, error) {
m.lock.RLock()
defer m.lock.RUnlock()
txn := m.db.Txn(false)
quotaRaw, err := txn.First(qType, indexName, name)
if err != nil {
return nil, err
}
if quotaRaw == nil {
return nil, nil
}
return quotaRaw.(Quota), nil
}
// QuotaByFactors returns the quota rule that matches the provided factors
func (m *Manager) QuotaByFactors(ctx context.Context, qType, nsPath, mountPath string) (Quota, error) {
m.lock.RLock()
defer m.lock.RUnlock()
// nsPath would have been made non-empty during insertion. Use non-empty value
// during query as well.
if nsPath == "" {
nsPath = "root"
}
idx := indexNamespace
args := []interface{}{nsPath, false}
if mountPath != "" {
idx = indexNamespaceMount
args = []interface{}{nsPath, mountPath}
}
txn := m.db.Txn(false)
iter, err := txn.Get(qType, idx, args...)
if err != nil {
return nil, err
}
var quotas []Quota
for raw := iter.Next(); raw != nil; raw = iter.Next() {
quotas = append(quotas, raw.(Quota))
}
if len(quotas) > 1 {
return nil, fmt.Errorf("conflicting quota definitions detected")
}
if len(quotas) == 0 {
return nil, nil
}
return quotas[0], nil
}
// queryQuota returns the quota rule that is applicable for the given request. It
// queries all the quota rules that are defined against request values and finds
// the quota rule that takes priority.
//
// Priority rules are as follows:
// - namespace specific quota takes precedence over global quota
// - mount specific quota takes precedence over namespace specific quota
func (m *Manager) queryQuota(txn *memdb.Txn, req *Request) (Quota, error) {
if txn == nil {
txn = m.db.Txn(false)
}
// ns would have been made non-empty during insertion. Use non-empty
// value during query as well.
if req.NamespacePath == "" {
req.NamespacePath = "root"
}
//
// Find a match from most specific applicable quota rule to less specific one.
//
quotaFetchFunc := func(idx string, args ...interface{}) (Quota, error) {
iter, err := txn.Get(req.Type.String(), idx, args...)
if err != nil {
return nil, err
}
var quotas []Quota
for raw := iter.Next(); raw != nil; raw = iter.Next() {
quota := raw.(Quota)
quotas = append(quotas, quota)
}
if len(quotas) > 1 {
return nil, fmt.Errorf("conflicting quota definitions detected")
}
if len(quotas) == 0 {
return nil, nil
}
return quotas[0], nil
}
// Fetch mount quota
quota, err := quotaFetchFunc(indexNamespaceMount, req.NamespacePath, req.MountPath)
if err != nil {
return nil, err
}
if quota != nil {
return quota, nil
}
// Fetch ns quota. If NamespacePath is root, this will return the global quota.
quota, err = quotaFetchFunc(indexNamespace, req.NamespacePath, false)
if err != nil {
return nil, err
}
if quota != nil {
return quota, nil
}
// If the request belongs to "root" namespace, then we have already looked at
// global quotas when fetching namespace specific quota rule. When the request
// belongs to a non-root namespace, and when there are no namespace specific
// quota rules present, we fallback on the global quotas.
if req.NamespacePath == "root" {
return nil, nil
}
// Fetch global quota
quota, err = quotaFetchFunc(indexNamespace, "root", false)
if err != nil {
return nil, err
}
if quota != nil {
return quota, nil
}
return nil, nil
}
// DeleteQuota removes a quota rule from the db for a given name
func (m *Manager) DeleteQuota(ctx context.Context, qType string, name string) error {
m.lock.Lock()
defer m.lock.Unlock()
txn := m.db.Txn(true)
defer txn.Abort()
raw, err := txn.First(qType, indexName, name)
if err != nil {
return err
}
if raw == nil {
return nil
}
quota := raw.(Quota)
if err := quota.close(); err != nil {
return err
}
err = txn.Delete(qType, raw)
if err != nil {
return err
}
// For the lease count type, recompute the counters
if qType == TypeLeaseCount.String() {
if err := m.recomputeLeaseCounts(ctx, txn); err != nil {
return err
}
}
txn.Commit()
return nil
}
// ApplyQuota runs the request against any quota rule that is applicable to it. If
// there are multiple quota rule that matches the request parameters, rule that
// takes precedence will be used to allow/reject the request.
func (m *Manager) ApplyQuota(req *Request) (Response, error) {
var resp Response
quota, err := m.queryQuota(nil, req)
if err != nil {
return resp, err
}
// If there is no quota defined, allow the request.
if quota == nil {
resp.Allowed = true
return resp, nil
}
// If the quota type is lease count, and if the path is not known to
// generate leases, allow the request.
if req.Type == TypeLeaseCount && !m.inLeasePathCache(req.Path) {
resp.Allowed = true
return resp, nil
}
return quota.allow(req)
}
// SetEnableRateLimitAuditLogging updates the operator preference regarding the
// audit logging behavior.
func (m *Manager) SetEnableRateLimitAuditLogging(val bool) {
m.config.EnableRateLimitAuditLogging = val
}
// RateLimitAuditLoggingEnabled returns if the quota configuration allows audit
// logging of request rejections due to rate limiting quota rule violations.
func (m *Manager) RateLimitAuditLoggingEnabled() bool {
return m.config.EnableRateLimitAuditLogging
}
// Config returns the operator preferences in the quota manager
func (m *Manager) Config() *Config {
return m.config
}
// Reset will clear all the quotas from the db and clear the lease path cache.
func (m *Manager) Reset() error {
m.lock.Lock()
defer m.lock.Unlock()
var err error
m.db, err = memdb.NewMemDB(dbSchema())
if err != nil {
return err
}
m.storage = nil
m.ctx = nil
m.entManager.Reset()
return nil
}
// dbSchema creates a DB schema for holding all the quota rules. It creates a
// table for each supported type of quota.
func dbSchema() *memdb.DBSchema {
schema := &memdb.DBSchema{
Tables: make(map[string]*memdb.TableSchema),
}
commonSchema := func(name string) *memdb.TableSchema {
return &memdb.TableSchema{
Name: name,
Indexes: map[string]*memdb.IndexSchema{
indexID: {
Name: indexID,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "ID",
},
},
indexName: {
Name: indexName,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
},
},
indexNamespace: {
Name: indexNamespace,
Indexer: &memdb.CompoundMultiIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "NamespacePath",
},
// By sending false as the query parameter, we can
// query just the namespace specific quota.
&memdb.FieldSetIndex{
Field: "MountPath",
},
},
},
},
indexNamespaceMount: {
Name: indexNamespaceMount,
AllowMissing: true,
Indexer: &memdb.CompoundMultiIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "NamespacePath",
},
&memdb.StringFieldIndex{
Field: "MountPath",
},
},
},
},
},
}
}
// Create a table per quota type. This allows names to be reused between
// different quota types and querying a bit easier.
for _, name := range quotaTypes() {
schema.Tables[name] = commonSchema(name)
}
return schema
}
// Invalidate receives notifications from the replication sub-system when a key
// is updated in the storage. This function will read the key from storage and
// updates the caches and data structures to reflect those updates.
func (m *Manager) Invalidate(key string) {
switch key {
case "config":
config, err := LoadConfig(m.ctx, m.storage)
if err != nil {
m.logger.Error("failed to invalidate quota config", "error", err)
return
}
m.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging)
default:
splitKeys := strings.Split(key, "/")
if len(splitKeys) != 2 {
m.logger.Error("incorrect key while invalidating quota rule")
return
}
qType := splitKeys[0]
name := splitKeys[1]
// Read quota rule from storage
quota, err := Load(m.ctx, m.storage, qType, name)
if err != nil {
m.logger.Error("failed to read invalidated quota rule", "error", err)
return
}
switch {
case quota == nil:
// Handle quota deletion
if err := m.DeleteQuota(m.ctx, qType, name); err != nil {
m.logger.Error("failed to delete invalidated quota rule", "error", err)
return
}
default:
// Handle quota update
if err := m.SetQuota(m.ctx, qType, quota, false); err != nil {
m.logger.Error("failed to update invalidated quota rule", "error", err)
return
}
}
}
}
// LoadConfig reads the quota configuration from the underlying storage
func LoadConfig(ctx context.Context, storage logical.Storage) (*Config, error) {
var config Config
entry, err := storage.Get(ctx, ConfigPath)
if err != nil {
return nil, err
}
if entry == nil {
return &config, nil
}
err = entry.DecodeJSON(&config)
if err != nil {
return nil, err
}
return &config, nil
}
// Load reads the quota rule from the underlying storage
func Load(ctx context.Context, storage logical.Storage, qType, name string) (Quota, error) {
var quota Quota
entry, err := storage.Get(ctx, QuotaStoragePath(qType, name))
if err != nil {
return nil, err
}
if entry == nil {
return nil, nil
}
switch qType {
case TypeRateLimit.String():
quota = &RateLimitQuota{}
case TypeLeaseCount.String():
quota = &LeaseCountQuota{}
default:
return nil, fmt.Errorf("unsupported type: %v", qType)
}
err = entry.DecodeJSON(quota)
if err != nil {
return nil, err
}
return quota, nil
}
// Setup loads the quota configuration and all the quota rules into the
// quota manager.
func (m *Manager) Setup(ctx context.Context, storage logical.Storage, isPerfStandby bool) error {
m.lock.Lock()
defer m.lock.Unlock()
m.storage = storage
m.ctx = ctx
m.isPerfStandby = isPerfStandby
// Load the quota configuration from storage and load it into the quota
// manager.
config, err := LoadConfig(ctx, storage)
if err != nil {
return err
}
m.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging)
// Load the quota rules for all supported types from storage and load it in
// the quota manager.
for _, qType := range quotaTypes() {
names, err := logical.CollectKeys(ctx, logical.NewStorageView(storage, StoragePrefix+qType+"/"))
if err != nil {
return nil
}
for _, name := range names {
quota, err := Load(ctx, m.storage, qType, name)
if err != nil {
return err
}
if quota == nil {
continue
}
err = m.setQuotaLocked(ctx, qType, quota, true)
if err != nil {
return err
}
}
}
return nil
}
// QuotaStoragePath returns the storage path suffix for persisting the quota
// rule.
func QuotaStoragePath(quotaType, name string) string {
return path.Join(StoragePrefix+quotaType, name)
}
// HandleRemount updates the quota subsystem about the remount operation that
// took place. Quota manager will trigger the quota specific updates including
// the mount path update..
func (m *Manager) HandleRemount(ctx context.Context, nsPath, fromPath, toPath string) error {
m.lock.Lock()
defer m.lock.Unlock()
txn := m.db.Txn(true)
defer txn.Abort()
// nsPath would have been made non-empty during insertion. Use non-empty value
// during query as well.
if nsPath == "" {
nsPath = "root"
}
idx := indexNamespaceMount
leaseQuotaUpdated := false
args := []interface{}{nsPath, fromPath}
for _, quotaType := range quotaTypes() {
iter, err := txn.Get(quotaType, idx, args...)
if err != nil {
return err
}
for raw := iter.Next(); raw != nil; raw = iter.Next() {
quota := raw.(Quota)
quota.handleRemount(toPath)
entry, err := logical.StorageEntryJSON(QuotaStoragePath(quotaType, quota.QuotaName()), quota)
if err != nil {
return err
}
if err := m.storage.Put(ctx, entry); err != nil {
return err
}
if quotaType == TypeLeaseCount.String() {
leaseQuotaUpdated = true
}
}
}
if leaseQuotaUpdated {
if err := m.recomputeLeaseCounts(ctx, txn); err != nil {
return err
}
}
txn.Commit()
return nil
}
// HandleBackendDisabling updates the quota subsystem with the disabling of auth
// or secret engine disabling.
func (m *Manager) HandleBackendDisabling(ctx context.Context, nsPath, mountPath string) error {
m.lock.Lock()
defer m.lock.Unlock()
txn := m.db.Txn(true)
defer txn.Abort()
// nsPath would have been made non-empty during insertion. Use non-empty value
// during query as well.
if nsPath == "" {
nsPath = "root"
}
idx := indexNamespaceMount
leaseQuotaDeleted := false
args := []interface{}{nsPath, mountPath}
for _, quotaType := range quotaTypes() {
iter, err := txn.Get(quotaType, idx, args...)
if err != nil {
return err
}
for raw := iter.Next(); raw != nil; raw = iter.Next() {
if err := txn.Delete(quotaType, raw); err != nil {
return fmt.Errorf("failed to delete quota from db after mount disabling; namespace %q, err %v", nsPath, err)
}
quota := raw.(Quota)
if err := m.storage.Delete(ctx, QuotaStoragePath(quotaType, quota.QuotaName())); err != nil {
return fmt.Errorf("failed to delete quota from storage after mount disabling; namespace %q, err %v", nsPath, err)
}
if quotaType == TypeLeaseCount.String() {
leaseQuotaDeleted = true
}
}
}
if leaseQuotaDeleted {
if err := m.recomputeLeaseCounts(ctx, txn); err != nil {
return err
}
}
txn.Commit()
return nil
}

View File

@@ -0,0 +1,282 @@
package quotas
import (
"fmt"
"math"
"sync"
"time"
"github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/sdk/helper/pathmanager"
"golang.org/x/time/rate"
)
var rateLimitExemptPaths = pathmanager.New()
const (
// DefaultRateLimitPurgeInterval defines the default purge interval used by a
// RateLimitQuota to remove stale client rate limiters.
DefaultRateLimitPurgeInterval = time.Minute
// DefaultRateLimitStaleAge defines the default stale age of a client limiter.
DefaultRateLimitStaleAge = 3 * time.Minute
// EnvVaultEnableRateLimitAuditLogging is used to enable audit logging of
// requests that get rejected due to rate limit quota violations.
EnvVaultEnableRateLimitAuditLogging = "VAULT_ENABLE_RATE_LIMIT_AUDIT_LOGGING"
)
func init() {
rateLimitExemptPaths.AddPaths([]string{
"/v1/sys/generate-recovery-token/attempt",
"/v1/sys/generate-recovery-token/update",
"/v1/sys/generate-root/attempt",
"/v1/sys/generate-root/update",
"/v1/sys/health",
"/v1/sys/seal-status",
"/v1/sys/unseal",
})
}
// ClientRateLimiter defines a token bucket based rate limiter for a unique
// addressable client (e.g. IP address). Whenever this client attempts to make
// a request, the lastSeen value will be updated.
type ClientRateLimiter struct {
// lastSeen defines the UNIX timestamp the client last made a request.
lastSeen time.Time
// limiter represents an instance of a token bucket based rate limiter.
limiter *rate.Limiter
}
// newClientRateLimiter returns a token bucket based rate limiter for a client
// that is uniquely addressable, where maxRequests defines the requests-per-second
// and burstSize defines the maximum burst allowed. A caller may provide -1 for
// burstSize to allow the burst value to be roughly equivalent to the RPS. Note,
// the underlying rate limiter is already thread-safe.
func newClientRateLimiter(maxRequests float64, burstSize int) *ClientRateLimiter {
if burstSize < 0 {
burstSize = int(math.Ceil(maxRequests))
}
return &ClientRateLimiter{
lastSeen: time.Now().UTC(),
limiter: rate.NewLimiter(rate.Limit(maxRequests), burstSize),
}
}
// Ensure that RateLimitQuota implements the Quota interface
var _ Quota = (*RateLimitQuota)(nil)
// RateLimitQuota represents the quota rule properties that is used to limit the
// number of requests per second for a namespace or mount.
type RateLimitQuota struct {
// ID is the identifier of the quota
ID string `json:"id"`
// Type of quota this represents
Type Type `json:"type"`
// Name of the quota rule
Name string `json:"name"`
// NamespacePath is the path of the namespace to which this quota is
// applicable.
NamespacePath string `json:"namespace_path"`
// MountPath is the path of the mount to which this quota is applicable
MountPath string `json:"mount_path"`
// Rate defines the rate of which allowed requests are refilled per second.
Rate float64 `json:"rate"`
// Burst defines maximum number of requests at any given moment to be allowed.
Burst int `json:"burst"`
lock *sync.Mutex
logger log.Logger
metricSink *metricsutil.ClusterMetricSink
purgeEnabled bool
// purgeInterval defines the interval in seconds in which the RateLimitQuota
// attempts to remove stale entries from the rateQuotas mapping.
purgeInterval time.Duration
closeCh chan struct{}
// staleAge defines the age in seconds in which a clientRateLimiter is
// considered stale. A clientRateLimiter is considered stale if the delta
// between the current purge time and its lastSeen timestamp is greater than
// this value.
staleAge time.Duration
// rateQuotas contains a mapping from a unique addressable client (e.g. IP address)
// to a clientRateLimiter reference. Every purgeInterval seconds, the RateLimitQuota
// will attempt to remove stale entries from the mapping.
rateQuotas map[string]*ClientRateLimiter
}
// NewRateLimitQuota creates a quota checker for imposing limits on the number
// of requests per second.
func NewRateLimitQuota(name, nsPath, mountPath string, rate float64, burst int) *RateLimitQuota {
return &RateLimitQuota{
Name: name,
Type: TypeRateLimit,
NamespacePath: nsPath,
MountPath: mountPath,
Rate: rate,
Burst: burst,
}
}
// jnitialize ensures the namespace and max requests are initialized, sets the ID
// if it's currently empty, sets the purge interval and stale age to default
// values, and finally starts the client purge go routine if it has been started
// already. Note, initialize will reset the internal rateQuotas mapping.
func (rlq *RateLimitQuota) initialize(logger log.Logger, ms *metricsutil.ClusterMetricSink) error {
if rlq.lock == nil {
rlq.lock = new(sync.Mutex)
}
rlq.lock.Lock()
defer rlq.lock.Unlock()
// Memdb requires a non-empty value for indexing
if rlq.NamespacePath == "" {
rlq.NamespacePath = "root"
}
if rlq.Rate <= 0 {
return fmt.Errorf("invalid avg rps: %v", rlq.Rate)
}
if rlq.Burst < int(rlq.Rate) {
return fmt.Errorf("burst size (%v) must be greater than or equal to average rps (%v)", rlq.Burst, rlq.Rate)
}
if logger != nil {
rlq.logger = logger
}
if rlq.metricSink == nil {
rlq.metricSink = ms
}
if rlq.ID == "" {
id, err := uuid.GenerateUUID()
if err != nil {
return err
}
rlq.ID = id
}
rlq.purgeInterval = DefaultRateLimitPurgeInterval
rlq.staleAge = DefaultRateLimitStaleAge
rlq.rateQuotas = make(map[string]*ClientRateLimiter)
if !rlq.purgeEnabled {
rlq.purgeEnabled = true
rlq.closeCh = make(chan struct{})
go rlq.purgeClientsLoop()
}
return nil
}
// quotaID returns the identifier of the quota rule
func (rlq *RateLimitQuota) quotaID() string {
return rlq.ID
}
// QuotaName returns the name of the quota rule
func (rlq *RateLimitQuota) QuotaName() string {
return rlq.Name
}
// purgeClientsLoop performs a blocking process where every purgeInterval
// duration, we look for stale clients to remove from the rateQuotas map.
// A ClientRateLimiter is considered stale if its lastSeen timestamp exceeds the
// current time. The loop will continue to run indefinitely until a value is
// sent on the closeCh in which we stop the ticker and exit.
func (rlq *RateLimitQuota) purgeClientsLoop() {
ticker := time.NewTicker(rlq.purgeInterval)
for {
select {
case t := <-ticker.C:
rlq.lock.Lock()
for client, crl := range rlq.rateQuotas {
if t.UTC().Sub(crl.lastSeen) >= rlq.staleAge {
delete(rlq.rateQuotas, client)
}
}
rlq.lock.Unlock()
case <-rlq.closeCh:
ticker.Stop()
rlq.purgeEnabled = false
return
}
}
}
// clientRateLimiter returns a reference to a ClientRateLimiter based on a
// provided client address (e.g. IP address). If the ClientRateLimiter does not
// exist in the RateLimitQuota's mapping, one will be created and set. The
// created RateLimitQuota will have its requests-per-second set to
// RateLimitQuota.AverageRps. If the ClientRateLimiter already exists, the
// lastSeen timestamp will be updated.
func (rlq *RateLimitQuota) clientRateLimiter(addr string) *ClientRateLimiter {
rlq.lock.Lock()
defer rlq.lock.Unlock()
crl, ok := rlq.rateQuotas[addr]
if !ok {
limiter := newClientRateLimiter(rlq.Rate, rlq.Burst)
rlq.rateQuotas[addr] = limiter
return limiter
}
crl.lastSeen = time.Now().UTC()
return crl
}
// allow decides if the request is allowed by the quota. An error will be
// returned if the request ID or address is empty. If the path is exempt, the
// quota will not be evaluated. Otherwise, the client rate limiter is retrieved
// by address and the rate limit quota is checked against that limiter.
func (rlq *RateLimitQuota) allow(req *Request) (Response, error) {
var resp Response
// Skip rate limit checks for paths that are exempt from rate limiting.
if rateLimitExemptPaths.HasPath(req.Path) {
resp.Allowed = true
return resp, nil
}
if req.ClientAddress == "" {
return resp, fmt.Errorf("missing request client address in quota request")
}
resp.Allowed = rlq.clientRateLimiter(req.ClientAddress).limiter.Allow()
if !resp.Allowed {
rlq.metricSink.IncrCounterWithLabels([]string{"quota", "rate_limit", "violation"}, 1, []metrics.Label{{"name", rlq.Name}})
}
return resp, nil
}
// close stops the current running client purge loop.
func (rlq *RateLimitQuota) close() error {
close(rlq.closeCh)
return nil
}
func (rlq *RateLimitQuota) handleRemount(toPath string) {
rlq.MountPath = toPath
}

View File

@@ -0,0 +1,173 @@
package quotas
import (
"fmt"
"sync"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/sdk/helper/logging"
"go.uber.org/atomic"
)
func TestNewClientRateLimiter(t *testing.T) {
testCases := []struct {
maxRequests float64
burstSize int
expectedBurst int
}{
{1000, -1, 1000},
{1000, 5000, 5000},
{16.1, -1, 17},
{16.7, -1, 17},
{16.7, 100, 100},
}
for _, tc := range testCases {
crl := newClientRateLimiter(tc.maxRequests, tc.burstSize)
b := crl.limiter.Burst()
if b != tc.expectedBurst {
t.Fatalf("unexpected burst size; expected: %d, got: %d", tc.expectedBurst, b)
}
}
}
func TestNewRateLimitQuota(t *testing.T) {
rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, 50)
if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil {
t.Fatal(err)
}
if !rlq.purgeEnabled {
t.Fatal("expected rate limit quota to start purge loop")
}
if rlq.purgeInterval != DefaultRateLimitPurgeInterval {
t.Fatalf("unexpected purgeInterval; expected: %d, got: %d", DefaultRateLimitPurgeInterval, rlq.purgeInterval)
}
if rlq.staleAge != DefaultRateLimitStaleAge {
t.Fatalf("unexpected staleAge; expected: %d, got: %d", DefaultRateLimitStaleAge, rlq.staleAge)
}
}
func TestRateLimitQuota_Close(t *testing.T) {
rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, 50)
if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil {
t.Fatal(err)
}
if err := rlq.close(); err != nil {
t.Fatalf("unexpected error when closing: %v", err)
}
time.Sleep(time.Second) // allow enough time for purgeClientsLoop to receive on closeCh
if rlq.purgeEnabled {
t.Fatal("expected client purging to be disabled after close")
}
}
func TestRateLimitQuota_Allow(t *testing.T) {
rlq := &RateLimitQuota{
Name: "test-rate-limiter",
Type: TypeRateLimit,
NamespacePath: "qa",
MountPath: "/foo/bar",
Rate: 16.7,
Burst: 83,
purgeEnabled: true, // to allow manual setting of purgeInterval and staleAge
}
if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil {
t.Fatal(err)
}
// override value and manually start purgeClientsLoop for testing purposes
rlq.purgeInterval = 10 * time.Second
rlq.staleAge = 10 * time.Second
go rlq.purgeClientsLoop()
var wg sync.WaitGroup
type clientResult struct {
atomicNumAllow *atomic.Int32
atomicNumFail *atomic.Int32
}
reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) {
defer wg.Done()
resp, err := rlq.allow(&Request{ClientAddress: addr})
if err != nil {
return
}
if resp.Allowed {
atomicNumAllow.Add(1)
} else {
atomicNumFail.Add(1)
}
}
results := make(map[string]*clientResult)
start := time.Now()
end := start.Add(5 * time.Second)
for time.Now().Before(end) {
for i := 0; i < 5; i++ {
wg.Add(1)
addr := fmt.Sprintf("127.0.0.%d", i)
cr, ok := results[addr]
if !ok {
results[addr] = &clientResult{atomicNumAllow: atomic.NewInt32(0), atomicNumFail: atomic.NewInt32(0)}
cr = results[addr]
}
go reqFunc(addr, cr.atomicNumAllow, cr.atomicNumFail)
time.Sleep(2 * time.Millisecond)
}
}
wg.Wait()
if got, expected := len(results), len(rlq.rateQuotas); got != expected {
t.Fatalf("unexpected number of tracked client rate limit quotas; got %d, expected; %d", got, expected)
}
elapsed := time.Since(start)
// evaluate the ideal RPS as (burst + (RPS * totalSeconds))
ideal := float64(rlq.Burst) + (rlq.Rate * float64(elapsed) / float64(time.Second))
for addr, cr := range results {
numAllow := cr.atomicNumAllow.Load()
numFail := cr.atomicNumFail.Load()
// ensure there were some failed requests for the namespace
if numFail == 0 {
t.Fatalf("expected some requests to fail; addr: %s, numSuccess: %d, numFail: %d, elapsed: %d", addr, numAllow, numFail, elapsed)
}
// ensure that we should never get more requests than allowed for the namespace
if want := int32(ideal + 1); numAllow > want {
t.Fatalf("too many successful requests; addr: %s, want: %d, numSuccess: %d, numFail: %d, elapsed: %d", addr, want, numAllow, numFail, elapsed)
}
}
// allow enough time for the client to be purged
time.Sleep(rlq.purgeInterval * 2)
for addr := range results {
rlc, ok := rlq.rateQuotas[addr]
if ok || rlc != nil {
t.Fatalf("expected stale client to be purged: %s", addr)
}
}
}

View File

@@ -0,0 +1,67 @@
package quotas
import (
"context"
"testing"
"github.com/go-test/deep"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/sdk/helper/logging"
)
func TestQuotas_Precedence(t *testing.T) {
qm, err := NewManager(logging.NewVaultLogger(log.Trace), nil, metricsutil.BlackholeSink())
if err != nil {
t.Fatal(err)
}
setQuotaFunc := func(t *testing.T, name, nsPath, mountPath string) Quota {
t.Helper()
quota := NewRateLimitQuota(name, nsPath, mountPath, 10, 20)
err := qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true)
if err != nil {
t.Fatal(err)
}
return quota
}
checkQuotaFunc := func(t *testing.T, nsPath, mountPath string, expected Quota) {
t.Helper()
quota, err := qm.queryQuota(nil, &Request{
Type: TypeRateLimit,
NamespacePath: nsPath,
MountPath: mountPath,
})
if err != nil {
t.Fatal(err)
}
if diff := deep.Equal(expected, quota); len(diff) > 0 {
t.Fatal(diff)
}
}
// No quota present. Expect nil.
checkQuotaFunc(t, "", "", nil)
// Define global quota and expect that to be returned.
rateLimitGlobalQuota := setQuotaFunc(t, "rateLimitGlobalQuota", "", "")
checkQuotaFunc(t, "", "", rateLimitGlobalQuota)
// Define a global mount specific quota and expect that to be returned.
rateLimitGlobalMountQuota := setQuotaFunc(t, "rateLimitGlobalMountQuota", "", "testmount")
checkQuotaFunc(t, "", "testmount", rateLimitGlobalMountQuota)
// Define a namespace quota and expect that to be returned.
rateLimitNSQuota := setQuotaFunc(t, "rateLimitNSQuota", "testns", "")
checkQuotaFunc(t, "testns", "", rateLimitNSQuota)
// Define a namespace mount specific quota and expect that to be returned.
rateLimitNSMountQuota := setQuotaFunc(t, "rateLimitNSMountQuota", "testns", "testmount")
checkQuotaFunc(t, "testns", "testmount", rateLimitNSMountQuota)
// Now that many quota types are defined, verify that the most specific
// matches are returned per namespace.
checkQuotaFunc(t, "", "", rateLimitGlobalQuota)
checkQuotaFunc(t, "testns", "", rateLimitNSQuota)
}

View File

@@ -0,0 +1,65 @@
// +build !enterprise
package quotas
import (
"context"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/metricsutil"
memdb "github.com/hashicorp/go-memdb"
)
func quotaTypes() []string {
return []string{
TypeRateLimit.String(),
}
}
func (m *Manager) init(walkFunc leaseWalkFunc) {}
func (m *Manager) recomputeLeaseCounts(ctx context.Context, txn *memdb.Txn) error {
return nil
}
func (m *Manager) setIsPerfStandby(quota Quota) {}
func (m *Manager) inLeasePathCache(path string) bool {
return false
}
type entManager struct {
isPerfStandby bool
}
func (*entManager) Reset() error {
return nil
}
type LeaseCountQuota struct {
}
func (l LeaseCountQuota) allow(request *Request) (Response, error) {
panic("implement me")
}
func (l LeaseCountQuota) quotaID() string {
panic("implement me")
}
func (l LeaseCountQuota) QuotaName() string {
panic("implement me")
}
func (l LeaseCountQuota) initialize(logger log.Logger, sink *metricsutil.ClusterMetricSink) error {
panic("implement me")
}
func (l LeaseCountQuota) close() error {
panic("implement me")
}
func (l LeaseCountQuota) handleRemount(s string) {
panic("implement me")
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/helper/wrapping" "github.com/hashicorp/vault/sdk/helper/wrapping"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/quotas"
uberAtomic "go.uber.org/atomic" uberAtomic "go.uber.org/atomic"
) )
@@ -539,7 +540,6 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp
} }
// Create an audit trail of the response // Create an audit trail of the response
if !isControlGroupRun(req) { if !isControlGroupRun(req) {
switch req.Path { switch req.Path {
case "sys/replication/dr/status", "sys/replication/performance/status", "sys/replication/status": case "sys/replication/dr/status", "sys/replication/performance/status", "sys/replication/status":
@@ -708,6 +708,36 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
} }
} }
leaseGenerated := false
quotaResp, quotaErr := c.applyLeaseCountQuota(&quotas.Request{
Path: req.Path,
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
NamespacePath: ns.Path,
})
if quotaErr != nil {
c.logger.Error("failed to apply quota", "path", req.Path, "error", err)
retErr = multierror.Append(retErr, quotaErr)
return nil, auth, retErr
}
if !quotaResp.Allowed {
if c.logger.IsTrace() {
c.logger.Trace("request rejected due to lease count quota violation", "request_path", req.Path)
}
retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrLeaseCountQuotaExceeded))
return nil, auth, retErr
}
defer func() {
if quotaResp.Access != nil {
quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated)
if quotaAckErr != nil {
retErr = multierror.Append(retErr, quotaAckErr)
}
}
}()
// Route the request // Route the request
resp, routeErr := c.doRouting(ctx, req) resp, routeErr := c.doRouting(ctx, req)
if resp != nil { if resp != nil {
@@ -827,6 +857,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
retErr = multierror.Append(retErr, ErrInternalError) retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr return nil, auth, retErr
} }
leaseGenerated = true
resp.Secret.LeaseID = leaseID resp.Secret.LeaseID = leaseID
// Get the actual time of the lease // Get the actual time of the lease
@@ -917,6 +948,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
retErr = multierror.Append(retErr, ErrInternalError) retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr return nil, auth, retErr
} }
leaseGenerated = true
} }
} }
@@ -1073,6 +1105,46 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
// If the response generated an authentication, then generate the token // If the response generated an authentication, then generate the token
if resp != nil && resp.Auth != nil { if resp != nil && resp.Auth != nil {
ns, err := namespace.FromContext(ctx)
if err != nil {
c.logger.Error("failed to get namespace from context", "error", err)
retErr = multierror.Append(retErr, ErrInternalError)
return
}
leaseGenerated := false
// The request successfully authenticated itself. Run the quota checks
// before creating lease.
quotaResp, quotaErr := c.applyLeaseCountQuota(&quotas.Request{
Path: req.Path,
MountPath: strings.TrimPrefix(req.MountPoint, ns.Path),
NamespacePath: ns.Path,
})
if quotaErr != nil {
c.logger.Error("failed to apply quota", "path", req.Path, "error", err)
retErr = multierror.Append(retErr, quotaErr)
return
}
if !quotaResp.Allowed {
if c.logger.IsTrace() {
c.logger.Trace("request rejected due to lease count quota violation", "request_path", req.Path)
}
retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrLeaseCountQuotaExceeded))
return
}
defer func() {
if quotaResp.Access != nil {
quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated)
if quotaAckErr != nil {
retErr = multierror.Append(retErr, quotaAckErr)
}
}
}()
var entity *identity.Entity var entity *identity.Entity
auth = resp.Auth auth = resp.Auth
@@ -1141,10 +1213,6 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
resp.AddWarning(warning) resp.AddWarning(warning)
} }
ns, err := namespace.FromContext(ctx)
if err != nil {
return nil, nil, err
}
_, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, ns, auth.EntityID) _, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, ns, auth.EntityID)
if err != nil { if err != nil {
return nil, nil, ErrInternalError return nil, nil, ErrInternalError
@@ -1181,6 +1249,9 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
err = registerFunc(ctx, tokenTTL, req.Path, auth) err = registerFunc(ctx, tokenTTL, req.Path, auth)
switch { switch {
case err == nil: case err == nil:
if auth.TokenType != logical.TokenTypeBatch {
leaseGenerated = true
}
case err == ErrInternalError: case err == ErrInternalError:
return nil, auth, err return nil, auth, err
default: default:

View File

@@ -422,6 +422,14 @@ func (r *Router) MatchingSystemView(ctx context.Context, path string) logical.Sy
return raw.(*routeEntry).backend.System() return raw.(*routeEntry).backend.System()
} }
func (r *Router) MatchingMountByAPIPath(ctx context.Context, path string) string {
me, _, _ := r.matchingMountEntryByPath(ctx, path, true)
if me == nil {
return ""
}
return me.Path
}
// MatchingStoragePrefixByAPIPath the storage prefix for the given api path // MatchingStoragePrefixByAPIPath the storage prefix for the given api path
func (r *Router) MatchingStoragePrefixByAPIPath(ctx context.Context, path string) (string, bool) { func (r *Router) MatchingStoragePrefixByAPIPath(ctx context.Context, path string) (string, bool) {
ns, err := namespace.FromContext(ctx) ns, err := namespace.FromContext(ctx)

View File

@@ -27,8 +27,8 @@ func (r *Response) DecodeJSON(out interface{}) error {
// body must still be closed manually. // body must still be closed manually.
func (r *Response) Error() error { func (r *Response) Error() error {
// 200 to 399 are okay status codes. 429 is the code for health status of // 200 to 399 are okay status codes. 429 is the code for health status of
// standby nodes. // standby nodes, otherwise, 429 is treated as quota limit reached.
if (r.StatusCode >= 200 && r.StatusCode < 400) || r.StatusCode == 429 { if (r.StatusCode >= 200 && r.StatusCode < 400) || (r.StatusCode == 429 && r.Request.URL.Path == "/v1/sys/health") {
return nil return nil
} }

View File

@@ -28,6 +28,14 @@ var (
// ErrPerfStandbyForward is returned when Vault is in a state such that a // ErrPerfStandbyForward is returned when Vault is in a state such that a
// perf standby cannot satisfy a request // perf standby cannot satisfy a request
ErrPerfStandbyPleaseForward = errors.New("please forward to the active node") ErrPerfStandbyPleaseForward = errors.New("please forward to the active node")
// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease
// count quota being exceeded.
ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded")
// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a
// rate limit quota being exceeded.
ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded")
) )
type HTTPCodedError interface { type HTTPCodedError interface {

View File

@@ -81,7 +81,7 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
} }
}) })
if allErrors != nil { if allErrors != nil {
return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors) return codedErr.Code, multierror.Append(fmt.Errorf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg), allErrors)
} }
return codedErr.Code, errors.New(codedErr.Msg) return codedErr.Code, errors.New(codedErr.Msg)
} }
@@ -110,6 +110,10 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
statusCode = http.StatusBadRequest statusCode = http.StatusBadRequest
case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): case errwrap.Contains(err, ErrUpstreamRateLimited.Error()):
statusCode = http.StatusBadGateway statusCode = http.StatusBadGateway
case errwrap.Contains(err, ErrRateLimitQuotaExceeded.Error()):
statusCode = http.StatusTooManyRequests
case errwrap.Contains(err, ErrLeaseCountQuotaExceeded.Error()):
statusCode = http.StatusTooManyRequests
} }
} }

View File

@@ -163,6 +163,17 @@ These metrics cover measurement of token, identity, and lease operations, and co
| `vault.token.revoke-tree` | Time taken to revoke a token tree | ms | summary | | `vault.token.revoke-tree` | Time taken to revoke a token tree | ms | summary |
| `vault.token.store` | Time taken to store an updated token entry without writing to the secondary index | ms | summary | | `vault.token.store` | Time taken to store an updated token entry without writing to the secondary index | ms | summary |
## Resource Quota Metrics
These metrics relate to rate limit and lease count quotas. Each metric comes with a label "name" identifying the specific quota.
| Metric | Description | Unit | Type |
| :---------------------------- | :---------------------------------------------------------------- | :---- | :------ |
| `quota.rate_limit.violation` | Total number of rate limit quota violations | quota | counter |
| `quota.lease_count.violation` | Total number of lease count quota violations | quota | counter |
| `quota.lease_count.max` | Total maximum amount of leases allowed by the lease count quota | lease | gauge |
| `quota.lease_count.counter` | Total current amount of leases generated by the lease count quota | lease | gauge |
## Merkle Tree and Write Ahead Log Metrics ## Merkle Tree and Write Ahead Log Metrics
These metrics relate to internal operations on Merkle Trees and Write Ahead Logs (WAL) These metrics relate to internal operations on Merkle Trees and Write Ahead Logs (WAL)