Merge pull request #1940 from smallstep/mariano/self-trust

Allow to use private  IdPs with the OIDC provisioner
This commit is contained in:
Herman Slatman
2024-08-13 09:59:56 +02:00
committed by GitHub
14 changed files with 276 additions and 44 deletions

View File

@@ -49,6 +49,7 @@ type Authority struct {
templates *templates.Templates
linkedCAToken string
webhookClient *http.Client
httpClient *http.Client
// X509 CA
password []byte
@@ -491,6 +492,15 @@ func (a *Authority) init() error {
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
}
// Initialize HTTPClient with all root certs
clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs))
clientRoots = append(clientRoots, a.rootX509Certs...)
clientRoots = append(clientRoots, a.federatedX509Certs...)
a.httpClient, err = newHTTPClient(clientRoots...)
if err != nil {
return err
}
// Decrypt and load SSH keys
var tmplVars templates.Step
if a.config.SSH != nil {

34
authority/http_client.go Normal file
View File

@@ -0,0 +1,34 @@
package authority
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
)
// newHTTPClient returns an HTTP client that trusts the system cert pool and the
// given roots.
func newHTTPClient(roots ...*x509.Certificate) (*http.Client, error) {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("error initializing http client: %w", err)
}
for _, crt := range roots {
pool.AddCert(crt)
}
tr, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("error initializing http client: type is not *http.Transport")
}
tr = tr.Clone()
tr.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: pool,
}
return &http.Client{
Transport: tr,
}, nil
}

View File

@@ -0,0 +1,105 @@
package authority
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
)
func mustCertificate(t *testing.T, a *Authority, csr *x509.CertificateRequest) []*x509.Certificate {
t.Helper()
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
now := time.Now()
signOpts := provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(now),
NotAfter: provisioner.NewTimeDuration(now.Add(5 * time.Minute)),
Backdate: 1 * time.Minute,
}
sans := []string{}
sans = append(sans, csr.DNSNames...)
sans = append(sans, csr.EmailAddresses...)
for _, s := range csr.IPAddresses {
sans = append(sans, s.String())
}
for _, s := range csr.URIs {
sans = append(sans, s.String())
}
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
require.NoError(t, err)
token, err := generateToken(csr.Subject.CommonName, "step-cli", testAudiences.Sign[0], sans, now, key)
require.NoError(t, err)
extraOpts, err := a.Authorize(ctx, token)
require.NoError(t, err)
chain, err := a.SignWithContext(ctx, csr, signOpts, extraOpts...)
require.NoError(t, err)
return chain
}
func Test_newHTTPClient(t *testing.T) {
signer, err := keyutil.GenerateDefaultSigner()
require.NoError(t, err)
csr, err := x509util.CreateCertificateRequest("test", []string{"localhost", "127.0.0.1", "[::1]"}, signer)
require.NoError(t, err)
auth := testAuthority(t)
chain := mustCertificate(t, auth, csr)
t.Run("SystemCertPool", func(t *testing.T) {
resp, err := auth.httpClient.Get("https://smallstep.com")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.NotEmpty(t, b)
assert.NoError(t, resp.Body.Close())
})
t.Run("LocalCertPool", func(t *testing.T) {
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "ok")
}))
srv.TLS = &tls.Config{
Certificates: []tls.Certificate{
{Certificate: [][]byte{chain[0].Raw, chain[1].Raw}, PrivateKey: signer, Leaf: chain[0]},
},
}
srv.StartTLS()
defer srv.Close()
resp, err := auth.httpClient.Get(srv.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, []byte("ok"), b)
assert.NoError(t, resp.Body.Close())
t.Run("DefaultClient", func(t *testing.T) {
client := &http.Client{}
_, err := client.Get(srv.URL)
assert.Error(t, err)
})
})
}

View File

@@ -251,14 +251,14 @@ func (p *Azure) Init(config Config) (err error) {
p.assertConfig()
// Decode and validate openid-configuration endpoint
if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
if err = getAndDecode(http.DefaultClient, p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
return
}
if err := p.oidcConfig.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
}
// Get JWK key set
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
if p.keyStore, err = newKeyStore(http.DefaultClient, p.oidcConfig.JWKSetURI); err != nil {
return
}

