diff --git a/audit/entry_formatter.go b/audit/entry_formatter.go index 81113c9a18..7e349513e5 100644 --- a/audit/entry_formatter.go +++ b/audit/entry_formatter.go @@ -107,6 +107,13 @@ func (f *EntryFormatter) Process(ctx context.Context, e *eventlogger.Event) (*ev data.Request.Headers = adjustedHeaders } + // If the request contains a Server-Side Consistency Token (SSCT), and we + // have an auth response, overwrite the existing client token with the SSCT, + // so that the SSCT appears in the audit log for this entry. + if data.Request != nil && data.Request.InboundSSCToken != "" && data.Auth != nil { + data.Auth.ClientToken = data.Request.InboundSSCToken + } + var result []byte switch a.Subtype { diff --git a/changelog/25443.txt b/changelog/25443.txt new file mode 100644 index 0000000000..301824d810 --- /dev/null +++ b/changelog/25443.txt @@ -0,0 +1,3 @@ +```release-note:bug +audit: Resolve potential race condition when auditing entries which use SSCT. +``` \ No newline at end of file diff --git a/vault/audit_broker.go b/vault/audit_broker.go index f783106fd4..19438643fb 100644 --- a/vault/audit_broker.go +++ b/vault/audit_broker.go @@ -209,6 +209,9 @@ func (a *AuditBroker) GetHash(ctx context.Context, name string, input string) (s // LogRequest is used to ensure all the audit backends have an opportunity to // log the given request and that *at least one* succeeds. func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput) (ret error) { + a.RLock() + defer a.RUnlock() + // If no backends are registered then we have no devices to log the request. if len(a.backends) < 1 { return nil @@ -216,19 +219,6 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput) (ret defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now()) - a.RLock() - defer a.RUnlock() - - if in.Request.InboundSSCToken != "" { - if in.Auth != nil { - reqAuthToken := in.Auth.ClientToken - in.Auth.ClientToken = in.Request.InboundSSCToken - defer func() { - in.Auth.ClientToken = reqAuthToken - }() - } - } - var retErr *multierror.Error defer func() { @@ -245,11 +235,6 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput) (ret metrics.IncrCounter([]string{"audit", "log_request_failure"}, failure) }() - headers := in.Request.Headers - defer func() { - in.Request.Headers = headers - }() - e, err := audit.NewEvent(audit.RequestType) if err != nil { retErr = multierror.Append(retErr, err) @@ -299,6 +284,9 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput) (ret // LogResponse is used to ensure all the audit backends have an opportunity to // log the given response and that *at least one* succeeds. func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput) (ret error) { + a.RLock() + defer a.RUnlock() + // If no backends are registered then we have no devices to send audit entries to. if len(a.backends) < 1 { return nil @@ -306,15 +294,6 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput) (re defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now()) - a.RLock() - defer a.RUnlock() - - if in.Request.InboundSSCToken != "" && in.Auth != nil { - reqAuthToken := in.Auth.ClientToken - in.Auth.ClientToken = in.Request.InboundSSCToken - defer func() { in.Auth.ClientToken = reqAuthToken }() - } - var retErr *multierror.Error defer func() { @@ -331,11 +310,6 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput) (re metrics.IncrCounter([]string{"audit", "log_response_failure"}, failure) }() - headers := in.Request.Headers - defer func() { - in.Request.Headers = headers - }() - e, err := audit.NewEvent(audit.ResponseType) if err != nil { retErr = multierror.Append(retErr, err)