From 5c2572c44397bbaf77a4e744e22a43aeb3dc30cf Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 01:55:35 +0100 Subject: [PATCH] Add support for user provider `X-Request-Id` header value --- ca/client.go | 16 +++++++++++++--- ca/client/requestid.go | 17 +++++++++++++++++ test/e2e/requestid_test.go | 26 +++++++++++++++++++++----- 3 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 ca/client/requestid.go diff --git a/ca/client.go b/ca/client.go index 8930d8ee..d7ec2875 100644 --- a/ca/client.go +++ b/ca/client.go @@ -28,6 +28,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/errs" "go.step.sm/cli-utils/step" @@ -105,10 +106,19 @@ func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, b const requestIDHeader = "X-Request-Id" // enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's -// empty, it'll generate a new request ID and set the header. +// empty, the context is searched for a request ID. If that's also empty, a new +// request ID is generated. func enforceRequestID(r *http.Request) { - if r.Header.Get(requestIDHeader) == "" { - r.Header.Set(requestIDHeader, xid.New().String()) + requestID := r.Header.Get(requestIDHeader) + if requestID == "" { + if reqID, ok := client.GetRequestID(r.Context()); ok && reqID != "" { + // TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been + // used before by the client (unless it's a retry for the same request)? + requestID = reqID + } else { + requestID = xid.New().String() + } + r.Header.Set(requestIDHeader, requestID) } } diff --git a/ca/client/requestid.go b/ca/client/requestid.go new file mode 100644 index 00000000..de92f8c0 --- /dev/null +++ b/ca/client/requestid.go @@ -0,0 +1,17 @@ +package client + +import "context" + +type requestIDKey struct{} + +// WithRequestID returns a new context with the given requestID added to the +// context. +func WithRequestID(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, requestIDKey{}, requestID) +} + +// GetRequestID returns the request id from the context if it exists. +func GetRequestID(ctx context.Context) (string, bool) { + v, ok := ctx.Value(requestIDKey{}).(string) + return v, ok +} diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go index 7eccb4f4..a1afd423 100644 --- a/test/e2e/requestid_test.go +++ b/test/e2e/requestid_test.go @@ -11,6 +11,7 @@ import ( "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/ca" + "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/errs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -57,8 +58,8 @@ func TestXxx(t *testing.T) { c, err := ca.New(cfg) require.NoError(t, err) - // instantiate a client for the CA - client, err := ca.NewClient( + // instantiate a client for the CA running at the random address + caClient, err := ca.NewClient( fmt.Sprintf("https://%s", randomAddress), ca.WithRootFile(rootFilepath), ) @@ -75,12 +76,12 @@ func TestXxx(t *testing.T) { // require OK health response as the baseline ctx := context.Background() - healthResponse, err := client.HealthWithContext(ctx) + healthResponse, err := caClient.HealthWithContext(ctx) assert.NoError(t, err) - require.Equal(t, "ok", healthResponse.Status) + assert.Equal(t, "ok", healthResponse.Status) // expect an error when retrieving an invalid root - rootResponse, err := client.RootWithContext(ctx, "invalid") + rootResponse, err := caClient.RootWithContext(ctx, "invalid") if assert.Error(t, err) { apiErr := &errs.Error{} if assert.ErrorAs(t, err, &apiErr) { @@ -94,6 +95,21 @@ func TestXxx(t *testing.T) { } assert.Nil(t, rootResponse) + // expect an error when retrieving an invalid root and provided request ID + rootResponse, err = caClient.RootWithContext(client.WithRequestID(ctx, "reqID"), "invalid") + if assert.Error(t, err) { + apiErr := &errs.Error{} + if assert.ErrorAs(t, err, &apiErr) { + assert.Equal(t, 404, apiErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) + assert.Equal(t, "reqID", apiErr.RequestID) + + // TODO: include the below error in the JSON? It's currently only output to the CA logs + //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) + } + } + assert.Nil(t, rootResponse) + // done testing; stop and wait for the server to quit err = c.Stop() require.NoError(t, err)