mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	[VAULT-5003] Use net/http client in Sys().RaftSnapshotRestore (#14269)
Use net/http client when body could be too big for retryablehttp client
This commit is contained in:
		
							
								
								
									
										154
									
								
								api/client.go
									
									
									
									
									
								
							
							
						
						
									
										154
									
								
								api/client.go
									
									
									
									
									
								
							| @@ -53,6 +53,14 @@ const ( | ||||
| 	HeaderIndex           = "X-Vault-Index" | ||||
| 	HeaderForward         = "X-Vault-Forward" | ||||
| 	HeaderInconsistent    = "X-Vault-Inconsistent" | ||||
| 	TLSErrorString        = "This error usually means that the server is running with TLS disabled\n" + | ||||
| 		"but the client is configured to use TLS. Please either enable TLS\n" + | ||||
| 		"on the server or run the client with -address set to an address\n" + | ||||
| 		"that uses the http protocol:\n\n" + | ||||
| 		"    vault <command> -address http://<address>\n\n" + | ||||
| 		"You can also set the VAULT_ADDR environment variable:\n\n\n" + | ||||
| 		"    VAULT_ADDR=http://<address> vault <command>\n\n" + | ||||
| 		"where <address> is replaced by the actual address to the server." | ||||
| ) | ||||
|  | ||||
| // Deprecated values | ||||
| @@ -1127,12 +1135,9 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon | ||||
| 		limiter.Wait(ctx) | ||||
| 	} | ||||
|  | ||||
| 	// Sanity check the token before potentially erroring from the API | ||||
| 	idx := strings.IndexFunc(token, func(c rune) bool { | ||||
| 		return !unicode.IsPrint(c) | ||||
| 	}) | ||||
| 	if idx != -1 { | ||||
| 		return nil, fmt.Errorf("configured Vault token contains non-printable characters and cannot be used") | ||||
| 	// check the token before potentially erroring from the API | ||||
| 	if err := validateToken(token); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	redirectCount := 0 | ||||
| @@ -1192,17 +1197,7 @@ START: | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		if strings.Contains(err.Error(), "tls: oversized") { | ||||
| 			err = errwrap.Wrapf( | ||||
| 				"{{err}}\n\n"+ | ||||
| 					"This error usually means that the server is running with TLS disabled\n"+ | ||||
| 					"but the client is configured to use TLS. Please either enable TLS\n"+ | ||||
| 					"on the server or run the client with -address set to an address\n"+ | ||||
| 					"that uses the http protocol:\n\n"+ | ||||
| 					"    vault <command> -address http://<address>\n\n"+ | ||||
| 					"You can also set the VAULT_ADDR environment variable:\n\n\n"+ | ||||
| 					"    VAULT_ADDR=http://<address> vault <command>\n\n"+ | ||||
| 					"where <address> is replaced by the actual address to the server.", | ||||
| 				err) | ||||
| 			err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err) | ||||
| 		} | ||||
| 		return result, err | ||||
| 	} | ||||
| @@ -1249,6 +1244,120 @@ START: | ||||
| 	return result, nil | ||||
| } | ||||
|  | ||||
| // httpRequestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is | ||||
| // useful when making calls where a net/http client is desirable. A single redirect (status code 301, 302, | ||||
| // or 307) will be followed but all retry and timeout logic is the responsibility of the caller as is | ||||
| // closing the Response body. | ||||
| func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Response, error) { | ||||
| 	req, err := http.NewRequestWithContext(ctx, r.Method, r.URL.RequestURI(), r.Body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	c.modifyLock.RLock() | ||||
| 	token := c.token | ||||
|  | ||||
| 	c.config.modifyLock.RLock() | ||||
| 	limiter := c.config.Limiter | ||||
| 	httpClient := c.config.HttpClient | ||||
| 	outputCurlString := c.config.OutputCurlString | ||||
| 	if c.headers != nil { | ||||
| 		for header, vals := range c.headers { | ||||
| 			for _, val := range vals { | ||||
| 				req.Header.Add(header, val) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	c.config.modifyLock.RUnlock() | ||||
| 	c.modifyLock.RUnlock() | ||||
|  | ||||
| 	// OutputCurlString logic relies on the request type to be retryable.Request as | ||||
| 	if outputCurlString { | ||||
| 		return nil, fmt.Errorf("output-curl-string is not implemented for this request") | ||||
| 	} | ||||
|  | ||||
| 	req.URL.User = r.URL.User | ||||
| 	req.URL.Scheme = r.URL.Scheme | ||||
| 	req.URL.Host = r.URL.Host | ||||
| 	req.Host = r.URL.Host | ||||
|  | ||||
| 	if len(r.ClientToken) != 0 { | ||||
| 		req.Header.Set(consts.AuthHeaderName, r.ClientToken) | ||||
| 	} | ||||
|  | ||||
| 	if len(r.WrapTTL) != 0 { | ||||
| 		req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) | ||||
| 	} | ||||
|  | ||||
| 	if len(r.MFAHeaderVals) != 0 { | ||||
| 		for _, mfaHeaderVal := range r.MFAHeaderVals { | ||||
| 			req.Header.Add("X-Vault-MFA", mfaHeaderVal) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if r.PolicyOverride { | ||||
| 		req.Header.Set("X-Vault-Policy-Override", "true") | ||||
| 	} | ||||
|  | ||||
| 	if limiter != nil { | ||||
| 		limiter.Wait(ctx) | ||||
| 	} | ||||
|  | ||||
| 	// check the token before potentially erroring from the API | ||||
| 	if err := validateToken(token); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var result *Response | ||||
|  | ||||
| 	resp, err := httpClient.Do(req) | ||||
|  | ||||
| 	if resp != nil { | ||||
| 		result = &Response{Response: resp} | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		if strings.Contains(err.Error(), "tls: oversized") { | ||||
| 			err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err) | ||||
| 		} | ||||
| 		return result, err | ||||
| 	} | ||||
|  | ||||
| 	// Check for a redirect, only allowing for a single redirect | ||||
| 	if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { | ||||
| 		// Parse the updated location | ||||
| 		respLoc, err := resp.Location() | ||||
| 		if err != nil { | ||||
| 			return result, fmt.Errorf("redirect failed: %s", err) | ||||
| 		} | ||||
|  | ||||
| 		// Ensure a protocol downgrade doesn't happen | ||||
| 		if req.URL.Scheme == "https" && respLoc.Scheme != "https" { | ||||
| 			return result, fmt.Errorf("redirect would cause protocol downgrade") | ||||
| 		} | ||||
|  | ||||
| 		// Update the request | ||||
| 		req.URL = respLoc | ||||
|  | ||||
| 		// Reset the request body if any | ||||
| 		if err := r.ResetJSONBody(); err != nil { | ||||
| 			return result, fmt.Errorf("redirect failed: %s", err) | ||||
| 		} | ||||
|  | ||||
| 		// Retry the request | ||||
| 		resp, err = httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return result, fmt.Errorf("redirect failed: %s", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := result.Error(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return result, nil | ||||
| } | ||||
|  | ||||
| type ( | ||||
| 	RequestCallback  func(*Request) | ||||
| 	ResponseCallback func(*Response) | ||||
| @@ -1466,3 +1575,14 @@ func (w *replicationStateStore) states() []string { | ||||
| 	copy(c, w.store) | ||||
| 	return c | ||||
| } | ||||
|  | ||||
| // validateToken will check for non-printable characters to prevent a call that will fail at the api | ||||
| func validateToken(t string) error { | ||||
| 	idx := strings.IndexFunc(t, func(c rune) bool { | ||||
| 		return !unicode.IsPrint(c) | ||||
| 	}) | ||||
| 	if idx != -1 { | ||||
| 		return fmt.Errorf("configured Vault token contains non-printable characters and cannot be used") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
							
								
								
									
										106
									
								
								api/sys_raft.go
									
									
									
									
									
								
							
							
						
						
									
										106
									
								
								api/sys_raft.go
									
									
									
									
									
								
							| @@ -6,7 +6,6 @@ import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| @@ -14,7 +13,6 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/go-secure-stdlib/parseutil" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/consts" | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| ) | ||||
|  | ||||
| @@ -132,87 +130,25 @@ func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) { | ||||
| 	return &result, err | ||||
| } | ||||
|  | ||||
| // RaftSnapshot invokes the API that takes the snapshot of the raft cluster and | ||||
| // writes it to the supplied io.Writer. | ||||
| // RaftSnapshot is a thin wrapper around RaftSnapshotWithContext | ||||
| func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { | ||||
| 	ctx, cancelFunc := context.WithCancel(context.Background()) | ||||
| 	defer cancelFunc() | ||||
|  | ||||
| 	return c.RaftSnapshotWithContext(ctx, snapWriter) | ||||
| } | ||||
|  | ||||
| // RaftSnapshotWithContext invokes the API that takes the snapshot of the raft cluster and | ||||
| // writes it to the supplied io.Writer. | ||||
| func (c *Sys) RaftSnapshotWithContext(ctx context.Context, snapWriter io.Writer) error { | ||||
| 	r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") | ||||
| 	r.URL.RawQuery = r.Params.Encode() | ||||
|  | ||||
| 	req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil) | ||||
| 	resp, err := c.c.httpRequestWithContext(ctx, r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	req.URL.User = r.URL.User | ||||
| 	req.URL.Scheme = r.URL.Scheme | ||||
| 	req.URL.Host = r.URL.Host | ||||
| 	req.Host = r.URL.Host | ||||
|  | ||||
| 	if r.Headers != nil { | ||||
| 		for header, vals := range r.Headers { | ||||
| 			for _, val := range vals { | ||||
| 				req.Header.Add(header, val) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(r.ClientToken) != 0 { | ||||
| 		req.Header.Set(consts.AuthHeaderName, r.ClientToken) | ||||
| 	} | ||||
|  | ||||
| 	if len(r.WrapTTL) != 0 { | ||||
| 		req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) | ||||
| 	} | ||||
|  | ||||
| 	if len(r.MFAHeaderVals) != 0 { | ||||
| 		for _, mfaHeaderVal := range r.MFAHeaderVals { | ||||
| 			req.Header.Add("X-Vault-MFA", mfaHeaderVal) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if r.PolicyOverride { | ||||
| 		req.Header.Set("X-Vault-Policy-Override", "true") | ||||
| 	} | ||||
|  | ||||
| 	// Avoiding the use of RawRequestWithContext which reads the response body | ||||
| 	// to determine if the body contains error message. | ||||
| 	var result *Response | ||||
| 	resp, err := c.c.config.HttpClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if resp == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// Check for a redirect, only allowing for a single redirect | ||||
| 	if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { | ||||
| 		// Parse the updated location | ||||
| 		respLoc, err := resp.Location() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// Ensure a protocol downgrade doesn't happen | ||||
| 		if req.URL.Scheme == "https" && respLoc.Scheme != "https" { | ||||
| 			return fmt.Errorf("redirect would cause protocol downgrade") | ||||
| 		} | ||||
|  | ||||
| 		// Update the request | ||||
| 		req.URL = respLoc | ||||
|  | ||||
| 		// Retry the request | ||||
| 		resp, err = c.c.config.HttpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	result = &Response{Response: resp} | ||||
| 	if err := result.Error(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	// Make sure that the last file in the archive, SHA256SUMS.sealed, is present | ||||
| 	// and non-empty.  This is to catch cases where the snapshot failed midstream, | ||||
| @@ -271,20 +207,26 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // RaftSnapshotRestore reads the snapshot from the io.Reader and installs that | ||||
| // snapshot, returning the cluster to the state defined by it. | ||||
| // RaftSnapshotRestore is a thin wrapper around RaftSnapshotRestoreWithContext | ||||
| func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	return c.RaftSnapshotRestoreWithContext(ctx, snapReader, force) | ||||
| } | ||||
|  | ||||
| // RaftSnapshotRestoreWithContext reads the snapshot from the io.Reader and installs that | ||||
| // snapshot, returning the cluster to the state defined by it. | ||||
| func (c *Sys) RaftSnapshotRestoreWithContext(ctx context.Context, snapReader io.Reader, force bool) error { | ||||
| 	path := "/v1/sys/storage/raft/snapshot" | ||||
| 	if force { | ||||
| 		path = "/v1/sys/storage/raft/snapshot-force" | ||||
| 	} | ||||
| 	r := c.c.NewRequest("POST", path) | ||||
|  | ||||
| 	r := c.c.NewRequest(http.MethodPost, path) | ||||
| 	r.Body = snapReader | ||||
|  | ||||
| 	ctx, cancelFunc := context.WithCancel(context.Background()) | ||||
| 	defer cancelFunc() | ||||
| 	resp, err := c.c.RawRequestWithContext(ctx, r) | ||||
| 	resp, err := c.c.httpRequestWithContext(ctx, r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|   | ||||
							
								
								
									
										3
									
								
								changelog/14269.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/14269.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| ```release-note:bug | ||||
|  api/sys/raft: Update RaftSnapshotRestore to use net/http client allowing bodies larger than allocated memory to be streamed | ||||
| ``` | ||||
| @@ -489,28 +489,10 @@ func TestRaft_SnapshotAPI(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	transport := cleanhttp.DefaultPooledTransport() | ||||
| 	transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone() | ||||
| 	if err := http2.ConfigureTransport(transport); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	client := &http.Client{ | ||||
| 		Transport: transport, | ||||
| 	} | ||||
|  | ||||
| 	// Take a snapshot | ||||
| 	req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot") | ||||
| 	httpReq, err := req.ToHTTP() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	resp, err := client.Do(httpReq) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	snap, err := ioutil.ReadAll(resp.Body) | ||||
| 	buf := new(bytes.Buffer) | ||||
| 	err := leaderClient.Sys().RaftSnapshot(buf) | ||||
| 	snap, err := io.ReadAll(buf) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -527,15 +509,8 @@ func TestRaft_SnapshotAPI(t *testing.T) { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Restore snapshot | ||||
| 	req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot") | ||||
| 	req.Body = bytes.NewBuffer(snap) | ||||
| 	httpReq, err = req.ToHTTP() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	resp, err = client.Do(httpReq) | ||||
| 	err = leaderClient.Sys().RaftSnapshotRestore(bytes.NewReader(snap), false) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Vinny Mannello
					Vinny Mannello