Transport wrappers (#2103)

* internal/httptransport: implemented Wrapper & NoopWrapper

* added transport wrappers

* addressed review comments
This commit is contained in:
Panagiotis Siatras
2024-12-12 19:51:36 +02:00
committed by GitHub
parent c986962154
commit 809c7023c9
14 changed files with 143 additions and 68 deletions

View File

@@ -33,6 +33,7 @@ import (
"github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/scep"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/nosql"
@@ -48,6 +49,7 @@ type Authority struct {
adminDB admin.DB
templates *templates.Templates
linkedCAToken string
wrapTransport httptransport.Wrapper
webhookClient *http.Client
httpClient *http.Client
@@ -128,10 +130,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) {
}
var a = &Authority{
config: cfg,
certificates: new(sync.Map),
validateSCEP: true,
meter: noopMeter{},
config: cfg,
certificates: new(sync.Map),
validateSCEP: true,
meter: noopMeter{},
wrapTransport: httptransport.NoopWrapper(),
}
// Apply options.
@@ -158,9 +161,10 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) {
// project without the limitations of the config.
func NewEmbedded(opts ...Option) (*Authority, error) {
a := &Authority{
config: &config.Config{},
certificates: new(sync.Map),
meter: noopMeter{},
config: &config.Config{},
certificates: new(sync.Map),
meter: noopMeter{},
wrapTransport: httptransport.NoopWrapper(),
}
// Apply options.
@@ -496,7 +500,7 @@ func (a *Authority) init() error {
clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs))
clientRoots = append(clientRoots, a.rootX509Certs...)
clientRoots = append(clientRoots, a.federatedX509Certs...)
a.httpClient, err = newHTTPClient(clientRoots...)
a.httpClient, err = newHTTPClient(a.wrapTransport, clientRoots...)
if err != nil {
return err
}

View File

@@ -5,30 +5,34 @@ import (
"crypto/x509"
"fmt"
"net/http"
"github.com/smallstep/certificates/internal/httptransport"
)
// newHTTPClient will return an HTTP client that trusts the system cert pool and
// the given roots, but only if the http.DefaultTransport is an *http.Transport.
// If not, it will return the default HTTP client.
func newHTTPClient(roots ...*x509.Certificate) (*http.Client, error) {
if tr, ok := http.DefaultTransport.(*http.Transport); ok {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("error initializing http client: %w", err)
}
for _, crt := range roots {
pool.AddCert(crt)
}
tr = tr.Clone()
tr.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: pool,
}
return &http.Client{
Transport: tr,
}, nil
// the given roots.
func newHTTPClient(wt httptransport.Wrapper, roots ...*x509.Certificate) (*http.Client, error) {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("error initializing http client: %w", err)
}
for _, crt := range roots {
pool.AddCert(crt)
}
return &http.Client{}, nil
tr, ok := http.DefaultTransport.(*http.Transport)
if !ok {
tr = httptransport.New()
} else {
tr = tr.Clone()
}
tr.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: pool,
}
return &http.Client{
Transport: wt(tr),
}, nil
}

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
@@ -113,8 +114,8 @@ func Test_newHTTPClient(t *testing.T) {
}{http.DefaultTransport}
http.DefaultTransport = transport
client, err := newHTTPClient(auth.rootX509Certs...)
client, err := newHTTPClient(httptransport.NoopWrapper(), auth.rootX509Certs...)
assert.NoError(t, err)
assert.Equal(t, &http.Client{}, client)
assert.NotNil(t, client)
})
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/scep"
)
@@ -103,6 +104,22 @@ func WithWebhookClient(c *http.Client) Option {
}
}
// Wrapper wraps the set of functions mapping [http.Transport] references to [http.RoundTripper].
type TransportWrapper = httptransport.Wrapper
// WithTransportWrapper sets the transport wrapper of the authority to the provided one or, in case
// that one is nil, to a noop one.
func WithTransportWrapper(tw httptransport.Wrapper) Option {
if tw == nil {
tw = httptransport.NoopWrapper()
}
return func(a *Authority) error {
a.wrapTransport = tw
return nil
}
}
// WithGetIdentityFunc sets a custom function to retrieve the identity from
// an external resource.
func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {

View File

@@ -9,6 +9,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
@@ -27,6 +28,7 @@ type Controller struct {
webhookClient *http.Client
webhooks []*Webhook
httpClient *http.Client
wrapTransport httptransport.Wrapper
}
// NewController initializes a new provisioner controller.
@@ -50,6 +52,7 @@ func NewController(p Interface, claims *Claims, config Config, options *Options)
webhookClient: config.WebhookClient,
webhooks: options.GetWebhooks(),
httpClient: config.HTTPClient,
wrapTransport: config.WrapTransport,
}, nil
}
@@ -89,16 +92,25 @@ func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificat
}
func (c *Controller) newWebhookController(templateData WebhookSetter, certType linkedca.Webhook_CertType, opts ...webhook.RequestBodyOption) *WebhookController {
wt := c.wrapTransport
if wt == nil {
wt = httptransport.NoopWrapper()
}
client := c.webhookClient
if client == nil {
client = http.DefaultClient
client = &http.Client{
Transport: wt(httptransport.New()),
}
}
return &WebhookController{
TemplateData: templateData,
client: client,
webhooks: c.webhooks,
certType: certType,
options: opts,
TemplateData: templateData,
client: client,
wrapTransport: wt,
webhooks: c.webhooks,
certType: certType,
options: opts,
}
}

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"go.step.sm/crypto/pemutil"
@@ -512,11 +513,18 @@ func Test_newWebhookController(t *testing.T) {
options: opts,
}},
}
for _, tt := range tests {
c := &Controller{}
got := c.newWebhookController(tt.args.templateData, tt.args.certType, tt.args.opts...)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("newWebhookController() = %v, want %v", got, tt.want)
c := Controller{
webhookClient: new(http.Client),
wrapTransport: httptransport.NoopWrapper(),
}
got := c.newWebhookController(tt.args.templateData, tt.args.certType, tt.args.opts...)
assert.Equal(t, tt.args.templateData, got.TemplateData)
assert.Same(t, c.webhookClient, got.client)
assert.Equal(t, c.webhooks, got.webhooks)
assert.Equal(t, tt.args.opts, got.options)
assert.Equal(t, tt.args.certType, got.certType)
}
}

