Buffer body read up to MaxRequestSize (#24354)

This commit is contained in:
Hamid Ghaf
2023-12-04 13:22:22 -08:00
committed by GitHub
parent cb217388d4
commit aeb817dfba
9 changed files with 163 additions and 178 deletions

View File

@@ -6,13 +6,13 @@ package http
import (
"bytes"
"context"
"errors"
"fmt"
"io/ioutil"
"io"
"net"
"net/http"
"strings"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/helper/namespace"
@@ -22,6 +22,27 @@ import (
var nonVotersAllowed = false
func wrapMaxRequestSizeHandler(handler http.Handler, props *vault.HandlerProperties) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var maxRequestSize int64
if props.ListenerConfig != nil {
maxRequestSize = props.ListenerConfig.MaxRequestSize
}
if maxRequestSize == 0 {
maxRequestSize = DefaultMaxRequestSize
}
ctx := r.Context()
originalBody := r.Body
if maxRequestSize > 0 {
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
}
ctx = logical.CreateContextOriginalBody(ctx, originalBody)
r = r.WithContext(ctx)
handler.ServeHTTP(w, r)
})
}
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ns, err := namespace.FromContext(r.Context())
@@ -40,14 +61,6 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
}
mountPath := strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path)
// Clone body, so we do not close the request body reader
bodyBytes, err := ioutil.ReadAll(r.Body)
if err != nil {
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
return
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
quotaReq := &quotas.Request{
Type: quotas.TypeRateLimit,
Path: path,
@@ -67,7 +80,18 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
// If any role-based quotas are enabled for this namespace/mount, just
// do the role resolution once here.
if requiresResolveRole {
role := core.DetermineRoleFromLoginRequestFromBytes(r.Context(), mountPath, bodyBytes)
buf := bytes.Buffer{}
teeReader := io.TeeReader(r.Body, &buf)
role := core.DetermineRoleFromLoginRequestFromReader(r.Context(), mountPath, teeReader)
// Reset the body if it was read
if buf.Len() > 0 {
r.Body = io.NopCloser(&buf)
originalBody, ok := logical.ContextOriginalBodyValue(r.Context())
if ok {
r = r.WithContext(logical.CreateContextOriginalBody(r.Context(), newMultiReaderCloser(&buf, originalBody)))
}
}
// add an entry to the context to prevent recalculating request role unnecessarily
r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role))
quotaReq.Role = role
@@ -134,3 +158,25 @@ func parseRemoteIPAddress(r *http.Request) string {
return ip
}
type multiReaderCloser struct {
readers []io.Reader
io.Reader
}
func newMultiReaderCloser(readers ...io.Reader) *multiReaderCloser {
return &multiReaderCloser{
readers: readers,
Reader: io.MultiReader(readers...),
}
}
func (m *multiReaderCloser) Close() error {
var err error
for _, r := range m.readers {
if c, ok := r.(io.Closer); ok {
err = multierror.Append(err, c.Close())
}
}
return err
}