From d4fd467e8ae3ab0b7e3e03e6c94df0e5fbd76691 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 1 Jul 2025 13:53:52 -0700 Subject: [PATCH] Memory improvements This commit replaces the client in provisioners and webhooks with an interface. Then it implements the interface using the new poolhttp package. This package implements the HTTPClient interface but it is backed by a sync.Pool, this improves memory, allowing the GC to clean more memory. It also removes the timer in the keystore to avoid having extra goroutines if a provisioner goes away. This commit avoids creating the templates func multiple times, reducing some memory in the heap. --- authority/authority.go | 18 +++- authority/authority_test.go | 11 +++ authority/http_client.go | 62 ++++++++----- authority/http_client_test.go | 3 +- authority/options.go | 3 +- authority/poolhttp/poolhttp.go | 80 ++++++++++++++++ authority/poolhttp/poolhttp_test.go | 124 +++++++++++++++++++++++++ authority/provisioner/controller.go | 7 +- authority/provisioner/keystore.go | 54 ++++------- authority/provisioner/keystore_test.go | 20 +--- authority/provisioner/oidc.go | 2 +- authority/provisioner/provisioner.go | 21 ++++- authority/provisioner/scep.go | 9 +- authority/provisioner/webhook.go | 13 +-- templates/templates.go | 26 ++++-- 15 files changed, 340 insertions(+), 113 deletions(-) create mode 100644 authority/poolhttp/poolhttp.go create mode 100644 authority/poolhttp/poolhttp_test.go diff --git a/authority/authority.go b/authority/authority.go index 2f66aeb0..f48a888b 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -8,8 +8,8 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "fmt" "log" - "net/http" "strings" "sync" "time" @@ -50,8 +50,8 @@ type Authority struct { templates *templates.Templates linkedCAToken string wrapTransport httptransport.Wrapper - webhookClient *http.Client - httpClient *http.Client + webhookClient provisioner.HTTPClient + httpClient provisioner.HTTPClient // X509 CA password []byte @@ -147,6 +147,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) { a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter) } + // Initialize system cert pool + if err := initializeSystemCertPool(); err != nil { + return nil, fmt.Errorf("failed to initialize the system cert pool: %w", err) + } + if !a.skipInit { // Initialize authority from options or configuration. if err := a.init(); err != nil { @@ -177,6 +182,11 @@ func NewEmbedded(opts ...Option) (*Authority, error) { a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter) } + // Initialize system cert pool + if err := initializeSystemCertPool(); err != nil { + return nil, fmt.Errorf("failed to initialize the system cert pool: %w", err) + } + // Validate required options switch { case a.config == nil: @@ -500,7 +510,7 @@ func (a *Authority) init() error { clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs)) clientRoots = append(clientRoots, a.rootX509Certs...) clientRoots = append(clientRoots, a.federatedX509Certs...) - a.httpClient, err = newHTTPClient(a.wrapTransport, clientRoots...) + a.httpClient = newHTTPClient(a.wrapTransport, clientRoots...) if err != nil { return err } diff --git a/authority/authority_test.go b/authority/authority_test.go index 387f7beb..973e7a9c 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "encoding/hex" "encoding/pem" + "fmt" "net" "os" "path/filepath" @@ -25,6 +26,16 @@ import ( "go.step.sm/crypto/pemutil" ) +func TestMain(m *testing.M) { + if err := initializeSystemCertPool(); err != nil { + fmt.Fprintln(os.Stderr, "failed to initialize system cert pool:", err) + fmt.Fprintln(os.Stderr, "See https://pkg.go.dev/github.com/tjfoc/gmsm/x509#SystemCertPool\n", err) + os.Exit(2) + } + + os.Exit(m.Run()) +} + func testAuthority(t *testing.T, opts ...Option) *Authority { maxjwk, err := jose.ReadKey("testdata/secrets/max_pub.jwk") assert.FatalError(t, err) diff --git a/authority/http_client.go b/authority/http_client.go index d06464b3..3c696564 100644 --- a/authority/http_client.go +++ b/authority/http_client.go @@ -3,36 +3,54 @@ package authority import ( "crypto/tls" "crypto/x509" - "fmt" "net/http" + "sync/atomic" + "github.com/smallstep/certificates/authority/poolhttp" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/internal/httptransport" ) +// systemCertPool holds a copy of the system cert pool. This cert pool must be +// initialized when the authority is created and we should always get a clone of +// this pool. +var systemCertPool atomic.Pointer[x509.CertPool] + +// initializeSystemCertPool initializes the system cert pool if necessary. +func initializeSystemCertPool() error { + if systemCertPool.Load() == nil { + pool, err := x509.SystemCertPool() + if err != nil { + return err + } + systemCertPool.Store(pool) + } + return nil +} + // newHTTPClient will return an HTTP client that trusts the system cert pool and // the given roots. -func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) (*http.Client, error) { - pool, err := x509.SystemCertPool() - if err != nil { - return nil, fmt.Errorf("error initializing http client: %w", err) - } - for _, crt := range roots { - pool.AddCert(crt) - } +func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) provisioner.HTTPClient { + return poolhttp.New(func() *http.Client { + pool := systemCertPool.Load().Clone() + for _, crt := range roots { + pool.AddCert(crt) + } - tr, ok := http.DefaultTransport.(*http.Transport) - if !ok { - tr = httptransport.New() - } else { - tr = tr.Clone() - } + tr, ok := http.DefaultTransport.(*http.Transport) + if !ok { + tr = httptransport.New() + } else { + tr = tr.Clone() + } - tr.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - RootCAs: pool, - } + tr.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + RootCAs: pool, + } - return &http.Client{ - Transport: wt(tr), - }, nil + rr := wt(tr) + + return &http.Client{Transport: rr} + }) } diff --git a/authority/http_client_test.go b/authority/http_client_test.go index 5a77331a..cb751019 100644 --- a/authority/http_client_test.go +++ b/authority/http_client_test.go @@ -114,8 +114,7 @@ func Test_newHTTPClient(t *testing.T) { }{http.DefaultTransport} http.DefaultTransport = transport - client, err := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...) - assert.NoError(t, err) + client := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...) assert.NotNil(t, client) }) } diff --git a/authority/options.go b/authority/options.go index 6a75bdd9..d85a611d 100644 --- a/authority/options.go +++ b/authority/options.go @@ -5,7 +5,6 @@ import ( "crypto" "crypto/x509" "encoding/pem" - "net/http" "github.com/pkg/errors" "golang.org/x/crypto/ssh" @@ -97,7 +96,7 @@ func WithQuietInit() Option { } // WithWebhookClient sets the http.Client to be used for outbound requests. -func WithWebhookClient(c *http.Client) Option { +func WithWebhookClient(c provisioner.HTTPClient) Option { return func(a *Authority) error { a.webhookClient = c return nil diff --git a/authority/poolhttp/poolhttp.go b/authority/poolhttp/poolhttp.go new file mode 100644 index 00000000..d639bcf4 --- /dev/null +++ b/authority/poolhttp/poolhttp.go @@ -0,0 +1,80 @@ +package poolhttp + +import ( + "net/http" + "sync" + + "github.com/smallstep/certificates/internal/httptransport" +) + +// Transporter is the implemented by custom HTTP clients with a method that +// returns an [*http.Transport]. +type Transporter interface { + Transport() *http.Transport +} + +// Client returns an HTTP client that uses a [sync.Pool] to create new HTTP +// client. It implements the [provisioner.HTTPClient] and [Transporter] +// interfaces. This is the HTTP client used by the provisioners. +type Client struct { + pool sync.Pool +} + +// New creates a new poolhttp [Client], the [sync.Pool] will initialize a new +// [*http.Client] with the given function. +func New(fn func() *http.Client) *Client { + return &Client{ + pool: sync.Pool{ + New: func() any { return fn() }, + }, + } +} + +// SetNew replaces the inner pool with a new [sync.Pool] with the given New +// function. This method should not be used concurrently with other methods. +func (c *Client) SetNew(fn func() *http.Client) { + c.pool = sync.Pool{ + New: func() any { return fn() }, + } +} + +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// [Client.CheckRedirect] function: +func (c *Client) Get(u string) (resp *http.Response, err error) { + if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + resp, err = hc.Get(u) + c.pool.Put(hc) + } else { + resp, err = http.DefaultClient.Get(u) + } + + return +} + +// Do sends an HTTP request and returns an HTTP response, following policy (such +// as redirects, cookies, auth) as configured on the client. +func (c *Client) Do(req *http.Request) (resp *http.Response, err error) { + if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + resp, err = hc.Do(req) + c.pool.Put(hc) + } else { + resp, err = http.DefaultClient.Do(req) + } + + return +} + +// Transport() returns a clone of the http.Client Transport or returns the +// default transport. +func (c *Client) Transport() *http.Transport { + if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil { + tr, ok := hc.Transport.(*http.Transport) + c.pool.Put(hc) + if ok { + return tr.Clone() + } + } + + return httptransport.New() +} diff --git a/authority/poolhttp/poolhttp_test.go b/authority/poolhttp/poolhttp_test.go new file mode 100644 index 00000000..ba4dcb30 --- /dev/null +++ b/authority/poolhttp/poolhttp_test.go @@ -0,0 +1,124 @@ +package poolhttp + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func requireBody(t *testing.T, want string, r io.ReadCloser) { + t.Helper() + t.Cleanup(func() { + require.NoError(t, r.Close()) + }) + + b, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, want, string(b)) +} + +func TestClient(t *testing.T) { + httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello World") + })) + t.Cleanup(httpSrv.Close) + tlsSrv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello World") + })) + t.Cleanup(tlsSrv.Close) + + tests := []struct { + name string + client *Client + srv *httptest.Server + }{ + {"http", New(func() *http.Client { return httpSrv.Client() }), httpSrv}, + {"tls", New(func() *http.Client { return tlsSrv.Client() }), tlsSrv}, + {"nil", New(func() *http.Client { return nil }), httpSrv}, + {"empty", &Client{}, httpSrv}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + resp, err := tc.client.Get(tc.srv.URL) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + + req, err := http.NewRequest("GET", tc.srv.URL, http.NoBody) + require.NoError(t, err) + + resp, err = tc.client.Do(req) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + + client := &http.Client{ + Transport: tc.client.Transport(), + } + resp, err = client.Get(tc.srv.URL) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + }) + } +} + +func TestClient_SetNew(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello World") + })) + t.Cleanup(srv.Close) + + c := New(func() *http.Client { + return srv.Client() + }) + + tests := []struct { + name string + client *http.Client + assertion assert.ErrorAssertionFunc + }{ + {"ok", srv.Client(), assert.NoError}, + {"fail", http.DefaultClient, assert.Error}, + {"ok again", srv.Client(), assert.NoError}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c.SetNew(func() *http.Client { + return tc.client + }) + _, err := c.Get(srv.URL) + tc.assertion(t, err) + + }) + } +} + +func TestClient_parallel(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello World") + })) + t.Cleanup(srv.Close) + + c := New(func() *http.Client { + return srv.Client() + }) + req, err := http.NewRequest("GET", srv.URL, http.NoBody) + require.NoError(t, err) + + for i := range 10 { + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + resp, err := c.Get(srv.URL) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + + resp, err = c.Do(req) + require.NoError(t, err) + requireBody(t, "Hello World\n", resp.Body) + }) + } +} diff --git a/authority/provisioner/controller.go b/authority/provisioner/controller.go index 3d828d76..93439e0c 100644 --- a/authority/provisioner/controller.go +++ b/authority/provisioner/controller.go @@ -28,9 +28,9 @@ type Controller struct { AuthorizeRenewFunc AuthorizeRenewFunc AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc policy *policyEngine - webhookClient *http.Client + httpClient HTTPClient + webhookClient HTTPClient webhooks []*Webhook - httpClient *http.Client wrapTransport httptransport.Wrapper } @@ -48,6 +48,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options) if wt == nil { wt = httptransport.NoopWrapper() } + return &Controller{ Interface: p, Audiences: &config.Audiences, @@ -65,7 +66,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options) // GetHTTPClient returns the configured HTTP client or the default one if none // is configured. -func (c *Controller) GetHTTPClient() *http.Client { +func (c *Controller) GetHTTPClient() HTTPClient { if c.httpClient != nil { return c.httpClient } diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go index 23a4e9aa..0fba67e4 100644 --- a/authority/provisioner/keystore.go +++ b/authority/provisioner/keystore.go @@ -3,7 +3,6 @@ package provisioner import ( "encoding/json" "math/rand" - "net/http" "regexp" "strconv" "sync" @@ -22,33 +21,26 @@ var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`) type keyStore struct { sync.RWMutex - client *http.Client + client HTTPClient uri string keySet jose.JSONWebKeySet - timer *time.Timer expiry time.Time jitter time.Duration } -func newKeyStore(client *http.Client, uri string) (*keyStore, error) { +func newKeyStore(client HTTPClient, uri string) (*keyStore, error) { keys, age, err := getKeysFromJWKsURI(client, uri) if err != nil { return nil, err } - ks := &keyStore{ + jitter := getCacheJitter(age) + return &keyStore{ client: client, uri: uri, keySet: keys, - expiry: getExpirationTime(age), - jitter: getCacheJitter(age), - } - next := ks.nextReloadDuration(age) - ks.timer = time.AfterFunc(next, ks.reload) - return ks, nil -} - -func (ks *keyStore) Close() { - ks.timer.Stop() + expiry: getExpirationTime(age, jitter), + jitter: jitter, + }, nil } func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { @@ -65,34 +57,16 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { } func (ks *keyStore) reload() { - var next time.Duration - keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri) - if err != nil { - next = ks.nextReloadDuration(ks.jitter / 2) - } else { + if keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri); err == nil { ks.Lock() ks.keySet = keys - ks.expiry = getExpirationTime(age) ks.jitter = getCacheJitter(age) - next = ks.nextReloadDuration(age) + ks.expiry = getExpirationTime(age, ks.jitter) ks.Unlock() } - - ks.Lock() - ks.timer.Reset(next) - ks.Unlock() } -// nextReloadDuration would return the duration for the next rotation. If age is -// 0 it will randomly rotate between 0-12 hours, but every time we call to Get -// it will automatically rotate. -func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { - n := rand.Int63n(int64(ks.jitter)) //nolint:gosec // not used for cryptographic security - age -= time.Duration(n) - return abs(age) -} - -func getKeysFromJWKsURI(client *http.Client, uri string) (jose.JSONWebKeySet, time.Duration, error) { +func getKeysFromJWKsURI(client HTTPClient, uri string) (jose.JSONWebKeySet, time.Duration, error) { var keys jose.JSONWebKeySet resp, err := client.Get(uri) if err != nil { @@ -136,8 +110,12 @@ func getCacheJitter(age time.Duration) time.Duration { } } -func getExpirationTime(age time.Duration) time.Time { - return time.Now().Truncate(time.Second).Add(age) +func getExpirationTime(age, jitter time.Duration) time.Time { + if age > 0 { + n := rand.Int63n(int64(jitter)) //nolint:gosec // not used for cryptographic security + age -= time.Duration(n) + } + return time.Now().Truncate(time.Second).Add(abs(age)) } // abs returns the absolute value of n. diff --git a/authority/provisioner/keystore_test.go b/authority/provisioner/keystore_test.go index 85d4260f..b17dc870 100644 --- a/authority/provisioner/keystore_test.go +++ b/authority/provisioner/keystore_test.go @@ -22,7 +22,6 @@ func Test_newKeyStore(t *testing.T) { ks, err := newKeyStore(srv.Client(), srv.URL) assert.FatalError(t, err) - defer ks.Close() type args struct { client *http.Client @@ -49,7 +48,6 @@ func Test_newKeyStore(t *testing.T) { if !reflect.DeepEqual(got.keySet, tt.want) { t.Errorf("newKeyStore() = %v, want %v", got, tt.want) } - got.Close() } }) } @@ -61,7 +59,6 @@ func Test_keyStore(t *testing.T) { ks, err := newKeyStore(srv.Client(), srv.URL+"/random") assert.FatalError(t, err) - defer ks.Close() ks.RLock() keySet1 := ks.keySet ks.RUnlock() @@ -73,6 +70,7 @@ func Test_keyStore(t *testing.T) { // Wait for rotation time.Sleep(5 * time.Second) + assert.Len(t, 0, ks.Get("foobar")) // force refresh ks.RLock() keySet2 := ks.keySet @@ -105,7 +103,6 @@ func Test_keyStore_noCache(t *testing.T) { ks, err := newKeyStore(srv.Client(), srv.URL+"/no-cache") assert.FatalError(t, err) - defer ks.Close() ks.RLock() keySet1 := ks.keySet ks.RUnlock() @@ -116,20 +113,6 @@ func Test_keyStore_noCache(t *testing.T) { assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID)) assert.Len(t, 0, ks.Get("foobar")) - ks.RLock() - keySet2 := ks.keySet - ks.RUnlock() - if reflect.DeepEqual(keySet1, keySet2) { - t.Error("keyStore did not rotated") - } - - // The keys will rotate on Get. - // So we won't be able to find the cached ones - assert.Len(t, 2, keySet2.Keys) - assert.Len(t, 0, ks.Get(keySet2.Keys[0].KeyID)) - assert.Len(t, 0, ks.Get(keySet2.Keys[1].KeyID)) - assert.Len(t, 0, ks.Get("foobar")) - // Check hits resp, err := srv.Client().Get(srv.URL + "/hits") assert.FatalError(t, err) @@ -147,7 +130,6 @@ func Test_keyStore_Get(t *testing.T) { defer srv.Close() ks, err := newKeyStore(srv.Client(), srv.URL) assert.FatalError(t, err) - defer ks.Close() type args struct { kid string diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 16f9caf4..044971bf 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -485,7 +485,7 @@ func (o *OIDC) AuthorizeSSHRevoke(_ context.Context, token string) error { return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token") } -func getAndDecode(client *http.Client, uri string, v interface{}) error { +func getAndDecode(client HTTPClient, uri string, v interface{}) error { resp, err := client.Get(uri) if err != nil { return errors.Wrapf(err, "failed to connect to %s", uri) diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 86fdd5ec..d57e3582 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -34,6 +34,13 @@ type Interface interface { AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) } +// HTTPClient is the interface implemented by the HTTP clients used by the +// provisioners. +type HTTPClient interface { + Get(string) (*http.Response, error) + Do(*http.Request) (*http.Response, error) +} + // Uninitialized represents a disabled provisioner. Uninitialized provisioners // are created when the Init methods fails. type Uninitialized struct { @@ -251,6 +258,16 @@ type Config struct { // GetIdentityFunc is a function that returns an identity that will be // used by the provisioner to populate certificate attributes. GetIdentityFunc GetIdentityFunc + /* + // GetHttpClientFunc is a function that returns an HTTP client that trusts + // the system cert pool and the CA roots. + GetHttpClientFunc func() *http.Client + // GetWebhookClientFunc ris a function that returns the HTTP client used + // when performing webhook requests, this is an HTTP client that trusts the + // system cert pool, the CA roots and uses the CA certificate as a client + // certificate. + GetWebhookClientFunc func() *http.Client + */ // AuthorizeRenewFunc is a function that returns nil if a given X.509 // certificate can be renewed. AuthorizeRenewFunc AuthorizeRenewFunc @@ -258,12 +275,12 @@ type Config struct { // certificate can be renewed. AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc // WebhookClient is an HTTP client used when performing webhook requests. - WebhookClient *http.Client + WebhookClient HTTPClient // SCEPKeyManager, if defined, is the interface used by SCEP provisioners. SCEPKeyManager SCEPKeyManager // HTTPClient is an HTTP client that trusts the system cert pool and the CA // roots. - HTTPClient *http.Client + HTTPClient HTTPClient // WrapTransport references the function that should wrap any [http.Transport] initialized // down the Config's chain. WrapTransport TransportWrapper diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index b6e8b925..a97ff8e5 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "net/http" "time" "github.com/pkg/errors" @@ -113,14 +112,14 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration { } type challengeValidationController struct { - client *http.Client + client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook } // newChallengeValidationController creates a new challengeValidationController // that performs challenge validation through webhooks. -func newChallengeValidationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { +func newChallengeValidationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController { scepHooks := []*Webhook{} for _, wh := range webhooks { if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() { @@ -179,14 +178,14 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509. } type notificationController struct { - client *http.Client + client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook } // newNotificationController creates a new notificationController // that performs SCEP notifications through webhooks. -func newNotificationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController { +func newNotificationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController { scepHooks := []*Webhook{} for _, wh := range webhooks { if wh.Kind != linkedca.Webhook_NOTIFYING.String() { diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index dba4a4c8..b78b90b0 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -18,6 +18,7 @@ import ( "github.com/smallstep/linkedca" + "github.com/smallstep/certificates/authority/poolhttp" "github.com/smallstep/certificates/internal/httptransport" "github.com/smallstep/certificates/middleware/requestid" "github.com/smallstep/certificates/templates" @@ -31,7 +32,7 @@ type WebhookSetter interface { } type WebhookController struct { - client *http.Client + client HTTPClient wrapTransport httptransport.Wrapper webhooks []*Webhook certType linkedca.Webhook_CertType @@ -146,7 +147,7 @@ type Webhook struct { // [http.RoundTripper]. type TransportWrapper = httptransport.Wrapper -func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { +func (w *Webhook) DoWithContext(ctx context.Context, client HTTPClient, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) if err != nil { return nil, err @@ -206,11 +207,11 @@ retry: } if w.DisableTLSClientAuth { - transport, ok := client.Transport.(*http.Transport) - if !ok { - transport = httptransport.New() + var transport *http.Transport + if ct, ok := client.(poolhttp.Transporter); ok { + transport = ct.Transport() } else { - transport = transport.Clone() + transport = httptransport.New() } if transport.TLSClientConfig != nil { diff --git a/templates/templates.go b/templates/templates.go index e58f5fbc..a8cd1df2 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "strings" + "sync" "text/template" "github.com/Masterminds/sprig/v3" @@ -29,6 +30,22 @@ const ( Directory TemplateType = "directory" ) +var ( + stepFuncMap template.FuncMap + stepFuncOnce sync.Once +) + +// StepFuncMap returns sprig.TxtFuncMap but removing the "env" and "expandenv" +// functions to avoid any leak of information. +func StepFuncMap() template.FuncMap { + stepFuncOnce.Do(func() { + stepFuncMap = sprig.TxtFuncMap() + delete(stepFuncMap, "env") + delete(stepFuncMap, "expandenv") + }) + return stepFuncMap +} + // Templates is a collection of templates and variables. type Templates struct { SSH *SSHTemplates `json:"ssh,omitempty"` @@ -282,12 +299,3 @@ func mkdir(path string, perm os.FileMode) error { } return nil } - -// StepFuncMap returns sprig.TxtFuncMap but removing the "env" and "expandenv" -// functions to avoid any leak of information. -func StepFuncMap() template.FuncMap { - m := sprig.TxtFuncMap() - delete(m, "env") - delete(m, "expandenv") - return m -}