diff --git a/http/handler.go b/http/handler.go index 69ed6bdc3a..1023384d26 100644 --- a/http/handler.go +++ b/http/handler.go @@ -32,6 +32,7 @@ import ( "github.com/hashicorp/go-sockaddr" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/http/priority" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/limits" "github.com/hashicorp/vault/sdk/helper/consts" @@ -251,6 +252,7 @@ func handler(props *vault.HandlerProperties) http.Handler { wrappedHandler = rateLimitQuotaWrapping(wrappedHandler, core) wrappedHandler = entWrapGenericHandler(core, wrappedHandler, props) wrappedHandler = wrapMaxRequestSizeHandler(wrappedHandler, props) + wrappedHandler = priority.WrapRequestPriorityHandler(wrappedHandler) // Add an extra wrapping handler if the DisablePrintableCheck listener // setting isn't true that checks for non-printable characters in the @@ -1021,7 +1023,14 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l } resp.AddWarning("Timeout hit while waiting for local replicated cluster to apply primary's write; this client may encounter stale reads of values written during this operation.") } - if errwrap.Contains(err, consts.ErrOverloaded.Error()) { + + // We need to rely on string comparison here because the error could be + // returned from an RPC client call with a non-ReplicatedResponse return + // value (see: PersistAlias). In these cases, the error we get back will + // contain the non-wrapped error message string we're looking for. We would + // love to clean up all error wrapping to be consistent in Vault but we + // considered that too high risk for now. + if err != nil && strings.Contains(err.Error(), consts.ErrOverloaded.Error()) { logical.RespondWithStatusCode(resp, r, http.StatusServiceUnavailable) respondError(w, http.StatusServiceUnavailable, err) return resp, false, false diff --git a/http/logical.go b/http/logical.go index cf80df2b0f..089a24b9c9 100644 --- a/http/logical.go +++ b/http/logical.go @@ -7,7 +7,6 @@ import ( "bufio" "encoding/base64" "encoding/json" - "errors" "fmt" "io" "mime" @@ -17,9 +16,9 @@ import ( "strings" "time" + "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/namespace" - "github.com/hashicorp/vault/limits" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" @@ -386,8 +385,8 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw // success. resp, ok, needsForward := request(core, w, r, req) switch { - case errors.Is(resp.Error(), limits.ErrCapacity): - respondError(w, http.StatusServiceUnavailable, limits.ErrCapacity) + case errwrap.Contains(resp.Error(), consts.ErrOverloaded.Error()): + respondError(w, http.StatusServiceUnavailable, consts.ErrOverloaded) return case needsForward && noForward: respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly) diff --git a/http/priority/priority.go b/http/priority/priority.go new file mode 100644 index 0000000000..31bcbebeae --- /dev/null +++ b/http/priority/priority.go @@ -0,0 +1,85 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package priority + +import ( + "context" + "net/http" + "strconv" + + "github.com/hashicorp/vault/sdk/helper/parseutil" + "github.com/hashicorp/vault/sdk/logical" +) + +const ( + // VaultAOPForceRejectHeaderName is the name of an HTTP header that is used primarily + // for testing (it's not documented publicly). If set to "true" in a request + // that is subject to any form of Adaptive Overload Protection, the request + // will be rejected as if there is an overload. This is useful for + // deterministically testing the error handling plumbing as there are many + // possible code paths that need to be tested. + VaultAOPForceRejectHeaderName = "X-Vault-AOP-Force-Reject" +) + +// Priorities are limited to 256 levels to keep the state space small making +// enforcement data structures much more efficient. +type AOPWritePriority uint8 + +const ( + // AlwaysDrop is intended for testing only and will cause the request to be + // rejected with a 503 even if the server is not overloaded. + AlwaysDrop AOPWritePriority = 0 + + // StandardHTTP is the default AOPWritePriority for HTTP requests. + StandardHTTP AOPWritePriority = 128 +) + +// String returns the string representation of the AOPWritePriority. +func (p AOPWritePriority) String() string { + switch p { + case AlwaysDrop: + return strconv.FormatUint(uint64(p), 8) + } + return "" +} + +// StringToAOPWritePriority converts a string to an AOPWritePriority. +func StringToAOPWritePriority(s string) AOPWritePriority { + // Just swallow the error and fall back to the standard priority + p, err := strconv.ParseUint(s, 8, 8) + if err != nil { + return StandardHTTP + } + return AOPWritePriority(p) +} + +// WrapRequestPriorityHandler provides special handling for headers with +// X-Vault-AOP-Force-Reject set to `true`. This is useful for testing status +// codes and return values related to Adaptive Overload Protection without +// overloading Vault. +func WrapRequestPriorityHandler(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if raw := req.Header.Get(VaultAOPForceRejectHeaderName); raw != "" { + if fail, _ := parseutil.ParseBool(raw); fail { + // Make the request fail as if Vault was overloaded. We don't + // explicitly error out here, but rather attach some context + // indicating that the PID controller should perform a + // rejection. This allows us to test errors propagated from the + // WAL backend. + req = req.WithContext(ContextWithRequestPriority(req.Context(), AlwaysDrop)) + } + } + handler.ServeHTTP(w, req) + }) +} + +// ContextWithRequestPriority returns a new context derived from ctx with the +// given priority set. +func ContextWithRequestPriority(ctx context.Context, priority AOPWritePriority) context.Context { + if _, ok := ctx.Value(logical.CtxKeyInFlightRequestPriority{}).(AOPWritePriority); ok { + return ctx + } + + return context.WithValue(ctx, logical.CtxKeyInFlightRequestPriority{}, priority) +} diff --git a/sdk/logical/request.go b/sdk/logical/request.go index 33bd850d49..01ae7b948e 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -478,6 +478,12 @@ func (c CtxKeyInFlightRequestID) String() string { return "in-flight-request-ID" } +type CtxKeyInFlightRequestPriority struct{} + +func (c CtxKeyInFlightRequestPriority) String() string { + return "in-flight-request-priority" +} + type CtxKeyRequestRole struct{} func (c CtxKeyRequestRole) String() string { diff --git a/sdk/logical/response_util.go b/sdk/logical/response_util.go index 2aebe3a43b..f47663c582 100644 --- a/sdk/logical/response_util.go +++ b/sdk/logical/response_util.go @@ -112,6 +112,8 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { // appropriate code if err != nil { switch { + case errwrap.Contains(err, consts.ErrOverloaded.Error()): + statusCode = http.StatusServiceUnavailable case errwrap.ContainsType(err, new(StatusBadRequest)): statusCode = http.StatusBadRequest case errwrap.Contains(err, ErrPermissionDenied.Error()): diff --git a/vault/request_handling.go b/vault/request_handling.go index e4e84627f6..d447aff781 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -30,6 +30,7 @@ import ( "github.com/hashicorp/vault/helper/identity/mfa" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/namespace" + "github.com/hashicorp/vault/http/priority" "github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" @@ -588,6 +589,11 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R if ok { ctx = logical.CreateContextRedactionSettings(ctx, redactVersion, redactAddresses, redactClusterName) } + inFlightRequestPriority, ok := httpCtx.Value(logical.CtxKeyInFlightRequestPriority{}).(priority.AOPWritePriority) + if ok { + ctx = context.WithValue(ctx, logical.CtxKeyInFlightRequestPriority{}, inFlightRequestPriority) + } + resp, err = c.handleCancelableRequest(ctx, req) req.SetTokenEntry(nil) cancel()