mirror of
https://github.com/outbackdingo/kubernetes.git
synced 2026-02-05 00:26:04 +00:00
474 lines
13 KiB
Go
474 lines
13 KiB
Go
/*
|
|
Copyright 2024 The Kubernetes Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package plugin
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
"gopkg.in/go-jose/go-jose.v2/jwt"
|
|
|
|
"k8s.io/kubernetes/pkg/serviceaccount"
|
|
|
|
utilnettesting "k8s.io/apimachinery/pkg/util/net/testing"
|
|
externaljwtv1 "k8s.io/externaljwt/apis/v1"
|
|
)
|
|
|
|
var (
|
|
rsaKey1 *rsa.PrivateKey
|
|
rsaKey2 *rsa.PrivateKey
|
|
)
|
|
|
|
func init() {
|
|
var err error
|
|
|
|
rsaKey1, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
panic("Error while generating first RSA key")
|
|
}
|
|
|
|
rsaKey2, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
panic("Error while generating second RSA key")
|
|
}
|
|
}
|
|
|
|
func TestExternalTokenGenerator(t *testing.T) {
|
|
testCases := []struct {
|
|
desc string
|
|
|
|
publicClaims jwt.Claims
|
|
privateClaims privateClaimsT
|
|
|
|
iss string
|
|
backendSetKeyID string
|
|
backendSetAlgorithm string
|
|
backendHeaderType string
|
|
supportedKeys map[string]supportedKeyT
|
|
allowSigningWithNonOIDCKeys bool
|
|
|
|
wantClaims unifiedClaimsT
|
|
wantErr error
|
|
}{
|
|
{
|
|
desc: "correct token with correct claims returned",
|
|
publicClaims: jwt.Claims{
|
|
Subject: "some-subject",
|
|
Audience: jwt.Audience{
|
|
"some-audience-1",
|
|
"some-audience-2",
|
|
},
|
|
ID: "id-1",
|
|
},
|
|
privateClaims: privateClaimsT{
|
|
Kubernetes: kubernetesT{
|
|
Namespace: "foo",
|
|
Svcacct: refT{
|
|
Name: "default",
|
|
UID: "abcdef",
|
|
},
|
|
},
|
|
},
|
|
iss: "some-issuer",
|
|
backendSetKeyID: "key-id-1",
|
|
backendHeaderType: "JWT",
|
|
backendSetAlgorithm: "RS256",
|
|
supportedKeys: map[string]supportedKeyT{
|
|
"key-id-1": {
|
|
key: &rsaKey1.PublicKey,
|
|
},
|
|
},
|
|
|
|
wantClaims: unifiedClaimsT{
|
|
Issuer: "some-issuer",
|
|
Subject: "some-subject",
|
|
Audience: jwt.Audience{
|
|
"some-audience-1",
|
|
"some-audience-2",
|
|
},
|
|
ID: "id-1",
|
|
Kubernetes: kubernetesT{
|
|
Namespace: "foo",
|
|
Svcacct: refT{
|
|
Name: "default",
|
|
UID: "abcdef",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
desc: "correct token with correct claims signed by key that's excluded from OIDC",
|
|
publicClaims: jwt.Claims{
|
|
Subject: "some-subject",
|
|
Audience: jwt.Audience{
|
|
"some-audience-1",
|
|
"some-audience-2",
|
|
},
|
|
ID: "id-1",
|
|
},
|
|
privateClaims: privateClaimsT{
|
|
Kubernetes: kubernetesT{
|
|
Namespace: "foo",
|
|
Svcacct: refT{
|
|
Name: "default",
|
|
UID: "abcdef",
|
|
},
|
|
},
|
|
},
|
|
iss: "some-issuer",
|
|
backendSetKeyID: "key-id-1",
|
|
backendHeaderType: "JWT",
|
|
backendSetAlgorithm: "RS256",
|
|
supportedKeys: map[string]supportedKeyT{
|
|
"key-id-1": {
|
|
key: &rsaKey1.PublicKey,
|
|
excludeFromOidc: true,
|
|
},
|
|
},
|
|
|
|
wantErr: fmt.Errorf("while validating header: key used for signing JWT (kid: key-id-1) is excluded from OIDC discovery docs"),
|
|
},
|
|
{
|
|
desc: "token signed with key that's excluded from OIDC but validation is disabled",
|
|
publicClaims: jwt.Claims{
|
|
Subject: "some-subject",
|
|
Audience: jwt.Audience{
|
|
"some-audience-1",
|
|
"some-audience-2",
|
|
},
|
|
ID: "key-id-1",
|
|
},
|
|
privateClaims: privateClaimsT{
|
|
Kubernetes: kubernetesT{
|
|
Namespace: "foo",
|
|
Svcacct: refT{
|
|
Name: "default",
|
|
UID: "abcdef",
|
|
},
|
|
},
|
|
},
|
|
iss: "some-issuer",
|
|
backendSetKeyID: "key-id-1",
|
|
backendHeaderType: "JWT",
|
|
backendSetAlgorithm: "RS256",
|
|
supportedKeys: map[string]supportedKeyT{
|
|
"key-id-1": {
|
|
key: &rsaKey1.PublicKey,
|
|
excludeFromOidc: true,
|
|
},
|
|
},
|
|
allowSigningWithNonOIDCKeys: true,
|
|
|
|
wantClaims: unifiedClaimsT{
|
|
Issuer: "some-issuer",
|
|
Subject: "some-subject",
|
|
Audience: jwt.Audience{
|
|
"some-audience-1",
|
|
"some-audience-2",
|
|
},
|
|
ID: "key-id-1",
|
|
Kubernetes: kubernetesT{
|
|
Namespace: "foo",
|
|
Svcacct: refT{
|
|
Name: "default",
|
|
UID: "abcdef",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
desc: "empty key ID returned from signer",
|
|
iss: "some-issuer",
|
|
backendSetKeyID: "",
|
|
backendHeaderType: "JWT",
|
|
backendSetAlgorithm: "RS256",
|
|
supportedKeys: map[string]supportedKeyT{
|
|
"key-id-1": {
|
|
key: &rsaKey1.PublicKey,
|
|
excludeFromOidc: true,
|
|
},
|
|
},
|
|
wantErr: fmt.Errorf("while validating header: key id missing"),
|
|
},
|
|
{
|
|
desc: "key id longer than 1024 bytes returned from signer",
|
|
iss: "some-issuer",
|
|
backendSetKeyID: string(make([]byte, 1025)),
|
|
backendHeaderType: "JWT",
|
|
backendSetAlgorithm: "RS256",
|
|
supportedKeys: map[string]supportedKeyT{
|
|
"key-id-1": {
|
|
key: &rsaKey1.PublicKey,
|
|
excludeFromOidc: true,
|
|
},
|
|
},
|
|
wantErr: fmt.Errorf("while validating header: key id longer than 1 kb"),
|
|
},
|
|
{
|
|
desc: "unsupported alg returned from signer",
|
|
iss: "some-issuer",
|
|
backendSetKeyID: "key-id-1",
|
|
backendHeaderType: "JWT",
|
|
backendSetAlgorithm: "something-unsupported",
|
|
supportedKeys: map[string]supportedKeyT{
|
|
"key-id-1": {
|
|
key: &rsaKey1.PublicKey,
|
|
excludeFromOidc: true,
|
|
},
|
|
},
|
|
wantErr: fmt.Errorf("while validating header: bad signing algorithm \"something-unsupported\""),
|
|
},
|
|
{
|
|
desc: "empty alg returned from signer",
|
|
iss: "some-issuer",
|
|
backendSetKeyID: "key-id-1",
|
|
backendHeaderType: "JWT",
|
|
backendSetAlgorithm: "",
|
|
supportedKeys: map[string]supportedKeyT{
|
|
"key-id-1": {
|
|
key: &rsaKey1.PublicKey,
|
|
excludeFromOidc: true,
|
|
},
|
|
},
|
|
wantErr: fmt.Errorf("while validating header: bad signing algorithm \"\""),
|
|
},
|
|
{
|
|
desc: "Invalid backend header type",
|
|
iss: "some-issuer",
|
|
backendSetKeyID: "key-id-1",
|
|
backendHeaderType: "WHAT",
|
|
backendSetAlgorithm: "RS256",
|
|
supportedKeys: map[string]supportedKeyT{
|
|
"key-id-1": {
|
|
key: &rsaKey1.PublicKey,
|
|
excludeFromOidc: true,
|
|
},
|
|
},
|
|
wantErr: fmt.Errorf("while validating header: bad type"),
|
|
},
|
|
}
|
|
|
|
for i, tc := range testCases {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
sockname := utilnettesting.MakeSocketNameForTest(t, fmt.Sprintf("test-external-token-generator-%d-%d.sock", time.Now().Nanosecond(), i))
|
|
|
|
addr := &net.UnixAddr{Name: sockname, Net: "unix"}
|
|
listener, err := net.ListenUnix(addr.Network(), addr)
|
|
if err != nil {
|
|
t.Fatalf("Failed to start fake backend: %v", err)
|
|
}
|
|
|
|
grpcServer := grpc.NewServer()
|
|
|
|
backend := &dummyExtrnalSigner{
|
|
keyID: tc.backendSetKeyID,
|
|
signingAlgorithm: tc.backendSetAlgorithm,
|
|
signature: "abcdef",
|
|
supportedKeys: tc.supportedKeys,
|
|
refreshHintSeconds: 10,
|
|
DataTimeStamp: timestamppb.New(time.Time{}),
|
|
headerType: tc.backendHeaderType,
|
|
}
|
|
externaljwtv1.RegisterExternalJWTSignerServer(grpcServer, backend)
|
|
|
|
go func() {
|
|
if err := grpcServer.Serve(listener); err != nil {
|
|
panic(fmt.Errorf("error returned from grpcServer: %w", err))
|
|
}
|
|
}()
|
|
defer grpcServer.Stop()
|
|
|
|
clientConn, err := grpc.DialContext(
|
|
ctx,
|
|
sockname,
|
|
grpc.WithContextDialer(func(ctx context.Context, path string) (net.Conn, error) {
|
|
return (&net.Dialer{}).DialContext(ctx, "unix", path)
|
|
}),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Failed to dial buffconn client: %v", err)
|
|
}
|
|
defer func() { _ = clientConn.Close() }()
|
|
|
|
plugin := newPlugin(tc.iss, clientConn, tc.allowSigningWithNonOIDCKeys)
|
|
err = plugin.keyCache.initialFill(ctx)
|
|
if err != nil {
|
|
t.Fatalf("initial fill failed: %v", err)
|
|
}
|
|
|
|
gotToken, err := plugin.GenerateToken(ctx, &tc.publicClaims, tc.privateClaims)
|
|
if err != nil && tc.wantErr != nil {
|
|
if err.Error() != tc.wantErr.Error() {
|
|
t.Fatalf("want error: %v, got error: %v", tc.wantErr, err)
|
|
}
|
|
return
|
|
} else if err != nil && tc.wantErr == nil {
|
|
t.Fatalf("Unexpected error generating token: %v", err)
|
|
} else if err == nil && tc.wantErr != nil {
|
|
t.Fatalf("Wanted error %q, but got nil", tc.wantErr)
|
|
}
|
|
|
|
tokenPieces := strings.Split(gotToken, ".")
|
|
payloadBase64 := tokenPieces[1]
|
|
|
|
gotClaimBytes, err := base64.RawURLEncoding.DecodeString(payloadBase64)
|
|
if err != nil {
|
|
t.Fatalf("error converting received tokens to bytes: %v", err)
|
|
}
|
|
|
|
gotClaims := unifiedClaimsT{}
|
|
if err := json.Unmarshal(gotClaimBytes, &gotClaims); err != nil {
|
|
t.Fatalf("Error while unmarshaling claims from backend: %v", err)
|
|
}
|
|
|
|
if diff := cmp.Diff(gotClaims, tc.wantClaims); diff != "" {
|
|
t.Fatalf("Bad claims; diff (-got +want):\n%s", diff)
|
|
}
|
|
|
|
// Don't check header or signature values since we're not testing
|
|
// our (fake) backends.
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
func sortPublicKeySlice(a, b serviceaccount.PublicKey) bool {
|
|
return a.KeyID < b.KeyID
|
|
}
|
|
|
|
type headerT struct {
|
|
Algorithm string `json:"alg"`
|
|
KeyID string `json:"kid,omitempty"`
|
|
Type string `json:"typ"`
|
|
}
|
|
|
|
type unifiedClaimsT struct {
|
|
Issuer string `json:"iss,omitempty"`
|
|
Subject string `json:"sub,omitempty"`
|
|
Audience jwt.Audience `json:"aud,omitempty"`
|
|
Expiry *jwt.NumericDate `json:"exp,omitempty"`
|
|
NotBefore *jwt.NumericDate `json:"nbf,omitempty"`
|
|
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
|
ID string `json:"jti,omitempty"`
|
|
Kubernetes kubernetesT `json:"kubernetes.io,omitempty"`
|
|
}
|
|
|
|
type privateClaimsT struct {
|
|
Kubernetes kubernetesT `json:"kubernetes.io,omitempty"`
|
|
}
|
|
|
|
type kubernetesT struct {
|
|
Namespace string `json:"namespace,omitempty"`
|
|
Svcacct refT `json:"serviceaccount,omitempty"`
|
|
Pod *refT `json:"pod,omitempty"`
|
|
Secret *refT `json:"secret,omitempty"`
|
|
Node *refT `json:"node,omitempty"`
|
|
WarnAfter *jwt.NumericDate `json:"warnafter,omitempty"`
|
|
}
|
|
|
|
type refT struct {
|
|
Name string `json:"name,omitempty"`
|
|
UID string `json:"uid,omitempty"`
|
|
}
|
|
|
|
type supportedKeyT struct {
|
|
key *rsa.PublicKey
|
|
excludeFromOidc bool
|
|
}
|
|
|
|
type dummyExtrnalSigner struct {
|
|
externaljwtv1.UnimplementedExternalJWTSignerServer
|
|
|
|
// required for Sign()
|
|
keyID string
|
|
signingAlgorithm string
|
|
signature string
|
|
headerType string
|
|
|
|
// required for FetchKeys()
|
|
keyLock sync.Mutex
|
|
supportedKeys map[string]supportedKeyT
|
|
refreshHintSeconds int
|
|
DataTimeStamp *timestamppb.Timestamp
|
|
SupportedKeysOverride []*externaljwtv1.Key
|
|
}
|
|
|
|
func (des *dummyExtrnalSigner) Sign(ctx context.Context, r *externaljwtv1.SignJWTRequest) (*externaljwtv1.SignJWTResponse, error) {
|
|
header := &headerT{
|
|
Type: des.headerType,
|
|
Algorithm: des.signingAlgorithm,
|
|
KeyID: des.keyID,
|
|
}
|
|
|
|
headerJSON, err := json.Marshal(header)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create header for JWT response")
|
|
}
|
|
|
|
resp := &externaljwtv1.SignJWTResponse{
|
|
Header: base64.RawURLEncoding.EncodeToString(headerJSON),
|
|
Signature: des.signature,
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func (des *dummyExtrnalSigner) FetchKeys(ctx context.Context, r *externaljwtv1.FetchKeysRequest) (*externaljwtv1.FetchKeysResponse, error) {
|
|
des.keyLock.Lock()
|
|
defer des.keyLock.Unlock()
|
|
|
|
pbKeys := []*externaljwtv1.Key{}
|
|
if des.SupportedKeysOverride != nil {
|
|
pbKeys = des.SupportedKeysOverride
|
|
} else {
|
|
for kid, k := range des.supportedKeys {
|
|
keyBytes, err := x509.MarshalPKIXPublicKey(k.key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("while marshaling key: %w", err)
|
|
}
|
|
pbKey := &externaljwtv1.Key{
|
|
KeyId: kid,
|
|
Key: keyBytes,
|
|
ExcludeFromOidcDiscovery: k.excludeFromOidc,
|
|
}
|
|
pbKeys = append(pbKeys, pbKey)
|
|
}
|
|
}
|
|
|
|
return &externaljwtv1.FetchKeysResponse{
|
|
Keys: pbKeys,
|
|
DataTimestamp: des.DataTimeStamp,
|
|
RefreshHintSeconds: int64(des.refreshHintSeconds),
|
|
}, nil
|
|
}
|