mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-11-01 11:08:10 +00:00
Add ability to optionally clone an api.Client's headers (#12117)
This commit is contained in:
@@ -125,6 +125,9 @@ type Config struct {
|
|||||||
|
|
||||||
// SRVLookup enables the client to lookup the host through DNS SRV lookup
|
// SRVLookup enables the client to lookup the host through DNS SRV lookup
|
||||||
SRVLookup bool
|
SRVLookup bool
|
||||||
|
|
||||||
|
// CloneHeaders ensures that the source client's headers are copied to its clone.
|
||||||
|
CloneHeaders bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSConfig contains the parameters needed to configure TLS on the HTTP client
|
// TLSConfig contains the parameters needed to configure TLS on the HTTP client
|
||||||
@@ -504,6 +507,7 @@ func (c *Client) CloneConfig() *Config {
|
|||||||
newConfig.Limiter = c.config.Limiter
|
newConfig.Limiter = c.config.Limiter
|
||||||
newConfig.OutputCurlString = c.config.OutputCurlString
|
newConfig.OutputCurlString = c.config.OutputCurlString
|
||||||
newConfig.SRVLookup = c.config.SRVLookup
|
newConfig.SRVLookup = c.config.SRVLookup
|
||||||
|
newConfig.CloneHeaders = c.config.CloneHeaders
|
||||||
|
|
||||||
// we specifically want a _copy_ of the client here, not a pointer to the original one
|
// we specifically want a _copy_ of the client here, not a pointer to the original one
|
||||||
newClient := *c.config.HttpClient
|
newClient := *c.config.HttpClient
|
||||||
@@ -809,6 +813,26 @@ func (c *Client) SetLogger(logger retryablehttp.LeveledLogger) {
|
|||||||
c.config.Logger = logger
|
c.config.Logger = logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCloneHeaders to allow headers to be copied whenever the client is cloned.
|
||||||
|
func (c *Client) SetCloneHeaders(cloneHeaders bool) {
|
||||||
|
c.modifyLock.Lock()
|
||||||
|
defer c.modifyLock.Unlock()
|
||||||
|
c.config.modifyLock.Lock()
|
||||||
|
defer c.config.modifyLock.Unlock()
|
||||||
|
|
||||||
|
c.config.CloneHeaders = cloneHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneHeaders gets the configured CloneHeaders value.
|
||||||
|
func (c *Client) CloneHeaders() bool {
|
||||||
|
c.modifyLock.RLock()
|
||||||
|
defer c.modifyLock.RUnlock()
|
||||||
|
c.config.modifyLock.RLock()
|
||||||
|
defer c.config.modifyLock.RUnlock()
|
||||||
|
|
||||||
|
return c.config.CloneHeaders
|
||||||
|
}
|
||||||
|
|
||||||
// Clone creates a new client with the same configuration. Note that the same
|
// Clone creates a new client with the same configuration. Note that the same
|
||||||
// underlying http.Client is used; modifying the client from more than one
|
// underlying http.Client is used; modifying the client from more than one
|
||||||
// goroutine at once may not be safe, so modify the client as needed and then
|
// goroutine at once may not be safe, so modify the client as needed and then
|
||||||
@@ -839,12 +863,17 @@ func (c *Client) Clone() (*Client, error) {
|
|||||||
OutputCurlString: config.OutputCurlString,
|
OutputCurlString: config.OutputCurlString,
|
||||||
AgentAddress: config.AgentAddress,
|
AgentAddress: config.AgentAddress,
|
||||||
SRVLookup: config.SRVLookup,
|
SRVLookup: config.SRVLookup,
|
||||||
|
CloneHeaders: config.CloneHeaders,
|
||||||
}
|
}
|
||||||
client, err := NewClient(newConfig)
|
client, err := NewClient(newConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.CloneHeaders {
|
||||||
|
client.SetHeaders(c.Headers().Clone())
|
||||||
|
}
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -409,7 +410,32 @@ func TestClientNonTransportRoundTripper(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClone(t *testing.T) {
|
func TestClone(t *testing.T) {
|
||||||
client1, err := NewClient(DefaultConfig())
|
type fields struct {
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *Config
|
||||||
|
headers *http.Header
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
config: DefaultConfig(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cloneHeaders",
|
||||||
|
config: &Config{
|
||||||
|
CloneHeaders: true,
|
||||||
|
},
|
||||||
|
headers: &http.Header{
|
||||||
|
"X-foo": []string{"bar"},
|
||||||
|
"X-baz": []string{"qux"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
client1, err := NewClient(tt.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewClient failed: %v", err)
|
t.Fatalf("NewClient failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -435,6 +461,10 @@ func TestClone(t *testing.T) {
|
|||||||
client1.SetOutputCurlString(true)
|
client1.SetOutputCurlString(true)
|
||||||
client1.SetSRVLookup(true)
|
client1.SetSRVLookup(true)
|
||||||
|
|
||||||
|
if tt.headers != nil {
|
||||||
|
client1.SetHeaders(*tt.headers)
|
||||||
|
}
|
||||||
|
|
||||||
client2, err := client1.Clone()
|
client2, err := client1.Clone()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Clone failed: %v", err)
|
t.Fatalf("Clone failed: %v", err)
|
||||||
@@ -467,6 +497,21 @@ func TestClone(t *testing.T) {
|
|||||||
if client1.SRVLookup() != client2.SRVLookup() {
|
if client1.SRVLookup() != client2.SRVLookup() {
|
||||||
t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup())
|
t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup())
|
||||||
}
|
}
|
||||||
|
if tt.config.CloneHeaders {
|
||||||
|
if !reflect.DeepEqual(client1.Headers(), client2.Headers()) {
|
||||||
|
t.Fatalf("Headers() don't match: %v vs %v", client1.Headers(), client2.Headers())
|
||||||
|
}
|
||||||
|
if client1.config.CloneHeaders != client2.config.CloneHeaders {
|
||||||
|
t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", client1.config.CloneHeaders, client2.config.CloneHeaders)
|
||||||
|
}
|
||||||
|
if tt.headers != nil {
|
||||||
|
if !reflect.DeepEqual(*tt.headers, client2.Headers()) {
|
||||||
|
t.Fatalf("expected headers %v, actual %v", *tt.headers, client2.Headers())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetHeadersRaceSafe(t *testing.T) {
|
func TestSetHeadersRaceSafe(t *testing.T) {
|
||||||
|
|||||||
3
changelog/12117.txt
Normal file
3
changelog/12117.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
```release-note:improvement
|
||||||
|
api: Allow cloning `api.Client` HTTP headers via `api.Config.CloneHeaders` or `api.Client.SetCloneHeaders`.
|
||||||
|
```
|
||||||
Reference in New Issue
Block a user