View File

@@ -264,6 +264,9 @@ type Config struct {
// HTTPClient is an HTTP client that trusts the system cert pool and the CA
// roots.
HTTPClient *http.Client
// WrapTransport references the function that should wrap any [http.Transport] initialized
// down the Config's chain.
WrapTransport TransportWrapper
}
type provisioner struct {

View File

@@ -18,6 +18,7 @@ import (
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/internal/httptransport"
"github.com/smallstep/certificates/webhook"
)
@@ -112,13 +113,14 @@ func (s *SCEP) DefaultTLSCertDuration() time.Duration {
}
type challengeValidationController struct {
client *http.Client
webhooks []*Webhook
client *http.Client
wrapTransport httptransport.Wrapper
webhooks []*Webhook
}
// newChallengeValidationController creates a new challengeValidationController
// that performs challenge validation through webhooks.
func newChallengeValidationController(client *http.Client, webhooks []*Webhook) *challengeValidationController {
func newChallengeValidationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *challengeValidationController {
scepHooks := []*Webhook{}
for _, wh := range webhooks {
if wh.Kind != linkedca.Webhook_SCEPCHALLENGE.String() {
@@ -130,8 +132,9 @@ func newChallengeValidationController(client *http.Client, webhooks []*Webhook)
scepHooks = append(scepHooks, wh)
}
return &challengeValidationController{
client: client,
webhooks: scepHooks,
client: client,
wrapTransport: tw,
webhooks: scepHooks,
}
}
@@ -157,7 +160,7 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509.
req.ProvisionerName = provisionerName
req.SCEPChallenge = challenge
req.SCEPTransactionID = transactionID
resp, err := wh.DoWithContext(ctx, c.client, req, nil) // TODO(hs): support templated URL? Requires some refactoring
resp, err := wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil) // TODO(hs): support templated URL? Requires some refactoring
if err != nil {
return nil, fmt.Errorf("failed executing webhook request: %w", err)
}
@@ -176,13 +179,14 @@ func (c *challengeValidationController) Validate(ctx context.Context, csr *x509.
}
type notificationController struct {
client *http.Client
webhooks []*Webhook
client *http.Client
wrapTransport httptransport.Wrapper
webhooks []*Webhook
}
// newNotificationController creates a new notificationController
// that performs SCEP notifications through webhooks.
func newNotificationController(client *http.Client, webhooks []*Webhook) *notificationController {
func newNotificationController(client *http.Client, tw httptransport.Wrapper, webhooks []*Webhook) *notificationController {
scepHooks := []*Webhook{}
for _, wh := range webhooks {
if wh.Kind != linkedca.Webhook_NOTIFYING.String() {
@@ -194,8 +198,9 @@ func newNotificationController(client *http.Client, webhooks []*Webhook) *notifi
scepHooks = append(scepHooks, wh)
}
return &notificationController{
client: client,
webhooks: scepHooks,
client: client,
wrapTransport: tw,
webhooks: scepHooks,
}
}
@@ -207,7 +212,7 @@ func (c *notificationController) Success(ctx context.Context, csr *x509.Certific
}
req.X509Certificate.Raw = cert.Raw // adding the full certificate DER bytes
req.SCEPTransactionID = transactionID
if _, err = wh.DoWithContext(ctx, c.client, req, nil); err != nil {
if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil {
return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err)
}
}
@@ -224,7 +229,7 @@ func (c *notificationController) Failure(ctx context.Context, csr *x509.Certific
req.SCEPTransactionID = transactionID
req.SCEPErrorCode = errorCode
req.SCEPErrorDescription = errorDescription
if _, err = wh.DoWithContext(ctx, c.client, req, nil); err != nil {
if _, err = wh.DoWithContext(ctx, c.client, c.wrapTransport, req, nil); err != nil {
return fmt.Errorf("failed executing webhook request: %w: %w", ErrSCEPNotificationFailed, err)
}
}
@@ -267,12 +272,14 @@ func (s *SCEP) Init(config Config) (err error) {
// Prepare the SCEP challenge validator
s.challengeValidationController = newChallengeValidationController(
config.WebhookClient,
config.WrapTransport,
s.GetOptions().GetWebhooks(),
)
// Prepare the SCEP notification controller
s.notificationController = newNotificationController(
config.WebhookClient,
config.WrapTransport,
s.GetOptions().GetWebhooks(),
)

View File

@@ -201,7 +201,7 @@ func Test_challengeValidationController_Validate(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := newChallengeValidationController(tt.fields.client, tt.fields.webhooks)
c := newChallengeValidationController(tt.fields.client, nil, tt.fields.webhooks)
ctx := context.Background()
got, err := c.Validate(ctx, dummyCSR, tt.args.provisionerName, tt.args.challenge, tt.args.transactionID)
if tt.expErr != nil {

View File

@@ -31,11 +31,12 @@ type WebhookSetter interface {
}
type WebhookController struct {
client *http.Client
webhooks []*Webhook
certType linkedca.Webhook_CertType
options []webhook.RequestBodyOption
TemplateData WebhookSetter
client *http.Client
wrapTransport httptransport.Wrapper
webhooks []*Webhook
certType linkedca.Webhook_CertType
options []webhook.RequestBodyOption
TemplateData WebhookSetter
}
// Enrich fetches data from remote servers and adds returned data to the
@@ -63,7 +64,7 @@ func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBod
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
resp, err := wh.DoWithContext(whCtx, wc.client, wc.wrapTransport, req, wc.TemplateData)
if err != nil {
return err
}
@@ -102,7 +103,7 @@ func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.Request
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
resp, err := wh.DoWithContext(whCtx, wc.client, wc.wrapTransport, req, wc.TemplateData)
if err != nil {
return err
}
@@ -141,7 +142,11 @@ type Webhook struct {
} `json:"-"`
}
func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
// TransportWrapper wraps the set of functions mapping [http.Transport] references to
// [http.RoundTripper].
type TransportWrapper = httptransport.Wrapper
func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, tw TransportWrapper, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL)
if err != nil {
return nil, err
@@ -214,7 +219,7 @@ retry:
}
client = &http.Client{
Transport: transport,
Transport: tw(transport),
}
}
resp, err := client.Do(req)

View File

@@ -627,7 +627,7 @@ func TestWebhook_Do(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg)
got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, httptransport.NoopWrapper(), reqBody, tc.dataArg)
if tc.expectErr != nil {
assert.Equal(t, tc.expectErr.Error(), err.Error())
return
@@ -663,14 +663,14 @@ func TestWebhook_Do(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
_, err = wh.DoWithContext(ctx, client, reqBody, nil)
_, err = wh.DoWithContext(ctx, client, httptransport.NoopWrapper(), reqBody, nil)
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
wh.DisableTLSClientAuth = true
_, err = wh.DoWithContext(ctx, client, reqBody, nil)
_, err = wh.DoWithContext(ctx, client, httptransport.NoopWrapper(), reqBody, nil)
require.Error(t, err)
})
}

View File

@@ -202,6 +202,7 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.
AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc,
WebhookClient: a.webhookClient,
HTTPClient: a.httpClient,
WrapTransport: a.wrapTransport,
SCEPKeyManager: a.scepKeyManager,
}, nil
}

View File

@@ -198,7 +198,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
}
webhookTransport := httptransport.New()
opts = append(opts, authority.WithWebhookClient(&http.Client{Transport: webhookTransport}))
opts = append(opts,
authority.WithWebhookClient(&http.Client{Transport: webhookTransport}),
)
auth, err := authority.New(cfg, opts...)
if err != nil {

View File

@@ -8,6 +8,17 @@ import (
"time"
)
// Wrapper wraps the set of functions mapping [http.Transport] references to [http.RoundTripper].
type Wrapper func(*http.Transport) http.RoundTripper
// NoopWrapper returns a [Wrapper] that simply casts its provided [http.Transport] to an
// [http.RoundTripper].
func NoopWrapper() Wrapper {
return func(t *http.Transport) http.RoundTripper {
return t
}
}
// New returns a reference to an [http.Transport] that's initialized just like the
// [http.DefaultTransport] is by the standard library.
func New() *http.Transport {