AOP: Force reject header (enterprise) (#26702)

This PR introduces the CE plumbing for a new HTTP header, called
X-Vault-AOP-Force-Reject, which will force any associated request to
reject storage writes as if Vault were overloaded.

This flag is intended to test end-to-end functionality of write
rejection in Vault. This is specifically useful for testing 503 -
Service Unavailable HTTP response codes during load shedding.
This commit is contained in:
Mike Palmiotto
2024-05-01 14:11:24 -04:00
committed by GitHub
parent b4a2e40124
commit c5fac98d2d
6 changed files with 112 additions and 5 deletions

View File

@@ -32,6 +32,7 @@ import (
"github.com/hashicorp/go-sockaddr"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/http/priority"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/sdk/helper/consts"
@@ -251,6 +252,7 @@ func handler(props *vault.HandlerProperties) http.Handler {
wrappedHandler = rateLimitQuotaWrapping(wrappedHandler, core)
wrappedHandler = entWrapGenericHandler(core, wrappedHandler, props)
wrappedHandler = wrapMaxRequestSizeHandler(wrappedHandler, props)
wrappedHandler = priority.WrapRequestPriorityHandler(wrappedHandler)
// Add an extra wrapping handler if the DisablePrintableCheck listener
// setting isn't true that checks for non-printable characters in the
@@ -1021,7 +1023,14 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l
}
resp.AddWarning("Timeout hit while waiting for local replicated cluster to apply primary's write; this client may encounter stale reads of values written during this operation.")
}
if errwrap.Contains(err, consts.ErrOverloaded.Error()) {
// We need to rely on string comparison here because the error could be
// returned from an RPC client call with a non-ReplicatedResponse return
// value (see: PersistAlias). In these cases, the error we get back will
// contain the non-wrapped error message string we're looking for. We would
// love to clean up all error wrapping to be consistent in Vault but we
// considered that too high risk for now.
if err != nil && strings.Contains(err.Error(), consts.ErrOverloaded.Error()) {
logical.RespondWithStatusCode(resp, r, http.StatusServiceUnavailable)
respondError(w, http.StatusServiceUnavailable, err)
return resp, false, false

View File

@@ -7,7 +7,6 @@ import (
"bufio"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
@@ -17,9 +16,9 @@ import (
"strings"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
@@ -386,8 +385,8 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool, noForw
// success.
resp, ok, needsForward := request(core, w, r, req)
switch {
case errors.Is(resp.Error(), limits.ErrCapacity):
respondError(w, http.StatusServiceUnavailable, limits.ErrCapacity)
case errwrap.Contains(resp.Error(), consts.ErrOverloaded.Error()):
respondError(w, http.StatusServiceUnavailable, consts.ErrOverloaded)
return
case needsForward && noForward:
respondError(w, http.StatusBadRequest, vault.ErrCannotForwardLocalOnly)

85
http/priority/priority.go Normal file
View File

@@ -0,0 +1,85 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package priority
import (
"context"
"net/http"
"strconv"
"github.com/hashicorp/vault/sdk/helper/parseutil"
"github.com/hashicorp/vault/sdk/logical"
)
const (
// VaultAOPForceRejectHeaderName is the name of an HTTP header that is used primarily
// for testing (it's not documented publicly). If set to "true" in a request
// that is subject to any form of Adaptive Overload Protection, the request
// will be rejected as if there is an overload. This is useful for
// deterministically testing the error handling plumbing as there are many
// possible code paths that need to be tested.
VaultAOPForceRejectHeaderName = "X-Vault-AOP-Force-Reject"
)
// Priorities are limited to 256 levels to keep the state space small making
// enforcement data structures much more efficient.
type AOPWritePriority uint8
const (
// AlwaysDrop is intended for testing only and will cause the request to be
// rejected with a 503 even if the server is not overloaded.
AlwaysDrop AOPWritePriority = 0
// StandardHTTP is the default AOPWritePriority for HTTP requests.
StandardHTTP AOPWritePriority = 128
)
// String returns the string representation of the AOPWritePriority.
func (p AOPWritePriority) String() string {
switch p {
case AlwaysDrop:
return strconv.FormatUint(uint64(p), 8)
}
return ""
}
// StringToAOPWritePriority converts a string to an AOPWritePriority.
func StringToAOPWritePriority(s string) AOPWritePriority {
// Just swallow the error and fall back to the standard priority
p, err := strconv.ParseUint(s, 8, 8)
if err != nil {
return StandardHTTP
}
return AOPWritePriority(p)
}
// WrapRequestPriorityHandler provides special handling for headers with
// X-Vault-AOP-Force-Reject set to `true`. This is useful for testing status
// codes and return values related to Adaptive Overload Protection without
// overloading Vault.
func WrapRequestPriorityHandler(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if raw := req.Header.Get(VaultAOPForceRejectHeaderName); raw != "" {
if fail, _ := parseutil.ParseBool(raw); fail {
// Make the request fail as if Vault was overloaded. We don't
// explicitly error out here, but rather attach some context
// indicating that the PID controller should perform a
// rejection. This allows us to test errors propagated from the
// WAL backend.
req = req.WithContext(ContextWithRequestPriority(req.Context(), AlwaysDrop))
}
}
handler.ServeHTTP(w, req)
})
}
// ContextWithRequestPriority returns a new context derived from ctx with the
// given priority set.
func ContextWithRequestPriority(ctx context.Context, priority AOPWritePriority) context.Context {
if _, ok := ctx.Value(logical.CtxKeyInFlightRequestPriority{}).(AOPWritePriority); ok {
return ctx
}
return context.WithValue(ctx, logical.CtxKeyInFlightRequestPriority{}, priority)
}

View File

@@ -478,6 +478,12 @@ func (c CtxKeyInFlightRequestID) String() string {
return "in-flight-request-ID"
}
type CtxKeyInFlightRequestPriority struct{}
func (c CtxKeyInFlightRequestPriority) String() string {
return "in-flight-request-priority"
}
type CtxKeyRequestRole struct{}
func (c CtxKeyRequestRole) String() string {

View File

@@ -112,6 +112,8 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
// appropriate code
if err != nil {
switch {
case errwrap.Contains(err, consts.ErrOverloaded.Error()):
statusCode = http.StatusServiceUnavailable
case errwrap.ContainsType(err, new(StatusBadRequest)):
statusCode = http.StatusBadRequest
case errwrap.Contains(err, ErrPermissionDenied.Error()):

View File

@@ -30,6 +30,7 @@ import (
"github.com/hashicorp/vault/helper/identity/mfa"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/http/priority"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
@@ -588,6 +589,11 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R
if ok {
ctx = logical.CreateContextRedactionSettings(ctx, redactVersion, redactAddresses, redactClusterName)
}
inFlightRequestPriority, ok := httpCtx.Value(logical.CtxKeyInFlightRequestPriority{}).(priority.AOPWritePriority)
if ok {
ctx = context.WithValue(ctx, logical.CtxKeyInFlightRequestPriority{}, inFlightRequestPriority)
}
resp, err = c.handleCancelableRequest(ctx, req)
req.SetTokenEntry(nil)
cancel()