mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 19:47:54 +00:00
Improve Robustness of Custom Context Values Types (#23697)
This commit is contained in:
@@ -7,7 +7,7 @@ import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -65,11 +65,10 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request
|
||||
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
|
||||
var reader io.Reader = req.Body
|
||||
ctx := req.Context()
|
||||
maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{})
|
||||
if maxRequestSize != nil {
|
||||
max, ok := maxRequestSize.(int64)
|
||||
if logical.ContextContainsMaxRequestSize(ctx) {
|
||||
max, ok := logical.ContextMaxRequestSizeValue(ctx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
|
||||
return nil, errors.New("could not parse max request size from request context")
|
||||
}
|
||||
if max > 0 {
|
||||
reader = io.LimitReader(req.Body, max)
|
||||
|
||||
@@ -381,9 +381,9 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
|
||||
// if maxRequestSize < 0, no need to set context value
|
||||
// Add a size limiter if desired
|
||||
if maxRequestSize > 0 {
|
||||
ctx = context.WithValue(ctx, logical.CtxKeyMaxRequestSize{}, maxRequestSize)
|
||||
ctx = logical.CreateContextMaxRequestSize(ctx, maxRequestSize)
|
||||
}
|
||||
ctx = context.WithValue(ctx, logical.CtxKeyOriginalRequestPath{}, r.URL.Path)
|
||||
ctx = logical.CreateContextOriginalRequestPath(ctx, r.URL.Path)
|
||||
r = r.WithContext(ctx)
|
||||
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
|
||||
|
||||
@@ -710,11 +710,10 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
|
||||
// against an indefinite amount of data being read.
|
||||
reader := r.Body
|
||||
ctx := r.Context()
|
||||
maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{})
|
||||
if maxRequestSize != nil {
|
||||
max, ok := maxRequestSize.(int64)
|
||||
if logical.ContextContainsMaxRequestSize(ctx) {
|
||||
max, ok := logical.ContextMaxRequestSizeValue(ctx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
|
||||
return nil, errors.New("could not parse max request size from request context")
|
||||
}
|
||||
if max > 0 {
|
||||
// MaxBytesReader won't do all the internal stuff it must unless it's
|
||||
@@ -728,7 +727,9 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
|
||||
reader = http.MaxBytesReader(inw, r.Body, max)
|
||||
}
|
||||
}
|
||||
|
||||
var origBody io.ReadWriter
|
||||
|
||||
if perfStandby {
|
||||
// Since we're checking PerfStandby here we key on origBody being nil
|
||||
// or not later, so we need to always allocate so it's non-nil
|
||||
@@ -749,16 +750,16 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
|
||||
//
|
||||
// A nil map will be returned if the format is empty or invalid.
|
||||
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
|
||||
maxRequestSize := r.Context().Value(logical.CtxKeyMaxRequestSize{})
|
||||
if maxRequestSize != nil {
|
||||
max, ok := maxRequestSize.(int64)
|
||||
if logical.ContextContainsMaxRequestSize(r.Context()) {
|
||||
max, ok := logical.ContextMaxRequestSizeValue(r.Context())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
|
||||
return nil, errors.New("could not parse max request size from request context")
|
||||
}
|
||||
if max > 0 {
|
||||
r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max))
|
||||
r.Body = io.NopCloser(io.LimitReader(r.Body, max))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
|
||||
|
||||
func disableReplicationStatusEndpointWrapping(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
request := r.WithContext(context.WithValue(r.Context(), logical.CtxKeyDisableReplicationStatusEndpoints{}, true))
|
||||
request := r.WithContext(logical.CreateContextDisableReplicationStatusEndpoints(r.Context(), true))
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
})
|
||||
|
||||
@@ -156,3 +156,7 @@ func (m *mockRunnerUtil) MlockEnabled() bool {
|
||||
args := m.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *mockRunnerUtil) ClusterID(ctx context.Context) (string, error) {
|
||||
return "clusterid", nil
|
||||
}
|
||||
|
||||
@@ -458,28 +458,83 @@ func (c CtxKeyRequestRole) String() string {
|
||||
// context.Context to store the value `true` when the
|
||||
// disable_replication_status_endpoints configuration parameter is set to true
|
||||
// for the listener through which a request was received.
|
||||
type CtxKeyDisableReplicationStatusEndpoints struct{}
|
||||
type ctxKeyDisableReplicationStatusEndpoints struct{}
|
||||
|
||||
// String returns a string representation of the receiver type.
|
||||
func (c CtxKeyDisableReplicationStatusEndpoints) String() string {
|
||||
func (c ctxKeyDisableReplicationStatusEndpoints) String() string {
|
||||
return "disable-replication-status-endpoints"
|
||||
}
|
||||
|
||||
// ContextDisableReplicationStatusEndpointsValue examines the provided
|
||||
// context.Context for the disable replication status endpoints value and
|
||||
// returns it as a bool value if it's found along with the ok return value set
|
||||
// to true; otherwise the ok return value is false.
|
||||
func ContextDisableReplicationStatusEndpointsValue(ctx context.Context) (value, ok bool) {
|
||||
value, ok = ctx.Value(ctxKeyDisableReplicationStatusEndpoints{}).(bool)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CreateContextDisableReplicationStatusEndpoints creates a new context.Context
|
||||
// based on the provided parent that also includes the provided value for the
|
||||
// ctxKeyDisableReplicationStatusEndpoints key.
|
||||
func CreateContextDisableReplicationStatusEndpoints(parent context.Context, value bool) context.Context {
|
||||
return context.WithValue(parent, ctxKeyDisableReplicationStatusEndpoints{}, value)
|
||||
}
|
||||
|
||||
// CtxKeyMaxRequestSize is a custom type used as a key in context.Context to
|
||||
// store the value of the max_request_size set for the listener through which
|
||||
// a request was received.
|
||||
type CtxKeyMaxRequestSize struct{}
|
||||
type ctxKeyMaxRequestSize struct{}
|
||||
|
||||
// String returns a string representation of the receiver type.
|
||||
func (c CtxKeyMaxRequestSize) String() string {
|
||||
func (c ctxKeyMaxRequestSize) String() string {
|
||||
return "max_request_size"
|
||||
}
|
||||
|
||||
// ContextMaxRequestSizeValue examines the provided context.Context for the max
|
||||
// request size value and returns it as an int64 value if it's found along with
|
||||
// the ok value set to true; otherwise the ok return value is false.
|
||||
func ContextMaxRequestSizeValue(ctx context.Context) (value int64, ok bool) {
|
||||
value, ok = ctx.Value(ctxKeyMaxRequestSize{}).(int64)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CreateContextMaxRequestSize creates a new context.Context based on the
|
||||
// provided parent that also includes the provided max request size value for
|
||||
// the ctxKeyMaxRequestSize key.
|
||||
func CreateContextMaxRequestSize(parent context.Context, value int64) context.Context {
|
||||
return context.WithValue(parent, ctxKeyMaxRequestSize{}, value)
|
||||
}
|
||||
|
||||
// ContextContainsMaxRequestSize returns a bool value that indicates if the
|
||||
// provided Context contains a value for the ctxKeyMaxRequestSize key.
|
||||
func ContextContainsMaxRequestSize(ctx context.Context) bool {
|
||||
return ctx.Value(ctxKeyMaxRequestSize{}) != nil
|
||||
}
|
||||
|
||||
// CtxKeyOriginalRequestPath is a custom type used as a key in context.Context
|
||||
// to store the original request path.
|
||||
type CtxKeyOriginalRequestPath struct{}
|
||||
type ctxKeyOriginalRequestPath struct{}
|
||||
|
||||
// String returns a string representation of the receiver type.
|
||||
func (c CtxKeyOriginalRequestPath) String() string {
|
||||
func (c ctxKeyOriginalRequestPath) String() string {
|
||||
return "original_request_path"
|
||||
}
|
||||
|
||||
// ContextOriginalRequestPathValue examines the provided context.Context for the
|
||||
// original request path value and returns it as a string value if it's found
|
||||
// along with the ok value set to true; otherwise the ok return value is false.
|
||||
func ContextOriginalRequestPathValue(ctx context.Context) (value string, ok bool) {
|
||||
value, ok = ctx.Value(ctxKeyOriginalRequestPath{}).(string)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CreateContextOriginalRequestPath creates a new context.Context based on the
|
||||
// provided parent that also includes the provided original request path value
|
||||
// for the ctxKeyOriginalRequestPath key.
|
||||
func CreateContextOriginalRequestPath(parent context.Context, value string) context.Context {
|
||||
return context.WithValue(parent, ctxKeyOriginalRequestPath{}, value)
|
||||
}
|
||||
|
||||
206
sdk/logical/request_test.go
Normal file
206
sdk/logical/request_test.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package logical
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContextDisableReplicationStatusEndpointsValue(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expectedValue bool
|
||||
expectedOk bool
|
||||
}{
|
||||
{
|
||||
name: "without-value",
|
||||
ctx: context.Background(),
|
||||
expectedValue: false,
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-nil",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, nil),
|
||||
expectedValue: false,
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-incompatible-value",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, "true"),
|
||||
expectedValue: false,
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-bool-true",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, true),
|
||||
expectedValue: true,
|
||||
expectedOk: true,
|
||||
},
|
||||
{
|
||||
name: "with-bool-false",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, false),
|
||||
expectedValue: false,
|
||||
expectedOk: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testcase := range testcases {
|
||||
value, ok := ContextDisableReplicationStatusEndpointsValue(testcase.ctx)
|
||||
assert.Equal(t, testcase.expectedValue, value, testcase.name)
|
||||
assert.Equal(t, testcase.expectedOk, ok, testcase.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateContextDisableReplicationStatusEndpoints(t *testing.T) {
|
||||
ctx := CreateContextDisableReplicationStatusEndpoints(context.Background(), true)
|
||||
|
||||
value := ctx.Value(ctxKeyDisableReplicationStatusEndpoints{})
|
||||
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, value)
|
||||
assert.IsType(t, bool(false), value)
|
||||
assert.Equal(t, true, value.(bool))
|
||||
|
||||
ctx = CreateContextDisableReplicationStatusEndpoints(context.Background(), false)
|
||||
|
||||
value = ctx.Value(ctxKeyDisableReplicationStatusEndpoints{})
|
||||
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, value)
|
||||
assert.IsType(t, bool(false), value)
|
||||
assert.Equal(t, false, value.(bool))
|
||||
}
|
||||
|
||||
func TestContextMaxRequestSizeValue(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expectedValue int64
|
||||
expectedOk bool
|
||||
}{
|
||||
{
|
||||
name: "without-value",
|
||||
ctx: context.Background(),
|
||||
expectedValue: 0,
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-nil",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, nil),
|
||||
expectedValue: 0,
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-incompatible-value",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, "6666"),
|
||||
expectedValue: 0,
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-int64-8888",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(8888)),
|
||||
expectedValue: 8888,
|
||||
expectedOk: true,
|
||||
},
|
||||
{
|
||||
name: "with-int64-zero",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(0)),
|
||||
expectedValue: 0,
|
||||
expectedOk: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testcase := range testcases {
|
||||
value, ok := ContextMaxRequestSizeValue(testcase.ctx)
|
||||
assert.Equal(t, testcase.expectedValue, value, testcase.name)
|
||||
assert.Equal(t, testcase.expectedOk, ok, testcase.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateContextMaxRequestSize(t *testing.T) {
|
||||
ctx := CreateContextMaxRequestSize(context.Background(), int64(8888))
|
||||
|
||||
value := ctx.Value(ctxKeyMaxRequestSize{})
|
||||
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, value)
|
||||
assert.IsType(t, int64(0), value)
|
||||
assert.Equal(t, int64(8888), value.(int64))
|
||||
|
||||
ctx = CreateContextMaxRequestSize(context.Background(), int64(0))
|
||||
|
||||
value = ctx.Value(ctxKeyMaxRequestSize{})
|
||||
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, value)
|
||||
assert.IsType(t, int64(0), value)
|
||||
assert.Equal(t, int64(0), value.(int64))
|
||||
}
|
||||
|
||||
func TestContextOriginalRequestPathValue(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expectedValue string
|
||||
expectedOk bool
|
||||
}{
|
||||
{
|
||||
name: "without-value",
|
||||
ctx: context.Background(),
|
||||
expectedValue: "",
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-nil",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, nil),
|
||||
expectedValue: "",
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-incompatible-value",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, 6666),
|
||||
expectedValue: "",
|
||||
expectedOk: false,
|
||||
},
|
||||
{
|
||||
name: "with-string-value",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, "test"),
|
||||
expectedValue: "test",
|
||||
expectedOk: true,
|
||||
},
|
||||
{
|
||||
name: "with-empty-string",
|
||||
ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, ""),
|
||||
expectedValue: "",
|
||||
expectedOk: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testcase := range testcases {
|
||||
value, ok := ContextOriginalRequestPathValue(testcase.ctx)
|
||||
assert.Equal(t, testcase.expectedValue, value, testcase.name)
|
||||
assert.Equal(t, testcase.expectedOk, ok, testcase.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateContextOriginalRequestPath(t *testing.T) {
|
||||
ctx := CreateContextOriginalRequestPath(context.Background(), "test")
|
||||
|
||||
value := ctx.Value(ctxKeyOriginalRequestPath{})
|
||||
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, value)
|
||||
assert.IsType(t, string(""), value)
|
||||
assert.Equal(t, "test", value.(string))
|
||||
|
||||
ctx = CreateContextOriginalRequestPath(context.Background(), "")
|
||||
|
||||
value = ctx.Value(ctxKeyOriginalRequestPath{})
|
||||
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, value)
|
||||
assert.IsType(t, string(""), value)
|
||||
assert.Equal(t, "", value.(string))
|
||||
}
|
||||
@@ -338,7 +338,7 @@ func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, rootToken, re
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Add(consts.AuthHeaderName, rootToken)
|
||||
req = req.WithContext(context.WithValue(req.Context(), logical.CtxKeyOriginalRequestPath{}, req.URL.Path))
|
||||
req = req.WithContext(logical.CreateContextOriginalRequestPath(req.Context(), req.URL.Path))
|
||||
|
||||
statusCode, header, respBytes, err := c.ForwardRequest(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -350,7 +350,12 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
|
||||
req.URL.Path = origPath
|
||||
}()
|
||||
|
||||
req.URL.Path = req.Context().Value(logical.CtxKeyOriginalRequestPath{}).(string)
|
||||
path, ok := logical.ContextOriginalRequestPathValue(req.Context())
|
||||
if !ok {
|
||||
return 0, nil, nil, errors.New("error extracting request path for forwarding RPC request")
|
||||
}
|
||||
|
||||
req.URL.Path = path
|
||||
|
||||
freq, err := forwarding.GenerateForwardedRequest(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -566,9 +566,8 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R
|
||||
if ok {
|
||||
ctx = context.WithValue(ctx, logical.CtxKeyRequestRole{}, requestRole)
|
||||
}
|
||||
disable_repl_status, ok := httpCtx.Value(logical.CtxKeyDisableReplicationStatusEndpoints{}).(string)
|
||||
if ok {
|
||||
ctx = context.WithValue(ctx, logical.CtxKeyDisableReplicationStatusEndpoints{}, disable_repl_status)
|
||||
if disable_repl_status, ok := logical.ContextDisableReplicationStatusEndpointsValue(httpCtx); ok {
|
||||
ctx = logical.CreateContextDisableReplicationStatusEndpoints(ctx, disable_repl_status)
|
||||
}
|
||||
resp, err = c.handleCancelableRequest(ctx, req)
|
||||
req.SetTokenEntry(nil)
|
||||
|
||||
Reference in New Issue
Block a user