mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 17:52:32 +00:00
Buffer body read up to MaxRequestSize (#24354)
This commit is contained in:
68
http/util.go
68
http/util.go
@@ -6,13 +6,13 @@ package http
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
@@ -22,6 +22,27 @@ import (
|
||||
|
||||
var nonVotersAllowed = false
|
||||
|
||||
func wrapMaxRequestSizeHandler(handler http.Handler, props *vault.HandlerProperties) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var maxRequestSize int64
|
||||
if props.ListenerConfig != nil {
|
||||
maxRequestSize = props.ListenerConfig.MaxRequestSize
|
||||
}
|
||||
if maxRequestSize == 0 {
|
||||
maxRequestSize = DefaultMaxRequestSize
|
||||
}
|
||||
ctx := r.Context()
|
||||
originalBody := r.Body
|
||||
if maxRequestSize > 0 {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||
}
|
||||
ctx = logical.CreateContextOriginalBody(ctx, originalBody)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ns, err := namespace.FromContext(r.Context())
|
||||
@@ -40,14 +61,6 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
|
||||
}
|
||||
mountPath := strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path)
|
||||
|
||||
// Clone body, so we do not close the request body reader
|
||||
bodyBytes, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
|
||||
return
|
||||
}
|
||||
r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
quotaReq := "as.Request{
|
||||
Type: quotas.TypeRateLimit,
|
||||
Path: path,
|
||||
@@ -67,7 +80,18 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
|
||||
// If any role-based quotas are enabled for this namespace/mount, just
|
||||
// do the role resolution once here.
|
||||
if requiresResolveRole {
|
||||
role := core.DetermineRoleFromLoginRequestFromBytes(r.Context(), mountPath, bodyBytes)
|
||||
buf := bytes.Buffer{}
|
||||
teeReader := io.TeeReader(r.Body, &buf)
|
||||
role := core.DetermineRoleFromLoginRequestFromReader(r.Context(), mountPath, teeReader)
|
||||
|
||||
// Reset the body if it was read
|
||||
if buf.Len() > 0 {
|
||||
r.Body = io.NopCloser(&buf)
|
||||
originalBody, ok := logical.ContextOriginalBodyValue(r.Context())
|
||||
if ok {
|
||||
r = r.WithContext(logical.CreateContextOriginalBody(r.Context(), newMultiReaderCloser(&buf, originalBody)))
|
||||
}
|
||||
}
|
||||
// add an entry to the context to prevent recalculating request role unnecessarily
|
||||
r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role))
|
||||
quotaReq.Role = role
|
||||
@@ -134,3 +158,25 @@ func parseRemoteIPAddress(r *http.Request) string {
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
type multiReaderCloser struct {
|
||||
readers []io.Reader
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func newMultiReaderCloser(readers ...io.Reader) *multiReaderCloser {
|
||||
return &multiReaderCloser{
|
||||
readers: readers,
|
||||
Reader: io.MultiReader(readers...),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *multiReaderCloser) Close() error {
|
||||
var err error
|
||||
for _, r := range m.readers {
|
||||
if c, ok := r.(io.Closer); ok {
|
||||
err = multierror.Append(err, c.Close())
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user