diff --git a/internal/observability/event/audit_sink_socket.go b/internal/observability/event/audit_sink_socket.go new file mode 100644 index 0000000000..f76de6c634 --- /dev/null +++ b/internal/observability/event/audit_sink_socket.go @@ -0,0 +1,196 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package event + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + + "github.com/hashicorp/eventlogger" +) + +// AuditSocketSink is a sink node which handles writing audit events to socket. +type AuditSocketSink struct { + format auditFormat + address string + socketType string + maxDuration time.Duration + socketLock sync.RWMutex + connection net.Conn +} + +// NewAuditSocketSink should be used to create a new AuditSocketSink. +// Accepted options: WithMaxDuration and WithSocketType. +func NewAuditSocketSink(format auditFormat, address string, opt ...Option) (*AuditSocketSink, error) { + const op = "event.NewAuditSocketSink" + + opts, err := getOpts(opt...) + if err != nil { + return nil, fmt.Errorf("%s: error applying options: %w", op, err) + } + + sink := &AuditSocketSink{ + format: format, + address: address, + socketType: opts.withSocketType, + maxDuration: opts.withMaxDuration, + socketLock: sync.RWMutex{}, + connection: nil, + } + + return sink, nil +} + +// Process handles writing the event to the socket. +func (s *AuditSocketSink) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) { + const op = "event.(AuditSocketSink).Process" + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + s.socketLock.Lock() + defer s.socketLock.Unlock() + + if e == nil { + return nil, fmt.Errorf("%s: event is nil: %w", op, ErrInvalidParameter) + } + + formatted, found := e.Format(s.format.String()) + if !found { + return nil, fmt.Errorf("%s: unable to retrieve event formatted as %q", op, s.format) + } + + // Try writing and return early if successful. + err := s.write(ctx, formatted) + if err == nil { + return nil, nil + } + + // We will try to reconnect and retry a single write. + reconErr := s.reconnect(ctx) + switch { + case reconErr != nil: + // Add the reconnection error to the existing error. + err = multierror.Append(err, reconErr) + default: + err = s.write(ctx, formatted) + } + + // Format the error nicely if we need to return one. + if err != nil { + err = fmt.Errorf("%s: error writing to socket: %w", op, err) + } + + // return nil for the event to indicate the pipeline is complete. + return nil, err +} + +// Reopen handles reopening the connection for the socket sink. +func (s *AuditSocketSink) Reopen() error { + const op = "event.(AuditSocketSink).Reopen" + + s.socketLock.Lock() + defer s.socketLock.Unlock() + + err := s.reconnect(nil) + if err != nil { + return fmt.Errorf("%s: error reconnecting: %w", op, err) + } + + return nil +} + +// Type describes the type of this node (sink). +func (s *AuditSocketSink) Type() eventlogger.NodeType { + return eventlogger.NodeTypeSink +} + +// connect attempts to establish a connection using the socketType and address. +func (s *AuditSocketSink) connect(ctx context.Context) error { + const op = "event.(AuditSocketSink).connect" + + // If we're already connected, we should have disconnected first. + if s.connection != nil { + return nil + } + + timeoutContext, cancel := context.WithTimeout(ctx, s.maxDuration) + defer cancel() + + dialer := net.Dialer{} + conn, err := dialer.DialContext(timeoutContext, s.socketType, s.address) + if err != nil { + return fmt.Errorf("%s: error connecting to %q address %q: %w", op, s.socketType, s.address, err) + } + + s.connection = conn + + return nil +} + +// disconnect attempts to close and clear an existing connection. +func (s *AuditSocketSink) disconnect() error { + const op = "event.(AuditSocketSink).disconnect" + + // If we're already disconnected, we can return early. + if s.connection == nil { + return nil + } + + err := s.connection.Close() + if err != nil { + return fmt.Errorf("%s: error closing connection: %w", op, err) + } + s.connection = nil + + return nil +} + +// reconnect attempts to disconnect and then connect to the configured socketType and address. +func (s *AuditSocketSink) reconnect(ctx context.Context) error { + const op = "event.(AuditSocketSink).reconnect" + + err := s.disconnect() + if err != nil { + return fmt.Errorf("%s: error disconnecting: %w", op, err) + } + + err = s.connect(ctx) + if err != nil { + return fmt.Errorf("%s: error connecting: %w", op, err) + } + + return nil +} + +// write attempts to write the specified data using the established connection. +func (s *AuditSocketSink) write(ctx context.Context, data []byte) error { + const op = "event.(AuditSocketSink).write" + + // Ensure we're connected. + err := s.connect(ctx) + if err != nil { + return fmt.Errorf("%s: connection error: %w", op, err) + } + + err = s.connection.SetWriteDeadline(time.Now().Add(s.maxDuration)) + if err != nil { + return fmt.Errorf("%s: unable to set write deadline: %w", op, err) + } + + _, err = s.connection.Write(data) + if err != nil { + return fmt.Errorf("%s: unable to write to socket: %w", op, err) + } + + return nil +} diff --git a/internal/observability/event/options.go b/internal/observability/event/options.go index f43f55af9e..64f6147d32 100644 --- a/internal/observability/event/options.go +++ b/internal/observability/event/options.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "github.com/hashicorp/go-secure-stdlib/parseutil" + "github.com/hashicorp/go-uuid" ) @@ -19,22 +21,26 @@ type Option func(*options) error // options are used to represent configuration for an Event. type options struct { - withID string - withNow time.Time - withSubtype auditSubtype - withFormat auditFormat - withFileMode *os.FileMode - withPrefix string - withFacility string - withTag string + withID string + withNow time.Time + withSubtype auditSubtype + withFormat auditFormat + withFileMode *os.FileMode + withPrefix string + withFacility string + withTag string + withSocketType string + withMaxDuration time.Duration } // getDefaultOptions returns options with their default values. func getDefaultOptions() options { return options{ - withNow: time.Now(), - withFacility: "AUTH", - withTag: "vault", + withNow: time.Now(), + withFacility: "AUTH", + withTag: "vault", + withSocketType: "tcp", + withMaxDuration: 2 * time.Second, } } @@ -204,3 +210,36 @@ func WithTag(tag string) Option { return nil } } + +// WithSocketType provides an option to represent the socket type for a socket sink. +func WithSocketType(socketType string) Option { + return func(o *options) error { + socketType = strings.TrimSpace(socketType) + + if socketType != "" { + o.withSocketType = socketType + } + + return nil + } +} + +// WithMaxDuration provides an option to represent the max duration for writing to a socket sink. +func WithMaxDuration(duration string) Option { + return func(o *options) error { + duration = strings.TrimSpace(duration) + + if duration == "" { + return nil + } + + parsed, err := parseutil.ParseDurationSecond(duration) + if err != nil { + return err + } + + o.withMaxDuration = parsed + + return nil + } +} diff --git a/internal/observability/event/options_test.go b/internal/observability/event/options_test.go index c0181e3a8d..0917beff87 100644 --- a/internal/observability/event/options_test.go +++ b/internal/observability/event/options_test.go @@ -273,6 +273,98 @@ func TestOptions_WithTag(t *testing.T) { } } +// TestOptions_WithSocketType exercises WithSocketType option to ensure it performs as expected. +func TestOptions_WithSocketType(t *testing.T) { + tests := map[string]struct { + Value string + ExpectedValue string + }{ + "empty": { + Value: "", + ExpectedValue: "", + }, + "whitespace": { + Value: " ", + ExpectedValue: "", + }, + "value": { + Value: "juan", + ExpectedValue: "juan", + }, + "spacey-value": { + Value: " juan ", + ExpectedValue: "juan", + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + options := &options{} + applyOption := WithSocketType(tc.Value) + err := applyOption(options) + require.NoError(t, err) + require.Equal(t, tc.ExpectedValue, options.withSocketType) + }) + } +} + +// TestOptions_WithMaxDuration exercises WithMaxDuration option to ensure it performs as expected. +func TestOptions_WithMaxDuration(t *testing.T) { + tests := map[string]struct { + Value string + ExpectedValue time.Duration + IsErrorExpected bool + ExpectedErrorMessage string + }{ + "empty-gives-default": { + Value: "", + }, + "whitespace-give-default": { + Value: " ", + }, + "bad-value": { + Value: "juan", + IsErrorExpected: true, + ExpectedErrorMessage: "time: invalid duration \"juan\"", + }, + "bad-spacey-value": { + Value: " juan ", + IsErrorExpected: true, + ExpectedErrorMessage: "time: invalid duration \"juan\"", + }, + "duration-2s": { + Value: "2s", + ExpectedValue: 2 * time.Second, + }, + "duration-2m": { + Value: "2m", + ExpectedValue: 2 * time.Minute, + }, + } + + for name, tc := range tests { + name := name + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + options := &options{} + applyOption := WithMaxDuration(tc.Value) + err := applyOption(options) + switch { + case tc.IsErrorExpected: + require.Error(t, err) + require.EqualError(t, err, tc.ExpectedErrorMessage) + default: + require.NoError(t, err) + require.Equal(t, tc.ExpectedValue, options.withMaxDuration) + } + }) + } +} + // TestOptions_WithFileMode exercises WithFileMode option to ensure it performs as expected. func TestOptions_WithFileMode(t *testing.T) { tests := map[string]struct { @@ -344,6 +436,7 @@ func TestOptions_Default(t *testing.T) { require.False(t, opts.withNow.IsZero()) require.Equal(t, "AUTH", opts.withFacility) require.Equal(t, "vault", opts.withTag) + require.Equal(t, 2*time.Second, opts.withMaxDuration) } // TestOptions_Opts exercises getOpts with various Option values.