mirror of
https://github.com/outbackdingo/certificates.git
synced 2026-01-27 10:18:34 +00:00
Decouple request ID middleware from logging middleware
This commit is contained in:
@@ -15,7 +15,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/internal/requestid"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"github.com/smallstep/certificates/webhook"
|
||||
"go.step.sm/linkedca"
|
||||
@@ -171,9 +171,8 @@ retry:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestID, ok := logging.GetRequestID(ctx)
|
||||
if ok {
|
||||
req.Header.Set("X-Request-ID", requestID)
|
||||
if requestID, ok := requestid.FromContext(ctx); ok {
|
||||
req.Header.Set("X-Request-Id", requestID)
|
||||
}
|
||||
|
||||
secret, err := base64.StdEncoding.DecodeString(w.Secret)
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/internal/requestid"
|
||||
"github.com/smallstep/certificates/webhook"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -101,10 +101,10 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// withRequestID is a helper that calls into [logging.WithRequestID] and returns
|
||||
// a new context with the requestID added to the provided context.
|
||||
// withRequestID is a helper that calls into [requestid.NewContext] and returns
|
||||
// a new context with the requestID added.
|
||||
func withRequestID(ctx context.Context, requestID string) context.Context {
|
||||
return logging.WithRequestID(ctx, requestID)
|
||||
return requestid.NewContext(ctx, requestID)
|
||||
}
|
||||
|
||||
func TestWebhookController_Enrich(t *testing.T) {
|
||||
|
||||
7
ca/ca.go
7
ca/ca.go
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/internal/metrix"
|
||||
"github.com/smallstep/certificates/internal/requestid"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/monitoring"
|
||||
"github.com/smallstep/certificates/scep"
|
||||
@@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||
}
|
||||
|
||||
// Add logger if configured
|
||||
var legacyTraceHeader string
|
||||
if len(cfg.Logger) > 0 {
|
||||
logger, err := logging.New("ca", cfg.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
legacyTraceHeader = logger.GetTraceHeader()
|
||||
handler = logger.Middleware(handler)
|
||||
insecureHandler = logger.Middleware(insecureHandler)
|
||||
}
|
||||
|
||||
// always use request ID middleware; traceHeader is provided for backwards compatibility (for now)
|
||||
handler = requestid.New(legacyTraceHeader).Middleware(handler)
|
||||
insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler)
|
||||
|
||||
// Create context with all the necessary values.
|
||||
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@ package errs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestError_MarshalJSON(t *testing.T) {
|
||||
@@ -27,13 +28,14 @@ func TestError_MarshalJSON(t *testing.T) {
|
||||
Err: tt.fields.Err,
|
||||
}
|
||||
got, err := e.MarshalJSON()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Error.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, got)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Error.MarshalJSON() = %s, want %s", got, tt.want)
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -54,13 +56,14 @@ func TestError_UnmarshalJSON(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := new(Error)
|
||||
if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
//nolint:govet // best option
|
||||
if !reflect.DeepEqual(tt.expected, e) {
|
||||
t.Errorf("Error.UnmarshalJSON() wants = %+v, got %+v", tt.expected, e)
|
||||
err := e.UnmarshalJSON(tt.args.data)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, e)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
82
internal/requestid/requestid.go
Normal file
82
internal/requestid/requestid.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
const (
|
||||
// requestIDHeader is the header name used for propagating request IDs. If
|
||||
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
|
||||
// header. It'll always be used in response and set to the request ID.
|
||||
requestIDHeader = "X-Request-Id"
|
||||
|
||||
// defaultTraceHeader is the default Smallstep tracing header that's currently
|
||||
// in use. It is used as a fallback to retrieve a request ID from, if the
|
||||
// "X-Request-Id" request header is not set.
|
||||
defaultTraceHeader = "X-Smallstep-Id"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
legacyTraceHeader string
|
||||
}
|
||||
|
||||
// New creates a new request ID [handler]. It takes a trace header,
|
||||
// which is used keep the legacy behavior intact, which relies on the
|
||||
// X-Smallstep-Id header instead of X-Request-Id.
|
||||
func New(legacyTraceHeader string) *Handler {
|
||||
if legacyTraceHeader == "" {
|
||||
legacyTraceHeader = defaultTraceHeader
|
||||
}
|
||||
|
||||
return &Handler{legacyTraceHeader: legacyTraceHeader}
|
||||
}
|
||||
|
||||
// Middleware wraps an [http.Handler] with request ID extraction
|
||||
// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id
|
||||
// header if not set. If both are not set, a new request ID is generated.
|
||||
// In all cases, the request ID is added to the request context, and
|
||||
// set to be reflected in the response.
|
||||
func (h *Handler) Middleware(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, req *http.Request) {
|
||||
requestID := req.Header.Get(requestIDHeader)
|
||||
if requestID == "" {
|
||||
requestID = req.Header.Get(h.legacyTraceHeader)
|
||||
}
|
||||
|
||||
if requestID == "" {
|
||||
requestID = newRequestID()
|
||||
req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior
|
||||
}
|
||||
|
||||
// immediately set the request ID to be reflected in the response
|
||||
w.Header().Set(requestIDHeader, requestID)
|
||||
|
||||
// continue down the handler chain
|
||||
ctx := NewContext(req.Context(), requestID)
|
||||
next.ServeHTTP(w, req.WithContext(ctx))
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
// newRequestID creates a new request ID using github.com/rs/xid.
|
||||
func newRequestID() string {
|
||||
return xid.New().String()
|
||||
}
|
||||
|
||||
type requestIDKey struct{}
|
||||
|
||||
// NewContext returns a new context with the given request ID added to the
|
||||
// context.
|
||||
func NewContext(ctx context.Context, requestID string) context.Context {
|
||||
return context.WithValue(ctx, requestIDKey{}, requestID)
|
||||
}
|
||||
|
||||
// FromContext returns the request ID from the context if it exists and
|
||||
// is not the empty value.
|
||||
func FromContext(ctx context.Context) (string, bool) {
|
||||
v, ok := ctx.Value(requestIDKey{}).(string)
|
||||
return v, ok && v != ""
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package logging
|
||||
package requestid
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -10,12 +10,13 @@ import (
|
||||
)
|
||||
|
||||
func newRequest(t *testing.T) *http.Request {
|
||||
t.Helper()
|
||||
r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
|
||||
require.NoError(t, err)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
func Test_Middleware(t *testing.T) {
|
||||
requestWithID := newRequest(t)
|
||||
requestWithID.Header.Set("X-Request-Id", "reqID")
|
||||
requestWithoutID := newRequest(t)
|
||||
@@ -23,20 +24,19 @@ func TestRequestID(t *testing.T) {
|
||||
requestWithEmptyHeader.Header.Set("X-Request-Id", "")
|
||||
requestWithSmallstepID := newRequest(t)
|
||||
requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headerName string
|
||||
handler http.HandlerFunc
|
||||
req *http.Request
|
||||
name string
|
||||
traceHeader string
|
||||
next http.HandlerFunc
|
||||
req *http.Request
|
||||
}{
|
||||
{
|
||||
name: "default-request-id",
|
||||
headerName: defaultTraceIDHeader,
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
name: "default-request-id",
|
||||
traceHeader: defaultTraceHeader,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
|
||||
assert.Equal(t, "reqID", r.Header.Get("X-Request-Id"))
|
||||
reqID, ok := GetRequestID(r.Context())
|
||||
reqID, ok := FromContext(r.Context())
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, "reqID", reqID)
|
||||
}
|
||||
@@ -45,13 +45,13 @@ func TestRequestID(t *testing.T) {
|
||||
req: requestWithID,
|
||||
},
|
||||
{
|
||||
name: "no-request-id",
|
||||
headerName: "X-Request-Id",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
name: "no-request-id",
|
||||
traceHeader: "X-Request-Id",
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
|
||||
value := r.Header.Get("X-Request-Id")
|
||||
assert.NotEmpty(t, value)
|
||||
reqID, ok := GetRequestID(r.Context())
|
||||
reqID, ok := FromContext(r.Context())
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, value, reqID)
|
||||
}
|
||||
@@ -60,13 +60,13 @@ func TestRequestID(t *testing.T) {
|
||||
req: requestWithoutID,
|
||||
},
|
||||
{
|
||||
name: "empty-header-name",
|
||||
headerName: "",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
name: "empty-header",
|
||||
traceHeader: "",
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Empty(t, r.Header.Get("X-Request-Id"))
|
||||
value := r.Header.Get("X-Smallstep-Id")
|
||||
assert.NotEmpty(t, value)
|
||||
reqID, ok := GetRequestID(r.Context())
|
||||
reqID, ok := FromContext(r.Context())
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, value, reqID)
|
||||
}
|
||||
@@ -75,12 +75,12 @@ func TestRequestID(t *testing.T) {
|
||||
req: requestWithEmptyHeader,
|
||||
},
|
||||
{
|
||||
name: "fallback-header-name",
|
||||
headerName: defaultTraceIDHeader,
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
name: "fallback-header-name",
|
||||
traceHeader: defaultTraceHeader,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Empty(t, r.Header.Get("X-Request-Id"))
|
||||
assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id"))
|
||||
reqID, ok := GetRequestID(r.Context())
|
||||
reqID, ok := FromContext(r.Context())
|
||||
if assert.True(t, ok) {
|
||||
assert.Equal(t, "smallstepID", reqID)
|
||||
}
|
||||
@@ -91,8 +91,11 @@ func TestRequestID(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := RequestID(tt.headerName)
|
||||
h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req)
|
||||
handler := New(tt.traceHeader).Middleware(tt.next)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, tt.req)
|
||||
assert.NotEmpty(t, w.Header().Get("X-Request-Id"))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,82 +2,18 @@ package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
// RequestIDKey is the context key that should store the request identifier.
|
||||
RequestIDKey key = iota
|
||||
// UserIDKey is the context key that should store the user identifier.
|
||||
UserIDKey
|
||||
)
|
||||
|
||||
// NewRequestID creates a new request id using github.com/rs/xid.
|
||||
func NewRequestID() string {
|
||||
return xid.New().String()
|
||||
}
|
||||
|
||||
// requestIDHeader is the header name used for propagating request IDs. If
|
||||
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
|
||||
// header. It'll always be used in response and set to the request ID.
|
||||
const requestIDHeader = "X-Request-Id"
|
||||
|
||||
// RequestID returns a new middleware that obtains the current request ID
|
||||
// and sets it in the context. It first tries to read the request ID from
|
||||
// the "X-Request-Id" header. If that's not set, it tries to read it from
|
||||
// the provided header name. If the header does not exist or its value is
|
||||
// the empty string, it uses github.com/rs/xid to create a new one.
|
||||
func RequestID(headerName string) func(next http.Handler) http.Handler {
|
||||
if headerName == "" {
|
||||
headerName = defaultTraceIDHeader
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, req *http.Request) {
|
||||
requestID := req.Header.Get(requestIDHeader)
|
||||
if requestID == "" {
|
||||
requestID = req.Header.Get(headerName)
|
||||
}
|
||||
|
||||
if requestID == "" {
|
||||
requestID = NewRequestID()
|
||||
req.Header.Set(headerName, requestID)
|
||||
}
|
||||
|
||||
// immediately set the request ID to be reflected in the response
|
||||
w.Header().Set(requestIDHeader, requestID)
|
||||
|
||||
// continue down the handler chain
|
||||
ctx := WithRequestID(req.Context(), requestID)
|
||||
next.ServeHTTP(w, req.WithContext(ctx))
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
}
|
||||
|
||||
// WithRequestID returns a new context with the given requestID added to the
|
||||
// context.
|
||||
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
||||
return context.WithValue(ctx, RequestIDKey, requestID)
|
||||
}
|
||||
|
||||
// GetRequestID returns the request id from the context if it exists.
|
||||
func GetRequestID(ctx context.Context) (string, bool) {
|
||||
v, ok := ctx.Value(RequestIDKey).(string)
|
||||
return v, ok
|
||||
}
|
||||
type userIDKey struct{}
|
||||
|
||||
// WithUserID decodes the token, extracts the user from the payload and stores
|
||||
// it in the context.
|
||||
func WithUserID(ctx context.Context, userID string) context.Context {
|
||||
return context.WithValue(ctx, UserIDKey, userID)
|
||||
return context.WithValue(ctx, userIDKey{}, userID)
|
||||
}
|
||||
|
||||
// GetUserID returns the request id from the context if it exists.
|
||||
func GetUserID(ctx context.Context) (string, bool) {
|
||||
v, ok := ctx.Value(UserIDKey).(string)
|
||||
return v, ok
|
||||
v, ok := ctx.Value(userIDKey{}).(string)
|
||||
return v, ok && v != ""
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/smallstep/certificates/internal/requestid"
|
||||
)
|
||||
|
||||
// LoggerHandler creates a logger handler
|
||||
@@ -29,16 +30,15 @@ type options struct {
|
||||
|
||||
// NewLoggerHandler returns the given http.Handler with the logger integrated.
|
||||
func NewLoggerHandler(name string, logger *Logger, next http.Handler) http.Handler {
|
||||
h := RequestID(logger.GetTraceHeader())
|
||||
onlyTraceHealthEndpoint, _ := strconv.ParseBool(os.Getenv("STEP_LOGGER_ONLY_TRACE_HEALTH_ENDPOINT"))
|
||||
return h(&LoggerHandler{
|
||||
return &LoggerHandler{
|
||||
name: name,
|
||||
logger: logger.GetImpl(),
|
||||
options: options{
|
||||
onlyTraceHealthEndpoint: onlyTraceHealthEndpoint,
|
||||
},
|
||||
next: next,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler and call to the handler to log with a
|
||||
@@ -54,14 +54,14 @@ func (l *LoggerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// writeEntry writes to the Logger writer the request information in the logger.
|
||||
func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Time, d time.Duration) {
|
||||
var reqID, user string
|
||||
var requestID, userID string
|
||||
|
||||
ctx := r.Context()
|
||||
if v, ok := ctx.Value(RequestIDKey).(string); ok && v != "" {
|
||||
reqID = v
|
||||
if v, ok := requestid.FromContext(ctx); ok {
|
||||
requestID = v
|
||||
}
|
||||
if v, ok := ctx.Value(UserIDKey).(string); ok && v != "" {
|
||||
user = v
|
||||
if v, ok := GetUserID(ctx); ok && v != "" {
|
||||
userID = v
|
||||
}
|
||||
|
||||
// Remote hostname
|
||||
@@ -85,10 +85,10 @@ func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Tim
|
||||
status := w.StatusCode()
|
||||
|
||||
fields := logrus.Fields{
|
||||
"request-id": reqID,
|
||||
"request-id": requestID,
|
||||
"remote-address": addr,
|
||||
"name": l.name,
|
||||
"user-id": user,
|
||||
"user-id": userID,
|
||||
"time": t.Format(time.RFC3339),
|
||||
"duration-ns": d.Nanoseconds(),
|
||||
"duration": d.String(),
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/newrelic/go-agent/v3/newrelic"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/internal/requestid"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
@@ -82,7 +83,7 @@ func newRelicMiddleware(app *newrelic.Application) Middleware {
|
||||
txn.AddAttribute("httpResponseCode", strconv.Itoa(status))
|
||||
|
||||
// Add custom attributes
|
||||
if v, ok := logging.GetRequestID(r.Context()); ok {
|
||||
if v, ok := requestid.FromContext(r.Context()); ok {
|
||||
txn.AddAttribute("request.id", v)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user