Improve validation in authorization path

This commit is contained in:
Mariano Cano
2025-12-02 16:54:44 -08:00
parent 48ed3a5d17
commit 1011f5f540
31 changed files with 464 additions and 108 deletions

View File

@@ -792,6 +792,9 @@ type attestationObject struct {
// TODO(bweeks): move attestation verification to a shared package.
func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error {
// Update challenge with the payload
ch.Payload = payload
// Load authorization to store the key fingerprint.
az, err := db.GetAuthorization(ctx, ch.AuthorizationID)
if err != nil {
@@ -946,7 +949,6 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
ch.Status = StatusValid
ch.Error = nil
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
ch.Payload = payload
ch.PayloadFormat = format
// Store the fingerprint in the authorization.

View File

@@ -878,7 +878,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error")
@@ -4077,7 +4077,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, errorPayload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error")
@@ -4117,7 +4117,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, errorBase64Payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "failed base64 decoding attObj %q", "?!")
@@ -4157,7 +4157,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, emptyPayload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty")
@@ -4197,7 +4197,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, emptyObjectPayload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty")
@@ -4237,7 +4237,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, errorNonWellformedCBORPayload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "attObj is not well formed CBOR: unexpected EOF")
@@ -4279,7 +4279,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, errorUnsupportedFormat, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "unsupported attestation object format %q", "unsupported-format")
@@ -4326,7 +4326,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewError(ErrorBadAttestationStatementType, "attestation format %q is not enabled", "step")
@@ -4383,7 +4383,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "x5c not present")
@@ -4432,7 +4432,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "serial-number", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "challenge token does not match")
@@ -4480,7 +4480,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "non-matching-value", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
subproblem := NewSubproblemWithIdentifier(
@@ -4560,7 +4560,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "x5c not present")
@@ -4616,7 +4616,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, "permanent identifier does not match").
@@ -4713,7 +4713,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
assert.Equal(t, StatusInvalid, updch.Status)
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
assert.Equal(t, "12345678", updch.Value)
assert.Nil(t, updch.Payload)
assert.Equal(t, payload, updch.Payload)
assert.Empty(t, updch.PayloadFormat)
err := NewDetailedError(ErrorBadAttestationStatementType, `unsupported attestation object format "bogus-format"`)

View File

@@ -7,6 +7,7 @@ import (
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
@@ -60,7 +61,7 @@ func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
}
if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
ctx := r.Context()
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RenewMethod)
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
return peer, parts[1], err
}

View File

