From 1ebbf449b4b68793b2242738f5bcabc677f41651 Mon Sep 17 00:00:00 2001 From: Marc Boudreau Date: Wed, 18 Oct 2023 09:30:00 -0400 Subject: [PATCH] Improve Robustness of Custom Context Values Types (#23697) --- helper/forwarding/util.go | 9 +- http/handler.go | 23 +- http/util.go | 2 +- .../dbplugin/v5/plugin_client_test.go | 4 + sdk/logical/request.go | 67 +++++- sdk/logical/request_test.go | 206 ++++++++++++++++++ vault/cluster_test.go | 2 +- vault/request_forwarding.go | 7 +- vault/request_handling.go | 5 +- 9 files changed, 297 insertions(+), 28 deletions(-) create mode 100644 sdk/logical/request_test.go diff --git a/helper/forwarding/util.go b/helper/forwarding/util.go index a712c11885..1bc1c5e68f 100644 --- a/helper/forwarding/util.go +++ b/helper/forwarding/util.go @@ -7,7 +7,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" - "fmt" + "errors" "io" "io/ioutil" "net/http" @@ -65,11 +65,10 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request func GenerateForwardedRequest(req *http.Request) (*Request, error) { var reader io.Reader = req.Body ctx := req.Context() - maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{}) - if maxRequestSize != nil { - max, ok := maxRequestSize.(int64) + if logical.ContextContainsMaxRequestSize(ctx) { + max, ok := logical.ContextMaxRequestSizeValue(ctx) if !ok { - return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{}) + return nil, errors.New("could not parse max request size from request context") } if max > 0 { reader = io.LimitReader(req.Body, max) diff --git a/http/handler.go b/http/handler.go index 43a92aa9e5..e0eedd883a 100644 --- a/http/handler.go +++ b/http/handler.go @@ -381,9 +381,9 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr // if maxRequestSize < 0, no need to set context value // Add a size limiter if desired if maxRequestSize > 0 { - ctx = context.WithValue(ctx, logical.CtxKeyMaxRequestSize{}, maxRequestSize) + ctx = logical.CreateContextMaxRequestSize(ctx, maxRequestSize) } - ctx = context.WithValue(ctx, logical.CtxKeyOriginalRequestPath{}, r.URL.Path) + ctx = logical.CreateContextOriginalRequestPath(ctx, r.URL.Path) r = r.WithContext(ctx) r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace)) @@ -710,11 +710,10 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, // against an indefinite amount of data being read. reader := r.Body ctx := r.Context() - maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{}) - if maxRequestSize != nil { - max, ok := maxRequestSize.(int64) + if logical.ContextContainsMaxRequestSize(ctx) { + max, ok := logical.ContextMaxRequestSizeValue(ctx) if !ok { - return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{}) + return nil, errors.New("could not parse max request size from request context") } if max > 0 { // MaxBytesReader won't do all the internal stuff it must unless it's @@ -728,7 +727,9 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, reader = http.MaxBytesReader(inw, r.Body, max) } } + var origBody io.ReadWriter + if perfStandby { // Since we're checking PerfStandby here we key on origBody being nil // or not later, so we need to always allocate so it's non-nil @@ -749,16 +750,16 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter, // // A nil map will be returned if the format is empty or invalid. func parseFormRequest(r *http.Request) (map[string]interface{}, error) { - maxRequestSize := r.Context().Value(logical.CtxKeyMaxRequestSize{}) - if maxRequestSize != nil { - max, ok := maxRequestSize.(int64) + if logical.ContextContainsMaxRequestSize(r.Context()) { + max, ok := logical.ContextMaxRequestSizeValue(r.Context()) if !ok { - return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{}) + return nil, errors.New("could not parse max request size from request context") } if max > 0 { - r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max)) + r.Body = io.NopCloser(io.LimitReader(r.Body, max)) } } + if err := r.ParseForm(); err != nil { return nil, err } diff --git a/http/util.go b/http/util.go index dc824c1269..a5e3cf0776 100644 --- a/http/util.go +++ b/http/util.go @@ -120,7 +120,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler func disableReplicationStatusEndpointWrapping(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - request := r.WithContext(context.WithValue(r.Context(), logical.CtxKeyDisableReplicationStatusEndpoints{}, true)) + request := r.WithContext(logical.CreateContextDisableReplicationStatusEndpoints(r.Context(), true)) h.ServeHTTP(w, request) }) diff --git a/sdk/database/dbplugin/v5/plugin_client_test.go b/sdk/database/dbplugin/v5/plugin_client_test.go index 10f02b7bec..fb6852d1a4 100644 --- a/sdk/database/dbplugin/v5/plugin_client_test.go +++ b/sdk/database/dbplugin/v5/plugin_client_test.go @@ -156,3 +156,7 @@ func (m *mockRunnerUtil) MlockEnabled() bool { args := m.Called() return args.Bool(0) } + +func (m *mockRunnerUtil) ClusterID(ctx context.Context) (string, error) { + return "clusterid", nil +} diff --git a/sdk/logical/request.go b/sdk/logical/request.go index bb42db7dd3..4b617dbc10 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -458,28 +458,83 @@ func (c CtxKeyRequestRole) String() string { // context.Context to store the value `true` when the // disable_replication_status_endpoints configuration parameter is set to true // for the listener through which a request was received. -type CtxKeyDisableReplicationStatusEndpoints struct{} +type ctxKeyDisableReplicationStatusEndpoints struct{} // String returns a string representation of the receiver type. -func (c CtxKeyDisableReplicationStatusEndpoints) String() string { +func (c ctxKeyDisableReplicationStatusEndpoints) String() string { return "disable-replication-status-endpoints" } +// ContextDisableReplicationStatusEndpointsValue examines the provided +// context.Context for the disable replication status endpoints value and +// returns it as a bool value if it's found along with the ok return value set +// to true; otherwise the ok return value is false. +func ContextDisableReplicationStatusEndpointsValue(ctx context.Context) (value, ok bool) { + value, ok = ctx.Value(ctxKeyDisableReplicationStatusEndpoints{}).(bool) + + return +} + +// CreateContextDisableReplicationStatusEndpoints creates a new context.Context +// based on the provided parent that also includes the provided value for the +// ctxKeyDisableReplicationStatusEndpoints key. +func CreateContextDisableReplicationStatusEndpoints(parent context.Context, value bool) context.Context { + return context.WithValue(parent, ctxKeyDisableReplicationStatusEndpoints{}, value) +} + // CtxKeyMaxRequestSize is a custom type used as a key in context.Context to // store the value of the max_request_size set for the listener through which // a request was received. -type CtxKeyMaxRequestSize struct{} +type ctxKeyMaxRequestSize struct{} // String returns a string representation of the receiver type. -func (c CtxKeyMaxRequestSize) String() string { +func (c ctxKeyMaxRequestSize) String() string { return "max_request_size" } +// ContextMaxRequestSizeValue examines the provided context.Context for the max +// request size value and returns it as an int64 value if it's found along with +// the ok value set to true; otherwise the ok return value is false. +func ContextMaxRequestSizeValue(ctx context.Context) (value int64, ok bool) { + value, ok = ctx.Value(ctxKeyMaxRequestSize{}).(int64) + + return +} + +// CreateContextMaxRequestSize creates a new context.Context based on the +// provided parent that also includes the provided max request size value for +// the ctxKeyMaxRequestSize key. +func CreateContextMaxRequestSize(parent context.Context, value int64) context.Context { + return context.WithValue(parent, ctxKeyMaxRequestSize{}, value) +} + +// ContextContainsMaxRequestSize returns a bool value that indicates if the +// provided Context contains a value for the ctxKeyMaxRequestSize key. +func ContextContainsMaxRequestSize(ctx context.Context) bool { + return ctx.Value(ctxKeyMaxRequestSize{}) != nil +} + // CtxKeyOriginalRequestPath is a custom type used as a key in context.Context // to store the original request path. -type CtxKeyOriginalRequestPath struct{} +type ctxKeyOriginalRequestPath struct{} // String returns a string representation of the receiver type. -func (c CtxKeyOriginalRequestPath) String() string { +func (c ctxKeyOriginalRequestPath) String() string { return "original_request_path" } + +// ContextOriginalRequestPathValue examines the provided context.Context for the +// original request path value and returns it as a string value if it's found +// along with the ok value set to true; otherwise the ok return value is false. +func ContextOriginalRequestPathValue(ctx context.Context) (value string, ok bool) { + value, ok = ctx.Value(ctxKeyOriginalRequestPath{}).(string) + + return +} + +// CreateContextOriginalRequestPath creates a new context.Context based on the +// provided parent that also includes the provided original request path value +// for the ctxKeyOriginalRequestPath key. +func CreateContextOriginalRequestPath(parent context.Context, value string) context.Context { + return context.WithValue(parent, ctxKeyOriginalRequestPath{}, value) +} diff --git a/sdk/logical/request_test.go b/sdk/logical/request_test.go new file mode 100644 index 0000000000..bb5d4bb6ad --- /dev/null +++ b/sdk/logical/request_test.go @@ -0,0 +1,206 @@ +package logical + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContextDisableReplicationStatusEndpointsValue(t *testing.T) { + testcases := []struct { + name string + ctx context.Context + expectedValue bool + expectedOk bool + }{ + { + name: "without-value", + ctx: context.Background(), + expectedValue: false, + expectedOk: false, + }, + { + name: "with-nil", + ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, nil), + expectedValue: false, + expectedOk: false, + }, + { + name: "with-incompatible-value", + ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, "true"), + expectedValue: false, + expectedOk: false, + }, + { + name: "with-bool-true", + ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, true), + expectedValue: true, + expectedOk: true, + }, + { + name: "with-bool-false", + ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, false), + expectedValue: false, + expectedOk: true, + }, + } + + for _, testcase := range testcases { + value, ok := ContextDisableReplicationStatusEndpointsValue(testcase.ctx) + assert.Equal(t, testcase.expectedValue, value, testcase.name) + assert.Equal(t, testcase.expectedOk, ok, testcase.name) + } +} + +func TestCreateContextDisableReplicationStatusEndpoints(t *testing.T) { + ctx := CreateContextDisableReplicationStatusEndpoints(context.Background(), true) + + value := ctx.Value(ctxKeyDisableReplicationStatusEndpoints{}) + + assert.NotNil(t, ctx) + assert.NotNil(t, value) + assert.IsType(t, bool(false), value) + assert.Equal(t, true, value.(bool)) + + ctx = CreateContextDisableReplicationStatusEndpoints(context.Background(), false) + + value = ctx.Value(ctxKeyDisableReplicationStatusEndpoints{}) + + assert.NotNil(t, ctx) + assert.NotNil(t, value) + assert.IsType(t, bool(false), value) + assert.Equal(t, false, value.(bool)) +} + +func TestContextMaxRequestSizeValue(t *testing.T) { + testcases := []struct { + name string + ctx context.Context + expectedValue int64 + expectedOk bool + }{ + { + name: "without-value", + ctx: context.Background(), + expectedValue: 0, + expectedOk: false, + }, + { + name: "with-nil", + ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, nil), + expectedValue: 0, + expectedOk: false, + }, + { + name: "with-incompatible-value", + ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, "6666"), + expectedValue: 0, + expectedOk: false, + }, + { + name: "with-int64-8888", + ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(8888)), + expectedValue: 8888, + expectedOk: true, + }, + { + name: "with-int64-zero", + ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(0)), + expectedValue: 0, + expectedOk: true, + }, + } + + for _, testcase := range testcases { + value, ok := ContextMaxRequestSizeValue(testcase.ctx) + assert.Equal(t, testcase.expectedValue, value, testcase.name) + assert.Equal(t, testcase.expectedOk, ok, testcase.name) + } +} + +func TestCreateContextMaxRequestSize(t *testing.T) { + ctx := CreateContextMaxRequestSize(context.Background(), int64(8888)) + + value := ctx.Value(ctxKeyMaxRequestSize{}) + + assert.NotNil(t, ctx) + assert.NotNil(t, value) + assert.IsType(t, int64(0), value) + assert.Equal(t, int64(8888), value.(int64)) + + ctx = CreateContextMaxRequestSize(context.Background(), int64(0)) + + value = ctx.Value(ctxKeyMaxRequestSize{}) + + assert.NotNil(t, ctx) + assert.NotNil(t, value) + assert.IsType(t, int64(0), value) + assert.Equal(t, int64(0), value.(int64)) +} + +func TestContextOriginalRequestPathValue(t *testing.T) { + testcases := []struct { + name string + ctx context.Context + expectedValue string + expectedOk bool + }{ + { + name: "without-value", + ctx: context.Background(), + expectedValue: "", + expectedOk: false, + }, + { + name: "with-nil", + ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, nil), + expectedValue: "", + expectedOk: false, + }, + { + name: "with-incompatible-value", + ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, 6666), + expectedValue: "", + expectedOk: false, + }, + { + name: "with-string-value", + ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, "test"), + expectedValue: "test", + expectedOk: true, + }, + { + name: "with-empty-string", + ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, ""), + expectedValue: "", + expectedOk: true, + }, + } + + for _, testcase := range testcases { + value, ok := ContextOriginalRequestPathValue(testcase.ctx) + assert.Equal(t, testcase.expectedValue, value, testcase.name) + assert.Equal(t, testcase.expectedOk, ok, testcase.name) + } +} + +func TestCreateContextOriginalRequestPath(t *testing.T) { + ctx := CreateContextOriginalRequestPath(context.Background(), "test") + + value := ctx.Value(ctxKeyOriginalRequestPath{}) + + assert.NotNil(t, ctx) + assert.NotNil(t, value) + assert.IsType(t, string(""), value) + assert.Equal(t, "test", value.(string)) + + ctx = CreateContextOriginalRequestPath(context.Background(), "") + + value = ctx.Value(ctxKeyOriginalRequestPath{}) + + assert.NotNil(t, ctx) + assert.NotNil(t, value) + assert.IsType(t, string(""), value) + assert.Equal(t, "", value.(string)) +} diff --git a/vault/cluster_test.go b/vault/cluster_test.go index c890582bea..97f27fd4f4 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -338,7 +338,7 @@ func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, rootToken, re t.Fatal(err) } req.Header.Add(consts.AuthHeaderName, rootToken) - req = req.WithContext(context.WithValue(req.Context(), logical.CtxKeyOriginalRequestPath{}, req.URL.Path)) + req = req.WithContext(logical.CreateContextOriginalRequestPath(req.Context(), req.URL.Path)) statusCode, header, respBytes, err := c.ForwardRequest(req) if err != nil { diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 440de62b34..6e75e0b2b2 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -350,7 +350,12 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro req.URL.Path = origPath }() - req.URL.Path = req.Context().Value(logical.CtxKeyOriginalRequestPath{}).(string) + path, ok := logical.ContextOriginalRequestPathValue(req.Context()) + if !ok { + return 0, nil, nil, errors.New("error extracting request path for forwarding RPC request") + } + + req.URL.Path = path freq, err := forwarding.GenerateForwardedRequest(req) if err != nil { diff --git a/vault/request_handling.go b/vault/request_handling.go index a3938ba6b7..8985e51f82 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -566,9 +566,8 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R if ok { ctx = context.WithValue(ctx, logical.CtxKeyRequestRole{}, requestRole) } - disable_repl_status, ok := httpCtx.Value(logical.CtxKeyDisableReplicationStatusEndpoints{}).(string) - if ok { - ctx = context.WithValue(ctx, logical.CtxKeyDisableReplicationStatusEndpoints{}, disable_repl_status) + if disable_repl_status, ok := logical.ContextDisableReplicationStatusEndpointsValue(httpCtx); ok { + ctx = logical.CreateContextDisableReplicationStatusEndpoints(ctx, disable_repl_status) } resp, err = c.handleCancelableRequest(ctx, req) req.SetTokenEntry(nil)