From 1ae78b803deb7afae4b98dd7507040593a0a1ea6 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 2 Jul 2025 12:45:52 -0700 Subject: [PATCH] Make poolhttp thread safe. --- authority/poolhttp/poolhttp.go | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/authority/poolhttp/poolhttp.go b/authority/poolhttp/poolhttp.go index 10e14194..dd8c5a11 100644 --- a/authority/poolhttp/poolhttp.go +++ b/authority/poolhttp/poolhttp.go @@ -17,6 +17,7 @@ type Transporter interface { // clients. It implements the [provisioner.HTTPClient] and [Transporter] // interfaces. This is the HTTP client used by the provisioners. type Client struct { + rw sync.RWMutex pool sync.Pool } @@ -31,18 +32,31 @@ func New(fn func() *http.Client) *Client { } // SetNew replaces the inner pool with a new [sync.Pool] with the given New -// function. This method should not be used concurrently with other methods. +// function. This method can be use concurrently with other methods of this +// package. func (c *Client) SetNew(fn func() *http.Client) { + c.rw.Lock() c.pool = sync.Pool{ New: func() any { return fn() }, } + c.rw.Unlock() +} + +// getClient gets a client from the pool. +func (c *Client) getClient() *http.Client { + c.rw.RLock() + defer c.rw.RUnlock() + if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + return hc + } + return nil } // Get issues a GET request to the specified URL. If the response is one of the // following redirect codes, Get follows the redirect after calling the // [Client.CheckRedirect] function: func (c *Client) Get(u string) (resp *http.Response, err error) { - if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + if hc := c.getClient(); hc != nil { resp, err = hc.Get(u) c.pool.Put(hc) } else { @@ -55,7 +69,7 @@ func (c *Client) Get(u string) (resp *http.Response, err error) { // Do sends an HTTP request and returns an HTTP response, following policy (such // as redirects, cookies, auth) as configured on the client. func (c *Client) Do(req *http.Request) (resp *http.Response, err error) { - if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + if hc := c.getClient(); hc != nil { resp, err = hc.Do(req) c.pool.Put(hc) } else { @@ -68,7 +82,7 @@ func (c *Client) Do(req *http.Request) (resp *http.Response, err error) { // Transport() returns a clone of the http.Client Transport or returns the // default transport. func (c *Client) Transport() *http.Transport { - if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + if hc := c.getClient(); hc != nil { tr, ok := hc.Transport.(*http.Transport) c.pool.Put(hc) if ok {