Add ability to optionally clone an api.Client's headers (#12117)

This commit is contained in:
Ben Ash
2021-07-19 17:15:31 -04:00
committed by GitHub
parent 522cfdf20a
commit 6b31c12b0a
3 changed files with 126 additions and 49 deletions

View File

@@ -125,6 +125,9 @@ type Config struct {
// SRVLookup enables the client to lookup the host through DNS SRV lookup // SRVLookup enables the client to lookup the host through DNS SRV lookup
SRVLookup bool SRVLookup bool
// CloneHeaders ensures that the source client's headers are copied to its clone.
CloneHeaders bool
} }
// TLSConfig contains the parameters needed to configure TLS on the HTTP client // TLSConfig contains the parameters needed to configure TLS on the HTTP client
@@ -504,6 +507,7 @@ func (c *Client) CloneConfig() *Config {
newConfig.Limiter = c.config.Limiter newConfig.Limiter = c.config.Limiter
newConfig.OutputCurlString = c.config.OutputCurlString newConfig.OutputCurlString = c.config.OutputCurlString
newConfig.SRVLookup = c.config.SRVLookup newConfig.SRVLookup = c.config.SRVLookup
newConfig.CloneHeaders = c.config.CloneHeaders
// we specifically want a _copy_ of the client here, not a pointer to the original one // we specifically want a _copy_ of the client here, not a pointer to the original one
newClient := *c.config.HttpClient newClient := *c.config.HttpClient
@@ -809,6 +813,26 @@ func (c *Client) SetLogger(logger retryablehttp.LeveledLogger) {
c.config.Logger = logger c.config.Logger = logger
} }
// SetCloneHeaders to allow headers to be copied whenever the client is cloned.
func (c *Client) SetCloneHeaders(cloneHeaders bool) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.config.modifyLock.Lock()
defer c.config.modifyLock.Unlock()
c.config.CloneHeaders = cloneHeaders
}
// CloneHeaders gets the configured CloneHeaders value.
func (c *Client) CloneHeaders() bool {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()
return c.config.CloneHeaders
}
// Clone creates a new client with the same configuration. Note that the same // Clone creates a new client with the same configuration. Note that the same
// underlying http.Client is used; modifying the client from more than one // underlying http.Client is used; modifying the client from more than one
// goroutine at once may not be safe, so modify the client as needed and then // goroutine at once may not be safe, so modify the client as needed and then
@@ -839,12 +863,17 @@ func (c *Client) Clone() (*Client, error) {
OutputCurlString: config.OutputCurlString, OutputCurlString: config.OutputCurlString,
AgentAddress: config.AgentAddress, AgentAddress: config.AgentAddress,
SRVLookup: config.SRVLookup, SRVLookup: config.SRVLookup,
CloneHeaders: config.CloneHeaders,
} }
client, err := NewClient(newConfig) client, err := NewClient(newConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if config.CloneHeaders {
client.SetHeaders(c.Headers().Clone())
}
return client, nil return client, nil
} }

View File

@@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -409,63 +410,107 @@ func TestClientNonTransportRoundTripper(t *testing.T) {
} }
func TestClone(t *testing.T) { func TestClone(t *testing.T) {
client1, err := NewClient(DefaultConfig()) type fields struct {
if err != nil { }
t.Fatalf("NewClient failed: %v", err) tests := []struct {
name string
config *Config
headers *http.Header
}{
{
name: "default",
config: DefaultConfig(),
},
{
name: "cloneHeaders",
config: &Config{
CloneHeaders: true,
},
headers: &http.Header{
"X-foo": []string{"bar"},
"X-baz": []string{"qux"},
},
},
} }
// Set all of the things that we provide setter methods for, which modify config values for _, tt := range tests {
err = client1.SetAddress("http://example.com:8080") t.Run(tt.name, func(t *testing.T) {
if err != nil { client1, err := NewClient(tt.config)
t.Fatalf("SetAddress failed: %v", err) if err != nil {
} t.Fatalf("NewClient failed: %v", err)
}
clientTimeout := time.Until(time.Now().AddDate(0, 0, 1)) // Set all of the things that we provide setter methods for, which modify config values
client1.SetClientTimeout(clientTimeout) err = client1.SetAddress("http://example.com:8080")
if err != nil {
t.Fatalf("SetAddress failed: %v", err)
}
checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) { clientTimeout := time.Until(time.Now().AddDate(0, 0, 1))
return true, nil client1.SetClientTimeout(clientTimeout)
}
client1.SetCheckRetry(checkRetry)
client1.SetLogger(hclog.NewNullLogger()) checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) {
return true, nil
}
client1.SetCheckRetry(checkRetry)
client1.SetLimiter(5.0, 10) client1.SetLogger(hclog.NewNullLogger())
client1.SetMaxRetries(5)
client1.SetOutputCurlString(true)
client1.SetSRVLookup(true)
client2, err := client1.Clone() client1.SetLimiter(5.0, 10)
if err != nil { client1.SetMaxRetries(5)
t.Fatalf("Clone failed: %v", err) client1.SetOutputCurlString(true)
} client1.SetSRVLookup(true)
if client1.Address() != client2.Address() { if tt.headers != nil {
t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address()) client1.SetHeaders(*tt.headers)
} }
if client1.ClientTimeout() != client2.ClientTimeout() {
t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout()) client2, err := client1.Clone()
} if err != nil {
if client1.CheckRetry() != nil && client2.CheckRetry() == nil { t.Fatalf("Clone failed: %v", err)
t.Fatal("checkRetry functions don't match. client2 is nil.") }
}
if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) { if client1.Address() != client2.Address() {
t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter()) t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address())
} }
if client1.Limiter().Limit() != client2.Limiter().Limit() { if client1.ClientTimeout() != client2.ClientTimeout() {
t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit()) t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout())
} }
if client1.Limiter().Burst() != client2.Limiter().Burst() { if client1.CheckRetry() != nil && client2.CheckRetry() == nil {
t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst()) t.Fatal("checkRetry functions don't match. client2 is nil.")
} }
if client1.MaxRetries() != client2.MaxRetries() { if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) {
t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries()) t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter())
} }
if client1.OutputCurlString() != client2.OutputCurlString() { if client1.Limiter().Limit() != client2.Limiter().Limit() {
t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString()) t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit())
} }
if client1.SRVLookup() != client2.SRVLookup() { if client1.Limiter().Burst() != client2.Limiter().Burst() {
t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup()) t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst())
}
if client1.MaxRetries() != client2.MaxRetries() {
t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries())
}
if client1.OutputCurlString() != client2.OutputCurlString() {
t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString())
}
if client1.SRVLookup() != client2.SRVLookup() {
t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup())
}
if tt.config.CloneHeaders {
if !reflect.DeepEqual(client1.Headers(), client2.Headers()) {
t.Fatalf("Headers() don't match: %v vs %v", client1.Headers(), client2.Headers())
}
if client1.config.CloneHeaders != client2.config.CloneHeaders {
t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", client1.config.CloneHeaders, client2.config.CloneHeaders)
}
if tt.headers != nil {
if !reflect.DeepEqual(*tt.headers, client2.Headers()) {
t.Fatalf("expected headers %v, actual %v", *tt.headers, client2.Headers())
}
}
}
})
} }
} }

3
changelog/12117.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:improvement
api: Allow cloning `api.Client` HTTP headers via `api.Config.CloneHeaders` or `api.Client.SetCloneHeaders`.
```