mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-11-03 19:58:17 +00:00 
			
		
		
		
	rework oidc client auth provider
* Cache OpenID Connect clients to prevent reinitialization * Don't retry requests in the http.RoundTripper. * Don't rely on the server not reading POST bodies. * Don't leak response body FDs. * Formerly ignored any throttling requests by the server. * Determine if the id token's expired by inspecting it. * Similar to logic in golang.org/x/oauth2 * Synchronize around refreshing tokens and persisting the new config.
This commit is contained in:
		@@ -22,6 +22,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/coreos/go-oidc/jose"
 | 
			
		||||
@@ -30,7 +31,6 @@ import (
 | 
			
		||||
	"github.com/golang/glog"
 | 
			
		||||
 | 
			
		||||
	"k8s.io/kubernetes/pkg/client/restclient"
 | 
			
		||||
	"k8s.io/kubernetes/pkg/util/wait"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -44,21 +44,68 @@ const (
 | 
			
		||||
	cfgRefreshToken             = "refresh-token"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	backoff = wait.Backoff{
 | 
			
		||||
		Duration: 1 * time.Second,
 | 
			
		||||
		Factor:   2,
 | 
			
		||||
		Jitter:   .1,
 | 
			
		||||
		Steps:    5,
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil {
 | 
			
		||||
		glog.Fatalf("Failed to register oidc auth plugin: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// expiryDelta determines how earlier a token should be considered
 | 
			
		||||
// expired than its actual expiration time. It is used to avoid late
 | 
			
		||||
// expirations due to client-server time mismatches.
 | 
			
		||||
//
 | 
			
		||||
// NOTE(ericchiang): this is take from golang.org/x/oauth2
 | 
			
		||||
const expiryDelta = 10 * time.Second
 | 
			
		||||
 | 
			
		||||
var cache = newClientCache()
 | 
			
		||||
 | 
			
		||||
// Like TLS transports, keep a cache of OIDC clients indexed by issuer URL.
 | 
			
		||||
type clientCache struct {
 | 
			
		||||
	mu    sync.RWMutex
 | 
			
		||||
	cache map[cacheKey]*oidcAuthProvider
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newClientCache() *clientCache {
 | 
			
		||||
	return &clientCache{cache: make(map[cacheKey]*oidcAuthProvider)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type cacheKey struct {
 | 
			
		||||
	// Canonical issuer URL string of the provider.
 | 
			
		||||
	issuerURL string
 | 
			
		||||
 | 
			
		||||
	clientID     string
 | 
			
		||||
	clientSecret string
 | 
			
		||||
 | 
			
		||||
	// Don't use CA as cache key because we only add a cache entry if we can connect
 | 
			
		||||
	// to the issuer in the first place. A valid CA is a prerequisite.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *clientCache) getClient(issuer, clientID, clientSecret string) (*oidcAuthProvider, bool) {
 | 
			
		||||
	c.mu.RLock()
 | 
			
		||||
	defer c.mu.RUnlock()
 | 
			
		||||
	client, ok := c.cache[cacheKey{issuer, clientID, clientSecret}]
 | 
			
		||||
	return client, ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// setClient attempts to put the client in the cache but may return any clients
 | 
			
		||||
// with the same keys set before. This is so there's only ever one client for a provider.
 | 
			
		||||
func (c *clientCache) setClient(issuer, clientID, clientSecret string, client *oidcAuthProvider) *oidcAuthProvider {
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
	defer c.mu.Unlock()
 | 
			
		||||
	key := cacheKey{issuer, clientID, clientSecret}
 | 
			
		||||
 | 
			
		||||
	// If another client has already initialized a client for the given provider we want
 | 
			
		||||
	// to use that client instead of the one we're trying to set. This is so all transports
 | 
			
		||||
	// share a client and can coordinate around the same mutex when refreshing and writing
 | 
			
		||||
	// to the kubeconfig.
 | 
			
		||||
	if oldClient, ok := c.cache[key]; ok {
 | 
			
		||||
		return oldClient
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.cache[key] = client
 | 
			
		||||
	return client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
 | 
			
		||||
	issuer := cfg[cfgIssuerUrl]
 | 
			
		||||
	if issuer == "" {
 | 
			
		||||
@@ -75,6 +122,11 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
 | 
			
		||||
		return nil, fmt.Errorf("Must provide %s", cfgClientSecret)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check cache for existing provider.
 | 
			
		||||
	if provider, ok := cache.getClient(issuer, clientID, clientSecret); ok {
 | 
			
		||||
		return provider, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var certAuthData []byte
 | 
			
		||||
	var err error
 | 
			
		||||
	if cfg[cfgCertificateAuthorityData] != "" {
 | 
			
		||||
@@ -112,146 +164,134 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
 | 
			
		||||
		ProviderConfig: providerCfg,
 | 
			
		||||
		Scope:          append(scopes, oidc.DefaultScope...),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := oidc.NewClient(oidcCfg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("error creating OIDC Client: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	oClient := &oidcClient{client}
 | 
			
		||||
 | 
			
		||||
	var initialIDToken jose.JWT
 | 
			
		||||
	if cfg[cfgIDToken] != "" {
 | 
			
		||||
		initialIDToken, err = jose.ParseJWT(cfg[cfgIDToken])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	provider := &oidcAuthProvider{
 | 
			
		||||
		client:    &oidcClient{client},
 | 
			
		||||
		cfg:       cfg,
 | 
			
		||||
		persister: persister,
 | 
			
		||||
		now:       time.Now,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &oidcAuthProvider{
 | 
			
		||||
		initialIDToken: initialIDToken,
 | 
			
		||||
		refresher: &idTokenRefresher{
 | 
			
		||||
			client:    oClient,
 | 
			
		||||
			cfg:       cfg,
 | 
			
		||||
			persister: persister,
 | 
			
		||||
		},
 | 
			
		||||
	}, nil
 | 
			
		||||
	return cache.setClient(issuer, clientID, clientSecret, provider), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type oidcAuthProvider struct {
 | 
			
		||||
	refresher      *idTokenRefresher
 | 
			
		||||
	initialIDToken jose.JWT
 | 
			
		||||
	// Interface rather than a raw *oidc.Client for testing.
 | 
			
		||||
	client OIDCClient
 | 
			
		||||
 | 
			
		||||
	// Stubbed out for testing.
 | 
			
		||||
	now func() time.Time
 | 
			
		||||
 | 
			
		||||
	// Mutex guards persisting to the kubeconfig file and allows synchronized
 | 
			
		||||
	// updates to the in-memory config. It also ensures concurrent calls to
 | 
			
		||||
	// the RoundTripper only trigger a single refresh request.
 | 
			
		||||
	mu        sync.Mutex
 | 
			
		||||
	cfg       map[string]string
 | 
			
		||||
	persister restclient.AuthProviderConfigPersister
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
 | 
			
		||||
	at := &oidc.AuthenticatedTransport{
 | 
			
		||||
		TokenRefresher: g.refresher,
 | 
			
		||||
		RoundTripper:   rt,
 | 
			
		||||
	}
 | 
			
		||||
	at.SetJWT(g.initialIDToken)
 | 
			
		||||
func (p *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
 | 
			
		||||
	return &roundTripper{
 | 
			
		||||
		wrapped:   at,
 | 
			
		||||
		refresher: g.refresher,
 | 
			
		||||
		wrapped:  rt,
 | 
			
		||||
		provider: p,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (g *oidcAuthProvider) Login() error {
 | 
			
		||||
func (p *oidcAuthProvider) Login() error {
 | 
			
		||||
	return errors.New("not yet implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OIDCClient interface {
 | 
			
		||||
	refreshToken(rt string) (oauth2.TokenResponse, error)
 | 
			
		||||
	verifyJWT(jwt jose.JWT) error
 | 
			
		||||
	verifyJWT(jwt *jose.JWT) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type roundTripper struct {
 | 
			
		||||
	refresher *idTokenRefresher
 | 
			
		||||
	wrapped   *oidc.AuthenticatedTransport
 | 
			
		||||
	provider *oidcAuthProvider
 | 
			
		||||
	wrapped  http.RoundTripper
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
 | 
			
		||||
	var res *http.Response
 | 
			
		||||
	var err error
 | 
			
		||||
	firstTime := true
 | 
			
		||||
	wait.ExponentialBackoff(backoff, func() (bool, error) {
 | 
			
		||||
		if !firstTime {
 | 
			
		||||
			var jwt jose.JWT
 | 
			
		||||
			jwt, err = r.refresher.Refresh()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return true, nil
 | 
			
		||||
			}
 | 
			
		||||
			r.wrapped.SetJWT(jwt)
 | 
			
		||||
		} else {
 | 
			
		||||
			firstTime = false
 | 
			
		||||
		}
 | 
			
		||||
	token, err := r.provider.idToken()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
		res, err = r.wrapped.RoundTrip(req)
 | 
			
		||||
	// shallow copy of the struct
 | 
			
		||||
	r2 := new(http.Request)
 | 
			
		||||
	*r2 = *req
 | 
			
		||||
	// deep copy of the Header so we don't modify the original
 | 
			
		||||
	// request's Header (as per RoundTripper contract).
 | 
			
		||||
	r2.Header = make(http.Header)
 | 
			
		||||
	for k, s := range req.Header {
 | 
			
		||||
		r2.Header[k] = s
 | 
			
		||||
	}
 | 
			
		||||
	r2.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
 | 
			
		||||
 | 
			
		||||
	return r.wrapped.RoundTrip(r2)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *oidcAuthProvider) idToken() (string, error) {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if idToken, ok := p.cfg[cfgIDToken]; ok && len(idToken) > 0 {
 | 
			
		||||
		valid, err := verifyJWTExpiry(p.now(), idToken)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return true, nil
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
		if res.StatusCode == http.StatusUnauthorized {
 | 
			
		||||
			return false, nil
 | 
			
		||||
		if valid {
 | 
			
		||||
			// If the cached id token is still valid use it.
 | 
			
		||||
			return idToken, nil
 | 
			
		||||
		}
 | 
			
		||||
		return true, nil
 | 
			
		||||
	})
 | 
			
		||||
	return res, err
 | 
			
		||||
}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
type idTokenRefresher struct {
 | 
			
		||||
	cfg           map[string]string
 | 
			
		||||
	client        OIDCClient
 | 
			
		||||
	persister     restclient.AuthProviderConfigPersister
 | 
			
		||||
	intialIDToken jose.JWT
 | 
			
		||||
}
 | 
			
		||||
	// Try to request a new token using the refresh token.
 | 
			
		||||
	rt, ok := p.cfg[cfgRefreshToken]
 | 
			
		||||
	if !ok || len(rt) == 0 {
 | 
			
		||||
		return "", errors.New("No valid id-token, and cannot refresh without refresh-token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
func (r *idTokenRefresher) Verify(jwt jose.JWT) error {
 | 
			
		||||
	claims, err := jwt.Claims()
 | 
			
		||||
	tokens, err := p.client.refreshToken(rt)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	exp, ok, err := claims.TimeClaim("exp")
 | 
			
		||||
	switch {
 | 
			
		||||
	case err != nil:
 | 
			
		||||
		return fmt.Errorf("failed to parse 'exp' claim: %v", err)
 | 
			
		||||
	case !ok:
 | 
			
		||||
		return errors.New("missing required 'exp' claim")
 | 
			
		||||
	case exp.Before(now):
 | 
			
		||||
		return fmt.Errorf("token already expired at: %v", exp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *idTokenRefresher) Refresh() (jose.JWT, error) {
 | 
			
		||||
	rt, ok := r.cfg[cfgRefreshToken]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return jose.JWT{}, errors.New("No valid id-token, and cannot refresh without refresh-token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tokens, err := r.client.refreshToken(rt)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return jose.JWT{}, fmt.Errorf("could not refresh token: %v", err)
 | 
			
		||||
		return "", fmt.Errorf("could not refresh token: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	jwt, err := jose.ParseJWT(tokens.IDToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return jose.JWT{}, err
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := p.client.verifyJWT(&jwt); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create a new config to persist.
 | 
			
		||||
	newCfg := make(map[string]string)
 | 
			
		||||
	for key, val := range p.cfg {
 | 
			
		||||
		newCfg[key] = val
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if tokens.RefreshToken != "" && tokens.RefreshToken != rt {
 | 
			
		||||
		r.cfg[cfgRefreshToken] = tokens.RefreshToken
 | 
			
		||||
	}
 | 
			
		||||
	r.cfg[cfgIDToken] = jwt.Encode()
 | 
			
		||||
 | 
			
		||||
	err = r.persister.Persist(r.cfg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return jose.JWT{}, fmt.Errorf("could not perist new tokens: %v", err)
 | 
			
		||||
		newCfg[cfgRefreshToken] = tokens.RefreshToken
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return jwt, r.client.verifyJWT(jwt)
 | 
			
		||||
	newCfg[cfgIDToken] = tokens.IDToken
 | 
			
		||||
	if err = p.persister.Persist(newCfg); err != nil {
 | 
			
		||||
		return "", fmt.Errorf("could not perist new tokens: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Update the in memory config to reflect the on disk one.
 | 
			
		||||
	p.cfg = newCfg
 | 
			
		||||
 | 
			
		||||
	return tokens.IDToken, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// oidcClient is the real implementation of the OIDCClient interface, which is
 | 
			
		||||
// used for testing.
 | 
			
		||||
type oidcClient struct {
 | 
			
		||||
	client *oidc.Client
 | 
			
		||||
}
 | 
			
		||||
@@ -265,6 +305,29 @@ func (o *oidcClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
 | 
			
		||||
	return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *oidcClient) verifyJWT(jwt jose.JWT) error {
 | 
			
		||||
	return o.client.VerifyJWT(jwt)
 | 
			
		||||
func (o *oidcClient) verifyJWT(jwt *jose.JWT) error {
 | 
			
		||||
	return o.client.VerifyJWT(*jwt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func verifyJWTExpiry(now time.Time, s string) (valid bool, err error) {
 | 
			
		||||
	jwt, err := jose.ParseJWT(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, fmt.Errorf("invalid %q", cfgIDToken)
 | 
			
		||||
	}
 | 
			
		||||
	claims, err := jwt.Claims()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	exp, ok, err := claims.TimeClaim("exp")
 | 
			
		||||
	switch {
 | 
			
		||||
	case err != nil:
 | 
			
		||||
		return false, fmt.Errorf("failed to parse 'exp' claim: %v", err)
 | 
			
		||||
	case !ok:
 | 
			
		||||
		return false, errors.New("missing required 'exp' claim")
 | 
			
		||||
	case exp.After(now.Add(expiryDelta)):
 | 
			
		||||
		return true, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,13 +19,10 @@ package oidc
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
@@ -33,11 +30,41 @@ import (
 | 
			
		||||
	"github.com/coreos/go-oidc/key"
 | 
			
		||||
	"github.com/coreos/go-oidc/oauth2"
 | 
			
		||||
 | 
			
		||||
	"k8s.io/kubernetes/pkg/util/diff"
 | 
			
		||||
	"k8s.io/kubernetes/pkg/util/wait"
 | 
			
		||||
	oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func clearCache() {
 | 
			
		||||
	cache = newClientCache()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type persister struct{}
 | 
			
		||||
 | 
			
		||||
// we don't need to actually persist anything because there's no way for us to
 | 
			
		||||
// read from a persister.
 | 
			
		||||
func (p *persister) Persist(map[string]string) error { return nil }
 | 
			
		||||
 | 
			
		||||
type noRefreshOIDCClient struct{}
 | 
			
		||||
 | 
			
		||||
func (c *noRefreshOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
 | 
			
		||||
	return oauth2.TokenResponse{}, errors.New("alwaysErrOIDCClient: cannot refresh token")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *noRefreshOIDCClient) verifyJWT(jwt *jose.JWT) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mockOIDCClient struct {
 | 
			
		||||
	tokenResponse oauth2.TokenResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *mockOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
 | 
			
		||||
	return c.tokenResponse, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *mockOIDCClient) verifyJWT(jwt *jose.JWT) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewOIDCAuthProvider(t *testing.T) {
 | 
			
		||||
	tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -60,127 +87,211 @@ func TestNewOIDCAuthProvider(t *testing.T) {
 | 
			
		||||
		t.Fatalf("Could not read cert bytes %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
 | 
			
		||||
		"test": "jwt",
 | 
			
		||||
	}), op.PrivKey.Signer())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("Could not create signed JWT %v", err)
 | 
			
		||||
	makeToken := func(exp time.Time) *jose.JWT {
 | 
			
		||||
		jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
 | 
			
		||||
			"exp": exp.UTC().Unix(),
 | 
			
		||||
		}), op.PrivKey.Signer())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("Could not create signed JWT %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		return jwt
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		cfg map[string]string
 | 
			
		||||
	t0 := time.Now()
 | 
			
		||||
 | 
			
		||||
		wantErr            bool
 | 
			
		||||
		wantInitialIDToken jose.JWT
 | 
			
		||||
	goodToken := makeToken(t0.Add(time.Hour)).Encode()
 | 
			
		||||
	expiredToken := makeToken(t0.Add(-time.Hour)).Encode()
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name string
 | 
			
		||||
 | 
			
		||||
		cfg         map[string]string
 | 
			
		||||
		wantInitErr bool
 | 
			
		||||
 | 
			
		||||
		client       OIDCClient
 | 
			
		||||
		wantCfg      map[string]string
 | 
			
		||||
		wantTokenErr bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			// A Valid configuration
 | 
			
		||||
			name: "no id token and no refresh token",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
			},
 | 
			
		||||
			wantTokenErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// A Valid configuration with an Initial JWT
 | 
			
		||||
			name: "valid config with an initial token",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
				cfgIDToken:              jwt.Encode(),
 | 
			
		||||
				cfgIDToken:              goodToken,
 | 
			
		||||
			},
 | 
			
		||||
			client: new(noRefreshOIDCClient),
 | 
			
		||||
			wantCfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
				cfgIDToken:              goodToken,
 | 
			
		||||
			},
 | 
			
		||||
			wantInitialIDToken: *jwt,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Valid config, but using cfgCertificateAuthorityData
 | 
			
		||||
			name: "invalid ID token with a refresh token",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
				cfgRefreshToken:         "foo",
 | 
			
		||||
				cfgIDToken:              expiredToken,
 | 
			
		||||
			},
 | 
			
		||||
			client: &mockOIDCClient{
 | 
			
		||||
				tokenResponse: oauth2.TokenResponse{
 | 
			
		||||
					IDToken: goodToken,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			wantCfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
				cfgRefreshToken:         "foo",
 | 
			
		||||
				cfgIDToken:              goodToken,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "invalid ID token with a refresh token, server returns new refresh token",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
				cfgRefreshToken:         "foo",
 | 
			
		||||
				cfgIDToken:              expiredToken,
 | 
			
		||||
			},
 | 
			
		||||
			client: &mockOIDCClient{
 | 
			
		||||
				tokenResponse: oauth2.TokenResponse{
 | 
			
		||||
					IDToken:      goodToken,
 | 
			
		||||
					RefreshToken: "bar",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			wantCfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
				cfgRefreshToken:         "bar",
 | 
			
		||||
				cfgIDToken:              goodToken,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "expired token and no refresh otken",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
				cfgIDToken:              expiredToken,
 | 
			
		||||
			},
 | 
			
		||||
			wantTokenErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid base64d ca",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:                srv.URL,
 | 
			
		||||
				cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData),
 | 
			
		||||
				cfgClientID:                 "client-id",
 | 
			
		||||
				cfgClientSecret:             "client-secret",
 | 
			
		||||
			},
 | 
			
		||||
			client:       new(noRefreshOIDCClient),
 | 
			
		||||
			wantTokenErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Missing client id
 | 
			
		||||
			name: "missing client ID",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientSecret:         "client-secret",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
			wantInitErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Missing client secret
 | 
			
		||||
			name: "missing client secret",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:            srv.URL,
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
			wantInitErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Missing issuer url.
 | 
			
		||||
			name: "missing issuer URL",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgCertificateAuthority: cert,
 | 
			
		||||
				cfgClientID:             "client-id",
 | 
			
		||||
				cfgClientSecret:         "secret",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
			wantInitErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// No TLS config
 | 
			
		||||
			name: "missing TLS config",
 | 
			
		||||
			cfg: map[string]string{
 | 
			
		||||
				cfgIssuerUrl:    srv.URL,
 | 
			
		||||
				cfgClientID:     "client-id",
 | 
			
		||||
				cfgClientSecret: "secret",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
			wantInitErr: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i, tt := range tests {
 | 
			
		||||
		ap, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, nil)
 | 
			
		||||
		if tt.wantErr {
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		clearCache()
 | 
			
		||||
 | 
			
		||||
		p, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, new(persister))
 | 
			
		||||
		if tt.wantInitErr {
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				t.Errorf("case %d: want non-nil err", i)
 | 
			
		||||
				t.Errorf("%s: want non-nil err", tt.name)
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("case %d: unexpected error on newOIDCAuthProvider: %v", i, err)
 | 
			
		||||
			t.Errorf("%s: unexpected error on newOIDCAuthProvider: %v", tt.name, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		oidcAP, ok := ap.(*oidcAuthProvider)
 | 
			
		||||
		if !ok {
 | 
			
		||||
			t.Errorf("case %d: expected ap to be an oidcAuthProvider", i)
 | 
			
		||||
		provider := p.(*oidcAuthProvider)
 | 
			
		||||
		provider.client = tt.client
 | 
			
		||||
		provider.now = func() time.Time { return t0 }
 | 
			
		||||
 | 
			
		||||
		if _, err := provider.idToken(); err != nil {
 | 
			
		||||
			if !tt.wantTokenErr {
 | 
			
		||||
				t.Errorf("%s: failed to get id token: %v", tt.name, err)
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if tt.wantTokenErr {
 | 
			
		||||
			t.Errorf("%s: expected to not get id token: %v", tt.name, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if diff := compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken); diff != "" {
 | 
			
		||||
			t.Errorf("case %d: compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken)=%v", i, diff)
 | 
			
		||||
		if !reflect.DeepEqual(tt.wantCfg, provider.cfg) {
 | 
			
		||||
			t.Errorf("%s: expected config %#v got %#v", tt.name, tt.wantCfg, provider.cfg)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestWrapTranport(t *testing.T) {
 | 
			
		||||
	oldBackoff := backoff
 | 
			
		||||
	defer func() {
 | 
			
		||||
		backoff = oldBackoff
 | 
			
		||||
	}()
 | 
			
		||||
	backoff = wait.Backoff{
 | 
			
		||||
		Duration: 1 * time.Nanosecond,
 | 
			
		||||
		Steps:    3,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
func TestVerifyJWTExpiry(t *testing.T) {
 | 
			
		||||
	privKey, err := key.GeneratePrivateKey()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("can't generate private key: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	makeToken := func(s string, exp time.Time, count int) *jose.JWT {
 | 
			
		||||
		jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
 | 
			
		||||
			"test":  s,
 | 
			
		||||
@@ -193,451 +304,81 @@ func TestWrapTranport(t *testing.T) {
 | 
			
		||||
		return jwt
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	goodToken := makeToken("good", time.Now().Add(time.Hour), 0)
 | 
			
		||||
	goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1)
 | 
			
		||||
	expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0)
 | 
			
		||||
	t0 := time.Now()
 | 
			
		||||
 | 
			
		||||
	str := func(s string) *string {
 | 
			
		||||
		return &s
 | 
			
		||||
	}
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		cfgIDToken      *jose.JWT
 | 
			
		||||
		cfgRefreshToken *string
 | 
			
		||||
 | 
			
		||||
		expectRequests []testRoundTrip
 | 
			
		||||
 | 
			
		||||
		expectRefreshes []testRefresh
 | 
			
		||||
 | 
			
		||||
		expectPersists []testPersist
 | 
			
		||||
 | 
			
		||||
		wantStatus int
 | 
			
		||||
		wantErr    bool
 | 
			
		||||
		name        string
 | 
			
		||||
		jwt         *jose.JWT
 | 
			
		||||
		now         time.Time
 | 
			
		||||
		wantErr     bool
 | 
			
		||||
		wantExpired bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			// Initial JWT is set, it is good, it is set as bearer.
 | 
			
		||||
			cfgIDToken: goodToken,
 | 
			
		||||
 | 
			
		||||
			expectRequests: []testRoundTrip{
 | 
			
		||||
				{
 | 
			
		||||
					expectBearerToken: goodToken.Encode(),
 | 
			
		||||
					returnHTTPStatus:  200,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			wantStatus: 200,
 | 
			
		||||
			name: "valid jwt",
 | 
			
		||||
			jwt:  makeToken("foo", t0.Add(time.Hour), 1),
 | 
			
		||||
			now:  t0,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Initial JWT is set, but it's expired, so it gets refreshed.
 | 
			
		||||
			cfgIDToken:      expiredToken,
 | 
			
		||||
			cfgRefreshToken: str("rt1"),
 | 
			
		||||
 | 
			
		||||
			expectRefreshes: []testRefresh{
 | 
			
		||||
				{
 | 
			
		||||
					expectRefreshToken: "rt1",
 | 
			
		||||
					returnTokens: oauth2.TokenResponse{
 | 
			
		||||
						IDToken: goodToken.Encode(),
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectRequests: []testRoundTrip{
 | 
			
		||||
				{
 | 
			
		||||
					expectBearerToken: goodToken.Encode(),
 | 
			
		||||
					returnHTTPStatus:  200,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectPersists: []testPersist{
 | 
			
		||||
				{
 | 
			
		||||
					cfg: map[string]string{
 | 
			
		||||
						cfgIDToken:      goodToken.Encode(),
 | 
			
		||||
						cfgRefreshToken: "rt1",
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			wantStatus: 200,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Initial JWT is set, but it's expired, so it gets refreshed - this
 | 
			
		||||
			// time the refresh token itself is also refreshed
 | 
			
		||||
			cfgIDToken:      expiredToken,
 | 
			
		||||
			cfgRefreshToken: str("rt1"),
 | 
			
		||||
 | 
			
		||||
			expectRefreshes: []testRefresh{
 | 
			
		||||
				{
 | 
			
		||||
					expectRefreshToken: "rt1",
 | 
			
		||||
					returnTokens: oauth2.TokenResponse{
 | 
			
		||||
						IDToken:      goodToken.Encode(),
 | 
			
		||||
						RefreshToken: "rt2",
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectRequests: []testRoundTrip{
 | 
			
		||||
				{
 | 
			
		||||
					expectBearerToken: goodToken.Encode(),
 | 
			
		||||
					returnHTTPStatus:  200,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectPersists: []testPersist{
 | 
			
		||||
				{
 | 
			
		||||
					cfg: map[string]string{
 | 
			
		||||
						cfgIDToken:      goodToken.Encode(),
 | 
			
		||||
						cfgRefreshToken: "rt2",
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			wantStatus: 200,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Initial JWT is not set, so it gets refreshed.
 | 
			
		||||
			cfgRefreshToken: str("rt1"),
 | 
			
		||||
 | 
			
		||||
			expectRefreshes: []testRefresh{
 | 
			
		||||
				{
 | 
			
		||||
					expectRefreshToken: "rt1",
 | 
			
		||||
					returnTokens: oauth2.TokenResponse{
 | 
			
		||||
						IDToken: goodToken.Encode(),
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectRequests: []testRoundTrip{
 | 
			
		||||
				{
 | 
			
		||||
					expectBearerToken: goodToken.Encode(),
 | 
			
		||||
					returnHTTPStatus:  200,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectPersists: []testPersist{
 | 
			
		||||
				{
 | 
			
		||||
					cfg: map[string]string{
 | 
			
		||||
						cfgIDToken:      goodToken.Encode(),
 | 
			
		||||
						cfgRefreshToken: "rt1",
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			wantStatus: 200,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Expired token, but no refresh token.
 | 
			
		||||
			cfgIDToken: expiredToken,
 | 
			
		||||
 | 
			
		||||
			name:    "invalid jwt",
 | 
			
		||||
			jwt:     &jose.JWT{},
 | 
			
		||||
			now:     t0,
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Initial JWT is not set, so it gets refreshed, but the server
 | 
			
		||||
			// rejects it when it is used, so it refreshes again, which
 | 
			
		||||
			// succeeds.
 | 
			
		||||
			cfgRefreshToken: str("rt1"),
 | 
			
		||||
 | 
			
		||||
			expectRefreshes: []testRefresh{
 | 
			
		||||
				{
 | 
			
		||||
					expectRefreshToken: "rt1",
 | 
			
		||||
					returnTokens: oauth2.TokenResponse{
 | 
			
		||||
						IDToken: goodToken.Encode(),
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
				{
 | 
			
		||||
					expectRefreshToken: "rt1",
 | 
			
		||||
					returnTokens: oauth2.TokenResponse{
 | 
			
		||||
						IDToken: goodToken2.Encode(),
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectRequests: []testRoundTrip{
 | 
			
		||||
				{
 | 
			
		||||
					expectBearerToken: goodToken.Encode(),
 | 
			
		||||
					returnHTTPStatus:  http.StatusUnauthorized,
 | 
			
		||||
				},
 | 
			
		||||
				{
 | 
			
		||||
					expectBearerToken: goodToken2.Encode(),
 | 
			
		||||
					returnHTTPStatus:  http.StatusOK,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectPersists: []testPersist{
 | 
			
		||||
				{
 | 
			
		||||
					cfg: map[string]string{
 | 
			
		||||
						cfgIDToken:      goodToken.Encode(),
 | 
			
		||||
						cfgRefreshToken: "rt1",
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
				{
 | 
			
		||||
					cfg: map[string]string{
 | 
			
		||||
						cfgIDToken:      goodToken2.Encode(),
 | 
			
		||||
						cfgRefreshToken: "rt1",
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			wantStatus: 200,
 | 
			
		||||
			name:        "expired jwt",
 | 
			
		||||
			jwt:         makeToken("foo", t0.Add(-time.Hour), 1),
 | 
			
		||||
			now:         t0,
 | 
			
		||||
			wantExpired: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			// Initial JWT is but the server rejects it when it is used, so it
 | 
			
		||||
			// refreshes again, which succeeds.
 | 
			
		||||
			cfgRefreshToken: str("rt1"),
 | 
			
		||||
			cfgIDToken:      goodToken,
 | 
			
		||||
 | 
			
		||||
			expectRefreshes: []testRefresh{
 | 
			
		||||
				{
 | 
			
		||||
					expectRefreshToken: "rt1",
 | 
			
		||||
					returnTokens: oauth2.TokenResponse{
 | 
			
		||||
						IDToken: goodToken2.Encode(),
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectRequests: []testRoundTrip{
 | 
			
		||||
				{
 | 
			
		||||
					expectBearerToken: goodToken.Encode(),
 | 
			
		||||
					returnHTTPStatus:  http.StatusUnauthorized,
 | 
			
		||||
				},
 | 
			
		||||
				{
 | 
			
		||||
					expectBearerToken: goodToken2.Encode(),
 | 
			
		||||
					returnHTTPStatus:  http.StatusOK,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
 | 
			
		||||
			expectPersists: []testPersist{
 | 
			
		||||
				{
 | 
			
		||||
					cfg: map[string]string{
 | 
			
		||||
						cfgIDToken:      goodToken2.Encode(),
 | 
			
		||||
						cfgRefreshToken: "rt1",
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			wantStatus: 200,
 | 
			
		||||
			name:        "jwt expires soon enough to be marked expired",
 | 
			
		||||
			jwt:         makeToken("foo", t0, 1),
 | 
			
		||||
			now:         t0,
 | 
			
		||||
			wantExpired: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i, tt := range tests {
 | 
			
		||||
		client := &testOIDCClient{
 | 
			
		||||
			refreshes: tt.expectRefreshes,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		persister := &testPersister{
 | 
			
		||||
			tt.expectPersists,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		cfg := map[string]string{}
 | 
			
		||||
		if tt.cfgIDToken != nil {
 | 
			
		||||
			cfg[cfgIDToken] = tt.cfgIDToken.Encode()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if tt.cfgRefreshToken != nil {
 | 
			
		||||
			cfg[cfgRefreshToken] = *tt.cfgRefreshToken
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ap := &oidcAuthProvider{
 | 
			
		||||
			refresher: &idTokenRefresher{
 | 
			
		||||
				client:    client,
 | 
			
		||||
				cfg:       cfg,
 | 
			
		||||
				persister: persister,
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if tt.cfgIDToken != nil {
 | 
			
		||||
			ap.initialIDToken = *tt.cfgIDToken
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		tstRT := &testRoundTripper{
 | 
			
		||||
			tt.expectRequests,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		rt := ap.WrapTransport(tstRT)
 | 
			
		||||
 | 
			
		||||
		req, err := http.NewRequest("GET", "http://cluster.example.com", nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("case %d: unexpected error making request: %v", i, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res, err := rt.RoundTrip(req)
 | 
			
		||||
		if tt.wantErr {
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				t.Errorf("case %d: Expected non-nil error", i)
 | 
			
		||||
	for _, tc := range tests {
 | 
			
		||||
		func() {
 | 
			
		||||
			valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode())
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if !tc.wantErr {
 | 
			
		||||
					t.Errorf("%s: %v", tc.name, err)
 | 
			
		||||
				}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else if err != nil {
 | 
			
		||||
			t.Errorf("case %d: unexpected error making round trip: %v", i, err)
 | 
			
		||||
 | 
			
		||||
		} else {
 | 
			
		||||
			if res.StatusCode != tt.wantStatus {
 | 
			
		||||
				t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode)
 | 
			
		||||
			if tc.wantErr {
 | 
			
		||||
				t.Errorf("%s: expected error", tc.name)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err = client.verify(); err != nil {
 | 
			
		||||
			t.Errorf("case %d: %v", i, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err = persister.verify(); err != nil {
 | 
			
		||||
			t.Errorf("case %d: %v", i, err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err = tstRT.verify(); err != nil {
 | 
			
		||||
			t.Errorf("case %d: %v", i, err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
			if valid && tc.wantExpired {
 | 
			
		||||
				t.Errorf("%s: expected token to be expired", tc.name)
 | 
			
		||||
			}
 | 
			
		||||
			if !valid && !tc.wantExpired {
 | 
			
		||||
				t.Errorf("%s: expected token to be valid", tc.name)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testRoundTrip struct {
 | 
			
		||||
	expectBearerToken string
 | 
			
		||||
	returnHTTPStatus  int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testRoundTripper struct {
 | 
			
		||||
	trips []testRoundTrip
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
 | 
			
		||||
	if len(t.trips) == 0 {
 | 
			
		||||
		return nil, errors.New("unexpected RoundTrip call")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var trip testRoundTrip
 | 
			
		||||
	trip, t.trips = t.trips[0], t.trips[1:]
 | 
			
		||||
 | 
			
		||||
	var bt string
 | 
			
		||||
	var parts []string
 | 
			
		||||
	auth := strings.TrimSpace(req.Header.Get("Authorization"))
 | 
			
		||||
	if auth == "" {
 | 
			
		||||
		goto Compare
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	parts = strings.Split(auth, " ")
 | 
			
		||||
	if len(parts) < 2 || strings.ToLower(parts[0]) != "bearer" {
 | 
			
		||||
		goto Compare
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	bt = parts[1]
 | 
			
		||||
 | 
			
		||||
Compare:
 | 
			
		||||
	if trip.expectBearerToken != bt {
 | 
			
		||||
		return nil, fmt.Errorf("want bearerToken=%v, got=%v", trip.expectBearerToken, bt)
 | 
			
		||||
	}
 | 
			
		||||
	return &http.Response{
 | 
			
		||||
		StatusCode: trip.returnHTTPStatus,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *testRoundTripper) verify() error {
 | 
			
		||||
	if l := len(t.trips); l > 0 {
 | 
			
		||||
		return fmt.Errorf("%d uncalled round trips", l)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testPersist struct {
 | 
			
		||||
	cfg       map[string]string
 | 
			
		||||
	returnErr error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testPersister struct {
 | 
			
		||||
	persists []testPersist
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *testPersister) Persist(cfg map[string]string) error {
 | 
			
		||||
	if len(t.persists) == 0 {
 | 
			
		||||
		return errors.New("unexpected persist call")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var persist testPersist
 | 
			
		||||
	persist, t.persists = t.persists[0], t.persists[1:]
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(persist.cfg, cfg) {
 | 
			
		||||
		return fmt.Errorf("Unexpected cfg: %v", diff.ObjectDiff(persist.cfg, cfg))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return persist.returnErr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *testPersister) verify() error {
 | 
			
		||||
	if l := len(t.persists); l > 0 {
 | 
			
		||||
		return fmt.Errorf("%d uncalled persists", l)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testRefresh struct {
 | 
			
		||||
	expectRefreshToken string
 | 
			
		||||
 | 
			
		||||
	returnErr    error
 | 
			
		||||
	returnTokens oauth2.TokenResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testOIDCClient struct {
 | 
			
		||||
	refreshes []testRefresh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *testOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
 | 
			
		||||
	if len(o.refreshes) == 0 {
 | 
			
		||||
		return oauth2.TokenResponse{}, errors.New("unexpected refresh request")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var refresh testRefresh
 | 
			
		||||
	refresh, o.refreshes = o.refreshes[0], o.refreshes[1:]
 | 
			
		||||
 | 
			
		||||
	if rt != refresh.expectRefreshToken {
 | 
			
		||||
		return oauth2.TokenResponse{}, fmt.Errorf("want rt=%v, got=%v",
 | 
			
		||||
			refresh.expectRefreshToken,
 | 
			
		||||
			rt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if refresh.returnErr != nil {
 | 
			
		||||
		return oauth2.TokenResponse{}, refresh.returnErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return refresh.returnTokens, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *testOIDCClient) verifyJWT(jwt jose.JWT) error {
 | 
			
		||||
	claims, err := jwt.Claims()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	claim, _, _ := claims.StringClaim("test")
 | 
			
		||||
	if claim != "good" {
 | 
			
		||||
		return errors.New("bad token")
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *testOIDCClient) verify() error {
 | 
			
		||||
	if l := len(t.refreshes); l > 0 {
 | 
			
		||||
		return fmt.Errorf("%d uncalled refreshes", l)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func compareJWTs(a, b jose.JWT) string {
 | 
			
		||||
	if a.Encode() == b.Encode() {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var aClaims, bClaims jose.Claims
 | 
			
		||||
	for _, j := range []struct {
 | 
			
		||||
		claims *jose.Claims
 | 
			
		||||
		jwt    jose.JWT
 | 
			
		||||
	}{
 | 
			
		||||
		{&aClaims, a},
 | 
			
		||||
		{&bClaims, b},
 | 
			
		||||
	} {
 | 
			
		||||
		var err error
 | 
			
		||||
		*j.claims, err = j.jwt.Claims()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			*j.claims = jose.Claims(map[string]interface{}{
 | 
			
		||||
				"msg": "bad claims",
 | 
			
		||||
				"err": err,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return diff.ObjectDiff(aClaims, bClaims)
 | 
			
		||||
func TestClientCache(t *testing.T) {
 | 
			
		||||
	cache := newClientCache()
 | 
			
		||||
 | 
			
		||||
	if _, ok := cache.getClient("issuer1", "id1", "secret1"); ok {
 | 
			
		||||
		t.Fatalf("got client before putting one in the cache")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cli1 := new(oidcAuthProvider)
 | 
			
		||||
	cli2 := new(oidcAuthProvider)
 | 
			
		||||
 | 
			
		||||
	gotcli := cache.setClient("issuer1", "id1", "secret1", cli1)
 | 
			
		||||
	if cli1 != gotcli {
 | 
			
		||||
		t.Fatalf("set first client and got a different one")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	gotcli = cache.setClient("issuer1", "id1", "secret1", cli2)
 | 
			
		||||
	if cli1 != gotcli {
 | 
			
		||||
		t.Fatalf("set a second client and didn't get the first")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user