ParseAddress test fix (#17382)

* check prefix of previous address

* add tests for dial context switching

---------

Co-authored-by: Violet Hynes <violet.hynes@hashicorp.com>
This commit is contained in:
Stephen Wodecki
2024-01-10 10:35:23 -05:00
committed by GitHub
parent 75846bc58f
commit d3c790a495
2 changed files with 7 additions and 1 deletions

View File

@@ -530,6 +530,7 @@ func (c *Config) ParseAddress(address string) (*url.URL, error) {
return nil, err
}
previousAddress := c.Address
c.Address = address
if strings.HasPrefix(address, "unix://") {
@@ -552,7 +553,7 @@ func (c *Config) ParseAddress(address string) (*url.URL, error) {
} else {
return nil, fmt.Errorf("attempting to specify unix:// address with non-transport transport")
}
} else if strings.HasPrefix(c.Address, "unix://") {
} else if strings.HasPrefix(previousAddress, "unix://") {
// When the address being set does not begin with unix:// but the previous
// address in the Config did, change the transport's DialContext back to
// use the default configuration that cleanhttp uses.

View File

@@ -104,6 +104,7 @@ func TestClientSetAddress(t *testing.T) {
t.Fatalf("bad: expected: '172.168.2.1:8300' actual: %q", client.addr.Host)
}
// Test switching to Unix Socket address from TCP address
client.config.HttpClient.Transport.(*http.Transport).DialContext = nil
if err := client.SetAddress("unix:///var/run/vault.sock"); err != nil {
t.Fatal(err)
}
@@ -120,6 +121,7 @@ func TestClientSetAddress(t *testing.T) {
t.Fatal("bad: expected DialContext to not be nil")
}
// Test switching to TCP address from Unix Socket address
client.config.HttpClient.Transport.(*http.Transport).DialContext = nil
if err := client.SetAddress("http://172.168.2.1:8300"); err != nil {
t.Fatal(err)
}
@@ -129,6 +131,9 @@ func TestClientSetAddress(t *testing.T) {
if client.addr.Scheme != "http" {
t.Fatalf("bad: expected: 'http' actual: %q", client.addr.Scheme)
}
if client.config.HttpClient.Transport.(*http.Transport).DialContext == nil {
t.Fatal("bad: expected DialContext to not be nil")
}
}
func TestClientToken(t *testing.T) {