diff --git a/helper/forwarding/util.go b/helper/forwarding/util.go index 1bc1c5e68f..e9d06de048 100644 --- a/helper/forwarding/util.go +++ b/helper/forwarding/util.go @@ -7,9 +7,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" - "errors" "io" - "io/ioutil" "net/http" "net/url" "os" @@ -17,7 +15,6 @@ import ( "github.com/golang/protobuf/proto" "github.com/hashicorp/vault/sdk/helper/compressutil" "github.com/hashicorp/vault/sdk/helper/jsonutil" - "github.com/hashicorp/vault/sdk/logical" ) type bufCloser struct { @@ -64,18 +61,7 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request func GenerateForwardedRequest(req *http.Request) (*Request, error) { var reader io.Reader = req.Body - ctx := req.Context() - if logical.ContextContainsMaxRequestSize(ctx) { - max, ok := logical.ContextMaxRequestSizeValue(ctx) - if !ok { - return nil, errors.New("could not parse max request size from request context") - } - if max > 0 { - reader = io.LimitReader(req.Body, max) - } - } - - body, err := ioutil.ReadAll(reader) + body, err := io.ReadAll(reader) if err != nil { return nil, err } diff --git a/http/handler.go b/http/handler.go index 37e769aa5e..c3f49d32b6 100644 --- a/http/handler.go +++ b/http/handler.go @@ -242,6 +242,7 @@ func handler(props *vault.HandlerProperties) http.Handler { wrappedHandler = wrapCORSHandler(wrappedHandler, core) wrappedHandler = rateLimitQuotaWrapping(wrappedHandler, core) wrappedHandler = entWrapGenericHandler(core, wrappedHandler, props) + wrappedHandler = wrapMaxRequestSizeHandler(wrappedHandler, props) // Add an extra wrapping handler if the DisablePrintableCheck listener // setting isn't true that checks for non-printable characters in the @@ -332,18 +333,12 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { // are performed. func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerProperties) http.Handler { var maxRequestDuration time.Duration - var maxRequestSize int64 if props.ListenerConfig != nil { maxRequestDuration = props.ListenerConfig.MaxRequestDuration - maxRequestSize = props.ListenerConfig.MaxRequestSize } if maxRequestDuration == 0 { maxRequestDuration = vault.DefaultMaxRequestDuration } - if maxRequestSize == 0 { - maxRequestSize = DefaultMaxRequestSize - } - // Swallow this error since we don't want to pollute the logs and we also don't want to // return an HTTP error here. This information is best effort. hostname, _ := os.Hostname() @@ -378,11 +373,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration) } - // if maxRequestSize < 0, no need to set context value - // Add a size limiter if desired - if maxRequestSize > 0 { - ctx = logical.CreateContextMaxRequestSize(ctx, maxRequestSize) - } ctx = logical.CreateContextOriginalRequestPath(ctx, r.URL.Path) r = r.WithContext(ctx) r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) @@ -738,24 +728,6 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, // Limit the maximum number of bytes to MaxRequestSize to protect // against an indefinite amount of data being read. reader := r.Body - ctx := r.Context() - if logical.ContextContainsMaxRequestSize(ctx) { - max, ok := logical.ContextMaxRequestSizeValue(ctx) - if !ok { - return nil, errors.New("could not parse max request size from request context") - } - if max > 0 { - // MaxBytesReader won't do all the internal stuff it must unless it's - // given a ResponseWriter that implements the internal http interface - // requestTooLarger. So we let it have access to the underlying - // ResponseWriter. - inw := w - if myw, ok := inw.(logical.WrappingResponseWriter); ok { - inw = myw.Wrapped() - } - reader = http.MaxBytesReader(inw, r.Body, max) - } - } var origBody io.ReadWriter @@ -779,16 +751,6 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, // // A nil map will be returned if the format is empty or invalid. func parseFormRequest(r *http.Request) (map[string]interface{}, error) { - if logical.ContextContainsMaxRequestSize(r.Context()) { - max, ok := logical.ContextMaxRequestSizeValue(r.Context()) - if !ok { - return nil, errors.New("could not parse max request size from request context") - } - if max > 0 { - r.Body = io.NopCloser(io.LimitReader(r.Body, max)) - } - } - if err := r.ParseForm(); err != nil { return nil, err } diff --git a/http/handler_test.go b/http/handler_test.go index a669672b60..92e1eacc46 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -4,6 +4,7 @@ package http import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -14,10 +15,12 @@ import ( "net/textproto" "net/url" "reflect" + "runtime" "strings" "testing" "github.com/hashicorp/vault/internalshared/configutil" + "github.com/stretchr/testify/require" "github.com/go-test/deep" "github.com/hashicorp/go-cleanhttp" @@ -892,3 +895,59 @@ func TestHandler_Parse_Form(t *testing.T) { t.Fatal(diff) } } + +// TestHandler_MaxRequestSize verifies that a request larger than the +// MaxRequestSize fails +func TestHandler_MaxRequestSize(t *testing.T) { + t.Parallel() + cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{ + DefaultHandlerProperties: vault.HandlerProperties{ + ListenerConfig: &configutil.Listener{ + MaxRequestSize: 1024, + }, + }, + HandlerFunc: Handler, + NumCores: 1, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + _, err := client.KVv2("secret").Put(context.Background(), "foo", map[string]interface{}{ + "bar": strings.Repeat("a", 1025), + }) + + require.ErrorContains(t, err, "error parsing JSON") +} + +// TestHandler_MaxRequestSize_Memory sets the max request size to 1024 bytes, +// and creates a 1MB request. The test verifies that less than 1MB of memory is +// allocated when the request is sent. This test shouldn't be run in parallel, +// because it modifies GOMAXPROCS +func TestHandler_MaxRequestSize_Memory(t *testing.T) { + ln, addr := TestListener(t) + core, _, token := vault.TestCoreUnsealed(t) + TestServerWithListenerAndProperties(t, ln, addr, core, &vault.HandlerProperties{ + Core: core, + ListenerConfig: &configutil.Listener{ + Address: addr, + MaxRequestSize: 1024, + }, + }) + defer ln.Close() + + data := bytes.Repeat([]byte{0x1}, 1024*1024) + + req, err := http.NewRequest("POST", addr+"/v1/sys/unseal", bytes.NewReader(data)) + require.NoError(t, err) + req.Header.Set(consts.AuthHeaderName, token) + + client := cleanhttp.DefaultClient() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + var start, end runtime.MemStats + runtime.GC() + runtime.ReadMemStats(&start) + client.Do(req) + runtime.ReadMemStats(&end) + require.Less(t, end.TotalAlloc-start.TotalAlloc, uint64(1024*1024)) +} diff --git a/http/util.go b/http/util.go index abf03b29ef..4de8f81326 100644 --- a/http/util.go +++ b/http/util.go @@ -6,13 +6,13 @@ package http import ( "bytes" "context" - "errors" "fmt" - "io/ioutil" + "io" "net" "net/http" "strings" + "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/helper/namespace" @@ -22,6 +22,27 @@ import ( var nonVotersAllowed = false +func wrapMaxRequestSizeHandler(handler http.Handler, props *vault.HandlerProperties) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var maxRequestSize int64 + if props.ListenerConfig != nil { + maxRequestSize = props.ListenerConfig.MaxRequestSize + } + if maxRequestSize == 0 { + maxRequestSize = DefaultMaxRequestSize + } + ctx := r.Context() + originalBody := r.Body + if maxRequestSize > 0 { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + } + ctx = logical.CreateContextOriginalBody(ctx, originalBody) + r = r.WithContext(ctx) + + handler.ServeHTTP(w, r) + }) +} + func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ns, err := namespace.FromContext(r.Context()) @@ -40,14 +61,6 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler } mountPath := strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path) - // Clone body, so we do not close the request body reader - bodyBytes, err := ioutil.ReadAll(r.Body) - if err != nil { - respondError(w, http.StatusInternalServerError, errors.New("failed to read request body")) - return - } - r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) - quotaReq := "as.Request{ Type: quotas.TypeRateLimit, Path: path, @@ -67,7 +80,18 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler // If any role-based quotas are enabled for this namespace/mount, just // do the role resolution once here. if requiresResolveRole { - role := core.DetermineRoleFromLoginRequestFromBytes(r.Context(), mountPath, bodyBytes) + buf := bytes.Buffer{} + teeReader := io.TeeReader(r.Body, &buf) + role := core.DetermineRoleFromLoginRequestFromReader(r.Context(), mountPath, teeReader) + + // Reset the body if it was read + if buf.Len() > 0 { + r.Body = io.NopCloser(&buf) + originalBody, ok := logical.ContextOriginalBodyValue(r.Context()) + if ok { + r = r.WithContext(logical.CreateContextOriginalBody(r.Context(), newMultiReaderCloser(&buf, originalBody))) + } + } // add an entry to the context to prevent recalculating request role unnecessarily r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role)) quotaReq.Role = role @@ -134,3 +158,25 @@ func parseRemoteIPAddress(r *http.Request) string { return ip } + +type multiReaderCloser struct { + readers []io.Reader + io.Reader +} + +func newMultiReaderCloser(readers ...io.Reader) *multiReaderCloser { + return &multiReaderCloser{ + readers: readers, + Reader: io.MultiReader(readers...), + } +} + +func (m *multiReaderCloser) Close() error { + var err error + for _, r := range m.readers { + if c, ok := r.(io.Closer); ok { + err = multierror.Append(err, c.Close()) + } + } + return err +} diff --git a/sdk/logical/request.go b/sdk/logical/request.go index a4850e0eb5..176b7013a3 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -6,6 +6,7 @@ package logical import ( "context" "fmt" + "io" "net/http" "strings" "time" @@ -483,38 +484,6 @@ func CreateContextDisableReplicationStatusEndpoints(parent context.Context, valu return context.WithValue(parent, ctxKeyDisableReplicationStatusEndpoints{}, value) } -// CtxKeyMaxRequestSize is a custom type used as a key in context.Context to -// store the value of the max_request_size set for the listener through which -// a request was received. -type ctxKeyMaxRequestSize struct{} - -// String returns a string representation of the receiver type. -func (c ctxKeyMaxRequestSize) String() string { - return "max_request_size" -} - -// ContextMaxRequestSizeValue examines the provided context.Context for the max -// request size value and returns it as an int64 value if it's found along with -// the ok value set to true; otherwise the ok return value is false. -func ContextMaxRequestSizeValue(ctx context.Context) (value int64, ok bool) { - value, ok = ctx.Value(ctxKeyMaxRequestSize{}).(int64) - - return -} - -// CreateContextMaxRequestSize creates a new context.Context based on the -// provided parent that also includes the provided max request size value for -// the ctxKeyMaxRequestSize key. -func CreateContextMaxRequestSize(parent context.Context, value int64) context.Context { - return context.WithValue(parent, ctxKeyMaxRequestSize{}, value) -} - -// ContextContainsMaxRequestSize returns a bool value that indicates if the -// provided Context contains a value for the ctxKeyMaxRequestSize key. -func ContextContainsMaxRequestSize(ctx context.Context) bool { - return ctx.Value(ctxKeyMaxRequestSize{}) != nil -} - // CtxKeyOriginalRequestPath is a custom type used as a key in context.Context // to store the original request path. type ctxKeyOriginalRequestPath struct{} @@ -539,3 +508,14 @@ func ContextOriginalRequestPathValue(ctx context.Context) (value string, ok bool func CreateContextOriginalRequestPath(parent context.Context, value string) context.Context { return context.WithValue(parent, ctxKeyOriginalRequestPath{}, value) } + +type ctxKeyOriginalBody struct{} + +func ContextOriginalBodyValue(ctx context.Context) (io.ReadCloser, bool) { + value, ok := ctx.Value(ctxKeyOriginalBody{}).(io.ReadCloser) + return value, ok +} + +func CreateContextOriginalBody(parent context.Context, body io.ReadCloser) context.Context { + return context.WithValue(parent, ctxKeyOriginalBody{}, body) +} diff --git a/sdk/logical/request_test.go b/sdk/logical/request_test.go index 1b07cc1b50..4e05471035 100644 --- a/sdk/logical/request_test.go +++ b/sdk/logical/request_test.go @@ -76,72 +76,6 @@ func TestCreateContextDisableReplicationStatusEndpoints(t *testing.T) { assert.Equal(t, false, value.(bool)) } -func TestContextMaxRequestSizeValue(t *testing.T) { - testcases := []struct { - name string - ctx context.Context - expectedValue int64 - expectedOk bool - }{ - { - name: "without-value", - ctx: context.Background(), - expectedValue: 0, - expectedOk: false, - }, - { - name: "with-nil", - ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, nil), - expectedValue: 0, - expectedOk: false, - }, - { - name: "with-incompatible-value", - ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, "6666"), - expectedValue: 0, - expectedOk: false, - }, - { - name: "with-int64-8888", - ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(8888)), - expectedValue: 8888, - expectedOk: true, - }, - { - name: "with-int64-zero", - ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(0)), - expectedValue: 0, - expectedOk: true, - }, - } - - for _, testcase := range testcases { - value, ok := ContextMaxRequestSizeValue(testcase.ctx) - assert.Equal(t, testcase.expectedValue, value, testcase.name) - assert.Equal(t, testcase.expectedOk, ok, testcase.name) - } -} - -func TestCreateContextMaxRequestSize(t *testing.T) { - ctx := CreateContextMaxRequestSize(context.Background(), int64(8888)) - - value := ctx.Value(ctxKeyMaxRequestSize{}) - - assert.NotNil(t, ctx) - assert.NotNil(t, value) - assert.IsType(t, int64(0), value) - assert.Equal(t, int64(8888), value.(int64)) - - ctx = CreateContextMaxRequestSize(context.Background(), int64(0)) - - value = ctx.Value(ctxKeyMaxRequestSize{}) - - assert.NotNil(t, ctx) - assert.NotNil(t, value) - assert.IsType(t, int64(0), value) - assert.Equal(t, int64(0), value.(int64)) -} - func TestContextOriginalRequestPathValue(t *testing.T) { testcases := []struct { name string diff --git a/vault/core.go b/vault/core.go index 1d7573c1b8..e02e53a82e 100644 --- a/vault/core.go +++ b/vault/core.go @@ -3999,19 +3999,6 @@ func (c *Core) LoadNodeID() (string, error) { return hostname, nil } -// DetermineRoleFromLoginRequestFromBytes will determine the role that should be applied to a quota for a given -// login request, accepting a byte payload -func (c *Core) DetermineRoleFromLoginRequestFromBytes(ctx context.Context, mountPoint string, payload []byte) string { - data := make(map[string]interface{}) - err := jsonutil.DecodeJSON(payload, &data) - if err != nil { - // Cannot discern a role from a request we cannot parse - return "" - } - - return c.DetermineRoleFromLoginRequest(ctx, mountPoint, data) -} - // DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given // login request func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string { @@ -4022,7 +4009,33 @@ func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint str // Role based quotas do not apply to this request return "" } + return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data) +} +// DetermineRoleFromLoginRequestFromReader will determine the role that should +// be applied to a quota for a given login request. The reader will only be +// consumed if the matching backend for the mount point exists and is a secret +// backend +func (c *Core) DetermineRoleFromLoginRequestFromReader(ctx context.Context, mountPoint string, reader io.Reader) string { + c.authLock.RLock() + defer c.authLock.RUnlock() + matchingBackend := c.router.MatchingBackend(ctx, mountPoint) + if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential { + // Role based quotas do not apply to this request + return "" + } + + data := make(map[string]interface{}) + err := jsonutil.DecodeJSONFromReader(reader, &data) + if err != nil { + return "" + } + return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data) +} + +// doResolveRoleLocked does a login and resolve role request on the matching +// backend. Callers should have a read lock on c.authLock +func (c *Core) doResolveRoleLocked(ctx context.Context, mountPoint string, matchingBackend logical.Backend, data map[string]interface{}) string { resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{ MountPoint: mountPoint, Path: "login", diff --git a/vault/logical_system_raft.go b/vault/logical_system_raft.go index d270d1d39b..7b25407605 100644 --- a/vault/logical_system_raft.go +++ b/vault/logical_system_raft.go @@ -570,7 +570,8 @@ func (b *SystemBackend) handleStorageRaftSnapshotWrite(force bool, makeSealer fu if !ok { return logical.ErrorResponse("raft storage is not in use"), logical.ErrInvalidRequest } - if req.HTTPRequest == nil || req.HTTPRequest.Body == nil { + body, ok := logical.ContextOriginalBodyValue(ctx) + if !ok { return nil, errors.New("no reader for request") } @@ -583,7 +584,7 @@ func (b *SystemBackend) handleStorageRaftSnapshotWrite(force bool, makeSealer fu // don't have to hold the full snapshot in memory. We also want to do // the restore in two parts so we can restore the snapshot while the // stateLock is write locked. - snapFile, cleanup, metadata, err := raftStorage.WriteSnapshotToTemp(req.HTTPRequest.Body, sealer) + snapFile, cleanup, metadata, err := raftStorage.WriteSnapshotToTemp(body, sealer) switch { case err == nil: case strings.Contains(err.Error(), "failed to open the sealed hashes"): diff --git a/vault/request_handling.go b/vault/request_handling.go index b8296f34f1..034146b5c0 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -581,6 +581,10 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R if disable_repl_status, ok := logical.ContextDisableReplicationStatusEndpointsValue(httpCtx); ok { ctx = logical.CreateContextDisableReplicationStatusEndpoints(ctx, disable_repl_status) } + body, ok := logical.ContextOriginalBodyValue(httpCtx) + if ok { + ctx = logical.CreateContextOriginalBody(ctx, body) + } resp, err = c.handleCancelableRequest(ctx, req) req.SetTokenEntry(nil) cancel()