mirror of
https://github.com/outbackdingo/certificates.git
synced 2026-01-27 02:18:27 +00:00
Improve validation in authorization path
This commit is contained in:
@@ -792,6 +792,9 @@ type attestationObject struct {
|
||||
|
||||
// TODO(bweeks): move attestation verification to a shared package.
|
||||
func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error {
|
||||
// Update challenge with the payload
|
||||
ch.Payload = payload
|
||||
|
||||
// Load authorization to store the key fingerprint.
|
||||
az, err := db.GetAuthorization(ctx, ch.AuthorizationID)
|
||||
if err != nil {
|
||||
@@ -946,7 +949,6 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose
|
||||
ch.Status = StatusValid
|
||||
ch.Error = nil
|
||||
ch.ValidatedAt = clock.Now().Format(time.RFC3339)
|
||||
ch.Payload = payload
|
||||
ch.PayloadFormat = format
|
||||
|
||||
// Store the fingerprint in the authorization.
|
||||
|
||||
@@ -878,7 +878,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error")
|
||||
@@ -4077,7 +4077,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, errorPayload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error")
|
||||
@@ -4117,7 +4117,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, errorBase64Payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "failed base64 decoding attObj %q", "?!")
|
||||
@@ -4157,7 +4157,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, emptyPayload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty")
|
||||
@@ -4197,7 +4197,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, emptyObjectPayload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty")
|
||||
@@ -4237,7 +4237,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, errorNonWellformedCBORPayload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "attObj is not well formed CBOR: unexpected EOF")
|
||||
@@ -4279,7 +4279,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, errorUnsupportedFormat, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "unsupported attestation object format %q", "unsupported-format")
|
||||
@@ -4326,7 +4326,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewError(ErrorBadAttestationStatementType, "attestation format %q is not enabled", "step")
|
||||
@@ -4383,7 +4383,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "x5c not present")
|
||||
@@ -4432,7 +4432,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "serial-number", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "challenge token does not match")
|
||||
@@ -4480,7 +4480,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "non-matching-value", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
subproblem := NewSubproblemWithIdentifier(
|
||||
@@ -4560,7 +4560,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "x5c not present")
|
||||
@@ -4616,7 +4616,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, "permanent identifier does not match").
|
||||
@@ -4713,7 +4713,7 @@ func Test_deviceAttest01Validate(t *testing.T) {
|
||||
assert.Equal(t, StatusInvalid, updch.Status)
|
||||
assert.Equal(t, ChallengeType("device-attest-01"), updch.Type)
|
||||
assert.Equal(t, "12345678", updch.Value)
|
||||
assert.Nil(t, updch.Payload)
|
||||
assert.Equal(t, payload, updch.Payload)
|
||||
assert.Empty(t, updch.PayloadFormat)
|
||||
|
||||
err := NewDetailedError(ErrorBadAttestationStatementType, `unsupported attestation object format "bogus-format"`)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
@@ -60,7 +61,7 @@ func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) {
|
||||
}
|
||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||
ctx := r.Context()
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RenewMethod)
|
||||
peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
|
||||
return peer, parts[1], err
|
||||
}
|
||||
|
||||
@@ -94,8 +94,9 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
opts.Crt = r.TLS.PeerCertificates[0]
|
||||
if opts.Crt.SerialNumber.String() != opts.Serial {
|
||||
render.Error(w, r, errs.BadRequest("serial number in client certificate different than body"))
|
||||
if serialNumber := opts.Crt.SerialNumber.String(); opts.Serial != serialNumber {
|
||||
render.Error(w, r, errs.Forbidden(
|
||||
"request serial number %q and certificate serial number %q do not match", opts.Serial, serialNumber))
|
||||
return
|
||||
}
|
||||
// TODO: should probably be checking if the certificate was revoked here.
|
||||
|
||||
@@ -70,14 +70,13 @@ func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
|
||||
a := mustAuthority(ctx)
|
||||
|
||||
// A token indicates that we are using the api via a provisioner token,
|
||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||
logOtt(w, body.OTT)
|
||||
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
|
||||
@@ -80,6 +80,10 @@ func testAuthority(t *testing.T, opts ...Option) *Authority {
|
||||
EnableSSHCA: &enableSSHCA,
|
||||
},
|
||||
},
|
||||
&provisioner.ACME{
|
||||
Name: "acme",
|
||||
Type: "ACME",
|
||||
},
|
||||
&provisioner.JWK{
|
||||
Name: "uninitialized",
|
||||
Type: "JWK",
|
||||
|
||||
@@ -93,7 +93,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
||||
// Store the token to protect against reuse unless it's skipped.
|
||||
// If we cannot get a token id from the provisioner, just hash the token.
|
||||
if !SkipTokenReuseFromContext(ctx) {
|
||||
if err := a.UseToken(token, p); err != nil {
|
||||
if err := a.UseToken(ctx, token, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -138,7 +138,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
||||
}
|
||||
|
||||
// Check that the token has not been used.
|
||||
if err := a.UseToken(token, prov); err != nil {
|
||||
if err := a.UseToken(r.Context(), token, prov); err != nil {
|
||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
|
||||
}
|
||||
|
||||
@@ -193,22 +193,35 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
||||
|
||||
// UseToken stores the token to protect against reuse.
|
||||
//
|
||||
// This method currently ignores any error coming from the GetTokenID, but it
|
||||
// should specifically ignore the error provisioner.ErrAllowTokenReuse.
|
||||
func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
|
||||
if reuseKey, err := prov.GetTokenID(token); err == nil {
|
||||
if reuseKey == "" {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
reuseKey = strings.ToLower(hex.EncodeToString(sum[:]))
|
||||
}
|
||||
ok, err := a.db.UseToken(reuseKey, token)
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token")
|
||||
}
|
||||
if !ok {
|
||||
return errs.Unauthorized("token already used")
|
||||
// This method currently ignores most errors coming from the GetTokenID because
|
||||
// the token is already validated. But it should specifically ignore the errors
|
||||
// provisioner.ErrAllowTokenReuse, provisioner.ErrNotImplemented, and
|
||||
// provisioner.ErrTokenFlowNotSupported unless this latter one used in a renewal
|
||||
// flow without mTLS.
|
||||
func (a *Authority) UseToken(ctx context.Context, token string, prov provisioner.Interface) error {
|
||||
reuseKey, err := prov.GetTokenID(token)
|
||||
if err != nil {
|
||||
// Fail on ErrTokenFlowNotSupported but allow x5cInsecure renew token
|
||||
if errors.Is(err, provisioner.ErrTokenFlowNotSupported) && provisioner.RenewMethod != provisioner.MethodFromContext(ctx) {
|
||||
return errs.BadRequest("token flow is not supported")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if reuseKey == "" {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
reuseKey = strings.ToLower(hex.EncodeToString(sum[:]))
|
||||
}
|
||||
|
||||
ok, err := a.db.UseToken(reuseKey, token)
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token")
|
||||
}
|
||||
if !ok {
|
||||
return errs.Unauthorized("token already used")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -398,7 +411,7 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
|
||||
|
||||
// AuthorizeRenewToken validates the renew token and returns the leaf
|
||||
// certificate in the x5cInsecure header.
|
||||
func (a *Authority) AuthorizeRenewToken(_ context.Context, ott string) (*x509.Certificate, error) {
|
||||
func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||
var claims jose.Claims
|
||||
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
|
||||
if err != nil {
|
||||
@@ -413,7 +426,7 @@ func (a *Authority) AuthorizeRenewToken(_ context.Context, ott string) (*x509.Ce
|
||||
if err != nil {
|
||||
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
|
||||
}
|
||||
if err := a.UseToken(ott, p); err != nil {
|
||||
if err := a.UseToken(ctx, ott, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -193,6 +193,25 @@ func TestAuthority_authorizeToken(t *testing.T) {
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
"fail/token-flow-not-supported": func(t *testing.T) *authorizeTest {
|
||||
cl := jose.Claims{
|
||||
Subject: "test.smallstep.com",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
|
||||
IssuedAt: jose.NewNumericDate(now),
|
||||
Audience: []string{"acme/acme"},
|
||||
ID: "45",
|
||||
}
|
||||
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: raw,
|
||||
err: errors.New("token flow is not supported"),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
},
|
||||
"ok/simpledb": func(t *testing.T) *authorizeTest {
|
||||
cl := jose.Claims{
|
||||
Subject: "test.smallstep.com",
|
||||
|
||||
@@ -138,9 +138,10 @@ func (p *ACME) GetIDForToken() string {
|
||||
return "acme/" + p.Name
|
||||
}
|
||||
|
||||
// GetTokenID returns the identifier of the token.
|
||||
// GetTokenID returns the identifier of the token. This provisioner will always
|
||||
// return [ErrTokenFlowNotSupported].
|
||||
func (p *ACME) GetTokenID(string) (string, error) {
|
||||
return "", errors.New("acme provisioner does not implement GetTokenID")
|
||||
return "", ErrTokenFlowNotSupported
|
||||
}
|
||||
|
||||
// GetName returns the name of the provisioner.
|
||||
|
||||
@@ -82,6 +82,9 @@ func TestACME_Getters(t *testing.T) {
|
||||
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
|
||||
kid, key, ok, "", "", false)
|
||||
}
|
||||
tokenID, err := p.GetTokenID("token")
|
||||
assert.Empty(t, tokenID)
|
||||
assert.Equal(t, ErrTokenFlowNotSupported, err)
|
||||
}
|
||||
|
||||
func TestACME_Init(t *testing.T) {
|
||||
|
||||
@@ -77,7 +77,7 @@ func (c *Collection) LoadByName(name string) (Interface, bool) {
|
||||
}
|
||||
|
||||
// LoadByTokenID a provisioner by identifier found in token.
|
||||
// For different provisioner types this identifier may be found in in different
|
||||
// For different provisioner types this identifier may be found in different
|
||||
// attributes of the token.
|
||||
func (c *Collection) LoadByTokenID(tokenProvisionerID string) (Interface, bool) {
|
||||
return loadProvisioner(c.byTokenID, tokenProvisionerID)
|
||||
|
||||
@@ -8,13 +8,15 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
)
|
||||
|
||||
func TestCollection_Load(t *testing.T) {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p.GetID(), p)
|
||||
byID.Store("string", "a-string")
|
||||
@@ -52,15 +54,60 @@ func TestCollection_Load(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadByTokenID(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
require.NoError(t, err)
|
||||
p2, err := generateACME()
|
||||
require.NoError(t, err)
|
||||
|
||||
byTokenID := new(sync.Map)
|
||||
byTokenID.Store(p1.GetIDForToken(), p1)
|
||||
byTokenID.Store(p2.GetIDForToken(), p2)
|
||||
byTokenID.Store("string", "a-string")
|
||||
|
||||
type fields struct {
|
||||
byTokenID *sync.Map
|
||||
}
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok jwk", fields{byTokenID}, args{p1.GetIDForToken()}, p1, true},
|
||||
{"ok acme", fields{byTokenID}, args{p2.GetIDForToken()}, p2, true},
|
||||
{"fail missing", fields{byTokenID}, args{"missing"}, nil, false},
|
||||
{"invalid", fields{byTokenID}, args{"string"}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byTokenID: tt.fields.byTokenID,
|
||||
}
|
||||
got, got1 := c.LoadByTokenID(tt.args.id)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.Load() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadByToken(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
p2, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
p3, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
p4, err := generateK8sSA(nil)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
@@ -75,35 +122,35 @@ func TestCollection_LoadByToken(t *testing.T) {
|
||||
byID2.Store(p3.GetID(), p3)
|
||||
|
||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
token, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], jwk)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
t1, c1, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwk, err = decryptJSONWebKey(p2.EncryptedKey)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
token, err = generateSimpleToken(p2.Name, testAudiences.Sign[1], jwk)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
t2, c2, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
t3, c3, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
t4, c4, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwk, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
token, err = generateK8sSAToken(jwk, nil)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
t5, c5, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
@@ -159,11 +206,11 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
||||
}
|
||||
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
p3, err := generateACME()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
byName := new(sync.Map)
|
||||
byName.Store(p1.GetName(), p1)
|
||||
@@ -229,11 +276,11 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
||||
func TestCollection_LoadEncryptedKey(t *testing.T) {
|
||||
c := NewCollection(testAudiences)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, c.Store(p1))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.Store(p1))
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, c.Store(p2))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, c.Store(p2))
|
||||
|
||||
// Add oidc in byKey.
|
||||
// It should not happen.
|
||||
@@ -269,9 +316,9 @@ func TestCollection_LoadEncryptedKey(t *testing.T) {
|
||||
func TestCollection_Store(t *testing.T) {
|
||||
c := NewCollection(testAudiences)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
p Interface
|
||||
@@ -297,7 +344,7 @@ func TestCollection_Store(t *testing.T) {
|
||||
|
||||
func TestCollection_Find(t *testing.T) {
|
||||
c, err := generateCollection(10, 10)
|
||||
assert.FatalError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
trim := func(s string) string {
|
||||
return strings.TrimLeft(s, "0")
|
||||
@@ -391,7 +438,7 @@ func Test_matchesAudience(t *testing.T) {
|
||||
}
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b))
|
||||
assert.Equal(t, tc.exp, matchesAudience(tc.a, tc.b))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func (p *K8sSA) GetIDForToken() string {
|
||||
|
||||
// GetTokenID returns an unimplemented error and does not use the input ott.
|
||||
func (p *K8sSA) GetTokenID(string) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
return "", ErrNotImplemented
|
||||
}
|
||||
|
||||
// GetName returns the name of the provisioner.
|
||||
|
||||
@@ -283,20 +283,14 @@ func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) erro
|
||||
return p.ctl.AuthorizeRenew(ctx, crt)
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an error if the token is not valid.
|
||||
func (p *Nebula) AuthorizeRevoke(_ context.Context, token string) error {
|
||||
return p.validateToken(token, p.ctl.Audiences.Revoke)
|
||||
// AuthorizeRevoke returns an unauthorized error.
|
||||
func (p *Nebula) AuthorizeRevoke(context.Context, string) error {
|
||||
return errs.Unauthorized("nebula provisioner does not support revoke")
|
||||
}
|
||||
|
||||
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
|
||||
func (p *Nebula) AuthorizeSSHRevoke(_ context.Context, token string) error {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
|
||||
}
|
||||
if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
// AuthorizeSSHRevoke returns an unauthorized error.
|
||||
func (p *Nebula) AuthorizeSSHRevoke(context.Context, string) error {
|
||||
return errs.Unauthorized("nebula provisioner does not support SSH revoke")
|
||||
}
|
||||
|
||||
// AuthorizeSSHRenew returns an unauthorized error.
|
||||
@@ -309,11 +303,6 @@ func (p *Nebula) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, [
|
||||
return nil, nil, errs.Unauthorized("nebula provisioner does not support SSH rekey")
|
||||
}
|
||||
|
||||
func (p *Nebula) validateToken(token string, audiences []string) error {
|
||||
_, _, err := p.authorizeToken(token, audiences)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Nebula) authorizeToken(token string, audiences []string) (*nebula.NebulaCertificate, *jwtPayload, error) {
|
||||
jwt, err := jose.ParseSigned(token)
|
||||
if err != nil {
|
||||
|
||||
@@ -709,7 +709,7 @@ func TestNebula_AuthorizeRevoke(t *testing.T) {
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p, args{ctx, ok}, false},
|
||||
{"fail unauthorized", p, args{ctx, ok}, true},
|
||||
{"fail token", p, args{ctx, failToken}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -749,7 +749,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p, args{ctx, ok}, false},
|
||||
{"fail unauthorized", p, args{ctx, ok}, true},
|
||||
{"fail token", p, args{ctx, failToken}, true},
|
||||
{"fail disabled", pDisabled, args{ctx, ok}, true},
|
||||
}
|
||||
|
||||
@@ -77,6 +77,13 @@ func (p Uninitialized) MarshalJSON() ([]byte, error) {
|
||||
// the understanding that we are not following security best practices
|
||||
var ErrAllowTokenReuse = stderrors.New("allow token reuse")
|
||||
|
||||
// ErrTokenFlowNotSupported is an error that is returned by provisioners on
|
||||
// GetTokenID when the use of tokens is not supported.
|
||||
var ErrTokenFlowNotSupported = stderrors.New("token flow is not supported")
|
||||
|
||||
// ErrNotImplemented is an error returned when one method is not implemented.
|
||||
var ErrNotImplemented = stderrors.New("not implemented")
|
||||
|
||||
// Audiences stores all supported audiences by request type.
|
||||
type Audiences struct {
|
||||
Sign []string
|
||||
|
||||
@@ -95,9 +95,10 @@ func (s *SCEP) GetEncryptedKey() (string, string, bool) {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// GetTokenID returns the identifier of the token.
|
||||
// GetTokenID returns the identifier of the token. This provisioner will always
|
||||
// return [ErrTokenFlowNotSupported].
|
||||
func (s *SCEP) GetTokenID(string) (string, error) {
|
||||
return "", errors.New("scep provisioner does not implement GetTokenID")
|
||||
return "", ErrTokenFlowNotSupported
|
||||
}
|
||||
|
||||
// GetOptions returns the configured provisioner options.
|
||||
|
||||
@@ -26,6 +26,44 @@ import (
|
||||
"go.step.sm/crypto/x509util"
|
||||
)
|
||||
|
||||
func generateSCEP(t *testing.T) *SCEP {
|
||||
t.Helper()
|
||||
|
||||
ca, err := minica.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := ca.Sign(&x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "SCEP decrypter"},
|
||||
PublicKey: key.Public(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE", Bytes: cert.Raw,
|
||||
})
|
||||
|
||||
block, err := pemutil.Serialize(key, pemutil.WithPassword([]byte("password")))
|
||||
require.NoError(t, err)
|
||||
keyPEM := pem.EncodeToMemory(block)
|
||||
|
||||
p := &SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "scep",
|
||||
ChallengePassword: "password123",
|
||||
MinimumPublicKeyLength: 0,
|
||||
DecrypterCertificate: certPEM,
|
||||
DecrypterKeyPEM: keyPEM,
|
||||
DecrypterKeyPassword: "password",
|
||||
EncryptionAlgorithmIdentifier: 0,
|
||||
}
|
||||
require.NoError(t, p.Init(Config{Claims: globalProvisionerClaims}))
|
||||
return p
|
||||
|
||||
}
|
||||
|
||||
func Test_challengeValidationController_Validate(t *testing.T) {
|
||||
dummyCSR := &x509.CertificateRequest{
|
||||
Raw: []byte{1},
|
||||
@@ -722,3 +760,17 @@ func TestSCEP_Init(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCEP_Getters(t *testing.T) {
|
||||
p := generateSCEP(t)
|
||||
assert.Equal(t, "scep/scep", p.GetID())
|
||||
assert.Equal(t, "scep", p.GetName())
|
||||
assert.Equal(t, TypeSCEP, p.GetType())
|
||||
kid, key, ok := p.GetEncryptedKey()
|
||||
if kid != "" || key != "" || ok == true {
|
||||
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false)
|
||||
}
|
||||
tokenID, err := p.GetTokenID("token")
|
||||
assert.Empty(t, tokenID)
|
||||
assert.Equal(t, ErrTokenFlowNotSupported, err)
|
||||
}
|
||||
|
||||
@@ -193,8 +193,11 @@ func (p *SSHPOP) AuthorizeSSHRevoke(_ context.Context, token string) error {
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
|
||||
}
|
||||
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
|
||||
return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number")
|
||||
if serial := strconv.FormatUint(claims.sshCert.Serial, 10); claims.Subject != serial {
|
||||
return errs.Forbidden(
|
||||
"token subject %q and sshpop certificate serial number %q do not match",
|
||||
claims.Subject, serial,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -268,8 +268,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusBadRequest,
|
||||
err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
|
||||
code: http.StatusForbidden,
|
||||
err: errors.New(`token subject "foo" and sshpop certificate serial number "0" do not match`),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
||||
@@ -154,7 +154,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
|
||||
return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature")
|
||||
}
|
||||
|
||||
// Using the leaf certificates key to validate the claims accomplishes two
|
||||
// Using the leaf certificate's key to validate the claims accomplishes two
|
||||
// things:
|
||||
// 1. Asserts that the private key used to sign the token corresponds
|
||||
// to the public certificate in the `x5c` header of the token.
|
||||
|
||||
@@ -568,12 +568,30 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
|
||||
"ok": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
|
||||
require.NoError(t, err)
|
||||
serialNumber := certs[0].SerialNumber.String()
|
||||
jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key")
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
require.NoError(t, err)
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.Revoke[0], "",
|
||||
tok, err := generateToken(serialNumber, p.GetName(), testAudiences.Revoke[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk,
|
||||
withX5CHdr(certs))
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
token: tok,
|
||||
}
|
||||
},
|
||||
"ok/different-serial-number": func(t *testing.T) test {
|
||||
certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt")
|
||||
require.NoError(t, err)
|
||||
jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key")
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := generateX5C(nil)
|
||||
require.NoError(t, err)
|
||||
tok, err := generateToken("123456789", p.GetName(), testAudiences.Revoke[0], "",
|
||||
[]string{"test.smallstep.com"}, time.Now(), jwk,
|
||||
withX5CHdr(certs))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -622,6 +622,16 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
||||
return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke", opts...)
|
||||
}
|
||||
|
||||
// Verify that the serial in the token matches the serial from the request.
|
||||
if revokeOpts.Serial != claims.Subject {
|
||||
return errs.ApplyOptions(
|
||||
errs.Forbidden(
|
||||
"request serial number %q and token subject %q do not match",
|
||||
revokeOpts.Serial, claims.Subject,
|
||||
), opts...,
|
||||
)
|
||||
}
|
||||
|
||||
// This method will also validate the audiences for JWK provisioners.
|
||||
p, err := a.LoadProvisionerByToken(token, &claims.Claims)
|
||||
if err != nil {
|
||||
|
||||
@@ -1652,6 +1652,42 @@ func TestAuthority_Revoke(t *testing.T) {
|
||||
},
|
||||
}
|
||||
},
|
||||
"fail/serial-number": func() test {
|
||||
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
MGetCertificate: func(sn string) (*x509.Certificate, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
}))
|
||||
|
||||
cl := jose.Claims{
|
||||
Subject: "token-sn",
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
|
||||
Audience: validAudience,
|
||||
ID: "44",
|
||||
}
|
||||
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
return test{
|
||||
auth: _a,
|
||||
ctx: tlsRevokeCtx,
|
||||
opts: &RevokeOptions{
|
||||
Serial: "request-sn",
|
||||
ReasonCode: reasonCode,
|
||||
Reason: reason,
|
||||
OTT: raw,
|
||||
},
|
||||
err: errors.New(`request serial number "request-sn" and token subject "token-sn" do not match`),
|
||||
code: http.StatusForbidden,
|
||||
checkErrDetails: func(err *errs.Error) {
|
||||
assert.Equal(t, raw, err.Details["token"])
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/token": func() test {
|
||||
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
@@ -1980,7 +2016,7 @@ func TestAuthority_CRL(t *testing.T) {
|
||||
sn := fmt.Sprintf("%v", i)
|
||||
|
||||
cl := jose.Claims{
|
||||
Subject: fmt.Sprintf("sn-%v", i),
|
||||
Subject: sn,
|
||||
Issuer: validIssuer,
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
|
||||
|
||||
@@ -14,9 +14,23 @@ import (
|
||||
// Option modifies the Error type.
|
||||
type Option func(e *Error) error
|
||||
|
||||
// withDefaultMessage returns an Option that modifies the error by overwriting the
|
||||
// message only if it is empty.
|
||||
func withDefaultMessage(format string, args ...interface{}) Option {
|
||||
// withDefaultMessage returns an Option that modifies the error by overwriting
|
||||
// the message only if it is empty. Having withDefaultMessage and
|
||||
// withFormattedMessage avoid vet errors when the "format" passed to
|
||||
// "fmt.Sprintf" is not a constant.
|
||||
func withDefaultMessage(message string) Option {
|
||||
return func(e *Error) error {
|
||||
if e.Msg != "" {
|
||||
return e
|
||||
}
|
||||
e.Msg = message
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
||||
// withFormattedMessage returns an Option that modifies the error by overwriting
|
||||
// the formatted message only if it is empty.
|
||||
func withFormattedMessage(format string, args ...interface{}) Option {
|
||||
return func(e *Error) error {
|
||||
if e.Msg != "" {
|
||||
return e
|
||||
@@ -27,7 +41,7 @@ func withDefaultMessage(format string, args ...interface{}) Option {
|
||||
}
|
||||
|
||||
// WithMessage returns an Option that modifies the error by overwriting the
|
||||
// message only if it is empty.
|
||||
// message with the formatted string.
|
||||
func WithMessage(format string, args ...interface{}) Option {
|
||||
return func(e *Error) error {
|
||||
e.Msg = fmt.Sprintf(format, args...)
|
||||
@@ -35,6 +49,15 @@ func WithMessage(format string, args ...interface{}) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithErrorMessage returns an Option that modifies the error by overwriting the
|
||||
// message with the error string.
|
||||
func WithErrorMessage() Option {
|
||||
return func(e *Error) error {
|
||||
e.Msg = e.Error()
|
||||
return e
|
||||
}
|
||||
}
|
||||
|
||||
// WithKeyVal returns an Option that adds the given key-value pair to the
|
||||
// Error details. This is helpful for debugging errors.
|
||||
func WithKeyVal(key string, val interface{}) Option {
|
||||
@@ -183,7 +206,8 @@ func StatusCodeError(code int, e error, opts ...Option) error {
|
||||
}
|
||||
|
||||
const (
|
||||
seeLogs = "Please see the certificate authority logs for more info."
|
||||
seeLogs = "Please see the certificate authority logs for more info."
|
||||
defaultMsg = "The requested could not be completed. " + seeLogs
|
||||
// BadRequestDefaultMsg 400 default msg
|
||||
BadRequestDefaultMsg = "The request could not be completed; malformed or missing data. " + seeLogs
|
||||
// UnauthorizedDefaultMsg 401 default msg
|
||||
@@ -198,6 +222,25 @@ const (
|
||||
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
|
||||
)
|
||||
|
||||
func defaultMessage(status int) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return BadRequestDefaultMsg
|
||||
case http.StatusUnauthorized:
|
||||
return UnauthorizedDefaultMsg
|
||||
case http.StatusForbidden:
|
||||
return ForbiddenDefaultMsg
|
||||
case http.StatusNotFound:
|
||||
return NotFoundDefaultMsg
|
||||
case http.StatusInternalServerError:
|
||||
return InternalServerErrorDefaultMsg
|
||||
case http.StatusNotImplemented:
|
||||
return NotImplementedDefaultMsg
|
||||
default:
|
||||
return defaultMsg
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// BadRequestPrefix is the prefix added to the bad request messages that are
|
||||
// directly sent to the cli.
|
||||
@@ -292,7 +335,7 @@ func NewErr(status int, err error, opts ...Option) error {
|
||||
// Errorf creates a new error using the given format and status code.
|
||||
func Errorf(code int, format string, args ...interface{}) error {
|
||||
as, opts := splitOptionArgs(args)
|
||||
opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg))
|
||||
opts = append(opts, withDefaultMessage(defaultMessage(code)))
|
||||
e := &Error{Status: code, Err: fmt.Errorf(format, as...)}
|
||||
for _, o := range opts {
|
||||
o(e)
|
||||
@@ -384,7 +427,6 @@ func NotFoundErr(err error, opts ...Option) error {
|
||||
// UnexpectedErr will be used when the certificate authority makes an outgoing
|
||||
// request and receives an unhandled status code.
|
||||
func UnexpectedErr(code int, err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage("The certificate authority received an "+
|
||||
"unexpected HTTP status code - '%d'. "+seeLogs, code))
|
||||
opts = append(opts, withFormattedMessage("The certificate authority received an unexpected HTTP status code - '%d'. "+seeLogs, code))
|
||||
return NewErr(code, err, opts...)
|
||||
}
|
||||
|
||||
@@ -127,3 +127,101 @@ func TestError_Unwrap_As(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorf(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
format string
|
||||
args []any
|
||||
want error
|
||||
}{
|
||||
{"bad request", 400, "test error string", nil, &Error{
|
||||
Status: 400,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: BadRequestDefaultMsg,
|
||||
}},
|
||||
{"unauthorized", 401, "test error string", nil, &Error{
|
||||
Status: 401,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: UnauthorizedDefaultMsg,
|
||||
}},
|
||||
{"forbidden", 403, "test error string", nil, &Error{
|
||||
Status: 403,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: ForbiddenDefaultMsg,
|
||||
}},
|
||||
{"not found", 404, "test error string", nil, &Error{
|
||||
Status: 404,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: NotFoundDefaultMsg,
|
||||
}},
|
||||
{"internal server error", 500, "test error string", nil, &Error{
|
||||
Status: 500,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: InternalServerErrorDefaultMsg,
|
||||
}},
|
||||
{"not implemented", 501, "test error string", nil, &Error{
|
||||
Status: 501,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: NotImplementedDefaultMsg,
|
||||
}},
|
||||
{"other", 502, "test error string", nil, &Error{
|
||||
Status: 502,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: defaultMsg,
|
||||
}},
|
||||
{"formatted args", 401, "test error string: %s", []any{"some reason"}, &Error{
|
||||
Status: 401,
|
||||
Err: errors.New("test error string: some reason"),
|
||||
Msg: UnauthorizedDefaultMsg,
|
||||
}},
|
||||
{"WithMessage", 403, "test error string", []any{WithMessage("%s failed", "something")}, &Error{
|
||||
Status: 403,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: "something failed",
|
||||
}},
|
||||
{"WithErrorMessage", 404, "test error string", []any{WithErrorMessage()}, &Error{
|
||||
Status: 404,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: "test error string",
|
||||
}},
|
||||
{"WithKeyValue", 500, "test error string", []any{WithKeyVal("foo", 1), WithKeyVal("bar", "zar")}, &Error{
|
||||
Status: 500,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: InternalServerErrorDefaultMsg,
|
||||
Details: map[string]interface{}{"foo": 1, "bar": "zar"},
|
||||
}},
|
||||
{"withDefaultMessage", 501, "test error string", []any{withDefaultMessage("some message")}, &Error{
|
||||
Status: 501,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: "some message",
|
||||
}},
|
||||
{"withFormattedMessage", 502, "test error string", []any{withFormattedMessage("some message: %s", "the reason")}, &Error{
|
||||
Status: 502,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: "some message: the reason",
|
||||
}},
|
||||
{"WithMessage and withDefaultMessage", 500, "test error string", []any{WithMessage("the message"), withDefaultMessage("some message")}, &Error{
|
||||
Status: 500,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: "the message",
|
||||
}},
|
||||
{"WithErrorMessage and withFormattedMessage", 500, "test error string", []any{WithErrorMessage(), withFormattedMessage("some message: %s", "the reason")}, &Error{
|
||||
Status: 500,
|
||||
Err: errors.New("test error string"),
|
||||
Msg: "test error string",
|
||||
}},
|
||||
{"formatted args and withMessage", 500, "test error string: %s, code %d", []any{"reason", 1234, WithMessage("the message")}, &Error{
|
||||
Status: 500,
|
||||
Err: errors.New("test error string: reason, code 1234"),
|
||||
Msg: "the message",
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotErr := Errorf(tt.code, tt.format, tt.args...)
|
||||
assert.Equal(t, tt.want, gotErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,8 @@ func Test_reflectRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
dir := t.TempDir()
|
||||
t.Setenv("STEPPATH", dir)
|
||||
|
||||
m, err := minica.New(minica.WithName("Step E2E"))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -103,6 +103,8 @@ func newTestCA(t *testing.T, name string) *testCA {
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
t.Setenv("STEPPATH", dir)
|
||||
|
||||
m, err := minica.New(minica.WithName(name), minica.WithGetSignerFunc(func() (crypto.Signer, error) {
|
||||
return signer, nil
|
||||
}))
|
||||
|
||||
@@ -34,6 +34,8 @@ func TestIssuesCertificateUsingSCEPWithDecrypterAndUpstreamCAS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
t.Setenv("STEPPATH", dir)
|
||||
|
||||
m, err := minica.New(minica.WithName("Step E2E | SCEP Decrypter w/ Upstream CAS"), minica.WithGetSignerFunc(func() (crypto.Signer, error) {
|
||||
return signer, nil
|
||||
}))
|
||||
|
||||
@@ -32,6 +32,8 @@ func TestIssuesCertificateUsingSCEPWithDecrypter(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
t.Setenv("STEPPATH", dir)
|
||||
|
||||
m, err := minica.New(minica.WithName("Step E2E | SCEP Decrypter"), minica.WithGetSignerFunc(func() (crypto.Signer, error) {
|
||||
return signer, nil
|
||||
}))
|
||||
|
||||
@@ -29,6 +29,8 @@ func TestFailsIssuingCertificateUsingRegularSCEPWithUpstreamCAS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
t.Setenv("STEPPATH", dir)
|
||||
|
||||
m, err := minica.New(minica.WithName("Step E2E | SCEP Regular w/ Upstream CAS"), minica.WithGetSignerFunc(func() (crypto.Signer, error) {
|
||||
return signer, nil
|
||||
}))
|
||||
|
||||
Reference in New Issue
Block a user