mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-02 03:27:54 +00:00
Save the original request body for forwarding (#6538)
* Save the original request body for forwarding If we are forwarding a request after initial parsing the request body is already consumed. As a result a forwarded call containing a request body will have the body be nil. This saves the original request body for a given request via a TeeReader and uses that in cases of forwarding past body consumption.
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
@@ -444,7 +446,7 @@ func (fs *UIAssetWrapper) Open(name string) (http.File, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func parseRequest(r *http.Request, w http.ResponseWriter, out interface{}) error {
|
||||
func parseRequest(core *vault.Core, r *http.Request, w http.ResponseWriter, out interface{}) (io.ReadCloser, error) {
|
||||
// Limit the maximum number of bytes to MaxRequestSize to protect
|
||||
// against an indefinite amount of data being read.
|
||||
reader := r.Body
|
||||
@@ -453,17 +455,27 @@ func parseRequest(r *http.Request, w http.ResponseWriter, out interface{}) error
|
||||
if maxRequestSize != nil {
|
||||
max, ok := maxRequestSize.(int64)
|
||||
if !ok {
|
||||
return errors.New("could not parse max_request_size from request context")
|
||||
return nil, errors.New("could not parse max_request_size from request context")
|
||||
}
|
||||
if max > 0 {
|
||||
reader = http.MaxBytesReader(w, r.Body, max)
|
||||
}
|
||||
}
|
||||
var origBody io.ReadWriter
|
||||
if core.PerfStandby() {
|
||||
// Since we're checking PerfStandby here we key on origBody being nil
|
||||
// or not later, so we need to always allocate so it's non-nil
|
||||
origBody = new(bytes.Buffer)
|
||||
reader = ioutil.NopCloser(io.TeeReader(reader, origBody))
|
||||
}
|
||||
err := jsonutil.DecodeJSONFromReader(reader, out)
|
||||
if err != nil && err != io.EOF {
|
||||
return errwrap.Wrapf("failed to parse JSON input: {{err}}", err)
|
||||
return nil, errwrap.Wrapf("failed to parse JSON input: {{err}}", err)
|
||||
}
|
||||
return err
|
||||
if origBody != nil {
|
||||
return ioutil.NopCloser(origBody), err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// handleRequestForwarding determines whether to forward a request or not,
|
||||
|
||||
@@ -18,14 +18,15 @@ import (
|
||||
"github.com/hashicorp/vault/vault"
|
||||
)
|
||||
|
||||
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, int, error) {
|
||||
func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) (*logical.Request, io.ReadCloser, int, error) {
|
||||
ns, err := namespace.FromContext(r.Context())
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, nil
|
||||
return nil, nil, http.StatusBadRequest, nil
|
||||
}
|
||||
path := ns.TrimmedPath(r.URL.Path[len("/v1/"):])
|
||||
|
||||
var data map[string]interface{}
|
||||
var origBody io.ReadCloser
|
||||
|
||||
// Determine the operation
|
||||
var op logical.Operation
|
||||
@@ -42,7 +43,7 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
||||
if listStr != "" {
|
||||
list, err = strconv.ParseBool(listStr)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, nil
|
||||
return nil, nil, http.StatusBadRequest, nil
|
||||
}
|
||||
if list {
|
||||
op = logical.ListOperation
|
||||
@@ -79,13 +80,13 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
||||
op = logical.UpdateOperation
|
||||
// Parse the request if we can
|
||||
if op == logical.UpdateOperation {
|
||||
err := parseRequest(r, w, &data)
|
||||
origBody, err = parseRequest(core, r, w, &data)
|
||||
if err == io.EOF {
|
||||
data = nil
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
return nil, nil, http.StatusBadRequest, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,12 +98,12 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
||||
|
||||
case "OPTIONS":
|
||||
default:
|
||||
return nil, http.StatusMethodNotAllowed, nil
|
||||
return nil, nil, http.StatusMethodNotAllowed, nil
|
||||
}
|
||||
|
||||
request_id, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err)
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to generate identifier for the request: {{err}}", err)
|
||||
}
|
||||
|
||||
req, err := requestAuth(core, r, &logical.Request{
|
||||
@@ -115,27 +116,27 @@ func buildLogicalRequest(core *vault.Core, w http.ResponseWriter, r *http.Reques
|
||||
})
|
||||
if err != nil {
|
||||
if errwrap.Contains(err, logical.ErrPermissionDenied.Error()) {
|
||||
return nil, http.StatusForbidden, nil
|
||||
return nil, nil, http.StatusForbidden, nil
|
||||
}
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err)
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error performing token check: {{err}}", err)
|
||||
}
|
||||
|
||||
req, err = requestWrapInfo(r, req)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("error parsing X-Vault-Wrap-TTL header: {{err}}", err)
|
||||
}
|
||||
|
||||
err = parseMFAHeader(req)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err)
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf("failed to parse X-Vault-MFA header: {{err}}", err)
|
||||
}
|
||||
|
||||
err = requestPolicyOverride(r, req)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
|
||||
return nil, nil, http.StatusBadRequest, errwrap.Wrapf(fmt.Sprintf(`failed to parse %s header: {{err}}`, PolicyOverrideHeaderName), err)
|
||||
}
|
||||
|
||||
return req, 0, nil
|
||||
return req, origBody, 0, nil
|
||||
}
|
||||
|
||||
func handleLogical(core *vault.Core) http.Handler {
|
||||
@@ -148,14 +149,17 @@ func handleLogicalWithInjector(core *vault.Core) http.Handler {
|
||||
|
||||
func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
req, origBody, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Always forward requests that are using a limited use count token
|
||||
if core.PerfStandby() && req.ClientTokenRemainingUses > 0 {
|
||||
// Always forward requests that are using a limited use count token.
|
||||
// origBody will not be nil if it's a perf standby as it checks
|
||||
// PerfStandby() but will be nil otherwise.
|
||||
if origBody != nil && req.ClientTokenRemainingUses > 0 {
|
||||
r.Body = origBody
|
||||
forwardRequest(core, w, r)
|
||||
return
|
||||
}
|
||||
@@ -271,6 +275,7 @@ func handleLogicalInternal(core *vault.Core, injectDataIntoTopLevel bool) http.H
|
||||
// success.
|
||||
resp, ok, needsForward := request(core, w, r, req)
|
||||
if needsForward {
|
||||
r.Body = origBody
|
||||
forwardRequest(core, w, r)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func handleSysGenerateRootAttemptGet(core *vault.Core, w http.ResponseWriter, r
|
||||
func handleSysGenerateRootAttemptPut(core *vault.Core, w http.ResponseWriter, r *http.Request, generateStrategy vault.GenerateRootStrategy) {
|
||||
// Parse the request
|
||||
var req GenerateRootInitRequest
|
||||
if err := parseRequest(r, w, &req); err != nil && err != io.EOF {
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil && err != io.EOF {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func handleSysGenerateRootUpdate(core *vault.Core, generateStrategy vault.Genera
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request
|
||||
var req GenerateRootUpdateRequest
|
||||
if err := parseRequest(r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func handleSysInitPut(core *vault.Core, w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// Parse the request
|
||||
var req InitRequest
|
||||
if err := parseRequest(r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func handleSysRekeyInitGet(ctx context.Context, core *vault.Core, recovery bool,
|
||||
func handleSysRekeyInitPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request
|
||||
var req RekeyRequest
|
||||
if err := parseRequest(r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func handleSysRekeyUpdate(core *vault.Core, recovery bool) http.Handler {
|
||||
|
||||
// Parse the request
|
||||
var req RekeyUpdateRequest
|
||||
if err := parseRequest(r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
@@ -306,7 +306,7 @@ func handleSysRekeyVerifyDelete(ctx context.Context, core *vault.Core, recovery
|
||||
func handleSysRekeyVerifyPut(ctx context.Context, core *vault.Core, recovery bool, w http.ResponseWriter, r *http.Request) {
|
||||
// Parse the request
|
||||
var req RekeyVerificationUpdateRequest
|
||||
if err := parseRequest(r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
|
||||
func handleSysSeal(core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
req, _, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
@@ -47,7 +47,7 @@ func handleSysSeal(core *vault.Core) http.Handler {
|
||||
|
||||
func handleSysStepDown(core *vault.Core) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
req, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
req, _, statusCode, err := buildLogicalRequest(core, w, r)
|
||||
if err != nil || statusCode != 0 {
|
||||
respondError(w, statusCode, err)
|
||||
return
|
||||
@@ -86,7 +86,7 @@ func handleSysUnseal(core *vault.Core) http.Handler {
|
||||
|
||||
// Parse the request
|
||||
var req UnsealRequest
|
||||
if err := parseRequest(r, w, &req); err != nil {
|
||||
if _, err := parseRequest(core, r, w, &req); err != nil {
|
||||
respondError(w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user