mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	Clean up request logic and use retryable's more efficient handling (#4670)
This commit is contained in:
		| @@ -635,7 +635,7 @@ func (c *Client) RawRequest(r *Request) (*Response, error) { | ||||
|  | ||||
| 	redirectCount := 0 | ||||
| START: | ||||
| 	req, err := r.toRetryableHTTP(false) | ||||
| 	req, err := r.toRetryableHTTP() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|   | ||||
| @@ -22,6 +22,12 @@ type Request struct { | ||||
| 	MFAHeaderVals []string | ||||
| 	WrapTTL       string | ||||
| 	Obj           interface{} | ||||
|  | ||||
| 	// When possible, use BodyBytes as it is more efficient due to how the | ||||
| 	// retry logic works | ||||
| 	BodyBytes []byte | ||||
|  | ||||
| 	// Fallback | ||||
| 	Body     io.Reader | ||||
| 	BodySize int64 | ||||
|  | ||||
| @@ -33,68 +39,76 @@ type Request struct { | ||||
|  | ||||
| // SetJSONBody is used to set a request body that is a JSON-encoded value. | ||||
| func (r *Request) SetJSONBody(val interface{}) error { | ||||
| 	buf := bytes.NewBuffer(nil) | ||||
| 	enc := json.NewEncoder(buf) | ||||
| 	if err := enc.Encode(val); err != nil { | ||||
| 	buf, err := json.Marshal(val) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	r.Obj = val | ||||
| 	r.Body = buf | ||||
| 	r.BodySize = int64(buf.Len()) | ||||
| 	r.BodyBytes = buf | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ResetJSONBody is used to reset the body for a redirect | ||||
| func (r *Request) ResetJSONBody() error { | ||||
| 	if r.Body == nil { | ||||
| 	if r.BodyBytes == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return r.SetJSONBody(r.Obj) | ||||
| } | ||||
|  | ||||
| // ToHTTP turns this request into a valid *http.Request for use with the | ||||
| // net/http package. | ||||
| // DEPRECATED: ToHTTP turns this request into a valid *http.Request for use | ||||
| // with the net/http package. | ||||
| func (r *Request) ToHTTP() (*http.Request, error) { | ||||
| 	req, err := r.toRetryableHTTP(true) | ||||
| 	req, err := r.toRetryableHTTP() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	switch { | ||||
| 	case r.BodyBytes == nil && r.Body == nil: | ||||
| 		// No body | ||||
|  | ||||
| 	case r.BodyBytes != nil: | ||||
| 		req.Request.Body = ioutil.NopCloser(bytes.NewReader(r.BodyBytes)) | ||||
|  | ||||
| 	default: | ||||
| 		if c, ok := r.Body.(io.ReadCloser); ok { | ||||
| 			req.Request.Body = c | ||||
| 		} else { | ||||
| 			req.Request.Body = ioutil.NopCloser(r.Body) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return req.Request, nil | ||||
| } | ||||
|  | ||||
| // legacy indicates whether we want to return a request derived from | ||||
| // http.NewRequest instead of retryablehttp.NewRequest, so that legacy clents | ||||
| // that might be using the public ToHTTP method still work | ||||
| func (r *Request) toRetryableHTTP(legacy bool) (*retryablehttp.Request, error) { | ||||
| func (r *Request) toRetryableHTTP() (*retryablehttp.Request, error) { | ||||
| 	// Encode the query parameters | ||||
| 	r.URL.RawQuery = r.Params.Encode() | ||||
|  | ||||
| 	// Create the HTTP request, defaulting to retryable | ||||
| 	var req *retryablehttp.Request | ||||
|  | ||||
| 	if legacy { | ||||
| 		regReq, err := http.NewRequest(r.Method, r.URL.RequestURI(), r.Body) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		req = &retryablehttp.Request{ | ||||
| 			Request: regReq, | ||||
| 		} | ||||
| 	} else { | ||||
| 		var buf []byte | ||||
| 	var err error | ||||
| 		if r.Body != nil { | ||||
| 			buf, err = ioutil.ReadAll(r.Body) | ||||
| 	var body interface{} | ||||
|  | ||||
| 	switch { | ||||
| 	case r.BodyBytes == nil && r.Body == nil: | ||||
| 		// No body | ||||
|  | ||||
| 	case r.BodyBytes != nil: | ||||
| 		// Use bytes, it's more efficient | ||||
| 		body = r.BodyBytes | ||||
|  | ||||
| 	default: | ||||
| 		body = r.Body | ||||
| 	} | ||||
|  | ||||
| 	req, err = retryablehttp.NewRequest(r.Method, r.URL.RequestURI(), body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 		} | ||||
| 		req, err = retryablehttp.NewRequest(r.Method, r.URL.RequestURI(), bytes.NewReader(buf)) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	req.URL.User = r.URL.User | ||||
| 	req.URL.Scheme = r.URL.Scheme | ||||
|   | ||||
| @@ -1,8 +1,6 @@ | ||||
| package api | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
| @@ -14,20 +12,11 @@ func TestRequestSetJSONBody(t *testing.T) { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	if _, err := io.Copy(&buf, r.Body); err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	expected := `{"foo":"bar"}` | ||||
| 	actual := strings.TrimSpace(buf.String()) | ||||
| 	actual := strings.TrimSpace(string(r.BodyBytes)) | ||||
| 	if actual != expected { | ||||
| 		t.Fatalf("bad: %s", actual) | ||||
| 	} | ||||
|  | ||||
| 	if int64(len(buf.String())) != r.BodySize { | ||||
| 		t.Fatalf("bad: %d", len(actual)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRequestResetJSONBody(t *testing.T) { | ||||
| @@ -37,27 +26,16 @@ func TestRequestResetJSONBody(t *testing.T) { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	if _, err := io.Copy(&buf, r.Body); err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	if err := r.ResetJSONBody(); err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	var buf2 bytes.Buffer | ||||
| 	if _, err := io.Copy(&buf2, r.Body); err != nil { | ||||
| 		t.Fatalf("err: %s", err) | ||||
| 	} | ||||
| 	buf := make([]byte, len(r.BodyBytes)) | ||||
| 	copy(buf, r.BodyBytes) | ||||
|  | ||||
| 	expected := `{"foo":"bar"}` | ||||
| 	actual := strings.TrimSpace(buf2.String()) | ||||
| 	actual := strings.TrimSpace(string(buf)) | ||||
| 	if actual != expected { | ||||
| 		t.Fatalf("bad: %s", actual) | ||||
| 	} | ||||
|  | ||||
| 	if int64(len(buf2.String())) != r.BodySize { | ||||
| 		t.Fatalf("bad: %d", len(actual)) | ||||
| 		t.Fatalf("bad: actual %s, expected %s", actual, expected) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Mitchell
					Jeff Mitchell