From 3ca33976762c3ea250367baf25325c0432bbaaa5 Mon Sep 17 00:00:00 2001 From: Ben Ash <32777270+benashz@users.noreply.github.com> Date: Wed, 28 Jun 2023 14:56:39 -0700 Subject: [PATCH] Add support for cloning a Client's tls.Config (#21424) Additional fixes: - handle a failed type assert in api.Config.configureTLS() Co-authored-by: Anton Averchenkov <84287187+averche@users.noreply.github.com> --- api/client.go | 37 ++++++++++++++++++++++++++++++++++++- api/client_test.go | 35 +++++++++++++++++++++++++++++++++++ changelog/21424.txt | 3 +++ 3 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 changelog/21424.txt diff --git a/api/client.go b/api/client.go index d20477e1d9..4be03826bb 100644 --- a/api/client.go +++ b/api/client.go @@ -185,6 +185,9 @@ type Config struct { // CloneToken from parent. CloneToken bool + // CloneTLSConfig from parent (tls.Config). + CloneTLSConfig bool + // ReadYourWrites ensures isolated read-after-write semantics by // providing discovered cluster replication states in each request. // The shared state is automatically propagated to all Client clones. @@ -290,7 +293,14 @@ func (c *Config) configureTLS(t *TLSConfig) error { if c.HttpClient == nil { c.HttpClient = DefaultConfig().HttpClient } - clientTLSConfig := c.HttpClient.Transport.(*http.Transport).TLSClientConfig + + transport, ok := c.HttpClient.Transport.(*http.Transport) + if !ok { + return fmt.Errorf( + "unsupported HTTPClient transport type %T", c.HttpClient.Transport) + } + + clientTLSConfig := transport.TLSClientConfig var clientCert tls.Certificate foundClientCert := false @@ -1143,6 +1153,26 @@ func (c *Client) ReadYourWrites() bool { return c.config.ReadYourWrites } +// SetCloneTLSConfig from parent. +func (c *Client) SetCloneTLSConfig(clone bool) { + c.modifyLock.Lock() + defer c.modifyLock.Unlock() + c.config.modifyLock.Lock() + defer c.config.modifyLock.Unlock() + + c.config.CloneTLSConfig = clone +} + +// CloneTLSConfig gets the configured CloneTLSConfig value. +func (c *Client) CloneTLSConfig() bool { + c.modifyLock.RLock() + defer c.modifyLock.RUnlock() + c.config.modifyLock.RLock() + defer c.config.modifyLock.RUnlock() + + return c.config.CloneTLSConfig +} + // Clone creates a new client with the same configuration. Note that the same // underlying http.Client is used; modifying the client from more than one // goroutine at once may not be safe, so modify the client as needed and then @@ -1189,6 +1219,11 @@ func (c *Client) clone(cloneHeaders bool) (*Client, error) { CloneToken: config.CloneToken, ReadYourWrites: config.ReadYourWrites, } + + if config.CloneTLSConfig { + newConfig.clientTLSConfig = config.clientTLSConfig + } + client, err := NewClient(newConfig) if err != nil { return nil, err diff --git a/api/client_test.go b/api/client_test.go index a23c0c19e7..0de2c17c2b 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -6,6 +6,7 @@ package api import ( "bytes" "context" + "crypto/tls" "crypto/x509" "encoding/base64" "fmt" @@ -591,6 +592,24 @@ func TestClone(t *testing.T) { }, token: "cloneToken", }, + { + name: "cloneTLSConfig-enabled", + config: &Config{ + CloneTLSConfig: true, + clientTLSConfig: &tls.Config{ + ServerName: "foo.bar.baz", + }, + }, + }, + { + name: "cloneTLSConfig-disabled", + config: &Config{ + CloneTLSConfig: false, + clientTLSConfig: &tls.Config{ + ServerName: "foo.bar.baz", + }, + }, + }, } for _, tt := range tests { @@ -699,6 +718,22 @@ func TestClone(t *testing.T) { t.Fatalf("expected replicationStateStore %v, actual %v", parent.replicationStateStore, clone.replicationStateStore) } + if tt.config.CloneTLSConfig { + if !reflect.DeepEqual(parent.config.TLSConfig(), clone.config.TLSConfig()) { + t.Fatalf("config.clientTLSConfig doesn't match: %v vs %v", + parent.config.TLSConfig(), clone.config.TLSConfig()) + } + } else if tt.config.clientTLSConfig != nil { + if reflect.DeepEqual(parent.config.TLSConfig(), clone.config.TLSConfig()) { + t.Fatalf("config.clientTLSConfig should not match: %v vs %v", + parent.config.TLSConfig(), clone.config.TLSConfig()) + } + } else { + if !reflect.DeepEqual(parent.config.TLSConfig(), clone.config.TLSConfig()) { + t.Fatalf("config.clientTLSConfig doesn't match: %v vs %v", + parent.config.TLSConfig(), clone.config.TLSConfig()) + } + } }) } } diff --git a/changelog/21424.txt b/changelog/21424.txt new file mode 100644 index 0000000000..229e97e4d3 --- /dev/null +++ b/changelog/21424.txt @@ -0,0 +1,3 @@ +```release-note:improvement +api: add support for cloning a Client's tls.Config. +```