mirror of
https://github.com/outbackdingo/certificates.git
synced 2026-01-27 10:18:34 +00:00
Log errors using slog.Logger
This commit allows logging errors in a slog.Logger injected in the context. This type of logger is not currently used directly in step-ca, but this will change in the future.
This commit is contained in:
32
api/api.go
32
api/api.go
@@ -353,15 +353,15 @@ func Route(r Router) {
|
||||
// Version is an HTTP handler that returns the version of the server.
|
||||
func Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := mustAuthority(r.Context()).Version()
|
||||
render.JSON(w, VersionResponse{
|
||||
render.JSON(w, r, VersionResponse{
|
||||
Version: v.Version,
|
||||
RequireClientAuthentication: v.RequireClientAuthentication,
|
||||
})
|
||||
}
|
||||
|
||||
// Health is an HTTP handler that returns the status of the server.
|
||||
func Health(w http.ResponseWriter, _ *http.Request) {
|
||||
render.JSON(w, HealthResponse{Status: "ok"})
|
||||
func Health(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSON(w, r, HealthResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
||||
@@ -372,11 +372,11 @@ func Root(w http.ResponseWriter, r *http.Request) {
|
||||
// Load root certificate with the
|
||||
cert, err := mustAuthority(r.Context()).Root(sum)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
render.Error(w, r, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &RootResponse{RootPEM: Certificate{cert}})
|
||||
render.JSON(w, r, &RootResponse{RootPEM: Certificate{cert}})
|
||||
}
|
||||
|
||||
func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
||||
@@ -391,17 +391,17 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
||||
func Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionersResponse{
|
||||
render.JSON(w, r, &ProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
})
|
||||
@@ -412,18 +412,18 @@ func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
kid := chi.URLParam(r, "kid")
|
||||
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
|
||||
if err != nil {
|
||||
render.Error(w, errs.NotFoundErr(err))
|
||||
render.Error(w, r, errs.NotFoundErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionerKeyResponse{key})
|
||||
render.JSON(w, r, &ProvisionerKeyResponse{key})
|
||||
}
|
||||
|
||||
// Roots returns all the root certificates for the CA.
|
||||
func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error getting roots"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
certs[i] = Certificate{roots[i]}
|
||||
}
|
||||
|
||||
render.JSONStatus(w, &RootsResponse{
|
||||
render.JSONStatus(w, r, &RootsResponse{
|
||||
Certificates: certs,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
@@ -441,7 +441,7 @@ func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -454,7 +454,7 @@ func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
|
||||
if _, err := w.Write(block); err != nil {
|
||||
log.Error(w, err)
|
||||
log.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -464,7 +464,7 @@ func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
func Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := mustAuthority(r.Context()).GetFederation()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -473,7 +473,7 @@ func Federation(w http.ResponseWriter, r *http.Request) {
|
||||
certs[i] = Certificate{federated[i]}
|
||||
}
|
||||
|
||||
render.JSONStatus(w, &FederationResponse{
|
||||
render.JSONStatus(w, r, &FederationResponse{
|
||||
Certificates: certs,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
func CRL(w http.ResponseWriter, r *http.Request) {
|
||||
crlInfo, err := mustAuthority(r.Context()).GetCertificateRevocationList()
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if crlInfo == nil {
|
||||
render.Error(w, errs.New(http.StatusNotFound, "no CRL available"))
|
||||
render.Error(w, r, errs.New(http.StatusNotFound, "no CRL available"))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,31 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ErrorKey is the logging attribute key for error values.
|
||||
const ErrorKey = "error"
|
||||
|
||||
type loggerKey struct{}
|
||||
|
||||
// NewContext creates a new context with the given slog.Logger.
|
||||
func NewContext(ctx context.Context, logger *slog.Logger) context.Context {
|
||||
return context.WithValue(ctx, loggerKey{}, logger)
|
||||
}
|
||||
|
||||
// FromContext returns the logger from the given context.
|
||||
func FromContext(ctx context.Context) (l *slog.Logger, ok bool) {
|
||||
l, ok = ctx.Value(loggerKey{}).(*slog.Logger)
|
||||
return
|
||||
}
|
||||
|
||||
// StackTracedError is the set of errors implementing the StackTrace function.
|
||||
//
|
||||
// Errors implementing this interface have their stack traces logged when passed
|
||||
@@ -27,7 +45,12 @@ type fieldCarrier interface {
|
||||
// Error adds to the response writer the given error if it implements
|
||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||
// using the log package.
|
||||
func Error(rw http.ResponseWriter, err error) {
|
||||
func Error(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
ctx := r.Context()
|
||||
if logger, ok := FromContext(ctx); ok && err != nil {
|
||||
logger.ErrorContext(ctx, "request failed", slog.Any(ErrorKey, err))
|
||||
}
|
||||
|
||||
fc, ok := rw.(fieldCarrier)
|
||||
if !ok {
|
||||
return
|
||||
@@ -51,7 +74,7 @@ func Error(rw http.ResponseWriter, err error) {
|
||||
|
||||
// EnabledResponse log the response object if it implements the EnableLogger
|
||||
// interface.
|
||||
func EnabledResponse(rw http.ResponseWriter, v any) {
|
||||
func EnabledResponse(rw http.ResponseWriter, r *http.Request, v any) {
|
||||
type enableLogger interface {
|
||||
ToLog() (any, error)
|
||||
}
|
||||
@@ -59,7 +82,7 @@ func EnabledResponse(rw http.ResponseWriter, v any) {
|
||||
if el, ok := v.(enableLogger); ok {
|
||||
out, err := el.ToLog()
|
||||
if err != nil {
|
||||
Error(rw, err)
|
||||
Error(rw, r, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -27,21 +30,30 @@ func (stackTracedError) StackTrace() pkgerrors.StackTrace {
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{}))
|
||||
req := httptest.NewRequest("GET", "/test", http.NoBody)
|
||||
reqWithLogger := req.WithContext(NewContext(req.Context(), logger))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
error
|
||||
rw http.ResponseWriter
|
||||
r *http.Request
|
||||
isFieldCarrier bool
|
||||
isSlogLogger bool
|
||||
stepDebug bool
|
||||
expectStackTrace bool
|
||||
}{
|
||||
{"noLogger", nil, nil, false, false, false},
|
||||
{"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
|
||||
{"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
|
||||
{"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
|
||||
{"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
|
||||
{"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
|
||||
{"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
|
||||
{"noLogger", nil, nil, req, false, false, false, false},
|
||||
{"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false},
|
||||
{"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false},
|
||||
{"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false},
|
||||
{"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false},
|
||||
{"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true},
|
||||
{"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true},
|
||||
{"slogWithNoError", nil, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false},
|
||||
{"slogWithError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -52,27 +64,41 @@ func TestError(t *testing.T) {
|
||||
t.Setenv("STEPDEBUG", "0")
|
||||
}
|
||||
|
||||
Error(tt.rw, tt.error)
|
||||
Error(tt.rw, tt.r, tt.error)
|
||||
|
||||
// return early if test case doesn't use logger
|
||||
if !tt.isFieldCarrier {
|
||||
if !tt.isFieldCarrier && !tt.isSlogLogger {
|
||||
return
|
||||
}
|
||||
|
||||
fields := tt.rw.(logging.ResponseLogger).Fields()
|
||||
if tt.isFieldCarrier {
|
||||
fields := tt.rw.(logging.ResponseLogger).Fields()
|
||||
|
||||
// expect the error field to be (not) set and to be the same error that was fed to Error
|
||||
if tt.error == nil {
|
||||
assert.Nil(t, fields["error"])
|
||||
} else {
|
||||
assert.Same(t, tt.error, fields["error"])
|
||||
// expect the error field to be (not) set and to be the same error that was fed to Error
|
||||
if tt.error == nil {
|
||||
assert.Nil(t, fields["error"])
|
||||
} else {
|
||||
assert.Same(t, tt.error, fields["error"])
|
||||
}
|
||||
|
||||
// check if stack-trace is set when expected
|
||||
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] not set`)
|
||||
} else if !tt.expectStackTrace && hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] was set`)
|
||||
}
|
||||
}
|
||||
|
||||
// check if stack-trace is set when expected
|
||||
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] not set`)
|
||||
} else if !tt.expectStackTrace && hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] was set`)
|
||||
if tt.isSlogLogger {
|
||||
b := buf.Bytes()
|
||||
if tt.error == nil {
|
||||
assert.Empty(t, b)
|
||||
} else if assert.NotEmpty(t, b) {
|
||||
var m map[string]any
|
||||
assert.NoError(t, json.Unmarshal(b, &m))
|
||||
assert.Equal(t, tt.error.Error(), m["error"])
|
||||
}
|
||||
buf.Reset()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func (e badProtoJSONError) Error() string {
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for badProtoJSONError
|
||||
func (e badProtoJSONError) Render(w http.ResponseWriter) {
|
||||
func (e badProtoJSONError) Render(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
@@ -62,5 +62,5 @@ func (e badProtoJSONError) Render(w http.ResponseWriter) {
|
||||
// trim the proto prefix for the message
|
||||
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
|
||||
}
|
||||
render.JSONStatus(w, v, http.StatusBadRequest)
|
||||
render.JSONStatus(w, r, v, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
@@ -142,7 +142,8 @@ func Test_badProtoJSONError_Render(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tt.e.Render(w)
|
||||
r := httptest.NewRequest("POST", "/test", http.NoBody)
|
||||
tt.e.Render(w, r)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
|
||||
10
api/rekey.go
10
api/rekey.go
@@ -29,25 +29,25 @@ func (s *RekeyRequest) Validate() error {
|
||||
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
||||
func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing client certificate"))
|
||||
render.Error(w, r, errs.BadRequest("missing client certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
var body RekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||
render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
@@ -57,7 +57,7 @@ func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
render.JSONStatus(w, r, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
|
||||
func JSON(w http.ResponseWriter, v interface{}) {
|
||||
JSONStatus(w, v, http.StatusOK)
|
||||
func JSON(w http.ResponseWriter, r *http.Request, v interface{}) {
|
||||
JSONStatus(w, r, v, http.StatusOK)
|
||||
}
|
||||
|
||||
// JSONStatus marshals v into w. It additionally sets the status code of
|
||||
@@ -22,7 +22,7 @@ func JSON(w http.ResponseWriter, v interface{}) {
|
||||
//
|
||||
// JSONStatus sets the Content-Type of w to application/json unless one is
|
||||
// specified.
|
||||
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||
func JSONStatus(w http.ResponseWriter, r *http.Request, v interface{}, status int) {
|
||||
setContentTypeUnlessPresent(w, "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
@@ -43,7 +43,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||
}
|
||||
}
|
||||
|
||||
log.EnabledResponse(w, v)
|
||||
log.EnabledResponse(w, r, v)
|
||||
}
|
||||
|
||||
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
|
||||
@@ -80,22 +80,22 @@ func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
|
||||
type RenderableError interface {
|
||||
error
|
||||
|
||||
Render(http.ResponseWriter)
|
||||
Render(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
// Error marshals the JSON representation of err to w. In case err implements
|
||||
// RenderableError its own Render method will be called instead.
|
||||
func Error(w http.ResponseWriter, err error) {
|
||||
log.Error(w, err)
|
||||
func Error(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Error(rw, r, err)
|
||||
|
||||
var r RenderableError
|
||||
if errors.As(err, &r) {
|
||||
r.Render(w)
|
||||
var re RenderableError
|
||||
if errors.As(err, &re) {
|
||||
re.Render(rw, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
JSONStatus(w, err, statusCodeFromError(err))
|
||||
JSONStatus(rw, r, err, statusCodeFromError(err))
|
||||
}
|
||||
|
||||
// StatusCodedError is the set of errors that implement the basic StatusCode
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
func TestJSON(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := logging.NewResponseLogger(rec)
|
||||
|
||||
JSON(rw, map[string]interface{}{"foo": "bar"})
|
||||
r := httptest.NewRequest("POST", "/test", http.NoBody)
|
||||
JSON(rw, r, map[string]interface{}{"foo": "bar"})
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
@@ -64,7 +64,8 @@ func jsonPanicTest[T json.UnsupportedTypeError | json.UnsupportedValueError | js
|
||||
assert.ErrorAs(t, err, &e)
|
||||
}()
|
||||
|
||||
JSON(httptest.NewRecorder(), v)
|
||||
r := httptest.NewRequest("POST", "/test", http.NoBody)
|
||||
JSON(httptest.NewRecorder(), r, v)
|
||||
}
|
||||
|
||||
type renderableError struct {
|
||||
@@ -76,10 +77,9 @@ func (err renderableError) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
func (err renderableError) Render(w http.ResponseWriter) {
|
||||
func (err renderableError) Render(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "something/custom")
|
||||
|
||||
JSONStatus(w, err, err.Code)
|
||||
JSONStatus(w, r, err, err.Code)
|
||||
}
|
||||
|
||||
type statusedError struct {
|
||||
@@ -116,8 +116,8 @@ func TestError(t *testing.T) {
|
||||
|
||||
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
Error(rec, kase.err)
|
||||
r := httptest.NewRequest("POST", "/test", http.NoBody)
|
||||
Error(rec, r, kase.err)
|
||||
|
||||
assert.Equal(t, kase.code, rec.Result().StatusCode)
|
||||
assert.Equal(t, kase.body, rec.Body.String())
|
||||
|
||||
@@ -23,19 +23,20 @@ func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the leaf certificate from the peer or the token.
|
||||
cert, token, err := getPeerCertificate(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// The token can be used by RAs to renew a certificate.
|
||||
if token != "" {
|
||||
ctx = authority.NewTokenContext(ctx, token)
|
||||
logOtt(w, token)
|
||||
}
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
certChain, err := a.RenewContext(ctx, cert, nil)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
@@ -45,7 +46,7 @@ func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
render.JSONStatus(w, r, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
|
||||
@@ -57,12 +57,12 @@ func (r *RevokeRequest) Validate() (err error) {
|
||||
func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
if body.OTT != "" {
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
@@ -90,12 +90,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
// the client certificate Serial Number must match the serial number
|
||||
// being revoked.
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing ott or client certificate"))
|
||||
render.Error(w, r, errs.BadRequest("missing ott or client certificate"))
|
||||
return
|
||||
}
|
||||
opts.Crt = r.TLS.PeerCertificates[0]
|
||||
if opts.Crt.SerialNumber.String() != opts.Serial {
|
||||
render.Error(w, errs.BadRequest("serial number in client certificate different than body"))
|
||||
render.Error(w, r, errs.BadRequest("serial number in client certificate different than body"))
|
||||
return
|
||||
}
|
||||
// TODO: should probably be checking if the certificate was revoked here.
|
||||
@@ -106,12 +106,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
logRevoke(w, opts)
|
||||
render.JSON(w, &RevokeResponse{Status: "ok"})
|
||||
render.JSON(w, r, &RevokeResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||
|
||||
10
api/sign.go
10
api/sign.go
@@ -52,13 +52,13 @@ type SignResponse struct {
|
||||
func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -74,13 +74,13 @@ func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error signing certificate"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
@@ -90,7 +90,7 @@ func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
render.JSONStatus(w, r, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
|
||||
62
api/ssh.go
62
api/ssh.go
@@ -253,19 +253,19 @@ type SSHBastionResponse struct {
|
||||
func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
if body.AddUserPublicKey != nil {
|
||||
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -293,13 +293,13 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -307,7 +307,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
}
|
||||
addUserCertificate = &SSHCertificate{addUserCert}
|
||||
@@ -320,7 +320,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -332,14 +332,14 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||
return
|
||||
}
|
||||
identityCertificate = certChainToPEM(certChain)
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, cert)
|
||||
render.JSONStatus(w, &SSHSignResponse{
|
||||
render.JSONStatus(w, r, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{cert},
|
||||
AddUserCertificate: addUserCertificate,
|
||||
IdentityCertificate: identityCertificate,
|
||||
@@ -352,12 +352,12 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
render.Error(w, errs.NotFound("no keys found"))
|
||||
render.Error(w, r, errs.NotFound("no keys found"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -369,7 +369,7 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
||||
}
|
||||
|
||||
render.JSON(w, resp)
|
||||
render.JSON(w, r, resp)
|
||||
}
|
||||
|
||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||
@@ -378,12 +378,12 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
render.Error(w, errs.NotFound("no keys found"))
|
||||
render.Error(w, r, errs.NotFound("no keys found"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -395,7 +395,7 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
||||
}
|
||||
|
||||
render.JSON(w, resp)
|
||||
render.JSON(w, r, resp)
|
||||
}
|
||||
|
||||
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
||||
@@ -403,18 +403,18 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -425,32 +425,32 @@ func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
case provisioner.SSHHostCert:
|
||||
cfg.HostTemplates = ts
|
||||
default:
|
||||
render.Error(w, errs.InternalServer("it should hot get here"))
|
||||
render.Error(w, r, errs.InternalServer("it should hot get here"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, cfg)
|
||||
render.JSON(w, r, cfg)
|
||||
}
|
||||
|
||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||
func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHCheckPrincipalRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
render.JSON(w, &SSHCheckPrincipalResponse{
|
||||
render.JSON(w, r, &SSHCheckPrincipalResponse{
|
||||
Exists: exists,
|
||||
})
|
||||
}
|
||||
@@ -465,10 +465,10 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
render.JSON(w, &SSHGetHostsResponse{
|
||||
render.JSON(w, r, &SSHGetHostsResponse{
|
||||
Hosts: hosts,
|
||||
})
|
||||
}
|
||||
@@ -477,22 +477,22 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &SSHBastionResponse{
|
||||
render.JSON(w, r, &SSHBastionResponse{
|
||||
Hostname: body.Hostname,
|
||||
Bastion: bastion,
|
||||
})
|
||||
|
||||
@@ -42,19 +42,19 @@ type SSHRekeyResponse struct {
|
||||
func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -64,18 +64,18 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -85,12 +85,12 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, newCert)
|
||||
render.JSONStatus(w, &SSHRekeyResponse{
|
||||
render.JSONStatus(w, r, &SSHRekeyResponse{
|
||||
Certificate: SSHCertificate{newCert},
|
||||
IdentityCertificate: identity,
|
||||
}, http.StatusCreated)
|
||||
|
||||
@@ -40,13 +40,13 @@ type SSHRenewResponse struct {
|
||||
func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -56,18 +56,18 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
a := mustAuthority(ctx)
|
||||
_, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
newCert, err := a.RenewSSH(ctx, oldCert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,12 +77,12 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, newCert)
|
||||
render.JSONStatus(w, &SSHSignResponse{
|
||||
render.JSONStatus(w, r, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{newCert},
|
||||
IdentityCertificate: identity,
|
||||
}, http.StatusCreated)
|
||||
|
||||
@@ -51,12 +51,12 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
||||
func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,18 +75,18 @@ func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
logOtt(w, body.OTT)
|
||||
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
logSSHRevoke(w, opts)
|
||||
render.JSON(w, &SSHRevokeResponse{Status: "ok"})
|
||||
render.JSON(w, r, &SSHRevokeResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||
|
||||
Reference in New Issue
Block a user