mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 20:17:59 +00:00 
			
		
		
		
	Add ability to optionally clone a Client's token (#13515)
This commit is contained in:
		@@ -139,6 +139,9 @@ type Config struct {
 | 
				
			|||||||
	// its clone.
 | 
						// its clone.
 | 
				
			||||||
	CloneHeaders bool
 | 
						CloneHeaders bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// CloneToken from parent.
 | 
				
			||||||
 | 
						CloneToken bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ReadYourWrites ensures isolated read-after-write semantics by
 | 
						// ReadYourWrites ensures isolated read-after-write semantics by
 | 
				
			||||||
	// providing discovered cluster replication states in each request.
 | 
						// providing discovered cluster replication states in each request.
 | 
				
			||||||
	// The shared state is automatically propagated to all Client clones.
 | 
						// The shared state is automatically propagated to all Client clones.
 | 
				
			||||||
@@ -547,6 +550,7 @@ func (c *Client) CloneConfig() *Config {
 | 
				
			|||||||
	newConfig.OutputCurlString = c.config.OutputCurlString
 | 
						newConfig.OutputCurlString = c.config.OutputCurlString
 | 
				
			||||||
	newConfig.SRVLookup = c.config.SRVLookup
 | 
						newConfig.SRVLookup = c.config.SRVLookup
 | 
				
			||||||
	newConfig.CloneHeaders = c.config.CloneHeaders
 | 
						newConfig.CloneHeaders = c.config.CloneHeaders
 | 
				
			||||||
 | 
						newConfig.CloneToken = c.config.CloneToken
 | 
				
			||||||
	newConfig.ReadYourWrites = c.config.ReadYourWrites
 | 
						newConfig.ReadYourWrites = c.config.ReadYourWrites
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// we specifically want a _copy_ of the client here, not a pointer to the original one
 | 
						// we specifically want a _copy_ of the client here, not a pointer to the original one
 | 
				
			||||||
@@ -873,6 +877,26 @@ func (c *Client) CloneHeaders() bool {
 | 
				
			|||||||
	return c.config.CloneHeaders
 | 
						return c.config.CloneHeaders
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetCloneToken from parent
 | 
				
			||||||
 | 
					func (c *Client) SetCloneToken(cloneToken bool) {
 | 
				
			||||||
 | 
						c.modifyLock.Lock()
 | 
				
			||||||
 | 
						defer c.modifyLock.Unlock()
 | 
				
			||||||
 | 
						c.config.modifyLock.Lock()
 | 
				
			||||||
 | 
						defer c.config.modifyLock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c.config.CloneToken = cloneToken
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CloneToken gets the configured CloneToken value.
 | 
				
			||||||
 | 
					func (c *Client) CloneToken() bool {
 | 
				
			||||||
 | 
						c.modifyLock.RLock()
 | 
				
			||||||
 | 
						defer c.modifyLock.RUnlock()
 | 
				
			||||||
 | 
						c.config.modifyLock.RLock()
 | 
				
			||||||
 | 
						defer c.config.modifyLock.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return c.config.CloneToken
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SetReadYourWrites to prevent reading stale cluster replication state.
 | 
					// SetReadYourWrites to prevent reading stale cluster replication state.
 | 
				
			||||||
func (c *Client) SetReadYourWrites(preventStaleReads bool) {
 | 
					func (c *Client) SetReadYourWrites(preventStaleReads bool) {
 | 
				
			||||||
	c.modifyLock.Lock()
 | 
						c.modifyLock.Lock()
 | 
				
			||||||
@@ -932,6 +956,7 @@ func (c *Client) Clone() (*Client, error) {
 | 
				
			|||||||
		AgentAddress:     config.AgentAddress,
 | 
							AgentAddress:     config.AgentAddress,
 | 
				
			||||||
		SRVLookup:        config.SRVLookup,
 | 
							SRVLookup:        config.SRVLookup,
 | 
				
			||||||
		CloneHeaders:     config.CloneHeaders,
 | 
							CloneHeaders:     config.CloneHeaders,
 | 
				
			||||||
 | 
							CloneToken:       config.CloneToken,
 | 
				
			||||||
		ReadYourWrites:   config.ReadYourWrites,
 | 
							ReadYourWrites:   config.ReadYourWrites,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	client, err := NewClient(newConfig)
 | 
						client, err := NewClient(newConfig)
 | 
				
			||||||
@@ -943,6 +968,10 @@ func (c *Client) Clone() (*Client, error) {
 | 
				
			|||||||
		client.SetHeaders(c.Headers().Clone())
 | 
							client.SetHeaders(c.Headers().Clone())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if config.CloneToken {
 | 
				
			||||||
 | 
							client.SetToken(c.token)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	client.replicationStateStore = c.replicationStateStore
 | 
						client.replicationStateStore = c.replicationStateStore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return client, nil
 | 
						return client, nil
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -420,6 +420,7 @@ func TestClone(t *testing.T) {
 | 
				
			|||||||
		name    string
 | 
							name    string
 | 
				
			||||||
		config  *Config
 | 
							config  *Config
 | 
				
			||||||
		headers *http.Header
 | 
							headers *http.Header
 | 
				
			||||||
 | 
							token   string
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:   "default",
 | 
								name:   "default",
 | 
				
			||||||
@@ -441,91 +442,119 @@ func TestClone(t *testing.T) {
 | 
				
			|||||||
				ReadYourWrites: true,
 | 
									ReadYourWrites: true,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "cloneToken",
 | 
				
			||||||
 | 
								config: &Config{
 | 
				
			||||||
 | 
									CloneToken: true,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								token: "cloneToken",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			client1, err := NewClient(tt.config)
 | 
								parent, err := NewClient(tt.config)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				t.Fatalf("NewClient failed: %v", err)
 | 
									t.Fatalf("NewClient failed: %v", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Set all of the things that we provide setter methods for, which modify config values
 | 
								// Set all of the things that we provide setter methods for, which modify config values
 | 
				
			||||||
			err = client1.SetAddress("http://example.com:8080")
 | 
								err = parent.SetAddress("http://example.com:8080")
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				t.Fatalf("SetAddress failed: %v", err)
 | 
									t.Fatalf("SetAddress failed: %v", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			clientTimeout := time.Until(time.Now().AddDate(0, 0, 1))
 | 
								clientTimeout := time.Until(time.Now().AddDate(0, 0, 1))
 | 
				
			||||||
			client1.SetClientTimeout(clientTimeout)
 | 
								parent.SetClientTimeout(clientTimeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) {
 | 
								checkRetry := func(ctx context.Context, resp *http.Response, err error) (bool, error) {
 | 
				
			||||||
				return true, nil
 | 
									return true, nil
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			client1.SetCheckRetry(checkRetry)
 | 
								parent.SetCheckRetry(checkRetry)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			client1.SetLogger(hclog.NewNullLogger())
 | 
								parent.SetLogger(hclog.NewNullLogger())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			client1.SetLimiter(5.0, 10)
 | 
								parent.SetLimiter(5.0, 10)
 | 
				
			||||||
			client1.SetMaxRetries(5)
 | 
								parent.SetMaxRetries(5)
 | 
				
			||||||
			client1.SetOutputCurlString(true)
 | 
								parent.SetOutputCurlString(true)
 | 
				
			||||||
			client1.SetSRVLookup(true)
 | 
								parent.SetSRVLookup(true)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if tt.headers != nil {
 | 
								if tt.headers != nil {
 | 
				
			||||||
				client1.SetHeaders(*tt.headers)
 | 
									parent.SetHeaders(*tt.headers)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			client2, err := client1.Clone()
 | 
								if tt.token != "" {
 | 
				
			||||||
 | 
									parent.SetToken(tt.token)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								clone, err := parent.Clone()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				t.Fatalf("Clone failed: %v", err)
 | 
									t.Fatalf("Clone failed: %v", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if client1.Address() != client2.Address() {
 | 
								if parent.Address() != clone.Address() {
 | 
				
			||||||
				t.Fatalf("addresses don't match: %v vs %v", client1.Address(), client2.Address())
 | 
									t.Fatalf("addresses don't match: %v vs %v", parent.Address(), clone.Address())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if client1.ClientTimeout() != client2.ClientTimeout() {
 | 
								if parent.ClientTimeout() != clone.ClientTimeout() {
 | 
				
			||||||
				t.Fatalf("timeouts don't match: %v vs %v", client1.ClientTimeout(), client2.ClientTimeout())
 | 
									t.Fatalf("timeouts don't match: %v vs %v", parent.ClientTimeout(), clone.ClientTimeout())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if client1.CheckRetry() != nil && client2.CheckRetry() == nil {
 | 
								if parent.CheckRetry() != nil && clone.CheckRetry() == nil {
 | 
				
			||||||
				t.Fatal("checkRetry functions don't match. client2 is nil.")
 | 
									t.Fatal("checkRetry functions don't match. clone is nil.")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if (client1.Limiter() != nil && client2.Limiter() == nil) || (client1.Limiter() == nil && client2.Limiter() != nil) {
 | 
								if (parent.Limiter() != nil && clone.Limiter() == nil) || (parent.Limiter() == nil && clone.Limiter() != nil) {
 | 
				
			||||||
				t.Fatalf("limiters don't match: %v vs %v", client1.Limiter(), client2.Limiter())
 | 
									t.Fatalf("limiters don't match: %v vs %v", parent.Limiter(), clone.Limiter())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if client1.Limiter().Limit() != client2.Limiter().Limit() {
 | 
								if parent.Limiter().Limit() != clone.Limiter().Limit() {
 | 
				
			||||||
				t.Fatalf("limiter limits don't match: %v vs %v", client1.Limiter().Limit(), client2.Limiter().Limit())
 | 
									t.Fatalf("limiter limits don't match: %v vs %v", parent.Limiter().Limit(), clone.Limiter().Limit())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if client1.Limiter().Burst() != client2.Limiter().Burst() {
 | 
								if parent.Limiter().Burst() != clone.Limiter().Burst() {
 | 
				
			||||||
				t.Fatalf("limiter bursts don't match: %v vs %v", client1.Limiter().Burst(), client2.Limiter().Burst())
 | 
									t.Fatalf("limiter bursts don't match: %v vs %v", parent.Limiter().Burst(), clone.Limiter().Burst())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if client1.MaxRetries() != client2.MaxRetries() {
 | 
								if parent.MaxRetries() != clone.MaxRetries() {
 | 
				
			||||||
				t.Fatalf("maxRetries don't match: %v vs %v", client1.MaxRetries(), client2.MaxRetries())
 | 
									t.Fatalf("maxRetries don't match: %v vs %v", parent.MaxRetries(), clone.MaxRetries())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if client1.OutputCurlString() != client2.OutputCurlString() {
 | 
								if parent.OutputCurlString() != clone.OutputCurlString() {
 | 
				
			||||||
				t.Fatalf("outputCurlString doesn't match: %v vs %v", client1.OutputCurlString(), client2.OutputCurlString())
 | 
									t.Fatalf("outputCurlString doesn't match: %v vs %v", parent.OutputCurlString(), clone.OutputCurlString())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if client1.SRVLookup() != client2.SRVLookup() {
 | 
								if parent.SRVLookup() != clone.SRVLookup() {
 | 
				
			||||||
				t.Fatalf("SRVLookup doesn't match: %v vs %v", client1.SRVLookup(), client2.SRVLookup())
 | 
									t.Fatalf("SRVLookup doesn't match: %v vs %v", parent.SRVLookup(), clone.SRVLookup())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if tt.config.CloneHeaders {
 | 
								if tt.config.CloneHeaders {
 | 
				
			||||||
				if !reflect.DeepEqual(client1.Headers(), client2.Headers()) {
 | 
									if !reflect.DeepEqual(parent.Headers(), clone.Headers()) {
 | 
				
			||||||
					t.Fatalf("Headers() don't match: %v vs %v", client1.Headers(), client2.Headers())
 | 
										t.Fatalf("Headers() don't match: %v vs %v", parent.Headers(), clone.Headers())
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if client1.config.CloneHeaders != client2.config.CloneHeaders {
 | 
									if parent.config.CloneHeaders != clone.config.CloneHeaders {
 | 
				
			||||||
					t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", client1.config.CloneHeaders, client2.config.CloneHeaders)
 | 
										t.Fatalf("config.CloneHeaders doesn't match: %v vs %v", parent.config.CloneHeaders, clone.config.CloneHeaders)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				if tt.headers != nil {
 | 
									if tt.headers != nil {
 | 
				
			||||||
					if !reflect.DeepEqual(*tt.headers, client2.Headers()) {
 | 
										if !reflect.DeepEqual(*tt.headers, clone.Headers()) {
 | 
				
			||||||
						t.Fatalf("expected headers %v, actual %v", *tt.headers, client2.Headers())
 | 
											t.Fatalf("expected headers %v, actual %v", *tt.headers, clone.Headers())
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if tt.config.ReadYourWrites && client1.replicationStateStore == nil {
 | 
								if tt.config.ReadYourWrites && parent.replicationStateStore == nil {
 | 
				
			||||||
				t.Fatalf("replicationStateStore is nil")
 | 
									t.Fatalf("replicationStateStore is nil")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if !reflect.DeepEqual(client1.replicationStateStore, client2.replicationStateStore) {
 | 
								if tt.config.CloneToken {
 | 
				
			||||||
				t.Fatalf("expected replicationStateStore %v, actual %v", client1.replicationStateStore,
 | 
									if tt.token == "" {
 | 
				
			||||||
					client2.replicationStateStore)
 | 
										t.Fatalf("test requires a non-empty token")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if parent.config.CloneToken != clone.config.CloneToken {
 | 
				
			||||||
 | 
										t.Fatalf("config.CloneToken doesn't match: %v vs %v", parent.config.CloneToken, clone.config.CloneToken)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if parent.token != clone.token {
 | 
				
			||||||
 | 
										t.Fatalf("tokens do not match: %v vs %v", parent.token, clone.token)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									// assumes `VAULT_TOKEN` is unset or has an empty value.
 | 
				
			||||||
 | 
									expected := ""
 | 
				
			||||||
 | 
									if clone.token != expected {
 | 
				
			||||||
 | 
										t.Fatalf("expected clone's token %q, actual %q", expected, clone.token)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if !reflect.DeepEqual(parent.replicationStateStore, clone.replicationStateStore) {
 | 
				
			||||||
 | 
									t.Fatalf("expected replicationStateStore %v, actual %v", parent.replicationStateStore,
 | 
				
			||||||
 | 
										clone.replicationStateStore)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -1052,3 +1081,45 @@ func TestClient_SetReadYourWrites(t *testing.T) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestClient_SetCloneToken(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name  string
 | 
				
			||||||
 | 
							calls []bool
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "false",
 | 
				
			||||||
 | 
								calls: []bool{false},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "true",
 | 
				
			||||||
 | 
								calls: []bool{true},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:  "multi",
 | 
				
			||||||
 | 
								calls: []bool{true, false, true},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								c := &Client{
 | 
				
			||||||
 | 
									config: &Config{},
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								var expected bool
 | 
				
			||||||
 | 
								for _, v := range tt.calls {
 | 
				
			||||||
 | 
									actual := c.CloneToken()
 | 
				
			||||||
 | 
									if expected != actual {
 | 
				
			||||||
 | 
										t.Fatalf("expected %v, actual %v", expected, actual)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									expected = v
 | 
				
			||||||
 | 
									c.SetCloneToken(expected)
 | 
				
			||||||
 | 
									actual = c.CloneToken()
 | 
				
			||||||
 | 
									if actual != expected {
 | 
				
			||||||
 | 
										t.Fatalf("SetCloneToken(): expected %v, actual %v", expected, actual)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										3
									
								
								changelog/13515.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								changelog/13515.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					```release-note:improvement
 | 
				
			||||||
 | 
					api: Allow cloning `api.Client` tokens via `api.Config.CloneToken` or `api.Client.SetCloneToken()`.
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
		Reference in New Issue
	
	Block a user