mirror of
https://github.com/outbackdingo/certificates.git
synced 2026-01-27 10:18:34 +00:00
Log errors using slog.Logger
This commit allows logging errors in a slog.Logger injected in the context. This type of logger is not currently used directly in step-ca, but this will change in the future.
This commit is contained in:
@@ -82,23 +82,23 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
var nar NewAccountRequest
|
||||
if err := json.Unmarshal(payload.value, &nar); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
"failed to unmarshal new-account request payload"))
|
||||
return
|
||||
}
|
||||
if err := nar.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -108,26 +108,26 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
var acmeErr *acme.Error
|
||||
if !errors.As(err, &acmeErr) || acmeErr.Status != http.StatusBadRequest {
|
||||
// Something went wrong ...
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Account does not exist //
|
||||
if nar.OnlyReturnExisting {
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType,
|
||||
"account does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
eak, err := validateExternalAccountBinding(ctx, &nar)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,17 +140,17 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
ProvisionerName: prov.Name,
|
||||
}
|
||||
if err := db.CreateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error creating account"))
|
||||
return
|
||||
}
|
||||
|
||||
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
|
||||
if err := eak.BindTo(acc); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error updating external account binding key"))
|
||||
return
|
||||
}
|
||||
acc.ExternalAccountBinding = nar.ExternalAccountBinding
|
||||
@@ -163,7 +163,7 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID))
|
||||
render.JSONStatus(w, acc, httpStatus)
|
||||
render.JSONStatus(w, r, acc, httpStatus)
|
||||
}
|
||||
|
||||
// GetOrUpdateAccount is the api for updating an ACME account.
|
||||
@@ -174,12 +174,12 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -188,12 +188,12 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
if !payload.isPostAsGet {
|
||||
var uar UpdateAccountRequest
|
||||
if err := json.Unmarshal(payload.value, &uar); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
"failed to unmarshal new-account request payload"))
|
||||
return
|
||||
}
|
||||
if err := uar.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
if len(uar.Status) > 0 || len(uar.Contact) > 0 {
|
||||
@@ -204,7 +204,7 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := db.UpdateAccount(ctx, acc); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating account"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error updating account"))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -213,7 +213,7 @@ func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
linker.LinkAccount(ctx, acc)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
|
||||
render.JSON(w, acc)
|
||||
render.JSON(w, r, acc)
|
||||
}
|
||||
|
||||
func logOrdersByAccount(w http.ResponseWriter, oids []string) {
|
||||
@@ -233,23 +233,23 @@ func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
accID := chi.URLParam(r, "accID")
|
||||
if acc.ID != accID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
|
||||
return
|
||||
}
|
||||
|
||||
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrdersByAccountID(ctx, orders)
|
||||
|
||||
render.JSON(w, orders)
|
||||
render.JSON(w, r, orders)
|
||||
logOrdersByAccount(w, orders)
|
||||
}
|
||||
|
||||
@@ -223,13 +223,13 @@ func GetDirectory(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
|
||||
render.JSON(w, &Directory{
|
||||
render.JSON(w, r, &Directory{
|
||||
NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
|
||||
NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
|
||||
NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
|
||||
@@ -273,8 +273,8 @@ func shouldAddMetaObject(p *provisioner.ACME) bool {
|
||||
|
||||
// NotImplemented returns a 501 and is generally a placeholder for functionality which
|
||||
// MAY be added at some point in the future but is not in any way a guarantee of such.
|
||||
func NotImplemented(w http.ResponseWriter, _ *http.Request) {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||
func NotImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
|
||||
}
|
||||
|
||||
// GetAuthorization ACME api for retrieving an Authz.
|
||||
@@ -285,28 +285,28 @@ func GetAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error retrieving authorization"))
|
||||
return
|
||||
}
|
||||
if acc.ID != az.AccountID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
|
||||
return
|
||||
}
|
||||
if err = az.UpdateStatus(ctx, db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error updating authorization status"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkAuthorization(ctx, az)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
|
||||
render.JSON(w, az)
|
||||
render.JSON(w, r, az)
|
||||
}
|
||||
|
||||
// GetChallenge ACME api for retrieving a Challenge.
|
||||
@@ -317,13 +317,13 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -336,22 +336,22 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
azID := chi.URLParam(r, "authzID")
|
||||
ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error retrieving challenge"))
|
||||
return
|
||||
}
|
||||
ch.AuthorizationID = azID
|
||||
if acc.ID != ch.AccountID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own challenge '%s'", acc.ID, ch.ID))
|
||||
return
|
||||
}
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
if err = ch.Validate(ctx, db, jwk, payload.value); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error validating challenge"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -359,7 +359,7 @@ func GetChallenge(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
|
||||
render.JSON(w, ch)
|
||||
render.JSON(w, r, ch)
|
||||
}
|
||||
|
||||
// GetCertificate ACME api for retrieving a Certificate.
|
||||
@@ -369,18 +369,18 @@ func GetCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
certID := chi.URLParam(r, "certID")
|
||||
cert, err := db.GetCertificate(ctx, certID)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error retrieving certificate"))
|
||||
return
|
||||
}
|
||||
if cert.AccountID != acc.ID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own certificate '%s'", acc.ID, certID))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func addNonce(next nextHTTP) nextHTTP {
|
||||
db := acme.MustDatabaseFromContext(r.Context())
|
||||
nonce, err := db.CreateNonce(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Replay-Nonce", string(nonce))
|
||||
@@ -64,7 +64,7 @@ func verifyContentType(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
p, err := provisionerFromContext(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ func verifyContentType(next nextHTTP) nextHTTP {
|
||||
return
|
||||
}
|
||||
}
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
|
||||
"expected content-type to be in %s, but got %s", expected, ct))
|
||||
}
|
||||
}
|
||||
@@ -98,12 +98,12 @@ func parseJWS(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "failed to read request body"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "failed to read request body"))
|
||||
return
|
||||
}
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), jwsContextKey, jws)
|
||||
@@ -133,15 +133,15 @@ func validateJWS(next nextHTTP) nextHTTP {
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
|
||||
return
|
||||
}
|
||||
if len(jws.Signatures) > 1 {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ func validateJWS(next nextHTTP) nextHTTP {
|
||||
uh.Algorithm != "" ||
|
||||
uh.Nonce != "" ||
|
||||
len(uh.ExtraHeaders) > 0 {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
|
||||
return
|
||||
}
|
||||
hdr := sig.Protected
|
||||
@@ -162,13 +162,13 @@ func validateJWS(next nextHTTP) nextHTTP {
|
||||
switch k := hdr.JSONWebKey.Key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if k.Size() < keyutil.MinRSAKeyBytes {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
|
||||
"rsa keys must be at least %d bits (%d bytes) in size",
|
||||
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
|
||||
return
|
||||
}
|
||||
default:
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
|
||||
"jws key type and algorithm do not match"))
|
||||
return
|
||||
}
|
||||
@@ -176,35 +176,35 @@ func validateJWS(next nextHTTP) nextHTTP {
|
||||
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
|
||||
// we good
|
||||
default:
|
||||
render.Error(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
|
||||
return
|
||||
}
|
||||
|
||||
// Check the validity/freshness of the Nonce.
|
||||
if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the JWS url matches the requested url.
|
||||
jwsURL, ok := hdr.ExtraHeaders["url"].(string)
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
|
||||
return
|
||||
}
|
||||
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
|
||||
if jwsURL != reqURL.String() {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
|
||||
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
|
||||
return
|
||||
}
|
||||
|
||||
if hdr.JSONWebKey != nil && hdr.KeyID != "" {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||
return
|
||||
}
|
||||
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
@@ -221,23 +221,23 @@ func extractJWK(next nextHTTP) nextHTTP {
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
jwk := jws.Signatures[0].Protected.JSONWebKey
|
||||
if jwk == nil {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
|
||||
return
|
||||
}
|
||||
if !jwk.Valid() {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
||||
return
|
||||
}
|
||||
|
||||
// Overwrite KeyID with the JWK thumbprint.
|
||||
jwk.KeyID, err = acme.KeyToID(jwk)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -251,11 +251,11 @@ func extractJWK(next nextHTTP) nextHTTP {
|
||||
// For NewAccount and Revoke requests ...
|
||||
break
|
||||
case err != nil:
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||
@@ -274,11 +274,11 @@ func checkPrerequisites(next nextHTTP) nextHTTP {
|
||||
if ok {
|
||||
ok, err := checkFunc(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -296,13 +296,13 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
kid := jws.Signatures[0].Protected.KeyID
|
||||
if kid == "" {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -310,14 +310,14 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
||||
acc, err := db.GetAccount(ctx, accID)
|
||||
switch {
|
||||
case acme.IsErrNotFound(err):
|
||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||
return
|
||||
case err != nil:
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
default:
|
||||
if !acc.IsValid() {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -325,7 +325,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
||||
if kid != storedLocation {
|
||||
// ACME accounts should have a stored location equivalent to the
|
||||
// kid in the ACME request.
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"kid does not match stored account location; expected %s, but got %s",
|
||||
storedLocation, kid))
|
||||
return
|
||||
@@ -339,7 +339,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
||||
if reqProvName != accProvName {
|
||||
// Provisioner in the URL must match the provisioner with
|
||||
// which the account was created.
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s",
|
||||
accProvName, reqProvName))
|
||||
return
|
||||
@@ -353,7 +353,7 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
||||
linker := acme.MustLinkerFromContext(ctx)
|
||||
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
||||
if !strings.HasPrefix(kid, kidPrefix) {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType,
|
||||
"kid does not have required prefix; expected %s, but got %s",
|
||||
kidPrefix, kid))
|
||||
return
|
||||
@@ -374,7 +374,7 @@ func extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -410,16 +410,16 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
ctx := r.Context()
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
jwk, err := jwkFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -428,11 +428,11 @@ func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||
case errors.Is(err, jose.ErrCryptoFailure):
|
||||
payload, err = retryVerificationWithPatchedSignatures(jws, jwk)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)"))
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws with patched signature(s)"))
|
||||
return
|
||||
}
|
||||
case err != nil:
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -549,11 +549,11 @@ func isPostAsGet(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
payload, err := payloadFromContext(r.Context())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
if !payload.isPostAsGet {
|
||||
render.Error(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
|
||||
@@ -99,29 +99,29 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
var nor NewOrderRequest
|
||||
if err := json.Unmarshal(payload.value, &nor); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
"failed to unmarshal new-order request payload"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := nor.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -130,39 +130,39 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acmeProv, err := acmeProvisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
var eak *acme.ExternalAccountKey
|
||||
if acmeProv.RequireEAB {
|
||||
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error retrieving external account binding key"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
acmePolicy, err := newACMEPolicyEngine(eak)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error creating ACME policy engine"))
|
||||
return
|
||||
}
|
||||
|
||||
for _, identifier := range nor.Identifiers {
|
||||
// evaluate the ACME account level policy
|
||||
if err = isIdentifierAllowed(acmePolicy, identifier); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
// evaluate the provisioner level policy
|
||||
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
|
||||
if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
// evaluate the authority level policy
|
||||
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -188,7 +188,7 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
Status: acme.StatusPending,
|
||||
}
|
||||
if err := newAuthorization(ctx, az); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
o.AuthorizationIDs[i] = az.ID
|
||||
@@ -207,14 +207,14 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := db.CreateOrder(ctx, o); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error creating order"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error creating order"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSONStatus(w, o, http.StatusCreated)
|
||||
render.JSONStatus(w, r, o, http.StatusCreated)
|
||||
}
|
||||
|
||||
func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error {
|
||||
@@ -288,39 +288,39 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
}
|
||||
if acc.ID != o.AccountID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
||||
return
|
||||
}
|
||||
if prov.GetID() != o.ProvisionerID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
if err = o.UpdateStatus(ctx, db); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error updating order status"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
render.JSON(w, r, o)
|
||||
}
|
||||
|
||||
// FinalizeOrder attempts to finalize an order and create a certificate.
|
||||
@@ -331,56 +331,56 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acc, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
var fr FinalizeRequest
|
||||
if err := json.Unmarshal(payload.value, &fr); err != nil {
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err,
|
||||
"failed to unmarshal finalize-order request payload"))
|
||||
return
|
||||
}
|
||||
if err := fr.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error retrieving order"))
|
||||
return
|
||||
}
|
||||
if acc.ID != o.AccountID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"account '%s' does not own order '%s'", acc.ID, o.ID))
|
||||
return
|
||||
}
|
||||
if prov.GetID() != o.ProvisionerID {
|
||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
render.Error(w, r, acme.NewError(acme.ErrorUnauthorizedType,
|
||||
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
|
||||
return
|
||||
}
|
||||
|
||||
ca := mustAuthority(ctx)
|
||||
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error finalizing order"))
|
||||
return
|
||||
}
|
||||
|
||||
linker.LinkOrder(ctx, o)
|
||||
|
||||
w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
|
||||
render.JSON(w, o)
|
||||
render.JSON(w, r, o)
|
||||
}
|
||||
|
||||
// challengeTypes determines the types of challenges that should be used
|
||||
|
||||
@@ -33,65 +33,65 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
jws, err := jwsFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := provisionerFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := payloadFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
var p revokePayload
|
||||
err = json.Unmarshal(payload.value, &p)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error unmarshaling payload"))
|
||||
return
|
||||
}
|
||||
|
||||
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
|
||||
if err != nil {
|
||||
// in this case the most likely cause is a client that didn't properly encode the certificate
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
|
||||
return
|
||||
}
|
||||
|
||||
certToBeRevoked, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
// in this case a client may have encoded something different than a certificate
|
||||
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
|
||||
render.Error(w, r, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
serial := certToBeRevoked.SerialNumber.String()
|
||||
dbCert, err := db.GetCertificateBySerial(ctx, serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
|
||||
// this should never happen
|
||||
render.Error(w, acme.NewErrorISE("certificate raw bytes are not equal"))
|
||||
render.Error(w, r, acme.NewErrorISE("certificate raw bytes are not equal"))
|
||||
return
|
||||
}
|
||||
|
||||
if shouldCheckAccountFrom(jws) {
|
||||
account, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
if acmeErr != nil {
|
||||
render.Error(w, acmeErr)
|
||||
render.Error(w, r, acmeErr)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
@@ -100,7 +100,7 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := jws.Verify(certToBeRevoked.PublicKey)
|
||||
if err != nil {
|
||||
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
|
||||
render.Error(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
|
||||
render.Error(w, r, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -108,19 +108,19 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
ca := mustAuthority(ctx)
|
||||
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
if hasBeenRevokedBefore {
|
||||
render.Error(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
|
||||
render.Error(w, r, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
|
||||
return
|
||||
}
|
||||
|
||||
reasonCode := p.ReasonCode
|
||||
acmeErr := validateReasonCode(reasonCode)
|
||||
if acmeErr != nil {
|
||||
render.Error(w, acmeErr)
|
||||
render.Error(w, r, acmeErr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -128,14 +128,14 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
|
||||
err = prov.AuthorizeRevoke(ctx, "")
|
||||
if err != nil {
|
||||
render.Error(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
|
||||
render.Error(w, r, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
|
||||
return
|
||||
}
|
||||
|
||||
options := revokeOptions(serial, certToBeRevoked, reasonCode)
|
||||
err = ca.Revoke(ctx, options)
|
||||
if err != nil {
|
||||
render.Error(w, wrapRevokeErr(err))
|
||||
render.Error(w, r, wrapRevokeErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -424,7 +424,7 @@ func (e *Error) ToLog() (interface{}, error) {
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for Error.
|
||||
func (e *Error) Render(w http.ResponseWriter) {
|
||||
func (e *Error) Render(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/problem+json")
|
||||
render.JSONStatus(w, e, e.StatusCode())
|
||||
render.JSONStatus(w, r, e, e.StatusCode())
|
||||
}
|
||||
|
||||
@@ -186,19 +186,19 @@ func (l *linker) Middleware(next http.Handler) http.Handler {
|
||||
nameEscaped := chi.URLParam(r, "provisionerID")
|
||||
name, err := url.PathUnescape(nameEscaped)
|
||||
if err != nil {
|
||||
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
render.Error(w, r, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
|
||||
return
|
||||
}
|
||||
|
||||
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
acmeProv, ok := p.(*provisioner.ACME)
|
||||
if !ok {
|
||||
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
render.Error(w, r, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
32
api/api.go
32
api/api.go
@@ -353,15 +353,15 @@ func Route(r Router) {
|
||||
// Version is an HTTP handler that returns the version of the server.
|
||||
func Version(w http.ResponseWriter, r *http.Request) {
|
||||
v := mustAuthority(r.Context()).Version()
|
||||
render.JSON(w, VersionResponse{
|
||||
render.JSON(w, r, VersionResponse{
|
||||
Version: v.Version,
|
||||
RequireClientAuthentication: v.RequireClientAuthentication,
|
||||
})
|
||||
}
|
||||
|
||||
// Health is an HTTP handler that returns the status of the server.
|
||||
func Health(w http.ResponseWriter, _ *http.Request) {
|
||||
render.JSON(w, HealthResponse{Status: "ok"})
|
||||
func Health(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSON(w, r, HealthResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
|
||||
@@ -372,11 +372,11 @@ func Root(w http.ResponseWriter, r *http.Request) {
|
||||
// Load root certificate with the
|
||||
cert, err := mustAuthority(r.Context()).Root(sum)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
render.Error(w, r, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &RootResponse{RootPEM: Certificate{cert}})
|
||||
render.JSON(w, r, &RootResponse{RootPEM: Certificate{cert}})
|
||||
}
|
||||
|
||||
func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
||||
@@ -391,17 +391,17 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
||||
func Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionersResponse{
|
||||
render.JSON(w, r, &ProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
})
|
||||
@@ -412,18 +412,18 @@ func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
|
||||
kid := chi.URLParam(r, "kid")
|
||||
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
|
||||
if err != nil {
|
||||
render.Error(w, errs.NotFoundErr(err))
|
||||
render.Error(w, r, errs.NotFoundErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &ProvisionerKeyResponse{key})
|
||||
render.JSON(w, r, &ProvisionerKeyResponse{key})
|
||||
}
|
||||
|
||||
// Roots returns all the root certificates for the CA.
|
||||
func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error getting roots"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -432,7 +432,7 @@ func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
certs[i] = Certificate{roots[i]}
|
||||
}
|
||||
|
||||
render.JSONStatus(w, &RootsResponse{
|
||||
render.JSONStatus(w, r, &RootsResponse{
|
||||
Certificates: certs,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
@@ -441,7 +441,7 @@ func Roots(w http.ResponseWriter, r *http.Request) {
|
||||
func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
roots, err := mustAuthority(r.Context()).GetRoots()
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -454,7 +454,7 @@ func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
|
||||
if _, err := w.Write(block); err != nil {
|
||||
log.Error(w, err)
|
||||
log.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -464,7 +464,7 @@ func RootsPEM(w http.ResponseWriter, r *http.Request) {
|
||||
func Federation(w http.ResponseWriter, r *http.Request) {
|
||||
federated, err := mustAuthority(r.Context()).GetFederation()
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error getting federated roots"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -473,7 +473,7 @@ func Federation(w http.ResponseWriter, r *http.Request) {
|
||||
certs[i] = Certificate{federated[i]}
|
||||
}
|
||||
|
||||
render.JSONStatus(w, &FederationResponse{
|
||||
render.JSONStatus(w, r, &FederationResponse{
|
||||
Certificates: certs,
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
func CRL(w http.ResponseWriter, r *http.Request) {
|
||||
crlInfo, err := mustAuthority(r.Context()).GetCertificateRevocationList()
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if crlInfo == nil {
|
||||
render.Error(w, errs.New(http.StatusNotFound, "no CRL available"))
|
||||
render.Error(w, r, errs.New(http.StatusNotFound, "no CRL available"))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,31 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ErrorKey is the logging attribute key for error values.
|
||||
const ErrorKey = "error"
|
||||
|
||||
type loggerKey struct{}
|
||||
|
||||
// NewContext creates a new context with the given slog.Logger.
|
||||
func NewContext(ctx context.Context, logger *slog.Logger) context.Context {
|
||||
return context.WithValue(ctx, loggerKey{}, logger)
|
||||
}
|
||||
|
||||
// FromContext returns the logger from the given context.
|
||||
func FromContext(ctx context.Context) (l *slog.Logger, ok bool) {
|
||||
l, ok = ctx.Value(loggerKey{}).(*slog.Logger)
|
||||
return
|
||||
}
|
||||
|
||||
// StackTracedError is the set of errors implementing the StackTrace function.
|
||||
//
|
||||
// Errors implementing this interface have their stack traces logged when passed
|
||||
@@ -27,7 +45,12 @@ type fieldCarrier interface {
|
||||
// Error adds to the response writer the given error if it implements
|
||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||
// using the log package.
|
||||
func Error(rw http.ResponseWriter, err error) {
|
||||
func Error(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
ctx := r.Context()
|
||||
if logger, ok := FromContext(ctx); ok && err != nil {
|
||||
logger.ErrorContext(ctx, "request failed", slog.Any(ErrorKey, err))
|
||||
}
|
||||
|
||||
fc, ok := rw.(fieldCarrier)
|
||||
if !ok {
|
||||
return
|
||||
@@ -51,7 +74,7 @@ func Error(rw http.ResponseWriter, err error) {
|
||||
|
||||
// EnabledResponse log the response object if it implements the EnableLogger
|
||||
// interface.
|
||||
func EnabledResponse(rw http.ResponseWriter, v any) {
|
||||
func EnabledResponse(rw http.ResponseWriter, r *http.Request, v any) {
|
||||
type enableLogger interface {
|
||||
ToLog() (any, error)
|
||||
}
|
||||
@@ -59,7 +82,7 @@ func EnabledResponse(rw http.ResponseWriter, v any) {
|
||||
if el, ok := v.(enableLogger); ok {
|
||||
out, err := el.ToLog()
|
||||
if err != nil {
|
||||
Error(rw, err)
|
||||
Error(rw, r, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -27,21 +30,30 @@ func (stackTracedError) StackTrace() pkgerrors.StackTrace {
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{}))
|
||||
req := httptest.NewRequest("GET", "/test", http.NoBody)
|
||||
reqWithLogger := req.WithContext(NewContext(req.Context(), logger))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
error
|
||||
rw http.ResponseWriter
|
||||
r *http.Request
|
||||
isFieldCarrier bool
|
||||
isSlogLogger bool
|
||||
stepDebug bool
|
||||
expectStackTrace bool
|
||||
}{
|
||||
{"noLogger", nil, nil, false, false, false},
|
||||
{"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
|
||||
{"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
|
||||
{"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, false, false},
|
||||
{"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), true, true, false},
|
||||
{"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
|
||||
{"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), true, true, true},
|
||||
{"noLogger", nil, nil, req, false, false, false, false},
|
||||
{"noError", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false},
|
||||
{"noErrorDebug", nil, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false},
|
||||
{"anError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, false, false},
|
||||
{"anErrorDebug", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, false},
|
||||
{"stackTracedError", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true},
|
||||
{"stackTracedErrorDebug", new(stackTracedError), logging.NewResponseLogger(httptest.NewRecorder()), req, true, false, true, true},
|
||||
{"slogWithNoError", nil, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false},
|
||||
{"slogWithError", assert.AnError, logging.NewResponseLogger(httptest.NewRecorder()), reqWithLogger, true, true, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -52,27 +64,41 @@ func TestError(t *testing.T) {
|
||||
t.Setenv("STEPDEBUG", "0")
|
||||
}
|
||||
|
||||
Error(tt.rw, tt.error)
|
||||
Error(tt.rw, tt.r, tt.error)
|
||||
|
||||
// return early if test case doesn't use logger
|
||||
if !tt.isFieldCarrier {
|
||||
if !tt.isFieldCarrier && !tt.isSlogLogger {
|
||||
return
|
||||
}
|
||||
|
||||
fields := tt.rw.(logging.ResponseLogger).Fields()
|
||||
if tt.isFieldCarrier {
|
||||
fields := tt.rw.(logging.ResponseLogger).Fields()
|
||||
|
||||
// expect the error field to be (not) set and to be the same error that was fed to Error
|
||||
if tt.error == nil {
|
||||
assert.Nil(t, fields["error"])
|
||||
} else {
|
||||
assert.Same(t, tt.error, fields["error"])
|
||||
// expect the error field to be (not) set and to be the same error that was fed to Error
|
||||
if tt.error == nil {
|
||||
assert.Nil(t, fields["error"])
|
||||
} else {
|
||||
assert.Same(t, tt.error, fields["error"])
|
||||
}
|
||||
|
||||
// check if stack-trace is set when expected
|
||||
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] not set`)
|
||||
} else if !tt.expectStackTrace && hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] was set`)
|
||||
}
|
||||
}
|
||||
|
||||
// check if stack-trace is set when expected
|
||||
if _, hasStackTrace := fields["stack-trace"]; tt.expectStackTrace && !hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] not set`)
|
||||
} else if !tt.expectStackTrace && hasStackTrace {
|
||||
t.Error(`ResponseLogger["stack-trace"] was set`)
|
||||
if tt.isSlogLogger {
|
||||
b := buf.Bytes()
|
||||
if tt.error == nil {
|
||||
assert.Empty(t, b)
|
||||
} else if assert.NotEmpty(t, b) {
|
||||
var m map[string]any
|
||||
assert.NoError(t, json.Unmarshal(b, &m))
|
||||
assert.Equal(t, tt.error.Error(), m["error"])
|
||||
}
|
||||
buf.Reset()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func (e badProtoJSONError) Error() string {
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for badProtoJSONError
|
||||
func (e badProtoJSONError) Render(w http.ResponseWriter) {
|
||||
func (e badProtoJSONError) Render(w http.ResponseWriter, r *http.Request) {
|
||||
v := struct {
|
||||
Type string `json:"type"`
|
||||
Detail string `json:"detail"`
|
||||
@@ -62,5 +62,5 @@ func (e badProtoJSONError) Render(w http.ResponseWriter) {
|
||||
// trim the proto prefix for the message
|
||||
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
|
||||
}
|
||||
render.JSONStatus(w, v, http.StatusBadRequest)
|
||||
render.JSONStatus(w, r, v, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
@@ -142,7 +142,8 @@ func Test_badProtoJSONError_Render(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
tt.e.Render(w)
|
||||
r := httptest.NewRequest("POST", "/test", http.NoBody)
|
||||
tt.e.Render(w, r)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
|
||||
10
api/rekey.go
10
api/rekey.go
@@ -29,25 +29,25 @@ func (s *RekeyRequest) Validate() error {
|
||||
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
|
||||
func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing client certificate"))
|
||||
render.Error(w, r, errs.BadRequest("missing client certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
var body RekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
a := mustAuthority(r.Context())
|
||||
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||
render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
@@ -57,7 +57,7 @@ func Rekey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
render.JSONStatus(w, r, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
)
|
||||
|
||||
// JSON is shorthand for JSONStatus(w, v, http.StatusOK).
|
||||
func JSON(w http.ResponseWriter, v interface{}) {
|
||||
JSONStatus(w, v, http.StatusOK)
|
||||
func JSON(w http.ResponseWriter, r *http.Request, v interface{}) {
|
||||
JSONStatus(w, r, v, http.StatusOK)
|
||||
}
|
||||
|
||||
// JSONStatus marshals v into w. It additionally sets the status code of
|
||||
@@ -22,7 +22,7 @@ func JSON(w http.ResponseWriter, v interface{}) {
|
||||
//
|
||||
// JSONStatus sets the Content-Type of w to application/json unless one is
|
||||
// specified.
|
||||
func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||
func JSONStatus(w http.ResponseWriter, r *http.Request, v interface{}, status int) {
|
||||
setContentTypeUnlessPresent(w, "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
@@ -43,7 +43,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
||||
}
|
||||
}
|
||||
|
||||
log.EnabledResponse(w, v)
|
||||
log.EnabledResponse(w, r, v)
|
||||
}
|
||||
|
||||
// ProtoJSON is shorthand for ProtoJSONStatus(w, m, http.StatusOK).
|
||||
@@ -80,22 +80,22 @@ func setContentTypeUnlessPresent(w http.ResponseWriter, contentType string) {
|
||||
type RenderableError interface {
|
||||
error
|
||||
|
||||
Render(http.ResponseWriter)
|
||||
Render(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
// Error marshals the JSON representation of err to w. In case err implements
|
||||
// RenderableError its own Render method will be called instead.
|
||||
func Error(w http.ResponseWriter, err error) {
|
||||
log.Error(w, err)
|
||||
func Error(rw http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Error(rw, r, err)
|
||||
|
||||
var r RenderableError
|
||||
if errors.As(err, &r) {
|
||||
r.Render(w)
|
||||
var re RenderableError
|
||||
if errors.As(err, &re) {
|
||||
re.Render(rw, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
JSONStatus(w, err, statusCodeFromError(err))
|
||||
JSONStatus(rw, r, err, statusCodeFromError(err))
|
||||
}
|
||||
|
||||
// StatusCodedError is the set of errors that implement the basic StatusCode
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
func TestJSON(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
rw := logging.NewResponseLogger(rec)
|
||||
|
||||
JSON(rw, map[string]interface{}{"foo": "bar"})
|
||||
r := httptest.NewRequest("POST", "/test", http.NoBody)
|
||||
JSON(rw, r, map[string]interface{}{"foo": "bar"})
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Result().StatusCode)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
@@ -64,7 +64,8 @@ func jsonPanicTest[T json.UnsupportedTypeError | json.UnsupportedValueError | js
|
||||
assert.ErrorAs(t, err, &e)
|
||||
}()
|
||||
|
||||
JSON(httptest.NewRecorder(), v)
|
||||
r := httptest.NewRequest("POST", "/test", http.NoBody)
|
||||
JSON(httptest.NewRecorder(), r, v)
|
||||
}
|
||||
|
||||
type renderableError struct {
|
||||
@@ -76,10 +77,9 @@ func (err renderableError) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
func (err renderableError) Render(w http.ResponseWriter) {
|
||||
func (err renderableError) Render(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "something/custom")
|
||||
|
||||
JSONStatus(w, err, err.Code)
|
||||
JSONStatus(w, r, err, err.Code)
|
||||
}
|
||||
|
||||
type statusedError struct {
|
||||
@@ -116,8 +116,8 @@ func TestError(t *testing.T) {
|
||||
|
||||
t.Run(strconv.Itoa(caseIndex), func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
Error(rec, kase.err)
|
||||
r := httptest.NewRequest("POST", "/test", http.NoBody)
|
||||
Error(rec, r, kase.err)
|
||||
|
||||
assert.Equal(t, kase.code, rec.Result().StatusCode)
|
||||
assert.Equal(t, kase.body, rec.Body.String())
|
||||
|
||||
@@ -23,19 +23,20 @@ func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
// Get the leaf certificate from the peer or the token.
|
||||
cert, token, err := getPeerCertificate(r)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// The token can be used by RAs to renew a certificate.
|
||||
if token != "" {
|
||||
ctx = authority.NewTokenContext(ctx, token)
|
||||
logOtt(w, token)
|
||||
}
|
||||
|
||||
a := mustAuthority(ctx)
|
||||
certChain, err := a.RenewContext(ctx, cert, nil)
|
||||
if err != nil {
|
||||
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
render.Error(w, r, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
@@ -45,7 +46,7 @@ func Renew(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
render.JSONStatus(w, r, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
|
||||
@@ -57,12 +57,12 @@ func (r *RevokeRequest) Validate() (err error) {
|
||||
func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
if body.OTT != "" {
|
||||
logOtt(w, body.OTT)
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
@@ -90,12 +90,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
// the client certificate Serial Number must match the serial number
|
||||
// being revoked.
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
render.Error(w, errs.BadRequest("missing ott or client certificate"))
|
||||
render.Error(w, r, errs.BadRequest("missing ott or client certificate"))
|
||||
return
|
||||
}
|
||||
opts.Crt = r.TLS.PeerCertificates[0]
|
||||
if opts.Crt.SerialNumber.String() != opts.Serial {
|
||||
render.Error(w, errs.BadRequest("serial number in client certificate different than body"))
|
||||
render.Error(w, r, errs.BadRequest("serial number in client certificate different than body"))
|
||||
return
|
||||
}
|
||||
// TODO: should probably be checking if the certificate was revoked here.
|
||||
@@ -106,12 +106,12 @@ func Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error revoking certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
logRevoke(w, opts)
|
||||
render.JSON(w, &RevokeResponse{Status: "ok"})
|
||||
render.JSON(w, r, &RevokeResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||
|
||||
10
api/sign.go
10
api/sign.go
@@ -52,13 +52,13 @@ type SignResponse struct {
|
||||
func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -74,13 +74,13 @@ func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error signing certificate"))
|
||||
return
|
||||
}
|
||||
certChainPEM := certChainToPEM(certChain)
|
||||
@@ -90,7 +90,7 @@ func Sign(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
LogCertificate(w, certChain[0])
|
||||
render.JSONStatus(w, &SignResponse{
|
||||
render.JSONStatus(w, r, &SignResponse{
|
||||
ServerPEM: certChainPEM[0],
|
||||
CaPEM: caPEM,
|
||||
CertChainPEM: certChainPEM,
|
||||
|
||||
62
api/ssh.go
62
api/ssh.go
@@ -253,19 +253,19 @@ type SSHBastionResponse struct {
|
||||
func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
if body.AddUserPublicKey != nil {
|
||||
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -293,13 +293,13 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -307,7 +307,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
|
||||
addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error signing ssh certificate"))
|
||||
return
|
||||
}
|
||||
addUserCertificate = &SSHCertificate{addUserCert}
|
||||
@@ -320,7 +320,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -332,14 +332,14 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error signing identity certificate"))
|
||||
return
|
||||
}
|
||||
identityCertificate = certChainToPEM(certChain)
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, cert)
|
||||
render.JSONStatus(w, &SSHSignResponse{
|
||||
render.JSONStatus(w, r, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{cert},
|
||||
AddUserCertificate: addUserCertificate,
|
||||
IdentityCertificate: identityCertificate,
|
||||
@@ -352,12 +352,12 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
render.Error(w, errs.NotFound("no keys found"))
|
||||
render.Error(w, r, errs.NotFound("no keys found"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -369,7 +369,7 @@ func SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
||||
}
|
||||
|
||||
render.JSON(w, resp)
|
||||
render.JSON(w, r, resp)
|
||||
}
|
||||
|
||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||
@@ -378,12 +378,12 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
|
||||
render.Error(w, errs.NotFound("no keys found"))
|
||||
render.Error(w, r, errs.NotFound("no keys found"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -395,7 +395,7 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
resp.UserKeys = append(resp.UserKeys, SSHPublicKey{PublicKey: k})
|
||||
}
|
||||
|
||||
render.JSON(w, resp)
|
||||
render.JSON(w, r, resp)
|
||||
}
|
||||
|
||||
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients
|
||||
@@ -403,18 +403,18 @@ func SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -425,32 +425,32 @@ func SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
case provisioner.SSHHostCert:
|
||||
cfg.HostTemplates = ts
|
||||
default:
|
||||
render.Error(w, errs.InternalServer("it should hot get here"))
|
||||
render.Error(w, r, errs.InternalServer("it should hot get here"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, cfg)
|
||||
render.JSON(w, r, cfg)
|
||||
}
|
||||
|
||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||
func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHCheckPrincipalRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
render.JSON(w, &SSHCheckPrincipalResponse{
|
||||
render.JSON(w, r, &SSHCheckPrincipalResponse{
|
||||
Exists: exists,
|
||||
})
|
||||
}
|
||||
@@ -465,10 +465,10 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
render.JSON(w, &SSHGetHostsResponse{
|
||||
render.JSON(w, r, &SSHGetHostsResponse{
|
||||
Hosts: hosts,
|
||||
})
|
||||
}
|
||||
@@ -477,22 +477,22 @@ func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
func SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &SSHBastionResponse{
|
||||
render.JSON(w, r, &SSHBastionResponse{
|
||||
Hostname: body.Hostname,
|
||||
Bastion: bastion,
|
||||
})
|
||||
|
||||
@@ -42,19 +42,19 @@ type SSHRekeyResponse struct {
|
||||
func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
|
||||
if err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error parsing publicKey"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -64,18 +64,18 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
a := mustAuthority(ctx)
|
||||
signOpts, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -85,12 +85,12 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, newCert)
|
||||
render.JSONStatus(w, &SSHRekeyResponse{
|
||||
render.JSONStatus(w, r, &SSHRekeyResponse{
|
||||
Certificate: SSHCertificate{newCert},
|
||||
IdentityCertificate: identity,
|
||||
}, http.StatusCreated)
|
||||
|
||||
@@ -40,13 +40,13 @@ type SSHRenewResponse struct {
|
||||
func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -56,18 +56,18 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
a := mustAuthority(ctx)
|
||||
_, err := a.Authorize(ctx, body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
|
||||
newCert, err := a.RenewSSH(ctx, oldCert)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error renewing ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,12 +77,12 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
identity, err := renewIdentityCertificate(r, notBefore, notAfter)
|
||||
if err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error renewing identity certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
LogSSHCertificate(w, newCert)
|
||||
render.JSONStatus(w, &SSHSignResponse{
|
||||
render.JSONStatus(w, r, &SSHSignResponse{
|
||||
Certificate: SSHCertificate{newCert},
|
||||
IdentityCertificate: identity,
|
||||
}, http.StatusCreated)
|
||||
|
||||
@@ -51,12 +51,12 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
||||
func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
render.Error(w, r, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,18 +75,18 @@ func SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
logOtt(w, body.OTT)
|
||||
|
||||
if _, err := a.Authorize(ctx, body.OTT); err != nil {
|
||||
render.Error(w, errs.UnauthorizedErr(err))
|
||||
render.Error(w, r, errs.UnauthorizedErr(err))
|
||||
return
|
||||
}
|
||||
opts.OTT = body.OTT
|
||||
|
||||
if err := a.Revoke(ctx, opts); err != nil {
|
||||
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||
render.Error(w, r, errs.ForbiddenErr(err, "error revoking ssh certificate"))
|
||||
return
|
||||
}
|
||||
|
||||
logSSHRevoke(w, opts)
|
||||
render.JSON(w, &SSHRevokeResponse{Status: "ok"})
|
||||
render.JSON(w, r, &SSHRevokeResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func logSSHRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
|
||||
|
||||
@@ -40,12 +40,12 @@ func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
acmeProvisioner := prov.GetDetails().GetACME()
|
||||
if acmeProvisioner == nil {
|
||||
render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
|
||||
render.Error(w, r, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
if !acmeProvisioner.RequireEab {
|
||||
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -69,18 +69,18 @@ func NewACMEAdminResponder() ACMEAdminResponder {
|
||||
}
|
||||
|
||||
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint
|
||||
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, _ *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
// CreateExternalAccountKey writes the response for the EAB key POST endpoint
|
||||
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, _ *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
|
||||
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, _ *http.Request) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
|
||||
}
|
||||
|
||||
func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey {
|
||||
|
||||
@@ -90,7 +90,7 @@ func GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
|
||||
if !ok {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType,
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType,
|
||||
"admin %s not found", id))
|
||||
return
|
||||
}
|
||||
@@ -101,17 +101,17 @@ func GetAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
func GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
"error parsing cursor and limit from query params"))
|
||||
return
|
||||
}
|
||||
|
||||
admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error retrieving paginated admins"))
|
||||
return
|
||||
}
|
||||
render.JSON(w, &GetAdminsResponse{
|
||||
render.JSON(w, r, &GetAdminsResponse{
|
||||
Admins: admins,
|
||||
NextCursor: nextCursor,
|
||||
})
|
||||
@@ -121,19 +121,19 @@ func GetAdmins(w http.ResponseWriter, r *http.Request) {
|
||||
func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body CreateAdminRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
auth := mustAuthority(r.Context())
|
||||
p, err := auth.LoadProvisionerByName(body.Provisioner)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
|
||||
return
|
||||
}
|
||||
adm := &linkedca.Admin{
|
||||
@@ -143,7 +143,7 @@ func CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
// Store to authority collection.
|
||||
if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error storing admin"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -155,23 +155,23 @@ func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error deleting admin %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &DeleteResponse{Status: "ok"})
|
||||
render.JSON(w, r, &DeleteResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// UpdateAdmin updates an existing admin.
|
||||
func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body UpdateAdminRequest
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := body.Validate(); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -179,7 +179,7 @@ func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
auth := mustAuthority(r.Context())
|
||||
adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error updating admin %s", id))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
@@ -31,7 +31,7 @@ func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
tok := r.Header.Get("Authorization")
|
||||
if tok == "" {
|
||||
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
render.Error(w, r, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
"missing authorization header token"))
|
||||
return
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||
ctx := r.Context()
|
||||
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -64,13 +64,13 @@ func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
// TODO(hs): distinguish 404 vs. 500
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.Handler
|
||||
// when an action is not supported in standalone mode and when
|
||||
// using a nosql.DB backend, actions are not supported
|
||||
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"operation not supported in standalone mode"))
|
||||
return
|
||||
}
|
||||
@@ -125,15 +125,15 @@ func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
if err != nil {
|
||||
if acme.IsErrNotFound(err) {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
|
||||
return
|
||||
}
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error retrieving ACME External Account Key"))
|
||||
return
|
||||
}
|
||||
|
||||
if eak == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ func NewPolicyAdminResponder() PolicyAdminResponder {
|
||||
func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -52,12 +52,12 @@ func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht
|
||||
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht
|
||||
func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,26 +77,26 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
|
||||
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
|
||||
render.Error(w, adminErr)
|
||||
render.Error(w, r, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -105,11 +105,11 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
|
||||
var createdPolicy *linkedca.Policy
|
||||
if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing authority policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error storing authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
|
||||
func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -129,25 +129,25 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
|
||||
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -156,11 +156,11 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
|
||||
var updatedPolicy *linkedca.Policy
|
||||
if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating authority policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error updating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -171,7 +171,7 @@ func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
|
||||
func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -180,35 +180,35 @@ func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r
|
||||
|
||||
var ae *admin.Error
|
||||
if errors.As(err, &ae) && !ae.IsType(admin.ErrorNotFoundType) {
|
||||
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if authorityPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error deleting authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
|
||||
func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -219,7 +219,7 @@ func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *
|
||||
func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -227,20 +227,20 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
|
||||
render.Error(w, adminErr)
|
||||
render.Error(w, r, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -248,11 +248,11 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error creating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -263,27 +263,27 @@ func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
|
||||
func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
provisionerPolicy := prov.GetPolicy()
|
||||
if provisionerPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -291,11 +291,11 @@ func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error updating provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -306,13 +306,13 @@ func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
|
||||
func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov := linkedca.MustProvisionerFromContext(ctx)
|
||||
if prov.Policy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -321,24 +321,24 @@ func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter,
|
||||
|
||||
auth := mustAuthority(ctx)
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error deleting provisioner policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -348,7 +348,7 @@ func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *
|
||||
func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -357,20 +357,20 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
|
||||
render.Error(w, adminErr)
|
||||
render.Error(w, r, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -379,7 +379,7 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -389,7 +389,7 @@ func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
|
||||
func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -397,20 +397,20 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
newPolicy.Deduplicate()
|
||||
|
||||
if err := validatePolicy(newPolicy); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -418,7 +418,7 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -428,7 +428,7 @@ func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
|
||||
func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if err := blockLinkedCA(ctx); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -436,7 +436,7 @@ func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
|
||||
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
|
||||
eakPolicy := eak.GetPolicy()
|
||||
if eakPolicy == nil {
|
||||
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
render.Error(w, r, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -446,11 +446,11 @@ func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
|
||||
acmeEAK := linkedEAKToCertificates(eak)
|
||||
acmeDB := acme.MustDatabaseFromContext(ctx)
|
||||
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
// blockLinkedCA blocks all API operations on linked deployments
|
||||
|
||||
@@ -40,19 +40,19 @@ func GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if id != "" {
|
||||
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
prov, err := db.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
render.ProtoJSON(w, prov)
|
||||
@@ -62,17 +62,17 @@ func GetProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
func GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := api.ParseCursor(r)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err,
|
||||
"error parsing cursor and limit from query params"))
|
||||
return
|
||||
}
|
||||
|
||||
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
|
||||
if err != nil {
|
||||
render.Error(w, errs.InternalServerErr(err))
|
||||
render.Error(w, r, errs.InternalServerErr(err))
|
||||
return
|
||||
}
|
||||
render.JSON(w, &GetProvisionersResponse{
|
||||
render.JSON(w, r, &GetProvisionersResponse{
|
||||
Provisioners: p,
|
||||
NextCursor: next,
|
||||
})
|
||||
@@ -82,24 +82,24 @@ func GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
||||
func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var prov = new(linkedca.Provisioner)
|
||||
if err := read.ProtoJSON(r.Body, prov); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Validate inputs
|
||||
if err := authority.ValidateClaims(prov.Claims); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// validate the templates and template data
|
||||
if err := validateTemplates(prov.X509Template, prov.SshTemplate); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
|
||||
return
|
||||
}
|
||||
render.ProtoJSONStatus(w, prov, http.StatusCreated)
|
||||
@@ -118,29 +118,29 @@ func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if id != "" {
|
||||
if p, err = auth.LoadProvisionerByID(id); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", id))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if p, err = auth.LoadProvisionerByName(name); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSON(w, &DeleteResponse{Status: "ok"})
|
||||
render.JSON(w, r, &DeleteResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
// UpdateProvisioner updates an existing prov.
|
||||
func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var nu = new(linkedca.Provisioner)
|
||||
if err := read.ProtoJSON(r.Body, nu); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -151,51 +151,51 @@ func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
p, err := auth.LoadProvisionerByName(name)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
|
||||
return
|
||||
}
|
||||
|
||||
old, err := db.GetProvisioner(r.Context(), p.GetID())
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
|
||||
return
|
||||
}
|
||||
|
||||
if nu.Id != old.Id {
|
||||
render.Error(w, admin.NewErrorISE("cannot change provisioner ID"))
|
||||
render.Error(w, r, admin.NewErrorISE("cannot change provisioner ID"))
|
||||
return
|
||||
}
|
||||
if nu.Type != old.Type {
|
||||
render.Error(w, admin.NewErrorISE("cannot change provisioner type"))
|
||||
render.Error(w, r, admin.NewErrorISE("cannot change provisioner type"))
|
||||
return
|
||||
}
|
||||
if nu.AuthorityId != old.AuthorityId {
|
||||
render.Error(w, admin.NewErrorISE("cannot change provisioner authorityID"))
|
||||
render.Error(w, r, admin.NewErrorISE("cannot change provisioner authorityID"))
|
||||
return
|
||||
}
|
||||
if !nu.CreatedAt.AsTime().Equal(old.CreatedAt.AsTime()) {
|
||||
render.Error(w, admin.NewErrorISE("cannot change provisioner createdAt"))
|
||||
render.Error(w, r, admin.NewErrorISE("cannot change provisioner createdAt"))
|
||||
return
|
||||
}
|
||||
if !nu.DeletedAt.AsTime().Equal(old.DeletedAt.AsTime()) {
|
||||
render.Error(w, admin.NewErrorISE("cannot change provisioner deletedAt"))
|
||||
render.Error(w, r, admin.NewErrorISE("cannot change provisioner deletedAt"))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Validate inputs
|
||||
if err := authority.ValidateClaims(nu.Claims); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
// validate the templates and template data
|
||||
if err := validateTemplates(nu.X509Template, nu.SshTemplate); err != nil {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "invalid template"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
render.ProtoJSON(w, nu)
|
||||
|
||||
@@ -71,28 +71,28 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter
|
||||
|
||||
var newWebhook = new(linkedca.Webhook)
|
||||
if err := read.ProtoJSON(r.Body, newWebhook); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateWebhook(newWebhook); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
if newWebhook.Secret != "" {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "webhook secret must not be set")
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
if newWebhook.Id != "" {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "webhook ID must not be set")
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := randutil.UUIDv4()
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error generating webhook id"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error generating webhook id"))
|
||||
return
|
||||
}
|
||||
newWebhook.Id = id
|
||||
@@ -101,14 +101,14 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter
|
||||
for _, wh := range prov.Webhooks {
|
||||
if wh.Name == newWebhook.Name {
|
||||
err := admin.NewError(admin.ErrorConflictType, "provisioner %q already has a webhook with the name %q", prov.Name, newWebhook.Name)
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
secret, err := randutil.Bytes(64)
|
||||
if err != nil {
|
||||
render.Error(w, admin.WrapErrorISE(err, "error generating webhook secret"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error generating webhook secret"))
|
||||
return
|
||||
}
|
||||
newWebhook.Secret = base64.StdEncoding.EncodeToString(secret)
|
||||
@@ -117,11 +117,11 @@ func (war *webhookAdminResponder) CreateProvisionerWebhook(w http.ResponseWriter
|
||||
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner webhook"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner webhook"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error creating provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -145,21 +145,21 @@ func (war *webhookAdminResponder) DeleteProvisionerWebhook(w http.ResponseWriter
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error deleting provisioner webhook"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error deleting provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner webhook"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error deleting provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
render.JSONStatus(w, r, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -170,12 +170,12 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter
|
||||
|
||||
var newWebhook = new(linkedca.Webhook)
|
||||
if err := read.ProtoJSON(r.Body, newWebhook); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateWebhook(newWebhook); err != nil {
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -186,13 +186,13 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter
|
||||
}
|
||||
if newWebhook.Secret != "" && newWebhook.Secret != wh.Secret {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "webhook secret cannot be updated")
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
newWebhook.Secret = wh.Secret
|
||||
if newWebhook.Id != "" && newWebhook.Id != wh.Id {
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "webhook ID cannot be updated")
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
newWebhook.Id = wh.Id
|
||||
@@ -203,17 +203,17 @@ func (war *webhookAdminResponder) UpdateProvisionerWebhook(w http.ResponseWriter
|
||||
if !found {
|
||||
msg := fmt.Sprintf("provisioner %q has no webhook with the name %q", prov.Name, newWebhook.Name)
|
||||
err := admin.NewError(admin.ErrorNotFoundType, msg)
|
||||
render.Error(w, err)
|
||||
render.Error(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
|
||||
if isBadRequest(err) {
|
||||
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner webhook"))
|
||||
render.Error(w, r, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner webhook"))
|
||||
render.Error(w, r, admin.WrapErrorISE(err, "error updating provisioner webhook"))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -205,8 +205,8 @@ func (e *Error) ToLog() (interface{}, error) {
|
||||
}
|
||||
|
||||
// Render implements render.RenderableError for Error.
|
||||
func (e *Error) Render(w http.ResponseWriter) {
|
||||
func (e *Error) Render(w http.ResponseWriter, r *http.Request) {
|
||||
e.Message = e.Err.Error()
|
||||
|
||||
render.JSONStatus(w, e, e.StatusCode())
|
||||
render.JSONStatus(w, r, e, e.StatusCode())
|
||||
}
|
||||
|
||||
@@ -108,19 +108,19 @@ func TestNewACMEClient(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
switch {
|
||||
case i == 0:
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
case i == 1:
|
||||
w.Header().Set("Replay-Nonce", "abc123")
|
||||
render.JSONStatus(w, []byte{}, 200)
|
||||
render.JSONStatus(w, r, []byte{}, 200)
|
||||
i++
|
||||
default:
|
||||
w.Header().Set("Location", accLocation)
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -203,10 +203,10 @@ func TestACMEClient_GetNonce(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
})
|
||||
|
||||
if nonce, err := ac.GetNonce(); err != nil {
|
||||
@@ -310,18 +310,18 @@ func TestACMEClient_post(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -338,7 +338,7 @@ func TestACMEClient_post(t *testing.T) {
|
||||
assert.Equals(t, hdr.KeyID, ac.kid)
|
||||
}
|
||||
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
})
|
||||
|
||||
if resp, err := tc.client.post(tc.payload, url, tc.ops...); err != nil {
|
||||
@@ -450,18 +450,18 @@ func TestACMEClient_NewOrder(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -477,7 +477,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, payload, norb)
|
||||
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
})
|
||||
|
||||
if res, err := ac.NewOrder(norb); err != nil {
|
||||
@@ -572,18 +572,18 @@ func TestACMEClient_GetOrder(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -599,7 +599,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, len(payload), 0)
|
||||
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
})
|
||||
|
||||
if res, err := ac.GetOrder(url); err != nil {
|
||||
@@ -694,18 +694,18 @@ func TestACMEClient_GetAuthz(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -721,7 +721,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, len(payload), 0)
|
||||
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
})
|
||||
|
||||
if res, err := ac.GetAuthz(url); err != nil {
|
||||
@@ -816,18 +816,18 @@ func TestACMEClient_GetChallenge(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -844,7 +844,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
|
||||
|
||||
assert.Equals(t, len(payload), 0)
|
||||
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
})
|
||||
|
||||
if res, err := ac.GetChallenge(url); err != nil {
|
||||
@@ -939,18 +939,18 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -967,7 +967,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
|
||||
|
||||
assert.Equals(t, payload, []byte("{}"))
|
||||
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
})
|
||||
|
||||
if err := ac.ValidateChallenge(url); err != nil {
|
||||
@@ -983,22 +983,22 @@ func TestACMEClient_ValidateWithPayload(t *testing.T) {
|
||||
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
t.Log(req.RequestURI)
|
||||
t.Log(r.RequestURI)
|
||||
w.Header().Set("Replay-Nonce", "nonce")
|
||||
switch req.RequestURI {
|
||||
switch r.RequestURI {
|
||||
case "/nonce":
|
||||
render.JSONStatus(w, []byte{}, 200)
|
||||
render.JSONStatus(w, r, []byte{}, 200)
|
||||
return
|
||||
case "/fail-nonce":
|
||||
render.JSONStatus(w, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400)
|
||||
render.JSONStatus(w, r, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400)
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
@@ -1015,15 +1015,15 @@ func TestACMEClient_ValidateWithPayload(t *testing.T) {
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, payload, []byte("the-payload"))
|
||||
|
||||
switch req.RequestURI {
|
||||
switch r.RequestURI {
|
||||
case "/ok":
|
||||
render.JSONStatus(w, acme.Challenge{
|
||||
render.JSONStatus(w, r, acme.Challenge{
|
||||
Type: "device-attestation-01",
|
||||
Status: "valid",
|
||||
Token: "foo",
|
||||
}, 200)
|
||||
case "/fail":
|
||||
render.JSONStatus(w, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400)
|
||||
render.JSONStatus(w, r, acme.NewError(acme.ErrorMalformedType, "malformed request"), 400)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
@@ -1160,18 +1160,18 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -1187,7 +1187,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, payload, frb)
|
||||
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
})
|
||||
|
||||
if err := ac.FinalizeOrder(url, csr); err != nil {
|
||||
@@ -1289,18 +1289,18 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -1316,7 +1316,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, len(payload), 0)
|
||||
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
})
|
||||
|
||||
if res, err := tc.client.GetAccountOrders(); err != nil {
|
||||
@@ -1420,18 +1420,18 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||
tc := run(t)
|
||||
|
||||
i := 0
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", req.Header.Get("User-Agent")) // check default User-Agent header
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equals(t, "step-http-client/1.0", r.Header.Get("User-Agent")) // check default User-Agent header
|
||||
|
||||
w.Header().Set("Replay-Nonce", expectedNonce)
|
||||
if i == 0 {
|
||||
render.JSONStatus(w, tc.r1, tc.rc1)
|
||||
render.JSONStatus(w, r, tc.r1, tc.rc1)
|
||||
i++
|
||||
return
|
||||
}
|
||||
|
||||
// validate jws request protected headers and body
|
||||
body, err := io.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := jose.ParseJWS(string(body))
|
||||
assert.FatalError(t, err)
|
||||
@@ -1450,7 +1450,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||
if tc.certBytes != nil {
|
||||
w.Write(tc.certBytes)
|
||||
} else {
|
||||
render.JSONStatus(w, tc.r2, tc.rc2)
|
||||
render.JSONStatus(w, r, tc.r2, tc.rc2)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ func startCAServer(configFile string) (*CA, string, error) {
|
||||
func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/version" {
|
||||
render.JSON(w, api.VersionResponse{
|
||||
render.JSON(w, r, api.VersionResponse{
|
||||
Version: "test",
|
||||
RequireClientAuthentication: true,
|
||||
})
|
||||
@@ -102,7 +102,7 @@ func mTLSMiddleware(next http.Handler, nonAuthenticatedPaths ...string) http.Han
|
||||
}
|
||||
isMTLS := r.TLS != nil && len(r.TLS.PeerCertificates) > 0
|
||||
if !isMTLS {
|
||||
render.Error(w, errs.Unauthorized("missing peer certificate"))
|
||||
render.Error(w, r, errs.Unauthorized("missing peer certificate"))
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
@@ -412,7 +412,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
|
||||
//nolint:gosec // insecure test server
|
||||
server, err := BootstrapServer(context.Background(), token, &http.Server{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("ok"))
|
||||
}),
|
||||
}, RequireAndVerifyClientCert())
|
||||
@@ -531,7 +531,7 @@ func TestBootstrapClientServerFederation(t *testing.T) {
|
||||
//nolint:gosec // insecure test server
|
||||
server, err := BootstrapServer(context.Background(), token, &http.Server{
|
||||
Addr: ":0",
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("ok"))
|
||||
}),
|
||||
}, RequireAndVerifyClientCert(), AddFederationToClientCAs())
|
||||
|
||||
@@ -177,8 +177,8 @@ func TestClient_Version(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Version()
|
||||
@@ -218,8 +218,8 @@ func TestClient_Health(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Health()
|
||||
@@ -262,12 +262,12 @@ func TestClient_Root(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
expected := "/root/" + tt.shasum
|
||||
if req.RequestURI != expected {
|
||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
||||
if r.RequestURI != expected {
|
||||
t.Errorf("RequestURI = %s, want %s", r.RequestURI, expected)
|
||||
}
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Root(tt.shasum)
|
||||
@@ -323,12 +323,12 @@ func TestClient_Sign(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body := new(api.SignRequest)
|
||||
if err := read.JSON(req.Body, body); err != nil {
|
||||
if err := read.JSON(r.Body, body); err != nil {
|
||||
e, ok := tt.response.(error)
|
||||
require.True(t, ok, "response expected to be error type")
|
||||
render.Error(w, e)
|
||||
render.Error(w, r, e)
|
||||
return
|
||||
} else if !equalJSON(t, body, tt.request) {
|
||||
if tt.request == nil {
|
||||
@@ -339,7 +339,7 @@ func TestClient_Sign(t *testing.T) {
|
||||
t.Errorf("Client.Sign() request = %v, wants %v", body, tt.request)
|
||||
}
|
||||
}
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Sign(tt.request)
|
||||
@@ -385,12 +385,12 @@ func TestClient_Revoke(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body := new(api.RevokeRequest)
|
||||
if err := read.JSON(req.Body, body); err != nil {
|
||||
if err := read.JSON(r.Body, body); err != nil {
|
||||
e, ok := tt.response.(error)
|
||||
require.True(t, ok, "response expected to be error type")
|
||||
render.Error(w, e)
|
||||
render.Error(w, r, e)
|
||||
return
|
||||
} else if !equalJSON(t, body, tt.request) {
|
||||
if tt.request == nil {
|
||||
@@ -401,7 +401,7 @@ func TestClient_Revoke(t *testing.T) {
|
||||
t.Errorf("Client.Revoke() request = %v, wants %v", body, tt.request)
|
||||
}
|
||||
}
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Revoke(tt.request, nil)
|
||||
@@ -450,8 +450,8 @@ func TestClient_Renew(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Renew(nil)
|
||||
@@ -504,11 +504,11 @@ func TestClient_RenewWithToken(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("Authorization") != "Bearer token" {
|
||||
render.JSONStatus(w, errs.InternalServer("force"), 500)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "Bearer token" {
|
||||
render.JSONStatus(w, r, errs.InternalServer("force"), 500)
|
||||
} else {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -567,8 +567,8 @@ func TestClient_Rekey(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Rekey(tt.request, nil)
|
||||
@@ -619,11 +619,11 @@ func TestClient_Provisioners(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.RequestURI != tt.expectedURI {
|
||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI != tt.expectedURI {
|
||||
t.Errorf("RequestURI = %s, want %s", r.RequestURI, tt.expectedURI)
|
||||
}
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Provisioners(tt.args...)
|
||||
@@ -666,12 +666,12 @@ func TestClient_ProvisionerKey(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
expected := "/provisioners/" + tt.kid + "/encrypted-key"
|
||||
if req.RequestURI != expected {
|
||||
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
|
||||
if r.RequestURI != expected {
|
||||
t.Errorf("RequestURI = %s, want %s", r.RequestURI, expected)
|
||||
}
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.ProvisionerKey(tt.kid)
|
||||
@@ -720,8 +720,8 @@ func TestClient_Roots(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Roots()
|
||||
@@ -769,8 +769,8 @@ func TestClient_Federation(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.Federation()
|
||||
@@ -820,8 +820,8 @@ func TestClient_SSHRoots(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.SSHRoots()
|
||||
@@ -912,8 +912,8 @@ func TestClient_RootFingerprint(t *testing.T) {
|
||||
c, err := NewClient(tt.server.URL, WithTransport(tr))
|
||||
require.NoError(t, err)
|
||||
|
||||
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.RootFingerprint()
|
||||
@@ -970,8 +970,8 @@ func TestClient_SSHBastion(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
render.JSONStatus(w, tt.response, tt.responseCode)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
render.JSONStatus(w, r, tt.response, tt.responseCode)
|
||||
})
|
||||
|
||||
got, err := c.SSHBastion(tt.request)
|
||||
|
||||
@@ -97,7 +97,7 @@ func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc
|
||||
func Get(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
fail(w, fmt.Errorf("invalid scep get request: %w", err))
|
||||
fail(w, r, fmt.Errorf("invalid scep get request: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -116,18 +116,18 @@ func Get(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fail(w, fmt.Errorf("scep get request failed: %w", err))
|
||||
fail(w, r, fmt.Errorf("scep get request failed: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
writeResponse(w, res)
|
||||
writeResponse(w, r, res)
|
||||
}
|
||||
|
||||
// Post handles all SCEP POST requests
|
||||
func Post(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
fail(w, fmt.Errorf("invalid scep post request: %w", err))
|
||||
fail(w, r, fmt.Errorf("invalid scep post request: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,11 +140,11 @@ func Post(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fail(w, fmt.Errorf("scep post request failed: %w", err))
|
||||
fail(w, r, fmt.Errorf("scep post request failed: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
writeResponse(w, res)
|
||||
writeResponse(w, r, res)
|
||||
}
|
||||
|
||||
func decodeRequest(r *http.Request) (request, error) {
|
||||
@@ -274,7 +274,7 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
||||
name := chi.URLParam(r, "provisionerName")
|
||||
provisionerName, err := url.PathUnescape(name)
|
||||
if err != nil {
|
||||
fail(w, fmt.Errorf("error url unescaping provisioner name '%s'", name))
|
||||
fail(w, r, fmt.Errorf("error url unescaping provisioner name '%s'", name))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -282,13 +282,13 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
||||
auth := authority.MustFromContext(ctx)
|
||||
p, err := auth.LoadProvisionerByName(provisionerName)
|
||||
if err != nil {
|
||||
fail(w, err)
|
||||
fail(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov, ok := p.(*provisioner.SCEP)
|
||||
if !ok {
|
||||
fail(w, errors.New("provisioner must be of type SCEP"))
|
||||
fail(w, r, errors.New("provisioner must be of type SCEP"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -430,9 +430,9 @@ func formatCapabilities(caps []string) []byte {
|
||||
}
|
||||
|
||||
// writeResponse writes a SCEP response back to the SCEP client.
|
||||
func writeResponse(w http.ResponseWriter, res Response) {
|
||||
func writeResponse(w http.ResponseWriter, r *http.Request, res Response) {
|
||||
if res.Error != nil {
|
||||
log.Error(w, res.Error)
|
||||
log.Error(w, r, res.Error)
|
||||
}
|
||||
|
||||
if res.Certificate != nil {
|
||||
@@ -443,8 +443,8 @@ func writeResponse(w http.ResponseWriter, res Response) {
|
||||
_, _ = w.Write(res.Data)
|
||||
}
|
||||
|
||||
func fail(w http.ResponseWriter, err error) {
|
||||
log.Error(w, err)
|
||||
func fail(w http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Error(w, r, err)
|
||||
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user