mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 10:37:56 +00:00 
			
		
		
		
	[VAULT-5003] Use net/http client in Sys().RaftSnapshotRestore (#14269)
Use net/http client when body could be too big for retryablehttp client
This commit is contained in:
		
							
								
								
									
										154
									
								
								api/client.go
									
									
									
									
									
								
							
							
						
						
									
										154
									
								
								api/client.go
									
									
									
									
									
								
							| @@ -53,6 +53,14 @@ const ( | |||||||
| 	HeaderIndex           = "X-Vault-Index" | 	HeaderIndex           = "X-Vault-Index" | ||||||
| 	HeaderForward         = "X-Vault-Forward" | 	HeaderForward         = "X-Vault-Forward" | ||||||
| 	HeaderInconsistent    = "X-Vault-Inconsistent" | 	HeaderInconsistent    = "X-Vault-Inconsistent" | ||||||
|  | 	TLSErrorString        = "This error usually means that the server is running with TLS disabled\n" + | ||||||
|  | 		"but the client is configured to use TLS. Please either enable TLS\n" + | ||||||
|  | 		"on the server or run the client with -address set to an address\n" + | ||||||
|  | 		"that uses the http protocol:\n\n" + | ||||||
|  | 		"    vault <command> -address http://<address>\n\n" + | ||||||
|  | 		"You can also set the VAULT_ADDR environment variable:\n\n\n" + | ||||||
|  | 		"    VAULT_ADDR=http://<address> vault <command>\n\n" + | ||||||
|  | 		"where <address> is replaced by the actual address to the server." | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Deprecated values | // Deprecated values | ||||||
| @@ -1127,12 +1135,9 @@ func (c *Client) RawRequestWithContext(ctx context.Context, r *Request) (*Respon | |||||||
| 		limiter.Wait(ctx) | 		limiter.Wait(ctx) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Sanity check the token before potentially erroring from the API | 	// check the token before potentially erroring from the API | ||||||
| 	idx := strings.IndexFunc(token, func(c rune) bool { | 	if err := validateToken(token); err != nil { | ||||||
| 		return !unicode.IsPrint(c) | 		return nil, err | ||||||
| 	}) |  | ||||||
| 	if idx != -1 { |  | ||||||
| 		return nil, fmt.Errorf("configured Vault token contains non-printable characters and cannot be used") |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	redirectCount := 0 | 	redirectCount := 0 | ||||||
| @@ -1192,17 +1197,7 @@ START: | |||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if strings.Contains(err.Error(), "tls: oversized") { | 		if strings.Contains(err.Error(), "tls: oversized") { | ||||||
| 			err = errwrap.Wrapf( | 			err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err) | ||||||
| 				"{{err}}\n\n"+ |  | ||||||
| 					"This error usually means that the server is running with TLS disabled\n"+ |  | ||||||
| 					"but the client is configured to use TLS. Please either enable TLS\n"+ |  | ||||||
| 					"on the server or run the client with -address set to an address\n"+ |  | ||||||
| 					"that uses the http protocol:\n\n"+ |  | ||||||
| 					"    vault <command> -address http://<address>\n\n"+ |  | ||||||
| 					"You can also set the VAULT_ADDR environment variable:\n\n\n"+ |  | ||||||
| 					"    VAULT_ADDR=http://<address> vault <command>\n\n"+ |  | ||||||
| 					"where <address> is replaced by the actual address to the server.", |  | ||||||
| 				err) |  | ||||||
| 		} | 		} | ||||||
| 		return result, err | 		return result, err | ||||||
| 	} | 	} | ||||||
| @@ -1249,6 +1244,120 @@ START: | |||||||
| 	return result, nil | 	return result, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // httpRequestWithContext avoids the use of the go-retryable library found in RawRequestWithContext and is | ||||||
|  | // useful when making calls where a net/http client is desirable. A single redirect (status code 301, 302, | ||||||
|  | // or 307) will be followed but all retry and timeout logic is the responsibility of the caller as is | ||||||
|  | // closing the Response body. | ||||||
|  | func (c *Client) httpRequestWithContext(ctx context.Context, r *Request) (*Response, error) { | ||||||
|  | 	req, err := http.NewRequestWithContext(ctx, r.Method, r.URL.RequestURI(), r.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.modifyLock.RLock() | ||||||
|  | 	token := c.token | ||||||
|  |  | ||||||
|  | 	c.config.modifyLock.RLock() | ||||||
|  | 	limiter := c.config.Limiter | ||||||
|  | 	httpClient := c.config.HttpClient | ||||||
|  | 	outputCurlString := c.config.OutputCurlString | ||||||
|  | 	if c.headers != nil { | ||||||
|  | 		for header, vals := range c.headers { | ||||||
|  | 			for _, val := range vals { | ||||||
|  | 				req.Header.Add(header, val) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	c.config.modifyLock.RUnlock() | ||||||
|  | 	c.modifyLock.RUnlock() | ||||||
|  |  | ||||||
|  | 	// OutputCurlString logic relies on the request type to be retryable.Request as | ||||||
|  | 	if outputCurlString { | ||||||
|  | 		return nil, fmt.Errorf("output-curl-string is not implemented for this request") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.URL.User = r.URL.User | ||||||
|  | 	req.URL.Scheme = r.URL.Scheme | ||||||
|  | 	req.URL.Host = r.URL.Host | ||||||
|  | 	req.Host = r.URL.Host | ||||||
|  |  | ||||||
|  | 	if len(r.ClientToken) != 0 { | ||||||
|  | 		req.Header.Set(consts.AuthHeaderName, r.ClientToken) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(r.WrapTTL) != 0 { | ||||||
|  | 		req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(r.MFAHeaderVals) != 0 { | ||||||
|  | 		for _, mfaHeaderVal := range r.MFAHeaderVals { | ||||||
|  | 			req.Header.Add("X-Vault-MFA", mfaHeaderVal) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if r.PolicyOverride { | ||||||
|  | 		req.Header.Set("X-Vault-Policy-Override", "true") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if limiter != nil { | ||||||
|  | 		limiter.Wait(ctx) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// check the token before potentially erroring from the API | ||||||
|  | 	if err := validateToken(token); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var result *Response | ||||||
|  |  | ||||||
|  | 	resp, err := httpClient.Do(req) | ||||||
|  |  | ||||||
|  | 	if resp != nil { | ||||||
|  | 		result = &Response{Response: resp} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		if strings.Contains(err.Error(), "tls: oversized") { | ||||||
|  | 			err = errwrap.Wrapf("{{err}}\n\n"+TLSErrorString, err) | ||||||
|  | 		} | ||||||
|  | 		return result, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check for a redirect, only allowing for a single redirect | ||||||
|  | 	if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { | ||||||
|  | 		// Parse the updated location | ||||||
|  | 		respLoc, err := resp.Location() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return result, fmt.Errorf("redirect failed: %s", err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Ensure a protocol downgrade doesn't happen | ||||||
|  | 		if req.URL.Scheme == "https" && respLoc.Scheme != "https" { | ||||||
|  | 			return result, fmt.Errorf("redirect would cause protocol downgrade") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Update the request | ||||||
|  | 		req.URL = respLoc | ||||||
|  |  | ||||||
|  | 		// Reset the request body if any | ||||||
|  | 		if err := r.ResetJSONBody(); err != nil { | ||||||
|  | 			return result, fmt.Errorf("redirect failed: %s", err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Retry the request | ||||||
|  | 		resp, err = httpClient.Do(req) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return result, fmt.Errorf("redirect failed: %s", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := result.Error(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return result, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| type ( | type ( | ||||||
| 	RequestCallback  func(*Request) | 	RequestCallback  func(*Request) | ||||||
| 	ResponseCallback func(*Response) | 	ResponseCallback func(*Response) | ||||||
| @@ -1466,3 +1575,14 @@ func (w *replicationStateStore) states() []string { | |||||||
| 	copy(c, w.store) | 	copy(c, w.store) | ||||||
| 	return c | 	return c | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // validateToken will check for non-printable characters to prevent a call that will fail at the api | ||||||
|  | func validateToken(t string) error { | ||||||
|  | 	idx := strings.IndexFunc(t, func(c rune) bool { | ||||||
|  | 		return !unicode.IsPrint(c) | ||||||
|  | 	}) | ||||||
|  | 	if idx != -1 { | ||||||
|  | 		return fmt.Errorf("configured Vault token contains non-printable characters and cannot be used") | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										106
									
								
								api/sys_raft.go
									
									
									
									
									
								
							
							
						
						
									
										106
									
								
								api/sys_raft.go
									
									
									
									
									
								
							| @@ -6,7 +6,6 @@ import ( | |||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| @@ -14,7 +13,6 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-secure-stdlib/parseutil" | 	"github.com/hashicorp/go-secure-stdlib/parseutil" | ||||||
| 	"github.com/hashicorp/vault/sdk/helper/consts" |  | ||||||
| 	"github.com/mitchellh/mapstructure" | 	"github.com/mitchellh/mapstructure" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -132,87 +130,25 @@ func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) { | |||||||
| 	return &result, err | 	return &result, err | ||||||
| } | } | ||||||
|  |  | ||||||
| // RaftSnapshot invokes the API that takes the snapshot of the raft cluster and | // RaftSnapshot is a thin wrapper around RaftSnapshotWithContext | ||||||
| // writes it to the supplied io.Writer. |  | ||||||
| func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { | func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { | ||||||
|  | 	ctx, cancelFunc := context.WithCancel(context.Background()) | ||||||
|  | 	defer cancelFunc() | ||||||
|  |  | ||||||
|  | 	return c.RaftSnapshotWithContext(ctx, snapWriter) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // RaftSnapshotWithContext invokes the API that takes the snapshot of the raft cluster and | ||||||
|  | // writes it to the supplied io.Writer. | ||||||
|  | func (c *Sys) RaftSnapshotWithContext(ctx context.Context, snapWriter io.Writer) error { | ||||||
| 	r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") | 	r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot") | ||||||
| 	r.URL.RawQuery = r.Params.Encode() | 	r.URL.RawQuery = r.Params.Encode() | ||||||
|  |  | ||||||
| 	req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil) | 	resp, err := c.c.httpRequestWithContext(ctx, r) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	defer resp.Body.Close() | ||||||
| 	req.URL.User = r.URL.User |  | ||||||
| 	req.URL.Scheme = r.URL.Scheme |  | ||||||
| 	req.URL.Host = r.URL.Host |  | ||||||
| 	req.Host = r.URL.Host |  | ||||||
|  |  | ||||||
| 	if r.Headers != nil { |  | ||||||
| 		for header, vals := range r.Headers { |  | ||||||
| 			for _, val := range vals { |  | ||||||
| 				req.Header.Add(header, val) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(r.ClientToken) != 0 { |  | ||||||
| 		req.Header.Set(consts.AuthHeaderName, r.ClientToken) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(r.WrapTTL) != 0 { |  | ||||||
| 		req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(r.MFAHeaderVals) != 0 { |  | ||||||
| 		for _, mfaHeaderVal := range r.MFAHeaderVals { |  | ||||||
| 			req.Header.Add("X-Vault-MFA", mfaHeaderVal) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if r.PolicyOverride { |  | ||||||
| 		req.Header.Set("X-Vault-Policy-Override", "true") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Avoiding the use of RawRequestWithContext which reads the response body |  | ||||||
| 	// to determine if the body contains error message. |  | ||||||
| 	var result *Response |  | ||||||
| 	resp, err := c.c.config.HttpClient.Do(req) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if resp == nil { |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check for a redirect, only allowing for a single redirect |  | ||||||
| 	if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { |  | ||||||
| 		// Parse the updated location |  | ||||||
| 		respLoc, err := resp.Location() |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Ensure a protocol downgrade doesn't happen |  | ||||||
| 		if req.URL.Scheme == "https" && respLoc.Scheme != "https" { |  | ||||||
| 			return fmt.Errorf("redirect would cause protocol downgrade") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Update the request |  | ||||||
| 		req.URL = respLoc |  | ||||||
|  |  | ||||||
| 		// Retry the request |  | ||||||
| 		resp, err = c.c.config.HttpClient.Do(req) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	result = &Response{Response: resp} |  | ||||||
| 	if err := result.Error(); err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Make sure that the last file in the archive, SHA256SUMS.sealed, is present | 	// Make sure that the last file in the archive, SHA256SUMS.sealed, is present | ||||||
| 	// and non-empty.  This is to catch cases where the snapshot failed midstream, | 	// and non-empty.  This is to catch cases where the snapshot failed midstream, | ||||||
| @@ -271,20 +207,26 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // RaftSnapshotRestore reads the snapshot from the io.Reader and installs that | // RaftSnapshotRestore is a thin wrapper around RaftSnapshotRestoreWithContext | ||||||
| // snapshot, returning the cluster to the state defined by it. |  | ||||||
| func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { | func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { | ||||||
|  | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	return c.RaftSnapshotRestoreWithContext(ctx, snapReader, force) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // RaftSnapshotRestoreWithContext reads the snapshot from the io.Reader and installs that | ||||||
|  | // snapshot, returning the cluster to the state defined by it. | ||||||
|  | func (c *Sys) RaftSnapshotRestoreWithContext(ctx context.Context, snapReader io.Reader, force bool) error { | ||||||
| 	path := "/v1/sys/storage/raft/snapshot" | 	path := "/v1/sys/storage/raft/snapshot" | ||||||
| 	if force { | 	if force { | ||||||
| 		path = "/v1/sys/storage/raft/snapshot-force" | 		path = "/v1/sys/storage/raft/snapshot-force" | ||||||
| 	} | 	} | ||||||
| 	r := c.c.NewRequest("POST", path) |  | ||||||
|  |  | ||||||
|  | 	r := c.c.NewRequest(http.MethodPost, path) | ||||||
| 	r.Body = snapReader | 	r.Body = snapReader | ||||||
|  |  | ||||||
| 	ctx, cancelFunc := context.WithCancel(context.Background()) | 	resp, err := c.c.httpRequestWithContext(ctx, r) | ||||||
| 	defer cancelFunc() |  | ||||||
| 	resp, err := c.c.RawRequestWithContext(ctx, r) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								changelog/14269.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/14269.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | ```release-note:bug | ||||||
|  |  api/sys/raft: Update RaftSnapshotRestore to use net/http client allowing bodies larger than allocated memory to be streamed | ||||||
|  | ``` | ||||||
| @@ -489,28 +489,10 @@ func TestRaft_SnapshotAPI(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	transport := cleanhttp.DefaultPooledTransport() |  | ||||||
| 	transport.TLSClientConfig = cluster.Cores[0].TLSConfig.Clone() |  | ||||||
| 	if err := http2.ConfigureTransport(transport); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	client := &http.Client{ |  | ||||||
| 		Transport: transport, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Take a snapshot | 	// Take a snapshot | ||||||
| 	req := leaderClient.NewRequest("GET", "/v1/sys/storage/raft/snapshot") | 	buf := new(bytes.Buffer) | ||||||
| 	httpReq, err := req.ToHTTP() | 	err := leaderClient.Sys().RaftSnapshot(buf) | ||||||
| 	if err != nil { | 	snap, err := io.ReadAll(buf) | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	resp, err := client.Do(httpReq) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer resp.Body.Close() |  | ||||||
|  |  | ||||||
| 	snap, err := ioutil.ReadAll(resp.Body) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| @@ -527,15 +509,8 @@ func TestRaft_SnapshotAPI(t *testing.T) { | |||||||
| 			t.Fatal(err) | 			t.Fatal(err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Restore snapshot | 	// Restore snapshot | ||||||
| 	req = leaderClient.NewRequest("POST", "/v1/sys/storage/raft/snapshot") | 	err = leaderClient.Sys().RaftSnapshotRestore(bytes.NewReader(snap), false) | ||||||
| 	req.Body = bytes.NewBuffer(snap) |  | ||||||
| 	httpReq, err = req.ToHTTP() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	resp, err = client.Do(httpReq) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Vinny Mannello
					Vinny Mannello