mirror of
https://github.com/outbackdingo/certificates.git
synced 2026-01-27 02:18:27 +00:00
Memory improvements
This commit replaces the client in provisioners and webhooks with an interface. Then it implements the interface using the new poolhttp package. This package implements the HTTPClient interface but it is backed by a sync.Pool, this improves memory, allowing the GC to clean more memory. It also removes the timer in the keystore to avoid having extra goroutines if a provisioner goes away. This commit avoids creating the templates func multiple times, reducing some memory in the heap.
This commit is contained in:
@@ -8,8 +8,8 @@ import (
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -50,8 +50,8 @@ type Authority struct {
|
||||
templates *templates.Templates
|
||||
linkedCAToken string
|
||||
wrapTransport httptransport.Wrapper
|
||||
webhookClient *http.Client
|
||||
httpClient *http.Client
|
||||
webhookClient provisioner.HTTPClient
|
||||
httpClient provisioner.HTTPClient
|
||||
|
||||
// X509 CA
|
||||
password []byte
|
||||
@@ -147,6 +147,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) {
|
||||
a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter)
|
||||
}
|
||||
|
||||
// Initialize system cert pool
|
||||
if err := initializeSystemCertPool(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize the system cert pool: %w", err)
|
||||
}
|
||||
|
||||
if !a.skipInit {
|
||||
// Initialize authority from options or configuration.
|
||||
if err := a.init(); err != nil {
|
||||
@@ -177,6 +182,11 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
|
||||
a.keyManager = newInstrumentedKeyManager(a.keyManager, a.meter)
|
||||
}
|
||||
|
||||
// Initialize system cert pool
|
||||
if err := initializeSystemCertPool(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize the system cert pool: %w", err)
|
||||
}
|
||||
|
||||
// Validate required options
|
||||
switch {
|
||||
case a.config == nil:
|
||||
@@ -500,7 +510,7 @@ func (a *Authority) init() error {
|
||||
clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs))
|
||||
clientRoots = append(clientRoots, a.rootX509Certs...)
|
||||
clientRoots = append(clientRoots, a.federatedX509Certs...)
|
||||
a.httpClient, err = newHTTPClient(a.wrapTransport, clientRoots...)
|
||||
a.httpClient = newHTTPClient(a.wrapTransport, clientRoots...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -25,6 +26,16 @@ import (
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if err := initializeSystemCertPool(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "failed to initialize system cert pool:", err)
|
||||
fmt.Fprintln(os.Stderr, "See https://pkg.go.dev/github.com/tjfoc/gmsm/x509#SystemCertPool\n", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func testAuthority(t *testing.T, opts ...Option) *Authority {
|
||||
maxjwk, err := jose.ReadKey("testdata/secrets/max_pub.jwk")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
@@ -3,36 +3,54 @@ package authority
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/smallstep/certificates/authority/poolhttp"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/internal/httptransport"
|
||||
)
|
||||
|
||||
// systemCertPool holds a copy of the system cert pool. This cert pool must be
|
||||
// initialized when the authority is created and we should always get a clone of
|
||||
// this pool.
|
||||
var systemCertPool atomic.Pointer[x509.CertPool]
|
||||
|
||||
// initializeSystemCertPool initializes the system cert pool if necessary.
|
||||
func initializeSystemCertPool() error {
|
||||
if systemCertPool.Load() == nil {
|
||||
pool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
systemCertPool.Store(pool)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newHTTPClient will return an HTTP client that trusts the system cert pool and
|
||||
// the given roots.
|
||||
func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) (*http.Client, error) {
|
||||
pool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error initializing http client: %w", err)
|
||||
}
|
||||
for _, crt := range roots {
|
||||
pool.AddCert(crt)
|
||||
}
|
||||
func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) provisioner.HTTPClient {
|
||||
return poolhttp.New(func() *http.Client {
|
||||
pool := systemCertPool.Load().Clone()
|
||||
for _, crt := range roots {
|
||||
pool.AddCert(crt)
|
||||
}
|
||||
|
||||
tr, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
tr = httptransport.New()
|
||||
} else {
|
||||
tr = tr.Clone()
|
||||
}
|
||||
tr, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
tr = httptransport.New()
|
||||
} else {
|
||||
tr = tr.Clone()
|
||||
}
|
||||
|
||||
tr.TLSClientConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: pool,
|
||||
}
|
||||
tr.TLSClientConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
RootCAs: pool,
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: wt(tr),
|
||||
}, nil
|
||||
rr := wt(tr)
|
||||
|
||||
return &http.Client{Transport: rr}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -114,8 +114,7 @@ func Test_newHTTPClient(t *testing.T) {
|
||||
}{http.DefaultTransport}
|
||||
http.DefaultTransport = transport
|
||||
|
||||
client, err := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...)
|
||||
assert.NoError(t, err)
|
||||
client := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...)
|
||||
assert.NotNil(t, client)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
@@ -97,7 +96,7 @@ func WithQuietInit() Option {
|
||||
}
|
||||
|
||||
// WithWebhookClient sets the http.Client to be used for outbound requests.
|
||||
func WithWebhookClient(c *http.Client) Option {
|
||||
func WithWebhookClient(c provisioner.HTTPClient) Option {
|
||||
return func(a *Authority) error {
|
||||
a.webhookClient = c
|
||||
return nil
|
||||
|
||||
80
authority/poolhttp/poolhttp.go
Normal file
80
authority/poolhttp/poolhttp.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package poolhttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/smallstep/certificates/internal/httptransport"
|
||||
)
|
||||
|
||||
// Transporter is the implemented by custom HTTP clients with a method that
|
||||
// returns an [*http.Transport].
|
||||
type Transporter interface {
|
||||
Transport() *http.Transport
|
||||
}
|
||||
|
||||
// Client returns an HTTP client that uses a [sync.Pool] to create new HTTP
|
||||
// client. It implements the [provisioner.HTTPClient] and [Transporter]
|
||||
// interfaces. This is the HTTP client used by the provisioners.
|
||||
type Client struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
// New creates a new poolhttp [Client], the [sync.Pool] will initialize a new
|
||||
// [*http.Client] with the given function.
|
||||
func New(fn func() *http.Client) *Client {
|
||||
return &Client{
|
||||
pool: sync.Pool{
|
||||
New: func() any { return fn() },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetNew replaces the inner pool with a new [sync.Pool] with the given New
|
||||
// function. This method should not be used concurrently with other methods.
|
||||
func (c *Client) SetNew(fn func() *http.Client) {
|
||||
c.pool = sync.Pool{
|
||||
New: func() any { return fn() },
|
||||
}
|
||||
}
|
||||
|
||||
// Get issues a GET to the specified URL. If the response is one of the
|
||||
// following redirect codes, Get follows the redirect after calling the
|
||||
// [Client.CheckRedirect] function:
|
||||
func (c *Client) Get(u string) (resp *http.Response, err error) {
|
||||
if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil {
|
||||
resp, err = hc.Get(u)
|
||||
c.pool.Put(hc)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Get(u)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Do sends an HTTP request and returns an HTTP response, following policy (such
|
||||
// as redirects, cookies, auth) as configured on the client.
|
||||
func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
|
||||
if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil {
|
||||
resp, err = hc.Do(req)
|
||||
c.pool.Put(hc)
|
||||
} else {
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Transport() returns a clone of the http.Client Transport or returns the
|
||||
// default transport.
|
||||
func (c *Client) Transport() *http.Transport {
|
||||
if hc, ok := c.pool.Get().(*http.Client); ok && hc != nil {
|
||||
tr, ok := hc.Transport.(*http.Transport)
|
||||
c.pool.Put(hc)
|
||||
if ok {
|
||||
return tr.Clone()
|
||||
}
|
||||
}
|
||||
|
||||
return httptransport.New()
|
||||
}
|
||||
124
authority/poolhttp/poolhttp_test.go
Normal file
124
authority/poolhttp/poolhttp_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package poolhttp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func requireBody(t *testing.T, want string, r io.ReadCloser) {
|
||||
t.Helper()
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.Close())
|
||||
})
|
||||
|
||||
b, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, string(b))
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello World")
|
||||
}))
|
||||
t.Cleanup(httpSrv.Close)
|
||||
tlsSrv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello World")
|
||||
}))
|
||||
t.Cleanup(tlsSrv.Close)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
client *Client
|
||||
srv *httptest.Server
|
||||
}{
|
||||
{"http", New(func() *http.Client { return httpSrv.Client() }), httpSrv},
|
||||
{"tls", New(func() *http.Client { return tlsSrv.Client() }), tlsSrv},
|
||||
{"nil", New(func() *http.Client { return nil }), httpSrv},
|
||||
{"empty", &Client{}, httpSrv},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := tc.client.Get(tc.srv.URL)
|
||||
require.NoError(t, err)
|
||||
requireBody(t, "Hello World\n", resp.Body)
|
||||
|
||||
req, err := http.NewRequest("GET", tc.srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err = tc.client.Do(req)
|
||||
require.NoError(t, err)
|
||||
requireBody(t, "Hello World\n", resp.Body)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: tc.client.Transport(),
|
||||
}
|
||||
resp, err = client.Get(tc.srv.URL)
|
||||
require.NoError(t, err)
|
||||
requireBody(t, "Hello World\n", resp.Body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_SetNew(t *testing.T) {
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello World")
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
c := New(func() *http.Client {
|
||||
return srv.Client()
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
client *http.Client
|
||||
assertion assert.ErrorAssertionFunc
|
||||
}{
|
||||
{"ok", srv.Client(), assert.NoError},
|
||||
{"fail", http.DefaultClient, assert.Error},
|
||||
{"ok again", srv.Client(), assert.NoError},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c.SetNew(func() *http.Client {
|
||||
return tc.client
|
||||
})
|
||||
_, err := c.Get(srv.URL)
|
||||
tc.assertion(t, err)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_parallel(t *testing.T) {
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Hello World")
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
c := New(func() *http.Client {
|
||||
return srv.Client()
|
||||
})
|
||||
req, err := http.NewRequest("GET", srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := range 10 {
|
||||
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp, err := c.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
requireBody(t, "Hello World\n", resp.Body)
|
||||
|
||||
resp, err = c.Do(req)
|
||||
require.NoError(t, err)
|
||||
requireBody(t, "Hello World\n", resp.Body)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -28,9 +28,9 @@ type Controller struct {
|
||||
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||
policy *policyEngine
|
||||
webhookClient *http.Client
|
||||
httpClient HTTPClient
|
||||
webhookClient HTTPClient
|
||||
webhooks []*Webhook
|
||||
httpClient *http.Client
|
||||
wrapTransport httptransport.Wrapper
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options)
|
||||
if wt == nil {
|
||||
wt = httptransport.NoopWrapper()
|
||||
}
|
||||
|
||||
return &Controller{
|
||||
Interface: p,
|
||||
Audiences: &config.Audiences,
|
||||
@@ -65,7 +66,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options)
|
||||
|
||||
// GetHTTPClient returns the configured HTTP client or the default one if none
|
||||
// is configured.
|
||||
func (c *Controller) GetHTTPClient() *http.Client {
|
||||
func (c *Controller) GetHTTPClient() HTTPClient {
|
||||
if c.httpClient != nil {
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package provisioner
|
||||
import (
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -22,33 +21,26 @@ var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
|
||||
|
||||
type keyStore struct {
|
||||
sync.RWMutex
|
||||
client *http.Client
|
||||
client HTTPClient
|
||||
uri string
|
||||
keySet jose.JSONWebKeySet
|
||||
timer *time.Timer
|
||||
expiry time.Time
|
||||
jitter time.Duration
|
||||
}
|
||||
|
||||
func newKeyStore(client *http.Client, uri string) (*keyStore, error) {
|
||||
func newKeyStore(client HTTPClient, uri string) (*keyStore, error) {
|
||||
keys, age, err := getKeysFromJWKsURI(client, uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ks := &keyStore{
|
||||
jitter := getCacheJitter(age)
|
||||
return &keyStore{
|
||||
client: client,
|
||||
uri: uri,
|
||||
keySet: keys,
|
||||
expiry: getExpirationTime(age),
|
||||
jitter: getCacheJitter(age),
|
||||
}
|
||||
next := ks.nextReloadDuration(age)
|
||||
ks.timer = time.AfterFunc(next, ks.reload)
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
func (ks *keyStore) Close() {
|
||||
ks.timer.Stop()
|
||||
expiry: getExpirationTime(age, jitter),
|
||||
jitter: jitter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
||||
@@ -65,34 +57,16 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
||||
}
|
||||
|
||||
func (ks *keyStore) reload() {
|
||||
var next time.Duration
|
||||
keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri)
|
||||
if err != nil {
|
||||
next = ks.nextReloadDuration(ks.jitter / 2)
|
||||
} else {
|
||||
if keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri); err == nil {
|
||||
ks.Lock()
|
||||
ks.keySet = keys
|
||||
ks.expiry = getExpirationTime(age)
|
||||
ks.jitter = getCacheJitter(age)
|
||||
next = ks.nextReloadDuration(age)
|
||||
ks.expiry = getExpirationTime(age, ks.jitter)
|
||||
ks.Unlock()
|
||||
}
|
||||
|
||||
ks.Lock()
|
||||
ks.timer.Reset(next)
|
||||
ks.Unlock()
|
||||
}
|
||||
|
||||
// nextReloadDuration would return the duration for the next rotation. If age is
|
||||
// 0 it will randomly rotate between 0-12 hours, but every time we call to Get
|
||||
// it will automatically rotate.
|
||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||
n := rand.Int63n(int64(ks.jitter)) //nolint:gosec // not used for cryptographic security
|
||||
age -= time.Duration(n)
|
||||
return abs(age)
|
||||
}
|
||||
|
||||
func getKeysFromJWKsURI(client *http.Client, uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||
func getKeysFromJWKsURI(client HTTPClient, uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||
var keys jose.JSONWebKeySet
|
||||
resp, err := client.Get(uri)
|
||||
if err != nil {
|
||||
@@ -136,8 +110,12 @@ func getCacheJitter(age time.Duration) time.Duration {
|
||||
}
|
||||
}
|
||||
|
||||
func getExpirationTime(age time.Duration) time.Time {
|
||||
return time.Now().Truncate(time.Second).Add(age)
|
||||
func getExpirationTime(age, jitter time.Duration) time.Time {
|
||||
if age > 0 {
|
||||
n := rand.Int63n(int64(jitter)) //nolint:gosec // not used for cryptographic security
|
||||
age -= time.Duration(n)
|
||||
}
|
||||
return time.Now().Truncate(time.Second).Add(abs(age))
|
||||
}
|
||||
|
||||
// abs returns the absolute value of n.
|
||||
|
||||
@@ -22,7 +22,6 @@ func Test_newKeyStore(t *testing.T) {
|
||||
|
||||
ks, err := newKeyStore(srv.Client(), srv.URL)
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
|
||||
type args struct {
|
||||
client *http.Client
|
||||
@@ -49,7 +48,6 @@ func Test_newKeyStore(t *testing.T) {
|
||||
if !reflect.DeepEqual(got.keySet, tt.want) {
|
||||
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -61,7 +59,6 @@ func Test_keyStore(t *testing.T) {
|
||||
|
||||
ks, err := newKeyStore(srv.Client(), srv.URL+"/random")
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
ks.RLock()
|
||||
keySet1 := ks.keySet
|
||||
ks.RUnlock()
|
||||
@@ -73,6 +70,7 @@ func Test_keyStore(t *testing.T) {
|
||||
|
||||
// Wait for rotation
|
||||
time.Sleep(5 * time.Second)
|
||||
assert.Len(t, 0, ks.Get("foobar")) // force refresh
|
||||
|
||||
ks.RLock()
|
||||
keySet2 := ks.keySet
|
||||
@@ -105,7 +103,6 @@ func Test_keyStore_noCache(t *testing.T) {
|
||||
|
||||
ks, err := newKeyStore(srv.Client(), srv.URL+"/no-cache")
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
ks.RLock()
|
||||
keySet1 := ks.keySet
|
||||
ks.RUnlock()
|
||||
@@ -116,20 +113,6 @@ func Test_keyStore_noCache(t *testing.T) {
|
||||
assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
ks.RLock()
|
||||
keySet2 := ks.keySet
|
||||
ks.RUnlock()
|
||||
if reflect.DeepEqual(keySet1, keySet2) {
|
||||
t.Error("keyStore did not rotated")
|
||||
}
|
||||
|
||||
// The keys will rotate on Get.
|
||||
// So we won't be able to find the cached ones
|
||||
assert.Len(t, 2, keySet2.Keys)
|
||||
assert.Len(t, 0, ks.Get(keySet2.Keys[0].KeyID))
|
||||
assert.Len(t, 0, ks.Get(keySet2.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
// Check hits
|
||||
resp, err := srv.Client().Get(srv.URL + "/hits")
|
||||
assert.FatalError(t, err)
|
||||
@@ -147,7 +130,6 @@ func Test_keyStore_Get(t *testing.T) {
|
||||
defer srv.Close()
|
||||
ks, err := newKeyStore(srv.Client(), srv.URL)
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
|
||||
type args struct {
|
||||
kid string
|
||||
|
||||
@@ -485,7 +485,7 @@ func (o *OIDC) AuthorizeSSHRevoke(_ context.Context, token string) error {
|
||||
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
|
||||
}
|
||||
|
||||
func getAndDecode(client *http.Client, uri string, v interface{}) error {
|
||||
func getAndDecode(client HTTPClient, uri string, v interface{}) error {
|
||||
resp, err := client.Get(uri)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to connect to %s", uri)
|
||||
|
||||
@@ -34,6 +34,13 @@ type Interface interface {
|
||||
AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error)
|
||||
}
|
||||
|
||||
// HTTPClient is the interface implemented by the HTTP clients used by the
|
||||
// provisioners.
|
||||
type HTTPClient interface {
|
||||
Get(string) (*http.Response, error)
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// Uninitialized represents a disabled provisioner. Uninitialized provisioners
|
||||
// are created when the Init methods fails.
|
||||
type Uninitialized struct {
|
||||
@@ -251,6 +258,16 @@ type Config struct {
|
||||
// GetIdentityFunc is a function that returns an identity that will be
|
||||
// used by the provisioner to populate certificate attributes.
|
||||
GetIdentityFunc GetIdentityFunc
|
||||
/*
|
||||
// GetHttpClientFunc is a function that returns an HTTP client that trusts
|
||||
// the system cert pool and the CA roots.
|
||||
GetHttpClientFunc func() *http.Client
|
||||
// GetWebhookClientFunc ris a function that returns the HTTP client used
|
||||
// when performing webhook requests, this is an HTTP client that trusts the
|
||||
// system cert pool, the CA roots and uses the CA certificate as a client
|
||||
// certificate.
|
||||
GetWebhookClientFunc func() *http.Client
|
||||
*/
|
||||
// AuthorizeRenewFunc is a function that returns nil if a given X.509
|
||||
// certificate can be renewed.
|
||||
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||
@@ -258,12 +275,12 @@ type Config struct {
|
||||
// certificate can be renewed.
|
||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||
// WebhookClient is an HTTP client used when performing webhook requests.
|
||||
WebhookClient *http.Client
|
||||
WebhookClient HTTPClient
|
||||
// SCEPKeyManager, if defined, is the interface used by SCEP provisioners.
|
||||
SCEPKeyManager SCEPKeyManager
|
||||
// HTTPClient is an HTTP client that trusts the system cert pool and the CA
|
||||
// roots.
|
||||
HTTPClient *http.Client
|
||||
HTTPClient HTTPClient
|
||||
// WrapTransport references the function that should wrap any [http.Transport] initialized
|
||||
// down the Config's chain.
|
||||
WrapTransport TransportWrapper
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -113,14 +112,14 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration {
|
||||
}
|
||||
|
||||
type challengeValidationController struct {
|
||||
client *http.Client
|
||||
client HTTPClient
|
||||
wrapTransport httptransport.Wrapper
|
||||
webhooks []*Webhook
|
||||
}
|
||||
|
||||
// newChallengeValidationController creates a new challengeValidationController
|
||||
// that performs challenge validation through webhooks.
|
||||
func newChallengeValidationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController {
|
||||
func newChallengeValidationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController {
|
||||
scepHooks := []*Webhook{}
|
||||
for _, wh := range webhooks {
|
||||
if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() {
|
||||
@@ -179,14 +178,14 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509.
|
||||
}
|
||||
|
||||
type notificationController struct {
|
||||
client *http.Client
|
||||
client HTTPClient
|
||||
wrapTransport httptransport.Wrapper
|
||||
webhooks []*Webhook
|
||||
}
|
||||
|
||||
// newNotificationController creates a new notificationController
|
||||
// that performs SCEP notifications through webhooks.
|
||||
func newNotificationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController {
|
||||
func newNotificationController(client HTTPClient, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController {
|
||||
scepHooks := []*Webhook{}
|
||||
for _, wh := range webhooks {
|
||||
if wh.Kind != linkedca.Webhook_NOTIFYING.String() {
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/smallstep/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/authority/poolhttp"
|
||||
"github.com/smallstep/certificates/internal/httptransport"
|
||||
"github.com/smallstep/certificates/middleware/requestid"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
@@ -31,7 +32,7 @@ type WebhookSetter interface {
|
||||
}
|
||||
|
||||
type WebhookController struct {
|
||||
client *http.Client
|
||||
client HTTPClient
|
||||
wrapTransport httptransport.Wrapper
|
||||
webhooks []*Webhook
|
||||
certType linkedca.Webhook_CertType
|
||||
@@ -146,7 +147,7 @@ type Webhook struct {
|
||||
// [http.RoundTripper].
|
||||
type TransportWrapper = httptransport.Wrapper
|
||||
|
||||
func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
|
||||
func (w *Webhook) DoWithContext(ctx context.Context, client HTTPClient, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
|
||||
tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -206,11 +207,11 @@ retry:
|
||||
}
|
||||
|
||||
if w.DisableTLSClientAuth {
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
transport = httptransport.New()
|
||||
var transport *http.Transport
|
||||
if ct, ok := client.(poolhttp.Transporter); ok {
|
||||
transport = ct.Transport()
|
||||
} else {
|
||||
transport = transport.Clone()
|
||||
transport = httptransport.New()
|
||||
}
|
||||
|
||||
if transport.TLSClientConfig != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
@@ -29,6 +30,22 @@ const (
|
||||
Directory TemplateType = "directory"
|
||||
)
|
||||
|
||||
var (
|
||||
stepFuncMap template.FuncMap
|
||||
stepFuncOnce sync.Once
|
||||
)
|
||||
|
||||
// StepFuncMap returns sprig.TxtFuncMap but removing the "env" and "expandenv"
|
||||
// functions to avoid any leak of information.
|
||||
func StepFuncMap() template.FuncMap {
|
||||
stepFuncOnce.Do(func() {
|
||||
stepFuncMap = sprig.TxtFuncMap()
|
||||
delete(stepFuncMap, "env")
|
||||
delete(stepFuncMap, "expandenv")
|
||||
})
|
||||
return stepFuncMap
|
||||
}
|
||||
|
||||
// Templates is a collection of templates and variables.
|
||||
type Templates struct {
|
||||
SSH *SSHTemplates `json:"ssh,omitempty"`
|
||||
@@ -282,12 +299,3 @@ func mkdir(path string, perm os.FileMode) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StepFuncMap returns sprig.TxtFuncMap but removing the "env" and "expandenv"
|
||||
// functions to avoid any leak of information.
|
||||
func StepFuncMap() template.FuncMap {
|
||||
m := sprig.TxtFuncMap()
|
||||
delete(m, "env")
|
||||
delete(m, "expandenv")
|
||||
return m
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user