mirror of
https://github.com/outbackdingo/certificates.git
synced 2026-01-27 10:18:34 +00:00
Add support for user provider X-Request-Id header value
This commit is contained in:
16
ca/client.go
16
ca/client.go
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/ca/client"
|
||||
"github.com/smallstep/certificates/ca/identity"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/cli-utils/step"
|
||||
@@ -105,10 +106,19 @@ func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, b
|
||||
const requestIDHeader = "X-Request-Id"
|
||||
|
||||
// enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's
|
||||
// empty, it'll generate a new request ID and set the header.
|
||||
// empty, the context is searched for a request ID. If that's also empty, a new
|
||||
// request ID is generated.
|
||||
func enforceRequestID(r *http.Request) {
|
||||
if r.Header.Get(requestIDHeader) == "" {
|
||||
r.Header.Set(requestIDHeader, xid.New().String())
|
||||
requestID := r.Header.Get(requestIDHeader)
|
||||
if requestID == "" {
|
||||
if reqID, ok := client.GetRequestID(r.Context()); ok && reqID != "" {
|
||||
// TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been
|
||||
// used before by the client (unless it's a retry for the same request)?
|
||||
requestID = reqID
|
||||
} else {
|
||||
requestID = xid.New().String()
|
||||
}
|
||||
r.Header.Set(requestIDHeader, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
17
ca/client/requestid.go
Normal file
17
ca/client/requestid.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package client
|
||||
|
||||
import "context"
|
||||
|
||||
type requestIDKey struct{}
|
||||
|
||||
// WithRequestID returns a new context with the given requestID added to the
|
||||
// context.
|
||||
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
||||
return context.WithValue(ctx, requestIDKey{}, requestID)
|
||||
}
|
||||
|
||||
// GetRequestID returns the request id from the context if it exists.
|
||||
func GetRequestID(ctx context.Context) (string, bool) {
|
||||
v, ok := ctx.Value(requestIDKey{}).(string)
|
||||
return v, ok
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/ca"
|
||||
"github.com/smallstep/certificates/ca/client"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -57,8 +58,8 @@ func TestXxx(t *testing.T) {
|
||||
c, err := ca.New(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// instantiate a client for the CA
|
||||
client, err := ca.NewClient(
|
||||
// instantiate a client for the CA running at the random address
|
||||
caClient, err := ca.NewClient(
|
||||
fmt.Sprintf("https://%s", randomAddress),
|
||||
ca.WithRootFile(rootFilepath),
|
||||
)
|
||||
@@ -75,12 +76,12 @@ func TestXxx(t *testing.T) {
|
||||
|
||||
// require OK health response as the baseline
|
||||
ctx := context.Background()
|
||||
healthResponse, err := client.HealthWithContext(ctx)
|
||||
healthResponse, err := caClient.HealthWithContext(ctx)
|
||||
assert.NoError(t, err)
|
||||
require.Equal(t, "ok", healthResponse.Status)
|
||||
assert.Equal(t, "ok", healthResponse.Status)
|
||||
|
||||
// expect an error when retrieving an invalid root
|
||||
rootResponse, err := client.RootWithContext(ctx, "invalid")
|
||||
rootResponse, err := caClient.RootWithContext(ctx, "invalid")
|
||||
if assert.Error(t, err) {
|
||||
apiErr := &errs.Error{}
|
||||
if assert.ErrorAs(t, err, &apiErr) {
|
||||
@@ -94,6 +95,21 @@ func TestXxx(t *testing.T) {
|
||||
}
|
||||
assert.Nil(t, rootResponse)
|
||||
|
||||
// expect an error when retrieving an invalid root and provided request ID
|
||||
rootResponse, err = caClient.RootWithContext(client.WithRequestID(ctx, "reqID"), "invalid")
|
||||
if assert.Error(t, err) {
|
||||
apiErr := &errs.Error{}
|
||||
if assert.ErrorAs(t, err, &apiErr) {
|
||||
assert.Equal(t, 404, apiErr.StatusCode())
|
||||
assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error())
|
||||
assert.Equal(t, "reqID", apiErr.RequestID)
|
||||
|
||||
// TODO: include the below error in the JSON? It's currently only output to the CA logs
|
||||
//assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg)
|
||||
}
|
||||
}
|
||||
assert.Nil(t, rootResponse)
|
||||
|
||||
// done testing; stop and wait for the server to quit
|
||||
err = c.Stop()
|
||||
require.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user