View File

@@ -26,6 +26,7 @@ type Controller struct {
policy *policyEngine
webhookClient *http.Client
webhooks []*Webhook
httpClient *http.Client
}
// NewController initializes a new provisioner controller.
@@ -48,9 +49,19 @@ func NewController(p Interface, claims *Claims, config Config, options *Options)
policy: policy,
webhookClient: config.WebhookClient,
webhooks: options.GetWebhooks(),
httpClient: config.HTTPClient,
}, nil
}
// GetHTTPClient returns the configured HTTP client or the default one if none
// is configured.
func (c *Controller) GetHTTPClient() *http.Client {
if c.httpClient != nil {
return c.httpClient
}
return &http.Client{}
}
// GetIdentity returns the identity for a given email.
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
if c.IdentityFunc != nil {

View File

@@ -9,13 +9,13 @@ import (
"testing"
"time"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/webhook"
)
var trueValue = true
@@ -79,12 +79,14 @@ func TestNewController(t *testing.T) {
wantErr bool
}{
{"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
Claims: globalProvisionerClaims,
Audiences: testAudiences,
HTTPClient: &http.Client{},
}, nil}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
httpClient: &http.Client{},
}, false},
{"ok with claims", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
@@ -145,6 +147,30 @@ func TestNewController(t *testing.T) {
}
}
func TestController_GetHTTPClient(t *testing.T) {
srv := generateTLSJWKServer(2)
defer srv.Close()
type fields struct {
httpClient *http.Client
}
tests := []struct {
name string
fields fields
want *http.Client
}{
{"ok custom", fields{srv.Client()}, srv.Client()},
{"ok default", fields{http.DefaultClient}, http.DefaultClient},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
httpClient: tt.fields.httpClient,
}
assert.Equal(t, tt.want, c.GetHTTPClient())
})
}
}
func TestController_GetIdentity(t *testing.T) {
ctx := context.Background()
type fields struct {

View File

@@ -228,7 +228,7 @@ func (p *GCP) Init(config Config) (err error) {
p.assertConfig()
// Initialize key store
if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
if p.keyStore, err = newKeyStore(http.DefaultClient, p.config.CertsURL); err != nil {
return
}

View File

@@ -22,6 +22,7 @@ var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
type keyStore struct {
sync.RWMutex
client *http.Client
uri string
keySet jose.JSONWebKeySet
timer *time.Timer
@@ -29,12 +30,13 @@ type keyStore struct {
jitter time.Duration
}
func newKeyStore(uri string) (*keyStore, error) {
keys, age, err := getKeysFromJWKsURI(uri)
func newKeyStore(client *http.Client, uri string) (*keyStore, error) {
keys, age, err := getKeysFromJWKsURI(client, uri)
if err != nil {
return nil, err
}
ks := &keyStore{
client: client,
uri: uri,
keySet: keys,
expiry: getExpirationTime(age),
@@ -64,7 +66,7 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
func (ks *keyStore) reload() {
var next time.Duration
keys, age, err := getKeysFromJWKsURI(ks.uri)
keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri)
if err != nil {
next = ks.nextReloadDuration(ks.jitter / 2)
} else {
@@ -90,9 +92,9 @@ func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
return abs(age)
}
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
func getKeysFromJWKsURI(client *http.Client, uri string) (jose.JSONWebKeySet, time.Duration, error) {
var keys jose.JSONWebKeySet
resp, err := http.Get(uri) //nolint:gosec // openid-configuration jwks_uri
resp, err := client.Get(uri)
if err != nil {
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
}

View File

@@ -3,6 +3,8 @@ package provisioner
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
@@ -12,14 +14,19 @@ import (
)
func Test_newKeyStore(t *testing.T) {
srv := generateJWKServer(2)
srv := generateTLSJWKServer(2)
srv.Close()
srv = httptest.NewTLSServer(srv.Config.Handler)
defer srv.Close()
ks, err := newKeyStore(srv.URL)
ks, err := newKeyStore(srv.Client(), srv.URL)
assert.FatalError(t, err)
defer ks.Close()
type args struct {
uri string
client *http.Client
uri string
}
tests := []struct {
name string
@@ -27,12 +34,13 @@ func Test_newKeyStore(t *testing.T) {
want jose.JSONWebKeySet
wantErr bool
}{
{"ok", args{srv.URL}, ks.keySet, false},
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
{"ok", args{srv.Client(), srv.URL}, ks.keySet, false},
{"fail", args{srv.Client(), srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
{"fail client", args{http.DefaultClient, srv.URL}, jose.JSONWebKeySet{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newKeyStore(tt.args.uri)
got, err := newKeyStore(tt.args.client, tt.args.uri)
if (err != nil) != tt.wantErr {
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -51,7 +59,7 @@ func Test_keyStore(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL + "/random")
ks, err := newKeyStore(srv.Client(), srv.URL+"/random")
assert.FatalError(t, err)
defer ks.Close()
ks.RLock()
@@ -95,7 +103,7 @@ func Test_keyStore_noCache(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL + "/no-cache")
ks, err := newKeyStore(srv.Client(), srv.URL+"/no-cache")
assert.FatalError(t, err)
defer ks.Close()
ks.RLock()
@@ -137,7 +145,7 @@ func Test_keyStore_noCache(t *testing.T) {
func Test_keyStore_Get(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL)
ks, err := newKeyStore(srv.Client(), srv.URL)
assert.FatalError(t, err)
defer ks.Close()

View File

@@ -184,23 +184,29 @@ func (o *OIDC) Init(config Config) (err error) {
if !strings.Contains(u.Path, "/.well-known/openid-configuration") {
u.Path = path.Join(u.Path, "/.well-known/openid-configuration")
}
if err := getAndDecode(u.String(), &o.configuration); err != nil {
// Initialize the common provisioner controller
o.ctl, err = NewController(o, o.Claims, config, o.Options)
if err != nil {
return err
}
// Decode and validate openid-configuration
httpClient := o.ctl.GetHTTPClient()
if err := getAndDecode(httpClient, u.String(), &o.configuration); err != nil {
return err
}
if err := o.configuration.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint)
}
// Replace {tenantid} with the configured one
if o.TenantID != "" {
o.configuration.Issuer = strings.ReplaceAll(o.configuration.Issuer, "{tenantid}", o.TenantID)
}
// Get JWK key set
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
if err != nil {
return err
}
o.ctl, err = NewController(o, o.Claims, config, o.Options)
// Get JWK key set
o.keyStore, err = newKeyStore(httpClient, o.configuration.JWKSetURI)
return
}
@@ -479,8 +485,8 @@ func (o *OIDC) AuthorizeSSHRevoke(_ context.Context, token string) error {
return errs.Unauthorized("oidc.AuthorizeSSHRevoke; cannot revoke with non-admin oidc token")
}
func getAndDecode(uri string, v interface{}) error {
resp, err := http.Get(uri) //nolint:gosec // openid-configuration uri
func getAndDecode(client *http.Client, uri string, v interface{}) error {
resp, err := client.Get(uri)
if err != nil {
return errors.Wrapf(err, "failed to connect to %s", uri)
}

View File

@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"testing"
"time"
@@ -70,8 +71,17 @@ func TestOIDC_Getters(t *testing.T) {
func TestOIDC_Init(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
tlsSrv := generateTLSJWKServer(2)
defer tlsSrv.Close()
config := Config{
Claims: globalProvisionerClaims,
Claims: globalProvisionerClaims,
HTTPClient: tlsSrv.Client(),
}
badHTTPClientConfig := Config{
Claims: globalProvisionerClaims,
HTTPClient: http.DefaultClient,
}
badClaims := &Claims{
DefaultTLSDur: &Duration{0},
@@ -98,6 +108,7 @@ func TestOIDC_Init(t *testing.T) {
wantErr bool
}{
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ""}, args{config}, false},
{"ok tls", fields{"oidc", "name", "client-id", "client-secret", tlsSrv.URL, nil, nil, nil, ""}, args{config}, false},
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/.well-known/openid-configuration", nil, []string{"foo@smallstep.com"}, nil, ""}, args{config}, false},
{"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, []string{"smallstep.com"}, ""}, args{config}, false},
{"ok-listen-port", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, ":10000"}, args{config}, false},
@@ -112,6 +123,7 @@ func TestOIDC_Init(t *testing.T) {
{"bad-parse-url", fields{"oidc", "name", "client-id", "client-secret", ":", nil, nil, nil, ""}, args{config}, true},
{"bad-get-url", fields{"oidc", "name", "client-id", "client-secret", "https://", nil, nil, nil, ""}, args{config}, true},
{"bad-listen-address", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil, "127.0.0.1"}, args{config}, true},
{"bad-http-client", fields{"oidc", "name", "client-id", "client-secret", tlsSrv.URL, nil, nil, nil, ""}, args{badHTTPClientConfig}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -131,9 +143,13 @@ func TestOIDC_Init(t *testing.T) {
}
if tt.wantErr == false {
assert.Len(t, 2, p.keyStore.keySet.Keys)
u, err := url.Parse(tt.fields.ConfigurationEndpoint)
require.NoError(t, err)
assert.Equals(t, openIDConfiguration{
Issuer: "the-issuer",
JWKSetURI: srv.URL + "/jwks_uri",
JWKSetURI: u.ResolveReference(&url.URL{Path: "/jwks_uri"}).String(),
}, p.configuration)
}
})
@@ -145,7 +161,7 @@ func TestOIDC_authorizeToken(t *testing.T) {
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys))
issuer := "the-issuer"
tenantID := "ab800f7d-2c87-45fb-b1d0-f90d0bc5ec25"
@@ -263,7 +279,7 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
@@ -356,7 +372,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
@@ -467,7 +483,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
@@ -645,7 +661,7 @@ func TestOIDC_AuthorizeSSHRevoke(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
assert.FatalError(t, getAndDecode(srv.Client(), srv.URL+"/private", &keys))
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"

View File

@@ -261,6 +261,9 @@ type Config struct {
WebhookClient *http.Client
// SCEPKeyManager, if defined, is the interface used by SCEP provisioners.
SCEPKeyManager SCEPKeyManager
// HTTPClient is an HTTP client that trust the system cert pool and the CA
// roots.
HTTPClient *http.Client
}
type provisioner struct {

View File

@@ -1099,7 +1099,7 @@ func parseAWSToken(token string) (*jose.JSONWebToken, *awsPayload, error) {
return tok, claims, nil
}
func generateJWKServer(n int) *httptest.Server {
func generateJWKServerHandler(n int, srv *httptest.Server) http.Handler {
hits := struct {
Hits int `json:"hits"`
}{}
@@ -1122,8 +1122,7 @@ func generateJWKServer(n int) *httptest.Server {
}
defaultKeySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
srv := httptest.NewUnstartedServer(nil)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits.Hits++
switch r.RequestURI {
case "/error":
@@ -1149,11 +1148,22 @@ func generateJWKServer(n int) *httptest.Server {
writeJSON(w, getPublic(defaultKeySet))
}
})
}
func generateJWKServer(n int) *httptest.Server {
srv := httptest.NewUnstartedServer(nil)
srv.Config.Handler = generateJWKServerHandler(n, srv)
srv.Start()
return srv
}
func generateTLSJWKServer(n int) *httptest.Server {
srv := httptest.NewUnstartedServer(nil)
srv.Config.Handler = generateJWKServerHandler(n, srv)
srv.StartTLS()
return srv
}
func generateACME() (*ACME, error) {
// Initialize provisioners
p := &ACME{

View File

@@ -201,6 +201,7 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.
AuthorizeRenewFunc: a.authorizeRenewFunc,
AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc,
WebhookClient: a.webhookClient,
HTTPClient: a.httpClient,
SCEPKeyManager: a.scepKeyManager,
}, nil
}