From c76dad8a22b81482994b3599280607a5cb990c84 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 8 Feb 2024 15:03:46 +0100 Subject: [PATCH] Improve tests for CRL HTTP handler --- api/api_test.go | 39 --------------------- api/crl.go | 13 ++++++- api/crl_test.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 40 deletions(-) create mode 100644 api/crl_test.go diff --git a/api/api_test.go b/api/api_test.go index a62b34e8..28944a1e 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -789,45 +789,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) ( return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err } -func Test_CRLGeneration(t *testing.T) { - tests := []struct { - name string - err error - statusCode int - expected []byte - }{ - {"empty", nil, http.StatusOK, nil}, - } - - chiCtx := chi.NewRouteContext() - req := httptest.NewRequest("GET", "http://example.com/crl", http.NoBody) - req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockMustAuthority(t, &mockAuthority{ret1: tt.expected, err: tt.err}) - w := httptest.NewRecorder() - CRL(w, req) - res := w.Result() - - if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.CRL StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) - } - - body, err := io.ReadAll(res.Body) - res.Body.Close() - if err != nil { - t.Errorf("caHandler.Root unexpected error = %v", err) - } - if tt.statusCode == 200 { - if !bytes.Equal(bytes.TrimSpace(body), tt.expected) { - t.Errorf("caHandler.Root CRL = %s, wants %s", body, tt.expected) - } - } - }) - } -} - func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority diff --git a/api/crl.go b/api/crl.go index 7f12c6f8..a94056ad 100644 --- a/api/crl.go +++ b/api/crl.go @@ -6,6 +6,7 @@ import ( "time" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/errs" ) // CRL is an HTTP handler that returns the current CRL in DER or PEM format @@ -16,7 +17,17 @@ func CRL(w http.ResponseWriter, r *http.Request) { return } - w.Header().Add("Expires", crlInfo.ExpiresAt.Format(time.RFC1123)) + if crlInfo == nil { + render.Error(w, errs.New(http.StatusInternalServerError, "no CRL available")) + return + } + + expires := crlInfo.ExpiresAt + if expires.IsZero() { + expires = time.Now() + } + + w.Header().Add("Expires", expires.Format(time.RFC1123)) _, formatAsPEM := r.URL.Query()["pem"] if formatAsPEM { diff --git a/api/crl_test.go b/api/crl_test.go new file mode 100644 index 00000000..c1c7a4b0 --- /dev/null +++ b/api/crl_test.go @@ -0,0 +1,93 @@ +package api + +import ( + "bytes" + "context" + "encoding/pem" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_CRL(t *testing.T) { + data := []byte{1, 2, 3, 4} + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "X509 CRL", + Bytes: data, + }) + pemData = bytes.TrimSpace(pemData) + emptyPEMData := pem.EncodeToMemory(&pem.Block{ + Type: "X509 CRL", + Bytes: nil, + }) + emptyPEMData = bytes.TrimSpace(emptyPEMData) + tests := []struct { + name string + url string + err error + statusCode int + crlInfo *authority.CertificateRevocationListInfo + expectedBody []byte + expectedHeaders http.Header + expectedErrorJSON string + }{ + {"ok", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, data, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""}, + {"ok/pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, pemData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""}, + {"ok/empty", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, nil, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""}, + {"ok/empty-pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, emptyPEMData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""}, + {"fail/internal", "http://example.com/crl", errs.Wrap(http.StatusInternalServerError, errors.New("failure"), "authority.GetCertificateRevocationList"), http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info."}`}, + {"fail/nil", "http://example.com/crl", nil, http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"no CRL available"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockMustAuthority(t, &mockAuthority{ret1: tt.crlInfo, err: tt.err}) + + chiCtx := chi.NewRouteContext() + req := httptest.NewRequest("GET", tt.url, http.NoBody) + req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) + w := httptest.NewRecorder() + CRL(w, req) + res := w.Result() + + assert.Equal(t, tt.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + require.NoError(t, err) + + if tt.statusCode >= 300 { + assert.JSONEq(t, tt.expectedErrorJSON, string(bytes.TrimSpace(body))) + return + } + + // check expected header values + for _, h := range []string{"content-type", "content-disposition"} { + v := tt.expectedHeaders.Get(h) + require.NotEmpty(t, v) + + actual := res.Header.Get(h) + assert.Equal(t, v, actual) + } + + // check expires header value + assert.NotEmpty(t, res.Header.Get("expires")) + t1, err := time.Parse(time.RFC1123, res.Header.Get("expires")) + if assert.NoError(t, err) { + assert.False(t, t1.IsZero()) + } + + // check body contents + assert.Equal(t, tt.expectedBody, bytes.TrimSpace(body)) + }) + } +}