mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 17:52:32 +00:00
Buffer body read up to MaxRequestSize (#24354)
This commit is contained in:
@@ -7,9 +7,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -17,7 +15,6 @@ import (
|
|||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"github.com/hashicorp/vault/sdk/helper/compressutil"
|
"github.com/hashicorp/vault/sdk/helper/compressutil"
|
||||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type bufCloser struct {
|
type bufCloser struct {
|
||||||
@@ -64,18 +61,7 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request
|
|||||||
|
|
||||||
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
|
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
|
||||||
var reader io.Reader = req.Body
|
var reader io.Reader = req.Body
|
||||||
ctx := req.Context()
|
body, err := io.ReadAll(reader)
|
||||||
if logical.ContextContainsMaxRequestSize(ctx) {
|
|
||||||
max, ok := logical.ContextMaxRequestSizeValue(ctx)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("could not parse max request size from request context")
|
|
||||||
}
|
|
||||||
if max > 0 {
|
|
||||||
reader = io.LimitReader(req.Body, max)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(reader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -242,6 +242,7 @@ func handler(props *vault.HandlerProperties) http.Handler {
|
|||||||
wrappedHandler = wrapCORSHandler(wrappedHandler, core)
|
wrappedHandler = wrapCORSHandler(wrappedHandler, core)
|
||||||
wrappedHandler = rateLimitQuotaWrapping(wrappedHandler, core)
|
wrappedHandler = rateLimitQuotaWrapping(wrappedHandler, core)
|
||||||
wrappedHandler = entWrapGenericHandler(core, wrappedHandler, props)
|
wrappedHandler = entWrapGenericHandler(core, wrappedHandler, props)
|
||||||
|
wrappedHandler = wrapMaxRequestSizeHandler(wrappedHandler, props)
|
||||||
|
|
||||||
// Add an extra wrapping handler if the DisablePrintableCheck listener
|
// Add an extra wrapping handler if the DisablePrintableCheck listener
|
||||||
// setting isn't true that checks for non-printable characters in the
|
// setting isn't true that checks for non-printable characters in the
|
||||||
@@ -332,18 +333,12 @@ func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler {
|
|||||||
// are performed.
|
// are performed.
|
||||||
func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerProperties) http.Handler {
|
func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerProperties) http.Handler {
|
||||||
var maxRequestDuration time.Duration
|
var maxRequestDuration time.Duration
|
||||||
var maxRequestSize int64
|
|
||||||
if props.ListenerConfig != nil {
|
if props.ListenerConfig != nil {
|
||||||
maxRequestDuration = props.ListenerConfig.MaxRequestDuration
|
maxRequestDuration = props.ListenerConfig.MaxRequestDuration
|
||||||
maxRequestSize = props.ListenerConfig.MaxRequestSize
|
|
||||||
}
|
}
|
||||||
if maxRequestDuration == 0 {
|
if maxRequestDuration == 0 {
|
||||||
maxRequestDuration = vault.DefaultMaxRequestDuration
|
maxRequestDuration = vault.DefaultMaxRequestDuration
|
||||||
}
|
}
|
||||||
if maxRequestSize == 0 {
|
|
||||||
maxRequestSize = DefaultMaxRequestSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// Swallow this error since we don't want to pollute the logs and we also don't want to
|
// Swallow this error since we don't want to pollute the logs and we also don't want to
|
||||||
// return an HTTP error here. This information is best effort.
|
// return an HTTP error here. This information is best effort.
|
||||||
hostname, _ := os.Hostname()
|
hostname, _ := os.Hostname()
|
||||||
@@ -378,11 +373,6 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
|
|||||||
ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration)
|
ctx, cancelFunc = context.WithTimeout(ctx, maxRequestDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if maxRequestSize < 0, no need to set context value
|
|
||||||
// Add a size limiter if desired
|
|
||||||
if maxRequestSize > 0 {
|
|
||||||
ctx = logical.CreateContextMaxRequestSize(ctx, maxRequestSize)
|
|
||||||
}
|
|
||||||
ctx = logical.CreateContextOriginalRequestPath(ctx, r.URL.Path)
|
ctx = logical.CreateContextOriginalRequestPath(ctx, r.URL.Path)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
|
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
|
||||||
@@ -738,24 +728,6 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
|
|||||||
// Limit the maximum number of bytes to MaxRequestSize to protect
|
// Limit the maximum number of bytes to MaxRequestSize to protect
|
||||||
// against an indefinite amount of data being read.
|
// against an indefinite amount of data being read.
|
||||||
reader := r.Body
|
reader := r.Body
|
||||||
ctx := r.Context()
|
|
||||||
if logical.ContextContainsMaxRequestSize(ctx) {
|
|
||||||
max, ok := logical.ContextMaxRequestSizeValue(ctx)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("could not parse max request size from request context")
|
|
||||||
}
|
|
||||||
if max > 0 {
|
|
||||||
// MaxBytesReader won't do all the internal stuff it must unless it's
|
|
||||||
// given a ResponseWriter that implements the internal http interface
|
|
||||||
// requestTooLarger. So we let it have access to the underlying
|
|
||||||
// ResponseWriter.
|
|
||||||
inw := w
|
|
||||||
if myw, ok := inw.(logical.WrappingResponseWriter); ok {
|
|
||||||
inw = myw.Wrapped()
|
|
||||||
}
|
|
||||||
reader = http.MaxBytesReader(inw, r.Body, max)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var origBody io.ReadWriter
|
var origBody io.ReadWriter
|
||||||
|
|
||||||
@@ -779,16 +751,6 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
|
|||||||
//
|
//
|
||||||
// A nil map will be returned if the format is empty or invalid.
|
// A nil map will be returned if the format is empty or invalid.
|
||||||
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
|
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
|
||||||
if logical.ContextContainsMaxRequestSize(r.Context()) {
|
|
||||||
max, ok := logical.ContextMaxRequestSizeValue(r.Context())
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("could not parse max request size from request context")
|
|
||||||
}
|
|
||||||
if max > 0 {
|
|
||||||
r.Body = io.NopCloser(io.LimitReader(r.Body, max))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -14,10 +15,12 @@ import (
|
|||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/internalshared/configutil"
|
"github.com/hashicorp/vault/internalshared/configutil"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/go-test/deep"
|
"github.com/go-test/deep"
|
||||||
"github.com/hashicorp/go-cleanhttp"
|
"github.com/hashicorp/go-cleanhttp"
|
||||||
@@ -892,3 +895,59 @@ func TestHandler_Parse_Form(t *testing.T) {
|
|||||||
t.Fatal(diff)
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestHandler_MaxRequestSize verifies that a request larger than the
|
||||||
|
// MaxRequestSize fails
|
||||||
|
func TestHandler_MaxRequestSize(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{
|
||||||
|
DefaultHandlerProperties: vault.HandlerProperties{
|
||||||
|
ListenerConfig: &configutil.Listener{
|
||||||
|
MaxRequestSize: 1024,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
HandlerFunc: Handler,
|
||||||
|
NumCores: 1,
|
||||||
|
})
|
||||||
|
cluster.Start()
|
||||||
|
defer cluster.Cleanup()
|
||||||
|
|
||||||
|
client := cluster.Cores[0].Client
|
||||||
|
_, err := client.KVv2("secret").Put(context.Background(), "foo", map[string]interface{}{
|
||||||
|
"bar": strings.Repeat("a", 1025),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.ErrorContains(t, err, "error parsing JSON")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandler_MaxRequestSize_Memory sets the max request size to 1024 bytes,
|
||||||
|
// and creates a 1MB request. The test verifies that less than 1MB of memory is
|
||||||
|
// allocated when the request is sent. This test shouldn't be run in parallel,
|
||||||
|
// because it modifies GOMAXPROCS
|
||||||
|
func TestHandler_MaxRequestSize_Memory(t *testing.T) {
|
||||||
|
ln, addr := TestListener(t)
|
||||||
|
core, _, token := vault.TestCoreUnsealed(t)
|
||||||
|
TestServerWithListenerAndProperties(t, ln, addr, core, &vault.HandlerProperties{
|
||||||
|
Core: core,
|
||||||
|
ListenerConfig: &configutil.Listener{
|
||||||
|
Address: addr,
|
||||||
|
MaxRequestSize: 1024,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
data := bytes.Repeat([]byte{0x1}, 1024*1024)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", addr+"/v1/sys/unseal", bytes.NewReader(data))
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set(consts.AuthHeaderName, token)
|
||||||
|
|
||||||
|
client := cleanhttp.DefaultClient()
|
||||||
|
defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
|
||||||
|
var start, end runtime.MemStats
|
||||||
|
runtime.GC()
|
||||||
|
runtime.ReadMemStats(&start)
|
||||||
|
client.Do(req)
|
||||||
|
runtime.ReadMemStats(&end)
|
||||||
|
require.Less(t, end.TotalAlloc-start.TotalAlloc, uint64(1024*1024))
|
||||||
|
}
|
||||||
|
|||||||
68
http/util.go
68
http/util.go
@@ -6,13 +6,13 @@ package http
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/helper/namespace"
|
"github.com/hashicorp/vault/helper/namespace"
|
||||||
@@ -22,6 +22,27 @@ import (
|
|||||||
|
|
||||||
var nonVotersAllowed = false
|
var nonVotersAllowed = false
|
||||||
|
|
||||||
|
func wrapMaxRequestSizeHandler(handler http.Handler, props *vault.HandlerProperties) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var maxRequestSize int64
|
||||||
|
if props.ListenerConfig != nil {
|
||||||
|
maxRequestSize = props.ListenerConfig.MaxRequestSize
|
||||||
|
}
|
||||||
|
if maxRequestSize == 0 {
|
||||||
|
maxRequestSize = DefaultMaxRequestSize
|
||||||
|
}
|
||||||
|
ctx := r.Context()
|
||||||
|
originalBody := r.Body
|
||||||
|
if maxRequestSize > 0 {
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize)
|
||||||
|
}
|
||||||
|
ctx = logical.CreateContextOriginalBody(ctx, originalBody)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
|
func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ns, err := namespace.FromContext(r.Context())
|
ns, err := namespace.FromContext(r.Context())
|
||||||
@@ -40,14 +61,6 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
|
|||||||
}
|
}
|
||||||
mountPath := strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path)
|
mountPath := strings.TrimPrefix(core.MatchingMount(r.Context(), path), ns.Path)
|
||||||
|
|
||||||
// Clone body, so we do not close the request body reader
|
|
||||||
bodyBytes, err := ioutil.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
respondError(w, http.StatusInternalServerError, errors.New("failed to read request body"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
|
|
||||||
quotaReq := "as.Request{
|
quotaReq := "as.Request{
|
||||||
Type: quotas.TypeRateLimit,
|
Type: quotas.TypeRateLimit,
|
||||||
Path: path,
|
Path: path,
|
||||||
@@ -67,7 +80,18 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
|
|||||||
// If any role-based quotas are enabled for this namespace/mount, just
|
// If any role-based quotas are enabled for this namespace/mount, just
|
||||||
// do the role resolution once here.
|
// do the role resolution once here.
|
||||||
if requiresResolveRole {
|
if requiresResolveRole {
|
||||||
role := core.DetermineRoleFromLoginRequestFromBytes(r.Context(), mountPath, bodyBytes)
|
buf := bytes.Buffer{}
|
||||||
|
teeReader := io.TeeReader(r.Body, &buf)
|
||||||
|
role := core.DetermineRoleFromLoginRequestFromReader(r.Context(), mountPath, teeReader)
|
||||||
|
|
||||||
|
// Reset the body if it was read
|
||||||
|
if buf.Len() > 0 {
|
||||||
|
r.Body = io.NopCloser(&buf)
|
||||||
|
originalBody, ok := logical.ContextOriginalBodyValue(r.Context())
|
||||||
|
if ok {
|
||||||
|
r = r.WithContext(logical.CreateContextOriginalBody(r.Context(), newMultiReaderCloser(&buf, originalBody)))
|
||||||
|
}
|
||||||
|
}
|
||||||
// add an entry to the context to prevent recalculating request role unnecessarily
|
// add an entry to the context to prevent recalculating request role unnecessarily
|
||||||
r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role))
|
r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role))
|
||||||
quotaReq.Role = role
|
quotaReq.Role = role
|
||||||
@@ -134,3 +158,25 @@ func parseRemoteIPAddress(r *http.Request) string {
|
|||||||
|
|
||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type multiReaderCloser struct {
|
||||||
|
readers []io.Reader
|
||||||
|
io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMultiReaderCloser(readers ...io.Reader) *multiReaderCloser {
|
||||||
|
return &multiReaderCloser{
|
||||||
|
readers: readers,
|
||||||
|
Reader: io.MultiReader(readers...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *multiReaderCloser) Close() error {
|
||||||
|
var err error
|
||||||
|
for _, r := range m.readers {
|
||||||
|
if c, ok := r.(io.Closer); ok {
|
||||||
|
err = multierror.Append(err, c.Close())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package logical
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -483,38 +484,6 @@ func CreateContextDisableReplicationStatusEndpoints(parent context.Context, valu
|
|||||||
return context.WithValue(parent, ctxKeyDisableReplicationStatusEndpoints{}, value)
|
return context.WithValue(parent, ctxKeyDisableReplicationStatusEndpoints{}, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CtxKeyMaxRequestSize is a custom type used as a key in context.Context to
|
|
||||||
// store the value of the max_request_size set for the listener through which
|
|
||||||
// a request was received.
|
|
||||||
type ctxKeyMaxRequestSize struct{}
|
|
||||||
|
|
||||||
// String returns a string representation of the receiver type.
|
|
||||||
func (c ctxKeyMaxRequestSize) String() string {
|
|
||||||
return "max_request_size"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContextMaxRequestSizeValue examines the provided context.Context for the max
|
|
||||||
// request size value and returns it as an int64 value if it's found along with
|
|
||||||
// the ok value set to true; otherwise the ok return value is false.
|
|
||||||
func ContextMaxRequestSizeValue(ctx context.Context) (value int64, ok bool) {
|
|
||||||
value, ok = ctx.Value(ctxKeyMaxRequestSize{}).(int64)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateContextMaxRequestSize creates a new context.Context based on the
|
|
||||||
// provided parent that also includes the provided max request size value for
|
|
||||||
// the ctxKeyMaxRequestSize key.
|
|
||||||
func CreateContextMaxRequestSize(parent context.Context, value int64) context.Context {
|
|
||||||
return context.WithValue(parent, ctxKeyMaxRequestSize{}, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContextContainsMaxRequestSize returns a bool value that indicates if the
|
|
||||||
// provided Context contains a value for the ctxKeyMaxRequestSize key.
|
|
||||||
func ContextContainsMaxRequestSize(ctx context.Context) bool {
|
|
||||||
return ctx.Value(ctxKeyMaxRequestSize{}) != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CtxKeyOriginalRequestPath is a custom type used as a key in context.Context
|
// CtxKeyOriginalRequestPath is a custom type used as a key in context.Context
|
||||||
// to store the original request path.
|
// to store the original request path.
|
||||||
type ctxKeyOriginalRequestPath struct{}
|
type ctxKeyOriginalRequestPath struct{}
|
||||||
@@ -539,3 +508,14 @@ func ContextOriginalRequestPathValue(ctx context.Context) (value string, ok bool
|
|||||||
func CreateContextOriginalRequestPath(parent context.Context, value string) context.Context {
|
func CreateContextOriginalRequestPath(parent context.Context, value string) context.Context {
|
||||||
return context.WithValue(parent, ctxKeyOriginalRequestPath{}, value)
|
return context.WithValue(parent, ctxKeyOriginalRequestPath{}, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ctxKeyOriginalBody struct{}
|
||||||
|
|
||||||
|
func ContextOriginalBodyValue(ctx context.Context) (io.ReadCloser, bool) {
|
||||||
|
value, ok := ctx.Value(ctxKeyOriginalBody{}).(io.ReadCloser)
|
||||||
|
return value, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateContextOriginalBody(parent context.Context, body io.ReadCloser) context.Context {
|
||||||
|
return context.WithValue(parent, ctxKeyOriginalBody{}, body)
|
||||||
|
}
|
||||||
|
|||||||
@@ -76,72 +76,6 @@ func TestCreateContextDisableReplicationStatusEndpoints(t *testing.T) {
|
|||||||
assert.Equal(t, false, value.(bool))
|
assert.Equal(t, false, value.(bool))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextMaxRequestSizeValue(t *testing.T) {
|
|
||||||
testcases := []struct {
|
|
||||||
name string
|
|
||||||
ctx context.Context
|
|
||||||
expectedValue int64
|
|
||||||
expectedOk bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "without-value",
|
|
||||||
ctx: context.Background(),
|
|
||||||
expectedValue: 0,
|
|
||||||
expectedOk: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with-nil",
|
|
||||||
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, nil),
|
|
||||||
expectedValue: 0,
|
|
||||||
expectedOk: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with-incompatible-value",
|
|
||||||
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, "6666"),
|
|
||||||
expectedValue: 0,
|
|
||||||
expectedOk: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with-int64-8888",
|
|
||||||
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(8888)),
|
|
||||||
expectedValue: 8888,
|
|
||||||
expectedOk: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with-int64-zero",
|
|
||||||
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(0)),
|
|
||||||
expectedValue: 0,
|
|
||||||
expectedOk: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testcase := range testcases {
|
|
||||||
value, ok := ContextMaxRequestSizeValue(testcase.ctx)
|
|
||||||
assert.Equal(t, testcase.expectedValue, value, testcase.name)
|
|
||||||
assert.Equal(t, testcase.expectedOk, ok, testcase.name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateContextMaxRequestSize(t *testing.T) {
|
|
||||||
ctx := CreateContextMaxRequestSize(context.Background(), int64(8888))
|
|
||||||
|
|
||||||
value := ctx.Value(ctxKeyMaxRequestSize{})
|
|
||||||
|
|
||||||
assert.NotNil(t, ctx)
|
|
||||||
assert.NotNil(t, value)
|
|
||||||
assert.IsType(t, int64(0), value)
|
|
||||||
assert.Equal(t, int64(8888), value.(int64))
|
|
||||||
|
|
||||||
ctx = CreateContextMaxRequestSize(context.Background(), int64(0))
|
|
||||||
|
|
||||||
value = ctx.Value(ctxKeyMaxRequestSize{})
|
|
||||||
|
|
||||||
assert.NotNil(t, ctx)
|
|
||||||
assert.NotNil(t, value)
|
|
||||||
assert.IsType(t, int64(0), value)
|
|
||||||
assert.Equal(t, int64(0), value.(int64))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestContextOriginalRequestPathValue(t *testing.T) {
|
func TestContextOriginalRequestPathValue(t *testing.T) {
|
||||||
testcases := []struct {
|
testcases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -3999,19 +3999,6 @@ func (c *Core) LoadNodeID() (string, error) {
|
|||||||
return hostname, nil
|
return hostname, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DetermineRoleFromLoginRequestFromBytes will determine the role that should be applied to a quota for a given
|
|
||||||
// login request, accepting a byte payload
|
|
||||||
func (c *Core) DetermineRoleFromLoginRequestFromBytes(ctx context.Context, mountPoint string, payload []byte) string {
|
|
||||||
data := make(map[string]interface{})
|
|
||||||
err := jsonutil.DecodeJSON(payload, &data)
|
|
||||||
if err != nil {
|
|
||||||
// Cannot discern a role from a request we cannot parse
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.DetermineRoleFromLoginRequest(ctx, mountPoint, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given
|
// DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given
|
||||||
// login request
|
// login request
|
||||||
func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string {
|
func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint string, data map[string]interface{}) string {
|
||||||
@@ -4022,7 +4009,33 @@ func (c *Core) DetermineRoleFromLoginRequest(ctx context.Context, mountPoint str
|
|||||||
// Role based quotas do not apply to this request
|
// Role based quotas do not apply to this request
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetermineRoleFromLoginRequestFromReader will determine the role that should
|
||||||
|
// be applied to a quota for a given login request. The reader will only be
|
||||||
|
// consumed if the matching backend for the mount point exists and is a secret
|
||||||
|
// backend
|
||||||
|
func (c *Core) DetermineRoleFromLoginRequestFromReader(ctx context.Context, mountPoint string, reader io.Reader) string {
|
||||||
|
c.authLock.RLock()
|
||||||
|
defer c.authLock.RUnlock()
|
||||||
|
matchingBackend := c.router.MatchingBackend(ctx, mountPoint)
|
||||||
|
if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential {
|
||||||
|
// Role based quotas do not apply to this request
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make(map[string]interface{})
|
||||||
|
err := jsonutil.DecodeJSONFromReader(reader, &data)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.doResolveRoleLocked(ctx, mountPoint, matchingBackend, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// doResolveRoleLocked does a login and resolve role request on the matching
|
||||||
|
// backend. Callers should have a read lock on c.authLock
|
||||||
|
func (c *Core) doResolveRoleLocked(ctx context.Context, mountPoint string, matchingBackend logical.Backend, data map[string]interface{}) string {
|
||||||
resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{
|
resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{
|
||||||
MountPoint: mountPoint,
|
MountPoint: mountPoint,
|
||||||
Path: "login",
|
Path: "login",
|
||||||
|
|||||||
@@ -570,7 +570,8 @@ func (b *SystemBackend) handleStorageRaftSnapshotWrite(force bool, makeSealer fu
|
|||||||
if !ok {
|
if !ok {
|
||||||
return logical.ErrorResponse("raft storage is not in use"), logical.ErrInvalidRequest
|
return logical.ErrorResponse("raft storage is not in use"), logical.ErrInvalidRequest
|
||||||
}
|
}
|
||||||
if req.HTTPRequest == nil || req.HTTPRequest.Body == nil {
|
body, ok := logical.ContextOriginalBodyValue(ctx)
|
||||||
|
if !ok {
|
||||||
return nil, errors.New("no reader for request")
|
return nil, errors.New("no reader for request")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -583,7 +584,7 @@ func (b *SystemBackend) handleStorageRaftSnapshotWrite(force bool, makeSealer fu
|
|||||||
// don't have to hold the full snapshot in memory. We also want to do
|
// don't have to hold the full snapshot in memory. We also want to do
|
||||||
// the restore in two parts so we can restore the snapshot while the
|
// the restore in two parts so we can restore the snapshot while the
|
||||||
// stateLock is write locked.
|
// stateLock is write locked.
|
||||||
snapFile, cleanup, metadata, err := raftStorage.WriteSnapshotToTemp(req.HTTPRequest.Body, sealer)
|
snapFile, cleanup, metadata, err := raftStorage.WriteSnapshotToTemp(body, sealer)
|
||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case err == nil:
|
||||||
case strings.Contains(err.Error(), "failed to open the sealed hashes"):
|
case strings.Contains(err.Error(), "failed to open the sealed hashes"):
|
||||||
|
|||||||
@@ -581,6 +581,10 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R
|
|||||||
if disable_repl_status, ok := logical.ContextDisableReplicationStatusEndpointsValue(httpCtx); ok {
|
if disable_repl_status, ok := logical.ContextDisableReplicationStatusEndpointsValue(httpCtx); ok {
|
||||||
ctx = logical.CreateContextDisableReplicationStatusEndpoints(ctx, disable_repl_status)
|
ctx = logical.CreateContextDisableReplicationStatusEndpoints(ctx, disable_repl_status)
|
||||||
}
|
}
|
||||||
|
body, ok := logical.ContextOriginalBodyValue(httpCtx)
|
||||||
|
if ok {
|
||||||
|
ctx = logical.CreateContextOriginalBody(ctx, body)
|
||||||
|
}
|
||||||
resp, err = c.handleCancelableRequest(ctx, req)
|
resp, err = c.handleCancelableRequest(ctx, req)
|
||||||
req.SetTokenEntry(nil)
|
req.SetTokenEntry(nil)
|
||||||
cancel()
|
cancel()
|
||||||
|
|||||||
Reference in New Issue
Block a user