@@ -94,8 +94,9 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
return
}
opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial {
render.Error(w, r, errs.BadRequest("serial number in client certificate different than body"))
if serialNumber := opts.Crt.SerialNumber.String(); opts.Serial != serialNumber {
render.Error(w, r, errs.Forbidden(
"request serial number %q and certificate serial number %q do not match", opts.Serial, serialNumber))
return
}
// TODO: should probably be checking if the certificate was revoked here.

View File

@@ -70,14 +70,13 @@ func SSHRevoke(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
a := mustAuthority(ctx)
// A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT)
if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, r, errs.UnauthorizedErr(err))
return
}
opts.OTT = body.OTT
if err := a.Revoke(ctx, opts); err != nil {

View File

@@ -80,6 +80,10 @@ func testAuthority(t *testing.T, opts ...Option) *Authority {
EnableSSHCA: &enableSSHCA,
},
},
&provisioner.ACME{
Name: "acme",
Type: "ACME",
},
&provisioner.JWK{
Name: "uninitialized",
Type: "JWK",

View File

@@ -93,7 +93,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// Store the token to protect against reuse unless it's skipped.
// If we cannot get a token id from the provisioner, just hash the token.
if !SkipTokenReuseFromContext(ctx) {
if err := a.UseToken(token, p); err != nil {
if err := a.UseToken(ctx, token, p); err != nil {
return nil, err
}
}
@@ -138,7 +138,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
}
// Check that the token has not been used.
if err := a.UseToken(token, prov); err != nil {
if err := a.UseToken(r.Context(), token, prov); err != nil {
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
}
@@ -193,22 +193,35 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
// UseToken stores the token to protect against reuse.
//
// This method currently ignores any error coming from the GetTokenID, but it
// should specifically ignore the error provisioner.ErrAllowTokenReuse.
func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
if reuseKey, err := prov.GetTokenID(token); err == nil {
if reuseKey == "" {
sum := sha256.Sum256([]byte(token))
reuseKey = strings.ToLower(hex.EncodeToString(sum[:]))
}
ok, err := a.db.UseToken(reuseKey, token)
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token")
}
if !ok {
return errs.Unauthorized("token already used")
// This method currently ignores most errors coming from the GetTokenID because
// the token is already validated. But it should specifically ignore the errors
// provisioner.ErrAllowTokenReuse, provisioner.ErrNotImplemented, and
// provisioner.ErrTokenFlowNotSupported unless this latter one used in a renewal
// flow without mTLS.
func (a *Authority) UseToken(ctx context.Context, token string, prov provisioner.Interface) error {
reuseKey, err := prov.GetTokenID(token)
if err != nil {
// Fail on ErrTokenFlowNotSupported but allow x5cInsecure renew token
if errors.Is(err, provisioner.ErrTokenFlowNotSupported) && provisioner.RenewMethod != provisioner.MethodFromContext(ctx) {
return errs.BadRequest("token flow is not supported")
}
return nil
}
if reuseKey == "" {
sum := sha256.Sum256([]byte(token))
reuseKey = strings.ToLower(hex.EncodeToString(sum[:]))
}
ok, err := a.db.UseToken(reuseKey, token)
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token")
}
if !ok {
return errs.Unauthorized("token already used")
}
return nil
}
@@ -398,7 +411,7 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
// AuthorizeRenewToken validates the renew token and returns the leaf
// certificate in the x5cInsecure header.
func (a *Authority) AuthorizeRenewToken(_ context.Context, ott string) (*x509.Certificate, error) {
func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
var claims jose.Claims
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
if err != nil {
@@ -413,7 +426,7 @@ func (a *Authority) AuthorizeRenewToken(_ context.Context, ott string) (*x509.Ce
if err != nil {
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
}
if err := a.UseToken(ott, p); err != nil {
if err := a.UseToken(ctx, ott, p); err != nil {
return nil, err
}

View File

@@ -193,6 +193,25 @@ func TestAuthority_authorizeToken(t *testing.T) {
code: http.StatusUnauthorized,
}
},
"fail/token-flow-not-supported": func(t *testing.T) *authorizeTest {
cl := jose.Claims{
Subject: "test.smallstep.com",
Issuer: validIssuer,
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
IssuedAt: jose.NewNumericDate(now),
Audience: []string{"acme/acme"},
ID: "45",
}
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
token: raw,
err: errors.New("token flow is not supported"),
code: http.StatusBadRequest,
}
},
"ok/simpledb": func(t *testing.T) *authorizeTest {
cl := jose.Claims{
Subject: "test.smallstep.com",

View File

@@ -138,9 +138,10 @@ func (p *ACME) GetIDForToken() string {
return "acme/" + p.Name
}
// GetTokenID returns the identifier of the token.
// GetTokenID returns the identifier of the token. This provisioner will always
// return [ErrTokenFlowNotSupported].
func (p *ACME) GetTokenID(string) (string, error) {
return "", errors.New("acme provisioner does not implement GetTokenID")
return "", ErrTokenFlowNotSupported
}
// GetName returns the name of the provisioner.

View File

@@ -82,6 +82,9 @@ func TestACME_Getters(t *testing.T) {
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
tokenID, err := p.GetTokenID("token")
assert.Empty(t, tokenID)
assert.Equal(t, ErrTokenFlowNotSupported, err)
}
func TestACME_Init(t *testing.T) {

View File

@@ -77,7 +77,7 @@ func (c *Collection) LoadByName(name string) (Interface, bool) {
}
// LoadByTokenID a provisioner by identifier found in token.
// For different provisioner types this identifier may be found in in different
// For different provisioner types this identifier may be found in different
// attributes of the token.
func (c *Collection) LoadByTokenID(tokenProvisionerID string) (Interface, bool) {
return loadProvisioner(c.byTokenID, tokenProvisionerID)

View File

@@ -8,13 +8,15 @@ import (
"sync"
"testing"
"github.com/smallstep/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
)
func TestCollection_Load(t *testing.T) {
p, err := generateJWK()
assert.FatalError(t, err)
require.NoError(t, err)
byID := new(sync.Map)
byID.Store(p.GetID(), p)
byID.Store("string", "a-string")
@@ -52,15 +54,60 @@ func TestCollection_Load(t *testing.T) {
}
}
func TestCollection_LoadByTokenID(t *testing.T) {
p1, err := generateJWK()
require.NoError(t, err)
p2, err := generateACME()
require.NoError(t, err)
byTokenID := new(sync.Map)
byTokenID.Store(p1.GetIDForToken(), p1)
byTokenID.Store(p2.GetIDForToken(), p2)
byTokenID.Store("string", "a-string")
type fields struct {
byTokenID *sync.Map
}
type args struct {
id string
}
tests := []struct {
name string
fields fields
args args
want Interface
want1 bool
}{
{"ok jwk", fields{byTokenID}, args{p1.GetIDForToken()}, p1, true},
{"ok acme", fields{byTokenID}, args{p2.GetIDForToken()}, p2, true},
{"fail missing", fields{byTokenID}, args{"missing"}, nil, false},
{"invalid", fields{byTokenID}, args{"string"}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Collection{
byTokenID: tt.fields.byTokenID,
}
got, got1 := c.LoadByTokenID(tt.args.id)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Collection.Load() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestCollection_LoadByToken(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
require.NoError(t, err)
p2, err := generateJWK()
assert.FatalError(t, err)
require.NoError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
require.NoError(t, err)
p4, err := generateK8sSA(nil)
assert.FatalError(t, err)
require.NoError(t, err)
byID := new(sync.Map)
byID.Store(p1.GetID(), p1)
@@ -75,35 +122,35 @@ func TestCollection_LoadByToken(t *testing.T) {
byID2.Store(p3.GetID(), p3)
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
assert.FatalError(t, err)
require.NoError(t, err)
token, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], jwk)
assert.FatalError(t, err)
require.NoError(t, err)
t1, c1, err := parseToken(token)
assert.FatalError(t, err)
require.NoError(t, err)
jwk, err = decryptJSONWebKey(p2.EncryptedKey)
assert.FatalError(t, err)
require.NoError(t, err)
token, err = generateSimpleToken(p2.Name, testAudiences.Sign[1], jwk)
assert.FatalError(t, err)
require.NoError(t, err)
t2, c2, err := parseToken(token)
assert.FatalError(t, err)
require.NoError(t, err)
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
require.NoError(t, err)
t3, c3, err := parseToken(token)
assert.FatalError(t, err)
require.NoError(t, err)
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
require.NoError(t, err)
t4, c4, err := parseToken(token)
assert.FatalError(t, err)
require.NoError(t, err)
jwk, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
require.NoError(t, err)
token, err = generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
require.NoError(t, err)
t5, c5, err := parseToken(token)
assert.FatalError(t, err)
require.NoError(t, err)
type fields struct {
byID *sync.Map
@@ -159,11 +206,11 @@ func TestCollection_LoadByCertificate(t *testing.T) {
}
p1, err := generateJWK()
assert.FatalError(t, err)
require.NoError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
require.NoError(t, err)
p3, err := generateACME()
assert.FatalError(t, err)
require.NoError(t, err)
byName := new(sync.Map)
byName.Store(p1.GetName(), p1)
@@ -229,11 +276,11 @@ func TestCollection_LoadByCertificate(t *testing.T) {
func TestCollection_LoadEncryptedKey(t *testing.T) {
c := NewCollection(testAudiences)
p1, err := generateJWK()
assert.FatalError(t, err)
assert.FatalError(t, c.Store(p1))
require.NoError(t, err)
require.NoError(t, c.Store(p1))
p2, err := generateOIDC()
assert.FatalError(t, err)
assert.FatalError(t, c.Store(p2))
require.NoError(t, err)
require.NoError(t, c.Store(p2))
// Add oidc in byKey.
// It should not happen.
@@ -269,9 +316,9 @@ func TestCollection_LoadEncryptedKey(t *testing.T) {
func TestCollection_Store(t *testing.T) {
c := NewCollection(testAudiences)
p1, err := generateJWK()
assert.FatalError(t, err)
require.NoError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
require.NoError(t, err)
type args struct {
p Interface
@@ -297,7 +344,7 @@ func TestCollection_Store(t *testing.T) {
func TestCollection_Find(t *testing.T) {
c, err := generateCollection(10, 10)
assert.FatalError(t, err)
require.NoError(t, err)
trim := func(s string) string {
return strings.TrimLeft(s, "0")
@@ -391,7 +438,7 @@ func Test_matchesAudience(t *testing.T) {
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
assert.Equal(t, tc.exp, matchesAudience(tc.a, tc.b))
})
}
}

View File

@@ -73,7 +73,7 @@ func (p *K8sSA) GetIDForToken() string {
// GetTokenID returns an unimplemented error and does not use the input ott.
func (p *K8sSA) GetTokenID(string) (string, error) {
return "", errors.New("not implemented")
return "", ErrNotImplemented
}
// GetName returns the name of the provisioner.

View File

@@ -283,20 +283,14 @@ func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) erro
return p.ctl.AuthorizeRenew(ctx, crt)
}
// AuthorizeRevoke returns an error if the token is not valid.
func (p *Nebula) AuthorizeRevoke(_ context.Context, token string) error {
return p.validateToken(token, p.ctl.Audiences.Revoke)
// AuthorizeRevoke returns an unauthorized error.
func (p *Nebula) AuthorizeRevoke(context.Context, string) error {
return errs.Unauthorized("nebula provisioner does not support revoke")
}
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
func (p *Nebula) AuthorizeSSHRevoke(_ context.Context, token string) error {
if !p.ctl.Claimer.IsSSHCAEnabled() {
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
}
if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil {
return err
}
return nil
// AuthorizeSSHRevoke returns an unauthorized error.
func (p *Nebula) AuthorizeSSHRevoke(context.Context, string) error {
return errs.Unauthorized("nebula provisioner does not support SSH revoke")
}
// AuthorizeSSHRenew returns an unauthorized error.
@@ -309,11 +303,6 @@ func (p *Nebula) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, [
return nil, nil, errs.Unauthorized("nebula provisioner does not support SSH rekey")
}
func (p *Nebula) validateToken(token string, audiences []string) error {
_, _, err := p.authorizeToken(token, audiences)
return err
}
func (p *Nebula) authorizeToken(token string, audiences []string) (*nebula.NebulaCertificate, *jwtPayload, error) {
jwt, err := jose.ParseSigned(token)
if err != nil {

View File

@@ -709,7 +709,7 @@ func TestNebula_AuthorizeRevoke(t *testing.T) {
args args
wantErr bool
}{
{"ok", p, args{ctx, ok}, false},
{"fail unauthorized", p, args{ctx, ok}, true},
{"fail token", p, args{ctx, failToken}, true},
}
for _, tt := range tests {
@@ -749,7 +749,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
args args
wantErr bool
}{
{"ok", p, args{ctx, ok}, false},
{"fail unauthorized", p, args{ctx, ok}, true},
{"fail token", p, args{ctx, failToken}, true},
{"fail disabled", pDisabled, args{ctx, ok}, true},
}

View File

@@ -77,6 +77,13 @@ func (p Uninitialized) MarshalJSON() ([]byte, error) {
// the understanding that we are not following security best practices
var ErrAllowTokenReuse = stderrors.New("allow token reuse")
// ErrTokenFlowNotSupported is an error that is returned by provisioners on
// GetTokenID when the use of tokens is not supported.
var ErrTokenFlowNotSupported = stderrors.New("token flow is not supported")
// ErrNotImplemented is an error returned when one method is not implemented.
var ErrNotImplemented = stderrors.New("not implemented")
// Audiences stores all supported audiences by request type.
type Audiences struct {
Sign []string

View File

@@ -95,9 +95,10 @@ func (s *SCEP) GetEncryptedKey() (string, string, bool) {
return "", "", false
}
// GetTokenID returns the identifier of the token.
// GetTokenID returns the identifier of the token. This provisioner will always
// return [ErrTokenFlowNotSupported].
func (s *SCEP) GetTokenID(string) (string, error) {
return "", errors.New("scep provisioner does not implement GetTokenID")
return "", ErrTokenFlowNotSupported
}
// GetOptions returns the configured provisioner options.

View File

@@ -26,6 +26,44 @@ import (
"go.step.sm/crypto/x509util"
)
func generateSCEP(t *testing.T) *SCEP {
t.Helper()
ca, err := minica.New()
require.NoError(t, err)
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
cert, err := ca.Sign(&x509.Certificate{
Subject: pkix.Name{CommonName: "SCEP decrypter"},
PublicKey: key.Public(),
})
require.NoError(t, err)
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Bytes: cert.Raw,
})
block, err := pemutil.Serialize(key, pemutil.WithPassword([]byte("password")))
require.NoError(t, err)
keyPEM := pem.EncodeToMemory(block)
p := &SCEP{
Type: "SCEP",
Name: "scep",
ChallengePassword: "password123",
MinimumPublicKeyLength: 0,
DecrypterCertificate: certPEM,
DecrypterKeyPEM: keyPEM,
DecrypterKeyPassword: "password",
EncryptionAlgorithmIdentifier: 0,
}
require.NoError(t, p.Init(Config{Claims: globalProvisionerClaims}))
return p
}
func Test_challengeValidationController_Validate(t *testing.T) {
dummyCSR := &x509.CertificateRequest{
Raw: []byte{1},
@@ -722,3 +760,17 @@ func TestSCEP_Init(t *testing.T) {
})
}
}
func TestSCEP_Getters(t *testing.T) {
p := generateSCEP(t)
assert.Equal(t, "scep/scep", p.GetID())
assert.Equal(t, "scep", p.GetName())
assert.Equal(t, TypeSCEP, p.GetType())
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false)
}
tokenID, err := p.GetTokenID("token")
assert.Empty(t, tokenID)
assert.Equal(t, ErrTokenFlowNotSupported, err)
}

View File

@@ -193,8 +193,11 @@ func (p *SSHPOP) AuthorizeSSHRevoke(_ context.Context, token string) error {
if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
}
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number")
if serial := strconv.FormatUint(claims.sshCert.Serial, 10); claims.Subject != serial {
return errs.Forbidden(
"token subject %q and sshpop certificate serial number %q do not match",
claims.Subject, serial,
)
}
return nil
}

View File

@@ -268,8 +268,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
return test{
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
code: http.StatusForbidden,
err: errors.New(`token subject "foo" and sshpop certificate serial number "0" do not match`),
}
},
"ok": func(t *testing.T) test {

View File

@@ -154,7 +154,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
}
// Using the leaf certificates key to validate the claims accomplishes two
// Using the leaf certificate's key to validate the claims accomplishes two
// things:
// 1. Asserts that the private key used to sign the token corresponds
// to the public certificate in the `x5c` header of the token.

View File

@@ -568,12 +568,30 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
"ok": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
require.NoError(t, err)
serialNumber := certs[0].SerialNumber.String()
jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key")
require.NoError(t, err)
p, err := generateX5C(nil)
require.NoError(t, err)
tok, err := generateToken("foo", p.GetName(), testAudiences.Revoke[0], "",
tok, err := generateToken(serialNumber, p.GetName(), testAudiences.Revoke[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
require.NoError(t, err)
return test{
p: p,
token: tok,
}
},
"ok/different-serial-number": func(t *testing.T) test {
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
require.NoError(t, err)
jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key")
require.NoError(t, err)
p, err := generateX5C(nil)
require.NoError(t, err)
tok, err := generateToken("123456789", p.GetName(), testAudiences.Revoke[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk,
withX5CHdr(certs))
require.NoError(t, err)

View File

@@ -622,6 +622,16 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke", opts...)
}
// Verify that the serial in the token matches the serial from the request.
if revokeOpts.Serial != claims.Subject {
return errs.ApplyOptions(
errs.Forbidden(
"request serial number %q and token subject %q do not match",
revokeOpts.Serial, claims.Subject,
), opts...,
)
}
// This method will also validate the audiences for JWK provisioners.
p, err := a.LoadProvisionerByToken(token, &claims.Claims)
if err != nil {

View File

@@ -1652,6 +1652,42 @@ func TestAuthority_Revoke(t *testing.T) {
},
}
},
"fail/serial-number": func() test {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificate: func(sn string) (*x509.Certificate, error) {
return nil, errors.New("not found")
},
}))
cl := jose.Claims{
Subject: "token-sn",
Issuer: validIssuer,
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: "44",
}
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
require.NoError(t, err)
return test{
auth: _a,
ctx: tlsRevokeCtx,
opts: &RevokeOptions{
Serial: "request-sn",
ReasonCode: reasonCode,
Reason: reason,
OTT: raw,
},
err: errors.New(`request serial number "request-sn" and token subject "token-sn" do not match`),
code: http.StatusForbidden,
checkErrDetails: func(err *errs.Error) {
assert.Equal(t, raw, err.Details["token"])
},
}
},
"ok/token": func() test {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
@@ -1980,7 +2016,7 @@ func TestAuthority_CRL(t *testing.T) {
sn := fmt.Sprintf("%v", i)
cl := jose.Claims{
Subject: fmt.Sprintf("sn-%v", i),
Subject: sn,
Issuer: validIssuer,
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(time.Minute)),

View File

@@ -14,9 +14,23 @@ import (
// Option modifies the Error type.
type Option func(e *Error) error
// withDefaultMessage returns an Option that modifies the error by overwriting the
// message only if it is empty.
func withDefaultMessage(format string, args ...interface{}) Option {
// withDefaultMessage returns an Option that modifies the error by overwriting
// the message only if it is empty. Having withDefaultMessage and
// withFormattedMessage avoid vet errors when the "format" passed to
// "fmt.Sprintf" is not a constant.
func withDefaultMessage(message string) Option {
return func(e *Error) error {
if e.Msg != "" {
return e
}
e.Msg = message
return e
}
}
// withFormattedMessage returns an Option that modifies the error by overwriting
// the formatted message only if it is empty.
func withFormattedMessage(format string, args ...interface{}) Option {
return func(e *Error) error {
if e.Msg != "" {
return e
@@ -27,7 +41,7 @@ func withDefaultMessage(format string, args ...interface{}) Option {
}
// WithMessage returns an Option that modifies the error by overwriting the
// message only if it is empty.
// message with the formatted string.
func WithMessage(format string, args ...interface{}) Option {
return func(e *Error) error {
e.Msg = fmt.Sprintf(format, args...)
@@ -35,6 +49,15 @@ func WithMessage(format string, args ...interface{}) Option {
}
}
// WithErrorMessage returns an Option that modifies the error by overwriting the
// message with the error string.
func WithErrorMessage() Option {
return func(e *Error) error {
e.Msg = e.Error()
return e
}
}
// WithKeyVal returns an Option that adds the given key-value pair to the
// Error details. This is helpful for debugging errors.
func WithKeyVal(key string, val interface{}) Option {
@@ -183,7 +206,8 @@ func StatusCodeError(code int, e error, opts ...Option) error {
}
const (
seeLogs = "Please see the certificate authority logs for more info."
seeLogs = "Please see the certificate authority logs for more info."
defaultMsg = "The requested could not be completed. " + seeLogs
// BadRequestDefaultMsg 400 default msg
BadRequestDefaultMsg = "The request could not be completed; malformed or missing data. " + seeLogs
// UnauthorizedDefaultMsg 401 default msg
@@ -198,6 +222,25 @@ const (
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
)
func defaultMessage(status int) string {
switch status {
case http.StatusBadRequest:
return BadRequestDefaultMsg
case http.StatusUnauthorized:
return UnauthorizedDefaultMsg
case http.StatusForbidden:
return ForbiddenDefaultMsg
case http.StatusNotFound:
return NotFoundDefaultMsg
case http.StatusInternalServerError:
return InternalServerErrorDefaultMsg
case http.StatusNotImplemented:
return NotImplementedDefaultMsg
default:
return defaultMsg
}
}
const (
// BadRequestPrefix is the prefix added to the bad request messages that are
// directly sent to the cli.
@@ -292,7 +335,7 @@ func NewErr(status int, err error, opts ...Option) error {
// Errorf creates a new error using the given format and status code.
func Errorf(code int, format string, args ...interface{}) error {
as, opts := splitOptionArgs(args)
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
opts = append(opts, withDefaultMessage(defaultMessage(code)))
e := &Error{Status: code, Err: fmt.Errorf(format, as...)}
for _, o := range opts {
o(e)
@@ -384,7 +427,6 @@ func NotFoundErr(err error, opts ...Option) error {
// UnexpectedErr will be used when the certificate authority makes an outgoing
// request and receives an unhandled status code.
func UnexpectedErr(code int, err error, opts ...Option) error {
opts = append(opts, withDefaultMessage("The certificate authority received an "+
"unexpected HTTP status code - '%d'. "+seeLogs, code))
opts = append(opts, withFormattedMessage("The certificate authority received an unexpected HTTP status code - '%d'. "+seeLogs, code))
return NewErr(code, err, opts...)
}

View File

@@ -127,3 +127,101 @@ func TestError_Unwrap_As(t *testing.T) {
})
}
}
func TestErrorf(t *testing.T) {
tests := []struct {
name string
code int
format string
args []any
want error
}{
{"bad request", 400, "test error string", nil, &Error{
Status: 400,
Err: errors.New("test error string"),
Msg: BadRequestDefaultMsg,
}},
{"unauthorized", 401, "test error string", nil, &Error{
Status: 401,
Err: errors.New("test error string"),
Msg: UnauthorizedDefaultMsg,
}},
{"forbidden", 403, "test error string", nil, &Error{
Status: 403,
Err: errors.New("test error string"),
Msg: ForbiddenDefaultMsg,
}},
{"not found", 404, "test error string", nil, &Error{
Status: 404,
Err: errors.New("test error string"),
Msg: NotFoundDefaultMsg,
}},
{"internal server error", 500, "test error string", nil, &Error{
Status: 500,
Err: errors.New("test error string"),
Msg: InternalServerErrorDefaultMsg,
}},
{"not implemented", 501, "test error string", nil, &Error{
Status: 501,
Err: errors.New("test error string"),
Msg: NotImplementedDefaultMsg,
}},
{"other", 502, "test error string", nil, &Error{
Status: 502,
Err: errors.New("test error string"),
Msg: defaultMsg,
}},
{"formatted args", 401, "test error string: %s", []any{"some reason"}, &Error{
Status: 401,
Err: errors.New("test error string: some reason"),
Msg: UnauthorizedDefaultMsg,
}},
{"WithMessage", 403, "test error string", []any{WithMessage("%s failed", "something")}, &Error{
Status: 403,
Err: errors.New("test error string"),
Msg: "something failed",
}},
{"WithErrorMessage", 404, "test error string", []any{WithErrorMessage()}, &Error{
Status: 404,
Err: errors.New("test error string"),
Msg: "test error string",
}},
{"WithKeyValue", 500, "test error string", []any{WithKeyVal("foo", 1), WithKeyVal("bar", "zar")}, &Error{
Status: 500,
Err: errors.New("test error string"),
Msg: InternalServerErrorDefaultMsg,
Details: map[string]interface{}{"foo": 1, "bar": "zar"},
}},
{"withDefaultMessage", 501, "test error string", []any{withDefaultMessage("some message")}, &Error{
Status: 501,
Err: errors.New("test error string"),
Msg: "some message",
}},
{"withFormattedMessage", 502, "test error string", []any{withFormattedMessage("some message: %s", "the reason")}, &Error{
Status: 502,
Err: errors.New("test error string"),
Msg: "some message: the reason",
}},
{"WithMessage and withDefaultMessage", 500, "test error string", []any{WithMessage("the message"), withDefaultMessage("some message")}, &Error{
Status: 500,
Err: errors.New("test error string"),
Msg: "the message",
}},
{"WithErrorMessage and withFormattedMessage", 500, "test error string", []any{WithErrorMessage(), withFormattedMessage("some message: %s", "the reason")}, &Error{
Status: 500,
Err: errors.New("test error string"),
Msg: "test error string",
}},
{"formatted args and withMessage", 500, "test error string: %s, code %d", []any{"reason", 1234, WithMessage("the message")}, &Error{
Status: 500,
Err: errors.New("test error string: reason, code 1234"),
Msg: "the message",
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotErr := Errorf(tt.code, tt.format, tt.args...)
assert.Equal(t, tt.want, gotErr)
})
}
}

View File

@@ -53,6 +53,8 @@ func Test_reflectRequestID(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
t.Setenv("STEPPATH", dir)
m, err := minica.New(minica.WithName("Step E2E"))
require.NoError(t, err)

View File

@@ -103,6 +103,8 @@ func newTestCA(t *testing.T, name string) *testCA {
require.NoError(t, err)
dir := t.TempDir()
t.Setenv("STEPPATH", dir)
m, err := minica.New(minica.WithName(name), minica.WithGetSignerFunc(func() (crypto.Signer, error) {
return signer, nil
}))

View File

@@ -34,6 +34,8 @@ func TestIssuesCertificateUsingSCEPWithDecrypterAndUpstreamCAS(t *testing.T) {
require.NoError(t, err)
dir := t.TempDir()
t.Setenv("STEPPATH", dir)
m, err := minica.New(minica.WithName("Step E2E | SCEP Decrypter w/ Upstream CAS"), minica.WithGetSignerFunc(func() (crypto.Signer, error) {
return signer, nil
}))

View File

@@ -32,6 +32,8 @@ func TestIssuesCertificateUsingSCEPWithDecrypter(t *testing.T) {
require.NoError(t, err)
dir := t.TempDir()
t.Setenv("STEPPATH", dir)
m, err := minica.New(minica.WithName("Step E2E | SCEP Decrypter"), minica.WithGetSignerFunc(func() (crypto.Signer, error) {
return signer, nil
}))

View File

@@ -29,6 +29,8 @@ func TestFailsIssuingCertificateUsingRegularSCEPWithUpstreamCAS(t *testing.T) {
require.NoError(t, err)
dir := t.TempDir()
t.Setenv("STEPPATH", dir)
m, err := minica.New(minica.WithName("Step E2E | SCEP Regular w/ Upstream CAS"), minica.WithGetSignerFunc(func() (crypto.Signer, error) {
return signer, nil
}))