From 1011f5f5408b470a636f583bf74c0d7bbaf75d72 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 2 Dec 2025 16:54:44 -0800 Subject: [PATCH] Improve validation in authorization path --- acme/challenge.go | 4 +- acme/challenge_test.go | 28 ++--- api/renew.go | 3 +- api/revoke.go | 5 +- api/sshRevoke.go | 3 +- authority/authority_test.go | 4 + authority/authorize.go | 49 +++++---- authority/authorize_test.go | 19 ++++ authority/provisioner/acme.go | 5 +- authority/provisioner/acme_test.go | 3 + authority/provisioner/collection.go | 2 +- authority/provisioner/collection_test.go | 107 ++++++++++++++------ authority/provisioner/k8sSA.go | 2 +- authority/provisioner/nebula.go | 23 ++--- authority/provisioner/nebula_test.go | 4 +- authority/provisioner/provisioner.go | 7 ++ authority/provisioner/scep.go | 5 +- authority/provisioner/scep_test.go | 52 ++++++++++ authority/provisioner/sshpop.go | 7 +- authority/provisioner/sshpop_test.go | 4 +- authority/provisioner/x5c.go | 2 +- authority/provisioner/x5c_test.go | 20 +++- authority/tls.go | 10 ++ authority/tls_test.go | 38 ++++++- errs/error.go | 58 +++++++++-- errs/errors_test.go | 98 ++++++++++++++++++ test/integration/requestid_test.go | 2 + test/integration/scep/common_test.go | 2 + test/integration/scep/decrypter_cas_test.go | 2 + test/integration/scep/decrypter_test.go | 2 + test/integration/scep/regular_cas_test.go | 2 + 31 files changed, 464 insertions(+), 108 deletions(-) diff --git a/acme/challenge.go b/acme/challenge.go index c0f9425c..17ca0ab8 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -792,6 +792,9 @@ type attestationObject struct { // TODO(bweeks): move attestation verification to a shared package. func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, payload []byte) error { + // Update challenge with the payload + ch.Payload = payload + // Load authorization to store the key fingerprint. az, err := db.GetAuthorization(ctx, ch.AuthorizationID) if err != nil { @@ -946,7 +949,6 @@ func deviceAttest01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose ch.Status = StatusValid ch.Error = nil ch.ValidatedAt = clock.Now().Format(time.RFC3339) - ch.Payload = payload ch.PayloadFormat = format // Store the fingerprint in the authorization. diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 0e630637..5702e42c 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -878,7 +878,7 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k= assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error") @@ -4077,7 +4077,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, errorPayload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewError(ErrorRejectedIdentifierType, "payload contained error: an error") @@ -4117,7 +4117,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, errorBase64Payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "failed base64 decoding attObj %q", "?!") @@ -4157,7 +4157,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, emptyPayload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty") @@ -4197,7 +4197,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, emptyObjectPayload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "attObj must not be empty") @@ -4237,7 +4237,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, errorNonWellformedCBORPayload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "attObj is not well formed CBOR: unexpected EOF") @@ -4279,7 +4279,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, errorUnsupportedFormat, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "unsupported attestation object format %q", "unsupported-format") @@ -4326,7 +4326,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewError(ErrorBadAttestationStatementType, "attestation format %q is not enabled", "step") @@ -4383,7 +4383,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "x5c not present") @@ -4432,7 +4432,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "serial-number", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "challenge token does not match") @@ -4480,7 +4480,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "non-matching-value", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) subproblem := NewSubproblemWithIdentifier( @@ -4560,7 +4560,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "x5c not present") @@ -4616,7 +4616,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, "permanent identifier does not match"). @@ -4713,7 +4713,7 @@ func Test_deviceAttest01Validate(t *testing.T) { assert.Equal(t, StatusInvalid, updch.Status) assert.Equal(t, ChallengeType("device-attest-01"), updch.Type) assert.Equal(t, "12345678", updch.Value) - assert.Nil(t, updch.Payload) + assert.Equal(t, payload, updch.Payload) assert.Empty(t, updch.PayloadFormat) err := NewDetailedError(ErrorBadAttestationStatementType, `unsupported attestation object format "bogus-format"`) diff --git a/api/renew.go b/api/renew.go index 7cd3707d..d1cab914 100644 --- a/api/renew.go +++ b/api/renew.go @@ -7,6 +7,7 @@ import ( "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" ) @@ -60,7 +61,7 @@ func getPeerCertificate(r *http.Request) (*x509.Certificate, string, error) { } if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { - ctx := r.Context() + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RenewMethod) peer, err := mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) return peer, parts[1], err } diff --git a/api/revoke.go b/api/revoke.go index 41969c08..7d87646b 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -94,8 +94,9 @@ func Revoke(w http.ResponseWriter, r *http.Request) { return } opts.Crt = r.TLS.PeerCertificates[0] - if opts.Crt.SerialNumber.String() != opts.Serial { - render.Error(w, r, errs.BadRequest("serial number in client certificate different than body")) + if serialNumber := opts.Crt.SerialNumber.String(); opts.Serial != serialNumber { + render.Error(w, r, errs.Forbidden( + "request serial number %q and certificate serial number %q do not match", opts.Serial, serialNumber)) return } // TODO: should probably be checking if the certificate was revoked here. diff --git a/api/sshRevoke.go b/api/sshRevoke.go index 2fe49199..68e9a2be 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -70,14 +70,13 @@ func SSHRevoke(w http.ResponseWriter, r *http.Request) { ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) a := mustAuthority(ctx) - // A token indicates that we are using the api via a provisioner token, - // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, r, errs.UnauthorizedErr(err)) return } + opts.OTT = body.OTT if err := a.Revoke(ctx, opts); err != nil { diff --git a/authority/authority_test.go b/authority/authority_test.go index 5ad7b747..dca6cf14 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -80,6 +80,10 @@ func testAuthority(t *testing.T, opts ...Option) *Authority { EnableSSHCA: &enableSSHCA, }, }, + &provisioner.ACME{ + Name: "acme", + Type: "ACME", + }, &provisioner.JWK{ Name: "uninitialized", Type: "JWK", diff --git a/authority/authorize.go b/authority/authorize.go index 7e85bfc2..74ed8936 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -93,7 +93,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // Store the token to protect against reuse unless it's skipped. // If we cannot get a token id from the provisioner, just hash the token. if !SkipTokenReuseFromContext(ctx) { - if err := a.UseToken(token, p); err != nil { + if err := a.UseToken(ctx, token, p); err != nil { return nil, err } } @@ -138,7 +138,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc } // Check that the token has not been used. - if err := a.UseToken(token, prov); err != nil { + if err := a.UseToken(r.Context(), token, prov); err != nil { return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token") } @@ -193,22 +193,35 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc // UseToken stores the token to protect against reuse. // -// This method currently ignores any error coming from the GetTokenID, but it -// should specifically ignore the error provisioner.ErrAllowTokenReuse. -func (a *Authority) UseToken(token string, prov provisioner.Interface) error { - if reuseKey, err := prov.GetTokenID(token); err == nil { - if reuseKey == "" { - sum := sha256.Sum256([]byte(token)) - reuseKey = strings.ToLower(hex.EncodeToString(sum[:])) - } - ok, err := a.db.UseToken(reuseKey, token) - if err != nil { - return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token") - } - if !ok { - return errs.Unauthorized("token already used") +// This method currently ignores most errors coming from the GetTokenID because +// the token is already validated. But it should specifically ignore the errors +// provisioner.ErrAllowTokenReuse, provisioner.ErrNotImplemented, and +// provisioner.ErrTokenFlowNotSupported unless this latter one used in a renewal +// flow without mTLS. +func (a *Authority) UseToken(ctx context.Context, token string, prov provisioner.Interface) error { + reuseKey, err := prov.GetTokenID(token) + if err != nil { + // Fail on ErrTokenFlowNotSupported but allow x5cInsecure renew token + if errors.Is(err, provisioner.ErrTokenFlowNotSupported) && provisioner.RenewMethod != provisioner.MethodFromContext(ctx) { + return errs.BadRequest("token flow is not supported") } + + return nil } + + if reuseKey == "" { + sum := sha256.Sum256([]byte(token)) + reuseKey = strings.ToLower(hex.EncodeToString(sum[:])) + } + + ok, err := a.db.UseToken(reuseKey, token) + if err != nil { + return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token") + } + if !ok { + return errs.Unauthorized("token already used") + } + return nil } @@ -398,7 +411,7 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error // AuthorizeRenewToken validates the renew token and returns the leaf // certificate in the x5cInsecure header. -func (a *Authority) AuthorizeRenewToken(_ context.Context, ott string) (*x509.Certificate, error) { +func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) { var claims jose.Claims jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs) if err != nil { @@ -413,7 +426,7 @@ func (a *Authority) AuthorizeRenewToken(_ context.Context, ott string) (*x509.Ce if err != nil { return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate") } - if err := a.UseToken(ott, p); err != nil { + if err := a.UseToken(ctx, ott, p); err != nil { return nil, err } diff --git a/authority/authorize_test.go b/authority/authorize_test.go index f7287e7a..48033195 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -193,6 +193,25 @@ func TestAuthority_authorizeToken(t *testing.T) { code: http.StatusUnauthorized, } }, + "fail/token-flow-not-supported": func(t *testing.T) *authorizeTest { + cl := jose.Claims{ + Subject: "test.smallstep.com", + Issuer: validIssuer, + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(time.Minute)), + IssuedAt: jose.NewNumericDate(now), + Audience: []string{"acme/acme"}, + ID: "45", + } + raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() + assert.FatalError(t, err) + return &authorizeTest{ + auth: a, + token: raw, + err: errors.New("token flow is not supported"), + code: http.StatusBadRequest, + } + }, "ok/simpledb": func(t *testing.T) *authorizeTest { cl := jose.Claims{ Subject: "test.smallstep.com", diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 32a0bdf0..b014e675 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -138,9 +138,10 @@ func (p *ACME) GetIDForToken() string { return "acme/" + p.Name } -// GetTokenID returns the identifier of the token. +// GetTokenID returns the identifier of the token. This provisioner will always +// return [ErrTokenFlowNotSupported]. func (p *ACME) GetTokenID(string) (string, error) { - return "", errors.New("acme provisioner does not implement GetTokenID") + return "", ErrTokenFlowNotSupported } // GetName returns the name of the provisioner. diff --git a/authority/provisioner/acme_test.go b/authority/provisioner/acme_test.go index 96f4bd8b..f5169809 100644 --- a/authority/provisioner/acme_test.go +++ b/authority/provisioner/acme_test.go @@ -82,6 +82,9 @@ func TestACME_Getters(t *testing.T) { t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) } + tokenID, err := p.GetTokenID("token") + assert.Empty(t, tokenID) + assert.Equal(t, ErrTokenFlowNotSupported, err) } func TestACME_Init(t *testing.T) { diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index d10d0135..a28ab6d0 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -77,7 +77,7 @@ func (c *Collection) LoadByName(name string) (Interface, bool) { } // LoadByTokenID a provisioner by identifier found in token. -// For different provisioner types this identifier may be found in in different +// For different provisioner types this identifier may be found in different // attributes of the token. func (c *Collection) LoadByTokenID(tokenProvisionerID string) (Interface, bool) { return loadProvisioner(c.byTokenID, tokenProvisionerID) diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go index 24db4593..68e926a8 100644 --- a/authority/provisioner/collection_test.go +++ b/authority/provisioner/collection_test.go @@ -8,13 +8,15 @@ import ( "sync" "testing" - "github.com/smallstep/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" ) func TestCollection_Load(t *testing.T) { p, err := generateJWK() - assert.FatalError(t, err) + require.NoError(t, err) byID := new(sync.Map) byID.Store(p.GetID(), p) byID.Store("string", "a-string") @@ -52,15 +54,60 @@ func TestCollection_Load(t *testing.T) { } } +func TestCollection_LoadByTokenID(t *testing.T) { + p1, err := generateJWK() + require.NoError(t, err) + p2, err := generateACME() + require.NoError(t, err) + + byTokenID := new(sync.Map) + byTokenID.Store(p1.GetIDForToken(), p1) + byTokenID.Store(p2.GetIDForToken(), p2) + byTokenID.Store("string", "a-string") + + type fields struct { + byTokenID *sync.Map + } + type args struct { + id string + } + tests := []struct { + name string + fields fields + args args + want Interface + want1 bool + }{ + {"ok jwk", fields{byTokenID}, args{p1.GetIDForToken()}, p1, true}, + {"ok acme", fields{byTokenID}, args{p2.GetIDForToken()}, p2, true}, + {"fail missing", fields{byTokenID}, args{"missing"}, nil, false}, + {"invalid", fields{byTokenID}, args{"string"}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Collection{ + byTokenID: tt.fields.byTokenID, + } + got, got1 := c.LoadByTokenID(tt.args.id) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.Load() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.Load() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + func TestCollection_LoadByToken(t *testing.T) { p1, err := generateJWK() - assert.FatalError(t, err) + require.NoError(t, err) p2, err := generateJWK() - assert.FatalError(t, err) + require.NoError(t, err) p3, err := generateOIDC() - assert.FatalError(t, err) + require.NoError(t, err) p4, err := generateK8sSA(nil) - assert.FatalError(t, err) + require.NoError(t, err) byID := new(sync.Map) byID.Store(p1.GetID(), p1) @@ -75,35 +122,35 @@ func TestCollection_LoadByToken(t *testing.T) { byID2.Store(p3.GetID(), p3) jwk, err := decryptJSONWebKey(p1.EncryptedKey) - assert.FatalError(t, err) + require.NoError(t, err) token, err := generateSimpleToken(p1.Name, testAudiences.Sign[0], jwk) - assert.FatalError(t, err) + require.NoError(t, err) t1, c1, err := parseToken(token) - assert.FatalError(t, err) + require.NoError(t, err) jwk, err = decryptJSONWebKey(p2.EncryptedKey) - assert.FatalError(t, err) + require.NoError(t, err) token, err = generateSimpleToken(p2.Name, testAudiences.Sign[1], jwk) - assert.FatalError(t, err) + require.NoError(t, err) t2, c2, err := parseToken(token) - assert.FatalError(t, err) + require.NoError(t, err) token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0]) - assert.FatalError(t, err) + require.NoError(t, err) t3, c3, err := parseToken(token) - assert.FatalError(t, err) + require.NoError(t, err) token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0]) - assert.FatalError(t, err) + require.NoError(t, err) t4, c4, err := parseToken(token) - assert.FatalError(t, err) + require.NoError(t, err) jwk, err = jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) + require.NoError(t, err) token, err = generateK8sSAToken(jwk, nil) - assert.FatalError(t, err) + require.NoError(t, err) t5, c5, err := parseToken(token) - assert.FatalError(t, err) + require.NoError(t, err) type fields struct { byID *sync.Map @@ -159,11 +206,11 @@ func TestCollection_LoadByCertificate(t *testing.T) { } p1, err := generateJWK() - assert.FatalError(t, err) + require.NoError(t, err) p2, err := generateOIDC() - assert.FatalError(t, err) + require.NoError(t, err) p3, err := generateACME() - assert.FatalError(t, err) + require.NoError(t, err) byName := new(sync.Map) byName.Store(p1.GetName(), p1) @@ -229,11 +276,11 @@ func TestCollection_LoadByCertificate(t *testing.T) { func TestCollection_LoadEncryptedKey(t *testing.T) { c := NewCollection(testAudiences) p1, err := generateJWK() - assert.FatalError(t, err) - assert.FatalError(t, c.Store(p1)) + require.NoError(t, err) + require.NoError(t, c.Store(p1)) p2, err := generateOIDC() - assert.FatalError(t, err) - assert.FatalError(t, c.Store(p2)) + require.NoError(t, err) + require.NoError(t, c.Store(p2)) // Add oidc in byKey. // It should not happen. @@ -269,9 +316,9 @@ func TestCollection_LoadEncryptedKey(t *testing.T) { func TestCollection_Store(t *testing.T) { c := NewCollection(testAudiences) p1, err := generateJWK() - assert.FatalError(t, err) + require.NoError(t, err) p2, err := generateOIDC() - assert.FatalError(t, err) + require.NoError(t, err) type args struct { p Interface @@ -297,7 +344,7 @@ func TestCollection_Store(t *testing.T) { func TestCollection_Find(t *testing.T) { c, err := generateCollection(10, 10) - assert.FatalError(t, err) + require.NoError(t, err) trim := func(s string) string { return strings.TrimLeft(s, "0") @@ -391,7 +438,7 @@ func Test_matchesAudience(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - assert.Equals(t, tc.exp, matchesAudience(tc.a, tc.b)) + assert.Equal(t, tc.exp, matchesAudience(tc.a, tc.b)) }) } } diff --git a/authority/provisioner/k8sSA.go b/authority/provisioner/k8sSA.go index fe60e95a..54d098c6 100644 --- a/authority/provisioner/k8sSA.go +++ b/authority/provisioner/k8sSA.go @@ -73,7 +73,7 @@ func (p *K8sSA) GetIDForToken() string { // GetTokenID returns an unimplemented error and does not use the input ott. func (p *K8sSA) GetTokenID(string) (string, error) { - return "", errors.New("not implemented") + return "", ErrNotImplemented } // GetName returns the name of the provisioner. diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index e6de126e..1c9c9754 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -283,20 +283,14 @@ func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) erro return p.ctl.AuthorizeRenew(ctx, crt) } -// AuthorizeRevoke returns an error if the token is not valid. -func (p *Nebula) AuthorizeRevoke(_ context.Context, token string) error { - return p.validateToken(token, p.ctl.Audiences.Revoke) +// AuthorizeRevoke returns an unauthorized error. +func (p *Nebula) AuthorizeRevoke(context.Context, string) error { + return errs.Unauthorized("nebula provisioner does not support revoke") } -// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. -func (p *Nebula) AuthorizeSSHRevoke(_ context.Context, token string) error { - if !p.ctl.Claimer.IsSSHCAEnabled() { - return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) - } - if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil { - return err - } - return nil +// AuthorizeSSHRevoke returns an unauthorized error. +func (p *Nebula) AuthorizeSSHRevoke(context.Context, string) error { + return errs.Unauthorized("nebula provisioner does not support SSH revoke") } // AuthorizeSSHRenew returns an unauthorized error. @@ -309,11 +303,6 @@ func (p *Nebula) AuthorizeSSHRekey(context.Context, string) (*ssh.Certificate, [ return nil, nil, errs.Unauthorized("nebula provisioner does not support SSH rekey") } -func (p *Nebula) validateToken(token string, audiences []string) error { - _, _, err := p.authorizeToken(token, audiences) - return err -} - func (p *Nebula) authorizeToken(token string, audiences []string) (*nebula.NebulaCertificate, *jwtPayload, error) { jwt, err := jose.ParseSigned(token) if err != nil { diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index 3e2d9780..5be7bfe7 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -709,7 +709,7 @@ func TestNebula_AuthorizeRevoke(t *testing.T) { args args wantErr bool }{ - {"ok", p, args{ctx, ok}, false}, + {"fail unauthorized", p, args{ctx, ok}, true}, {"fail token", p, args{ctx, failToken}, true}, } for _, tt := range tests { @@ -749,7 +749,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) { args args wantErr bool }{ - {"ok", p, args{ctx, ok}, false}, + {"fail unauthorized", p, args{ctx, ok}, true}, {"fail token", p, args{ctx, failToken}, true}, {"fail disabled", pDisabled, args{ctx, ok}, true}, } diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 25a8f23a..33d75fe9 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -77,6 +77,13 @@ func (p Uninitialized) MarshalJSON() ([]byte, error) { // the understanding that we are not following security best practices var ErrAllowTokenReuse = stderrors.New("allow token reuse") +// ErrTokenFlowNotSupported is an error that is returned by provisioners on +// GetTokenID when the use of tokens is not supported. +var ErrTokenFlowNotSupported = stderrors.New("token flow is not supported") + +// ErrNotImplemented is an error returned when one method is not implemented. +var ErrNotImplemented = stderrors.New("not implemented") + // Audiences stores all supported audiences by request type. type Audiences struct { Sign []string diff --git a/authority/provisioner/scep.go b/authority/provisioner/scep.go index a97ff8e5..3ada91d0 100644 --- a/authority/provisioner/scep.go +++ b/authority/provisioner/scep.go @@ -95,9 +95,10 @@ func (s *SCEP) GetEncryptedKey() (string, string, bool) { return "", "", false } -// GetTokenID returns the identifier of the token. +// GetTokenID returns the identifier of the token. This provisioner will always +// return [ErrTokenFlowNotSupported]. func (s *SCEP) GetTokenID(string) (string, error) { - return "", errors.New("scep provisioner does not implement GetTokenID") + return "", ErrTokenFlowNotSupported } // GetOptions returns the configured provisioner options. diff --git a/authority/provisioner/scep_test.go b/authority/provisioner/scep_test.go index f520c931..b4e47a46 100644 --- a/authority/provisioner/scep_test.go +++ b/authority/provisioner/scep_test.go @@ -26,6 +26,44 @@ import ( "go.step.sm/crypto/x509util" ) +func generateSCEP(t *testing.T) *SCEP { + t.Helper() + + ca, err := minica.New() + require.NoError(t, err) + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + cert, err := ca.Sign(&x509.Certificate{ + Subject: pkix.Name{CommonName: "SCEP decrypter"}, + PublicKey: key.Public(), + }) + require.NoError(t, err) + + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", Bytes: cert.Raw, + }) + + block, err := pemutil.Serialize(key, pemutil.WithPassword([]byte("password"))) + require.NoError(t, err) + keyPEM := pem.EncodeToMemory(block) + + p := &SCEP{ + Type: "SCEP", + Name: "scep", + ChallengePassword: "password123", + MinimumPublicKeyLength: 0, + DecrypterCertificate: certPEM, + DecrypterKeyPEM: keyPEM, + DecrypterKeyPassword: "password", + EncryptionAlgorithmIdentifier: 0, + } + require.NoError(t, p.Init(Config{Claims: globalProvisionerClaims})) + return p + +} + func Test_challengeValidationController_Validate(t *testing.T) { dummyCSR := &x509.CertificateRequest{ Raw: []byte{1}, @@ -722,3 +760,17 @@ func TestSCEP_Init(t *testing.T) { }) } } + +func TestSCEP_Getters(t *testing.T) { + p := generateSCEP(t) + assert.Equal(t, "scep/scep", p.GetID()) + assert.Equal(t, "scep", p.GetName()) + assert.Equal(t, TypeSCEP, p.GetType()) + kid, key, ok := p.GetEncryptedKey() + if kid != "" || key != "" || ok == true { + t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)", kid, key, ok, "", "", false) + } + tokenID, err := p.GetTokenID("token") + assert.Empty(t, tokenID) + assert.Equal(t, ErrTokenFlowNotSupported, err) +} diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index 7149dc95..33c1e8a0 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -193,8 +193,11 @@ func (p *SSHPOP) AuthorizeSSHRevoke(_ context.Context, token string) error { if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") } - if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { - return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number") + if serial := strconv.FormatUint(claims.sshCert.Serial, 10); claims.Subject != serial { + return errs.Forbidden( + "token subject %q and sshpop certificate serial number %q do not match", + claims.Subject, serial, + ) } return nil } diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index ae75b349..83001bfb 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -268,8 +268,8 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { return test{ p: p, token: tok, - code: http.StatusBadRequest, - err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"), + code: http.StatusForbidden, + err: errors.New(`token subject "foo" and sshpop certificate serial number "0" do not match`), } }, "ok": func(t *testing.T) test { diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 4dda1c69..eb6f88b1 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -154,7 +154,7 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err return nil, errs.Unauthorized("x5c.authorizeToken; certificate used to sign x5c token cannot be used for digital signature") } - // Using the leaf certificates key to validate the claims accomplishes two + // Using the leaf certificate's key to validate the claims accomplishes two // things: // 1. Asserts that the private key used to sign the token corresponds // to the public certificate in the `x5c` header of the token. diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index f394bc05..1e64bc0b 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -568,12 +568,30 @@ func TestX5C_AuthorizeRevoke(t *testing.T) { "ok": func(t *testing.T) test { certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") require.NoError(t, err) + serialNumber := certs[0].SerialNumber.String() jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") require.NoError(t, err) p, err := generateX5C(nil) require.NoError(t, err) - tok, err := generateToken("foo", p.GetName(), testAudiences.Revoke[0], "", + tok, err := generateToken(serialNumber, p.GetName(), testAudiences.Revoke[0], "", + []string{"test.smallstep.com"}, time.Now(), jwk, + withX5CHdr(certs)) + require.NoError(t, err) + return test{ + p: p, + token: tok, + } + }, + "ok/different-serial-number": func(t *testing.T) test { + certs, err := pemutil.ReadCertificateBundle("./testdata/certs/x5c-leaf.crt") + require.NoError(t, err) + jwk, err := jose.ReadKey("./testdata/secrets/x5c-leaf.key") + require.NoError(t, err) + + p, err := generateX5C(nil) + require.NoError(t, err) + tok, err := generateToken("123456789", p.GetName(), testAudiences.Revoke[0], "", []string{"test.smallstep.com"}, time.Now(), jwk, withX5CHdr(certs)) require.NoError(t, err) diff --git a/authority/tls.go b/authority/tls.go index ed599879..b5d3f549 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -622,6 +622,16 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error return errs.Wrap(http.StatusUnauthorized, err, "authority.Revoke", opts...) } + // Verify that the serial in the token matches the serial from the request. + if revokeOpts.Serial != claims.Subject { + return errs.ApplyOptions( + errs.Forbidden( + "request serial number %q and token subject %q do not match", + revokeOpts.Serial, claims.Subject, + ), opts..., + ) + } + // This method will also validate the audiences for JWK provisioners. p, err := a.LoadProvisionerByToken(token, &claims.Claims) if err != nil { diff --git a/authority/tls_test.go b/authority/tls_test.go index c7bd6f10..2fc66b4e 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -1652,6 +1652,42 @@ func TestAuthority_Revoke(t *testing.T) { }, } }, + "fail/serial-number": func() test { + _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MGetCertificate: func(sn string) (*x509.Certificate, error) { + return nil, errors.New("not found") + }, + })) + + cl := jose.Claims{ + Subject: "token-sn", + Issuer: validIssuer, + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(time.Minute)), + Audience: validAudience, + ID: "44", + } + raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() + require.NoError(t, err) + return test{ + auth: _a, + ctx: tlsRevokeCtx, + opts: &RevokeOptions{ + Serial: "request-sn", + ReasonCode: reasonCode, + Reason: reason, + OTT: raw, + }, + err: errors.New(`request serial number "request-sn" and token subject "token-sn" do not match`), + code: http.StatusForbidden, + checkErrDetails: func(err *errs.Error) { + assert.Equal(t, raw, err.Details["token"]) + }, + } + }, "ok/token": func() test { _a := testAuthority(t, WithDatabase(&db.MockAuthDB{ MUseToken: func(id, tok string) (bool, error) { @@ -1980,7 +2016,7 @@ func TestAuthority_CRL(t *testing.T) { sn := fmt.Sprintf("%v", i) cl := jose.Claims{ - Subject: fmt.Sprintf("sn-%v", i), + Subject: sn, Issuer: validIssuer, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(time.Minute)), diff --git a/errs/error.go b/errs/error.go index f01cb4d8..04de21dd 100644 --- a/errs/error.go +++ b/errs/error.go @@ -14,9 +14,23 @@ import ( // Option modifies the Error type. type Option func(e *Error) error -// withDefaultMessage returns an Option that modifies the error by overwriting the -// message only if it is empty. -func withDefaultMessage(format string, args ...interface{}) Option { +// withDefaultMessage returns an Option that modifies the error by overwriting +// the message only if it is empty. Having withDefaultMessage and +// withFormattedMessage avoid vet errors when the "format" passed to +// "fmt.Sprintf" is not a constant. +func withDefaultMessage(message string) Option { + return func(e *Error) error { + if e.Msg != "" { + return e + } + e.Msg = message + return e + } +} + +// withFormattedMessage returns an Option that modifies the error by overwriting +// the formatted message only if it is empty. +func withFormattedMessage(format string, args ...interface{}) Option { return func(e *Error) error { if e.Msg != "" { return e @@ -27,7 +41,7 @@ func withDefaultMessage(format string, args ...interface{}) Option { } // WithMessage returns an Option that modifies the error by overwriting the -// message only if it is empty. +// message with the formatted string. func WithMessage(format string, args ...interface{}) Option { return func(e *Error) error { e.Msg = fmt.Sprintf(format, args...) @@ -35,6 +49,15 @@ func WithMessage(format string, args ...interface{}) Option { } } +// WithErrorMessage returns an Option that modifies the error by overwriting the +// message with the error string. +func WithErrorMessage() Option { + return func(e *Error) error { + e.Msg = e.Error() + return e + } +} + // WithKeyVal returns an Option that adds the given key-value pair to the // Error details. This is helpful for debugging errors. func WithKeyVal(key string, val interface{}) Option { @@ -183,7 +206,8 @@ func StatusCodeError(code int, e error, opts ...Option) error { } const ( - seeLogs = "Please see the certificate authority logs for more info." + seeLogs = "Please see the certificate authority logs for more info." + defaultMsg = "The requested could not be completed. " + seeLogs // BadRequestDefaultMsg 400 default msg BadRequestDefaultMsg = "The request could not be completed; malformed or missing data. " + seeLogs // UnauthorizedDefaultMsg 401 default msg @@ -198,6 +222,25 @@ const ( NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs ) +func defaultMessage(status int) string { + switch status { + case http.StatusBadRequest: + return BadRequestDefaultMsg + case http.StatusUnauthorized: + return UnauthorizedDefaultMsg + case http.StatusForbidden: + return ForbiddenDefaultMsg + case http.StatusNotFound: + return NotFoundDefaultMsg + case http.StatusInternalServerError: + return InternalServerErrorDefaultMsg + case http.StatusNotImplemented: + return NotImplementedDefaultMsg + default: + return defaultMsg + } +} + const ( // BadRequestPrefix is the prefix added to the bad request messages that are // directly sent to the cli. @@ -292,7 +335,7 @@ func NewErr(status int, err error, opts ...Option) error { // Errorf creates a new error using the given format and status code. func Errorf(code int, format string, args ...interface{}) error { as, opts := splitOptionArgs(args) - opts = append(opts, withDefaultMessage(NotImplementedDefaultMsg)) + opts = append(opts, withDefaultMessage(defaultMessage(code))) e := &Error{Status: code, Err: fmt.Errorf(format, as...)} for _, o := range opts { o(e) @@ -384,7 +427,6 @@ func NotFoundErr(err error, opts ...Option) error { // UnexpectedErr will be used when the certificate authority makes an outgoing // request and receives an unhandled status code. func UnexpectedErr(code int, err error, opts ...Option) error { - opts = append(opts, withDefaultMessage("The certificate authority received an "+ - "unexpected HTTP status code - '%d'. "+seeLogs, code)) + opts = append(opts, withFormattedMessage("The certificate authority received an unexpected HTTP status code - '%d'. "+seeLogs, code)) return NewErr(code, err, opts...) } diff --git a/errs/errors_test.go b/errs/errors_test.go index 5836c592..a5eb4af7 100644 --- a/errs/errors_test.go +++ b/errs/errors_test.go @@ -127,3 +127,101 @@ func TestError_Unwrap_As(t *testing.T) { }) } } + +func TestErrorf(t *testing.T) { + tests := []struct { + name string + code int + format string + args []any + want error + }{ + {"bad request", 400, "test error string", nil, &Error{ + Status: 400, + Err: errors.New("test error string"), + Msg: BadRequestDefaultMsg, + }}, + {"unauthorized", 401, "test error string", nil, &Error{ + Status: 401, + Err: errors.New("test error string"), + Msg: UnauthorizedDefaultMsg, + }}, + {"forbidden", 403, "test error string", nil, &Error{ + Status: 403, + Err: errors.New("test error string"), + Msg: ForbiddenDefaultMsg, + }}, + {"not found", 404, "test error string", nil, &Error{ + Status: 404, + Err: errors.New("test error string"), + Msg: NotFoundDefaultMsg, + }}, + {"internal server error", 500, "test error string", nil, &Error{ + Status: 500, + Err: errors.New("test error string"), + Msg: InternalServerErrorDefaultMsg, + }}, + {"not implemented", 501, "test error string", nil, &Error{ + Status: 501, + Err: errors.New("test error string"), + Msg: NotImplementedDefaultMsg, + }}, + {"other", 502, "test error string", nil, &Error{ + Status: 502, + Err: errors.New("test error string"), + Msg: defaultMsg, + }}, + {"formatted args", 401, "test error string: %s", []any{"some reason"}, &Error{ + Status: 401, + Err: errors.New("test error string: some reason"), + Msg: UnauthorizedDefaultMsg, + }}, + {"WithMessage", 403, "test error string", []any{WithMessage("%s failed", "something")}, &Error{ + Status: 403, + Err: errors.New("test error string"), + Msg: "something failed", + }}, + {"WithErrorMessage", 404, "test error string", []any{WithErrorMessage()}, &Error{ + Status: 404, + Err: errors.New("test error string"), + Msg: "test error string", + }}, + {"WithKeyValue", 500, "test error string", []any{WithKeyVal("foo", 1), WithKeyVal("bar", "zar")}, &Error{ + Status: 500, + Err: errors.New("test error string"), + Msg: InternalServerErrorDefaultMsg, + Details: map[string]interface{}{"foo": 1, "bar": "zar"}, + }}, + {"withDefaultMessage", 501, "test error string", []any{withDefaultMessage("some message")}, &Error{ + Status: 501, + Err: errors.New("test error string"), + Msg: "some message", + }}, + {"withFormattedMessage", 502, "test error string", []any{withFormattedMessage("some message: %s", "the reason")}, &Error{ + Status: 502, + Err: errors.New("test error string"), + Msg: "some message: the reason", + }}, + {"WithMessage and withDefaultMessage", 500, "test error string", []any{WithMessage("the message"), withDefaultMessage("some message")}, &Error{ + Status: 500, + Err: errors.New("test error string"), + Msg: "the message", + }}, + {"WithErrorMessage and withFormattedMessage", 500, "test error string", []any{WithErrorMessage(), withFormattedMessage("some message: %s", "the reason")}, &Error{ + Status: 500, + Err: errors.New("test error string"), + Msg: "test error string", + }}, + {"formatted args and withMessage", 500, "test error string: %s, code %d", []any{"reason", 1234, WithMessage("the message")}, &Error{ + Status: 500, + Err: errors.New("test error string: reason, code 1234"), + Msg: "the message", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := Errorf(tt.code, tt.format, tt.args...) + assert.Equal(t, tt.want, gotErr) + }) + } +} diff --git a/test/integration/requestid_test.go b/test/integration/requestid_test.go index 8801dc45..7ac4a9a8 100644 --- a/test/integration/requestid_test.go +++ b/test/integration/requestid_test.go @@ -53,6 +53,8 @@ func Test_reflectRequestID(t *testing.T) { ctx := context.Background() dir := t.TempDir() + t.Setenv("STEPPATH", dir) + m, err := minica.New(minica.WithName("Step E2E")) require.NoError(t, err) diff --git a/test/integration/scep/common_test.go b/test/integration/scep/common_test.go index 7cf7f97f..037f33b5 100644 --- a/test/integration/scep/common_test.go +++ b/test/integration/scep/common_test.go @@ -103,6 +103,8 @@ func newTestCA(t *testing.T, name string) *testCA { require.NoError(t, err) dir := t.TempDir() + t.Setenv("STEPPATH", dir) + m, err := minica.New(minica.WithName(name), minica.WithGetSignerFunc(func() (crypto.Signer, error) { return signer, nil })) diff --git a/test/integration/scep/decrypter_cas_test.go b/test/integration/scep/decrypter_cas_test.go index f19a2c91..43e217d2 100644 --- a/test/integration/scep/decrypter_cas_test.go +++ b/test/integration/scep/decrypter_cas_test.go @@ -34,6 +34,8 @@ func TestIssuesCertificateUsingSCEPWithDecrypterAndUpstreamCAS(t *testing.T) { require.NoError(t, err) dir := t.TempDir() + t.Setenv("STEPPATH", dir) + m, err := minica.New(minica.WithName("Step E2E | SCEP Decrypter w/ Upstream CAS"), minica.WithGetSignerFunc(func() (crypto.Signer, error) { return signer, nil })) diff --git a/test/integration/scep/decrypter_test.go b/test/integration/scep/decrypter_test.go index f59ae8b1..4ad9fdd8 100644 --- a/test/integration/scep/decrypter_test.go +++ b/test/integration/scep/decrypter_test.go @@ -32,6 +32,8 @@ func TestIssuesCertificateUsingSCEPWithDecrypter(t *testing.T) { require.NoError(t, err) dir := t.TempDir() + t.Setenv("STEPPATH", dir) + m, err := minica.New(minica.WithName("Step E2E | SCEP Decrypter"), minica.WithGetSignerFunc(func() (crypto.Signer, error) { return signer, nil })) diff --git a/test/integration/scep/regular_cas_test.go b/test/integration/scep/regular_cas_test.go index ae5ebbfd..ebca51c2 100644 --- a/test/integration/scep/regular_cas_test.go +++ b/test/integration/scep/regular_cas_test.go @@ -29,6 +29,8 @@ func TestFailsIssuingCertificateUsingRegularSCEPWithUpstreamCAS(t *testing.T) { require.NoError(t, err) dir := t.TempDir() + t.Setenv("STEPPATH", dir) + m, err := minica.New(minica.WithName("Step E2E | SCEP Regular w/ Upstream CAS"), minica.WithGetSignerFunc(func() (crypto.Signer, error) { return signer, nil }))