mirror of
https://github.com/optim-enterprises-bv/kubernetes.git
synced 2025-11-02 11:18:16 +00:00
rework oidc client auth provider
* Cache OpenID Connect clients to prevent reinitialization * Don't retry requests in the http.RoundTripper. * Don't rely on the server not reading POST bodies. * Don't leak response body FDs. * Formerly ignored any throttling requests by the server. * Determine if the id token's expired by inspecting it. * Similar to logic in golang.org/x/oauth2 * Synchronize around refreshing tokens and persisting the new config.
This commit is contained in:
@@ -19,13 +19,10 @@ package oidc
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -33,11 +30,41 @@ import (
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
|
||||
"k8s.io/kubernetes/pkg/util/diff"
|
||||
"k8s.io/kubernetes/pkg/util/wait"
|
||||
oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing"
|
||||
)
|
||||
|
||||
func clearCache() {
|
||||
cache = newClientCache()
|
||||
}
|
||||
|
||||
type persister struct{}
|
||||
|
||||
// we don't need to actually persist anything because there's no way for us to
|
||||
// read from a persister.
|
||||
func (p *persister) Persist(map[string]string) error { return nil }
|
||||
|
||||
type noRefreshOIDCClient struct{}
|
||||
|
||||
func (c *noRefreshOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||
return oauth2.TokenResponse{}, errors.New("alwaysErrOIDCClient: cannot refresh token")
|
||||
}
|
||||
|
||||
func (c *noRefreshOIDCClient) verifyJWT(jwt *jose.JWT) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockOIDCClient struct {
|
||||
tokenResponse oauth2.TokenResponse
|
||||
}
|
||||
|
||||
func (c *mockOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||
return c.tokenResponse, nil
|
||||
}
|
||||
|
||||
func (c *mockOIDCClient) verifyJWT(jwt *jose.JWT) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewOIDCAuthProvider(t *testing.T) {
|
||||
tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test")
|
||||
if err != nil {
|
||||
@@ -60,127 +87,211 @@ func TestNewOIDCAuthProvider(t *testing.T) {
|
||||
t.Fatalf("Could not read cert bytes %v", err)
|
||||
}
|
||||
|
||||
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||
"test": "jwt",
|
||||
}), op.PrivKey.Signer())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create signed JWT %v", err)
|
||||
makeToken := func(exp time.Time) *jose.JWT {
|
||||
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||
"exp": exp.UTC().Unix(),
|
||||
}), op.PrivKey.Signer())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create signed JWT %v", err)
|
||||
}
|
||||
return jwt
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
cfg map[string]string
|
||||
t0 := time.Now()
|
||||
|
||||
wantErr bool
|
||||
wantInitialIDToken jose.JWT
|
||||
goodToken := makeToken(t0.Add(time.Hour)).Encode()
|
||||
expiredToken := makeToken(t0.Add(-time.Hour)).Encode()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
cfg map[string]string
|
||||
wantInitErr bool
|
||||
|
||||
client OIDCClient
|
||||
wantCfg map[string]string
|
||||
wantTokenErr bool
|
||||
}{
|
||||
{
|
||||
// A Valid configuration
|
||||
name: "no id token and no refresh token",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
wantTokenErr: true,
|
||||
},
|
||||
{
|
||||
// A Valid configuration with an Initial JWT
|
||||
name: "valid config with an initial token",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgIDToken: jwt.Encode(),
|
||||
cfgIDToken: goodToken,
|
||||
},
|
||||
client: new(noRefreshOIDCClient),
|
||||
wantCfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgIDToken: goodToken,
|
||||
},
|
||||
wantInitialIDToken: *jwt,
|
||||
},
|
||||
{
|
||||
// Valid config, but using cfgCertificateAuthorityData
|
||||
name: "invalid ID token with a refresh token",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgRefreshToken: "foo",
|
||||
cfgIDToken: expiredToken,
|
||||
},
|
||||
client: &mockOIDCClient{
|
||||
tokenResponse: oauth2.TokenResponse{
|
||||
IDToken: goodToken,
|
||||
},
|
||||
},
|
||||
wantCfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgRefreshToken: "foo",
|
||||
cfgIDToken: goodToken,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid ID token with a refresh token, server returns new refresh token",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgRefreshToken: "foo",
|
||||
cfgIDToken: expiredToken,
|
||||
},
|
||||
client: &mockOIDCClient{
|
||||
tokenResponse: oauth2.TokenResponse{
|
||||
IDToken: goodToken,
|
||||
RefreshToken: "bar",
|
||||
},
|
||||
},
|
||||
wantCfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgRefreshToken: "bar",
|
||||
cfgIDToken: goodToken,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired token and no refresh otken",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgIDToken: expiredToken,
|
||||
},
|
||||
wantTokenErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid base64d ca",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData),
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
client: new(noRefreshOIDCClient),
|
||||
wantTokenErr: true,
|
||||
},
|
||||
{
|
||||
// Missing client id
|
||||
name: "missing client ID",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
wantErr: true,
|
||||
wantInitErr: true,
|
||||
},
|
||||
{
|
||||
// Missing client secret
|
||||
name: "missing client secret",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
},
|
||||
wantErr: true,
|
||||
wantInitErr: true,
|
||||
},
|
||||
{
|
||||
// Missing issuer url.
|
||||
name: "missing issuer URL",
|
||||
cfg: map[string]string{
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "secret",
|
||||
},
|
||||
wantErr: true,
|
||||
wantInitErr: true,
|
||||
},
|
||||
{
|
||||
// No TLS config
|
||||
name: "missing TLS config",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "secret",
|
||||
},
|
||||
wantErr: true,
|
||||
wantInitErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
ap, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, nil)
|
||||
if tt.wantErr {
|
||||
for _, tt := range tests {
|
||||
clearCache()
|
||||
|
||||
p, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, new(persister))
|
||||
if tt.wantInitErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
t.Errorf("%s: want non-nil err", tt.name)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error on newOIDCAuthProvider: %v", i, err)
|
||||
t.Errorf("%s: unexpected error on newOIDCAuthProvider: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
oidcAP, ok := ap.(*oidcAuthProvider)
|
||||
if !ok {
|
||||
t.Errorf("case %d: expected ap to be an oidcAuthProvider", i)
|
||||
provider := p.(*oidcAuthProvider)
|
||||
provider.client = tt.client
|
||||
provider.now = func() time.Time { return t0 }
|
||||
|
||||
if _, err := provider.idToken(); err != nil {
|
||||
if !tt.wantTokenErr {
|
||||
t.Errorf("%s: failed to get id token: %v", tt.name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tt.wantTokenErr {
|
||||
t.Errorf("%s: expected to not get id token: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if diff := compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken); diff != "" {
|
||||
t.Errorf("case %d: compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken)=%v", i, diff)
|
||||
if !reflect.DeepEqual(tt.wantCfg, provider.cfg) {
|
||||
t.Errorf("%s: expected config %#v got %#v", tt.name, tt.wantCfg, provider.cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapTranport(t *testing.T) {
|
||||
oldBackoff := backoff
|
||||
defer func() {
|
||||
backoff = oldBackoff
|
||||
}()
|
||||
backoff = wait.Backoff{
|
||||
Duration: 1 * time.Nanosecond,
|
||||
Steps: 3,
|
||||
}
|
||||
|
||||
func TestVerifyJWTExpiry(t *testing.T) {
|
||||
privKey, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("can't generate private key: %v", err)
|
||||
}
|
||||
|
||||
makeToken := func(s string, exp time.Time, count int) *jose.JWT {
|
||||
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||
"test": s,
|
||||
@@ -193,451 +304,81 @@ func TestWrapTranport(t *testing.T) {
|
||||
return jwt
|
||||
}
|
||||
|
||||
goodToken := makeToken("good", time.Now().Add(time.Hour), 0)
|
||||
goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1)
|
||||
expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0)
|
||||
t0 := time.Now()
|
||||
|
||||
str := func(s string) *string {
|
||||
return &s
|
||||
}
|
||||
tests := []struct {
|
||||
cfgIDToken *jose.JWT
|
||||
cfgRefreshToken *string
|
||||
|
||||
expectRequests []testRoundTrip
|
||||
|
||||
expectRefreshes []testRefresh
|
||||
|
||||
expectPersists []testPersist
|
||||
|
||||
wantStatus int
|
||||
wantErr bool
|
||||
name string
|
||||
jwt *jose.JWT
|
||||
now time.Time
|
||||
wantErr bool
|
||||
wantExpired bool
|
||||
}{
|
||||
{
|
||||
// Initial JWT is set, it is good, it is set as bearer.
|
||||
cfgIDToken: goodToken,
|
||||
|
||||
expectRequests: []testRoundTrip{
|
||||
{
|
||||
expectBearerToken: goodToken.Encode(),
|
||||
returnHTTPStatus: 200,
|
||||
},
|
||||
},
|
||||
|
||||
wantStatus: 200,
|
||||
name: "valid jwt",
|
||||
jwt: makeToken("foo", t0.Add(time.Hour), 1),
|
||||
now: t0,
|
||||
},
|
||||
{
|
||||
// Initial JWT is set, but it's expired, so it gets refreshed.
|
||||
cfgIDToken: expiredToken,
|
||||
cfgRefreshToken: str("rt1"),
|
||||
|
||||
expectRefreshes: []testRefresh{
|
||||
{
|
||||
expectRefreshToken: "rt1",
|
||||
returnTokens: oauth2.TokenResponse{
|
||||
IDToken: goodToken.Encode(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
expectRequests: []testRoundTrip{
|
||||
{
|
||||
expectBearerToken: goodToken.Encode(),
|
||||
returnHTTPStatus: 200,
|
||||
},
|
||||
},
|
||||
|
||||
expectPersists: []testPersist{
|
||||
{
|
||||
cfg: map[string]string{
|
||||
cfgIDToken: goodToken.Encode(),
|
||||
cfgRefreshToken: "rt1",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
wantStatus: 200,
|
||||
},
|
||||
{
|
||||
// Initial JWT is set, but it's expired, so it gets refreshed - this
|
||||
// time the refresh token itself is also refreshed
|
||||
cfgIDToken: expiredToken,
|
||||
cfgRefreshToken: str("rt1"),
|
||||
|
||||
expectRefreshes: []testRefresh{
|
||||
{
|
||||
expectRefreshToken: "rt1",
|
||||
returnTokens: oauth2.TokenResponse{
|
||||
IDToken: goodToken.Encode(),
|
||||
RefreshToken: "rt2",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
expectRequests: []testRoundTrip{
|
||||
{
|
||||
expectBearerToken: goodToken.Encode(),
|
||||
returnHTTPStatus: 200,
|
||||
},
|
||||
},
|
||||
|
||||
expectPersists: []testPersist{
|
||||
{
|
||||
cfg: map[string]string{
|
||||
cfgIDToken: goodToken.Encode(),
|
||||
cfgRefreshToken: "rt2",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
wantStatus: 200,
|
||||
},
|
||||
{
|
||||
// Initial JWT is not set, so it gets refreshed.
|
||||
cfgRefreshToken: str("rt1"),
|
||||
|
||||
expectRefreshes: []testRefresh{
|
||||
{
|
||||
expectRefreshToken: "rt1",
|
||||
returnTokens: oauth2.TokenResponse{
|
||||
IDToken: goodToken.Encode(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
expectRequests: []testRoundTrip{
|
||||
{
|
||||
expectBearerToken: goodToken.Encode(),
|
||||
returnHTTPStatus: 200,
|
||||
},
|
||||
},
|
||||
|
||||
expectPersists: []testPersist{
|
||||
{
|
||||
cfg: map[string]string{
|
||||
cfgIDToken: goodToken.Encode(),
|
||||
cfgRefreshToken: "rt1",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
wantStatus: 200,
|
||||
},
|
||||
{
|
||||
// Expired token, but no refresh token.
|
||||
cfgIDToken: expiredToken,
|
||||
|
||||
name: "invalid jwt",
|
||||
jwt: &jose.JWT{},
|
||||
now: t0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
// Initial JWT is not set, so it gets refreshed, but the server
|
||||
// rejects it when it is used, so it refreshes again, which
|
||||
// succeeds.
|
||||
cfgRefreshToken: str("rt1"),
|
||||
|
||||
expectRefreshes: []testRefresh{
|
||||
{
|
||||
expectRefreshToken: "rt1",
|
||||
returnTokens: oauth2.TokenResponse{
|
||||
IDToken: goodToken.Encode(),
|
||||
},
|
||||
},
|
||||
{
|
||||
expectRefreshToken: "rt1",
|
||||
returnTokens: oauth2.TokenResponse{
|
||||
IDToken: goodToken2.Encode(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
expectRequests: []testRoundTrip{
|
||||
{
|
||||
expectBearerToken: goodToken.Encode(),
|
||||
returnHTTPStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
expectBearerToken: goodToken2.Encode(),
|
||||
returnHTTPStatus: http.StatusOK,
|
||||
},
|
||||
},
|
||||
|
||||
expectPersists: []testPersist{
|
||||
{
|
||||
cfg: map[string]string{
|
||||
cfgIDToken: goodToken.Encode(),
|
||||
cfgRefreshToken: "rt1",
|
||||
},
|
||||
},
|
||||
{
|
||||
cfg: map[string]string{
|
||||
cfgIDToken: goodToken2.Encode(),
|
||||
cfgRefreshToken: "rt1",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
wantStatus: 200,
|
||||
name: "expired jwt",
|
||||
jwt: makeToken("foo", t0.Add(-time.Hour), 1),
|
||||
now: t0,
|
||||
wantExpired: true,
|
||||
},
|
||||
{
|
||||
// Initial JWT is but the server rejects it when it is used, so it
|
||||
// refreshes again, which succeeds.
|
||||
cfgRefreshToken: str("rt1"),
|
||||
cfgIDToken: goodToken,
|
||||
|
||||
expectRefreshes: []testRefresh{
|
||||
{
|
||||
expectRefreshToken: "rt1",
|
||||
returnTokens: oauth2.TokenResponse{
|
||||
IDToken: goodToken2.Encode(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
expectRequests: []testRoundTrip{
|
||||
{
|
||||
expectBearerToken: goodToken.Encode(),
|
||||
returnHTTPStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
expectBearerToken: goodToken2.Encode(),
|
||||
returnHTTPStatus: http.StatusOK,
|
||||
},
|
||||
},
|
||||
|
||||
expectPersists: []testPersist{
|
||||
{
|
||||
cfg: map[string]string{
|
||||
cfgIDToken: goodToken2.Encode(),
|
||||
cfgRefreshToken: "rt1",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStatus: 200,
|
||||
name: "jwt expires soon enough to be marked expired",
|
||||
jwt: makeToken("foo", t0, 1),
|
||||
now: t0,
|
||||
wantExpired: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
client := &testOIDCClient{
|
||||
refreshes: tt.expectRefreshes,
|
||||
}
|
||||
|
||||
persister := &testPersister{
|
||||
tt.expectPersists,
|
||||
}
|
||||
|
||||
cfg := map[string]string{}
|
||||
if tt.cfgIDToken != nil {
|
||||
cfg[cfgIDToken] = tt.cfgIDToken.Encode()
|
||||
}
|
||||
|
||||
if tt.cfgRefreshToken != nil {
|
||||
cfg[cfgRefreshToken] = *tt.cfgRefreshToken
|
||||
}
|
||||
|
||||
ap := &oidcAuthProvider{
|
||||
refresher: &idTokenRefresher{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
persister: persister,
|
||||
},
|
||||
}
|
||||
|
||||
if tt.cfgIDToken != nil {
|
||||
ap.initialIDToken = *tt.cfgIDToken
|
||||
}
|
||||
|
||||
tstRT := &testRoundTripper{
|
||||
tt.expectRequests,
|
||||
}
|
||||
|
||||
rt := ap.WrapTransport(tstRT)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://cluster.example.com", nil)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error making request: %v", i, err)
|
||||
}
|
||||
|
||||
res, err := rt.RoundTrip(req)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: Expected non-nil error", i)
|
||||
for _, tc := range tests {
|
||||
func() {
|
||||
valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode())
|
||||
if err != nil {
|
||||
if !tc.wantErr {
|
||||
t.Errorf("%s: %v", tc.name, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("case %d: unexpected error making round trip: %v", i, err)
|
||||
|
||||
} else {
|
||||
if res.StatusCode != tt.wantStatus {
|
||||
t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode)
|
||||
if tc.wantErr {
|
||||
t.Errorf("%s: expected error", tc.name)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err = client.verify(); err != nil {
|
||||
t.Errorf("case %d: %v", i, err)
|
||||
}
|
||||
|
||||
if err = persister.verify(); err != nil {
|
||||
t.Errorf("case %d: %v", i, err)
|
||||
}
|
||||
|
||||
if err = tstRT.verify(); err != nil {
|
||||
t.Errorf("case %d: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if valid && tc.wantExpired {
|
||||
t.Errorf("%s: expected token to be expired", tc.name)
|
||||
}
|
||||
if !valid && !tc.wantExpired {
|
||||
t.Errorf("%s: expected token to be valid", tc.name)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
type testRoundTrip struct {
|
||||
expectBearerToken string
|
||||
returnHTTPStatus int
|
||||
}
|
||||
|
||||
type testRoundTripper struct {
|
||||
trips []testRoundTrip
|
||||
}
|
||||
|
||||
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if len(t.trips) == 0 {
|
||||
return nil, errors.New("unexpected RoundTrip call")
|
||||
}
|
||||
|
||||
var trip testRoundTrip
|
||||
trip, t.trips = t.trips[0], t.trips[1:]
|
||||
|
||||
var bt string
|
||||
var parts []string
|
||||
auth := strings.TrimSpace(req.Header.Get("Authorization"))
|
||||
if auth == "" {
|
||||
goto Compare
|
||||
}
|
||||
|
||||
parts = strings.Split(auth, " ")
|
||||
if len(parts) < 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
goto Compare
|
||||
}
|
||||
|
||||
bt = parts[1]
|
||||
|
||||
Compare:
|
||||
if trip.expectBearerToken != bt {
|
||||
return nil, fmt.Errorf("want bearerToken=%v, got=%v", trip.expectBearerToken, bt)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: trip.returnHTTPStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *testRoundTripper) verify() error {
|
||||
if l := len(t.trips); l > 0 {
|
||||
return fmt.Errorf("%d uncalled round trips", l)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type testPersist struct {
|
||||
cfg map[string]string
|
||||
returnErr error
|
||||
}
|
||||
|
||||
type testPersister struct {
|
||||
persists []testPersist
|
||||
}
|
||||
|
||||
func (t *testPersister) Persist(cfg map[string]string) error {
|
||||
if len(t.persists) == 0 {
|
||||
return errors.New("unexpected persist call")
|
||||
}
|
||||
|
||||
var persist testPersist
|
||||
persist, t.persists = t.persists[0], t.persists[1:]
|
||||
|
||||
if !reflect.DeepEqual(persist.cfg, cfg) {
|
||||
return fmt.Errorf("Unexpected cfg: %v", diff.ObjectDiff(persist.cfg, cfg))
|
||||
}
|
||||
|
||||
return persist.returnErr
|
||||
}
|
||||
|
||||
func (t *testPersister) verify() error {
|
||||
if l := len(t.persists); l > 0 {
|
||||
return fmt.Errorf("%d uncalled persists", l)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type testRefresh struct {
|
||||
expectRefreshToken string
|
||||
|
||||
returnErr error
|
||||
returnTokens oauth2.TokenResponse
|
||||
}
|
||||
|
||||
type testOIDCClient struct {
|
||||
refreshes []testRefresh
|
||||
}
|
||||
|
||||
func (o *testOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||
if len(o.refreshes) == 0 {
|
||||
return oauth2.TokenResponse{}, errors.New("unexpected refresh request")
|
||||
}
|
||||
|
||||
var refresh testRefresh
|
||||
refresh, o.refreshes = o.refreshes[0], o.refreshes[1:]
|
||||
|
||||
if rt != refresh.expectRefreshToken {
|
||||
return oauth2.TokenResponse{}, fmt.Errorf("want rt=%v, got=%v",
|
||||
refresh.expectRefreshToken,
|
||||
rt)
|
||||
}
|
||||
|
||||
if refresh.returnErr != nil {
|
||||
return oauth2.TokenResponse{}, refresh.returnErr
|
||||
}
|
||||
|
||||
return refresh.returnTokens, nil
|
||||
}
|
||||
|
||||
func (o *testOIDCClient) verifyJWT(jwt jose.JWT) error {
|
||||
claims, err := jwt.Claims()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
claim, _, _ := claims.StringClaim("test")
|
||||
if claim != "good" {
|
||||
return errors.New("bad token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testOIDCClient) verify() error {
|
||||
if l := len(t.refreshes); l > 0 {
|
||||
return fmt.Errorf("%d uncalled refreshes", l)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareJWTs(a, b jose.JWT) string {
|
||||
if a.Encode() == b.Encode() {
|
||||
return ""
|
||||
}
|
||||
|
||||
var aClaims, bClaims jose.Claims
|
||||
for _, j := range []struct {
|
||||
claims *jose.Claims
|
||||
jwt jose.JWT
|
||||
}{
|
||||
{&aClaims, a},
|
||||
{&bClaims, b},
|
||||
} {
|
||||
var err error
|
||||
*j.claims, err = j.jwt.Claims()
|
||||
if err != nil {
|
||||
*j.claims = jose.Claims(map[string]interface{}{
|
||||
"msg": "bad claims",
|
||||
"err": err,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return diff.ObjectDiff(aClaims, bClaims)
|
||||
func TestClientCache(t *testing.T) {
|
||||
cache := newClientCache()
|
||||
|
||||
if _, ok := cache.getClient("issuer1", "id1", "secret1"); ok {
|
||||
t.Fatalf("got client before putting one in the cache")
|
||||
}
|
||||
|
||||
cli1 := new(oidcAuthProvider)
|
||||
cli2 := new(oidcAuthProvider)
|
||||
|
||||
gotcli := cache.setClient("issuer1", "id1", "secret1", cli1)
|
||||
if cli1 != gotcli {
|
||||
t.Fatalf("set first client and got a different one")
|
||||
}
|
||||
|
||||
gotcli = cache.setClient("issuer1", "id1", "secret1", cli2)
|
||||
if cli1 != gotcli {
|
||||
t.Fatalf("set a second client and didn't get the first")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user