Save the original request body for forwarding (#6538)

* Save the original request body for forwarding

If we are forwarding a request after initial parsing the request body is
already consumed. As a result a forwarded call containing a request body
will have the body be nil. This saves the original request body for a
given request via a TeeReader and uses that in cases of forwarding past
body consumption.
This commit is contained in:
Jeff Mitchell
2019-04-05 14:36:34 -04:00
committed by GitHub
parent 87b11cd949
commit b1df69d8d5
6 changed files with 46 additions and 29 deletions

View File

@@ -1,11 +1,13 @@
package http
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/textproto"
@@ -444,7 +446,7 @@ func (fs *UIAssetWrapper) Open(name string) (http.File, error) {
return nil, err
}
func parseRequest(r *http.Request, w http.ResponseWriter, out interface{}) error {
func parseRequest(core *vault.Core, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) {
// Limit the maximum number of bytes to MaxRequestSize to protect
// against an indefinite amount of data being read.
reader := r.Body
@@ -453,17 +455,27 @@ func parseRequest(r *http.Request, w http.ResponseWriter, out interface{}) error
if maxRequestSize != nil {
max, ok := maxRequestSize.(int64)
if !ok {
return errors.New("could not parse max_request_size from request context")
return nil, errors.New("could not parse max_request_size from request context")
}
if max > 0 {
reader = http.MaxBytesReader(w, r.Body, max)
}
}
var origBody io.ReadWriter
if core.PerfStandby() {
// Since we're checking PerfStandby here we key on origBody being nil
// or not later, so we need to always allocate so it's non-nil
origBody = new(bytes.Buffer)
reader = ioutil.NopCloser(io.TeeReader(reader, origBody))
}
err := jsonutil.DecodeJSONFromReader(reader, out)
if err != nil && err != io.EOF {
return errwrap.Wrapf("failed to parse JSON input: {{err}}", err)
return nil, errwrap.Wrapf("failed to parse JSON input: {{err}}", err)
}
return err
if origBody != nil {
return ioutil.NopCloser(origBody), err
}
return nil, err
}
// handleRequestForwarding determines whether to forward a request or not,

View File

@@ -18,14 +18,15 @@ import (
"github.com/hashicorp/vault/vault"
)
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) {
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
ns, err := namespace.FromContext(r.Context())
if err != nil {
return nil, http.StatusBadRequest, nil
return nil, nil, http.StatusBadRequest, nil
}
path := ns.TrimmedPath(r.URL.Path[len("/v1/"):])
var data map[string]interface{}
var origBody io.ReadCloser
// Determine the operation
var op logical.Operation
@@ -42,7 +43,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
if listStr != "" {
list, err = strconv.ParseBool(listStr)
if err != nil {
return nil, http.StatusBadRequest, nil
return nil, nil, http.StatusBadRequest, nil
}
if list {
op = logical.ListOperation
@@ -79,13 +80,13 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
op = logical.UpdateOperation
// Parse the request if we can
if op == logical.UpdateOperation {
err := parseRequest(r, w, &data)
origBody, err = parseRequest(core, r, w, &data)
if err == io.EOF {
data = nil
err = nil
}
if err != nil {
return nil, http.StatusBadRequest, err
return nil, nil, http.StatusBadRequest, err
}
}
@@ -97,12 +98,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
case "OPTIONS":
default:
return nil, http.StatusMethodNotAllowed, nil
return nil, nil, http.StatusMethodNotAllowed, nil
}
request_id, err := uuid.GenerateUUID()
if err != nil {
return nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err)
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err)
}
req, err := requestAuth(core, r, &logical.Request{
@@ -115,27 +116,27 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
})
if err != nil {
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
return nil, http.StatusForbidden, nil
return nil, nil, http.StatusForbidden, nil
}
return nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err)
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err)
}
req, err = requestWrapInfo(r, req)
if err != nil {
return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
}
err = parseMFAHeader(req)
if err != nil {
return nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err)
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err)
}
err = requestPolicyOverride(r, req)
if err != nil {
return nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
}
return req, 0, nil
return req, origBody, 0, nil
}
func handleLogical(core *vault.Core) http.Handler {
@@ -148,14 +149,17 @@ func handleLogicalWithInjector(core *vault.Core) http.Handler {
func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(core, w, r)
req, origBody, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return
}
// Always forward requests that are using a limited use count token
if core.PerfStandby() && req.ClientTokenRemainingUses > 0 {
// Always forward requests that are using a limited use count token.
// origBody will not be nil if it's a perf standby as it checks
// PerfStandby() but will be nil otherwise.
if origBody != nil && req.ClientTokenRemainingUses > 0 {
r.Body = origBody
forwardRequest(core, w, r)
return
}
@@ -271,6 +275,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool) http.H
// success.
resp, ok, needsForward := request(core, w, r, req)
if needsForward {
r.Body = origBody
forwardRequest(core, w, r)
return
}

View File

@@ -86,7 +86,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r
func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) {
// Parse the request
var req GenerateRootInitRequest
if err := parseRequest(r, w, &req); err != nil && err != io.EOF {
if _, err := parseRequest(core, r, w, &req); err != nil && err != io.EOF {
respondError(w, http.StatusBadRequest, err)
return
}
@@ -132,7 +132,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Parse the request
var req GenerateRootUpdateRequest
if err := parseRequest(r, w, &req); err != nil {
if _, err := parseRequest(core, r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}

View File

@@ -40,7 +40,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request)
// Parse the request
var req InitRequest
if err := parseRequest(r, w, &req); err != nil {
if _, err := parseRequest(core, r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}

View File

@@ -108,7 +108,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool,
func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
// Parse the request
var req RekeyRequest
if err := parseRequest(r, w, &req); err != nil {
if _, err := parseRequest(core, r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
@@ -158,7 +158,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler {
// Parse the request
var req RekeyUpdateRequest
if err := parseRequest(r, w, &req); err != nil {
if _, err := parseRequest(core, r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}
@@ -306,7 +306,7 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery
func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
// Parse the request
var req RekeyVerificationUpdateRequest
if err := parseRequest(r, w, &req); err != nil {
if _, err := parseRequest(core, r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}

View File

@@ -17,7 +17,7 @@ import (
func handleSysSeal(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(core, w, r)
req, _, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return
@@ -47,7 +47,7 @@ func handleSysSeal(core *vault.Core) http.Handler {
func handleSysStepDown(core *vault.Core) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req, statusCode, err := buildLogicalRequest(core, w, r)
req, _, statusCode, err := buildLogicalRequest(core, w, r)
if err != nil || statusCode != 0 {
respondError(w, statusCode, err)
return
@@ -86,7 +86,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {
// Parse the request
var req UnsealRequest
if err := parseRequest(r, w, &req); err != nil {
if _, err := parseRequest(core, r, w, &req); err != nil {
respondError(w, http.StatusBadRequest, err)
return
}