Improve Robustness of Custom Context Values Types (#23697)

This commit is contained in:
Marc Boudreau
2023-10-18 09:30:00 -04:00
committed by GitHub
parent e2c7cc2b18
commit 1ebbf449b4
9 changed files with 297 additions and 28 deletions

View File

@@ -7,7 +7,7 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"errors"
"io"
"io/ioutil"
"net/http"
@@ -65,11 +65,10 @@ func GenerateForwardedHTTPRequest(req *http.Request, addr string) (*http.Request
func GenerateForwardedRequest(req *http.Request) (*Request, error) {
var reader io.Reader = req.Body
ctx := req.Context()
maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{})
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if logical.ContextContainsMaxRequestSize(ctx) {
max, ok := logical.ContextMaxRequestSizeValue(ctx)
if !ok {
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
return nil, errors.New("could not parse max request size from request context")
}
if max > 0 {
reader = io.LimitReader(req.Body, max)

View File

@@ -381,9 +381,9 @@ func wrapGenericHandler(core *vault.Core, h http.Handler, props *vault.HandlerPr
// if maxRequestSize < 0, no need to set context value
// Add a size limiter if desired
if maxRequestSize > 0 {
ctx = context.WithValue(ctx, logical.CtxKeyMaxRequestSize{}, maxRequestSize)
ctx = logical.CreateContextMaxRequestSize(ctx, maxRequestSize)
}
ctx = context.WithValue(ctx, logical.CtxKeyOriginalRequestPath{}, r.URL.Path)
ctx = logical.CreateContextOriginalRequestPath(ctx, r.URL.Path)
r = r.WithContext(ctx)
r = r.WithContext(namespace.ContextWithNamespace(r.Context(), namespace.RootNamespace))
@@ -710,11 +710,10 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
// against an indefinite amount of data being read.
reader := r.Body
ctx := r.Context()
maxRequestSize := ctx.Value(logical.CtxKeyMaxRequestSize{})
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if logical.ContextContainsMaxRequestSize(ctx) {
max, ok := logical.ContextMaxRequestSizeValue(ctx)
if !ok {
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
return nil, errors.New("could not parse max request size from request context")
}
if max > 0 {
// MaxBytesReader won't do all the internal stuff it must unless it's
@@ -728,7 +727,9 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
reader = http.MaxBytesReader(inw, r.Body, max)
}
}
var origBody io.ReadWriter
if perfStandby {
// Since we're checking PerfStandby here we key on origBody being nil
// or not later, so we need to always allocate so it's non-nil
@@ -749,16 +750,16 @@ func parseJSONRequest(perfStandby bool, r *http.Request, w http.ResponseWriter,
//
// A nil map will be returned if the format is empty or invalid.
func parseFormRequest(r *http.Request) (map[string]interface{}, error) {
maxRequestSize := r.Context().Value(logical.CtxKeyMaxRequestSize{})
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if logical.ContextContainsMaxRequestSize(r.Context()) {
max, ok := logical.ContextMaxRequestSizeValue(r.Context())
if !ok {
return nil, fmt.Errorf("could not parse %s from request context", logical.CtxKeyMaxRequestSize{})
return nil, errors.New("could not parse max request size from request context")
}
if max > 0 {
r.Body = ioutil.NopCloser(io.LimitReader(r.Body, max))
r.Body = io.NopCloser(io.LimitReader(r.Body, max))
}
}
if err := r.ParseForm(); err != nil {
return nil, err
}

View File

@@ -120,7 +120,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler
func disableReplicationStatusEndpointWrapping(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request := r.WithContext(context.WithValue(r.Context(), logical.CtxKeyDisableReplicationStatusEndpoints{}, true))
request := r.WithContext(logical.CreateContextDisableReplicationStatusEndpoints(r.Context(), true))
h.ServeHTTP(w, request)
})

View File

@@ -156,3 +156,7 @@ func (m *mockRunnerUtil) MlockEnabled() bool {
args := m.Called()
return args.Bool(0)
}
func (m *mockRunnerUtil) ClusterID(ctx context.Context) (string, error) {
return "clusterid", nil
}

View File

@@ -458,28 +458,83 @@ func (c CtxKeyRequestRole) String() string {
// context.Context to store the value `true` when the
// disable_replication_status_endpoints configuration parameter is set to true
// for the listener through which a request was received.
type CtxKeyDisableReplicationStatusEndpoints struct{}
type ctxKeyDisableReplicationStatusEndpoints struct{}
// String returns a string representation of the receiver type.
func (c CtxKeyDisableReplicationStatusEndpoints) String() string {
func (c ctxKeyDisableReplicationStatusEndpoints) String() string {
return "disable-replication-status-endpoints"
}
// ContextDisableReplicationStatusEndpointsValue examines the provided
// context.Context for the disable replication status endpoints value and
// returns it as a bool value if it's found along with the ok return value set
// to true; otherwise the ok return value is false.
func ContextDisableReplicationStatusEndpointsValue(ctx context.Context) (value, ok bool) {
value, ok = ctx.Value(ctxKeyDisableReplicationStatusEndpoints{}).(bool)
return
}
// CreateContextDisableReplicationStatusEndpoints creates a new context.Context
// based on the provided parent that also includes the provided value for the
// ctxKeyDisableReplicationStatusEndpoints key.
func CreateContextDisableReplicationStatusEndpoints(parent context.Context, value bool) context.Context {
return context.WithValue(parent, ctxKeyDisableReplicationStatusEndpoints{}, value)
}
// CtxKeyMaxRequestSize is a custom type used as a key in context.Context to
// store the value of the max_request_size set for the listener through which
// a request was received.
type CtxKeyMaxRequestSize struct{}
type ctxKeyMaxRequestSize struct{}
// String returns a string representation of the receiver type.
func (c CtxKeyMaxRequestSize) String() string {
func (c ctxKeyMaxRequestSize) String() string {
return "max_request_size"
}
// ContextMaxRequestSizeValue examines the provided context.Context for the max
// request size value and returns it as an int64 value if it's found along with
// the ok value set to true; otherwise the ok return value is false.
func ContextMaxRequestSizeValue(ctx context.Context) (value int64, ok bool) {
value, ok = ctx.Value(ctxKeyMaxRequestSize{}).(int64)
return
}
// CreateContextMaxRequestSize creates a new context.Context based on the
// provided parent that also includes the provided max request size value for
// the ctxKeyMaxRequestSize key.
func CreateContextMaxRequestSize(parent context.Context, value int64) context.Context {
return context.WithValue(parent, ctxKeyMaxRequestSize{}, value)
}
// ContextContainsMaxRequestSize returns a bool value that indicates if the
// provided Context contains a value for the ctxKeyMaxRequestSize key.
func ContextContainsMaxRequestSize(ctx context.Context) bool {
return ctx.Value(ctxKeyMaxRequestSize{}) != nil
}
// CtxKeyOriginalRequestPath is a custom type used as a key in context.Context
// to store the original request path.
type CtxKeyOriginalRequestPath struct{}
type ctxKeyOriginalRequestPath struct{}
// String returns a string representation of the receiver type.
func (c CtxKeyOriginalRequestPath) String() string {
func (c ctxKeyOriginalRequestPath) String() string {
return "original_request_path"
}
// ContextOriginalRequestPathValue examines the provided context.Context for the
// original request path value and returns it as a string value if it's found
// along with the ok value set to true; otherwise the ok return value is false.
func ContextOriginalRequestPathValue(ctx context.Context) (value string, ok bool) {
value, ok = ctx.Value(ctxKeyOriginalRequestPath{}).(string)
return
}
// CreateContextOriginalRequestPath creates a new context.Context based on the
// provided parent that also includes the provided original request path value
// for the ctxKeyOriginalRequestPath key.
func CreateContextOriginalRequestPath(parent context.Context, value string) context.Context {
return context.WithValue(parent, ctxKeyOriginalRequestPath{}, value)
}

206
sdk/logical/request_test.go Normal file
View File

@@ -0,0 +1,206 @@
package logical
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestContextDisableReplicationStatusEndpointsValue(t *testing.T) {
testcases := []struct {
name string
ctx context.Context
expectedValue bool
expectedOk bool
}{
{
name: "without-value",
ctx: context.Background(),
expectedValue: false,
expectedOk: false,
},
{
name: "with-nil",
ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, nil),
expectedValue: false,
expectedOk: false,
},
{
name: "with-incompatible-value",
ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, "true"),
expectedValue: false,
expectedOk: false,
},
{
name: "with-bool-true",
ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, true),
expectedValue: true,
expectedOk: true,
},
{
name: "with-bool-false",
ctx: context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, false),
expectedValue: false,
expectedOk: true,
},
}
for _, testcase := range testcases {
value, ok := ContextDisableReplicationStatusEndpointsValue(testcase.ctx)
assert.Equal(t, testcase.expectedValue, value, testcase.name)
assert.Equal(t, testcase.expectedOk, ok, testcase.name)
}
}
func TestCreateContextDisableReplicationStatusEndpoints(t *testing.T) {
ctx := CreateContextDisableReplicationStatusEndpoints(context.Background(), true)
value := ctx.Value(ctxKeyDisableReplicationStatusEndpoints{})
assert.NotNil(t, ctx)
assert.NotNil(t, value)
assert.IsType(t, bool(false), value)
assert.Equal(t, true, value.(bool))
ctx = CreateContextDisableReplicationStatusEndpoints(context.Background(), false)
value = ctx.Value(ctxKeyDisableReplicationStatusEndpoints{})
assert.NotNil(t, ctx)
assert.NotNil(t, value)
assert.IsType(t, bool(false), value)
assert.Equal(t, false, value.(bool))
}
func TestContextMaxRequestSizeValue(t *testing.T) {
testcases := []struct {
name string
ctx context.Context
expectedValue int64
expectedOk bool
}{
{
name: "without-value",
ctx: context.Background(),
expectedValue: 0,
expectedOk: false,
},
{
name: "with-nil",
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, nil),
expectedValue: 0,
expectedOk: false,
},
{
name: "with-incompatible-value",
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, "6666"),
expectedValue: 0,
expectedOk: false,
},
{
name: "with-int64-8888",
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(8888)),
expectedValue: 8888,
expectedOk: true,
},
{
name: "with-int64-zero",
ctx: context.WithValue(context.Background(), ctxKeyMaxRequestSize{}, int64(0)),
expectedValue: 0,
expectedOk: true,
},
}
for _, testcase := range testcases {
value, ok := ContextMaxRequestSizeValue(testcase.ctx)
assert.Equal(t, testcase.expectedValue, value, testcase.name)
assert.Equal(t, testcase.expectedOk, ok, testcase.name)
}
}
func TestCreateContextMaxRequestSize(t *testing.T) {
ctx := CreateContextMaxRequestSize(context.Background(), int64(8888))
value := ctx.Value(ctxKeyMaxRequestSize{})
assert.NotNil(t, ctx)
assert.NotNil(t, value)
assert.IsType(t, int64(0), value)
assert.Equal(t, int64(8888), value.(int64))
ctx = CreateContextMaxRequestSize(context.Background(), int64(0))
value = ctx.Value(ctxKeyMaxRequestSize{})
assert.NotNil(t, ctx)
assert.NotNil(t, value)
assert.IsType(t, int64(0), value)
assert.Equal(t, int64(0), value.(int64))
}
func TestContextOriginalRequestPathValue(t *testing.T) {
testcases := []struct {
name string
ctx context.Context
expectedValue string
expectedOk bool
}{
{
name: "without-value",
ctx: context.Background(),
expectedValue: "",
expectedOk: false,
},
{
name: "with-nil",
ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, nil),
expectedValue: "",
expectedOk: false,
},
{
name: "with-incompatible-value",
ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, 6666),
expectedValue: "",
expectedOk: false,
},
{
name: "with-string-value",
ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, "test"),
expectedValue: "test",
expectedOk: true,
},
{
name: "with-empty-string",
ctx: context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, ""),
expectedValue: "",
expectedOk: true,
},
}
for _, testcase := range testcases {
value, ok := ContextOriginalRequestPathValue(testcase.ctx)
assert.Equal(t, testcase.expectedValue, value, testcase.name)
assert.Equal(t, testcase.expectedOk, ok, testcase.name)
}
}
func TestCreateContextOriginalRequestPath(t *testing.T) {
ctx := CreateContextOriginalRequestPath(context.Background(), "test")
value := ctx.Value(ctxKeyOriginalRequestPath{})
assert.NotNil(t, ctx)
assert.NotNil(t, value)
assert.IsType(t, string(""), value)
assert.Equal(t, "test", value.(string))
ctx = CreateContextOriginalRequestPath(context.Background(), "")
value = ctx.Value(ctxKeyOriginalRequestPath{})
assert.NotNil(t, ctx)
assert.NotNil(t, value)
assert.IsType(t, string(""), value)
assert.Equal(t, "", value.(string))
}

