mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	Resource Quotas: Rate Limiting (#9330)
This commit is contained in:
		| @@ -27,8 +27,8 @@ func (r *Response) DecodeJSON(out interface{}) error { | |||||||
| // body must still be closed manually. | // body must still be closed manually. | ||||||
| func (r *Response) Error() error { | func (r *Response) Error() error { | ||||||
| 	// 200 to 399 are okay status codes. 429 is the code for health status of | 	// 200 to 399 are okay status codes. 429 is the code for health status of | ||||||
| 	// standby nodes. | 	// standby nodes, otherwise, 429 is treated as quota limit reached. | ||||||
| 	if (r.StatusCode >= 200 && r.StatusCode < 400) || r.StatusCode == 429 { | 	if (r.StatusCode >= 200 && r.StatusCode < 400) || (r.StatusCode == 429 && r.Request.URL.Path == "/v1/sys/health") { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							| @@ -146,6 +146,7 @@ require ( | |||||||
| 	golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 | 	golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 | ||||||
| 	golang.org/x/net v0.0.0-20200602114024-627f9648deb9 | 	golang.org/x/net v0.0.0-20200602114024-627f9648deb9 | ||||||
| 	golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d | 	golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d | ||||||
|  | 	golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 | ||||||
| 	golang.org/x/tools v0.0.0-20200416214402-fc959738d646 | 	golang.org/x/tools v0.0.0-20200416214402-fc959738d646 | ||||||
| 	google.golang.org/api v0.24.0 | 	google.golang.org/api v0.24.0 | ||||||
| 	google.golang.org/grpc v1.29.1 | 	google.golang.org/grpc v1.29.1 | ||||||
|   | |||||||
| @@ -176,8 +176,8 @@ func Handler(props *vault.HandlerProperties) http.Handler { | |||||||
| 	// Wrap the handler in another handler to trigger all help paths. | 	// Wrap the handler in another handler to trigger all help paths. | ||||||
| 	helpWrappedHandler := wrapHelpHandler(mux, core) | 	helpWrappedHandler := wrapHelpHandler(mux, core) | ||||||
| 	corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core) | 	corsWrappedHandler := wrapCORSHandler(helpWrappedHandler, core) | ||||||
|  | 	quotaWrappedHandler := rateLimitQuotaWrapping(corsWrappedHandler, core) | ||||||
| 	genericWrappedHandler := genericWrapping(core, corsWrappedHandler, props) | 	genericWrappedHandler := genericWrapping(core, quotaWrappedHandler, props) | ||||||
|  |  | ||||||
| 	// Wrap the handler with PrintablePathCheckHandler to check for non-printable | 	// Wrap the handler with PrintablePathCheckHandler to check for non-printable | ||||||
| 	// characters in the request path. | 	// characters in the request path. | ||||||
| @@ -221,26 +221,14 @@ func (w *copyResponseWriter) WriteHeader(code int) { | |||||||
|  |  | ||||||
| func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { | func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { | ||||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		origBody := new(bytes.Buffer) |  | ||||||
| 		reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody)) |  | ||||||
| 		r.Body = reader |  | ||||||
| 		req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) |  | ||||||
| 		if err != nil || status != 0 { |  | ||||||
| 			respondError(w, status, err) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		if origBody != nil { |  | ||||||
| 			r.Body = ioutil.NopCloser(origBody) |  | ||||||
| 		} |  | ||||||
| 		input := &logical.LogInput{ | 		input := &logical.LogInput{ | ||||||
| 			Request: req, | 			Request: w.(*LogicalResponseWriter).request, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		core.AuditLogger().AuditRequest(r.Context(), input) | 		core.AuditLogger().AuditRequest(r.Context(), input) | ||||||
| 		cw := newCopyResponseWriter(w) | 		cw := newCopyResponseWriter(w) | ||||||
| 		h.ServeHTTP(cw, r) | 		h.ServeHTTP(cw, r) | ||||||
| 		data := make(map[string]interface{}) | 		data := make(map[string]interface{}) | ||||||
| 		err = jsonutil.DecodeJSON(cw.body.Bytes(), &data) | 		err := jsonutil.DecodeJSON(cw.body.Bytes(), &data) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			// best effort, ignore | 			// best effort, ignore | ||||||
| 		} | 		} | ||||||
| @@ -249,7 +237,13 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { | |||||||
| 		core.AuditLogger().AuditResponse(r.Context(), input) | 		core.AuditLogger().AuditResponse(r.Context(), input) | ||||||
| 		return | 		return | ||||||
| 	}) | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // LogicalResponseWriter is used to carry the logical request from generic | ||||||
|  | // handler down to all the middleware http handlers. | ||||||
|  | type LogicalResponseWriter struct { | ||||||
|  | 	http.ResponseWriter | ||||||
|  | 	request *logical.Request | ||||||
| } | } | ||||||
|  |  | ||||||
| // wrapGenericHandler wraps the handler with an extra layer of handler where | // wrapGenericHandler wraps the handler with an extra layer of handler where | ||||||
| @@ -288,6 +282,7 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr | |||||||
| 		} | 		} | ||||||
| 		ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) | 		ctx = context.WithValue(ctx, "original_request_path", r.URL.Path) | ||||||
| 		r = r.WithContext(ctx) | 		r = r.WithContext(ctx) | ||||||
|  | 		r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) | ||||||
|  |  | ||||||
| 		switch { | 		switch { | ||||||
| 		case strings.HasPrefix(r.URL.Path, "/v1/"): | 		case strings.HasPrefix(r.URL.Path, "/v1/"): | ||||||
| @@ -306,7 +301,27 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		h.ServeHTTP(w, r) | 		origBody := new(bytes.Buffer) | ||||||
|  | 		reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody)) | ||||||
|  | 		r.Body = reader | ||||||
|  | 		req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) | ||||||
|  | 		if err != nil || status != 0 { | ||||||
|  | 			respondError(w, status, err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		// Reset the body since logical request creation already read the | ||||||
|  | 		// request body. | ||||||
|  | 		r.Body = ioutil.NopCloser(origBody) | ||||||
|  |  | ||||||
|  | 		// Set the mount path in the request | ||||||
|  | 		req.MountPoint = core.MatchingMount(r.Context(), req.Path) | ||||||
|  |  | ||||||
|  | 		// Pass the logical request down through the response writer | ||||||
|  | 		h.ServeHTTP(&LogicalResponseWriter{ | ||||||
|  | 			ResponseWriter: w, | ||||||
|  | 			request:        req, | ||||||
|  | 		}, r) | ||||||
|  |  | ||||||
| 		cancelFunc() | 		cancelFunc() | ||||||
| 		return | 		return | ||||||
| 	}) | 	}) | ||||||
|   | |||||||
| @@ -141,6 +141,7 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	case "OPTIONS": | 	case "OPTIONS": | ||||||
|  | 	case "HEAD": | ||||||
| 	default: | 	default: | ||||||
| 		return nil, nil, http.StatusMethodNotAllowed, nil | 		return nil, nil, http.StatusMethodNotAllowed, nil | ||||||
| 	} | 	} | ||||||
| @@ -169,36 +170,32 @@ func buildLogicalRequestNoAuth(perfStandby bool, w http.ResponseWriter, r *http. | |||||||
| 	return req, origBody, 0, nil | 	return req, origBody, 0, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) { | func setupLogicalRequest(core *vault.Core, req *logical.Request, r *http.Request) (*logical.Request, int, error) { | ||||||
| 	req, origBody, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) | 	var err error | ||||||
| 	if err != nil || status != 0 { |  | ||||||
| 		return nil, nil, status, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	req, err = requestAuth(core, r, req) | 	req, err = requestAuth(core, r, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { | 		if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) { | ||||||
| 			return nil, nil, http.StatusForbidden, nil | 			return nil, http.StatusForbidden, nil | ||||||
| 		} | 		} | ||||||
| 		return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err) | 		return nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	req, err = requestWrapInfo(r, req) | 	req, err = requestWrapInfo(r, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err) | 		return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = parseMFAHeader(req) | 	err = parseMFAHeader(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err) | 		return nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = requestPolicyOverride(r, req) | 	err = requestPolicyOverride(r, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err) | 		return nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return req, origBody, 0, nil | 	return req, 0, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // handleLogical returns a handler for processing logical requests. These requests | // handleLogical returns a handler for processing logical requests. These requests | ||||||
| @@ -257,7 +254,7 @@ func handleLogicalRecovery(raw *vault.RawBackend, token *atomic.String) http.Han | |||||||
| // toggles. Refer to usage on functions for possible behaviors. | // toggles. Refer to usage on functions for possible behaviors. | ||||||
| func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForward bool) http.Handler { | func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForward bool) http.Handler { | ||||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		req, origBody, statusCode, err := buildLogicalRequest(core, w, r) | 		req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r) | ||||||
| 		if err != nil || statusCode != 0 { | 		if err != nil || statusCode != 0 { | ||||||
| 			respondError(w, statusCode, err) | 			respondError(w, statusCode, err) | ||||||
| 			return | 			return | ||||||
| @@ -270,10 +267,6 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw | |||||||
| 				respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) | 				respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if origBody != nil { |  | ||||||
| 				r.Body = origBody |  | ||||||
| 			} |  | ||||||
| 			forwardRequest(core, w, r) | 			forwardRequest(core, w, r) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @@ -398,9 +391,6 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw | |||||||
| 			respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) | 			respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) | ||||||
| 			return | 			return | ||||||
| 		case needsForward && !noForward: | 		case needsForward && !noForward: | ||||||
| 			if origBody != nil { |  | ||||||
| 				r.Body = origBody |  | ||||||
| 			} |  | ||||||
| 			forwardRequest(core, w, r) | 			forwardRequest(core, w, r) | ||||||
| 			return | 			return | ||||||
| 		case !ok: | 		case !ok: | ||||||
|   | |||||||
| @@ -281,7 +281,13 @@ func TestLogical_ListSuffix(t *testing.T) { | |||||||
| 	req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil) | 	req, _ := http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo", nil) | ||||||
| 	req = req.WithContext(namespace.RootContext(nil)) | 	req = req.WithContext(namespace.RootContext(nil)) | ||||||
| 	req.Header.Add(consts.AuthHeaderName, rootToken) | 	req.Header.Add(consts.AuthHeaderName, rootToken) | ||||||
| 	lreq, _, status, err := buildLogicalRequest(core, nil, req) |  | ||||||
|  | 	lreq, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), nil, req) | ||||||
|  | 	if err != nil || status != 0 { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	lreq, status, err = setupLogicalRequest(core, lreq, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| @@ -295,7 +301,11 @@ func TestLogical_ListSuffix(t *testing.T) { | |||||||
| 	req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil) | 	req, _ = http.NewRequest("GET", "http://127.0.0.1:8200/v1/secret/foo?list=true", nil) | ||||||
| 	req = req.WithContext(namespace.RootContext(nil)) | 	req = req.WithContext(namespace.RootContext(nil)) | ||||||
| 	req.Header.Add(consts.AuthHeaderName, rootToken) | 	req.Header.Add(consts.AuthHeaderName, rootToken) | ||||||
| 	lreq, _, status, err = buildLogicalRequest(core, nil, req) | 	lreq, _, status, err = buildLogicalRequestNoAuth(core.PerfStandby(), nil, req) | ||||||
|  | 	if err != nil || status != 0 { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	lreq, status, err = setupLogicalRequest(core, lreq, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| @@ -309,7 +319,11 @@ func TestLogical_ListSuffix(t *testing.T) { | |||||||
| 	req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil) | 	req, _ = http.NewRequest("LIST", "http://127.0.0.1:8200/v1/secret/foo", nil) | ||||||
| 	req = req.WithContext(namespace.RootContext(nil)) | 	req = req.WithContext(namespace.RootContext(nil)) | ||||||
| 	req.Header.Add(consts.AuthHeaderName, rootToken) | 	req.Header.Add(consts.AuthHeaderName, rootToken) | ||||||
| 	lreq, _, status, err = buildLogicalRequest(core, nil, req) | 	lreq, _, status, err = buildLogicalRequestNoAuth(core.PerfStandby(), nil, req) | ||||||
|  | 	if err != nil || status != 0 { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	lreq, status, err = setupLogicalRequest(core, lreq, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ import ( | |||||||
|  |  | ||||||
| func handleSysSeal(core *vault.Core) http.Handler { | func handleSysSeal(core *vault.Core) http.Handler { | ||||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		req, _, statusCode, err := buildLogicalRequest(core, w, r) | 		req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r) | ||||||
| 		if err != nil || statusCode != 0 { | 		if err != nil || statusCode != 0 { | ||||||
| 			respondError(w, statusCode, err) | 			respondError(w, statusCode, err) | ||||||
| 			return | 			return | ||||||
| @@ -47,7 +47,7 @@ func handleSysSeal(core *vault.Core) http.Handler { | |||||||
|  |  | ||||||
| func handleSysStepDown(core *vault.Core) http.Handler { | func handleSysStepDown(core *vault.Core) http.Handler { | ||||||
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		req, _, statusCode, err := buildLogicalRequest(core, w, r) | 		req, statusCode, err := setupLogicalRequest(core, w.(*LogicalResponseWriter).request, r) | ||||||
| 		if err != nil || statusCode != 0 { | 		if err != nil || statusCode != 0 { | ||||||
| 			respondError(w, statusCode, err) | 			respondError(w, statusCode, err) | ||||||
| 			return | 			return | ||||||
|   | |||||||
							
								
								
									
										61
									
								
								http/util.go
									
									
									
									
									
								
							
							
						
						
									
										61
									
								
								http/util.go
									
									
									
									
									
								
							| @@ -1,15 +1,21 @@ | |||||||
| package http | package http | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/errwrap" | ||||||
| 	"github.com/hashicorp/vault/helper/namespace" | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
| 	"github.com/hashicorp/vault/vault" | 	"github.com/hashicorp/vault/vault" | ||||||
|  | 	"github.com/hashicorp/vault/vault/quotas" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| 	adjustRequest = func(c *vault.Core, r *http.Request) (*http.Request, int) { | 	adjustRequest = func(c *vault.Core, r *http.Request) (*http.Request, int) { | ||||||
| 		return r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)), 0 | 		return r, 0 | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	genericWrapping = func(core *vault.Core, in http.Handler, props *vault.HandlerProperties) http.Handler { | 	genericWrapping = func(core *vault.Core, in http.Handler, props *vault.HandlerProperties) http.Handler { | ||||||
| @@ -22,3 +28,56 @@ var ( | |||||||
|  |  | ||||||
| 	nonVotersAllowed = false | 	nonVotersAllowed = false | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler { | ||||||
|  | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		ns, err := namespace.FromContext(r.Context()) | ||||||
|  | 		if err != nil { | ||||||
|  | 			respondError(w, http.StatusInternalServerError, err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		req := w.(*LogicalResponseWriter).request | ||||||
|  | 		quotaResp, err := core.ApplyRateLimitQuota("as.Request{ | ||||||
|  | 			Type:          quotas.TypeRateLimit, | ||||||
|  | 			Path:          req.Path, | ||||||
|  | 			MountPath:     strings.TrimPrefix(req.MountPoint, ns.Path), | ||||||
|  | 			NamespacePath: ns.Path, | ||||||
|  | 			ClientAddress: parseRemoteIPAddress(r), | ||||||
|  | 		}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			core.Logger().Error("failed to apply quota", "path", req.Path, "error", err) | ||||||
|  | 			respondError(w, http.StatusUnprocessableEntity, err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !quotaResp.Allowed { | ||||||
|  | 			quotaErr := errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrRateLimitQuotaExceeded) | ||||||
|  | 			respondError(w, http.StatusTooManyRequests, quotaErr) | ||||||
|  |  | ||||||
|  | 			if core.Logger().IsTrace() { | ||||||
|  | 				core.Logger().Trace("request rejected due to lease count quota violation", "request_path", req.Path) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if core.RateLimitAuditLoggingEnabled() { | ||||||
|  | 				_ = core.AuditLogger().AuditRequest(r.Context(), &logical.LogInput{ | ||||||
|  | 					Request:  req, | ||||||
|  | 					OuterErr: quotaErr, | ||||||
|  | 				}) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		handler.ServeHTTP(w, r) | ||||||
|  | 		return | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseRemoteIPAddress(r *http.Request) string { | ||||||
|  | 	ip, _, err := net.SplitHostPort(r.RemoteAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return ip | ||||||
|  | } | ||||||
|   | |||||||
| @@ -28,6 +28,14 @@ var ( | |||||||
| 	// ErrPerfStandbyForward is returned when Vault is in a state such that a | 	// ErrPerfStandbyForward is returned when Vault is in a state such that a | ||||||
| 	// perf standby cannot satisfy a request | 	// perf standby cannot satisfy a request | ||||||
| 	ErrPerfStandbyPleaseForward = errors.New("please forward to the active node") | 	ErrPerfStandbyPleaseForward = errors.New("please forward to the active node") | ||||||
|  |  | ||||||
|  | 	// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease | ||||||
|  | 	// count quota being exceeded. | ||||||
|  | 	ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded") | ||||||
|  |  | ||||||
|  | 	// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a | ||||||
|  | 	// rate limit quota being exceeded. | ||||||
|  | 	ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type HTTPCodedError interface { | type HTTPCodedError interface { | ||||||
|   | |||||||
| @@ -81,7 +81,7 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { | |||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
| 		if allErrors != nil { | 		if allErrors != nil { | ||||||
| 			return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors) | 			return codedErr.Code, multierror.Append(fmt.Errorf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg), allErrors) | ||||||
| 		} | 		} | ||||||
| 		return codedErr.Code, errors.New(codedErr.Msg) | 		return codedErr.Code, errors.New(codedErr.Msg) | ||||||
| 	} | 	} | ||||||
| @@ -110,6 +110,10 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { | |||||||
| 			statusCode = http.StatusBadRequest | 			statusCode = http.StatusBadRequest | ||||||
| 		case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): | 		case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): | ||||||
| 			statusCode = http.StatusBadGateway | 			statusCode = http.StatusBadGateway | ||||||
|  | 		case errwrap.Contains(err, ErrRateLimitQuotaExceeded.Error()): | ||||||
|  | 			statusCode = http.StatusTooManyRequests | ||||||
|  | 		case errwrap.Contains(err, ErrLeaseCountQuotaExceeded.Error()): | ||||||
|  | 			statusCode = http.StatusTooManyRequests | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -339,6 +339,11 @@ func (c *Core) disableCredentialInternal(ctx context.Context, path string, updat | |||||||
|  |  | ||||||
| 	removePathCheckers(c, entry, viewPath) | 	removePathCheckers(c, entry, viewPath) | ||||||
|  |  | ||||||
|  | 	if err := c.quotaManager.HandleBackendDisabling(ctx, ns.Path, path); err != nil { | ||||||
|  | 		c.logger.Error("failed to update quotas after disabling auth", "path", path, "error", err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if c.logger.IsInfo() { | 	if c.logger.IsInfo() { | ||||||
| 		c.logger.Info("disabled credential backend", "path", path) | 		c.logger.Info("disabled credential backend", "path", path) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -43,6 +43,7 @@ import ( | |||||||
| 	sr "github.com/hashicorp/vault/serviceregistration" | 	sr "github.com/hashicorp/vault/serviceregistration" | ||||||
| 	"github.com/hashicorp/vault/shamir" | 	"github.com/hashicorp/vault/shamir" | ||||||
| 	"github.com/hashicorp/vault/vault/cluster" | 	"github.com/hashicorp/vault/vault/cluster" | ||||||
|  | 	"github.com/hashicorp/vault/vault/quotas" | ||||||
| 	vaultseal "github.com/hashicorp/vault/vault/seal" | 	vaultseal "github.com/hashicorp/vault/vault/seal" | ||||||
| 	"github.com/patrickmn/go-cache" | 	"github.com/patrickmn/go-cache" | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
| @@ -97,6 +98,7 @@ var ( | |||||||
| 	enterprisePostUnseal         = enterprisePostUnsealImpl | 	enterprisePostUnseal         = enterprisePostUnsealImpl | ||||||
| 	enterprisePreSeal            = enterprisePreSealImpl | 	enterprisePreSeal            = enterprisePreSealImpl | ||||||
| 	enterpriseSetupFilteredPaths = enterpriseSetupFilteredPathsImpl | 	enterpriseSetupFilteredPaths = enterpriseSetupFilteredPathsImpl | ||||||
|  | 	enterpriseSetupQuotas        = enterpriseSetupQuotasImpl | ||||||
| 	startReplication             = startReplicationImpl | 	startReplication             = startReplicationImpl | ||||||
| 	stopReplication              = stopReplicationImpl | 	stopReplication              = stopReplicationImpl | ||||||
| 	LastWAL                      = lastWALImpl | 	LastWAL                      = lastWALImpl | ||||||
| @@ -520,6 +522,8 @@ type Core struct { | |||||||
| 	// can test an upgrade to a version that includes the fixes from | 	// can test an upgrade to a version that includes the fixes from | ||||||
| 	// https://github.com/hashicorp/vault-enterprise/pull/1103 | 	// https://github.com/hashicorp/vault-enterprise/pull/1103 | ||||||
| 	PR1103disabled bool | 	PR1103disabled bool | ||||||
|  |  | ||||||
|  | 	quotaManager *quotas.Manager | ||||||
| } | } | ||||||
|  |  | ||||||
| // CoreConfig is used to parameterize a core | // CoreConfig is used to parameterize a core | ||||||
| @@ -944,7 +948,9 @@ func NewCore(conf *CoreConfig) (*Core, error) { | |||||||
|  |  | ||||||
| 	c.clusterListener.Store((*cluster.Listener)(nil)) | 	c.clusterListener.Store((*cluster.Listener)(nil)) | ||||||
|  |  | ||||||
| 	err = c.adjustForSealMigration(conf.UnwrapSeal) | 	quotasLogger := conf.Logger.Named("quotas") | ||||||
|  | 	c.allLoggers = append(c.allLoggers, quotasLogger) | ||||||
|  | 	c.quotaManager, err = quotas.NewManager(quotasLogger, c.quotaLeaseWalker, c.metricSink) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -1822,7 +1828,10 @@ func (c *Core) sealInternalWithOptions(grabStateLock, keepHALock, performCleanup | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	postSealInternal(c) | 	if err := postSealInternal(c); err != nil { | ||||||
|  | 		c.logger.Error("post seal error", "error", err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	c.logger.Info("vault is sealed") | 	c.logger.Info("vault is sealed") | ||||||
|  |  | ||||||
| @@ -1892,6 +1901,9 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c | |||||||
| 	if err := c.setupCredentials(ctx); err != nil { | 	if err := c.setupCredentials(ctx); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	if err := c.setupQuotas(ctx, false); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
| 	if !c.IsDRSecondary() { | 	if !c.IsDRSecondary() { | ||||||
| 		if err := c.startRollback(); err != nil { | 		if err := c.startRollback(); err != nil { | ||||||
| 			return err | 			return err | ||||||
| @@ -2078,6 +2090,10 @@ func enterpriseSetupFilteredPathsImpl(c *Core) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func enterpriseSetupQuotasImpl(ctx context.Context, c *Core) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func startReplicationImpl(c *Core) error { | func startReplicationImpl(c *Core) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -2474,3 +2490,29 @@ func (c *Core) readFeatureFlags(ctx context.Context) (*FeatureFlags, error) { | |||||||
| 	} | 	} | ||||||
| 	return &flags, nil | 	return &flags, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // MatchingMount returns the path of the mount that will be responsible for | ||||||
|  | // handling the given request path. | ||||||
|  | func (c *Core) MatchingMount(ctx context.Context, reqPath string) string { | ||||||
|  | 	return c.router.MatchingMount(ctx, reqPath) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Core) setupQuotas(ctx context.Context, isPerfStandby bool) error { | ||||||
|  | 	if c.quotaManager == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return c.quotaManager.Setup(ctx, c.systemBarrierView, isPerfStandby) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ApplyRateLimitQuota checks the request against all the applicable quota rules | ||||||
|  | func (c *Core) ApplyRateLimitQuota(req *quotas.Request) (quotas.Response, error) { | ||||||
|  | 	req.Type = quotas.TypeRateLimit | ||||||
|  | 	return c.quotaManager.ApplyQuota(req) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // RateLimitAuditLoggingEnabled returns if the quota configuration allows audit | ||||||
|  | // logging of request rejections due to rate limiting quota rule violations. | ||||||
|  | func (c *Core) RateLimitAuditLoggingEnabled() bool { | ||||||
|  | 	return c.quotaManager.RateLimitAuditLoggingEnabled() | ||||||
|  | } | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/sdk/helper/license" | 	"github.com/hashicorp/vault/sdk/helper/license" | ||||||
| 	"github.com/hashicorp/vault/sdk/logical" | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
| 	"github.com/hashicorp/vault/sdk/physical" | 	"github.com/hashicorp/vault/sdk/physical" | ||||||
|  | 	"github.com/hashicorp/vault/vault/quotas" | ||||||
| 	"github.com/hashicorp/vault/vault/replication" | 	"github.com/hashicorp/vault/vault/replication" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -58,7 +59,7 @@ func addExtraCredentialBackends(*Core, map[string]logical.Factory) {} | |||||||
|  |  | ||||||
| func preUnsealInternal(context.Context, *Core) error { return nil } | func preUnsealInternal(context.Context, *Core) error { return nil } | ||||||
|  |  | ||||||
| func postSealInternal(*Core) {} | func postSealInternal(*Core) error { return nil } | ||||||
|  |  | ||||||
| func preSealPhysical(c *Core) { | func preSealPhysical(c *Core) { | ||||||
| 	switch c.sealUnwrapper.(type) { | 	switch c.sealUnwrapper.(type) { | ||||||
| @@ -132,3 +133,23 @@ func (c *Core) perfStandbyClusterHandler() (*replication.Cluster, chan struct{}, | |||||||
| func (c *Core) initSealsForMigration() {} | func (c *Core) initSealsForMigration() {} | ||||||
|  |  | ||||||
| func (c *Core) postSealMigration(ctx context.Context) error { return nil } | func (c *Core) postSealMigration(ctx context.Context) error { return nil } | ||||||
|  |  | ||||||
|  | func (c *Core) applyLeaseCountQuota(in *quotas.Request) (*quotas.Response, error) { | ||||||
|  | 	return "as.Response{Allowed: true}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Core) ackLeaseQuota(access quotas.Access, leaseGenerated bool) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Core) quotaLeaseWalker(ctx context.Context, callback func(request *quotas.Request) bool) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leaseIDs []string) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Core) namespaceByPath(path string) *namespace.Namespace { | ||||||
|  | 	return namespace.RootNamespace | ||||||
|  | } | ||||||
|   | |||||||
| @@ -12,17 +12,18 @@ import ( | |||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	metrics "github.com/armon/go-metrics" |  | ||||||
| 	"github.com/hashicorp/errwrap" |  | ||||||
| 	log "github.com/hashicorp/go-hclog" |  | ||||||
| 	multierror "github.com/hashicorp/go-multierror" |  | ||||||
| 	"github.com/hashicorp/vault/helper/namespace" |  | ||||||
| 	"github.com/hashicorp/vault/sdk/framework" | 	"github.com/hashicorp/vault/sdk/framework" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/base62" | 	"github.com/hashicorp/vault/sdk/helper/base62" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/consts" | 	"github.com/hashicorp/vault/sdk/helper/consts" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/jsonutil" | 	"github.com/hashicorp/vault/sdk/helper/jsonutil" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/locksutil" | 	"github.com/hashicorp/vault/sdk/helper/locksutil" | ||||||
| 	"github.com/hashicorp/vault/sdk/logical" | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
|  | 	"github.com/hashicorp/vault/vault/quotas" | ||||||
|  | 	metrics "github.com/armon/go-metrics" | ||||||
|  | 	"github.com/hashicorp/errwrap" | ||||||
|  | 	log "github.com/hashicorp/go-hclog" | ||||||
|  | 	multierror "github.com/hashicorp/go-multierror" | ||||||
|  | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
| 	uberAtomic "go.uber.org/atomic" | 	uberAtomic "go.uber.org/atomic" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -256,23 +257,60 @@ func (m *ExpirationManager) inRestoreMode() bool { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (m *ExpirationManager) invalidate(key string) { | func (m *ExpirationManager) invalidate(key string) { | ||||||
|  |  | ||||||
| 	switch { | 	switch { | ||||||
| 	case strings.HasPrefix(key, leaseViewPrefix): | 	case strings.HasPrefix(key, leaseViewPrefix): | ||||||
| 		// Clear from the pending expiration |  | ||||||
| 		leaseID := strings.TrimPrefix(key, leaseViewPrefix) | 		leaseID := strings.TrimPrefix(key, leaseViewPrefix) | ||||||
|  | 		ctx := m.quitContext | ||||||
|  | 		_, nsID := namespace.SplitIDFromString(leaseID) | ||||||
|  | 		leaseNS := namespace.RootNamespace | ||||||
|  | 		var err error | ||||||
|  | 		if nsID != "" { | ||||||
|  | 			leaseNS, err = NamespaceByID(ctx, nsID, m.core) | ||||||
|  | 			if err != nil { | ||||||
|  | 				m.logger.Error("failed to invalidate lease entry", "error", err) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		le, err := m.loadEntryInternal(namespace.ContextWithNamespace(ctx, leaseNS), leaseID, false, false) | ||||||
|  | 		if err != nil { | ||||||
|  | 			m.logger.Error("failed to invalidate lease entry", "error", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		m.pendingLock.Lock() | 		m.pendingLock.Lock() | ||||||
| 		if info, ok := m.pending.Load(leaseID); ok { | 		defer m.pendingLock.Unlock() | ||||||
|  | 		info, ok := m.pending.Load(leaseID) | ||||||
|  | 		switch { | ||||||
|  | 		case ok: | ||||||
|  | 			switch { | ||||||
|  | 			case le == nil: | ||||||
|  | 				// Handle lease deletion | ||||||
| 				pending := info.(pendingInfo) | 				pending := info.(pendingInfo) | ||||||
| 				pending.timer.Stop() | 				pending.timer.Stop() | ||||||
| 				m.pending.Delete(leaseID) | 				m.pending.Delete(leaseID) | ||||||
| 				m.leaseCount-- | 				m.leaseCount-- | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 				// If in the nonexpiring map, remove there. | 				// If in the nonexpiring map, remove there. | ||||||
| 				m.nonexpiring.Delete(leaseID) | 				m.nonexpiring.Delete(leaseID) | ||||||
|  |  | ||||||
| 		m.pendingLock.Unlock() | 				if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { | ||||||
|  | 					m.logger.Error("failed to handle lease delete invalidation", "error", err) | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 			default: | ||||||
|  | 				// Handle lease update | ||||||
|  | 				m.updatePendingInternal(le, le.ExpireTime.Sub(time.Now())) | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			// There is no entry in the pending map and the invalidation | ||||||
|  | 			// resulted in a nil entry. This should ideally never happen. | ||||||
|  | 			if le == nil { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			// Handle lease creation | ||||||
|  | 			m.updatePendingInternal(le, le.ExpireTime.Sub(time.Now())) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -692,13 +730,18 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Clear the expiration handler (or remove from the list of non-expiring tokens.) | 	// Clear the expiration handler | ||||||
| 	m.pendingLock.Lock() | 	m.pendingLock.Lock() | ||||||
| 	if info, ok := m.pending.Load(leaseID); ok { | 	if info, ok := m.pending.Load(leaseID); ok { | ||||||
| 		pending := info.(pendingInfo) | 		pending := info.(pendingInfo) | ||||||
| 		pending.timer.Stop() | 		pending.timer.Stop() | ||||||
| 		m.pending.Delete(leaseID) | 		m.pending.Delete(leaseID) | ||||||
| 		m.leaseCount-- | 		m.leaseCount-- | ||||||
|  | 		if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { | ||||||
|  | 			m.pendingLock.Unlock() | ||||||
|  | 			m.logger.Error("failed to handle lease path deletion", "error", err) | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	m.nonexpiring.Delete(leaseID) | 	m.nonexpiring.Delete(leaseID) | ||||||
| 	m.pendingLock.Unlock() | 	m.pendingLock.Unlock() | ||||||
| @@ -1420,10 +1463,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry, leaseTotal tim | |||||||
| 			info.(pendingInfo).timer.Stop() | 			info.(pendingInfo).timer.Stop() | ||||||
| 			m.pending.Delete(le.LeaseID) | 			m.pending.Delete(le.LeaseID) | ||||||
| 			m.leaseCount-- | 			m.leaseCount-- | ||||||
|  | 			if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []string{le.LeaseID}); err != nil { | ||||||
|  | 				m.logger.Error("failed to handle lease path deletion", "error", err) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	leaseCreated := false | ||||||
| 	// Create entry if it does not exist or reset if it does | 	// Create entry if it does not exist or reset if it does | ||||||
| 	if ok { | 	if ok { | ||||||
| 		pending = info.(pendingInfo) | 		pending = info.(pendingInfo) | ||||||
| @@ -1439,12 +1487,20 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry, leaseTotal tim | |||||||
| 		} | 		} | ||||||
| 		// new lease | 		// new lease | ||||||
| 		m.leaseCount++ | 		m.leaseCount++ | ||||||
|  | 		leaseCreated = true | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Retain some information in-memory | 	// Retain some information in-memory | ||||||
| 	pending.cachedLeaseInfo = m.inMemoryLeaseInfo(le) | 	pending.cachedLeaseInfo = m.inMemoryLeaseInfo(le) | ||||||
|  |  | ||||||
| 	m.pending.Store(le.LeaseID, pending) | 	m.pending.Store(le.LeaseID, pending) | ||||||
|  |  | ||||||
|  | 	if leaseCreated { | ||||||
|  | 		if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []string{le.LeaseID}); err != nil { | ||||||
|  | 			m.logger.Error("failed to handle lease creation", "error", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // revokeEntry is used to attempt revocation of an internal entry | // revokeEntry is used to attempt revocation of an internal entry | ||||||
|   | |||||||
							
								
								
									
										384
									
								
								vault/external_tests/quotas/quotas_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										384
									
								
								vault/external_tests/quotas/quotas_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,384 @@ | |||||||
|  | package quotas | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/api" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/builtin/credential/userpass" | ||||||
|  | 	"github.com/hashicorp/vault/builtin/logical/pki" | ||||||
|  | 	"github.com/hashicorp/vault/helper/testhelpers/teststorage" | ||||||
|  | 	"github.com/hashicorp/vault/vault" | ||||||
|  | 	"go.uber.org/atomic" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	testLookupOnlyPolicy = ` | ||||||
|  | path "/auth/token/lookup" { | ||||||
|  | 	capabilities = [ "create", "update"] | ||||||
|  | } | ||||||
|  | ` | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	coreConfig = &vault.CoreConfig{ | ||||||
|  | 		LogicalBackends: map[string]logical.Factory{ | ||||||
|  | 			"pki": pki.Factory, | ||||||
|  | 		}, | ||||||
|  | 		CredentialBackends: map[string]logical.Factory{ | ||||||
|  | 			"userpass": userpass.Factory, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func setupMounts(t *testing.T, client *api.Client) { | ||||||
|  | 	t.Helper() | ||||||
|  |  | ||||||
|  | 	err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ | ||||||
|  | 		Type: "userpass", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{ | ||||||
|  | 		"password": "bar", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = client.Sys().Mount("pki", &api.MountInput{ | ||||||
|  | 		Type: "pki", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{ | ||||||
|  | 		"common_name": "testvault.com", | ||||||
|  | 		"ttl":         "200h", | ||||||
|  | 		"ip_sans":     "127.0.0.1", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{ | ||||||
|  | 		"require_cn":       false, | ||||||
|  | 		"allowed_domains":  "testvault.com", | ||||||
|  | 		"allow_subdomains": true, | ||||||
|  | 		"max_ttl":          "2h", | ||||||
|  | 		"generate_lease":   true, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func teardownMounts(t *testing.T, client *api.Client) { | ||||||
|  | 	t.Helper() | ||||||
|  | 	if err := client.Sys().Unmount("pki"); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if err := client.Sys().DisableAuth("userpass"); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func testRPS(reqFunc func(numSuccess, numFail *atomic.Int32), d time.Duration) (int32, int32, time.Duration) { | ||||||
|  | 	numSuccess := atomic.NewInt32(0) | ||||||
|  | 	numFail := atomic.NewInt32(0) | ||||||
|  |  | ||||||
|  | 	start := time.Now() | ||||||
|  | 	end := start.Add(d) | ||||||
|  | 	for time.Now().Before(end) { | ||||||
|  | 		reqFunc(numSuccess, numFail) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return numSuccess.Load(), numFail.Load(), time.Since(start) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func waitForRemovalOrTimeout(c *api.Client, path string, tick, to time.Duration) error { | ||||||
|  | 	ticker := time.Tick(tick) | ||||||
|  | 	timeout := time.After(to) | ||||||
|  |  | ||||||
|  | 	// wait for the resource to be removed | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case <-timeout: | ||||||
|  | 			return fmt.Errorf("timeout exceeding waiting for resource to be deleted: %s", path) | ||||||
|  |  | ||||||
|  | 		case <-ticker: | ||||||
|  | 			resp, err := c.Logical().Read(path) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if resp == nil { | ||||||
|  | 				return nil | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestQuotas_RateLimitQuota_Mount(t *testing.T) { | ||||||
|  | 	conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil) | ||||||
|  | 	cluster := vault.NewTestCluster(t, conf, opts) | ||||||
|  | 	cluster.Start() | ||||||
|  | 	defer cluster.Cleanup() | ||||||
|  |  | ||||||
|  | 	core := cluster.Cores[0].Core | ||||||
|  | 	client := cluster.Cores[0].Client | ||||||
|  | 	vault.TestWaitActive(t, core) | ||||||
|  |  | ||||||
|  | 	err := client.Sys().Mount("pki", &api.MountInput{ | ||||||
|  | 		Type: "pki", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{ | ||||||
|  | 		"common_name": "testvault.com", | ||||||
|  | 		"ttl":         "200h", | ||||||
|  | 		"ip_sans":     "127.0.0.1", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{ | ||||||
|  | 		"require_cn":       false, | ||||||
|  | 		"allowed_domains":  "testvault.com", | ||||||
|  | 		"allow_subdomains": true, | ||||||
|  | 		"max_ttl":          "2h", | ||||||
|  | 		"generate_lease":   true, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	reqFunc := func(numSuccess, numFail *atomic.Int32) { | ||||||
|  | 		_, err := client.Logical().Read("pki/cert/ca_chain") | ||||||
|  |  | ||||||
|  | 		if err != nil { | ||||||
|  | 			numFail.Add(1) | ||||||
|  | 		} else { | ||||||
|  | 			numSuccess.Add(1) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create a rate limit quota with a low RPS of 7.7, which means we can process | ||||||
|  | 	// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed | ||||||
|  | 	// by a refill rate of 7.7 per-second. | ||||||
|  | 	_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{ | ||||||
|  | 		"rate":  7.7, | ||||||
|  | 		"burst": 8, | ||||||
|  | 		"path":  "pki/", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second) | ||||||
|  |  | ||||||
|  | 	// evaluate the ideal RPS as (burst + (RPS * totalSeconds)) | ||||||
|  | 	ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second)) | ||||||
|  |  | ||||||
|  | 	// ensure there were some failed requests | ||||||
|  | 	if numFail == 0 { | ||||||
|  | 		t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ensure that we should never get more requests than allowed | ||||||
|  | 	if want := int32(ideal + 1); numSuccess > want { | ||||||
|  | 		t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// update the rate limit quota with a high RPS such that no requests should fail | ||||||
|  | 	_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{ | ||||||
|  | 		"rate":  1000.0, | ||||||
|  | 		"burst": 3000, | ||||||
|  | 		"path":  "pki/", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, numFail, _ = testRPS(reqFunc, 5*time.Second) | ||||||
|  | 	if numFail > 0 { | ||||||
|  | 		t.Fatalf("unexpected number of failed requests: %d", numFail) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestQuotas_RateLimitQuota_MountPrecedence(t *testing.T) { | ||||||
|  | 	conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil) | ||||||
|  | 	cluster := vault.NewTestCluster(t, conf, opts) | ||||||
|  | 	cluster.Start() | ||||||
|  | 	defer cluster.Cleanup() | ||||||
|  |  | ||||||
|  | 	core := cluster.Cores[0].Core | ||||||
|  | 	client := cluster.Cores[0].Client | ||||||
|  |  | ||||||
|  | 	vault.TestWaitActive(t, core) | ||||||
|  |  | ||||||
|  | 	// create PKI mount | ||||||
|  | 	err := client.Sys().Mount("pki", &api.MountInput{ | ||||||
|  | 		Type: "pki", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = client.Logical().Write("pki/root/generate/internal", map[string]interface{}{ | ||||||
|  | 		"common_name": "testvault.com", | ||||||
|  | 		"ttl":         "200h", | ||||||
|  | 		"ip_sans":     "127.0.0.1", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = client.Logical().Write("pki/roles/test", map[string]interface{}{ | ||||||
|  | 		"require_cn":       false, | ||||||
|  | 		"allowed_domains":  "testvault.com", | ||||||
|  | 		"allow_subdomains": true, | ||||||
|  | 		"max_ttl":          "2h", | ||||||
|  | 		"generate_lease":   true, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// create a root rate limit quota | ||||||
|  | 	_, err = client.Logical().Write("sys/quotas/rate-limit/root-rlq", map[string]interface{}{ | ||||||
|  | 		"name":  "root-rlq", | ||||||
|  | 		"rate":  14.7, | ||||||
|  | 		"burst": 15, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// create a mount rate limit quota with a lower RPS than the root rate limit quota | ||||||
|  | 	_, err = client.Logical().Write("sys/quotas/rate-limit/mount-rlq", map[string]interface{}{ | ||||||
|  | 		"name":  "mount-rlq", | ||||||
|  | 		"rate":  7.7, | ||||||
|  | 		"burst": 8, | ||||||
|  | 		"path":  "pki/", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ensure mount rate limit quota takes precedence over root rate limit quota | ||||||
|  | 	reqFunc := func(numSuccess, numFail *atomic.Int32) { | ||||||
|  | 		_, err := client.Logical().Read("pki/cert/ca_chain") | ||||||
|  |  | ||||||
|  | 		if err != nil { | ||||||
|  | 			numFail.Add(1) | ||||||
|  | 		} else { | ||||||
|  | 			numSuccess.Add(1) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ensure mount rate limit quota takes precedence over root rate limit quota | ||||||
|  | 	numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second) | ||||||
|  |  | ||||||
|  | 	// evaluate the ideal RPS as (burst + (RPS * totalSeconds)) | ||||||
|  | 	ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second)) | ||||||
|  |  | ||||||
|  | 	// ensure there were some failed requests | ||||||
|  | 	if numFail == 0 { | ||||||
|  | 		t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ensure that we should never get more requests than allowed | ||||||
|  | 	if want := int32(ideal + 1); numSuccess > want { | ||||||
|  | 		t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestQuotas_RateLimitQuota(t *testing.T) { | ||||||
|  | 	conf, opts := teststorage.ClusterSetup(coreConfig, nil, nil) | ||||||
|  | 	cluster := vault.NewTestCluster(t, conf, opts) | ||||||
|  | 	cluster.Start() | ||||||
|  | 	defer cluster.Cleanup() | ||||||
|  |  | ||||||
|  | 	core := cluster.Cores[0].Core | ||||||
|  | 	client := cluster.Cores[0].Client | ||||||
|  |  | ||||||
|  | 	vault.TestWaitActive(t, core) | ||||||
|  |  | ||||||
|  | 	err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ | ||||||
|  | 		Type: "userpass", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = client.Logical().Write("auth/userpass/users/foo", map[string]interface{}{ | ||||||
|  | 		"password": "bar", | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create a rate limit quota with a low RPS of 7.7, which means we can process | ||||||
|  | 	// ⌈7.7⌉*2 requests in the span of roughly a second -- 8 initially, followed | ||||||
|  | 	// by a refill rate of 7.7 per-second. | ||||||
|  | 	_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{ | ||||||
|  | 		"rate":  7.7, | ||||||
|  | 		"burst": 8, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	reqFunc := func(numSuccess, numFail *atomic.Int32) { | ||||||
|  | 		_, err := client.Logical().Read("sys/quotas/rate-limit/rlq") | ||||||
|  |  | ||||||
|  | 		if err != nil { | ||||||
|  | 			numFail.Add(1) | ||||||
|  | 		} else { | ||||||
|  | 			numSuccess.Add(1) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	numSuccess, numFail, elapsed := testRPS(reqFunc, 5*time.Second) | ||||||
|  |  | ||||||
|  | 	// evaluate the ideal RPS as (burst + (RPS * totalSeconds)) | ||||||
|  | 	ideal := 8 + (7.7 * float64(elapsed) / float64(time.Second)) | ||||||
|  |  | ||||||
|  | 	// ensure there were some failed requests | ||||||
|  | 	if numFail == 0 { | ||||||
|  | 		t.Fatalf("expected some requests to fail; numSuccess: %d, numFail: %d, elapsed: %d", numSuccess, numFail, elapsed) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ensure that we should never get more requests than allowed | ||||||
|  | 	if want := int32(ideal + 1); numSuccess > want { | ||||||
|  | 		t.Fatalf("too many successful requests; want: %d, numSuccess: %d, numFail: %d, elapsed: %d", want, numSuccess, numFail, elapsed) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// allow time (1s) for rate limit to refill before updating the quota | ||||||
|  | 	time.Sleep(time.Second) | ||||||
|  |  | ||||||
|  | 	// update the rate limit quota with a high RPS such that no requests should fail | ||||||
|  | 	_, err = client.Logical().Write("sys/quotas/rate-limit/rlq", map[string]interface{}{ | ||||||
|  | 		"rate":  1000.0, | ||||||
|  | 		"burst": 3000, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, numFail, _ = testRPS(reqFunc, 5*time.Second) | ||||||
|  | 	if numFail > 0 { | ||||||
|  | 		t.Fatalf("unexpected number of failed requests: %d", numFail) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -160,6 +160,7 @@ func NewSystemBackend(core *Core, logger log.Logger) *SystemBackend { | |||||||
| 	b.Backend.Paths = append(b.Backend.Paths, b.metricsPath()) | 	b.Backend.Paths = append(b.Backend.Paths, b.metricsPath()) | ||||||
| 	b.Backend.Paths = append(b.Backend.Paths, b.monitorPath()) | 	b.Backend.Paths = append(b.Backend.Paths, b.monitorPath()) | ||||||
| 	b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath()) | 	b.Backend.Paths = append(b.Backend.Paths, b.hostInfoPath()) | ||||||
|  | 	b.Backend.Paths = append(b.Backend.Paths, b.quotasPaths()...) | ||||||
|  |  | ||||||
| 	if core.rawEnabled { | 	if core.rawEnabled { | ||||||
| 		b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) | 		b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) | ||||||
| @@ -751,7 +752,7 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d | |||||||
|  |  | ||||||
| 	// Get all the options | 	// Get all the options | ||||||
| 	path := data.Get("path").(string) | 	path := data.Get("path").(string) | ||||||
| 	path = sanitizeMountPath(path) | 	path = sanitizePath(path) | ||||||
|  |  | ||||||
| 	logicalType := data.Get("type").(string) | 	logicalType := data.Get("type").(string) | ||||||
| 	description := data.Get("description").(string) | 	description := data.Get("description").(string) | ||||||
| @@ -934,7 +935,7 @@ func handleErrorNoReadOnlyForward( | |||||||
| // handleUnmount is used to unmount a path | // handleUnmount is used to unmount a path | ||||||
| func (b *SystemBackend) handleUnmount(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | func (b *SystemBackend) handleUnmount(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
| 	path := data.Get("path").(string) | 	path := data.Get("path").(string) | ||||||
| 	path = sanitizeMountPath(path) | 	path = sanitizePath(path) | ||||||
|  |  | ||||||
| 	ns, err := namespace.FromContext(ctx) | 	ns, err := namespace.FromContext(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -1029,6 +1030,12 @@ func (b *SystemBackend) handleRemount(ctx context.Context, req *logical.Request, | |||||||
| 		return handleError(err) | 		return handleError(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Update quotas with the new path | ||||||
|  | 	if err := b.Core.quotaManager.HandleRemount(ctx, ns.Path, sanitizePath(fromPath), sanitizePath(toPath)); err != nil { | ||||||
|  | 		b.Core.logger.Error("failed to update quotas after remount", "ns_path", ns.Path, "from_path", fromPath, "to_path", toPath, "error", err) | ||||||
|  | 		return handleError(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -1060,7 +1067,7 @@ func (b *SystemBackend) handleMountTuneRead(ctx context.Context, req *logical.Re | |||||||
|  |  | ||||||
| // handleTuneReadCommon returns the config settings of a path | // handleTuneReadCommon returns the config settings of a path | ||||||
| func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) (*logical.Response, error) { | func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) (*logical.Response, error) { | ||||||
| 	path = sanitizeMountPath(path) | 	path = sanitizePath(path) | ||||||
|  |  | ||||||
| 	sysView := b.Core.router.MatchingSystemView(ctx, path) | 	sysView := b.Core.router.MatchingSystemView(ctx, path) | ||||||
| 	if sysView == nil { | 	if sysView == nil { | ||||||
| @@ -1146,7 +1153,7 @@ func (b *SystemBackend) handleMountTuneWrite(ctx context.Context, req *logical.R | |||||||
| func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, data *framework.FieldData) (*logical.Response, error) { | func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, data *framework.FieldData) (*logical.Response, error) { | ||||||
| 	repState := b.Core.ReplicationState() | 	repState := b.Core.ReplicationState() | ||||||
|  |  | ||||||
| 	path = sanitizeMountPath(path) | 	path = sanitizePath(path) | ||||||
|  |  | ||||||
| 	// Prevent protected paths from being changed | 	// Prevent protected paths from being changed | ||||||
| 	for _, p := range untunableMounts { | 	for _, p := range untunableMounts { | ||||||
| @@ -1716,7 +1723,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque | |||||||
|  |  | ||||||
| 	// Get all the options | 	// Get all the options | ||||||
| 	path := data.Get("path").(string) | 	path := data.Get("path").(string) | ||||||
| 	path = sanitizeMountPath(path) | 	path = sanitizePath(path) | ||||||
| 	logicalType := data.Get("type").(string) | 	logicalType := data.Get("type").(string) | ||||||
| 	description := data.Get("description").(string) | 	description := data.Get("description").(string) | ||||||
| 	pluginName := data.Get("plugin_name").(string) | 	pluginName := data.Get("plugin_name").(string) | ||||||
| @@ -1857,7 +1864,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque | |||||||
| // handleDisableAuth is used to disable a credential backend | // handleDisableAuth is used to disable a credential backend | ||||||
| func (b *SystemBackend) handleDisableAuth(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | func (b *SystemBackend) handleDisableAuth(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { | ||||||
| 	path := data.Get("path").(string) | 	path := data.Get("path").(string) | ||||||
| 	path = sanitizeMountPath(path) | 	path = sanitizePath(path) | ||||||
|  |  | ||||||
| 	ns, err := namespace.FromContext(ctx) | 	ns, err := namespace.FromContext(ctx) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -2272,7 +2279,7 @@ func (b *SystemBackend) handleAuditHash(ctx context.Context, req *logical.Reques | |||||||
| 		return logical.ErrorResponse("the \"input\" parameter is empty"), nil | 		return logical.ErrorResponse("the \"input\" parameter is empty"), nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	path = sanitizeMountPath(path) | 	path = sanitizePath(path) | ||||||
|  |  | ||||||
| 	hash, err := b.Core.auditBroker.GetHash(ctx, path, input) | 	hash, err := b.Core.auditBroker.GetHash(ctx, path, input) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -3258,7 +3265,7 @@ func (b *SystemBackend) pathInternalUIMountRead(ctx context.Context, req *logica | |||||||
| 	if path == "" { | 	if path == "" { | ||||||
| 		return logical.ErrorResponse("path not set"), logical.ErrInvalidRequest | 		return logical.ErrorResponse("path not set"), logical.ErrInvalidRequest | ||||||
| 	} | 	} | ||||||
| 	path = sanitizeMountPath(path) | 	path = sanitizePath(path) | ||||||
|  |  | ||||||
| 	errResp := logical.ErrorResponse(fmt.Sprintf("preflight capability check returned 403, please ensure client's policies grant access to path %q", path)) | 	errResp := logical.ErrorResponse(fmt.Sprintf("preflight capability check returned 403, please ensure client's policies grant access to path %q", path)) | ||||||
|  |  | ||||||
| @@ -3576,7 +3583,7 @@ func (b *SystemBackend) pathInternalOpenAPI(ctx context.Context, req *logical.Re | |||||||
| 	return resp, nil | 	return resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func sanitizeMountPath(path string) string { | func sanitizePath(path string) string { | ||||||
| 	if !strings.HasSuffix(path, "/") { | 	if !strings.HasSuffix(path, "/") { | ||||||
| 		path += "/" | 		path += "/" | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										272
									
								
								vault/logical_system_quotas.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										272
									
								
								vault/logical_system_quotas.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,272 @@ | |||||||
|  | package vault | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/helper/namespace" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/framework" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
|  | 	"github.com/hashicorp/vault/vault/quotas" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // quotasPaths returns paths that enable quota management | ||||||
|  | func (b *SystemBackend) quotasPaths() []*framework.Path { | ||||||
|  | 	return []*framework.Path{ | ||||||
|  | 		{ | ||||||
|  | 			Pattern: "quotas/config$", | ||||||
|  | 			Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 				"enable_rate_limit_audit_logging": { | ||||||
|  | 					Type:        framework.TypeBool, | ||||||
|  | 					Description: "If set, starts audit logging of requests that get rejected due to rate limit quota rule violations.", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			Operations: map[logical.Operation]framework.OperationHandler{ | ||||||
|  | 				logical.UpdateOperation: &framework.PathOperation{ | ||||||
|  | 					Callback: b.handleQuotasConfigUpdate(), | ||||||
|  | 				}, | ||||||
|  | 				logical.ReadOperation: &framework.PathOperation{ | ||||||
|  | 					Callback: b.handleQuotasConfigRead(), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			HelpSynopsis:    strings.TrimSpace(quotasHelp["quotas-config"][0]), | ||||||
|  | 			HelpDescription: strings.TrimSpace(quotasHelp["quotas-config"][1]), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Pattern: "quotas/rate-limit/?$", | ||||||
|  | 			Operations: map[logical.Operation]framework.OperationHandler{ | ||||||
|  | 				logical.ListOperation: &framework.PathOperation{ | ||||||
|  | 					Callback: b.handleRateLimitQuotasList(), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			HelpSynopsis:    strings.TrimSpace(quotasHelp["rate-limit-list"][0]), | ||||||
|  | 			HelpDescription: strings.TrimSpace(quotasHelp["rate-limit-list"][1]), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Pattern: "quotas/rate-limit/" + framework.GenericNameRegex("name"), | ||||||
|  | 			Fields: map[string]*framework.FieldSchema{ | ||||||
|  | 				"type": { | ||||||
|  | 					Type:        framework.TypeString, | ||||||
|  | 					Description: "Type of the quota rule.", | ||||||
|  | 				}, | ||||||
|  | 				"name": { | ||||||
|  | 					Type:        framework.TypeString, | ||||||
|  | 					Description: "Name of the quota rule.", | ||||||
|  | 				}, | ||||||
|  | 				"path": { | ||||||
|  | 					Type: framework.TypeString, | ||||||
|  | 					Description: `Path of the mount or namespace to apply the quota. A blank path configures a | ||||||
|  | global quota. For example namespace1/ adds a quota to a full namespace, | ||||||
|  | namespace1/auth/userpass adds a quota to userpass in namespace1.`, | ||||||
|  | 				}, | ||||||
|  | 				"rate": { | ||||||
|  | 					Type: framework.TypeFloat, | ||||||
|  | 					Description: `The rate at which allowed requests are refilled per second by the quota rule. | ||||||
|  | Internally, a token-bucket algorithm is used which has a size of 'burst', initially full. The quota | ||||||
|  | limits requests to 'rate' per-second, with a maximum burst size of 'burst'. Each request takes a single | ||||||
|  | token from this bucket. The 'rate' must be positive.`, | ||||||
|  | 				}, | ||||||
|  | 				"burst": { | ||||||
|  | 					Type: framework.TypeInt, | ||||||
|  | 					Description: `The maximum number of requests at any given second to be allowed by the quota | ||||||
|  | rule. There is a one-to-one mapping between requests and tokens in the rate limit quota. A client | ||||||
|  | may perform up to 'burst' requests at once, at which they they may invoke additional requests at | ||||||
|  | 'rate' per-second.`, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			Operations: map[logical.Operation]framework.OperationHandler{ | ||||||
|  | 				logical.UpdateOperation: &framework.PathOperation{ | ||||||
|  | 					Callback: b.handleRateLimitQuotasUpdate(), | ||||||
|  | 				}, | ||||||
|  | 				logical.ReadOperation: &framework.PathOperation{ | ||||||
|  | 					Callback: b.handleRateLimitQuotasRead(), | ||||||
|  | 				}, | ||||||
|  | 				logical.DeleteOperation: &framework.PathOperation{ | ||||||
|  | 					Callback: b.handleRateLimitQuotasDelete(), | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			HelpSynopsis:    strings.TrimSpace(quotasHelp["rate-limit"][0]), | ||||||
|  | 			HelpDescription: strings.TrimSpace(quotasHelp["rate-limit"][1]), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *SystemBackend) handleQuotasConfigUpdate() framework.OperationFunc { | ||||||
|  | 	return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 		config, err := quotas.LoadConfig(ctx, b.Core.systemBarrierView) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		config.EnableRateLimitAuditLogging = d.Get("enable_rate_limit_audit_logging").(bool) | ||||||
|  |  | ||||||
|  | 		entry, err := logical.StorageEntryJSON(quotas.ConfigPath, config) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		if err := req.Storage.Put(ctx, entry); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		b.Core.quotaManager.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging) | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *SystemBackend) handleQuotasConfigRead() framework.OperationFunc { | ||||||
|  | 	return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 		config := b.Core.quotaManager.Config() | ||||||
|  | 		return &logical.Response{ | ||||||
|  | 			Data: map[string]interface{}{ | ||||||
|  | 				"enable_rate_limit_audit_logging": config.EnableRateLimitAuditLogging, | ||||||
|  | 			}, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *SystemBackend) handleRateLimitQuotasList() framework.OperationFunc { | ||||||
|  | 	return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 		names, err := b.Core.quotaManager.QuotaNames(quotas.TypeRateLimit) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return logical.ListResponse(names), nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc { | ||||||
|  | 	return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 		name := d.Get("name").(string) | ||||||
|  |  | ||||||
|  | 		qType := quotas.TypeRateLimit.String() | ||||||
|  | 		rate := d.Get("rate").(float64) | ||||||
|  | 		if rate <= 0 { | ||||||
|  | 			return logical.ErrorResponse("'rate' is invalid"), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		burst := d.Get("burst").(int) | ||||||
|  | 		if burst < int(rate) { | ||||||
|  | 			return logical.ErrorResponse("'burst' must be greater than or equal to 'rate' as an integer value"), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		mountPath := sanitizePath(d.Get("path").(string)) | ||||||
|  | 		ns := b.Core.namespaceByPath(mountPath) | ||||||
|  | 		if ns.ID != namespace.RootNamespaceID { | ||||||
|  | 			mountPath = strings.TrimPrefix(mountPath, ns.Path) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if mountPath != "" { | ||||||
|  | 			match := b.Core.router.MatchingMount(namespace.ContextWithNamespace(ctx, ns), mountPath) | ||||||
|  | 			if match == "" { | ||||||
|  | 				return logical.ErrorResponse("invalid mount path %q", mountPath), nil | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Disallow duplicate quotas with same precedence and similar | ||||||
|  | 		// properties. | ||||||
|  | 		quota, err := b.Core.quotaManager.QuotaByFactors(ctx, qType, ns.Path, mountPath) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		if quota != nil && quota.QuotaName() != name { | ||||||
|  | 			return logical.ErrorResponse("quota rule with similar properties exists under the name %q", quota.QuotaName()), nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		switch { | ||||||
|  | 		case quota == nil: | ||||||
|  | 			quota = quotas.NewRateLimitQuota(name, ns.Path, mountPath, rate, burst) | ||||||
|  | 		default: | ||||||
|  | 			rlq := quota.(*quotas.RateLimitQuota) | ||||||
|  | 			rlq.NamespacePath = ns.Path | ||||||
|  | 			rlq.MountPath = mountPath | ||||||
|  | 			rlq.Rate = rate | ||||||
|  | 			rlq.Burst = burst | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		entry, err := logical.StorageEntryJSON(quotas.QuotaStoragePath(qType, name), quota) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if err := req.Storage.Put(ctx, entry); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if err := b.Core.quotaManager.SetQuota(ctx, qType, quota, false); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *SystemBackend) handleRateLimitQuotasRead() framework.OperationFunc { | ||||||
|  | 	return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 		name := d.Get("name").(string) | ||||||
|  | 		qType := quotas.TypeRateLimit.String() | ||||||
|  |  | ||||||
|  | 		quota, err := b.Core.quotaManager.QuotaByName(qType, name) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		if quota == nil { | ||||||
|  | 			return nil, nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		rlq := quota.(*quotas.RateLimitQuota) | ||||||
|  |  | ||||||
|  | 		nsPath := rlq.NamespacePath | ||||||
|  | 		if rlq.NamespacePath == "root" { | ||||||
|  | 			nsPath = "" | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		data := map[string]interface{}{ | ||||||
|  | 			"type":  qType, | ||||||
|  | 			"name":  rlq.Name, | ||||||
|  | 			"path":  nsPath + rlq.MountPath, | ||||||
|  | 			"rate":  rlq.Rate, | ||||||
|  | 			"burst": rlq.Burst, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return &logical.Response{ | ||||||
|  | 			Data: data, | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *SystemBackend) handleRateLimitQuotasDelete() framework.OperationFunc { | ||||||
|  | 	return func(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { | ||||||
|  | 		name := d.Get("name").(string) | ||||||
|  | 		qType := quotas.TypeRateLimit.String() | ||||||
|  |  | ||||||
|  | 		if err := req.Storage.Delete(ctx, quotas.QuotaStoragePath(qType, name)); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if err := b.Core.quotaManager.DeleteQuota(ctx, qType, name); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var quotasHelp = map[string][2]string{ | ||||||
|  | 	"quotas-config": { | ||||||
|  | 		"Create, update and read the quota configuration.", | ||||||
|  | 		"", | ||||||
|  | 	}, | ||||||
|  | 	"rate-limit": { | ||||||
|  | 		`Get, create or update rate limit resource quota for an optional namespace or | ||||||
|  | mount.`, | ||||||
|  | 		`A rate limit quota will enforce rate limiting using a token bucket algorithm. A | ||||||
|  | rate limit quota can be created at the root level or defined on a namespace or | ||||||
|  | mount by specifying a 'path'. The rate limiter is applied to each unique client | ||||||
|  | IP address. A client may invoke 'burst' requests at any given second, after | ||||||
|  | which they may invoke additional requests at 'rate' per-second.`, | ||||||
|  | 	}, | ||||||
|  | 	"rate-limit-list": { | ||||||
|  | 		"Lists the names of all the rate limit quotas.", | ||||||
|  | 		"This list contains quota definitions from all the namespaces.", | ||||||
|  | 	}, | ||||||
|  | } | ||||||
| @@ -2654,7 +2654,7 @@ func TestSystemBackend_PathWildcardPreflight(t *testing.T) { | |||||||
| 	// Add another mount | 	// Add another mount | ||||||
| 	me := &MountEntry{ | 	me := &MountEntry{ | ||||||
| 		Table:   mountTableType, | 		Table:   mountTableType, | ||||||
| 		Path:    sanitizeMountPath("kv-v1"), | 		Path:    sanitizePath("kv-v1"), | ||||||
| 		Type:    "kv", | 		Type:    "kv", | ||||||
| 		Options: map[string]string{"version": "1"}, | 		Options: map[string]string{"version": "1"}, | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -664,6 +664,11 @@ func (c *Core) unmountInternal(ctx context.Context, path string, updateStorage b | |||||||
|  |  | ||||||
| 	removePathCheckers(c, entry, viewPath) | 	removePathCheckers(c, entry, viewPath) | ||||||
|  |  | ||||||
|  | 	if err := c.quotaManager.HandleBackendDisabling(ctx, ns.Path, path); err != nil { | ||||||
|  | 		c.logger.Error("failed to update quotas after disabling mount", "path", path, "error", err) | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if c.logger.IsInfo() { | 	if c.logger.IsInfo() { | ||||||
| 		c.logger.Info("successfully unmounted", "path", path, "namespace", ns.Path) | 		c.logger.Info("successfully unmounted", "path", path, "namespace", ns.Path) | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										860
									
								
								vault/quotas/quotas.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										860
									
								
								vault/quotas/quotas.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,860 @@ | |||||||
|  | package quotas | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"path" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  |  | ||||||
|  | 	log "github.com/hashicorp/go-hclog" | ||||||
|  | 	"github.com/hashicorp/go-memdb" | ||||||
|  | 	"github.com/hashicorp/vault/helper/metricsutil" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Type represents the quota kind | ||||||
|  | type Type string | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	// TypeRateLimit represents the rate limiting quota type | ||||||
|  | 	TypeRateLimit Type = "rate-limit" | ||||||
|  |  | ||||||
|  | 	// TypeLeaseCount represents the lease count limiting quota type | ||||||
|  | 	TypeLeaseCount Type = "lease-count" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // LeaseAction is the action taken by the expiration manager on the lease. The | ||||||
|  | // quota manager will use this information to update the lease path cache and | ||||||
|  | // updating counters for relevant quota rules. | ||||||
|  | type LeaseAction uint32 | ||||||
|  |  | ||||||
|  | // String converts each lease action into its string equivalent value | ||||||
|  | func (la LeaseAction) String() string { | ||||||
|  | 	switch la { | ||||||
|  | 	case LeaseActionLoaded: | ||||||
|  | 		return "loaded" | ||||||
|  | 	case LeaseActionCreated: | ||||||
|  | 		return "created" | ||||||
|  | 	case LeaseActionDeleted: | ||||||
|  | 		return "deleted" | ||||||
|  | 	case LeaseActionAllow: | ||||||
|  | 		return "allow" | ||||||
|  | 	} | ||||||
|  | 	return "unknown" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	_ LeaseAction = iota | ||||||
|  |  | ||||||
|  | 	// LeaseActionLoaded indicates loading of lease in the expiration manager after | ||||||
|  | 	// unseal. | ||||||
|  | 	LeaseActionLoaded | ||||||
|  |  | ||||||
|  | 	// LeaseActionCreated indicates that a lease is created in the expiration manager. | ||||||
|  | 	LeaseActionCreated | ||||||
|  |  | ||||||
|  | 	// LeaseActionDeleted indicates that is lease is expired and deleted in the | ||||||
|  | 	// expiration manager. | ||||||
|  | 	LeaseActionDeleted | ||||||
|  |  | ||||||
|  | 	// LeaseActionAllow will be used to indicate the lease count checker that | ||||||
|  | 	// incCounter is called from Allow(). All the rest of the actions indicate the | ||||||
|  | 	// action took place on the lease in the expiration manager. | ||||||
|  | 	LeaseActionAllow | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type leaseWalkFunc func(context.Context, func(request *Request) bool) error | ||||||
|  |  | ||||||
|  | // String converts each quota type into its string equivalent value | ||||||
|  | func (q Type) String() string { | ||||||
|  | 	switch q { | ||||||
|  | 	case TypeLeaseCount: | ||||||
|  | 		return "lease-count" | ||||||
|  | 	case TypeRateLimit: | ||||||
|  | 		return "rate-limit" | ||||||
|  | 	} | ||||||
|  | 	return "unknown" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	indexID             = "id" | ||||||
|  | 	indexName           = "name" | ||||||
|  | 	indexNamespace      = "ns" | ||||||
|  | 	indexNamespaceMount = "ns_mount" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	// StoragePrefix is the prefix for the physical location where quota rules are | ||||||
|  | 	// persisted. | ||||||
|  | 	StoragePrefix = "quotas/" | ||||||
|  |  | ||||||
|  | 	// ConfigPath is the physical location where the quota configuration is | ||||||
|  | 	// persisted. | ||||||
|  | 	ConfigPath = StoragePrefix + "config" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease | ||||||
|  | 	// count quota being exceeded. | ||||||
|  | 	ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded") | ||||||
|  |  | ||||||
|  | 	// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a | ||||||
|  | 	// rate limit quota being exceeded. | ||||||
|  | 	ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded") | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Access provides information to reach back to the quota checker. | ||||||
|  | type Access interface { | ||||||
|  | 	// QuotaID is the identifier of the quota that issued this access. | ||||||
|  | 	QuotaID() string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Ensure that access implements the Access interface. | ||||||
|  | var _ Access = (*access)(nil) | ||||||
|  |  | ||||||
|  | // access implements the Access interface | ||||||
|  | type access struct { | ||||||
|  | 	quotaID string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QuotaID returns the identifier of the quota rule to which this access refers | ||||||
|  | // to. | ||||||
|  | func (a *access) QuotaID() string { | ||||||
|  | 	return a.quotaID | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Manager holds all the existing quota rules. For any given input. the manager | ||||||
|  | // checks them against any applicable quota rules. | ||||||
|  | type Manager struct { | ||||||
|  | 	entManager | ||||||
|  |  | ||||||
|  | 	// db holds the in memory instances of all active quota rules indexed by | ||||||
|  | 	// some of the quota properties. | ||||||
|  | 	db *memdb.MemDB | ||||||
|  |  | ||||||
|  | 	// config containing operator preferences and quota behaviors | ||||||
|  | 	config *Config | ||||||
|  |  | ||||||
|  | 	storage logical.Storage | ||||||
|  | 	ctx     context.Context | ||||||
|  |  | ||||||
|  | 	logger     log.Logger | ||||||
|  | 	metricSink *metricsutil.ClusterMetricSink | ||||||
|  | 	lock       *sync.RWMutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Quota represents the common properties of every quota type | ||||||
|  | type Quota interface { | ||||||
|  | 	// allow checks the if the request is allowed by the quota type implementation. | ||||||
|  | 	allow(*Request) (Response, error) | ||||||
|  |  | ||||||
|  | 	// quotaID is the identifier of the quota rule | ||||||
|  | 	quotaID() string | ||||||
|  |  | ||||||
|  | 	// QuotaName is the name of the quota rule | ||||||
|  | 	QuotaName() string | ||||||
|  |  | ||||||
|  | 	// initialize sets up the fields in the quota type to begin operating | ||||||
|  | 	initialize(log.Logger, *metricsutil.ClusterMetricSink) error | ||||||
|  |  | ||||||
|  | 	// close defines any cleanup behavior that needs to be executed when a quota | ||||||
|  | 	// rule is deleted. | ||||||
|  | 	close() error | ||||||
|  |  | ||||||
|  | 	// handleRemount takes in the new mount path in the quota | ||||||
|  | 	handleRemount(string) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Response holds information about the result of the Allow() call. The response | ||||||
|  | // can optionally have the Access field set, which is used to reach back into | ||||||
|  | // the quota rule that sent this response. | ||||||
|  | type Response struct { | ||||||
|  | 	// Allowed is set if the quota allows the request | ||||||
|  | 	Allowed bool | ||||||
|  |  | ||||||
|  | 	// Access is the handle to reach back into the quota rule that processed the | ||||||
|  | 	// quota request. This may not be set all the time. | ||||||
|  | 	Access Access | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Config holds operator preferences around quota behaviors | ||||||
|  | type Config struct { | ||||||
|  | 	// EnableRateLimitAuditLogging, if set, starts audit logging of the | ||||||
|  | 	// request rejections that arise due to rate limit quota violations. | ||||||
|  | 	EnableRateLimitAuditLogging bool `json:"enable_rate_limit_audit_logging"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Request contains information required by the quota manager to query and | ||||||
|  | // apply the quota rules. | ||||||
|  | type Request struct { | ||||||
|  | 	// Type is the quota type | ||||||
|  | 	Type Type | ||||||
|  |  | ||||||
|  | 	// Path is the request path to which quota rules are being queried for | ||||||
|  | 	Path string | ||||||
|  |  | ||||||
|  | 	// NamespacePath is the namespace path to which the request belongs | ||||||
|  | 	NamespacePath string | ||||||
|  |  | ||||||
|  | 	// MountPath is the mount path to which the request is made | ||||||
|  | 	MountPath string | ||||||
|  |  | ||||||
|  | 	// ClientAddress is client unique addressable string (e.g. IP address). It can | ||||||
|  | 	// be empty if the quota type does not need it. | ||||||
|  | 	ClientAddress string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewManager creates and initializes a new quota manager to hold all the quota | ||||||
|  | // rules and to process incoming requests. | ||||||
|  | func NewManager(logger log.Logger, walkFunc leaseWalkFunc, ms *metricsutil.ClusterMetricSink) (*Manager, error) { | ||||||
|  | 	db, err := memdb.NewMemDB(dbSchema()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	manager := &Manager{ | ||||||
|  | 		db:         db, | ||||||
|  | 		logger:     logger, | ||||||
|  | 		metricSink: ms, | ||||||
|  | 		config:     new(Config), | ||||||
|  | 		lock:       new(sync.RWMutex), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	manager.init(walkFunc) | ||||||
|  |  | ||||||
|  | 	return manager, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // SetQuota adds a new quota rule to the db. | ||||||
|  | func (m *Manager) SetQuota(ctx context.Context, qType string, quota Quota, loading bool) error { | ||||||
|  | 	m.lock.Lock() | ||||||
|  | 	defer m.lock.Unlock() | ||||||
|  | 	return m.setQuotaLocked(ctx, qType, quota, loading) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // setQuotaLocked should be called with the manager's lock held | ||||||
|  | func (m *Manager) setQuotaLocked(ctx context.Context, qType string, quota Quota, loading bool) error { | ||||||
|  | 	if qType == TypeLeaseCount.String() { | ||||||
|  | 		m.setIsPerfStandby(quota) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	txn := m.db.Txn(true) | ||||||
|  | 	defer txn.Abort() | ||||||
|  |  | ||||||
|  | 	raw, err := txn.First(qType, "id", quota.quotaID()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If there already exists an entry in the db, remove that first. | ||||||
|  | 	if raw != nil { | ||||||
|  | 		err = txn.Delete(qType, raw) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Initialize the quota type implementation | ||||||
|  | 	if err := quota.initialize(m.logger, m.metricSink); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Add the initialized quota type implementation to the db | ||||||
|  | 	if err := txn.Insert(qType, quota); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if loading { | ||||||
|  | 		txn.Commit() | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// For the lease count type, recompute the counters | ||||||
|  | 	if !loading && qType == TypeLeaseCount.String() { | ||||||
|  | 		if err := m.recomputeLeaseCounts(ctx, txn); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	txn.Commit() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QuotaNames returns the names of all the quota rules for a given type | ||||||
|  | func (m *Manager) QuotaNames(qType Type) ([]string, error) { | ||||||
|  | 	m.lock.RLock() | ||||||
|  | 	defer m.lock.RUnlock() | ||||||
|  |  | ||||||
|  | 	txn := m.db.Txn(false) | ||||||
|  | 	iter, err := txn.Get(qType.String(), indexID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	var names []string | ||||||
|  | 	for raw := iter.Next(); raw != nil; raw = iter.Next() { | ||||||
|  | 		names = append(names, raw.(Quota).QuotaName()) | ||||||
|  | 	} | ||||||
|  | 	return names, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QuotaByID queries for a quota rule in the db for a given quota ID | ||||||
|  | func (m *Manager) QuotaByID(qType string, id string) (Quota, error) { | ||||||
|  | 	m.lock.RLock() | ||||||
|  | 	defer m.lock.RUnlock() | ||||||
|  |  | ||||||
|  | 	txn := m.db.Txn(false) | ||||||
|  |  | ||||||
|  | 	quotaRaw, err := txn.First(qType, indexID, id) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if quotaRaw == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return quotaRaw.(Quota), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QuotaByName queries for a quota rule in the db for a given quota name | ||||||
|  | func (m *Manager) QuotaByName(qType string, name string) (Quota, error) { | ||||||
|  | 	m.lock.RLock() | ||||||
|  | 	defer m.lock.RUnlock() | ||||||
|  |  | ||||||
|  | 	txn := m.db.Txn(false) | ||||||
|  |  | ||||||
|  | 	quotaRaw, err := txn.First(qType, indexName, name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if quotaRaw == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return quotaRaw.(Quota), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QuotaByFactors returns the quota rule that matches the provided factors | ||||||
|  | func (m *Manager) QuotaByFactors(ctx context.Context, qType, nsPath, mountPath string) (Quota, error) { | ||||||
|  | 	m.lock.RLock() | ||||||
|  | 	defer m.lock.RUnlock() | ||||||
|  |  | ||||||
|  | 	// nsPath would have been made non-empty during insertion. Use non-empty value | ||||||
|  | 	// during query as well. | ||||||
|  | 	if nsPath == "" { | ||||||
|  | 		nsPath = "root" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	idx := indexNamespace | ||||||
|  | 	args := []interface{}{nsPath, false} | ||||||
|  | 	if mountPath != "" { | ||||||
|  | 		idx = indexNamespaceMount | ||||||
|  | 		args = []interface{}{nsPath, mountPath} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	txn := m.db.Txn(false) | ||||||
|  | 	iter, err := txn.Get(qType, idx, args...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	var quotas []Quota | ||||||
|  | 	for raw := iter.Next(); raw != nil; raw = iter.Next() { | ||||||
|  | 		quotas = append(quotas, raw.(Quota)) | ||||||
|  | 	} | ||||||
|  | 	if len(quotas) > 1 { | ||||||
|  | 		return nil, fmt.Errorf("conflicting quota definitions detected") | ||||||
|  | 	} | ||||||
|  | 	if len(quotas) == 0 { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return quotas[0], nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // queryQuota returns the quota rule that is applicable for the given request. It | ||||||
|  | // queries all the quota rules that are defined against request values and finds | ||||||
|  | // the quota rule that takes priority. | ||||||
|  | // | ||||||
|  | // Priority rules are as follows: | ||||||
|  | // - namespace specific quota takes precedence over global quota | ||||||
|  | // - mount specific quota takes precedence over namespace specific quota | ||||||
|  | func (m *Manager) queryQuota(txn *memdb.Txn, req *Request) (Quota, error) { | ||||||
|  | 	if txn == nil { | ||||||
|  | 		txn = m.db.Txn(false) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// ns would have been made non-empty during insertion. Use non-empty | ||||||
|  | 	// value during query as well. | ||||||
|  | 	if req.NamespacePath == "" { | ||||||
|  | 		req.NamespacePath = "root" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// | ||||||
|  | 	// Find a match from most specific applicable quota rule to less specific one. | ||||||
|  | 	// | ||||||
|  | 	quotaFetchFunc := func(idx string, args ...interface{}) (Quota, error) { | ||||||
|  | 		iter, err := txn.Get(req.Type.String(), idx, args...) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		var quotas []Quota | ||||||
|  | 		for raw := iter.Next(); raw != nil; raw = iter.Next() { | ||||||
|  | 			quota := raw.(Quota) | ||||||
|  | 			quotas = append(quotas, quota) | ||||||
|  | 		} | ||||||
|  | 		if len(quotas) > 1 { | ||||||
|  | 			return nil, fmt.Errorf("conflicting quota definitions detected") | ||||||
|  | 		} | ||||||
|  | 		if len(quotas) == 0 { | ||||||
|  | 			return nil, nil | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return quotas[0], nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Fetch mount quota | ||||||
|  | 	quota, err := quotaFetchFunc(indexNamespaceMount, req.NamespacePath, req.MountPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if quota != nil { | ||||||
|  | 		return quota, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Fetch ns quota. If NamespacePath is root, this will return the global quota. | ||||||
|  | 	quota, err = quotaFetchFunc(indexNamespace, req.NamespacePath, false) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if quota != nil { | ||||||
|  | 		return quota, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If the request belongs to "root" namespace, then we have already looked at | ||||||
|  | 	// global quotas when fetching namespace specific quota rule. When the request | ||||||
|  | 	// belongs to a non-root namespace, and when there are no namespace specific | ||||||
|  | 	// quota rules present, we fallback on the global quotas. | ||||||
|  | 	if req.NamespacePath == "root" { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Fetch global quota | ||||||
|  | 	quota, err = quotaFetchFunc(indexNamespace, "root", false) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if quota != nil { | ||||||
|  | 		return quota, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // DeleteQuota removes a quota rule from the db for a given name | ||||||
|  | func (m *Manager) DeleteQuota(ctx context.Context, qType string, name string) error { | ||||||
|  | 	m.lock.Lock() | ||||||
|  | 	defer m.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	txn := m.db.Txn(true) | ||||||
|  | 	defer txn.Abort() | ||||||
|  |  | ||||||
|  | 	raw, err := txn.First(qType, indexName, name) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if raw == nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	quota := raw.(Quota) | ||||||
|  | 	if err := quota.close(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = txn.Delete(qType, raw) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// For the lease count type, recompute the counters | ||||||
|  | 	if qType == TypeLeaseCount.String() { | ||||||
|  | 		if err := m.recomputeLeaseCounts(ctx, txn); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	txn.Commit() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ApplyQuota runs the request against any quota rule that is applicable to it. If | ||||||
|  | // there are multiple quota rule that matches the request parameters, rule that | ||||||
|  | // takes precedence will be used to allow/reject the request. | ||||||
|  | func (m *Manager) ApplyQuota(req *Request) (Response, error) { | ||||||
|  | 	var resp Response | ||||||
|  |  | ||||||
|  | 	quota, err := m.queryQuota(nil, req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return resp, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If there is no quota defined, allow the request. | ||||||
|  | 	if quota == nil { | ||||||
|  | 		resp.Allowed = true | ||||||
|  | 		return resp, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If the quota type is lease count, and if the path is not known to | ||||||
|  | 	// generate leases, allow the request. | ||||||
|  | 	if req.Type == TypeLeaseCount && !m.inLeasePathCache(req.Path) { | ||||||
|  | 		resp.Allowed = true | ||||||
|  | 		return resp, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return quota.allow(req) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // SetEnableRateLimitAuditLogging updates the operator preference regarding the | ||||||
|  | // audit logging behavior. | ||||||
|  | func (m *Manager) SetEnableRateLimitAuditLogging(val bool) { | ||||||
|  | 	m.config.EnableRateLimitAuditLogging = val | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // RateLimitAuditLoggingEnabled returns if the quota configuration allows audit | ||||||
|  | // logging of request rejections due to rate limiting quota rule violations. | ||||||
|  | func (m *Manager) RateLimitAuditLoggingEnabled() bool { | ||||||
|  | 	return m.config.EnableRateLimitAuditLogging | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Config returns the operator preferences in the quota manager | ||||||
|  | func (m *Manager) Config() *Config { | ||||||
|  | 	return m.config | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Reset will clear all the quotas from the db and clear the lease path cache. | ||||||
|  | func (m *Manager) Reset() error { | ||||||
|  | 	m.lock.Lock() | ||||||
|  | 	defer m.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	var err error | ||||||
|  | 	m.db, err = memdb.NewMemDB(dbSchema()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	m.storage = nil | ||||||
|  | 	m.ctx = nil | ||||||
|  |  | ||||||
|  | 	m.entManager.Reset() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // dbSchema creates a DB schema for holding all the quota rules. It creates a | ||||||
|  | // table for each supported type of quota. | ||||||
|  | func dbSchema() *memdb.DBSchema { | ||||||
|  | 	schema := &memdb.DBSchema{ | ||||||
|  | 		Tables: make(map[string]*memdb.TableSchema), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	commonSchema := func(name string) *memdb.TableSchema { | ||||||
|  | 		return &memdb.TableSchema{ | ||||||
|  | 			Name: name, | ||||||
|  | 			Indexes: map[string]*memdb.IndexSchema{ | ||||||
|  | 				indexID: { | ||||||
|  | 					Name:   indexID, | ||||||
|  | 					Unique: true, | ||||||
|  | 					Indexer: &memdb.StringFieldIndex{ | ||||||
|  | 						Field: "ID", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				indexName: { | ||||||
|  | 					Name:   indexName, | ||||||
|  | 					Unique: true, | ||||||
|  | 					Indexer: &memdb.StringFieldIndex{ | ||||||
|  | 						Field: "Name", | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				indexNamespace: { | ||||||
|  | 					Name: indexNamespace, | ||||||
|  | 					Indexer: &memdb.CompoundMultiIndex{ | ||||||
|  | 						Indexes: []memdb.Indexer{ | ||||||
|  | 							&memdb.StringFieldIndex{ | ||||||
|  | 								Field: "NamespacePath", | ||||||
|  | 							}, | ||||||
|  | 							// By sending false as the query parameter, we can | ||||||
|  | 							// query just the namespace specific quota. | ||||||
|  | 							&memdb.FieldSetIndex{ | ||||||
|  | 								Field: "MountPath", | ||||||
|  | 							}, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				indexNamespaceMount: { | ||||||
|  | 					Name:         indexNamespaceMount, | ||||||
|  | 					AllowMissing: true, | ||||||
|  | 					Indexer: &memdb.CompoundMultiIndex{ | ||||||
|  | 						Indexes: []memdb.Indexer{ | ||||||
|  | 							&memdb.StringFieldIndex{ | ||||||
|  | 								Field: "NamespacePath", | ||||||
|  | 							}, | ||||||
|  | 							&memdb.StringFieldIndex{ | ||||||
|  | 								Field: "MountPath", | ||||||
|  | 							}, | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create a table per quota type. This allows names to be reused between | ||||||
|  | 	// different quota types and querying a bit easier. | ||||||
|  | 	for _, name := range quotaTypes() { | ||||||
|  | 		schema.Tables[name] = commonSchema(name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return schema | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Invalidate receives notifications from the replication sub-system when a key | ||||||
|  | // is updated in the storage. This function will read the key from storage and | ||||||
|  | // updates the caches and data structures to reflect those updates. | ||||||
|  | func (m *Manager) Invalidate(key string) { | ||||||
|  | 	switch key { | ||||||
|  | 	case "config": | ||||||
|  | 		config, err := LoadConfig(m.ctx, m.storage) | ||||||
|  | 		if err != nil { | ||||||
|  | 			m.logger.Error("failed to invalidate quota config", "error", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		m.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging) | ||||||
|  | 	default: | ||||||
|  | 		splitKeys := strings.Split(key, "/") | ||||||
|  | 		if len(splitKeys) != 2 { | ||||||
|  | 			m.logger.Error("incorrect key while invalidating quota rule") | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		qType := splitKeys[0] | ||||||
|  | 		name := splitKeys[1] | ||||||
|  |  | ||||||
|  | 		// Read quota rule from storage | ||||||
|  | 		quota, err := Load(m.ctx, m.storage, qType, name) | ||||||
|  | 		if err != nil { | ||||||
|  | 			m.logger.Error("failed to read invalidated quota rule", "error", err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		switch { | ||||||
|  | 		case quota == nil: | ||||||
|  | 			// Handle quota deletion | ||||||
|  | 			if err := m.DeleteQuota(m.ctx, qType, name); err != nil { | ||||||
|  | 				m.logger.Error("failed to delete invalidated quota rule", "error", err) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			// Handle quota update | ||||||
|  | 			if err := m.SetQuota(m.ctx, qType, quota, false); err != nil { | ||||||
|  | 				m.logger.Error("failed to update invalidated quota rule", "error", err) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // LoadConfig reads the quota configuration from the underlying storage | ||||||
|  | func LoadConfig(ctx context.Context, storage logical.Storage) (*Config, error) { | ||||||
|  | 	var config Config | ||||||
|  | 	entry, err := storage.Get(ctx, ConfigPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if entry == nil { | ||||||
|  | 		return &config, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = entry.DecodeJSON(&config) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &config, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Load reads the quota rule from the underlying storage | ||||||
|  | func Load(ctx context.Context, storage logical.Storage, qType, name string) (Quota, error) { | ||||||
|  | 	var quota Quota | ||||||
|  | 	entry, err := storage.Get(ctx, QuotaStoragePath(qType, name)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if entry == nil { | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	switch qType { | ||||||
|  | 	case TypeRateLimit.String(): | ||||||
|  | 		quota = &RateLimitQuota{} | ||||||
|  | 	case TypeLeaseCount.String(): | ||||||
|  | 		quota = &LeaseCountQuota{} | ||||||
|  | 	default: | ||||||
|  | 		return nil, fmt.Errorf("unsupported type: %v", qType) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = entry.DecodeJSON(quota) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return quota, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Setup loads the quota configuration and all the quota rules into the | ||||||
|  | // quota manager. | ||||||
|  | func (m *Manager) Setup(ctx context.Context, storage logical.Storage, isPerfStandby bool) error { | ||||||
|  | 	m.lock.Lock() | ||||||
|  | 	defer m.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	m.storage = storage | ||||||
|  | 	m.ctx = ctx | ||||||
|  | 	m.isPerfStandby = isPerfStandby | ||||||
|  |  | ||||||
|  | 	// Load the quota configuration from storage and load it into the quota | ||||||
|  | 	// manager. | ||||||
|  | 	config, err := LoadConfig(ctx, storage) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	m.SetEnableRateLimitAuditLogging(config.EnableRateLimitAuditLogging) | ||||||
|  |  | ||||||
|  | 	// Load the quota rules for all supported types from storage and load it in | ||||||
|  | 	// the quota manager. | ||||||
|  | 	for _, qType := range quotaTypes() { | ||||||
|  | 		names, err := logical.CollectKeys(ctx, logical.NewStorageView(storage, StoragePrefix+qType+"/")) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 		for _, name := range names { | ||||||
|  | 			quota, err := Load(ctx, m.storage, qType, name) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if quota == nil { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			err = m.setQuotaLocked(ctx, qType, quota, true) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QuotaStoragePath returns the storage path suffix for persisting the quota | ||||||
|  | // rule. | ||||||
|  | func QuotaStoragePath(quotaType, name string) string { | ||||||
|  | 	return path.Join(StoragePrefix+quotaType, name) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // HandleRemount updates the quota subsystem about the remount operation that | ||||||
|  | // took place. Quota manager will trigger the quota specific updates including | ||||||
|  | // the mount path update.. | ||||||
|  | func (m *Manager) HandleRemount(ctx context.Context, nsPath, fromPath, toPath string) error { | ||||||
|  | 	m.lock.Lock() | ||||||
|  | 	defer m.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	txn := m.db.Txn(true) | ||||||
|  | 	defer txn.Abort() | ||||||
|  |  | ||||||
|  | 	// nsPath would have been made non-empty during insertion. Use non-empty value | ||||||
|  | 	// during query as well. | ||||||
|  | 	if nsPath == "" { | ||||||
|  | 		nsPath = "root" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	idx := indexNamespaceMount | ||||||
|  | 	leaseQuotaUpdated := false | ||||||
|  | 	args := []interface{}{nsPath, fromPath} | ||||||
|  | 	for _, quotaType := range quotaTypes() { | ||||||
|  | 		iter, err := txn.Get(quotaType, idx, args...) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		for raw := iter.Next(); raw != nil; raw = iter.Next() { | ||||||
|  | 			quota := raw.(Quota) | ||||||
|  | 			quota.handleRemount(toPath) | ||||||
|  | 			entry, err := logical.StorageEntryJSON(QuotaStoragePath(quotaType, quota.QuotaName()), quota) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if err := m.storage.Put(ctx, entry); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 			if quotaType == TypeLeaseCount.String() { | ||||||
|  | 				leaseQuotaUpdated = true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if leaseQuotaUpdated { | ||||||
|  | 		if err := m.recomputeLeaseCounts(ctx, txn); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	txn.Commit() | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // HandleBackendDisabling updates the quota subsystem with the disabling of auth | ||||||
|  | // or secret engine disabling. | ||||||
|  | func (m *Manager) HandleBackendDisabling(ctx context.Context, nsPath, mountPath string) error { | ||||||
|  | 	m.lock.Lock() | ||||||
|  | 	defer m.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	txn := m.db.Txn(true) | ||||||
|  | 	defer txn.Abort() | ||||||
|  |  | ||||||
|  | 	// nsPath would have been made non-empty during insertion. Use non-empty value | ||||||
|  | 	// during query as well. | ||||||
|  | 	if nsPath == "" { | ||||||
|  | 		nsPath = "root" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	idx := indexNamespaceMount | ||||||
|  | 	leaseQuotaDeleted := false | ||||||
|  | 	args := []interface{}{nsPath, mountPath} | ||||||
|  | 	for _, quotaType := range quotaTypes() { | ||||||
|  | 		iter, err := txn.Get(quotaType, idx, args...) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		for raw := iter.Next(); raw != nil; raw = iter.Next() { | ||||||
|  | 			if err := txn.Delete(quotaType, raw); err != nil { | ||||||
|  | 				return fmt.Errorf("failed to delete quota from db after mount disabling; namespace %q, err %v", nsPath, err) | ||||||
|  | 			} | ||||||
|  | 			quota := raw.(Quota) | ||||||
|  | 			if err := m.storage.Delete(ctx, QuotaStoragePath(quotaType, quota.QuotaName())); err != nil { | ||||||
|  | 				return fmt.Errorf("failed to delete quota from storage after mount disabling; namespace %q, err %v", nsPath, err) | ||||||
|  | 			} | ||||||
|  | 			if quotaType == TypeLeaseCount.String() { | ||||||
|  | 				leaseQuotaDeleted = true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if leaseQuotaDeleted { | ||||||
|  | 		if err := m.recomputeLeaseCounts(ctx, txn); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	txn.Commit() | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										282
									
								
								vault/quotas/quotas_rate_limit.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										282
									
								
								vault/quotas/quotas_rate_limit.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,282 @@ | |||||||
|  | package quotas | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"math" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/armon/go-metrics" | ||||||
|  | 	log "github.com/hashicorp/go-hclog" | ||||||
|  | 	"github.com/hashicorp/go-uuid" | ||||||
|  | 	"github.com/hashicorp/vault/helper/metricsutil" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/helper/pathmanager" | ||||||
|  | 	"golang.org/x/time/rate" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var rateLimitExemptPaths = pathmanager.New() | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	// DefaultRateLimitPurgeInterval defines the default purge interval used by a | ||||||
|  | 	// RateLimitQuota to remove stale client rate limiters. | ||||||
|  | 	DefaultRateLimitPurgeInterval = time.Minute | ||||||
|  |  | ||||||
|  | 	// DefaultRateLimitStaleAge defines the default stale age of a client limiter. | ||||||
|  | 	DefaultRateLimitStaleAge = 3 * time.Minute | ||||||
|  |  | ||||||
|  | 	// EnvVaultEnableRateLimitAuditLogging is used to enable audit logging of | ||||||
|  | 	// requests that get rejected due to rate limit quota violations. | ||||||
|  | 	EnvVaultEnableRateLimitAuditLogging = "VAULT_ENABLE_RATE_LIMIT_AUDIT_LOGGING" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	rateLimitExemptPaths.AddPaths([]string{ | ||||||
|  | 		"/v1/sys/generate-recovery-token/attempt", | ||||||
|  | 		"/v1/sys/generate-recovery-token/update", | ||||||
|  | 		"/v1/sys/generate-root/attempt", | ||||||
|  | 		"/v1/sys/generate-root/update", | ||||||
|  | 		"/v1/sys/health", | ||||||
|  | 		"/v1/sys/seal-status", | ||||||
|  | 		"/v1/sys/unseal", | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ClientRateLimiter defines a token bucket based rate limiter for a unique | ||||||
|  | // addressable client (e.g. IP address). Whenever this client attempts to make | ||||||
|  | // a request, the lastSeen value will be updated. | ||||||
|  | type ClientRateLimiter struct { | ||||||
|  | 	// lastSeen defines the UNIX timestamp the client last made a request. | ||||||
|  | 	lastSeen time.Time | ||||||
|  |  | ||||||
|  | 	// limiter represents an instance of a token bucket based rate limiter. | ||||||
|  | 	limiter *rate.Limiter | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // newClientRateLimiter returns a token bucket based rate limiter for a client | ||||||
|  | // that is uniquely addressable, where maxRequests defines the requests-per-second | ||||||
|  | // and burstSize defines the maximum burst allowed. A caller may provide -1 for | ||||||
|  | // burstSize to allow the burst value to be roughly equivalent to the RPS. Note, | ||||||
|  | // the underlying rate limiter is already thread-safe. | ||||||
|  | func newClientRateLimiter(maxRequests float64, burstSize int) *ClientRateLimiter { | ||||||
|  | 	if burstSize < 0 { | ||||||
|  | 		burstSize = int(math.Ceil(maxRequests)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &ClientRateLimiter{ | ||||||
|  | 		lastSeen: time.Now().UTC(), | ||||||
|  | 		limiter:  rate.NewLimiter(rate.Limit(maxRequests), burstSize), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Ensure that RateLimitQuota implements the Quota interface | ||||||
|  | var _ Quota = (*RateLimitQuota)(nil) | ||||||
|  |  | ||||||
|  | // RateLimitQuota represents the quota rule properties that is used to limit the | ||||||
|  | // number of requests per second for a namespace or mount. | ||||||
|  | type RateLimitQuota struct { | ||||||
|  | 	// ID is the identifier of the quota | ||||||
|  | 	ID string `json:"id"` | ||||||
|  |  | ||||||
|  | 	// Type of quota this represents | ||||||
|  | 	Type Type `json:"type"` | ||||||
|  |  | ||||||
|  | 	// Name of the quota rule | ||||||
|  | 	Name string `json:"name"` | ||||||
|  |  | ||||||
|  | 	// NamespacePath is the path of the namespace to which this quota is | ||||||
|  | 	// applicable. | ||||||
|  | 	NamespacePath string `json:"namespace_path"` | ||||||
|  |  | ||||||
|  | 	// MountPath is the path of the mount to which this quota is applicable | ||||||
|  | 	MountPath string `json:"mount_path"` | ||||||
|  |  | ||||||
|  | 	// Rate defines the rate of which allowed requests are refilled per second. | ||||||
|  | 	Rate float64 `json:"rate"` | ||||||
|  |  | ||||||
|  | 	// Burst defines maximum number of requests at any given moment to be allowed. | ||||||
|  | 	Burst int `json:"burst"` | ||||||
|  |  | ||||||
|  | 	lock         *sync.Mutex | ||||||
|  | 	logger       log.Logger | ||||||
|  | 	metricSink   *metricsutil.ClusterMetricSink | ||||||
|  | 	purgeEnabled bool | ||||||
|  |  | ||||||
|  | 	// purgeInterval defines the interval in seconds in which the RateLimitQuota | ||||||
|  | 	// attempts to remove stale entries from the rateQuotas mapping. | ||||||
|  | 	purgeInterval time.Duration | ||||||
|  | 	closeCh       chan struct{} | ||||||
|  |  | ||||||
|  | 	// staleAge defines the age in seconds in which a clientRateLimiter is | ||||||
|  | 	// considered stale. A clientRateLimiter is considered stale if the delta | ||||||
|  | 	// between the current purge time and its lastSeen timestamp is greater than | ||||||
|  | 	// this value. | ||||||
|  | 	staleAge time.Duration | ||||||
|  |  | ||||||
|  | 	// rateQuotas contains a mapping from a unique addressable client (e.g. IP address) | ||||||
|  | 	// to a clientRateLimiter reference. Every purgeInterval seconds, the RateLimitQuota | ||||||
|  | 	// will attempt to remove stale entries from the mapping. | ||||||
|  | 	rateQuotas map[string]*ClientRateLimiter | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewRateLimitQuota creates a quota checker for imposing limits on the number | ||||||
|  | // of requests per second. | ||||||
|  | func NewRateLimitQuota(name, nsPath, mountPath string, rate float64, burst int) *RateLimitQuota { | ||||||
|  | 	return &RateLimitQuota{ | ||||||
|  | 		Name:          name, | ||||||
|  | 		Type:          TypeRateLimit, | ||||||
|  | 		NamespacePath: nsPath, | ||||||
|  | 		MountPath:     mountPath, | ||||||
|  | 		Rate:          rate, | ||||||
|  | 		Burst:         burst, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // jnitialize ensures the namespace and max requests are initialized, sets the ID | ||||||
|  | // if it's currently empty, sets the purge interval and stale age to default | ||||||
|  | // values, and finally starts the client purge go routine if it has been started | ||||||
|  | // already. Note, initialize will reset the internal rateQuotas mapping. | ||||||
|  | func (rlq *RateLimitQuota) initialize(logger log.Logger, ms *metricsutil.ClusterMetricSink) error { | ||||||
|  | 	if rlq.lock == nil { | ||||||
|  | 		rlq.lock = new(sync.Mutex) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rlq.lock.Lock() | ||||||
|  | 	defer rlq.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	// Memdb requires a non-empty value for indexing | ||||||
|  | 	if rlq.NamespacePath == "" { | ||||||
|  | 		rlq.NamespacePath = "root" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if rlq.Rate <= 0 { | ||||||
|  | 		return fmt.Errorf("invalid avg rps: %v", rlq.Rate) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if rlq.Burst < int(rlq.Rate) { | ||||||
|  | 		return fmt.Errorf("burst size (%v) must be greater than or equal to average rps (%v)", rlq.Burst, rlq.Rate) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if logger != nil { | ||||||
|  | 		rlq.logger = logger | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if rlq.metricSink == nil { | ||||||
|  | 		rlq.metricSink = ms | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if rlq.ID == "" { | ||||||
|  | 		id, err := uuid.GenerateUUID() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		rlq.ID = id | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rlq.purgeInterval = DefaultRateLimitPurgeInterval | ||||||
|  | 	rlq.staleAge = DefaultRateLimitStaleAge | ||||||
|  | 	rlq.rateQuotas = make(map[string]*ClientRateLimiter) | ||||||
|  |  | ||||||
|  | 	if !rlq.purgeEnabled { | ||||||
|  | 		rlq.purgeEnabled = true | ||||||
|  | 		rlq.closeCh = make(chan struct{}) | ||||||
|  | 		go rlq.purgeClientsLoop() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // quotaID returns the identifier of the quota rule | ||||||
|  | func (rlq *RateLimitQuota) quotaID() string { | ||||||
|  | 	return rlq.ID | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // QuotaName returns the name of the quota rule | ||||||
|  | func (rlq *RateLimitQuota) QuotaName() string { | ||||||
|  | 	return rlq.Name | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // purgeClientsLoop performs a blocking process where every purgeInterval | ||||||
|  | // duration, we look for stale clients to remove from the rateQuotas map. | ||||||
|  | // A ClientRateLimiter is considered stale if its lastSeen timestamp exceeds the | ||||||
|  | // current time. The loop will continue to run indefinitely until a value is | ||||||
|  | // sent on the closeCh in which we stop the ticker and exit. | ||||||
|  | func (rlq *RateLimitQuota) purgeClientsLoop() { | ||||||
|  | 	ticker := time.NewTicker(rlq.purgeInterval) | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case t := <-ticker.C: | ||||||
|  | 			rlq.lock.Lock() | ||||||
|  |  | ||||||
|  | 			for client, crl := range rlq.rateQuotas { | ||||||
|  | 				if t.UTC().Sub(crl.lastSeen) >= rlq.staleAge { | ||||||
|  | 					delete(rlq.rateQuotas, client) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			rlq.lock.Unlock() | ||||||
|  |  | ||||||
|  | 		case <-rlq.closeCh: | ||||||
|  | 			ticker.Stop() | ||||||
|  | 			rlq.purgeEnabled = false | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // clientRateLimiter returns a reference to a ClientRateLimiter based on a | ||||||
|  | // provided client address (e.g. IP address). If the ClientRateLimiter does not | ||||||
|  | // exist in the RateLimitQuota's mapping, one will be created and set. The | ||||||
|  | // created RateLimitQuota will have its requests-per-second set to | ||||||
|  | // RateLimitQuota.AverageRps. If the ClientRateLimiter already exists, the | ||||||
|  | // lastSeen timestamp will be updated. | ||||||
|  | func (rlq *RateLimitQuota) clientRateLimiter(addr string) *ClientRateLimiter { | ||||||
|  | 	rlq.lock.Lock() | ||||||
|  | 	defer rlq.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	crl, ok := rlq.rateQuotas[addr] | ||||||
|  | 	if !ok { | ||||||
|  | 		limiter := newClientRateLimiter(rlq.Rate, rlq.Burst) | ||||||
|  | 		rlq.rateQuotas[addr] = limiter | ||||||
|  | 		return limiter | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	crl.lastSeen = time.Now().UTC() | ||||||
|  | 	return crl | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // allow decides if the request is allowed by the quota. An error will be | ||||||
|  | // returned if the request ID or address is empty. If the path is exempt, the | ||||||
|  | // quota will not be evaluated. Otherwise, the client rate limiter is retrieved | ||||||
|  | // by address and the rate limit quota is checked against that limiter. | ||||||
|  | func (rlq *RateLimitQuota) allow(req *Request) (Response, error) { | ||||||
|  | 	var resp Response | ||||||
|  |  | ||||||
|  | 	// Skip rate limit checks for paths that are exempt from rate limiting. | ||||||
|  | 	if rateLimitExemptPaths.HasPath(req.Path) { | ||||||
|  | 		resp.Allowed = true | ||||||
|  | 		return resp, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if req.ClientAddress == "" { | ||||||
|  | 		return resp, fmt.Errorf("missing request client address in quota request") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp.Allowed = rlq.clientRateLimiter(req.ClientAddress).limiter.Allow() | ||||||
|  | 	if !resp.Allowed { | ||||||
|  | 		rlq.metricSink.IncrCounterWithLabels([]string{"quota", "rate_limit", "violation"}, 1, []metrics.Label{{"name", rlq.Name}}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return resp, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // close stops the current running client purge loop. | ||||||
|  | func (rlq *RateLimitQuota) close() error { | ||||||
|  | 	close(rlq.closeCh) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rlq *RateLimitQuota) handleRemount(toPath string) { | ||||||
|  | 	rlq.MountPath = toPath | ||||||
|  | } | ||||||
							
								
								
									
										173
									
								
								vault/quotas/quotas_rate_limit_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								vault/quotas/quotas_rate_limit_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,173 @@ | |||||||
|  | package quotas | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	log "github.com/hashicorp/go-hclog" | ||||||
|  | 	"github.com/hashicorp/vault/helper/metricsutil" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/helper/logging" | ||||||
|  | 	"go.uber.org/atomic" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestNewClientRateLimiter(t *testing.T) { | ||||||
|  | 	testCases := []struct { | ||||||
|  | 		maxRequests   float64 | ||||||
|  | 		burstSize     int | ||||||
|  | 		expectedBurst int | ||||||
|  | 	}{ | ||||||
|  | 		{1000, -1, 1000}, | ||||||
|  | 		{1000, 5000, 5000}, | ||||||
|  | 		{16.1, -1, 17}, | ||||||
|  | 		{16.7, -1, 17}, | ||||||
|  | 		{16.7, 100, 100}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range testCases { | ||||||
|  | 		crl := newClientRateLimiter(tc.maxRequests, tc.burstSize) | ||||||
|  | 		b := crl.limiter.Burst() | ||||||
|  | 		if b != tc.expectedBurst { | ||||||
|  | 			t.Fatalf("unexpected burst size; expected: %d, got: %d", tc.expectedBurst, b) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNewRateLimitQuota(t *testing.T) { | ||||||
|  | 	rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, 50) | ||||||
|  | 	if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !rlq.purgeEnabled { | ||||||
|  | 		t.Fatal("expected rate limit quota to start purge loop") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if rlq.purgeInterval != DefaultRateLimitPurgeInterval { | ||||||
|  | 		t.Fatalf("unexpected purgeInterval; expected: %d, got: %d", DefaultRateLimitPurgeInterval, rlq.purgeInterval) | ||||||
|  | 	} | ||||||
|  | 	if rlq.staleAge != DefaultRateLimitStaleAge { | ||||||
|  | 		t.Fatalf("unexpected staleAge; expected: %d, got: %d", DefaultRateLimitStaleAge, rlq.staleAge) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRateLimitQuota_Close(t *testing.T) { | ||||||
|  | 	rlq := NewRateLimitQuota("test-rate-limiter", "qa", "/foo/bar", 16.7, 50) | ||||||
|  |  | ||||||
|  | 	if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := rlq.close(); err != nil { | ||||||
|  | 		t.Fatalf("unexpected error when closing: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	time.Sleep(time.Second) // allow enough time for purgeClientsLoop to receive on closeCh | ||||||
|  |  | ||||||
|  | 	if rlq.purgeEnabled { | ||||||
|  | 		t.Fatal("expected client purging to be disabled after close") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRateLimitQuota_Allow(t *testing.T) { | ||||||
|  | 	rlq := &RateLimitQuota{ | ||||||
|  | 		Name:          "test-rate-limiter", | ||||||
|  | 		Type:          TypeRateLimit, | ||||||
|  | 		NamespacePath: "qa", | ||||||
|  | 		MountPath:     "/foo/bar", | ||||||
|  | 		Rate:          16.7, | ||||||
|  | 		Burst:         83, | ||||||
|  | 		purgeEnabled:  true, // to allow manual setting of purgeInterval and staleAge | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := rlq.initialize(logging.NewVaultLogger(log.Trace), metricsutil.BlackholeSink()); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// override value and manually start purgeClientsLoop for testing purposes | ||||||
|  | 	rlq.purgeInterval = 10 * time.Second | ||||||
|  | 	rlq.staleAge = 10 * time.Second | ||||||
|  | 	go rlq.purgeClientsLoop() | ||||||
|  |  | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  |  | ||||||
|  | 	type clientResult struct { | ||||||
|  | 		atomicNumAllow *atomic.Int32 | ||||||
|  | 		atomicNumFail  *atomic.Int32 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	reqFunc := func(addr string, atomicNumAllow, atomicNumFail *atomic.Int32) { | ||||||
|  | 		defer wg.Done() | ||||||
|  |  | ||||||
|  | 		resp, err := rlq.allow(&Request{ClientAddress: addr}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if resp.Allowed { | ||||||
|  | 			atomicNumAllow.Add(1) | ||||||
|  | 		} else { | ||||||
|  | 			atomicNumFail.Add(1) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	results := make(map[string]*clientResult) | ||||||
|  |  | ||||||
|  | 	start := time.Now() | ||||||
|  | 	end := start.Add(5 * time.Second) | ||||||
|  | 	for time.Now().Before(end) { | ||||||
|  |  | ||||||
|  | 		for i := 0; i < 5; i++ { | ||||||
|  | 			wg.Add(1) | ||||||
|  |  | ||||||
|  | 			addr := fmt.Sprintf("127.0.0.%d", i) | ||||||
|  | 			cr, ok := results[addr] | ||||||
|  | 			if !ok { | ||||||
|  | 				results[addr] = &clientResult{atomicNumAllow: atomic.NewInt32(0), atomicNumFail: atomic.NewInt32(0)} | ||||||
|  | 				cr = results[addr] | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			go reqFunc(addr, cr.atomicNumAllow, cr.atomicNumFail) | ||||||
|  |  | ||||||
|  | 			time.Sleep(2 * time.Millisecond) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	wg.Wait() | ||||||
|  |  | ||||||
|  | 	if got, expected := len(results), len(rlq.rateQuotas); got != expected { | ||||||
|  | 		t.Fatalf("unexpected number of tracked client rate limit quotas; got %d, expected; %d", got, expected) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	elapsed := time.Since(start) | ||||||
|  |  | ||||||
|  | 	// evaluate the ideal RPS as (burst + (RPS * totalSeconds)) | ||||||
|  | 	ideal := float64(rlq.Burst) + (rlq.Rate * float64(elapsed) / float64(time.Second)) | ||||||
|  |  | ||||||
|  | 	for addr, cr := range results { | ||||||
|  | 		numAllow := cr.atomicNumAllow.Load() | ||||||
|  | 		numFail := cr.atomicNumFail.Load() | ||||||
|  |  | ||||||
|  | 		// ensure there were some failed requests for the namespace | ||||||
|  | 		if numFail == 0 { | ||||||
|  | 			t.Fatalf("expected some requests to fail; addr: %s, numSuccess: %d, numFail: %d, elapsed: %d", addr, numAllow, numFail, elapsed) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// ensure that we should never get more requests than allowed for the namespace | ||||||
|  | 		if want := int32(ideal + 1); numAllow > want { | ||||||
|  | 			t.Fatalf("too many successful requests; addr: %s, want: %d, numSuccess: %d, numFail: %d, elapsed: %d", addr, want, numAllow, numFail, elapsed) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// allow enough time for the client to be purged | ||||||
|  | 	time.Sleep(rlq.purgeInterval * 2) | ||||||
|  |  | ||||||
|  | 	for addr := range results { | ||||||
|  | 		rlc, ok := rlq.rateQuotas[addr] | ||||||
|  | 		if ok || rlc != nil { | ||||||
|  | 			t.Fatalf("expected stale client to be purged: %s", addr) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										67
									
								
								vault/quotas/quotas_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								vault/quotas/quotas_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | |||||||
|  | package quotas | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/go-test/deep" | ||||||
|  | 	log "github.com/hashicorp/go-hclog" | ||||||
|  | 	"github.com/hashicorp/vault/helper/metricsutil" | ||||||
|  | 	"github.com/hashicorp/vault/sdk/helper/logging" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestQuotas_Precedence(t *testing.T) { | ||||||
|  | 	qm, err := NewManager(logging.NewVaultLogger(log.Trace), nil, metricsutil.BlackholeSink()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	setQuotaFunc := func(t *testing.T, name, nsPath, mountPath string) Quota { | ||||||
|  | 		t.Helper() | ||||||
|  | 		quota := NewRateLimitQuota(name, nsPath, mountPath, 10, 20) | ||||||
|  | 		err := qm.SetQuota(context.Background(), TypeRateLimit.String(), quota, true) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 		return quota | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	checkQuotaFunc := func(t *testing.T, nsPath, mountPath string, expected Quota) { | ||||||
|  | 		t.Helper() | ||||||
|  | 		quota, err := qm.queryQuota(nil, &Request{ | ||||||
|  | 			Type:          TypeRateLimit, | ||||||
|  | 			NamespacePath: nsPath, | ||||||
|  | 			MountPath:     mountPath, | ||||||
|  | 		}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 		if diff := deep.Equal(expected, quota); len(diff) > 0 { | ||||||
|  | 			t.Fatal(diff) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// No quota present. Expect nil. | ||||||
|  | 	checkQuotaFunc(t, "", "", nil) | ||||||
|  |  | ||||||
|  | 	// Define global quota and expect that to be returned. | ||||||
|  | 	rateLimitGlobalQuota := setQuotaFunc(t, "rateLimitGlobalQuota", "", "") | ||||||
|  | 	checkQuotaFunc(t, "", "", rateLimitGlobalQuota) | ||||||
|  |  | ||||||
|  | 	// Define a global mount specific quota and expect that to be returned. | ||||||
|  | 	rateLimitGlobalMountQuota := setQuotaFunc(t, "rateLimitGlobalMountQuota", "", "testmount") | ||||||
|  | 	checkQuotaFunc(t, "", "testmount", rateLimitGlobalMountQuota) | ||||||
|  |  | ||||||
|  | 	// Define a namespace quota and expect that to be returned. | ||||||
|  | 	rateLimitNSQuota := setQuotaFunc(t, "rateLimitNSQuota", "testns", "") | ||||||
|  | 	checkQuotaFunc(t, "testns", "", rateLimitNSQuota) | ||||||
|  |  | ||||||
|  | 	// Define a namespace mount specific quota and expect that to be returned. | ||||||
|  | 	rateLimitNSMountQuota := setQuotaFunc(t, "rateLimitNSMountQuota", "testns", "testmount") | ||||||
|  | 	checkQuotaFunc(t, "testns", "testmount", rateLimitNSMountQuota) | ||||||
|  |  | ||||||
|  | 	// Now that many quota types are defined, verify that the most specific | ||||||
|  | 	// matches are returned per namespace. | ||||||
|  | 	checkQuotaFunc(t, "", "", rateLimitGlobalQuota) | ||||||
|  | 	checkQuotaFunc(t, "testns", "", rateLimitNSQuota) | ||||||
|  | } | ||||||
							
								
								
									
										65
									
								
								vault/quotas/quotas_util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								vault/quotas/quotas_util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,65 @@ | |||||||
|  | // +build !enterprise | ||||||
|  |  | ||||||
|  | package quotas | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  |  | ||||||
|  | 	log "github.com/hashicorp/go-hclog" | ||||||
|  | 	"github.com/hashicorp/vault/helper/metricsutil" | ||||||
|  |  | ||||||
|  | 	memdb "github.com/hashicorp/go-memdb" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func quotaTypes() []string { | ||||||
|  | 	return []string{ | ||||||
|  | 		TypeRateLimit.String(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *Manager) init(walkFunc leaseWalkFunc) {} | ||||||
|  |  | ||||||
|  | func (m *Manager) recomputeLeaseCounts(ctx context.Context, txn *memdb.Txn) error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *Manager) setIsPerfStandby(quota Quota) {} | ||||||
|  |  | ||||||
|  | func (m *Manager) inLeasePathCache(path string) bool { | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type entManager struct { | ||||||
|  | 	isPerfStandby bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (*entManager) Reset() error { | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type LeaseCountQuota struct { | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l LeaseCountQuota) allow(request *Request) (Response, error) { | ||||||
|  | 	panic("implement me") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l LeaseCountQuota) quotaID() string { | ||||||
|  | 	panic("implement me") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l LeaseCountQuota) QuotaName() string { | ||||||
|  | 	panic("implement me") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l LeaseCountQuota) initialize(logger log.Logger, sink *metricsutil.ClusterMetricSink) error { | ||||||
|  | 	panic("implement me") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l LeaseCountQuota) close() error { | ||||||
|  | 	panic("implement me") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (l LeaseCountQuota) handleRemount(s string) { | ||||||
|  | 	panic("implement me") | ||||||
|  | } | ||||||
| @@ -24,6 +24,7 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/sdk/helper/strutil" | 	"github.com/hashicorp/vault/sdk/helper/strutil" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/wrapping" | 	"github.com/hashicorp/vault/sdk/helper/wrapping" | ||||||
| 	"github.com/hashicorp/vault/sdk/logical" | 	"github.com/hashicorp/vault/sdk/logical" | ||||||
|  | 	"github.com/hashicorp/vault/vault/quotas" | ||||||
| 	uberAtomic "go.uber.org/atomic" | 	uberAtomic "go.uber.org/atomic" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -539,7 +540,6 @@ func (c *Core) handleCancelableRequest(ctx context.Context, ns *namespace.Namesp | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Create an audit trail of the response | 	// Create an audit trail of the response | ||||||
|  |  | ||||||
| 	if !isControlGroupRun(req) { | 	if !isControlGroupRun(req) { | ||||||
| 		switch req.Path { | 		switch req.Path { | ||||||
| 		case "sys/replication/dr/status", "sys/replication/performance/status", "sys/replication/status": | 		case "sys/replication/dr/status", "sys/replication/performance/status", "sys/replication/status": | ||||||
| @@ -708,6 +708,36 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	leaseGenerated := false | ||||||
|  | 	quotaResp, quotaErr := c.applyLeaseCountQuota("as.Request{ | ||||||
|  | 		Path:          req.Path, | ||||||
|  | 		MountPath:     strings.TrimPrefix(req.MountPoint, ns.Path), | ||||||
|  | 		NamespacePath: ns.Path, | ||||||
|  | 	}) | ||||||
|  | 	if quotaErr != nil { | ||||||
|  | 		c.logger.Error("failed to apply quota", "path", req.Path, "error", err) | ||||||
|  | 		retErr = multierror.Append(retErr, quotaErr) | ||||||
|  | 		return nil, auth, retErr | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !quotaResp.Allowed { | ||||||
|  | 		if c.logger.IsTrace() { | ||||||
|  | 			c.logger.Trace("request rejected due to lease count quota violation", "request_path", req.Path) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrLeaseCountQuotaExceeded)) | ||||||
|  | 		return nil, auth, retErr | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	defer func() { | ||||||
|  | 		if quotaResp.Access != nil { | ||||||
|  | 			quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated) | ||||||
|  | 			if quotaAckErr != nil { | ||||||
|  | 				retErr = multierror.Append(retErr, quotaAckErr) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
| 	// Route the request | 	// Route the request | ||||||
| 	resp, routeErr := c.doRouting(ctx, req) | 	resp, routeErr := c.doRouting(ctx, req) | ||||||
| 	if resp != nil { | 	if resp != nil { | ||||||
| @@ -827,6 +857,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp | |||||||
| 				retErr = multierror.Append(retErr, ErrInternalError) | 				retErr = multierror.Append(retErr, ErrInternalError) | ||||||
| 				return nil, auth, retErr | 				return nil, auth, retErr | ||||||
| 			} | 			} | ||||||
|  | 			leaseGenerated = true | ||||||
| 			resp.Secret.LeaseID = leaseID | 			resp.Secret.LeaseID = leaseID | ||||||
|  |  | ||||||
| 			// Get the actual time of the lease | 			// Get the actual time of the lease | ||||||
| @@ -917,6 +948,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp | |||||||
| 					retErr = multierror.Append(retErr, ErrInternalError) | 					retErr = multierror.Append(retErr, ErrInternalError) | ||||||
| 					return nil, auth, retErr | 					return nil, auth, retErr | ||||||
| 				} | 				} | ||||||
|  | 				leaseGenerated = true | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -1073,6 +1105,46 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re | |||||||
|  |  | ||||||
| 	// If the response generated an authentication, then generate the token | 	// If the response generated an authentication, then generate the token | ||||||
| 	if resp != nil && resp.Auth != nil { | 	if resp != nil && resp.Auth != nil { | ||||||
|  | 		ns, err := namespace.FromContext(ctx) | ||||||
|  | 		if err != nil { | ||||||
|  | 			c.logger.Error("failed to get namespace from context", "error", err) | ||||||
|  | 			retErr = multierror.Append(retErr, ErrInternalError) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		leaseGenerated := false | ||||||
|  |  | ||||||
|  | 		// The request successfully authenticated itself. Run the quota checks | ||||||
|  | 		// before creating lease. | ||||||
|  | 		quotaResp, quotaErr := c.applyLeaseCountQuota("as.Request{ | ||||||
|  | 			Path:          req.Path, | ||||||
|  | 			MountPath:     strings.TrimPrefix(req.MountPoint, ns.Path), | ||||||
|  | 			NamespacePath: ns.Path, | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		if quotaErr != nil { | ||||||
|  | 			c.logger.Error("failed to apply quota", "path", req.Path, "error", err) | ||||||
|  | 			retErr = multierror.Append(retErr, quotaErr) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !quotaResp.Allowed { | ||||||
|  | 			if c.logger.IsTrace() { | ||||||
|  | 				c.logger.Trace("request rejected due to lease count quota violation", "request_path", req.Path) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("request path %q: {{err}}", req.Path), quotas.ErrLeaseCountQuotaExceeded)) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		defer func() { | ||||||
|  | 			if quotaResp.Access != nil { | ||||||
|  | 				quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated) | ||||||
|  | 				if quotaAckErr != nil { | ||||||
|  | 					retErr = multierror.Append(retErr, quotaAckErr) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
|  |  | ||||||
| 		var entity *identity.Entity | 		var entity *identity.Entity | ||||||
| 		auth = resp.Auth | 		auth = resp.Auth | ||||||
| @@ -1141,10 +1213,6 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re | |||||||
| 			resp.AddWarning(warning) | 			resp.AddWarning(warning) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		ns, err := namespace.FromContext(ctx) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, nil, err |  | ||||||
| 		} |  | ||||||
| 		_, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, ns, auth.EntityID) | 		_, identityPolicies, err := c.fetchEntityAndDerivedPolicies(ctx, ns, auth.EntityID) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, nil, ErrInternalError | 			return nil, nil, ErrInternalError | ||||||
| @@ -1181,6 +1249,9 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re | |||||||
| 		err = registerFunc(ctx, tokenTTL, req.Path, auth) | 		err = registerFunc(ctx, tokenTTL, req.Path, auth) | ||||||
| 		switch { | 		switch { | ||||||
| 		case err == nil: | 		case err == nil: | ||||||
|  | 			if auth.TokenType != logical.TokenTypeBatch { | ||||||
|  | 				leaseGenerated = true | ||||||
|  | 			} | ||||||
| 		case err == ErrInternalError: | 		case err == ErrInternalError: | ||||||
| 			return nil, auth, err | 			return nil, auth, err | ||||||
| 		default: | 		default: | ||||||
|   | |||||||
| @@ -422,6 +422,14 @@ func (r *Router) MatchingSystemView(ctx context.Context, path string) logical.Sy | |||||||
| 	return raw.(*routeEntry).backend.System() | 	return raw.(*routeEntry).backend.System() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (r *Router) MatchingMountByAPIPath(ctx context.Context, path string) string { | ||||||
|  | 	me, _, _ := r.matchingMountEntryByPath(ctx, path, true) | ||||||
|  | 	if me == nil { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return me.Path | ||||||
|  | } | ||||||
|  |  | ||||||
| // MatchingStoragePrefixByAPIPath the storage prefix for the given api path | // MatchingStoragePrefixByAPIPath the storage prefix for the given api path | ||||||
| func (r *Router) MatchingStoragePrefixByAPIPath(ctx context.Context, path string) (string, bool) { | func (r *Router) MatchingStoragePrefixByAPIPath(ctx context.Context, path string) (string, bool) { | ||||||
| 	ns, err := namespace.FromContext(ctx) | 	ns, err := namespace.FromContext(ctx) | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								vendor/github.com/hashicorp/vault/api/response.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								vendor/github.com/hashicorp/vault/api/response.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -27,8 +27,8 @@ func (r *Response) DecodeJSON(out interface{}) error { | |||||||
| // body must still be closed manually. | // body must still be closed manually. | ||||||
| func (r *Response) Error() error { | func (r *Response) Error() error { | ||||||
| 	// 200 to 399 are okay status codes. 429 is the code for health status of | 	// 200 to 399 are okay status codes. 429 is the code for health status of | ||||||
| 	// standby nodes. | 	// standby nodes, otherwise, 429 is treated as quota limit reached. | ||||||
| 	if (r.StatusCode >= 200 && r.StatusCode < 400) || r.StatusCode == 429 { | 	if (r.StatusCode >= 200 && r.StatusCode < 400) || (r.StatusCode == 429 && r.Request.URL.Path == "/v1/sys/health") { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								vendor/github.com/hashicorp/vault/sdk/logical/error.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								vendor/github.com/hashicorp/vault/sdk/logical/error.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -28,6 +28,14 @@ var ( | |||||||
| 	// ErrPerfStandbyForward is returned when Vault is in a state such that a | 	// ErrPerfStandbyForward is returned when Vault is in a state such that a | ||||||
| 	// perf standby cannot satisfy a request | 	// perf standby cannot satisfy a request | ||||||
| 	ErrPerfStandbyPleaseForward = errors.New("please forward to the active node") | 	ErrPerfStandbyPleaseForward = errors.New("please forward to the active node") | ||||||
|  |  | ||||||
|  | 	// ErrLeaseCountQuotaExceeded is returned when a request is rejected due to a lease | ||||||
|  | 	// count quota being exceeded. | ||||||
|  | 	ErrLeaseCountQuotaExceeded = errors.New("lease count quota exceeded") | ||||||
|  |  | ||||||
|  | 	// ErrRateLimitQuotaExceeded is returned when a request is rejected due to a | ||||||
|  | 	// rate limit quota being exceeded. | ||||||
|  | 	ErrRateLimitQuotaExceeded = errors.New("rate limit quota exceeded") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type HTTPCodedError interface { | type HTTPCodedError interface { | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								vendor/github.com/hashicorp/vault/sdk/logical/response_util.go
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								vendor/github.com/hashicorp/vault/sdk/logical/response_util.go
									
									
									
										generated
									
									
										vendored
									
									
								
							| @@ -81,7 +81,7 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { | |||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
| 		if allErrors != nil { | 		if allErrors != nil { | ||||||
| 			return codedErr.Code, multierror.Append(errors.New(fmt.Sprintf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg)), allErrors) | 			return codedErr.Code, multierror.Append(fmt.Errorf("errors from both primary and secondary; primary error was %v; secondary errors follow", codedErr.Msg), allErrors) | ||||||
| 		} | 		} | ||||||
| 		return codedErr.Code, errors.New(codedErr.Msg) | 		return codedErr.Code, errors.New(codedErr.Msg) | ||||||
| 	} | 	} | ||||||
| @@ -110,6 +110,10 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { | |||||||
| 			statusCode = http.StatusBadRequest | 			statusCode = http.StatusBadRequest | ||||||
| 		case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): | 		case errwrap.Contains(err, ErrUpstreamRateLimited.Error()): | ||||||
| 			statusCode = http.StatusBadGateway | 			statusCode = http.StatusBadGateway | ||||||
|  | 		case errwrap.Contains(err, ErrRateLimitQuotaExceeded.Error()): | ||||||
|  | 			statusCode = http.StatusTooManyRequests | ||||||
|  | 		case errwrap.Contains(err, ErrLeaseCountQuotaExceeded.Error()): | ||||||
|  | 			statusCode = http.StatusTooManyRequests | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -163,6 +163,17 @@ These metrics cover measurement of token, identity, and lease operations, and co | |||||||
| | `vault.token.revoke-tree`    | Time taken to revoke a token tree                                                             | ms     | summary | | | `vault.token.revoke-tree`    | Time taken to revoke a token tree                                                             | ms     | summary | | ||||||
| | `vault.token.store`          | Time taken to store an updated token entry without writing to the secondary index             | ms     | summary | | | `vault.token.store`          | Time taken to store an updated token entry without writing to the secondary index             | ms     | summary | | ||||||
|  |  | ||||||
|  | ## Resource Quota Metrics | ||||||
|  |  | ||||||
|  | These metrics relate to rate limit and lease count quotas. Each metric comes with a label "name" identifying the specific quota. | ||||||
|  |  | ||||||
|  | | Metric                        | Description                                                       | Unit  | Type    | | ||||||
|  | | :---------------------------- | :---------------------------------------------------------------- | :---- | :------ | | ||||||
|  | | `quota.rate_limit.violation`  | Total number of rate limit quota violations                       | quota | counter | | ||||||
|  | | `quota.lease_count.violation` | Total number of lease count quota violations                      | quota | counter | | ||||||
|  | | `quota.lease_count.max`       | Total maximum amount of leases allowed by the lease count quota   | lease | gauge   | | ||||||
|  | | `quota.lease_count.counter`   | Total current amount of leases generated by the lease count quota | lease | gauge   | | ||||||
|  |  | ||||||
| ## Merkle Tree and Write Ahead Log Metrics | ## Merkle Tree and Write Ahead Log Metrics | ||||||
|  |  | ||||||
| These metrics relate to internal operations on Merkle Trees and Write Ahead Logs (WAL) | These metrics relate to internal operations on Merkle Trees and Write Ahead Logs (WAL) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Vishal Nayak
					Vishal Nayak