mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 02:28:09 +00:00 
			
		
		
		
	Add ability to optionally clone an api.Client's headers (#12117)
This commit is contained in:
		| @@ -125,6 +125,9 @@ type Config struct { | |||||||
|  |  | ||||||
| 	// SRVLookup enables the client to lookup the host through DNS SRV lookup | 	// SRVLookup enables the client to lookup the host through DNS SRV lookup | ||||||
| 	SRVLookup bool | 	SRVLookup bool | ||||||
|  |  | ||||||
|  | 	// CloneHeaders ensures that the source client's headers are copied to its clone. | ||||||
|  | 	CloneHeaders bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // TLSConfig contains the parameters needed to configure TLS on the HTTP client | // TLSConfig contains the parameters needed to configure TLS on the HTTP client | ||||||
| @@ -504,6 +507,7 @@ func (c *Client) CloneConfig() *Config { | |||||||
| 	newConfig.Limiter = c.config.Limiter | 	newConfig.Limiter = c.config.Limiter | ||||||
| 	newConfig.OutputCurlString = c.config.OutputCurlString | 	newConfig.OutputCurlString = c.config.OutputCurlString | ||||||
| 	newConfig.SRVLookup = c.config.SRVLookup | 	newConfig.SRVLookup = c.config.SRVLookup | ||||||
|  | 	newConfig.CloneHeaders = c.config.CloneHeaders | ||||||
|  |  | ||||||
| 	// we specifically want a _copy_ of the client here, not a pointer to the original one | 	// we specifically want a _copy_ of the client here, not a pointer to the original one | ||||||
| 	newClient := *c.config.HttpClient | 	newClient := *c.config.HttpClient | ||||||
| @@ -809,6 +813,26 @@ func (c *Client) SetLogger(logger retryablehttp.LeveledLogger) { | |||||||
| 	c.config.Logger = logger | 	c.config.Logger = logger | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // SetCloneHeaders to allow headers to be copied whenever the client is cloned. | ||||||
|  | func (c *Client) SetCloneHeaders(cloneHeaders bool) { | ||||||
|  | 	c.modifyLock.Lock() | ||||||
|  | 	defer c.modifyLock.Unlock() | ||||||
|  | 	c.config.modifyLock.Lock() | ||||||
|  | 	defer c.config.modifyLock.Unlock() | ||||||
|  |  | ||||||
|  | 	c.config.CloneHeaders = cloneHeaders | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // CloneHeaders gets the configured CloneHeaders value. | ||||||
|  | func (c *Client) CloneHeaders() bool { | ||||||
|  | 	c.modifyLock.RLock() | ||||||
|  | 	defer c.modifyLock.RUnlock() | ||||||
|  | 	c.config.modifyLock.RLock() | ||||||
|  | 	defer c.config.modifyLock.RUnlock() | ||||||
|  |  | ||||||
|  | 	return c.config.CloneHeaders | ||||||
|  | } | ||||||
|  |  | ||||||
| // Clone creates a new client with the same configuration. Note that the same | // Clone creates a new client with the same configuration. Note that the same | ||||||
| // underlying http.Client is used; modifying the client from more than one | // underlying http.Client is used; modifying the client from more than one | ||||||
| // goroutine at once may not be safe, so modify the client as needed and then | // goroutine at once may not be safe, so modify the client as needed and then | ||||||
| @@ -839,12 +863,17 @@ func (c *Client) Clone() (*Client, error) { | |||||||
| 		OutputCurlString: config.OutputCurlString, | 		OutputCurlString: config.OutputCurlString, | ||||||
| 		AgentAddress:     config.AgentAddress, | 		AgentAddress:     config.AgentAddress, | ||||||
| 		SRVLookup:        config.SRVLookup, | 		SRVLookup:        config.SRVLookup, | ||||||
|  | 		CloneHeaders:     config.CloneHeaders, | ||||||
| 	} | 	} | ||||||
| 	client, err := NewClient(newConfig) | 	client, err := NewClient(newConfig) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if config.CloneHeaders { | ||||||
|  | 		client.SetHeaders(c.Headers().Clone()) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return client, nil | 	return client, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -409,63 +410,107 @@ func TestClientNonTransportRoundTripper(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestClone(t *testing.T) { | func TestClone(t *testing.T) { | ||||||
| 	client1, err := NewClient(DefaultConfig()) | 	type fields struct { | ||||||
| 	if err != nil { | 	} | ||||||
| 		t.Fatalf("NewClient failed: %v", err) | 	tests := []struct { | ||||||
|  | 		name    string | ||||||
|  | 		config  *Config | ||||||
|  | 		headers *http.Header | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:   "default", | ||||||
|  | 			config: DefaultConfig(), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "cloneHeaders", | ||||||
|  | 			config: &Config{ | ||||||
|  | 				CloneHeaders: true, | ||||||
|  | 			}, | ||||||
|  | 			headers: &http.Header{ | ||||||
|  | 				"X-foo": []string{"bar"}, | ||||||
|  | 				"X-baz": []string{"qux"}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Set all of the things that we provide setter methods for, which modify config values | 	for _, tt := range tests { | ||||||
| 	err = client1.SetAddress("http://example.com:8080") | 		t.Run(tt.name, func(t *testing.T) { | ||||||
| 	if err != nil { | 			client1, err := NewClient(tt.config) | ||||||
| 		t.Fatalf("SetAddress failed: %v", err) | 			if err != nil { | ||||||
| 	} | 				t.Fatalf("NewClient failed: %v", err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 	clientTimeout := time.Until(time.Now().AddDate(0, 0, 1)) | 			// Set all of the things that we provide setter methods for, which modify config values | ||||||
| 	client1.SetClientTimeout(clientTimeout) | 			err = client1.SetAddress("http://example.com:8080") | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatalf("SetAddress failed: %v", err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 	checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) { | 			clientTimeout := time.Until(time.Now().AddDate(0, 0, 1)) | ||||||
| 		return true, nil | 			client1.SetClientTimeout(clientTimeout) | ||||||
| 	} |  | ||||||
| 	client1.SetCheckRetry(checkRetry) |  | ||||||
|  |  | ||||||
| 	client1.SetLogger(hclog.NewNullLogger()) | 			checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) { | ||||||
|  | 				return true, nil | ||||||
|  | 			} | ||||||
|  | 			client1.SetCheckRetry(checkRetry) | ||||||
|  |  | ||||||
| 	client1.SetLimiter(5.0, 10) | 			client1.SetLogger(hclog.NewNullLogger()) | ||||||
| 	client1.SetMaxRetries(5) |  | ||||||
| 	client1.SetOutputCurlString(true) |  | ||||||
| 	client1.SetSRVLookup(true) |  | ||||||
|  |  | ||||||
| 	client2, err := client1.Clone() | 			client1.SetLimiter(5.0, 10) | ||||||
| 	if err != nil { | 			client1.SetMaxRetries(5) | ||||||
| 		t.Fatalf("Clone failed: %v", err) | 			client1.SetOutputCurlString(true) | ||||||
| 	} | 			client1.SetSRVLookup(true) | ||||||
|  |  | ||||||
| 	if client1.Address() != client2.Address() { | 			if tt.headers != nil { | ||||||
| 		t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address()) | 				client1.SetHeaders(*tt.headers) | ||||||
| 	} | 			} | ||||||
| 	if client1.ClientTimeout() != client2.ClientTimeout() { |  | ||||||
| 		t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout()) | 			client2, err := client1.Clone() | ||||||
| 	} | 			if err != nil { | ||||||
| 	if client1.CheckRetry() != nil && client2.CheckRetry() == nil { | 				t.Fatalf("Clone failed: %v", err) | ||||||
| 		t.Fatal("checkRetry functions don't match. client2 is nil.") | 			} | ||||||
| 	} |  | ||||||
| 	if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) { | 			if client1.Address() != client2.Address() { | ||||||
| 		t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter()) | 				t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address()) | ||||||
| 	} | 			} | ||||||
| 	if client1.Limiter().Limit() != client2.Limiter().Limit() { | 			if client1.ClientTimeout() != client2.ClientTimeout() { | ||||||
| 		t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit()) | 				t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout()) | ||||||
| 	} | 			} | ||||||
| 	if client1.Limiter().Burst() != client2.Limiter().Burst() { | 			if client1.CheckRetry() != nil && client2.CheckRetry() == nil { | ||||||
| 		t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst()) | 				t.Fatal("checkRetry functions don't match. client2 is nil.") | ||||||
| 	} | 			} | ||||||
| 	if client1.MaxRetries() != client2.MaxRetries() { | 			if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) { | ||||||
| 		t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries()) | 				t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter()) | ||||||
| 	} | 			} | ||||||
| 	if client1.OutputCurlString() != client2.OutputCurlString() { | 			if client1.Limiter().Limit() != client2.Limiter().Limit() { | ||||||
| 		t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString()) | 				t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit()) | ||||||
| 	} | 			} | ||||||
| 	if client1.SRVLookup() != client2.SRVLookup() { | 			if client1.Limiter().Burst() != client2.Limiter().Burst() { | ||||||
| 		t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup()) | 				t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst()) | ||||||
|  | 			} | ||||||
|  | 			if client1.MaxRetries() != client2.MaxRetries() { | ||||||
|  | 				t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries()) | ||||||
|  | 			} | ||||||
|  | 			if client1.OutputCurlString() != client2.OutputCurlString() { | ||||||
|  | 				t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString()) | ||||||
|  | 			} | ||||||
|  | 			if client1.SRVLookup() != client2.SRVLookup() { | ||||||
|  | 				t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup()) | ||||||
|  | 			} | ||||||
|  | 			if tt.config.CloneHeaders { | ||||||
|  | 				if !reflect.DeepEqual(client1.Headers(), client2.Headers()) { | ||||||
|  | 					t.Fatalf("Headers() don't match: %v vs %v", client1.Headers(), client2.Headers()) | ||||||
|  | 				} | ||||||
|  | 				if client1.config.CloneHeaders != client2.config.CloneHeaders { | ||||||
|  | 					t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", client1.config.CloneHeaders, client2.config.CloneHeaders) | ||||||
|  | 				} | ||||||
|  | 				if tt.headers != nil { | ||||||
|  | 					if !reflect.DeepEqual(*tt.headers, client2.Headers()) { | ||||||
|  | 						t.Fatalf("expected headers %v, actual %v", *tt.headers, client2.Headers()) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								changelog/12117.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/12117.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | ```release-note:improvement | ||||||
|  | api: Allow cloning `api.Client` HTTP headers via `api.Config.CloneHeaders` or `api.Client.SetCloneHeaders`. | ||||||
|  | ``` | ||||||
		Reference in New Issue
	
	Block a user
	 Ben Ash
					Ben Ash