View File

@@ -338,7 +338,7 @@ func testCluster_ForwardRequests(t *testing.T, c *TestClusterCore, rootToken, re
t.Fatal(err)
}
req.Header.Add(consts.AuthHeaderName, rootToken)
req = req.WithContext(context.WithValue(req.Context(), logical.CtxKeyOriginalRequestPath{}, req.URL.Path))
req = req.WithContext(logical.CreateContextOriginalRequestPath(req.Context(), req.URL.Path))
statusCode, header, respBytes, err := c.ForwardRequest(req)
if err != nil {

View File

@@ -350,7 +350,12 @@ func (c *Core) ForwardRequest(req *http.Request) (int, http.Header, []byte, erro
req.URL.Path = origPath
}()
req.URL.Path = req.Context().Value(logical.CtxKeyOriginalRequestPath{}).(string)
path, ok := logical.ContextOriginalRequestPathValue(req.Context())
if !ok {
return 0, nil, nil, errors.New("error extracting request path for forwarding RPC request")
}
req.URL.Path = path
freq, err := forwarding.GenerateForwardedRequest(req)
if err != nil {

View File

@@ -566,9 +566,8 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R
if ok {
ctx = context.WithValue(ctx, logical.CtxKeyRequestRole{}, requestRole)
}
disable_repl_status, ok := httpCtx.Value(logical.CtxKeyDisableReplicationStatusEndpoints{}).(string)
if ok {
ctx = context.WithValue(ctx, logical.CtxKeyDisableReplicationStatusEndpoints{}, disable_repl_status)
if disable_repl_status, ok := logical.ContextDisableReplicationStatusEndpointsValue(httpCtx); ok {
ctx = logical.CreateContextDisableReplicationStatusEndpoints(ctx, disable_repl_status)
}
resp, err = c.handleCancelableRequest(ctx, req)
req.SetTokenEntry(nil)