diff --git a/authority/authority.go b/authority/authority.go index 2f66aeb0..f48a888b 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -8,8 +8,8 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "fmt" "log" - "net/http" "strings" "sync" "time" @@ -50,8 +50,8 @@ type Authority struct { templates *templates.Templates linkedCAToken string wrapTransport httptransport.Wrapper - webhookClient *http.Client - httpClient *http.Client + webhookClient provisioner.HTTPClient + httpClient provisioner.HTTPClient // X509 CA password []byte @@ -147,6 +147,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) { a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter) } + // Initialize system cert pool + if err := initializeSystemCertPool(); err != nil { + return nil, fmt.Errorf("failed to initialize the system cert pool: %w", err) + } + if !a.skipInit { // Initialize authority from options or configuration. if err := a.init(); err != nil { @@ -177,6 +182,11 @@ func NewEmbedded(opts ...Option) (*Authority, error) { a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter) } + // Initialize system cert pool + if err := initializeSystemCertPool(); err != nil { + return nil, fmt.Errorf("failed to initialize the system cert pool: %w", err) + } + // Validate required options switch { case a.config == nil: @@ -500,7 +510,7 @@ func (a *Authority) init() error { clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs)) clientRoots = append(clientRoots, a.rootX509Certs...) clientRoots = append(clientRoots, a.federatedX509Certs...) - a.httpClient, err = newHTTPClient(a.wrapTransport, clientRoots...) + a.httpClient = newHTTPClient(a.wrapTransport, clientRoots...) if err != nil { return err } diff --git a/authority/authority_test.go b/authority/authority_test.go index 387f7beb..973e7a9c 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "encoding/hex" "encoding/pem" + "fmt" "net" "os" "path/filepath" @@ -25,6 +26,16 @@ import ( "go.step.sm/crypto/pemutil" ) +func TestMain(m *testing.M) { + if err := initializeSystemCertPool(); err != nil { + fmt.Fprintln(os.Stderr, "failed to initialize system cert pool:", err) + fmt.Fprintln(os.Stderr, "See https://pkg.go.dev/github.com/tjfoc/gmsm/x509#SystemCertPool\n", err) + os.Exit(2) + } + + os.Exit(m.Run()) +} + func testAuthority(t *testing.T, opts ...Option) *Authority { maxjwk, err := jose.ReadKey("testdata/secrets/max_pub.jwk") assert.FatalError(t, err) diff --git a/authority/http_client.go b/authority/http_client.go index d06464b3..3c696564 100644 --- a/authority/http_client.go +++ b/authority/http_client.go @@ -3,36 +3,54 @@ package authority import ( "crypto/tls" "crypto/x509" - "fmt" "net/http" + "sync/atomic" + "github.com/smallstep/certificates/authority/poolhttp" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/internal/httptransport" ) +// systemCertPool holds a copy of the system cert pool. This cert pool must be +// initialized when the authority is created and we should always get a clone of +// this pool. +var systemCertPool atomic.Pointer[x509.CertPool] + +// initializeSystemCertPool initializes the system cert pool if necessary. +func initializeSystemCertPool() error { + if systemCertPool.Load() == nil { + pool, err := x509.SystemCertPool() + if err != nil { + return err + } + systemCertPool.Store(pool) + } + return nil +} + // newHTTPClient will return an HTTP client that trusts the system cert pool and // the given roots. -func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) (*http.Client, error) { - pool, err := x509.SystemCertPool() - if err != nil { - return nil, fmt.Errorf("error initializing http client: %w", err) - } - for _, crt := range roots { - pool.AddCert(crt) - } +func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) provisioner.HTTPClient { + return poolhttp.New(func() *http.Client { + pool := systemCertPool.Load().Clone() + for _, crt := range roots { + pool.AddCert(crt) + } - tr, ok := http.DefaultTransport.(*http.Transport) - if !ok { - tr = httptransport.New() - } else { - tr = tr.Clone() - } + tr, ok := http.DefaultTransport.(*http.Transport) + if !ok { + tr = httptransport.New() + } else { + tr = tr.Clone() + } - tr.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - RootCAs: pool, - } + tr.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + RootCAs: pool, + } - return &http.Client{ - Transport: wt(tr), - }, nil + rr := wt(tr) + + return &http.Client{Transport: rr} + }) } diff --git a/authority/http_client_test.go b/authority/http_client_test.go index 5a77331a..cb751019 100644 --- a/authority/http_client_test.go +++ b/authority/http_client_test.go @@ -114,8 +114,7 @@ func Test_newHTTPClient(t *testing.T) { }{http.DefaultTransport} http.DefaultTransport = transport - client, err := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...) - assert.NoError(t, err) + client := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...) assert.NotNil(t, client) }) } diff --git a/authority/options.go b/authority/options.go index 6a75bdd9..d85a611d 100644 --- a/authority/options.go +++ b/authority/options.go @@ -5,7 +5,6 @@ import ( "crypto" "crypto/x509" "encoding/pem" - "net/http" "github.com/pkg/errors" "golang.org/x/crypto/ssh" @@ -97,7 +96,7 @@ func WithQuietInit() Option { } // WithWebhookClient sets the http.Client to be used for outbound requests. -func WithWebhookClient(c *http.Client) Option { +func WithWebhookClient(c provisioner.HTTPClient) Option { return func(a *Authority) error { a.webhookClient = c return nil diff --git a/authority/poolhttp/poolhttp.go b/authority/poolhttp/poolhttp.go new file mode 100644 index 00000000..d639bcf4 --- /dev/null +++ b/authority/poolhttp/poolhttp.go @@ -0,0 +1,80 @@ +package poolhttp + +import ( + "net/http" + "sync" + + "github.com/smallstep/certificates/internal/httptransport" +) + +// Transporter is the implemented by custom HTTP clients with a method that +// returns an [*http.Transport]. +type Transporter interface { + Transport() *http.Transport +} + +// Client returns an HTTP client that uses a [sync.Pool] to create new HTTP +// client. It implements the [provisioner.HTTPClient] and [Transporter] +// interfaces. This is the HTTP client used by the provisioners. +type Client struct { + pool sync.Pool +} + +// New creates a new poolhttp [Client], the [sync.Pool] will initialize a new +// [*http.Client] with the given function. +func New(fn func() *http.Client) *Client { + return &Client{ + pool: sync.Pool{ + New: func() any { return fn() }, + }, + } +} + +// SetNew replaces the inner pool with a new [sync.Pool] with the given New +// function. This method should not be used concurrently with other methods. +func (c *Client) SetNew(fn func() *http.Client) { + c.pool = sync.Pool{ + New: func() any { return fn() }, + } +} + +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// [Client.CheckRedirect] function: +func (c *Client) Get(u string) (resp *http.Response, err error) { + if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + resp, err = hc.Get(u) + c.pool.Put(hc) + } else { + resp, err = http.DefaultClient.Get(u) + } + + return +} + +// Do sends an HTTP request and returns an HTTP response, following policy (such +// as redirects, cookies, auth) as configured on the client. +func (c *Client) Do(req *http.Request) (resp *http.Response, err error) { + if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + resp, err = hc.Do(req) + c.pool.Put(hc) + } else { + resp, err = http.DefaultClient.Do(req) + } + + return +} + +// Transport() returns a clone of the http.Client Transport or returns the +// default transport. +func (c *Client) Transport() *http.Transport { + if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + tr, ok := hc.Transport.(*http.Transport) + c.pool.Put(hc) + if ok { + return tr.Clone() + } + } + + return httptransport.New() +} diff --git a/authority/poolhttp/poolhttp_test.go b/authority/poolhttp/poolhttp_test.go new file mode 100644 index 00000000..ba4dcb30 --- /dev/null +++ b/authority/poolhttp/poolhttp_test.go @@ -0,0 +1,124 @@ +package poolhttp + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func requireBody(t *testing.T, want string, r io.ReadCloser) { + t.Helper() + t.Cleanup(func() { + require.NoError(t, r.Close()) + }) + + b, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, want, string(b)) +} + +func TestClient(t *testing.T) { + httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello World") + })) + t.Cleanup(httpSrv.Close) + tlsSrv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello World") + })) + t.Cleanup(tlsSrv.Close) + + tests := []struct { + name string + client *Client + srv *httptest.Server + }{ + {"http", New(func() *http.Client { return httpSrv.Client() }), httpSrv}, + {"tls", New(func() *http.Client { return tlsSrv.Client() }), tlsSrv}, + {"nil", New(func() *http.Client { return nil }), httpSrv}, + {"empty", &Client{}, httpSrv}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + resp, err := tc.client.Get(tc.srv.URL) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + + req, err := http.NewRequest("GET", tc.srv.URL, http.NoBody) + require.NoError(t, err) + + resp, err = tc.client.Do(req) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + + client := &http.Client{ + Transport: tc.client.Transport(), + } + resp, err = client.Get(tc.srv.URL) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + }) + } +} + +func TestClient_SetNew(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello World") + })) + t.Cleanup(srv.Close) + + c := New(func() *http.Client { + return srv.Client() + }) + + tests := []struct { + name string + client *http.Client + assertion assert.ErrorAssertionFunc + }{ + {"ok", srv.Client(), assert.NoError}, + {"fail", http.DefaultClient, assert.Error}, + {"ok again", srv.Client(), assert.NoError}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c.SetNew(func() *http.Client { + return tc.client + }) + _, err := c.Get(srv.URL) + tc.assertion(t, err) + + }) + } +} + +func TestClient_parallel(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello World") + })) + t.Cleanup(srv.Close) + + c := New(func() *http.Client { + return srv.Client() + }) + req, err := http.NewRequest("GET", srv.URL, http.NoBody) + require.NoError(t, err) + + for i := range 10 { + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + resp, err := c.Get(srv.URL) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + + resp, err = c.Do(req) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + }) + } +} diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go index 3d828d76..93439e0c 100644 --- a/authority/provisioner/controller.go +++ b/authority/provisioner/controller.go @@ -28,9 +28,9 @@ type Controller struct { AuthorizeRenewFunc AuthorizeRenewFunc AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc policy *policyEngine - webhookClient *http.Client + httpClient HTTPClient + webhookClient HTTPClient webhooks []*Webhook - httpClient *http.Client wrapTransport httptransport.Wrapper } @@ -48,6 +48,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options) if wt == nil { wt = httptransport.NoopWrapper() } + return &Controller{ Interface: p, Audiences: &config.Audiences, @@ -65,7 +66,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options) // GetHTTPClient returns the configured HTTP client or the default one if none // is configured. -func (c *Controller) GetHTTPClient() *http.Client { +func (c *Controller) GetHTTPClient() HTTPClient { if c.httpClient != nil { return c.httpClient } diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go index 23a4e9aa..0fba67e4 100644 --- a/authority/provisioner/keystore.go +++ b/authority/provisioner/keystore.go @@ -3,7 +3,6 @@ package provisioner import ( "encoding/json" "math/rand" - "net/http" "regexp" "strconv" "sync" @@ -22,33 +21,26 @@ var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`) type keyStore struct { sync.RWMutex - client *http.Client + client HTTPClient uri string keySet jose.JSONWebKeySet - timer *time.Timer expiry time.Time jitter time.Duration } -func newKeyStore(client *http.Client, uri string) (*keyStore, error) { +func newKeyStore(client HTTPClient, uri string) (*keyStore, error) { keys, age, err := getKeysFromJWKsURI(client, uri) if err != nil { return nil, err } - ks := &keyStore{ + jitter := getCacheJitter(age) + return &keyStore{ client: client, uri: uri, keySet: keys, - expiry: getExpirationTime(age), - jitter: getCacheJitter(age), - } - next := ks.nextReloadDuration(age) - ks.timer = time.AfterFunc(next, ks.reload) - return ks, nil -} - -func (ks *keyStore) Close() { - ks.timer.Stop() + expiry: getExpirationTime(age, jitter), + jitter: jitter, + }, nil } func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { @@ -65,34 +57,16 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { } func (ks *keyStore) reload() { - var next time.Duration - keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri) - if err != nil { - next = ks.nextReloadDuration(ks.jitter / 2) - } else { + if keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri); err == nil { ks.Lock() ks.keySet = keys - ks.expiry = getExpirationTime(age) ks.jitter = getCacheJitter(age) - next = ks.nextReloadDuration(age) + ks.expiry = getExpirationTime(age, ks.jitter) ks.Unlock() } - - ks.Lock() - ks.timer.Reset(next) - ks.Unlock() } -// nextReloadDuration would return the duration for the next rotation. If age is -// 0 it will randomly rotate between 0-12 hours, but every time we call to Get -// it will automatically rotate. -func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { - n := rand.Int63n(int64(ks.jitter)) //nolint:gosec // not used for cryptographic security - age -= time.Duration(n) - return abs(age) -} - -func getKeysFromJWKsURI(client *http.Client, uri string) (jose.JSONWebKeySet, time.Duration, error) { +func getKeysFromJWKsURI(client HTTPClient, uri string) (jose.JSONWebKeySet, time.Duration, error) { var keys jose.JSONWebKeySet resp, err := client.Get(uri) if err != nil { @@ -136,8 +110,12 @@ func getCacheJitter(age time.Duration) time.Duration { } } -func getExpirationTime(age time.Duration) time.Time { - return time.Now().Truncate(time.Second).Add(age) +func getExpirationTime(age, jitter time.Duration) time.Time { + if age > 0 { + n := rand.Int63n(int64(jitter)) //nolint:gosec // not used for cryptographic security + age -= time.Duration(n) + } + return time.Now().Truncate(time.Second).Add(abs(age)) } // abs returns the absolute value of n. diff --git a/authority/provisioner/keystore_test.go b/authority/provisioner/keystore_test.go index 85d4260f..b17dc870 100644 --- a/authority/provisioner/keystore_test.go +++ b/authority/provisioner/keystore_test.go @@ -22,7 +22,6 @@ func Test_newKeyStore(t *testing.T) { ks, err := newKeyStore(srv.Client(), srv.URL) assert.FatalError(t, err) - defer ks.Close() type args struct { client *http.Client @@ -49,7 +48,6 @@ func Test_newKeyStore(t *testing.T) { if !reflect.DeepEqual(got.keySet, tt.want) { t.Errorf("newKeyStore() = %v, want %v", got, tt.want) } - got.Close() } }) } @@ -61,7 +59,6 @@ func Test_keyStore(t *testing.T) { ks, err := newKeyStore(srv.Client(), srv.URL+"/random") assert.FatalError(t, err) - defer ks.Close() ks.RLock() keySet1 := ks.keySet ks.RUnlock() @@ -73,6 +70,7 @@ func Test_keyStore(t *testing.T) { // Wait for rotation time.Sleep(5 * time.Second) + assert.Len(t, 0, ks.Get("foobar")) // force refresh ks.RLock() keySet2 := ks.keySet @@ -105,7 +103,6 @@ func Test_keyStore_noCache(t *testing.T) { ks, err := newKeyStore(srv.Client(), srv.URL+"/no-cache") assert.FatalError(t, err) - defer ks.Close() ks.RLock() keySet1 := ks.keySet ks.RUnlock() @@ -116,20 +113,6 @@ func Test_keyStore_noCache(t *testing.T) { assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID)) assert.Len(t, 0, ks.Get("foobar")) - ks.RLock() - keySet2 := ks.keySet - ks.RUnlock() - if reflect.DeepEqual(keySet1, keySet2) { - t.Error("keyStore did not rotated") - } - - // The keys will rotate on Get. - // So we won't be able to find the cached ones - assert.Len(t, 2, keySet2.Keys) - assert.Len(t, 0, ks.Get(keySet2.Keys[0].KeyID)) - assert.Len(t, 0, ks.Get(keySet2.Keys[1].KeyID)) - assert.Len(t, 0, ks.Get("foobar")) - // Check hits resp, err := srv.Client().Get(srv.URL + "/hits") assert.FatalError(t, err) @@ -147,7 +130,6 @@ func Test_keyStore_Get(t *testing.T) { defer srv.Close() ks, err := newKeyStore(srv.Client(), srv.URL) assert.FatalError(t, err) - defer ks.Close() type args struct { kid string diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 16f9caf4..044971bf 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -485,7 +485,7 @@ func (o *OIDC) AuthorizeSSHRevoke(_ context.Context, token string) error { return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token") } -func getAndDecode(client *http.Client, uri string, v interface{}) error { +func getAndDecode(client HTTPClient, uri string, v interface{}) error { resp, err := client.Get(uri) if err != nil { return errors.Wrapf(err, "failed to connect to %s", uri) diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 86fdd5ec..d57e3582 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -34,6 +34,13 @@ type Interface interface { AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) } +// HTTPClient is the interface implemented by the HTTP clients used by the +// provisioners. +type HTTPClient interface { + Get(string) (*http.Response, error) + Do(*http.Request) (*http.Response, error) +} + // Uninitialized represents a disabled provisioner. Uninitialized provisioners // are created when the Init methods fails. type Uninitialized struct { @@ -251,6 +258,16 @@ type Config struct { // GetIdentityFunc is a function that returns an identity that will be // used by the provisioner to populate certificate attributes. GetIdentityFunc GetIdentityFunc + /* + // GetHttpClientFunc is a function that returns an HTTP client that trusts + // the system cert pool and the CA roots. + GetHttpClientFunc func() *http.Client + // GetWebhookClientFunc ris a function that returns the HTTP client used + // when performing webhook requests, this is an HTTP client that trusts the + // system cert pool, the CA roots and uses the CA certificate as a client + // certificate. + GetWebhookClientFunc func() *http.Client + */ // AuthorizeRenewFunc is a function that returns nil if a given X.509 // certificate can be renewed. AuthorizeRenewFunc AuthorizeRenewFunc @@ -258,12 +275,12 @@ type Config struct { // certificate can be renewed. AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc // WebhookClient is an HTTP client used when performing webhook requests. - WebhookClient *http.Client + WebhookClient HTTPClient // SCEPKeyManager, if defined, is the interface used by SCEP provisioners. SCEPKeyManager SCEPKeyManager // HTTPClient is an HTTP client that trusts the system cert pool and the CA // roots. - HTTPClient *http.Client + HTTPClient HTTPClient // WrapTransport references the function that should wrap any [http.Transport] initialized // down the Config's chain. WrapTransport TransportWrapper diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index b6e8b925..a97ff8e5 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "net/http" "time" "github.com/pkg/errors" @@ -113,14 +112,14 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration { } type challengeValidationController struct { - client *http.Client + client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook } // newChallengeValidationController creates a new challengeValidationController // that performs challenge validation through webhooks. -func newChallengeValidationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { +func newChallengeValidationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { scepHooks := []*Webhook{} for _, wh := range webhooks { if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() { @@ -179,14 +178,14 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509. } type notificationController struct { - client *http.Client + client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook } // newNotificationController creates a new notificationController // that performs SCEP notifications through webhooks. -func newNotificationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController { +func newNotificationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController { scepHooks := []*Webhook{} for _, wh := range webhooks { if wh.Kind != linkedca.Webhook_NOTIFYING.String() { diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index dba4a4c8..b78b90b0 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -18,6 +18,7 @@ import ( "github.com/smallstep/linkedca" + "github.com/smallstep/certificates/authority/poolhttp" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/middleware/requestid" "github.com/smallstep/certificates/templates" @@ -31,7 +32,7 @@ type WebhookSetter interface { } type WebhookController struct { - client *http.Client + client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook certType linkedca.Webhook_CertType @@ -146,7 +147,7 @@ type Webhook struct { // [http.RoundTripper]. type TransportWrapper = httptransport.Wrapper -func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { +func (w *Webhook) DoWithContext(ctx context.Context, client HTTPClient, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) if err != nil { return nil, err @@ -206,11 +207,11 @@ retry: } if w.DisableTLSClientAuth { - transport, ok := client.Transport.(*http.Transport) - if !ok { - transport = httptransport.New() + var transport *http.Transport + if ct, ok := client.(poolhttp.Transporter); ok { + transport = ct.Transport() } else { - transport = transport.Clone() + transport = httptransport.New() } if transport.TLSClientConfig != nil { diff --git a/templates/templates.go b/templates/templates.go index e58f5fbc..a8cd1df2 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "strings" + "sync" "text/template" "github.com/Masterminds/sprig/v3" @@ -29,6 +30,22 @@ const ( Directory TemplateType = "directory" ) +var ( + stepFuncMap template.FuncMap + stepFuncOnce sync.Once +) + +// StepFuncMap returns sprig.TxtFuncMap but removing the "env" and "expandenv" +// functions to avoid any leak of information. +func StepFuncMap() template.FuncMap { + stepFuncOnce.Do(func() { + stepFuncMap = sprig.TxtFuncMap() + delete(stepFuncMap, "env") + delete(stepFuncMap, "expandenv") + }) + return stepFuncMap +} + // Templates is a collection of templates and variables. type Templates struct { SSH *SSHTemplates `json:"ssh,omitempty"` @@ -282,12 +299,3 @@ func mkdir(path string, perm os.FileMode) error { } return nil } - -// StepFuncMap returns sprig.TxtFuncMap but removing the "env" and "expandenv" -// functions to avoid any leak of information. -func StepFuncMap() template.FuncMap { - m := sprig.TxtFuncMap() - delete(m, "env") - delete(m, "expandenv") - return m -}