VAULT-24452: audit refactor (#26460)

* Refactor audit code into audit package
* remove builtin/audit
* removed unrequired files
This commit is contained in:
Peter Wilson
2024-04-18 08:25:04 +01:00
committed by GitHub
parent 961bf20bdb
commit 8bee54c89d
60 changed files with 2638 additions and 3214 deletions

332
audit/backend.go Normal file
View File

@@ -0,0 +1,332 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"context"
"fmt"
"reflect"
"strconv"
"sync"
"sync/atomic"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
const (
optionElideListResponses = "elide_list_responses"
optionFallback = "fallback"
optionFilter = "filter"
optionFormat = "format"
optionHMACAccessor = "hmac_accessor"
optionLogRaw = "log_raw"
optionPrefix = "prefix"
)
var _ Backend = (*backend)(nil)
// Factory is the factory function to create an audit backend.
type Factory func(*BackendConfig, HeaderFormatter) (Backend, error)
// Backend interface must be implemented for an audit
// mechanism to be made available. Audit backends can be enabled to
// sink information to different backends such as logs, file, databases,
// or other external services.
type Backend interface {
// Salter interface must be implemented by anything implementing Backend.
Salter
// The PipelineReader interface allows backends to surface information about their
// nodes for node and pipeline registration.
event.PipelineReader
// IsFallback can be used to determine if this audit backend device is intended to
// be used as a fallback to catch all events that are not written when only using
// filtered pipelines.
IsFallback() bool
// LogTestMessage is used to check an audit backend before adding it
// permanently. It should attempt to synchronously log the given test
// message, WITHOUT using the normal Salt (which would require a storage
// operation on creation).
LogTestMessage(context.Context, *logical.LogInput) error
// Reload is called on SIGHUP for supporting backends.
Reload() error
// Invalidate is called for path invalidation
Invalidate(context.Context)
}
// Salter is an interface that provides a way to obtain a Salt for hashing.
type Salter interface {
// Salt returns a non-nil salt or an error.
Salt(context.Context) (*salt.Salt, error)
}
// backend represents an audit backend's shared fields across supported devices (file, socket, syslog).
// NOTE: Use newBackend to initialize the backend.
// e.g. within NewFileBackend, NewSocketBackend, NewSyslogBackend.
type backend struct {
*backendEnt
name string
nodeIDList []eventlogger.NodeID
nodeMap map[eventlogger.NodeID]eventlogger.Node
salt *atomic.Value
saltConfig *salt.Config
saltMutex sync.RWMutex
saltView logical.Storage
}
// newBackend will create the common backend which should be used by supported audit
// backend types (file, socket, syslog) to which they can create and add their sink.
// It handles basic validation of config and creates required pipelines nodes that
// precede the sink node.
func newBackend(headersConfig HeaderFormatter, conf *BackendConfig) (*backend, error) {
b := &backend{
backendEnt: newBackendEnt(conf.Config),
name: conf.MountPath,
saltConfig: conf.SaltConfig,
saltView: conf.SaltView,
salt: new(atomic.Value),
nodeIDList: []eventlogger.NodeID{},
nodeMap: make(map[eventlogger.NodeID]eventlogger.Node),
}
// Ensure we are working with the right type by explicitly storing a nil of the right type.
b.salt.Store((*salt.Salt)(nil))
if err := b.configureFilterNode(conf.Config[optionFilter]); err != nil {
return nil, err
}
cfg, err := newFormatterConfig(headersConfig, conf.Config)
if err != nil {
return nil, err
}
if err := b.configureFormatterNode(conf.MountPath, cfg, conf.Logger); err != nil {
return nil, err
}
return b, nil
}
// newFormatterConfig creates the configuration required by a formatter node using the config map supplied to the factory.
func newFormatterConfig(headerFormatter HeaderFormatter, config map[string]string) (formatterConfig, error) {
if headerFormatter == nil || reflect.ValueOf(headerFormatter).IsNil() {
return formatterConfig{}, fmt.Errorf("header formatter is required: %w", ErrInvalidParameter)
}
var opt []Option
if format, ok := config[optionFormat]; ok {
if !IsValidFormat(format) {
return formatterConfig{}, fmt.Errorf("unsupported %q: %w", optionFormat, ErrExternalOptions)
}
opt = append(opt, WithFormat(format))
}
// Check if hashing of accessor is disabled
if hmacAccessorRaw, ok := config[optionHMACAccessor]; ok {
v, err := strconv.ParseBool(hmacAccessorRaw)
if err != nil {
return formatterConfig{}, fmt.Errorf("unable to parse %q: %w", optionHMACAccessor, ErrExternalOptions)
}
opt = append(opt, WithHMACAccessor(v))
}
// Check if raw logging is enabled
if raw, ok := config[optionLogRaw]; ok {
v, err := strconv.ParseBool(raw)
if err != nil {
return formatterConfig{}, fmt.Errorf("unable to parse %q: %w", optionLogRaw, ErrExternalOptions)
}
opt = append(opt, WithRaw(v))
}
if elideListResponsesRaw, ok := config[optionElideListResponses]; ok {
v, err := strconv.ParseBool(elideListResponsesRaw)
if err != nil {
return formatterConfig{}, fmt.Errorf("unable to parse %q: %w", optionElideListResponses, ErrExternalOptions)
}
opt = append(opt, WithElision(v))
}
if prefix, ok := config[optionPrefix]; ok {
opt = append(opt, WithPrefix(prefix))
}
err := ValidateOptions()
if err != nil {
return formatterConfig{}, err
}
opts, err := getOpts(opt...)
if err != nil {
return formatterConfig{}, err
}
return formatterConfig{
headerFormatter: headerFormatter,
elideListResponses: opts.withElision,
hmacAccessor: opts.withHMACAccessor,
omitTime: opts.withOmitTime, // This must be set in code after creation.
prefix: opts.withPrefix,
raw: opts.withRaw,
requiredFormat: opts.withFormat,
}, nil
}
// configureFormatterNode is used to configure a formatter node and associated ID on the Backend.
func (b *backend) configureFormatterNode(name string, formatConfig formatterConfig, logger hclog.Logger) error {
formatterNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for formatter node: %w: %w", ErrInternal, err)
}
formatterNode, err := newEntryFormatter(name, formatConfig, b, logger)
if err != nil {
return fmt.Errorf("error creating formatter: %w", err)
}
b.nodeIDList = append(b.nodeIDList, formatterNodeID)
b.nodeMap[formatterNodeID] = formatterNode
return nil
}
// wrapMetrics takes a sink node and augments it by wrapping it with metrics nodes.
// Metrics can be used to measure time and count.
func (b *backend) wrapMetrics(name string, id eventlogger.NodeID, n eventlogger.Node) error {
if n.Type() != eventlogger.NodeTypeSink {
return fmt.Errorf("unable to wrap node with metrics. %q is not a sink node: %w", name, ErrInvalidParameter)
}
// Wrap the sink node with metrics middleware
sinkMetricTimer, err := newSinkMetricTimer(name, n)
if err != nil {
return fmt.Errorf("unable to add timing metrics to sink for path %q: %w", name, err)
}
sinkMetricCounter, err := event.NewMetricsCounter(name, sinkMetricTimer, b.getMetricLabeler())
if err != nil {
return fmt.Errorf("unable to add counting metrics to sink for path %q: %w", name, err)
}
b.nodeIDList = append(b.nodeIDList, id)
b.nodeMap[id] = sinkMetricCounter
return nil
}
// Salt is used to provide a salt for HMAC'ing data. If the salt is not currently
// loaded from storage, then loading will be attempted to create a new salt, which
// will then be stored and returned on subsequent calls.
// NOTE: If invalidation occurs the salt will likely be cleared, forcing reload
// from storage.
func (b *backend) Salt(ctx context.Context) (*salt.Salt, error) {
s := b.salt.Load().(*salt.Salt)
if s != nil {
return s, nil
}
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
s = b.salt.Load().(*salt.Salt)
if s != nil {
return s, nil
}
newSalt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig)
if err != nil {
b.salt.Store((*salt.Salt)(nil))
return nil, err
}
b.salt.Store(newSalt)
return newSalt, nil
}
// EventType returns the event type for the backend.
func (b *backend) EventType() eventlogger.EventType {
return event.AuditType.AsEventType()
}
// HasFiltering determines if the first node for the pipeline is an eventlogger.NodeTypeFilter.
func (b *backend) HasFiltering() bool {
if b.nodeMap == nil {
return false
}
return len(b.nodeIDList) > 0 && b.nodeMap[b.nodeIDList[0]].Type() == eventlogger.NodeTypeFilter
}
// Name for this backend, this must correspond to the mount path for the audit device.
func (b *backend) Name() string {
return b.name
}
// NodeIDs returns the IDs of the nodes, in the order they are required.
func (b *backend) NodeIDs() []eventlogger.NodeID {
return b.nodeIDList
}
// Nodes returns the nodes which should be used by the event framework to process audit entries.
func (b *backend) Nodes() map[eventlogger.NodeID]eventlogger.Node {
return b.nodeMap
}
func (b *backend) LogTestMessage(ctx context.Context, input *logical.LogInput) error {
if len(b.nodeIDList) > 0 {
return processManual(ctx, input, b.nodeIDList, b.nodeMap)
}
return nil
}
func (b *backend) Reload() error {
for _, n := range b.nodeMap {
if n.Type() == eventlogger.NodeTypeSink {
return n.Reopen()
}
}
return nil
}
func (b *backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt.Store((*salt.Salt)(nil))
}
// HasInvalidAuditOptions is used to determine if a non-Enterprise version of Vault
// is being used when supplying options that contain options exclusive to Enterprise.
func HasInvalidAuditOptions(options map[string]string) bool {
return !constants.IsEnterprise && hasEnterpriseAuditOptions(options)
}
// hasValidEnterpriseAuditOptions is used to check if any of the options supplied
// are only for use in the Enterprise version of Vault.
func hasEnterpriseAuditOptions(options map[string]string) bool {
enterpriseAuditOptions := []string{
optionFallback,
optionFilter,
}
for _, o := range enterpriseAuditOptions {
if _, ok := options[o]; ok {
return true
}
}
return false
}

27
audit/backend_ce.go Normal file
View File

@@ -0,0 +1,27 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package audit
import "github.com/hashicorp/vault/internal/observability/event"
type backendEnt struct{}
func newBackendEnt(_ map[string]string) *backendEnt {
return &backendEnt{}
}
func (b *backendEnt) IsFallback() bool {
return false
}
// configureFilterNode is a no-op as filters are an Enterprise-only feature.
func (b *backend) configureFilterNode(_ string) error {
return nil
}
func (b *backend) getMetricLabeler() event.Labeler {
return &metricLabelerAuditSink{}
}

59
audit/backend_ce_test.go Normal file
View File

@@ -0,0 +1,59 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package audit
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/stretchr/testify/require"
)
// TestBackend_configureFilterNode ensures that configureFilterNode handles various
// filter values as expected. Empty (including whitespace) strings should return
// no error but skip configuration of the node.
// NOTE: Audit filtering is an Enterprise feature and behaves differently in the
// community edition of Vault.
func TestBackend_configureFilterNode(t *testing.T) {
t.Parallel()
tests := map[string]struct {
filter string
}{
"happy": {
filter: "operation == \"update\"",
},
"empty": {
filter: "",
},
"spacey": {
filter: " ",
},
"bad": {
filter: "___qwerty",
},
"unsupported-field": {
filter: "foo == bar",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
b := &backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
err := b.configureFilterNode(tc.filter)
require.NoError(t, err)
require.Len(t, b.nodeIDList, 0)
require.Len(t, b.nodeMap, 0)
})
}
}

63
audit/backend_config.go Normal file
View File

@@ -0,0 +1,63 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"fmt"
"reflect"
"strings"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
// BackendConfig contains configuration parameters used in the factory func to
// instantiate audit backends
type BackendConfig struct {
// The view to store the salt
SaltView logical.Storage
// The salt config that should be used for any secret obfuscation
SaltConfig *salt.Config
// Config is the opaque user configuration provided when mounting
Config map[string]string
// MountPath is the path where this Backend is mounted
MountPath string
// Logger is used to emit log messages usually captured in the server logs.
Logger hclog.Logger
}
// Validate ensures that we have the required configuration to create audit backends.
func (c *BackendConfig) Validate() error {
if c.SaltConfig == nil {
return fmt.Errorf("nil salt config: %w", ErrInvalidParameter)
}
if c.SaltView == nil {
return fmt.Errorf("nil salt view: %w", ErrInvalidParameter)
}
if c.Logger == nil || reflect.ValueOf(c.Logger).IsNil() {
return fmt.Errorf("nil logger: %w", ErrInvalidParameter)
}
if c.Config == nil {
return fmt.Errorf("config cannot be nil: %w", ErrInvalidParameter)
}
if strings.TrimSpace(c.MountPath) == "" {
return fmt.Errorf("mount path cannot be empty: %w", ErrExternalOptions)
}
// Validate actual config specific to Vault version (Enterprise/CE).
if err := c.validate(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,18 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package audit
import "fmt"
// validate ensures that this if we're not running Vault Enterprise, we cannot
// supply Enterprise-only audit configuration options.
func (c *BackendConfig) validate() error {
if HasInvalidAuditOptions(c.Config) {
return fmt.Errorf("enterprise-only options supplied: %w", ErrExternalOptions)
}
return nil
}

154
audit/backend_file.go Normal file
View File

@@ -0,0 +1,154 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"fmt"
"reflect"
"strings"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/vault/internal/observability/event"
)
const (
stdout = "stdout"
discard = "discard"
optionFilePath = "file_path"
optionMode = "mode"
)
var _ Backend = (*FileBackend)(nil)
type FileBackend struct {
*backend
}
// NewFileBackend provides a wrapper to support the expectation elsewhere in Vault that
// all audit backends can be created via a factory that returns an interface (Backend).
func NewFileBackend(conf *BackendConfig, headersConfig HeaderFormatter) (be Backend, err error) {
be, err = newFileBackend(conf, headersConfig)
return
}
// newFileBackend creates a backend and configures all nodes including a file sink.
func newFileBackend(conf *BackendConfig, headersConfig HeaderFormatter) (*FileBackend, error) {
if headersConfig == nil || reflect.ValueOf(headersConfig).IsNil() {
return nil, fmt.Errorf("nil header formatter: %w", ErrInvalidParameter)
}
if conf == nil {
return nil, fmt.Errorf("nil config: %w", ErrInvalidParameter)
}
if err := conf.Validate(); err != nil {
return nil, err
}
// Get file path from config or fall back to the old option ('path') for compatibility
// (see commit bac4fe0799a372ba1245db642f3f6cd1f1d02669).
var filePath string
if p, ok := conf.Config[optionFilePath]; ok {
filePath = p
} else if p, ok = conf.Config["path"]; ok {
filePath = p
} else {
return nil, fmt.Errorf("%q is required: %w", optionFilePath, ErrExternalOptions)
}
bec, err := newBackend(headersConfig, conf)
if err != nil {
return nil, err
}
b := &FileBackend{backend: bec}
// normalize file path if configured for stdout
if strings.EqualFold(filePath, stdout) {
filePath = stdout
}
if strings.EqualFold(filePath, discard) {
filePath = discard
}
// Configure the sink.
cfg, err := newFormatterConfig(headersConfig, conf.Config)
if err != nil {
return nil, err
}
var opt []event.Option
if mode, ok := conf.Config[optionMode]; ok {
opt = append(opt, event.WithFileMode(mode))
}
err = b.configureSinkNode(conf.MountPath, filePath, cfg.requiredFormat, opt...)
if err != nil {
return nil, err
}
return b, nil
}
// configureSinkNode is used internally by FileBackend to create and configure the
// sink node on the backend.
func (b *FileBackend) configureSinkNode(name string, filePath string, format format, opt ...event.Option) error {
name = strings.TrimSpace(name)
if name == "" {
return fmt.Errorf("name is required: %w", ErrExternalOptions)
}
filePath = strings.TrimSpace(filePath)
if filePath == "" {
return fmt.Errorf("file path is required: %w", ErrExternalOptions)
}
sinkNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for sink node: %w: %w", ErrInternal, err)
}
// normalize file path if configured for stdout or discard
if strings.EqualFold(filePath, stdout) {
filePath = stdout
} else if strings.EqualFold(filePath, discard) {
filePath = discard
}
var sinkNode eventlogger.Node
var sinkName string
switch filePath {
case stdout:
sinkName = stdout
sinkNode, err = event.NewStdoutSinkNode(format.String())
case discard:
sinkName = discard
sinkNode = event.NewNoopSink()
default:
// The NewFileSink function attempts to open the file and will return an error if it can't.
sinkName = name
sinkNode, err = event.NewFileSink(filePath, format.String(), opt...)
}
if err != nil {
return fmt.Errorf("file sink creation failed for path %q: %w", filePath, err)
}
// Wrap the sink node with metrics middleware
err = b.wrapMetrics(sinkName, sinkNodeID, sinkNode)
if err != nil {
return err
}
return nil
}
// Reload will trigger the reload action on the sink node for this backend.
func (b *FileBackend) Reload() error {
for _, n := range b.nodeMap {
if n.Type() == eventlogger.NodeTypeSink {
return n.Reopen()
}
}
return nil
}

View File

@@ -0,0 +1,147 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package audit
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestFileBackend_newFileBackend_fallback ensures that we get the correct errors
// in CE when we try to enable a FileBackend with enterprise options like fallback
// and filter.
func TestFileBackend_newFileBackend_fallback(t *testing.T) {
t.Parallel()
tests := map[string]struct {
backendConfig *BackendConfig
isErrorExpected bool
expectedErrorMessage string
}{
"non-fallback-device-with-filter": {
backendConfig: &BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "false",
"file_path": discard,
"filter": "mount_type == kv",
},
},
isErrorExpected: true,
expectedErrorMessage: "enterprise-only options supplied: invalid configuration",
},
"fallback-device-with-filter": {
backendConfig: &BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "true",
"file_path": discard,
"filter": "mount_type == kv",
},
},
isErrorExpected: true,
expectedErrorMessage: "enterprise-only options supplied: invalid configuration",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
be, err := newFileBackend(tc.backendConfig, &NoopHeaderFormatter{})
if tc.isErrorExpected {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrorMessage)
} else {
require.NoError(t, err)
require.NotNil(t, be)
}
})
}
}
// TestFileBackend_newFileBackend_FilterFormatterSink ensures that when configuring
// a backend in community edition we cannot configure a filter node.
// We can verify that we have formatter and sink nodes added to the backend.
// The order of calls influences the slice of IDs on the Backend.
func TestFileBackend_newFileBackend_FilterFormatterSink(t *testing.T) {
t.Parallel()
cfg := map[string]string{
"file_path": "/tmp/foo",
"mode": "0777",
"format": "json",
"filter": "mount_type == \"kv\"",
}
backendConfig := &BackendConfig{
SaltView: &logical.InmemStorage{},
SaltConfig: &salt.Config{},
Config: cfg,
MountPath: "bar",
Logger: hclog.NewNullLogger(),
}
b, err := newFileBackend(backendConfig, &NoopHeaderFormatter{})
require.Error(t, err)
require.EqualError(t, err, "enterprise-only options supplied: invalid configuration")
// Try without filter option
delete(cfg, "filter")
b, err = newFileBackend(backendConfig, &NoopHeaderFormatter{})
require.NoError(t, err)
require.Len(t, b.nodeIDList, 2)
require.Len(t, b.nodeMap, 2)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeFormatter, node.Type())
id = b.nodeIDList[1]
node = b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
}
// TestBackend_IsFallback ensures that no CE audit device can be a fallback.
func TestBackend_IsFallback(t *testing.T) {
t.Parallel()
cfg := &BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "true",
"file_path": discard,
},
}
be, err := newFileBackend(cfg, &NoopHeaderFormatter{})
require.Error(t, err)
require.EqualError(t, err, "enterprise-only options supplied: invalid configuration")
// Remove the option and try again
delete(cfg.Config, "fallback")
be, err = newFileBackend(cfg, &NoopHeaderFormatter{})
require.NoError(t, err)
require.NotNil(t, be)
require.Equal(t, false, be.IsFallback())
}

289
audit/backend_file_test.go Normal file
View File

@@ -0,0 +1,289 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"os"
"path/filepath"
"strconv"
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestAuditFile_fileModeNew verifies that the backend Factory correctly sets
// the file mode when the mode argument is set.
func TestAuditFile_fileModeNew(t *testing.T) {
t.Parallel()
modeStr := "0777"
mode, err := strconv.ParseUint(modeStr, 8, 32)
require.NoError(t, err)
file := filepath.Join(t.TempDir(), "auditTest.txt")
backendConfig := &BackendConfig{
Config: map[string]string{
"path": file,
"mode": modeStr,
},
MountPath: "foo/bar",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
}
_, err = newFileBackend(backendConfig, &NoopHeaderFormatter{})
require.NoError(t, err)
info, err := os.Stat(file)
require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`")
require.Equalf(t, os.FileMode(mode), info.Mode(), "File mode does not match.")
}
// TestAuditFile_fileModeExisting verifies that the backend Factory correctly sets
// the mode on an existing file.
func TestAuditFile_fileModeExisting(t *testing.T) {
t.Parallel()
dir := t.TempDir()
f, err := os.CreateTemp(dir, "auditTest.log")
require.NoErrorf(t, err, "Failure to create test file.")
err = os.Chmod(f.Name(), 0o777)
require.NoErrorf(t, err, "Failure to chmod temp file for testing.")
err = f.Close()
require.NoErrorf(t, err, "Failure to close temp file for test.")
backendConfig := &BackendConfig{
Config: map[string]string{
"path": f.Name(),
},
MountPath: "foo/bar",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
}
_, err = newFileBackend(backendConfig, &NoopHeaderFormatter{})
require.NoError(t, err)
info, err := os.Stat(f.Name())
require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`")
require.Equalf(t, os.FileMode(0o600), info.Mode(), "File mode does not match.")
}
// TestAuditFile_fileMode0000 verifies that setting the audit file mode to
// "0000" prevents Vault from modifying the permissions of the file.
func TestAuditFile_fileMode0000(t *testing.T) {
t.Parallel()
dir := t.TempDir()
f, err := os.CreateTemp(dir, "auditTest.log")
require.NoErrorf(t, err, "Failure to create test file.")
err = os.Chmod(f.Name(), 0o777)
require.NoErrorf(t, err, "Failure to chmod temp file for testing.")
err = f.Close()
require.NoErrorf(t, err, "Failure to close temp file for test.")
backendConfig := &BackendConfig{
Config: map[string]string{
"path": f.Name(),
"mode": "0000",
},
MountPath: "foo/bar",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
}
_, err = newFileBackend(backendConfig, &NoopHeaderFormatter{})
require.NoError(t, err)
info, err := os.Stat(f.Name())
require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`. The error is %v", err)
require.Equalf(t, os.FileMode(0o777), info.Mode(), "File mode does not match.")
}
// TestAuditFile_EventLogger_fileModeNew verifies that the Factory function
// correctly sets the file mode when the useEventLogger argument is set to
// true.
func TestAuditFile_EventLogger_fileModeNew(t *testing.T) {
modeStr := "0777"
mode, err := strconv.ParseUint(modeStr, 8, 32)
require.NoError(t, err)
file := filepath.Join(t.TempDir(), "auditTest.txt")
backendConfig := &BackendConfig{
Config: map[string]string{
"file_path": file,
"mode": modeStr,
},
MountPath: "foo/bar",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
}
_, err = newFileBackend(backendConfig, &NoopHeaderFormatter{})
require.NoError(t, err)
info, err := os.Stat(file)
require.NoError(t, err)
require.Equalf(t, os.FileMode(mode), info.Mode(), "File mode does not match.")
}
// TestFileBackend_newFileBackend ensures that we can correctly configure the sink
// node on the Backend, and any incorrect parameters result in the relevant errors.
func TestFileBackend_newFileBackend(t *testing.T) {
t.Parallel()
tests := map[string]struct {
mountPath string
filePath string
mode string
format string
wantErr bool
expectedErrMsg string
expectedName string
}{
"name-empty": {
mountPath: "",
format: "json",
wantErr: true,
expectedErrMsg: "mount path cannot be empty: invalid configuration",
},
"name-whitespace": {
mountPath: " ",
format: "json",
wantErr: true,
expectedErrMsg: "mount path cannot be empty: invalid configuration",
},
"filePath-empty": {
mountPath: "foo",
filePath: "",
format: "json",
wantErr: true,
expectedErrMsg: "file path is required: invalid configuration",
},
"filePath-whitespace": {
mountPath: "foo",
filePath: " ",
format: "json",
wantErr: true,
expectedErrMsg: "file path is required: invalid configuration",
},
"filePath-stdout-lower": {
mountPath: "foo",
expectedName: "stdout",
filePath: "stdout",
format: "json",
},
"filePath-stdout-upper": {
mountPath: "foo",
expectedName: "stdout",
filePath: "STDOUT",
format: "json",
},
"filePath-stdout-mixed": {
mountPath: "foo",
expectedName: "stdout",
filePath: "StdOut",
format: "json",
},
"filePath-discard-lower": {
mountPath: "foo",
expectedName: "discard",
filePath: "discard",
format: "json",
},
"filePath-discard-upper": {
mountPath: "foo",
expectedName: "discard",
filePath: "DISCARD",
format: "json",
},
"filePath-discard-mixed": {
mountPath: "foo",
expectedName: "discard",
filePath: "DisCArd",
format: "json",
},
"format-empty": {
mountPath: "foo",
filePath: "/tmp/",
format: "",
wantErr: true,
expectedErrMsg: "unsupported \"format\": invalid configuration",
},
"format-whitespace": {
mountPath: "foo",
filePath: "/tmp/",
format: " ",
wantErr: true,
expectedErrMsg: "unsupported \"format\": invalid configuration",
},
"filePath-weird-with-mode-zero": {
mountPath: "foo",
filePath: "/tmp/qwerty",
format: "json",
mode: "0",
wantErr: true,
expectedErrMsg: "file sink creation failed for path \"/tmp/qwerty\": unable to determine existing file mode: stat /tmp/qwerty: no such file or directory",
},
"happy": {
mountPath: "foo",
filePath: "/tmp/log",
mode: "",
format: "json",
wantErr: false,
expectedName: "foo",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
cfg := &BackendConfig{
SaltView: &logical.InmemStorage{},
SaltConfig: &salt.Config{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"file_path": tc.filePath,
"mode": tc.mode,
"format": tc.format,
},
MountPath: tc.mountPath,
}
b, err := newFileBackend(cfg, &NoopHeaderFormatter{})
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrMsg)
require.Nil(t, b)
} else {
require.NoError(t, err)
require.Len(t, b.nodeIDList, 2) // Expect formatter + the sink
require.Len(t, b.nodeMap, 2)
id := b.nodeIDList[1]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
mc, ok := node.(*event.MetricsCounter)
require.True(t, ok)
require.Equal(t, tc.expectedName, mc.Name)
}
})
}
}

View File

@@ -122,7 +122,7 @@ func NewNoopAudit(config *BackendConfig) (*NoopAudit, error) {
nodeMap: make(map[eventlogger.NodeID]eventlogger.Node, 2), nodeMap: make(map[eventlogger.NodeID]eventlogger.Node, 2),
} }
cfg, err := NewFormatterConfig(&NoopHeaderFormatter{}) cfg, err := newFormatterConfig(&NoopHeaderFormatter{}, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -132,7 +132,7 @@ func NewNoopAudit(config *BackendConfig) (*NoopAudit, error) {
return nil, fmt.Errorf("error generating random NodeID for formatter node: %w", err) return nil, fmt.Errorf("error generating random NodeID for formatter node: %w", err)
} }
formatterNode, err := NewEntryFormatter(config.MountPath, cfg, noopBackend, config.Logger) formatterNode, err := newEntryFormatter(config.MountPath, cfg, noopBackend, config.Logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating formatter: %w", err) return nil, fmt.Errorf("error creating formatter: %w", err)
} }
@@ -268,7 +268,7 @@ func (n *noopWrapper) Type() eventlogger.NodeType {
// LogTestMessage will manually crank the handle on the nodes associated with this backend. // LogTestMessage will manually crank the handle on the nodes associated with this backend.
func (n *NoopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput) error { func (n *NoopAudit) LogTestMessage(ctx context.Context, in *logical.LogInput) error {
if len(n.nodeIDList) > 0 { if len(n.nodeIDList) > 0 {
return ProcessManual(ctx, in, n.nodeIDList, n.nodeMap) return processManual(ctx, in, n.nodeIDList, n.nodeMap)
} }
return nil return nil

125
audit/backend_socket.go Normal file
View File

@@ -0,0 +1,125 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"fmt"
"reflect"
"strings"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/vault/internal/observability/event"
)
const (
optionAddress = "address"
optionSocketType = "socket_type"
optionWriteTimeout = "write_timeout"
)
var _ Backend = (*SocketBackend)(nil)
type SocketBackend struct {
*backend
}
// NewSocketBackend provides a means to create socket backend audit devices that
// satisfy the Factory pattern expected elsewhere in Vault.
func NewSocketBackend(conf *BackendConfig, headersConfig HeaderFormatter) (be Backend, err error) {
be, err = newSocketBackend(conf, headersConfig)
return
}
// newSocketBackend creates a backend and configures all nodes including a socket sink.
func newSocketBackend(conf *BackendConfig, headersConfig HeaderFormatter) (*SocketBackend, error) {
if headersConfig == nil || reflect.ValueOf(headersConfig).IsNil() {
return nil, fmt.Errorf("nil header formatter: %w", ErrInvalidParameter)
}
if conf == nil {
return nil, fmt.Errorf("nil config: %w", ErrInvalidParameter)
}
if err := conf.Validate(); err != nil {
return nil, err
}
bec, err := newBackend(headersConfig, conf)
if err != nil {
return nil, err
}
address, ok := conf.Config[optionAddress]
if !ok {
return nil, fmt.Errorf("%q is required: %w", optionAddress, ErrExternalOptions)
}
address = strings.TrimSpace(address)
if address == "" {
return nil, fmt.Errorf("%q cannot be empty: %w", optionAddress, ErrExternalOptions)
}
socketType, ok := conf.Config[optionSocketType]
if !ok {
socketType = "tcp"
}
writeDeadline, ok := conf.Config[optionWriteTimeout]
if !ok {
writeDeadline = "2s"
}
sinkOpts := []event.Option{
event.WithSocketType(socketType),
event.WithMaxDuration(writeDeadline),
}
err = event.ValidateOptions(sinkOpts...)
if err != nil {
return nil, err
}
b := &SocketBackend{backend: bec}
// Configure the sink.
cfg, err := newFormatterConfig(headersConfig, conf.Config)
if err != nil {
return nil, err
}
err = b.configureSinkNode(conf.MountPath, address, cfg.requiredFormat, sinkOpts...)
if err != nil {
return nil, err
}
return b, nil
}
func (b *SocketBackend) configureSinkNode(name string, address string, format format, opts ...event.Option) error {
sinkNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for sink node: %w", err)
}
n, err := event.NewSocketSink(address, format.String(), opts...)
if err != nil {
return err
}
// Wrap the sink node with metrics middleware
err = b.wrapMetrics(name, sinkNodeID, n)
if err != nil {
return err
}
return nil
}
// Reload will trigger the reload action on the sink node for this backend.
func (b *SocketBackend) Reload() error {
for _, n := range b.nodeMap {
if n.Type() == eventlogger.NodeTypeSink {
return n.Reopen()
}
}
return nil
}

View File

@@ -0,0 +1,136 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestSocketBackend_newSocketBackend ensures that we can correctly configure the sink
// node on the Backend, and any incorrect parameters result in the relevant errors.
func TestSocketBackend_newSocketBackend(t *testing.T) {
t.Parallel()
tests := map[string]struct {
mountPath string
address string
socketType string
writeDuration string
format string
wantErr bool
expectedErrMsg string
expectedName string
}{
"name-empty": {
mountPath: "",
address: "wss://foo",
format: "json",
wantErr: true,
expectedErrMsg: "mount path cannot be empty: invalid configuration",
},
"name-whitespace": {
mountPath: " ",
address: "wss://foo",
format: "json",
wantErr: true,
expectedErrMsg: "mount path cannot be empty: invalid configuration",
},
"address-empty": {
mountPath: "foo",
address: "",
format: "json",
wantErr: true,
expectedErrMsg: "\"address\" cannot be empty: invalid configuration",
},
"address-whitespace": {
mountPath: "foo",
address: " ",
format: "json",
wantErr: true,
expectedErrMsg: "\"address\" cannot be empty: invalid configuration",
},
"format-empty": {
mountPath: "foo",
address: "wss://foo",
format: "",
wantErr: true,
expectedErrMsg: "unsupported \"format\": invalid configuration",
},
"format-whitespace": {
mountPath: "foo",
address: "wss://foo",
format: " ",
wantErr: true,
expectedErrMsg: "unsupported \"format\": invalid configuration",
},
"write-duration-valid": {
mountPath: "foo",
address: "wss://foo",
writeDuration: "5s",
format: "json",
wantErr: false,
expectedName: "foo",
},
"write-duration-not-valid": {
mountPath: "foo",
address: "wss://foo",
writeDuration: "qwerty",
format: "json",
wantErr: true,
expectedErrMsg: "unable to parse max duration: invalid parameter: time: invalid duration \"qwerty\"",
},
"happy": {
mountPath: "foo",
address: "wss://foo",
format: "json",
wantErr: false,
expectedName: "foo",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
cfg := &BackendConfig{
SaltView: &logical.InmemStorage{},
SaltConfig: &salt.Config{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"address": tc.address,
"format": tc.format,
"write_timeout": tc.writeDuration,
"socket": tc.socketType,
},
MountPath: tc.mountPath,
}
b, err := newSocketBackend(cfg, &NoopHeaderFormatter{})
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrMsg)
require.Nil(t, b)
} else {
require.NoError(t, err)
require.Len(t, b.nodeIDList, 2) // formatter + sink
require.Len(t, b.nodeMap, 2)
id := b.nodeIDList[1] // sink is 2nd
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
mc, ok := node.(*event.MetricsCounter)
require.True(t, ok)
require.Equal(t, tc.expectedName, mc.Name)
}
})
}
}

108
audit/backend_syslog.go Normal file
View File

@@ -0,0 +1,108 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"fmt"
"reflect"
"github.com/hashicorp/vault/internal/observability/event"
)
const (
optionFacility = "facility"
optionTag = "tag"
)
var _ Backend = (*SyslogBackend)(nil)
type SyslogBackend struct {
*backend
}
// NewSyslogBackend provides a wrapper to support the expectation elsewhere in Vault that
// all audit backends can be created via a factory that returns an interface (Backend).
func NewSyslogBackend(conf *BackendConfig, headersConfig HeaderFormatter) (be Backend, err error) {
be, err = newSyslogBackend(conf, headersConfig)
return
}
// newSyslogBackend creates a backend and configures all nodes including a socket sink.
func newSyslogBackend(conf *BackendConfig, headersConfig HeaderFormatter) (*SyslogBackend, error) {
if headersConfig == nil || reflect.ValueOf(headersConfig).IsNil() {
return nil, fmt.Errorf("nil header formatter: %w", ErrInvalidParameter)
}
if conf == nil {
return nil, fmt.Errorf("nil config: %w", ErrInvalidParameter)
}
if err := conf.Validate(); err != nil {
return nil, err
}
bec, err := newBackend(headersConfig, conf)
if err != nil {
return nil, err
}
// Get facility or default to AUTH
facility, ok := conf.Config[optionFacility]
if !ok {
facility = "AUTH"
}
// Get tag or default to 'vault'
tag, ok := conf.Config[optionTag]
if !ok {
tag = "vault"
}
sinkOpts := []event.Option{
event.WithFacility(facility),
event.WithTag(tag),
}
err = event.ValidateOptions(sinkOpts...)
if err != nil {
return nil, err
}
b := &SyslogBackend{backend: bec}
// Configure the sink.
cfg, err := newFormatterConfig(headersConfig, conf.Config)
if err != nil {
return nil, err
}
err = b.configureSinkNode(conf.MountPath, cfg.requiredFormat, sinkOpts...)
if err != nil {
return nil, err
}
return b, nil
}
func (b *SyslogBackend) configureSinkNode(name string, format format, opts ...event.Option) error {
sinkNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for sink node: %w: %w", ErrInternal, err)
}
n, err := event.NewSyslogSink(format.String(), opts...)
if err != nil {
return fmt.Errorf("error creating syslog sink node: %w", err)
}
err = b.wrapMetrics(name, sinkNodeID, n)
if err != nil {
return err
}
return nil
}
// Reload will trigger the reload action on the sink node for this backend.
func (b *SyslogBackend) Reload() error {
return nil
}

View File

@@ -0,0 +1,119 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestSyslogBackend_newSyslogBackend tests the ways we can try to create a new
// SyslogBackend both good and bad.
func TestSyslogBackend_newSyslogBackend(t *testing.T) {
t.Parallel()
tests := map[string]struct {
mountPath string
format string
tag string
facility string
wantErr bool
expectedErrMsg string
expectedName string
}{
"name-empty": {
mountPath: "",
wantErr: true,
expectedErrMsg: "mount path cannot be empty: invalid configuration",
},
"name-whitespace": {
mountPath: " ",
wantErr: true,
expectedErrMsg: "mount path cannot be empty: invalid configuration",
},
"format-empty": {
mountPath: "foo",
format: "",
wantErr: true,
expectedErrMsg: "unsupported \"format\": invalid configuration",
},
"format-whitespace": {
mountPath: "foo",
format: " ",
wantErr: true,
expectedErrMsg: "unsupported \"format\": invalid configuration",
},
"happy": {
mountPath: "foo",
format: "json",
wantErr: false,
expectedName: "foo",
},
"happy-tag": {
mountPath: "foo",
format: "json",
tag: "beep",
wantErr: false,
expectedName: "foo",
},
"happy-facility": {
mountPath: "foo",
format: "json",
facility: "daemon",
wantErr: false,
expectedName: "foo",
},
"happy-all": {
mountPath: "foo",
format: "json",
tag: "beep",
facility: "daemon",
wantErr: false,
expectedName: "foo",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
cfg := &BackendConfig{
SaltView: &logical.InmemStorage{},
SaltConfig: &salt.Config{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"tag": tc.tag,
"facility": tc.facility,
"format": tc.format,
},
MountPath: tc.mountPath,
}
b, err := newSyslogBackend(cfg, &NoopHeaderFormatter{})
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrMsg)
require.Nil(t, b)
} else {
require.NoError(t, err)
require.Len(t, b.nodeIDList, 2)
require.Len(t, b.nodeMap, 2)
id := b.nodeIDList[1]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
mc, ok := node.(*event.MetricsCounter)
require.True(t, ok)
require.Equal(t, tc.expectedName, mc.Name)
}
})
}
}

146
audit/backend_test.go Normal file
View File

@@ -0,0 +1,146 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
)
// TestBackend_newFormatterConfig ensures that all the configuration values are
// parsed correctly when trying to create a new formatterConfig via newFormatterConfig.
func TestBackend_newFormatterConfig(t *testing.T) {
t.Parallel()
tests := map[string]struct {
config map[string]string
want formatterConfig
wantErr bool
expectedMessage string
}{
"happy-path-json": {
config: map[string]string{
"format": JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: formatterConfig{
raw: true,
hmacAccessor: true,
elideListResponses: true,
requiredFormat: "json",
}, wantErr: false,
},
"happy-path-jsonx": {
config: map[string]string{
"format": JSONxFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: formatterConfig{
raw: true,
hmacAccessor: true,
elideListResponses: true,
requiredFormat: "jsonx",
},
wantErr: false,
},
"invalid-format": {
config: map[string]string{
"format": " squiggly ",
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: formatterConfig{},
wantErr: true,
expectedMessage: "unsupported \"format\": invalid configuration",
},
"invalid-hmac-accessor": {
config: map[string]string{
"format": JSONFormat.String(),
"hmac_accessor": "maybe",
},
want: formatterConfig{},
wantErr: true,
expectedMessage: "unable to parse \"hmac_accessor\": invalid configuration",
},
"invalid-log-raw": {
config: map[string]string{
"format": JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "maybe",
},
want: formatterConfig{},
wantErr: true,
expectedMessage: "unable to parse \"log_raw\": invalid configuration",
},
"invalid-elide-bool": {
config: map[string]string{
"format": JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "maybe",
},
want: formatterConfig{},
wantErr: true,
expectedMessage: "unable to parse \"elide_list_responses\": invalid configuration",
},
"prefix": {
config: map[string]string{
"format": JSONFormat.String(),
"prefix": "foo",
},
want: formatterConfig{
requiredFormat: JSONFormat,
prefix: "foo",
hmacAccessor: true,
},
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
got, err := newFormatterConfig(&NoopHeaderFormatter{}, tc.config)
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedMessage)
} else {
require.NoError(t, err)
}
require.Equal(t, tc.want.requiredFormat, got.requiredFormat)
require.Equal(t, tc.want.raw, got.raw)
require.Equal(t, tc.want.elideListResponses, got.elideListResponses)
require.Equal(t, tc.want.hmacAccessor, got.hmacAccessor)
require.Equal(t, tc.want.omitTime, got.omitTime)
require.Equal(t, tc.want.prefix, got.prefix)
})
}
}
// TestBackend_configureFormatterNode ensures that configureFormatterNode
// populates the nodeIDList and nodeMap on backend when given valid config.
func TestBackend_configureFormatterNode(t *testing.T) {
t.Parallel()
b, err := newBackend(&NoopHeaderFormatter{}, &BackendConfig{
MountPath: "foo",
Logger: hclog.NewNullLogger(),
})
require.NoError(t, err)
require.Len(t, b.nodeIDList, 1)
require.Len(t, b.nodeMap, 1)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeFormatter, node.Type())
}

426
audit/broker.go Normal file
View File

@@ -0,0 +1,426 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package audit
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/logical"
)
var (
_ Registrar = (*Broker)(nil)
_ Auditor = (*Broker)(nil)
)
// Registrar interface describes a means to register and deregister audit devices.
type Registrar interface {
Register(backend Backend, local bool) error
Deregister(ctx context.Context, name string) error
IsRegistered(name string) bool
IsLocal(name string) (bool, error)
}
// Auditor interface describes methods which can be used to perform auditing.
type Auditor interface {
LogRequest(ctx context.Context, input *logical.LogInput) error
LogResponse(ctx context.Context, input *logical.LogInput) error
GetHash(ctx context.Context, name string, input string) (string, error)
Invalidate(ctx context.Context, key string)
}
// backendEntry composes a backend with additional settings.
type backendEntry struct {
// backend is the underlying audit backend.
backend Backend
// local indicates whether this audit backend should be local to the Vault cluster.
local bool
}
// Broker represents an audit broker which performs actions such as registering/de-registering
// backends and logging audit entries for a request or response.
// NOTE: NewBroker should be used to initialize the Broker struct.
type Broker struct {
*brokerEnt
sync.RWMutex
logger hclog.Logger
// backends is the map of audit device name to {thing}
backends map[string]backendEntry
// broker is used to register pipelines for audit devices.
broker *eventlogger.Broker
}
// NewBroker initializes a broker, which can be used to perform audit logging.
func NewBroker(logger hclog.Logger) (*Broker, error) {
if logger == nil || reflect.ValueOf(logger).IsNil() {
return nil, fmt.Errorf("cannot create a new audit broker with nil logger: %w", ErrInvalidParameter)
}
eventBroker, err := eventlogger.NewBroker()
if err != nil {
return nil, fmt.Errorf("error creating event broker for audit events: %w", err)
}
ent, err := newBrokerEnt()
if err != nil {
return nil, fmt.Errorf("error creating audit broker extentions: %w", err)
}
return &Broker{
backends: make(map[string]backendEntry),
broker: eventBroker,
brokerEnt: ent,
logger: logger,
}, nil
}
// hasAuditPipelines can be used as a shorthand to check if a broker has any
// registered pipelines that are for the audit event type.
func hasAuditPipelines(broker *eventlogger.Broker) bool {
return broker.IsAnyPipelineRegistered(event.AuditType.AsEventType())
}
// isRegistered is used to check if a given audit backend is registered.
// This method should be used within the broker to prevent locking issues.
func (b *Broker) isRegistered(backend Backend) error {
if b.isRegisteredByName(backend.Name()) {
return fmt.Errorf("backend already registered '%s': %w", backend.Name(), ErrExternalOptions)
}
if err := b.validateRegistrationRequest(backend); err != nil {
return err
}
return nil
}
// isRegisteredByName returns a boolean to indicate whether an audit backend is
// registered with the broker.
func (b *Broker) isRegisteredByName(name string) bool {
_, ok := b.backends[name]
return ok
}
// register can be used to register a normal audit device, it will also calculate
// and configure the success threshold required for sinks.
// NOTE: register assumes that the backend which is being registered has not yet
// been added to the broker's backends.
func (b *Broker) register(backend Backend) error {
err := registerNodesAndPipeline(b.broker, backend)
if err != nil {
return fmt.Errorf("audit pipeline registration error: %w", err)
}
threshold := 0
if !backend.HasFiltering() {
threshold = 1
} else {
threshold = b.requiredSuccessThresholdSinks()
}
// Update the success threshold now that the pipeline is registered.
err = b.broker.SetSuccessThresholdSinks(event.AuditType.AsEventType(), threshold)
if err != nil {
return fmt.Errorf("unable to configure sink success threshold (%d): %w", threshold, err)
}
return nil
}
// deregister can be used to deregister an audit device, it will also configure
// the success threshold required for sinks.
// NOTE: deregister assumes that the backend which is being deregistered has already
// been removed from the broker's backends.
func (b *Broker) deregister(ctx context.Context, name string) error {
threshold := b.requiredSuccessThresholdSinks()
err := b.broker.SetSuccessThresholdSinks(event.AuditType.AsEventType(), threshold)
if err != nil {
return fmt.Errorf("unable to reconfigure sink success threshold (%d): %w", threshold, err)
}
// The first return value, a bool, indicates whether
// RemovePipelineAndNodes encountered the error while evaluating
// pre-conditions (false) or once it started removing the pipeline and
// the nodes (true). This code doesn't care either way.
_, err = b.broker.RemovePipelineAndNodes(ctx, event.AuditType.AsEventType(), eventlogger.PipelineID(name))
if err != nil {
return fmt.Errorf("unable to remove pipeline and nodes: %w", err)
}
return nil
}
// registerNodesAndPipeline registers eventlogger nodes and a pipeline with the
// backend's name, on the specified eventlogger.Broker using the Backend to supply them.
func registerNodesAndPipeline(broker *eventlogger.Broker, b Backend) error {
for id, node := range b.Nodes() {
err := broker.RegisterNode(id, node)
if err != nil {
return fmt.Errorf("unable to register nodes for %q: %w", b.Name(), err)
}
}
pipeline := eventlogger.Pipeline{
PipelineID: eventlogger.PipelineID(b.Name()),
EventType: b.EventType(),
NodeIDs: b.NodeIDs(),
}
err := broker.RegisterPipeline(pipeline)
if err != nil {
return fmt.Errorf("unable to register pipeline for %q: %w", b.Name(), err)
}
return nil
}
func (b *Broker) Register(backend Backend, local bool) error {
b.Lock()
defer b.Unlock()
if backend == nil || reflect.ValueOf(backend).IsNil() {
return fmt.Errorf("backend cannot be nil: %w", ErrInvalidParameter)
}
// If the backend is already registered, we cannot re-register it.
err := b.isRegistered(backend)
if err != nil {
return err
}
if err := b.handlePipelineRegistration(backend); err != nil {
return err
}
b.backends[backend.Name()] = backendEntry{
backend: backend,
local: local,
}
return nil
}
func (b *Broker) Deregister(ctx context.Context, name string) error {
b.Lock()
defer b.Unlock()
name = strings.TrimSpace(name)
if name == "" {
return fmt.Errorf("name is required: %w", ErrInvalidParameter)
}
// If the backend isn't actually registered, then there's nothing to do.
// We don't return any error so that Deregister can be idempotent.
if !b.isRegisteredByName(name) {
return nil
}
// Remove the Backend from the map first, so that if an error occurs while
// removing the pipeline and nodes, we can quickly exit this method with
// the error.
delete(b.backends, name)
if err := b.handlePipelineDeregistration(ctx, name); err != nil {
return err
}
return nil
}
// LogRequest is used to ensure all the audit backends have an opportunity to
// log the given request and that *at least one* succeeds.
func (b *Broker) LogRequest(ctx context.Context, in *logical.LogInput) (ret error) {
b.RLock()
defer b.RUnlock()
// If no backends are registered then we have no devices to log the request.
if len(b.backends) < 1 {
return nil
}
defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now())
defer func() {
metricVal := float32(0.0)
if ret != nil {
metricVal = 1.0
}
metrics.IncrCounter([]string{"audit", "log_request_failure"}, metricVal)
}()
var retErr *multierror.Error
e, err := NewEvent(RequestType)
if err != nil {
retErr = multierror.Append(retErr, err)
return retErr.ErrorOrNil()
}
e.Data = in
var status eventlogger.Status
if hasAuditPipelines(b.broker) {
status, err = b.broker.Send(ctx, event.AuditType.AsEventType(), e)
if err != nil {
retErr = multierror.Append(retErr, multierror.Append(err, status.Warnings...))
return retErr.ErrorOrNil()
}
}
// Audit event ended up in at least 1 sink.
if len(status.CompleteSinks()) > 0 {
return retErr.ErrorOrNil()
}
// There were errors from inside the pipeline and we didn't write to a sink.
if len(status.Warnings) > 0 {
retErr = multierror.Append(retErr, multierror.Append(errors.New("error during audit pipeline processing"), status.Warnings...))
return retErr.ErrorOrNil()
}
// Handle any additional audit that is required (Enterprise/CE dependant).
err = b.handleAdditionalAudit(ctx, e)
if err != nil {
retErr = multierror.Append(retErr, err)
}
return retErr.ErrorOrNil()
}
// LogResponse is used to ensure all the audit backends have an opportunity to
// log the given response and that *at least one* succeeds.
func (b *Broker) LogResponse(ctx context.Context, in *logical.LogInput) (ret error) {
b.RLock()
defer b.RUnlock()
// If no backends are registered then we have no devices to send audit entries to.
if len(b.backends) < 1 {
return nil
}
defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now())
defer func() {
metricVal := float32(0.0)
if ret != nil {
metricVal = 1.0
}
metrics.IncrCounter([]string{"audit", "log_response_failure"}, metricVal)
}()
var retErr *multierror.Error
e, err := NewEvent(ResponseType)
if err != nil {
retErr = multierror.Append(retErr, err)
return retErr.ErrorOrNil()
}
e.Data = in
// In cases where we are trying to audit the response, we detach
// ourselves from the original context (keeping only the namespace).
// This is so that we get a fair run at writing audit entries if Vault
// has taken up a lot of time handling the request before audit (response)
// is triggered. Pipeline nodes and the eventlogger.Broker may check for a
// cancelled context and refuse to process the nodes further.
ns, err := namespace.FromContext(ctx)
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("namespace missing from context: %w", err))
return retErr.ErrorOrNil()
}
auditContext, auditCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer auditCancel()
auditContext = namespace.ContextWithNamespace(auditContext, ns)
var status eventlogger.Status
if hasAuditPipelines(b.broker) {
status, err = b.broker.Send(auditContext, event.AuditType.AsEventType(), e)
if err != nil {
retErr = multierror.Append(retErr, multierror.Append(err, status.Warnings...))
return retErr.ErrorOrNil()
}
}
// Audit event ended up in at least 1 sink.
if len(status.CompleteSinks()) > 0 {
return retErr.ErrorOrNil()
}
// There were errors from inside the pipeline and we didn't write to a sink.
if len(status.Warnings) > 0 {
retErr = multierror.Append(retErr, multierror.Append(errors.New("error during audit pipeline processing"), status.Warnings...))
return retErr.ErrorOrNil()
}
// Handle any additional audit that is required (Enterprise/CE dependant).
err = b.handleAdditionalAudit(auditContext, e)
if err != nil {
retErr = multierror.Append(retErr, err)
}
return retErr.ErrorOrNil()
}
func (b *Broker) Invalidate(ctx context.Context, _ string) {
// For now, we ignore the key as this would only apply to salts.
// We just sort of brute force it on each one.
b.Lock()
defer b.Unlock()
for _, be := range b.backends {
be.backend.Invalidate(ctx)
}
}
// IsLocal is used to check if a given audit backend is registered
func (b *Broker) IsLocal(name string) (bool, error) {
b.RLock()
defer b.RUnlock()
be, ok := b.backends[name]
if ok {
return be.local, nil
}
return false, fmt.Errorf("unknown audit backend %q", name)
}
// GetHash returns a hash using the salt of the given backend
func (b *Broker) GetHash(ctx context.Context, name string, input string) (string, error) {
b.RLock()
defer b.RUnlock()
be, ok := b.backends[name]
if !ok {
return "", fmt.Errorf("unknown audit backend %q", name)
}
return HashString(ctx, be.backend, input)
}
// IsRegistered is used to check if a given audit backend is registered.
func (b *Broker) IsRegistered(name string) bool {
b.RLock()
defer b.RUnlock()
return b.isRegisteredByName(name)
}

49
audit/broker_ce.go Normal file
View File

@@ -0,0 +1,49 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package audit
import (
"context"
"fmt"
)
// brokerEnt provides extensions to the broker behavior, but not in the community edition.
type brokerEnt struct{}
func newBrokerEnt() (*brokerEnt, error) {
return &brokerEnt{}, nil
}
func (b *Broker) validateRegistrationRequest(_ Backend) error {
return nil
}
func (b *Broker) handlePipelineRegistration(backend Backend) error {
err := b.register(backend)
if err != nil {
return fmt.Errorf("unable to register device for %q: %w", backend.Name(), err)
}
return nil
}
func (b *Broker) handlePipelineDeregistration(ctx context.Context, name string) error {
return b.deregister(ctx, name)
}
// requiredSuccessThresholdSinks is the value that should be used as the success
// threshold in the eventlogger broker.
func (b *Broker) requiredSuccessThresholdSinks() int {
if len(b.backends) > 0 {
return 1
}
return 0
}
func (b *brokerEnt) handleAdditionalAudit(_ context.Context, _ *AuditEvent) error {
return nil
}

View File

@@ -1,7 +1,7 @@
// Copyright (c) HashiCorp, Inc. // Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1 // SPDX-License-Identifier: BUSL-1.1
package vault package audit
import ( import (
"context" "context"
@@ -10,9 +10,6 @@ import (
"time" "time"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/builtin/audit/syslog"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/sdk/helper/salt" "github.com/hashicorp/vault/sdk/helper/salt"
@@ -21,13 +18,12 @@ import (
) )
// testAuditBackend will create an audit.Backend (which expects to use the eventlogger). // testAuditBackend will create an audit.Backend (which expects to use the eventlogger).
// NOTE: this will create the backend, it does not care whether or not Enterprise // NOTE: this will create the backend, it does not care whether Enterprise only options are in place.
// only options are in place. func testAuditBackend(t *testing.T, path string, config map[string]string) Backend {
func testAuditBackend(t *testing.T, path string, config map[string]string) audit.Backend {
t.Helper() t.Helper()
headersCfg := &AuditedHeadersConfig{ headersCfg := &HeadersConfig{
headerSettings: make(map[string]*auditedHeaderSettings), headerSettings: make(map[string]*HeaderSettings),
view: nil, view: nil,
} }
@@ -36,7 +32,7 @@ func testAuditBackend(t *testing.T, path string, config map[string]string) audit
err := view.Put(context.Background(), se) err := view.Put(context.Background(), se)
require.NoError(t, err) require.NoError(t, err)
cfg := &audit.BackendConfig{ cfg := &BackendConfig{
SaltView: view, SaltView: view,
SaltConfig: &salt.Config{ SaltConfig: &salt.Config{
HMAC: sha256.New, HMAC: sha256.New,
@@ -47,7 +43,7 @@ func testAuditBackend(t *testing.T, path string, config map[string]string) audit
MountPath: path, MountPath: path,
} }
be, err := syslog.Factory(cfg, headersCfg) be, err := NewSyslogBackend(cfg, headersCfg)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, be) require.NotNil(t, be)
@@ -60,7 +56,7 @@ func TestAuditBroker_Deregister_Multiple(t *testing.T) {
t.Parallel() t.Parallel()
l := corehelpers.NewTestLogger(t) l := corehelpers.NewTestLogger(t)
a, err := NewAuditBroker(l) a, err := NewBroker(l)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, a) require.NotNil(t, a)
@@ -77,17 +73,17 @@ func TestAuditBroker_Register_MultipleFails(t *testing.T) {
t.Parallel() t.Parallel()
l := corehelpers.NewTestLogger(t) l := corehelpers.NewTestLogger(t)
a, err := NewAuditBroker(l) a, err := NewBroker(l)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, a) require.NotNil(t, a)
path := "b2-no-filter" path := "b2-no-filter"
noFilterBackend := testAuditBackend(t, path, map[string]string{}) noFilterBackend := testAuditBackend(t, path, map[string]string{})
err = a.Register(path, noFilterBackend, false) err = a.Register(noFilterBackend, false)
require.NoError(t, err) require.NoError(t, err)
err = a.Register(path, noFilterBackend, false) err = a.Register(noFilterBackend, false)
require.Error(t, err) require.Error(t, err)
require.EqualError(t, err, "backend already registered 'b2-no-filter': invalid configuration") require.EqualError(t, err, "backend already registered 'b2-no-filter': invalid configuration")
} }
@@ -108,7 +104,7 @@ func TestAuditBroker_Register_MultipleFails(t *testing.T) {
// formatter nodes format the events (to JSON/JSONX and perform HMACing etc) // formatter nodes format the events (to JSON/JSONX and perform HMACing etc)
// sink nodes handle sending the formatted data to a file, syslog or socket. // sink nodes handle sending the formatted data to a file, syslog or socket.
func BenchmarkAuditBroker_File_Request_DevNull(b *testing.B) { func BenchmarkAuditBroker_File_Request_DevNull(b *testing.B) {
backendConfig := &audit.BackendConfig{ backendConfig := &BackendConfig{
Config: map[string]string{ Config: map[string]string{
"path": "/dev/null", "path": "/dev/null",
}, },
@@ -118,13 +114,13 @@ func BenchmarkAuditBroker_File_Request_DevNull(b *testing.B) {
Logger: hclog.NewNullLogger(), Logger: hclog.NewNullLogger(),
} }
sink, err := file.Factory(backendConfig, nil) sink, err := NewFileBackend(backendConfig, nil)
require.NoError(b, err) require.NoError(b, err)
broker, err := NewAuditBroker(nil) broker, err := NewBroker(nil)
require.NoError(b, err) require.NoError(b, err)
err = broker.Register("test", sink, false) err = broker.Register(sink, false)
require.NoError(b, err) require.NoError(b, err)
in := &logical.LogInput{ in := &logical.LogInput{

View File

@@ -14,18 +14,18 @@ import (
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
var _ eventlogger.Node = (*EntryFilter)(nil) var _ eventlogger.Node = (*entryFilter)(nil)
// EntryFilter should be used to filter audit requests and responses which should // entryFilter should be used to filter audit requests and responses which should
// make it to a sink. // make it to a sink.
type EntryFilter struct { type entryFilter struct {
// the evaluator for the bexpr expression that should be applied by the node. // the evaluator for the bexpr expression that should be applied by the node.
evaluator *bexpr.Evaluator evaluator *bexpr.Evaluator
} }
// NewEntryFilter should be used to create an EntryFilter node. // newEntryFilter should be used to create an entryFilter node.
// The filter supplied should be in bexpr format and reference fields from logical.LogInputBexpr. // The filter supplied should be in bexpr format and reference fields from logical.LogInputBexpr.
func NewEntryFilter(filter string) (*EntryFilter, error) { func newEntryFilter(filter string) (*entryFilter, error) {
filter = strings.TrimSpace(filter) filter = strings.TrimSpace(filter)
if filter == "" { if filter == "" {
return nil, fmt.Errorf("cannot create new audit filter with empty filter expression: %w", ErrExternalOptions) return nil, fmt.Errorf("cannot create new audit filter with empty filter expression: %w", ErrExternalOptions)
@@ -45,22 +45,22 @@ func NewEntryFilter(filter string) (*EntryFilter, error) {
return nil, fmt.Errorf("filter references an unsupported field: %s: %w", filter, ErrExternalOptions) return nil, fmt.Errorf("filter references an unsupported field: %s: %w", filter, ErrExternalOptions)
} }
return &EntryFilter{evaluator: eval}, nil return &entryFilter{evaluator: eval}, nil
} }
// Reopen is a no-op for the filter node. // Reopen is a no-op for the filter node.
func (*EntryFilter) Reopen() error { func (*entryFilter) Reopen() error {
return nil return nil
} }
// Type describes the type of this node (filter). // Type describes the type of this node (filter).
func (*EntryFilter) Type() eventlogger.NodeType { func (*entryFilter) Type() eventlogger.NodeType {
return eventlogger.NodeTypeFilter return eventlogger.NodeTypeFilter
} }
// Process will attempt to parse the incoming event data and decide whether it // Process will attempt to parse the incoming event data and decide whether it
// should be filtered or remain in the pipeline and passed to the next node. // should be filtered or remain in the pipeline and passed to the next node.
func (f *EntryFilter) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) { func (f *entryFilter) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()

View File

@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// TestEntryFilter_NewEntryFilter tests that we can create EntryFilter types correctly. // TestEntryFilter_NewEntryFilter tests that we can create entryFilter types correctly.
func TestEntryFilter_NewEntryFilter(t *testing.T) { func TestEntryFilter_NewEntryFilter(t *testing.T) {
t.Parallel() t.Parallel()
@@ -72,7 +72,7 @@ func TestEntryFilter_NewEntryFilter(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
f, err := NewEntryFilter(tc.Filter) f, err := newEntryFilter(tc.Filter)
switch { switch {
case tc.IsErrorExpected: case tc.IsErrorExpected:
require.Error(t, err) require.Error(t, err)
@@ -90,7 +90,7 @@ func TestEntryFilter_NewEntryFilter(t *testing.T) {
func TestEntryFilter_Reopen(t *testing.T) { func TestEntryFilter_Reopen(t *testing.T) {
t.Parallel() t.Parallel()
f := &EntryFilter{} f := &entryFilter{}
res := f.Reopen() res := f.Reopen()
require.Nil(t, res) require.Nil(t, res)
} }
@@ -99,7 +99,7 @@ func TestEntryFilter_Reopen(t *testing.T) {
func TestEntryFilter_Type(t *testing.T) { func TestEntryFilter_Type(t *testing.T) {
t.Parallel() t.Parallel()
f := &EntryFilter{} f := &entryFilter{}
require.Equal(t, eventlogger.NodeTypeFilter, f.Type()) require.Equal(t, eventlogger.NodeTypeFilter, f.Type())
} }
@@ -113,7 +113,7 @@ func TestEntryFilter_Process_ContextDone(t *testing.T) {
// Explicitly cancel the context // Explicitly cancel the context
cancel() cancel()
l, err := NewEntryFilter("operation == foo") l, err := newEntryFilter("operation == foo")
require.NoError(t, err) require.NoError(t, err)
// Fake audit event // Fake audit event
@@ -142,7 +142,7 @@ func TestEntryFilter_Process_ContextDone(t *testing.T) {
func TestEntryFilter_Process_NilEvent(t *testing.T) { func TestEntryFilter_Process_NilEvent(t *testing.T) {
t.Parallel() t.Parallel()
l, err := NewEntryFilter("operation == foo") l, err := newEntryFilter("operation == foo")
require.NoError(t, err) require.NoError(t, err)
e, err := l.Process(context.Background(), nil) e, err := l.Process(context.Background(), nil)
require.Error(t, err) require.Error(t, err)
@@ -158,7 +158,7 @@ func TestEntryFilter_Process_NilEvent(t *testing.T) {
func TestEntryFilter_Process_BadPayload(t *testing.T) { func TestEntryFilter_Process_BadPayload(t *testing.T) {
t.Parallel() t.Parallel()
l, err := NewEntryFilter("operation == foo") l, err := newEntryFilter("operation == foo")
require.NoError(t, err) require.NoError(t, err)
e := &eventlogger.Event{ e := &eventlogger.Event{
@@ -181,7 +181,7 @@ func TestEntryFilter_Process_BadPayload(t *testing.T) {
func TestEntryFilter_Process_NoAuditDataInPayload(t *testing.T) { func TestEntryFilter_Process_NoAuditDataInPayload(t *testing.T) {
t.Parallel() t.Parallel()
l, err := NewEntryFilter("operation == foo") l, err := newEntryFilter("operation == foo")
require.NoError(t, err) require.NoError(t, err)
a, err := NewEvent(RequestType) a, err := NewEvent(RequestType)
@@ -209,7 +209,7 @@ func TestEntryFilter_Process_NoAuditDataInPayload(t *testing.T) {
func TestEntryFilter_Process_FilterSuccess(t *testing.T) { func TestEntryFilter_Process_FilterSuccess(t *testing.T) {
t.Parallel() t.Parallel()
l, err := NewEntryFilter("mount_type == juan") l, err := newEntryFilter("mount_type == juan")
require.NoError(t, err) require.NoError(t, err)
a, err := NewEvent(RequestType) a, err := NewEvent(RequestType)
@@ -242,7 +242,7 @@ func TestEntryFilter_Process_FilterSuccess(t *testing.T) {
func TestEntryFilter_Process_FilterFail(t *testing.T) { func TestEntryFilter_Process_FilterFail(t *testing.T) {
t.Parallel() t.Parallel()
l, err := NewEntryFilter("mount_type == john and operation == create and namespace == root") l, err := newEntryFilter("mount_type == john and operation == create and namespace == root")
require.NoError(t, err) require.NoError(t, err)
a, err := NewEvent(RequestType) a, err := NewEvent(RequestType)

View File

@@ -24,10 +24,7 @@ import (
"github.com/jefferai/jsonx" "github.com/jefferai/jsonx"
) )
var ( var _ eventlogger.Node = (*entryFormatter)(nil)
_ Formatter = (*EntryFormatter)(nil)
_ eventlogger.Node = (*EntryFormatter)(nil)
)
// timeProvider offers a way to supply a pre-configured time. // timeProvider offers a way to supply a pre-configured time.
type timeProvider interface { type timeProvider interface {
@@ -35,11 +32,14 @@ type timeProvider interface {
formattedTime() string formattedTime() string
} }
// FormatterConfig is used to provide basic configuration to a formatter. // nonPersistentSalt is used for obtaining a salt that is not persisted.
// Use NewFormatterConfig to initialize the FormatterConfig struct. type nonPersistentSalt struct{}
type FormatterConfig struct {
Raw bool // formatterConfig is used to provide basic configuration to a formatter.
HMACAccessor bool // Use newFormatterConfig to initialize the formatterConfig struct.
type formatterConfig struct {
raw bool
hmacAccessor bool
// Vault lacks pagination in its APIs. As a result, certain list operations can return **very** large responses. // Vault lacks pagination in its APIs. As a result, certain list operations can return **very** large responses.
// The user's chosen audit sinks may experience difficulty consuming audit records that swell to tens of megabytes // The user's chosen audit sinks may experience difficulty consuming audit records that swell to tens of megabytes
@@ -61,55 +61,32 @@ type FormatterConfig struct {
// The elision replaces the values of the "keys" and "key_info" fields with an integer count of the number of // The elision replaces the values of the "keys" and "key_info" fields with an integer count of the number of
// entries. This allows even the elided audit logs to still be useful for answering questions like // entries. This allows even the elided audit logs to still be useful for answering questions like
// "Was any data returned?" or "How many records were listed?". // "Was any data returned?" or "How many records were listed?".
ElideListResponses bool elideListResponses bool
// This should only ever be used in a testing context // This should only ever be used in a testing context
OmitTime bool omitTime bool
// The required/target format for the event (supported: JSONFormat and JSONxFormat). // The required/target format for the event (supported: JSONFormat and JSONxFormat).
RequiredFormat format requiredFormat format
// headerFormatter specifies the formatter used for headers that existing in any incoming audit request. // headerFormatter specifies the formatter used for headers that existing in any incoming audit request.
headerFormatter HeaderFormatter headerFormatter HeaderFormatter
// Prefix specifies a Prefix that should be prepended to any formatted request or response before serialization. // prefix specifies a prefix that should be prepended to any formatted request or response before serialization.
Prefix string prefix string
} }
// EntryFormatter should be used to format audit requests and responses. // entryFormatter should be used to format audit requests and responses.
// NOTE: Use NewEntryFormatter to initialize the EntryFormatter struct. // NOTE: Use newEntryFormatter to initialize the entryFormatter struct.
type EntryFormatter struct { type entryFormatter struct {
config FormatterConfig config formatterConfig
salter Salter salter Salter
logger hclog.Logger logger hclog.Logger
name string name string
} }
// NewFormatterConfig should be used to create a FormatterConfig. // newEntryFormatter should be used to create an entryFormatter.
// Accepted options: WithElision, WithFormat, WithHMACAccessor, WithOmitTime, WithPrefix, WithRaw. func newEntryFormatter(name string, config formatterConfig, salter Salter, logger hclog.Logger) (*entryFormatter, error) {
func NewFormatterConfig(headerFormatter HeaderFormatter, opt ...Option) (FormatterConfig, error) {
if headerFormatter == nil || reflect.ValueOf(headerFormatter).IsNil() {
return FormatterConfig{}, fmt.Errorf("header formatter is required: %w", ErrInvalidParameter)
}
opts, err := getOpts(opt...)
if err != nil {
return FormatterConfig{}, err
}
return FormatterConfig{
headerFormatter: headerFormatter,
ElideListResponses: opts.withElision,
HMACAccessor: opts.withHMACAccessor,
OmitTime: opts.withOmitTime,
Prefix: opts.withPrefix,
Raw: opts.withRaw,
RequiredFormat: opts.withFormat,
}, nil
}
// NewEntryFormatter should be used to create an EntryFormatter.
func NewEntryFormatter(name string, config FormatterConfig, salter Salter, logger hclog.Logger) (*EntryFormatter, error) {
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
if name == "" { if name == "" {
return nil, fmt.Errorf("name is required: %w", ErrInvalidParameter) return nil, fmt.Errorf("name is required: %w", ErrInvalidParameter)
@@ -123,7 +100,7 @@ func NewEntryFormatter(name string, config FormatterConfig, salter Salter, logge
return nil, fmt.Errorf("cannot create a new audit formatter with nil logger: %w", ErrInvalidParameter) return nil, fmt.Errorf("cannot create a new audit formatter with nil logger: %w", ErrInvalidParameter)
} }
return &EntryFormatter{ return &entryFormatter{
config: config, config: config,
salter: salter, salter: salter,
logger: logger, logger: logger,
@@ -132,18 +109,18 @@ func NewEntryFormatter(name string, config FormatterConfig, salter Salter, logge
} }
// Reopen is a no-op for the formatter node. // Reopen is a no-op for the formatter node.
func (*EntryFormatter) Reopen() error { func (*entryFormatter) Reopen() error {
return nil return nil
} }
// Type describes the type of this node (formatter). // Type describes the type of this node (formatter).
func (*EntryFormatter) Type() eventlogger.NodeType { func (*entryFormatter) Type() eventlogger.NodeType {
return eventlogger.NodeTypeFormatter return eventlogger.NodeTypeFormatter
} }
// Process will attempt to parse the incoming event data into a corresponding // Process will attempt to parse the incoming event data into a corresponding
// audit Request/Response which is serialized to JSON/JSONx and stored within the event. // audit Request/Response which is serialized to JSON/JSONx and stored within the event.
func (f *EntryFormatter) Process(ctx context.Context, e *eventlogger.Event) (_ *eventlogger.Event, retErr error) { func (f *entryFormatter) Process(ctx context.Context, e *eventlogger.Event) (_ *eventlogger.Event, retErr error) {
// Return early if the context was cancelled, eventlogger will not carry on // Return early if the context was cancelled, eventlogger will not carry on
// asking nodes to process, so any sink node in the pipeline won't be called. // asking nodes to process, so any sink node in the pipeline won't be called.
select { select {
@@ -211,14 +188,14 @@ func (f *EntryFormatter) Process(ctx context.Context, e *eventlogger.Event) (_ *
} }
// Using 'any' as we have two different types that we can get back from either // Using 'any' as we have two different types that we can get back from either
// FormatRequest or FormatResponse, but the JSON encoder doesn't care about types. // formatRequest or formatResponse, but the JSON encoder doesn't care about types.
var entry any var entry any
switch a.Subtype { switch a.Subtype {
case RequestType: case RequestType:
entry, err = f.FormatRequest(ctx, data, a) entry, err = f.formatRequest(ctx, data, a)
case ResponseType: case ResponseType:
entry, err = f.FormatResponse(ctx, data, a) entry, err = f.formatResponse(ctx, data, a)
default: default:
return nil, fmt.Errorf("unknown audit event subtype: %q", a.Subtype) return nil, fmt.Errorf("unknown audit event subtype: %q", a.Subtype)
} }
@@ -231,7 +208,7 @@ func (f *EntryFormatter) Process(ctx context.Context, e *eventlogger.Event) (_ *
return nil, fmt.Errorf("unable to format %s: %w", a.Subtype, err) return nil, fmt.Errorf("unable to format %s: %w", a.Subtype, err)
} }
if f.config.RequiredFormat == JSONxFormat { if f.config.requiredFormat == JSONxFormat {
var err error var err error
result, err = jsonx.EncodeJSONBytes(result) result, err = jsonx.EncodeJSONBytes(result)
if err != nil { if err != nil {
@@ -246,8 +223,8 @@ func (f *EntryFormatter) Process(ctx context.Context, e *eventlogger.Event) (_ *
// don't support a prefix just sitting there. // don't support a prefix just sitting there.
// However, this would be a breaking change to how Vault currently works to // However, this would be a breaking change to how Vault currently works to
// include the prefix as part of the JSON object or XML document. // include the prefix as part of the JSON object or XML document.
if f.config.Prefix != "" { if f.config.prefix != "" {
result = append([]byte(f.config.Prefix), result...) result = append([]byte(f.config.prefix), result...)
} }
// Copy some properties from the event (and audit event) and store the // Copy some properties from the event (and audit event) and store the
@@ -267,13 +244,13 @@ func (f *EntryFormatter) Process(ctx context.Context, e *eventlogger.Event) (_ *
Payload: a2, Payload: a2,
} }
e2.FormattedAs(f.config.RequiredFormat.String(), result) e2.FormattedAs(f.config.requiredFormat.String(), result)
return e2, nil return e2, nil
} }
// FormatRequest attempts to format the specified logical.LogInput into a RequestEntry. // formatRequest attempts to format the specified logical.LogInput into a RequestEntry.
func (f *EntryFormatter) FormatRequest(ctx context.Context, in *logical.LogInput, provider timeProvider) (*RequestEntry, error) { func (f *entryFormatter) formatRequest(ctx context.Context, in *logical.LogInput, provider timeProvider) (*RequestEntry, error) {
switch { switch {
case in == nil || in.Request == nil: case in == nil || in.Request == nil:
return nil, errors.New("request to request-audit a nil request") return nil, errors.New("request to request-audit a nil request")
@@ -293,14 +270,14 @@ func (f *EntryFormatter) FormatRequest(ctx context.Context, in *logical.LogInput
connState = in.Request.Connection.ConnState connState = in.Request.Connection.ConnState
} }
if !f.config.Raw { if !f.config.raw {
var err error var err error
auth, err = HashAuth(ctx, f.salter, auth, f.config.HMACAccessor) auth, err = HashAuth(ctx, f.salter, auth, f.config.hmacAccessor)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err = HashRequest(ctx, f.salter, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) req, err = HashRequest(ctx, f.salter, req, f.config.hmacAccessor, in.NonHMACReqDataKeys)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -395,7 +372,7 @@ func (f *EntryFormatter) FormatRequest(ctx context.Context, in *logical.LogInput
reqEntry.Request.WrapTTL = int(req.WrapInfo.TTL / time.Second) reqEntry.Request.WrapTTL = int(req.WrapInfo.TTL / time.Second)
} }
if !f.config.OmitTime { if !f.config.omitTime {
// Use the time provider to supply the time for this entry. // Use the time provider to supply the time for this entry.
reqEntry.Time = provider.formattedTime() reqEntry.Time = provider.formattedTime()
} }
@@ -403,8 +380,8 @@ func (f *EntryFormatter) FormatRequest(ctx context.Context, in *logical.LogInput
return reqEntry, nil return reqEntry, nil
} }
// FormatResponse attempts to format the specified logical.LogInput into a ResponseEntry. // formatResponse attempts to format the specified logical.LogInput into a ResponseEntry.
func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInput, provider timeProvider) (*ResponseEntry, error) { func (f *entryFormatter) formatResponse(ctx context.Context, in *logical.LogInput, provider timeProvider) (*ResponseEntry, error) {
switch { switch {
case f == nil: case f == nil:
return nil, errors.New("formatter is nil") return nil, errors.New("formatter is nil")
@@ -428,10 +405,10 @@ func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInpu
connState = in.Request.Connection.ConnState connState = in.Request.Connection.ConnState
} }
elideListResponseData := f.config.ElideListResponses && req.Operation == logical.ListOperation elideListResponseData := f.config.elideListResponses && req.Operation == logical.ListOperation
var respData map[string]interface{} var respData map[string]interface{}
if f.config.Raw { if f.config.raw {
// In the non-raw case, elision of list response data occurs inside HashResponse, to avoid redundant deep // In the non-raw case, elision of list response data occurs inside HashResponse, to avoid redundant deep
// copies and hashing of data only to elide it later. In the raw case, we need to do it here. // copies and hashing of data only to elide it later. In the raw case, we need to do it here.
if elideListResponseData && resp.Data != nil { if elideListResponseData && resp.Data != nil {
@@ -447,17 +424,17 @@ func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInpu
} }
} else { } else {
var err error var err error
auth, err = HashAuth(ctx, f.salter, auth, f.config.HMACAccessor) auth, err = HashAuth(ctx, f.salter, auth, f.config.hmacAccessor)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err = HashRequest(ctx, f.salter, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) req, err = HashRequest(ctx, f.salter, req, f.config.hmacAccessor, in.NonHMACReqDataKeys)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err = HashResponse(ctx, f.salter, resp, f.config.HMACAccessor, in.NonHMACRespDataKeys, elideListResponseData) resp, err = HashResponse(ctx, f.salter, resp, f.config.hmacAccessor, in.NonHMACRespDataKeys, elideListResponseData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -616,7 +593,7 @@ func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInpu
respEntry.Request.WrapTTL = int(req.WrapInfo.TTL / time.Second) respEntry.Request.WrapTTL = int(req.WrapInfo.TTL / time.Second)
} }
if !f.config.OmitTime { if !f.config.omitTime {
// Use the time provider to supply the time for this entry. // Use the time provider to supply the time for this entry.
respEntry.Time = provider.formattedTime() respEntry.Time = provider.formattedTime()
} }
@@ -674,7 +651,7 @@ func parseVaultTokenFromJWT(token string) *string {
// determined it should apply to a particular request. The data map that is passed in must be a copy that is safe to // determined it should apply to a particular request. The data map that is passed in must be a copy that is safe to
// modify in place, but need not be a full recursive deep copy, as only top-level keys are changed. // modify in place, but need not be a full recursive deep copy, as only top-level keys are changed.
// //
// See the documentation of the controlling option in FormatterConfig for more information on the purpose. // See the documentation of the controlling option in formatterConfig for more information on the purpose.
func doElideListResponseData(data map[string]interface{}) { func doElideListResponseData(data map[string]interface{}) {
for k, v := range data { for k, v := range data {
if k == "keys" { if k == "keys" {
@@ -689,9 +666,9 @@ func doElideListResponseData(data map[string]interface{}) {
} }
} }
// newTemporaryEntryFormatter creates a cloned EntryFormatter instance with a non-persistent Salter. // newTemporaryEntryFormatter creates a cloned entryFormatter instance with a non-persistent Salter.
func newTemporaryEntryFormatter(n *EntryFormatter) *EntryFormatter { func newTemporaryEntryFormatter(n *entryFormatter) *entryFormatter {
return &EntryFormatter{ return &entryFormatter{
salter: &nonPersistentSalt{}, salter: &nonPersistentSalt{},
config: n.config, config: n.config,
} }

View File

@@ -60,7 +60,7 @@ const testFormatJSONReqBasicStrFmt = `
` `
// testHeaderFormatter is a stub to prevent the need to import the vault package // testHeaderFormatter is a stub to prevent the need to import the vault package
// to bring in vault.AuditedHeadersConfig for testing. // to bring in vault.HeadersConfig for testing.
type testHeaderFormatter struct { type testHeaderFormatter struct {
shouldReturnEmpty bool shouldReturnEmpty bool
} }
@@ -86,7 +86,7 @@ func (p *testTimeProvider) formattedTime() string {
return time.Date(2024, time.March, 22, 10, 0o0, 5, 10, time.UTC).UTC().Format(time.RFC3339Nano) return time.Date(2024, time.March, 22, 10, 0o0, 5, 10, time.UTC).UTC().Format(time.RFC3339Nano)
} }
// TestNewEntryFormatter ensures we can create new EntryFormatter structs. // TestNewEntryFormatter ensures we can create new entryFormatter structs.
func TestNewEntryFormatter(t *testing.T) { func TestNewEntryFormatter(t *testing.T) {
t.Parallel() t.Parallel()
@@ -94,7 +94,7 @@ func TestNewEntryFormatter(t *testing.T) {
Name string Name string
UseStaticSalt bool UseStaticSalt bool
Logger hclog.Logger Logger hclog.Logger
Options []Option // Only supports WithPrefix Options map[string]string
IsErrorExpected bool IsErrorExpected bool
ExpectedErrorMessage string ExpectedErrorMessage string
ExpectedFormat format ExpectedFormat format
@@ -128,8 +128,8 @@ func TestNewEntryFormatter(t *testing.T) {
UseStaticSalt: true, UseStaticSalt: true,
Logger: hclog.NewNullLogger(), Logger: hclog.NewNullLogger(),
IsErrorExpected: false, IsErrorExpected: false,
Options: []Option{ Options: map[string]string{
WithFormat(JSONFormat.String()), "format": "json",
}, },
ExpectedFormat: JSONFormat, ExpectedFormat: JSONFormat,
}, },
@@ -144,8 +144,8 @@ func TestNewEntryFormatter(t *testing.T) {
Name: "juan", Name: "juan",
UseStaticSalt: true, UseStaticSalt: true,
Logger: hclog.NewNullLogger(), Logger: hclog.NewNullLogger(),
Options: []Option{ Options: map[string]string{
WithFormat(JSONFormat.String()), "format": "json",
}, },
IsErrorExpected: false, IsErrorExpected: false,
ExpectedFormat: JSONFormat, ExpectedFormat: JSONFormat,
@@ -154,8 +154,8 @@ func TestNewEntryFormatter(t *testing.T) {
Name: "juan", Name: "juan",
UseStaticSalt: true, UseStaticSalt: true,
Logger: hclog.NewNullLogger(), Logger: hclog.NewNullLogger(),
Options: []Option{ Options: map[string]string{
WithFormat(JSONxFormat.String()), "format": "jsonx",
}, },
IsErrorExpected: false, IsErrorExpected: false,
ExpectedFormat: JSONxFormat, ExpectedFormat: JSONxFormat,
@@ -164,9 +164,9 @@ func TestNewEntryFormatter(t *testing.T) {
Name: "juan", Name: "juan",
UseStaticSalt: true, UseStaticSalt: true,
Logger: hclog.NewNullLogger(), Logger: hclog.NewNullLogger(),
Options: []Option{ Options: map[string]string{
WithPrefix("foo"), "prefix": "foo",
WithFormat(JSONFormat.String()), "format": "json",
}, },
IsErrorExpected: false, IsErrorExpected: false,
ExpectedFormat: JSONFormat, ExpectedFormat: JSONFormat,
@@ -176,9 +176,9 @@ func TestNewEntryFormatter(t *testing.T) {
Name: "juan", Name: "juan",
UseStaticSalt: true, UseStaticSalt: true,
Logger: hclog.NewNullLogger(), Logger: hclog.NewNullLogger(),
Options: []Option{ Options: map[string]string{
WithPrefix("foo"), "prefix": "foo",
WithFormat(JSONxFormat.String()), "format": "jsonx",
}, },
IsErrorExpected: false, IsErrorExpected: false,
ExpectedFormat: JSONxFormat, ExpectedFormat: JSONxFormat,
@@ -196,9 +196,9 @@ func TestNewEntryFormatter(t *testing.T) {
ss = newStaticSalt(t) ss = newStaticSalt(t)
} }
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, tc.Options...) cfg, err := newFormatterConfig(&testHeaderFormatter{}, tc.Options)
require.NoError(t, err) require.NoError(t, err)
f, err := NewEntryFormatter(tc.Name, cfg, ss, tc.Logger) f, err := newEntryFormatter(tc.Name, cfg, ss, tc.Logger)
switch { switch {
case tc.IsErrorExpected: case tc.IsErrorExpected:
@@ -208,8 +208,8 @@ func TestNewEntryFormatter(t *testing.T) {
default: default:
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, f) require.NotNil(t, f)
require.Equal(t, tc.ExpectedFormat, f.config.RequiredFormat) require.Equal(t, tc.ExpectedFormat, f.config.requiredFormat)
require.Equal(t, tc.ExpectedPrefix, f.config.Prefix) require.Equal(t, tc.ExpectedPrefix, f.config.prefix)
} }
}) })
} }
@@ -220,10 +220,10 @@ func TestEntryFormatter_Reopen(t *testing.T) {
t.Parallel() t.Parallel()
ss := newStaticSalt(t) ss := newStaticSalt(t)
cfg, err := NewFormatterConfig(&testHeaderFormatter{}) cfg, err := newFormatterConfig(&testHeaderFormatter{}, nil)
require.NoError(t, err) require.NoError(t, err)
f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) f, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, f) require.NotNil(t, f)
require.NoError(t, f.Reopen()) require.NoError(t, f.Reopen())
@@ -234,10 +234,10 @@ func TestEntryFormatter_Type(t *testing.T) {
t.Parallel() t.Parallel()
ss := newStaticSalt(t) ss := newStaticSalt(t)
cfg, err := NewFormatterConfig(&testHeaderFormatter{}) cfg, err := newFormatterConfig(&testHeaderFormatter{}, nil)
require.NoError(t, err) require.NoError(t, err)
f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) f, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, f) require.NotNil(t, f)
require.Equal(t, eventlogger.NodeTypeFormatter, f.Type()) require.Equal(t, eventlogger.NodeTypeFormatter, f.Type())
@@ -379,10 +379,10 @@ func TestEntryFormatter_Process(t *testing.T) {
require.NotNil(t, e) require.NotNil(t, e)
ss := newStaticSalt(t) ss := newStaticSalt(t)
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, WithFormat(tc.RequiredFormat.String())) cfg, err := newFormatterConfig(&testHeaderFormatter{}, map[string]string{"format": tc.RequiredFormat.String()})
require.NoError(t, err) require.NoError(t, err)
f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) f, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, f) require.NotNil(t, f)
@@ -412,7 +412,7 @@ func TestEntryFormatter_Process(t *testing.T) {
} }
} }
// BenchmarkAuditFileSink_Process benchmarks the EntryFormatter and then event.FileSink calling Process. // BenchmarkAuditFileSink_Process benchmarks the entryFormatter and then event.FileSink calling Process.
// This should replicate the original benchmark testing which used to perform both of these roles together. // This should replicate the original benchmark testing which used to perform both of these roles together.
func BenchmarkAuditFileSink_Process(b *testing.B) { func BenchmarkAuditFileSink_Process(b *testing.B) {
// Base input // Base input
@@ -444,10 +444,10 @@ func BenchmarkAuditFileSink_Process(b *testing.B) {
ctx := namespace.RootContext(context.Background()) ctx := namespace.RootContext(context.Background())
// Create the formatter node. // Create the formatter node.
cfg, err := NewFormatterConfig(&testHeaderFormatter{}) cfg, err := newFormatterConfig(&testHeaderFormatter{}, nil)
require.NoError(b, err) require.NoError(b, err)
ss := newStaticSalt(b) ss := newStaticSalt(b)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) formatter, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(b, err) require.NoError(b, err)
require.NotNil(b, formatter) require.NotNil(b, formatter)
@@ -475,7 +475,7 @@ func BenchmarkAuditFileSink_Process(b *testing.B) {
}) })
} }
// TestEntryFormatter_FormatRequest exercises EntryFormatter.FormatRequest with // TestEntryFormatter_FormatRequest exercises entryFormatter.formatRequest with
// varying inputs. // varying inputs.
func TestEntryFormatter_FormatRequest(t *testing.T) { func TestEntryFormatter_FormatRequest(t *testing.T) {
t.Parallel() t.Parallel()
@@ -522,9 +522,10 @@ func TestEntryFormatter_FormatRequest(t *testing.T) {
t.Parallel() t.Parallel()
ss := newStaticSalt(t) ss := newStaticSalt(t)
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, WithOmitTime(tc.ShouldOmitTime)) cfg, err := newFormatterConfig(&testHeaderFormatter{}, nil)
cfg.omitTime = tc.ShouldOmitTime
require.NoError(t, err) require.NoError(t, err)
f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) f, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
var ctx context.Context var ctx context.Context
@@ -535,7 +536,7 @@ func TestEntryFormatter_FormatRequest(t *testing.T) {
ctx = context.Background() ctx = context.Background()
} }
entry, err := f.FormatRequest(ctx, tc.Input, &testTimeProvider{}) entry, err := f.formatRequest(ctx, tc.Input, &testTimeProvider{})
switch { switch {
case tc.IsErrorExpected: case tc.IsErrorExpected:
@@ -556,7 +557,7 @@ func TestEntryFormatter_FormatRequest(t *testing.T) {
} }
} }
// TestEntryFormatter_FormatResponse exercises EntryFormatter.FormatResponse with // TestEntryFormatter_FormatResponse exercises entryFormatter.formatResponse with
// varying inputs. // varying inputs.
func TestEntryFormatter_FormatResponse(t *testing.T) { func TestEntryFormatter_FormatResponse(t *testing.T) {
t.Parallel() t.Parallel()
@@ -604,9 +605,10 @@ func TestEntryFormatter_FormatResponse(t *testing.T) {
t.Parallel() t.Parallel()
ss := newStaticSalt(t) ss := newStaticSalt(t)
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, WithOmitTime(tc.ShouldOmitTime)) cfg, err := newFormatterConfig(&testHeaderFormatter{}, nil)
cfg.omitTime = tc.ShouldOmitTime
require.NoError(t, err) require.NoError(t, err)
f, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) f, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
var ctx context.Context var ctx context.Context
@@ -617,7 +619,7 @@ func TestEntryFormatter_FormatResponse(t *testing.T) {
ctx = context.Background() ctx = context.Background()
} }
entry, err := f.FormatResponse(ctx, tc.Input, &testTimeProvider{}) entry, err := f.formatResponse(ctx, tc.Input, &testTimeProvider{})
switch { switch {
case tc.IsErrorExpected: case tc.IsErrorExpected:
@@ -720,9 +722,12 @@ func TestEntryFormatter_Process_JSON(t *testing.T) {
} }
for name, tc := range cases { for name, tc := range cases {
cfg, err := NewFormatterConfig(&testHeaderFormatter{}, WithHMACAccessor(false), WithPrefix(tc.Prefix)) cfg, err := newFormatterConfig(&testHeaderFormatter{}, map[string]string{
"hmac_accessor": "false",
"prefix": tc.Prefix,
})
require.NoError(t, err) require.NoError(t, err)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) formatter, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
in := &logical.LogInput{ in := &logical.LogInput{
@@ -877,15 +882,16 @@ func TestEntryFormatter_Process_JSONx(t *testing.T) {
} }
for name, tc := range cases { for name, tc := range cases {
cfg, err := NewFormatterConfig( cfg, err := newFormatterConfig(
&testHeaderFormatter{}, &testHeaderFormatter{},
WithOmitTime(true), map[string]string{
WithHMACAccessor(false), "format": "jsonx",
WithFormat(JSONxFormat.String()), "hmac_accessor": "false",
WithPrefix(tc.Prefix), "prefix": tc.Prefix,
) })
cfg.omitTime = true
require.NoError(t, err) require.NoError(t, err)
formatter, err := NewEntryFormatter("juan", cfg, tempStaticSalt, hclog.NewNullLogger()) formatter, err := newEntryFormatter("juan", cfg, tempStaticSalt, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, formatter) require.NotNil(t, formatter)
@@ -997,11 +1003,11 @@ func TestEntryFormatter_FormatResponse_ElideListResponses(t *testing.T) {
ss := newStaticSalt(t) ss := newStaticSalt(t)
ctx := namespace.RootContext(context.Background()) ctx := namespace.RootContext(context.Background())
var formatter *EntryFormatter var formatter *entryFormatter
var err error var err error
format := func(t *testing.T, config FormatterConfig, operation logical.Operation, inputData map[string]any) *ResponseEntry { format := func(t *testing.T, config formatterConfig, operation logical.Operation, inputData map[string]any) *ResponseEntry {
formatter, err = NewEntryFormatter("juan", config, ss, hclog.NewNullLogger()) formatter, err = newEntryFormatter("juan", config, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, formatter) require.NotNil(t, formatter)
@@ -1010,14 +1016,14 @@ func TestEntryFormatter_FormatResponse_ElideListResponses(t *testing.T) {
Response: &logical.Response{Data: inputData}, Response: &logical.Response{Data: inputData},
} }
resp, err := formatter.FormatResponse(ctx, in, &testTimeProvider{}) resp, err := formatter.formatResponse(ctx, in, &testTimeProvider{})
require.NoError(t, err) require.NoError(t, err)
return resp return resp
} }
t.Run("Default case", func(t *testing.T) { t.Run("Default case", func(t *testing.T) {
config, err := NewFormatterConfig(&testHeaderFormatter{}, WithElision(true)) config, err := newFormatterConfig(&testHeaderFormatter{}, map[string]string{"elide_list_responses": "true"})
require.NoError(t, err) require.NoError(t, err)
for name, tc := range tests { for name, tc := range tests {
name := name name := name
@@ -1030,23 +1036,30 @@ func TestEntryFormatter_FormatResponse_ElideListResponses(t *testing.T) {
}) })
t.Run("When Operation is not list, eliding does not happen", func(t *testing.T) { t.Run("When Operation is not list, eliding does not happen", func(t *testing.T) {
config, err := NewFormatterConfig(&testHeaderFormatter{}, WithElision(true)) config, err := newFormatterConfig(&testHeaderFormatter{}, map[string]string{"elide_list_responses": "true"})
require.NoError(t, err) require.NoError(t, err)
tc := oneInterestingTestCase tc := oneInterestingTestCase
entry := format(t, config, logical.ReadOperation, tc.inputData) entry := format(t, config, logical.ReadOperation, tc.inputData)
assert.Equal(t, formatter.hashExpectedValueForComparison(tc.inputData), entry.Response.Data) assert.Equal(t, formatter.hashExpectedValueForComparison(tc.inputData), entry.Response.Data)
}) })
t.Run("When ElideListResponses is false, eliding does not happen", func(t *testing.T) { t.Run("When elideListResponses is false, eliding does not happen", func(t *testing.T) {
config, err := NewFormatterConfig(&testHeaderFormatter{}, WithElision(false), WithFormat(JSONFormat.String())) config, err := newFormatterConfig(&testHeaderFormatter{}, map[string]string{
"elide_list_responses": "false",
"format": "json",
})
require.NoError(t, err) require.NoError(t, err)
tc := oneInterestingTestCase tc := oneInterestingTestCase
entry := format(t, config, logical.ListOperation, tc.inputData) entry := format(t, config, logical.ListOperation, tc.inputData)
assert.Equal(t, formatter.hashExpectedValueForComparison(tc.inputData), entry.Response.Data) assert.Equal(t, formatter.hashExpectedValueForComparison(tc.inputData), entry.Response.Data)
}) })
t.Run("When Raw is true, eliding still happens", func(t *testing.T) { t.Run("When raw is true, eliding still happens", func(t *testing.T) {
config, err := NewFormatterConfig(&testHeaderFormatter{}, WithElision(true), WithRaw(true), WithFormat(JSONFormat.String())) config, err := newFormatterConfig(&testHeaderFormatter{}, map[string]string{
"elide_list_responses": "true",
"format": "json",
"log_raw": "true",
})
require.NoError(t, err) require.NoError(t, err)
tc := oneInterestingTestCase tc := oneInterestingTestCase
entry := format(t, config, logical.ListOperation, tc.inputData) entry := format(t, config, logical.ListOperation, tc.inputData)
@@ -1055,15 +1068,15 @@ func TestEntryFormatter_FormatResponse_ElideListResponses(t *testing.T) {
} }
// TestEntryFormatter_Process_NoMutation tests that the event returned by an // TestEntryFormatter_Process_NoMutation tests that the event returned by an
// EntryFormatter.Process method is not the same as the one that it accepted. // entryFormatter.Process method is not the same as the one that it accepted.
func TestEntryFormatter_Process_NoMutation(t *testing.T) { func TestEntryFormatter_Process_NoMutation(t *testing.T) {
t.Parallel() t.Parallel()
// Create the formatter node. // Create the formatter node.
cfg, err := NewFormatterConfig(&testHeaderFormatter{}) cfg, err := newFormatterConfig(&testHeaderFormatter{}, nil)
require.NoError(t, err) require.NoError(t, err)
ss := newStaticSalt(t) ss := newStaticSalt(t)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) formatter, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, formatter) require.NotNil(t, formatter)
@@ -1113,17 +1126,17 @@ func TestEntryFormatter_Process_NoMutation(t *testing.T) {
require.NotEqual(t, a2, a) require.NotEqual(t, a2, a)
} }
// TestEntryFormatter_Process_Panic tries to send data into the EntryFormatter // TestEntryFormatter_Process_Panic tries to send data into the entryFormatter
// which will currently cause a panic when a response is formatted due to the // which will currently cause a panic when a response is formatted due to the
// underlying hashing that is done with reflectwalk. // underlying hashing that is done with reflectwalk.
func TestEntryFormatter_Process_Panic(t *testing.T) { func TestEntryFormatter_Process_Panic(t *testing.T) {
t.Parallel() t.Parallel()
// Create the formatter node. // Create the formatter node.
cfg, err := NewFormatterConfig(&testHeaderFormatter{}) cfg, err := newFormatterConfig(&testHeaderFormatter{}, nil)
require.NoError(t, err) require.NoError(t, err)
ss := newStaticSalt(t) ss := newStaticSalt(t)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) formatter, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, formatter) require.NotNil(t, formatter)
@@ -1174,9 +1187,9 @@ func TestEntryFormatter_Process_Panic(t *testing.T) {
} }
// TestEntryFormatter_NewFormatterConfig_NilHeaderFormatter ensures we cannot // TestEntryFormatter_NewFormatterConfig_NilHeaderFormatter ensures we cannot
// create a FormatterConfig using NewFormatterConfig if we supply a nil formatter. // create a formatterConfig using NewFormatterConfig if we supply a nil formatter.
func TestEntryFormatter_NewFormatterConfig_NilHeaderFormatter(t *testing.T) { func TestEntryFormatter_NewFormatterConfig_NilHeaderFormatter(t *testing.T) {
_, err := NewFormatterConfig(nil) _, err := newFormatterConfig(nil, nil)
require.Error(t, err) require.Error(t, err)
} }
@@ -1187,10 +1200,10 @@ func TestEntryFormatter_Process_NeverLeaksHeaders(t *testing.T) {
t.Parallel() t.Parallel()
// Create the formatter node. // Create the formatter node.
cfg, err := NewFormatterConfig(&testHeaderFormatter{shouldReturnEmpty: true}) cfg, err := newFormatterConfig(&testHeaderFormatter{shouldReturnEmpty: true}, nil)
require.NoError(t, err) require.NoError(t, err)
ss := newStaticSalt(t) ss := newStaticSalt(t)
formatter, err := NewEntryFormatter("juan", cfg, ss, hclog.NewNullLogger()) formatter, err := newEntryFormatter("juan", cfg, ss, hclog.NewNullLogger())
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, formatter) require.NotNil(t, formatter)
@@ -1222,7 +1235,7 @@ func TestEntryFormatter_Process_NeverLeaksHeaders(t *testing.T) {
// hashExpectedValueForComparison replicates enough of the audit HMAC process on a piece of expected data in a test, // hashExpectedValueForComparison replicates enough of the audit HMAC process on a piece of expected data in a test,
// so that we can use assert.Equal to compare the expected and output values. // so that we can use assert.Equal to compare the expected and output values.
func (f *EntryFormatter) hashExpectedValueForComparison(input map[string]any) map[string]any { func (f *entryFormatter) hashExpectedValueForComparison(input map[string]any) map[string]any {
// Copy input before modifying, since we may re-use the same data in another test // Copy input before modifying, since we may re-use the same data in another test
copied, err := copystructure.Copy(input) copied, err := copystructure.Copy(input)
if err != nil { if err != nil {

View File

@@ -1,7 +1,7 @@
// Copyright (c) HashiCorp, Inc. // Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1 // SPDX-License-Identifier: BUSL-1.1
package vault package audit
import ( import (
"context" "context"
@@ -9,7 +9,6 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
@@ -17,56 +16,71 @@ import (
// requires all headers to be converted to lower case, so we just do that. // requires all headers to be converted to lower case, so we just do that.
const ( const (
// Key used in the BarrierView to store and retrieve the header config // auditedHeadersEntry is the key used in storage to store and retrieve the header config
auditedHeadersEntry = "audited-headers" auditedHeadersEntry = "audited-headers"
// Path used to create a sub view off of BarrierView
auditedHeadersSubPath = "audited-headers-config/" // AuditedHeadersSubPath is the path used to create a sub view within storage.
AuditedHeadersSubPath = "audited-headers-config/"
) )
// auditedHeadersKey returns the key at which audit header configuration is stored. type durableStorer interface {
func auditedHeadersKey() string { Get(ctx context.Context, key string) (*logical.StorageEntry, error)
return auditedHeadersSubPath + auditedHeadersEntry Put(ctx context.Context, entry *logical.StorageEntry) error
} }
type auditedHeaderSettings struct { // HeaderFormatter is an interface defining the methods of the
// vault.HeadersConfig structure needed in this package.
type HeaderFormatter interface {
// ApplyConfig returns a map of header values that consists of the
// intersection of the provided set of header values with a configured
// set of headers and will hash headers that have been configured as such.
ApplyConfig(context.Context, map[string][]string, Salter) (map[string][]string, error)
}
// AuditedHeadersKey returns the key at which audit header configuration is stored.
func AuditedHeadersKey() string {
return AuditedHeadersSubPath + auditedHeadersEntry
}
type HeaderSettings struct {
// HMAC is used to indicate whether the value of the header should be HMAC'd. // HMAC is used to indicate whether the value of the header should be HMAC'd.
HMAC bool `json:"hmac"` HMAC bool `json:"hmac"`
} }
// AuditedHeadersConfig is used by the Audit Broker to write only approved // HeadersConfig is used by the Audit Broker to write only approved
// headers to the audit logs. It uses a BarrierView to persist the settings. // headers to the audit logs. It uses a BarrierView to persist the settings.
type AuditedHeadersConfig struct { type HeadersConfig struct {
// headerSettings stores the current headers that should be audited, and their settings. // headerSettings stores the current headers that should be audited, and their settings.
headerSettings map[string]*auditedHeaderSettings headerSettings map[string]*HeaderSettings
// view is the barrier view which should be used to access underlying audit header config data. // view is the barrier view which should be used to access underlying audit header config data.
view *BarrierView view durableStorer
sync.RWMutex sync.RWMutex
} }
// NewAuditedHeadersConfig should be used to create AuditedHeadersConfig. // NewHeadersConfig should be used to create HeadersConfig.
func NewAuditedHeadersConfig(view *BarrierView) (*AuditedHeadersConfig, error) { func NewHeadersConfig(view durableStorer) (*HeadersConfig, error) {
if view == nil { if view == nil {
return nil, fmt.Errorf("barrier view cannot be nil") return nil, fmt.Errorf("barrier view cannot be nil")
} }
// This should be the only place where the AuditedHeadersConfig struct is initialized. // This should be the only place where the HeadersConfig struct is initialized.
// Store the view so that we can reload headers when we 'invalidate'. // Store the view so that we can reload headers when we 'Invalidate'.
return &AuditedHeadersConfig{ return &HeadersConfig{
view: view, view: view,
headerSettings: make(map[string]*auditedHeaderSettings), headerSettings: make(map[string]*HeaderSettings),
}, nil }, nil
} }
// header attempts to retrieve a copy of the settings associated with the specified header. // Header attempts to retrieve a copy of the settings associated with the specified header.
// The second boolean return parameter indicates whether the header existed in configuration, // The second boolean return parameter indicates whether the header existed in configuration,
// it should be checked as when 'false' the returned settings will have the default values. // it should be checked as when 'false' the returned settings will have the default values.
func (a *AuditedHeadersConfig) header(name string) (auditedHeaderSettings, bool) { func (a *HeadersConfig) Header(name string) (HeaderSettings, bool) {
a.RLock() a.RLock()
defer a.RUnlock() defer a.RUnlock()
var s auditedHeaderSettings var s HeaderSettings
v, ok := a.headerSettings[strings.ToLower(name)] v, ok := a.headerSettings[strings.ToLower(name)]
if ok { if ok {
@@ -76,25 +90,25 @@ func (a *AuditedHeadersConfig) header(name string) (auditedHeaderSettings, bool)
return s, ok return s, ok
} }
// headers returns all existing headers along with a copy of their current settings. // Headers returns all existing headers along with a copy of their current settings.
func (a *AuditedHeadersConfig) headers() map[string]auditedHeaderSettings { func (a *HeadersConfig) Headers() map[string]HeaderSettings {
a.RLock() a.RLock()
defer a.RUnlock() defer a.RUnlock()
// We know how many entries the map should have. // We know how many entries the map should have.
headers := make(map[string]auditedHeaderSettings, len(a.headerSettings)) headers := make(map[string]HeaderSettings, len(a.headerSettings))
// Clone the headers // Clone the headers
for name, setting := range a.headerSettings { for name, setting := range a.headerSettings {
headers[name] = auditedHeaderSettings{HMAC: setting.HMAC} headers[name] = HeaderSettings{HMAC: setting.HMAC}
} }
return headers return headers
} }
// add adds or overwrites a header in the config and updates the barrier view // Add adds or overwrites a header in the config and updates the barrier view
// NOTE: add will acquire a write lock in order to update the underlying headers. // NOTE: Add will acquire a write lock in order to update the underlying headers.
func (a *AuditedHeadersConfig) add(ctx context.Context, header string, hmac bool) error { func (a *HeadersConfig) Add(ctx context.Context, header string, hmac bool) error {
if header == "" { if header == "" {
return fmt.Errorf("header value cannot be empty") return fmt.Errorf("header value cannot be empty")
} }
@@ -104,10 +118,10 @@ func (a *AuditedHeadersConfig) add(ctx context.Context, header string, hmac bool
defer a.Unlock() defer a.Unlock()
if a.headerSettings == nil { if a.headerSettings == nil {
a.headerSettings = make(map[string]*auditedHeaderSettings, 1) a.headerSettings = make(map[string]*HeaderSettings, 1)
} }
a.headerSettings[strings.ToLower(header)] = &auditedHeaderSettings{hmac} a.headerSettings[strings.ToLower(header)] = &HeaderSettings{hmac}
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.headerSettings) entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.headerSettings)
if err != nil { if err != nil {
return fmt.Errorf("failed to persist audited headers config: %w", err) return fmt.Errorf("failed to persist audited headers config: %w", err)
@@ -120,9 +134,9 @@ func (a *AuditedHeadersConfig) add(ctx context.Context, header string, hmac bool
return nil return nil
} }
// remove deletes a header out of the header config and updates the barrier view // Remove deletes a header out of the header config and updates the barrier view
// NOTE: remove will acquire a write lock in order to update the underlying headers. // NOTE: Remove will acquire a write lock in order to update the underlying headers.
func (a *AuditedHeadersConfig) remove(ctx context.Context, header string) error { func (a *HeadersConfig) Remove(ctx context.Context, header string) error {
if header == "" { if header == "" {
return fmt.Errorf("header value cannot be empty") return fmt.Errorf("header value cannot be empty")
} }
@@ -149,9 +163,9 @@ func (a *AuditedHeadersConfig) remove(ctx context.Context, header string) error
return nil return nil
} }
// invalidate attempts to refresh the allowed audit headers and their settings. // Invalidate attempts to refresh the allowed audit headers and their settings.
// NOTE: invalidate will acquire a write lock in order to update the underlying headers. // NOTE: Invalidate will acquire a write lock in order to update the underlying headers.
func (a *AuditedHeadersConfig) invalidate(ctx context.Context) error { func (a *HeadersConfig) Invalidate(ctx context.Context) error {
a.Lock() a.Lock()
defer a.Unlock() defer a.Unlock()
@@ -163,7 +177,7 @@ func (a *AuditedHeadersConfig) invalidate(ctx context.Context) error {
// If we cannot update the stored 'new' headers, we will clear the existing // If we cannot update the stored 'new' headers, we will clear the existing
// ones as part of invalidation. // ones as part of invalidation.
headers := make(map[string]*auditedHeaderSettings) headers := make(map[string]*HeaderSettings)
if out != nil { if out != nil {
err = out.DecodeJSON(&headers) err = out.DecodeJSON(&headers)
if err != nil { if err != nil {
@@ -173,7 +187,7 @@ func (a *AuditedHeadersConfig) invalidate(ctx context.Context) error {
// Ensure that we are able to case-sensitively access the headers; // Ensure that we are able to case-sensitively access the headers;
// necessary for the upgrade case // necessary for the upgrade case
lowerHeaders := make(map[string]*auditedHeaderSettings, len(headers)) lowerHeaders := make(map[string]*HeaderSettings, len(headers))
for k, v := range headers { for k, v := range headers {
lowerHeaders[strings.ToLower(k)] = v lowerHeaders[strings.ToLower(k)] = v
} }
@@ -184,7 +198,7 @@ func (a *AuditedHeadersConfig) invalidate(ctx context.Context) error {
// ApplyConfig returns a map of approved headers and their values, either HMAC'd or plaintext. // ApplyConfig returns a map of approved headers and their values, either HMAC'd or plaintext.
// If the supplied headers are empty or nil, an empty set of headers will be returned. // If the supplied headers are empty or nil, an empty set of headers will be returned.
func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, salter audit.Salter) (result map[string][]string, retErr error) { func (a *HeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, salter Salter) (result map[string][]string, retErr error) {
// Return early if we don't have headers. // Return early if we don't have headers.
if len(headers) < 1 { if len(headers) < 1 {
return map[string][]string{}, nil return map[string][]string{}, nil
@@ -211,7 +225,7 @@ func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[stri
// Optionally hmac the values // Optionally hmac the values
if settings.HMAC { if settings.HMAC {
for i, el := range hVals { for i, el := range hVals {
hVal, err := audit.HashString(ctx, salter, el) hVal, err := HashString(ctx, salter, el)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -225,26 +239,3 @@ func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[stri
return result, nil return result, nil
} }
// setupAuditedHeadersConfig will initialize new audited headers configuration on
// the Core by loading data from the barrier view.
func (c *Core) setupAuditedHeadersConfig(ctx context.Context) error {
// Create a sub-view, e.g. sys/audited-headers-config/
view := c.systemBarrierView.SubView(auditedHeadersSubPath)
headers, err := NewAuditedHeadersConfig(view)
if err != nil {
return err
}
// Invalidate the headers now in order to load them for the first time.
err = headers.invalidate(ctx)
if err != nil {
return err
}
// Update the Core.
c.auditedHeaders = headers
return nil
}

View File

@@ -1,7 +1,7 @@
// Copyright (c) HashiCorp, Inc. // Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1 // SPDX-License-Identifier: BUSL-1.1
package vault package audit
import ( import (
"context" "context"
@@ -20,6 +20,7 @@ import (
// mockStorage is a struct that is used to mock barrier storage. // mockStorage is a struct that is used to mock barrier storage.
type mockStorage struct { type mockStorage struct {
mock.Mock mock.Mock
v map[string][]byte
} }
// List implements List from BarrierStorage interface. // List implements List from BarrierStorage interface.
@@ -30,12 +31,27 @@ func (m *mockStorage) List(_ context.Context, _ string) ([]string, error) {
// Get implements Get from BarrierStorage interface. // Get implements Get from BarrierStorage interface.
// ignore-nil-nil-function-check. // ignore-nil-nil-function-check.
func (m *mockStorage) Get(_ context.Context, _ string) (*logical.StorageEntry, error) { func (m *mockStorage) Get(_ context.Context, key string) (*logical.StorageEntry, error) {
b, ok := m.v[key]
if !ok {
return nil, nil return nil, nil
} }
var entry *logical.StorageEntry
err := json.Unmarshal(b, &entry)
return entry, err
}
// Put implements Put from BarrierStorage interface. // Put implements Put from BarrierStorage interface.
func (m *mockStorage) Put(_ context.Context, _ *logical.StorageEntry) error { func (m *mockStorage) Put(_ context.Context, entry *logical.StorageEntry) error {
b, err := json.Marshal(entry)
if err != nil {
return err
}
m.v[entry.Key] = b
return nil return nil
} }
@@ -44,12 +60,19 @@ func (m *mockStorage) Delete(_ context.Context, _ string) error {
return nil return nil
} }
func mockAuditedHeadersConfig(t *testing.T) *AuditedHeadersConfig { func newMockStorage(t *testing.T) *mockStorage {
_, barrier, _ := mockBarrier(t) t.Helper()
view := NewBarrierView(barrier, "foo/")
return &AuditedHeadersConfig{ return &mockStorage{
headerSettings: make(map[string]*auditedHeaderSettings), Mock: mock.Mock{},
view: view, v: make(map[string][]byte),
}
}
func mockAuditedHeadersConfig(t *testing.T) *HeadersConfig {
return &HeadersConfig{
headerSettings: make(map[string]*HeaderSettings),
view: newMockStorage(t),
} }
} }
@@ -60,8 +83,8 @@ func TestAuditedHeadersConfig_CRUD(t *testing.T) {
testAuditedHeadersConfig_Remove(t, conf) testAuditedHeadersConfig_Remove(t, conf)
} }
func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) { func testAuditedHeadersConfig_Add(t *testing.T, conf *HeadersConfig) {
err := conf.add(context.Background(), "X-Test-Header", false) err := conf.Add(context.Background(), "X-Test-Header", false)
if err != nil { if err != nil {
t.Fatalf("Error when adding header to config: %s", err) t.Fatalf("Error when adding header to config: %s", err)
} }
@@ -83,13 +106,13 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatal("nil value") t.Fatal("nil value")
} }
headers := make(map[string]*auditedHeaderSettings) headers := make(map[string]*HeaderSettings)
err = out.DecodeJSON(&headers) err = out.DecodeJSON(&headers)
if err != nil { if err != nil {
t.Fatalf("Error decoding header view: %s", err) t.Fatalf("Error decoding header view: %s", err)
} }
expected := map[string]*auditedHeaderSettings{ expected := map[string]*HeaderSettings{
"x-test-header": { "x-test-header": {
HMAC: false, HMAC: false,
}, },
@@ -99,7 +122,7 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers) t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
} }
err = conf.add(context.Background(), "X-Vault-Header", true) err = conf.Add(context.Background(), "X-Vault-Header", true)
if err != nil { if err != nil {
t.Fatalf("Error when adding header to config: %s", err) t.Fatalf("Error when adding header to config: %s", err)
} }
@@ -121,13 +144,13 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatal("nil value") t.Fatal("nil value")
} }
headers = make(map[string]*auditedHeaderSettings) headers = make(map[string]*HeaderSettings)
err = out.DecodeJSON(&headers) err = out.DecodeJSON(&headers)
if err != nil { if err != nil {
t.Fatalf("Error decoding header view: %s", err) t.Fatalf("Error decoding header view: %s", err)
} }
expected["x-vault-header"] = &auditedHeaderSettings{ expected["x-vault-header"] = &HeaderSettings{
HMAC: true, HMAC: true,
} }
@@ -136,8 +159,8 @@ func testAuditedHeadersConfig_Add(t *testing.T, conf *AuditedHeadersConfig) {
} }
} }
func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) { func testAuditedHeadersConfig_Remove(t *testing.T, conf *HeadersConfig) {
err := conf.remove(context.Background(), "X-Test-Header") err := conf.Remove(context.Background(), "X-Test-Header")
if err != nil { if err != nil {
t.Fatalf("Error when adding header to config: %s", err) t.Fatalf("Error when adding header to config: %s", err)
} }
@@ -155,13 +178,13 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatal("nil value") t.Fatal("nil value")
} }
headers := make(map[string]*auditedHeaderSettings) headers := make(map[string]*HeaderSettings)
err = out.DecodeJSON(&headers) err = out.DecodeJSON(&headers)
if err != nil { if err != nil {
t.Fatalf("Error decoding header view: %s", err) t.Fatalf("Error decoding header view: %s", err)
} }
expected := map[string]*auditedHeaderSettings{ expected := map[string]*HeaderSettings{
"x-vault-header": { "x-vault-header": {
HMAC: true, HMAC: true,
}, },
@@ -171,7 +194,7 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers) t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
} }
err = conf.remove(context.Background(), "x-VaulT-Header") err = conf.Remove(context.Background(), "x-VaulT-Header")
if err != nil { if err != nil {
t.Fatalf("Error when adding header to config: %s", err) t.Fatalf("Error when adding header to config: %s", err)
} }
@@ -189,30 +212,24 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) {
t.Fatal("nil value") t.Fatal("nil value")
} }
headers = make(map[string]*auditedHeaderSettings) headers = make(map[string]*HeaderSettings)
err = out.DecodeJSON(&headers) err = out.DecodeJSON(&headers)
if err != nil { if err != nil {
t.Fatalf("Error decoding header view: %s", err) t.Fatalf("Error decoding header view: %s", err)
} }
expected = make(map[string]*auditedHeaderSettings) expected = make(map[string]*HeaderSettings)
if !reflect.DeepEqual(headers, expected) { if !reflect.DeepEqual(headers, expected) {
t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers) t.Fatalf("Expected config didn't match actual. Expected: %#v, Got: %#v", expected, headers)
} }
} }
type TestSalter struct{}
func (*TestSalter) Salt(ctx context.Context) (*salt.Salt, error) {
return salt.NewSalt(ctx, nil, nil)
}
func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) {
conf := mockAuditedHeadersConfig(t) conf := mockAuditedHeadersConfig(t)
conf.add(context.Background(), "X-TesT-Header", false) conf.Add(context.Background(), "X-TesT-Header", false)
conf.add(context.Background(), "X-Vault-HeAdEr", true) conf.Add(context.Background(), "X-Vault-HeAdEr", true)
reqHeaders := map[string][]string{ reqHeaders := map[string][]string{
"X-Test-Header": {"foo"}, "X-Test-Header": {"foo"},
@@ -273,9 +290,9 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) {
func TestAuditedHeadersConfig_ApplyConfig_NoRequestHeaders(t *testing.T) { func TestAuditedHeadersConfig_ApplyConfig_NoRequestHeaders(t *testing.T) {
conf := mockAuditedHeadersConfig(t) conf := mockAuditedHeadersConfig(t)
err := conf.add(context.Background(), "X-TesT-Header", false) err := conf.Add(context.Background(), "X-TesT-Header", false)
require.NoError(t, err) require.NoError(t, err)
err = conf.add(context.Background(), "X-Vault-HeAdEr", true) err = conf.Add(context.Background(), "X-Vault-HeAdEr", true)
require.NoError(t, err) require.NoError(t, err)
salter := &TestSalter{} salter := &TestSalter{}
@@ -337,8 +354,8 @@ func (s *FailingSalter) Salt(context.Context) (*salt.Salt, error) {
func TestAuditedHeadersConfig_ApplyConfig_HashStringError(t *testing.T) { func TestAuditedHeadersConfig_ApplyConfig_HashStringError(t *testing.T) {
conf := mockAuditedHeadersConfig(t) conf := mockAuditedHeadersConfig(t)
conf.add(context.Background(), "X-TesT-Header", false) conf.Add(context.Background(), "X-TesT-Header", false)
conf.add(context.Background(), "X-Vault-HeAdEr", true) conf.Add(context.Background(), "X-Vault-HeAdEr", true)
reqHeaders := map[string][]string{ reqHeaders := map[string][]string{
"X-Test-Header": {"foo"}, "X-Test-Header": {"foo"},
@@ -355,12 +372,12 @@ func TestAuditedHeadersConfig_ApplyConfig_HashStringError(t *testing.T) {
} }
func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) { func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) {
conf := &AuditedHeadersConfig{ conf := &HeadersConfig{
headerSettings: make(map[string]*auditedHeaderSettings), headerSettings: make(map[string]*HeaderSettings),
view: nil, view: nil,
} }
conf.headerSettings = map[string]*auditedHeaderSettings{ conf.headerSettings = map[string]*HeaderSettings{
"X-Test-Header": {false}, "X-Test-Header": {false},
"X-Vault-Header": {true}, "X-Vault-Header": {true},
} }
@@ -383,46 +400,45 @@ func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) {
// TestAuditedHeaders_auditedHeadersKey is used to check the key we use to handle // TestAuditedHeaders_auditedHeadersKey is used to check the key we use to handle
// invalidation doesn't change when we weren't expecting it to. // invalidation doesn't change when we weren't expecting it to.
func TestAuditedHeaders_auditedHeadersKey(t *testing.T) { func TestAuditedHeaders_auditedHeadersKey(t *testing.T) {
require.Equal(t, "audited-headers-config/audited-headers", auditedHeadersKey()) require.Equal(t, "audited-headers-config/audited-headers", AuditedHeadersKey())
} }
// TestAuditedHeaders_NewAuditedHeadersConfig checks supplying incorrect params to // TestAuditedHeaders_NewAuditedHeadersConfig checks supplying incorrect params to
// the constructor for AuditedHeadersConfig returns an error. // the constructor for HeadersConfig returns an error.
func TestAuditedHeaders_NewAuditedHeadersConfig(t *testing.T) { func TestAuditedHeaders_NewAuditedHeadersConfig(t *testing.T) {
ac, err := NewAuditedHeadersConfig(nil) ac, err := NewHeadersConfig(nil)
require.Error(t, err) require.Error(t, err)
require.Nil(t, ac) require.Nil(t, ac)
ac, err = NewAuditedHeadersConfig(&BarrierView{}) ac, err = NewHeadersConfig(newMockStorage(t))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, ac) require.NotNil(t, ac)
} }
// TestAuditedHeaders_invalidate ensures that we can update the headers on AuditedHeadersConfig // TestAuditedHeaders_invalidate ensures that we can update the headers on HeadersConfig
// when we invalidate, and load the updated headers from the view/storage. // when we invalidate, and load the updated headers from the view/storage.
func TestAuditedHeaders_invalidate(t *testing.T) { func TestAuditedHeaders_invalidate(t *testing.T) {
_, barrier, _ := mockBarrier(t) view := newMockStorage(t)
view := NewBarrierView(barrier, auditedHeadersSubPath) ahc, err := NewHeadersConfig(view)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0) require.Len(t, ahc.headerSettings, 0)
// Store some data using the view. // Store some data using the view.
fakeHeaders1 := map[string]*auditedHeaderSettings{"x-magic-header": {}} fakeHeaders1 := map[string]*HeaderSettings{"x-magic-header": {}}
fakeBytes1, err := json.Marshal(fakeHeaders1) fakeBytes1, err := json.Marshal(fakeHeaders1)
require.NoError(t, err) require.NoError(t, err)
err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes1}) err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes1})
require.NoError(t, err) require.NoError(t, err)
// Invalidate and check we now see the header we stored // Invalidate and check we now see the header we stored
err = ahc.invalidate(context.Background()) err = ahc.Invalidate(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 1) require.Len(t, ahc.headerSettings, 1)
_, ok := ahc.headerSettings["x-magic-header"] _, ok := ahc.headerSettings["x-magic-header"]
require.True(t, ok) require.True(t, ok)
// Do it again with more headers and random casing. // Do it again with more headers and random casing.
fakeHeaders2 := map[string]*auditedHeaderSettings{ fakeHeaders2 := map[string]*HeaderSettings{
"x-magic-header": {}, "x-magic-header": {},
"x-even-MORE-magic-header": {}, "x-even-MORE-magic-header": {},
} }
@@ -432,7 +448,7 @@ func TestAuditedHeaders_invalidate(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Invalidate and check we now see the header we stored // Invalidate and check we now see the header we stored
err = ahc.invalidate(context.Background()) err = ahc.Invalidate(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 2) require.Len(t, ahc.headerSettings, 2)
_, ok = ahc.headerSettings["x-magic-header"] _, ok = ahc.headerSettings["x-magic-header"]
@@ -444,21 +460,20 @@ func TestAuditedHeaders_invalidate(t *testing.T) {
// TestAuditedHeaders_invalidate_nil_view ensures that we invalidate the headers // TestAuditedHeaders_invalidate_nil_view ensures that we invalidate the headers
// correctly (clear them) when we get nil for the storage entry from the view. // correctly (clear them) when we get nil for the storage entry from the view.
func TestAuditedHeaders_invalidate_nil_view(t *testing.T) { func TestAuditedHeaders_invalidate_nil_view(t *testing.T) {
_, barrier, _ := mockBarrier(t) view := newMockStorage(t)
view := NewBarrierView(barrier, auditedHeadersSubPath) ahc, err := NewHeadersConfig(view)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0) require.Len(t, ahc.headerSettings, 0)
// Store some data using the view. // Store some data using the view.
fakeHeaders1 := map[string]*auditedHeaderSettings{"x-magic-header": {}} fakeHeaders1 := map[string]*HeaderSettings{"x-magic-header": {}}
fakeBytes1, err := json.Marshal(fakeHeaders1) fakeBytes1, err := json.Marshal(fakeHeaders1)
require.NoError(t, err) require.NoError(t, err)
err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes1}) err = view.Put(context.Background(), &logical.StorageEntry{Key: auditedHeadersEntry, Value: fakeBytes1})
require.NoError(t, err) require.NoError(t, err)
// Invalidate and check we now see the header we stored // Invalidate and check we now see the header we stored
err = ahc.invalidate(context.Background()) err = ahc.Invalidate(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 1) require.Len(t, ahc.headerSettings, 1)
_, ok := ahc.headerSettings["x-magic-header"] _, ok := ahc.headerSettings["x-magic-header"]
@@ -466,12 +481,13 @@ func TestAuditedHeaders_invalidate_nil_view(t *testing.T) {
// Swap out the view with a mock that returns nil when we try to invalidate. // Swap out the view with a mock that returns nil when we try to invalidate.
// This should mean we end up just clearing the headers (no errors). // This should mean we end up just clearing the headers (no errors).
mockStorageBarrier := new(mockStorage) mockStorageBarrier := newMockStorage(t)
mockStorageBarrier.On("Get", mock.Anything, mock.Anything).Return(nil, nil) mockStorageBarrier.On("Get", mock.Anything, mock.Anything).Return(nil, nil)
ahc.view = NewBarrierView(mockStorageBarrier, auditedHeadersSubPath) ahc.view = mockStorageBarrier
// ahc.view = NewBarrierView(mockStorageBarrier, AuditedHeadersSubPath)
// Invalidate should clear out the existing headers without error // Invalidate should clear out the existing headers without error
err = ahc.invalidate(context.Background()) err = ahc.Invalidate(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0) require.Len(t, ahc.headerSettings, 0)
} }
@@ -479,9 +495,8 @@ func TestAuditedHeaders_invalidate_nil_view(t *testing.T) {
// TestAuditedHeaders_invalidate_bad_data ensures that we correctly error if the // TestAuditedHeaders_invalidate_bad_data ensures that we correctly error if the
// underlying data cannot be parsed as expected. // underlying data cannot be parsed as expected.
func TestAuditedHeaders_invalidate_bad_data(t *testing.T) { func TestAuditedHeaders_invalidate_bad_data(t *testing.T) {
_, barrier, _ := mockBarrier(t) view := newMockStorage(t)
view := NewBarrierView(barrier, auditedHeadersSubPath) ahc, err := NewHeadersConfig(view)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0) require.Len(t, ahc.headerSettings, 0)
@@ -492,7 +507,7 @@ func TestAuditedHeaders_invalidate_bad_data(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Invalidate should // Invalidate should
err = ahc.invalidate(context.Background()) err = ahc.Invalidate(context.Background())
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "failed to parse config") require.ErrorContains(t, err, "failed to parse config")
} }
@@ -500,40 +515,38 @@ func TestAuditedHeaders_invalidate_bad_data(t *testing.T) {
// TestAuditedHeaders_header checks we can return a copy of settings associated with // TestAuditedHeaders_header checks we can return a copy of settings associated with
// an existing header, and we also know when a header wasn't found. // an existing header, and we also know when a header wasn't found.
func TestAuditedHeaders_header(t *testing.T) { func TestAuditedHeaders_header(t *testing.T) {
_, barrier, _ := mockBarrier(t) view := newMockStorage(t)
view := NewBarrierView(barrier, auditedHeadersSubPath) ahc, err := NewHeadersConfig(view)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0) require.Len(t, ahc.headerSettings, 0)
err = ahc.add(context.Background(), "juan", true) err = ahc.Add(context.Background(), "juan", true)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 1) require.Len(t, ahc.headerSettings, 1)
s, ok := ahc.header("juan") s, ok := ahc.Header("juan")
require.True(t, ok) require.True(t, ok)
require.Equal(t, true, s.HMAC) require.Equal(t, true, s.HMAC)
s, ok = ahc.header("x-magic-token") s, ok = ahc.Header("x-magic-token")
require.False(t, ok) require.False(t, ok)
} }
// TestAuditedHeaders_headers checks we are able to return a copy of the existing // TestAuditedHeaders_headers checks we are able to return a copy of the existing
// configured headers. // configured headers.
func TestAuditedHeaders_headers(t *testing.T) { func TestAuditedHeaders_headers(t *testing.T) {
_, barrier, _ := mockBarrier(t) view := newMockStorage(t)
view := NewBarrierView(barrier, auditedHeadersSubPath) ahc, err := NewHeadersConfig(view)
ahc, err := NewAuditedHeadersConfig(view)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 0) require.Len(t, ahc.headerSettings, 0)
err = ahc.add(context.Background(), "juan", true) err = ahc.Add(context.Background(), "juan", true)
require.NoError(t, err) require.NoError(t, err)
err = ahc.add(context.Background(), "john", false) err = ahc.Add(context.Background(), "john", false)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ahc.headerSettings, 2) require.Len(t, ahc.headerSettings, 2)
s := ahc.headers() s := ahc.Headers()
require.Len(t, s, 2) require.Len(t, s, 2)
require.Equal(t, true, s["juan"].HMAC) require.Equal(t, true, s["juan"].HMAC)
require.Equal(t, false, s["john"].HMAC) require.Equal(t, false, s["john"].HMAC)

View File

@@ -14,14 +14,14 @@ import (
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
// ProcessManual will attempt to create an (audit) event with the specified data // processManual will attempt to create an (audit) event with the specified data
// and manually iterate over the supplied nodes calling Process on each until the // and manually iterate over the supplied nodes calling Process on each until the
// event is nil (which indicates the pipeline has completed). // event is nil (which indicates the pipeline has completed).
// Order of IDs in the NodeID slice determines the order they are processed. // Order of IDs in the NodeID slice determines the order they are processed.
// (Audit) Event will be of RequestType (as opposed to ResponseType). // (Audit) Event will be of RequestType (as opposed to ResponseType).
// The last node must be a filter node (eventlogger.NodeTypeFilter) or // The last node must be a filter node (eventlogger.NodeTypeFilter) or
// sink node (eventlogger.NodeTypeSink). // sink node (eventlogger.NodeTypeSink).
func ProcessManual(ctx context.Context, data *logical.LogInput, ids []eventlogger.NodeID, nodes map[eventlogger.NodeID]eventlogger.Node) error { func processManual(ctx context.Context, data *logical.LogInput, ids []eventlogger.NodeID, nodes map[eventlogger.NodeID]eventlogger.Node) error {
switch { switch {
case data == nil: case data == nil:
return errors.New("data cannot be nil") return errors.New("data cannot be nil")
@@ -71,7 +71,7 @@ func ProcessManual(ctx context.Context, data *logical.LogInput, ids []eventlogge
switch node.Type() { switch node.Type() {
case eventlogger.NodeTypeFormatter: case eventlogger.NodeTypeFormatter:
// Use a temporary formatter node which doesn't persist its salt anywhere. // Use a temporary formatter node which doesn't persist its salt anywhere.
if formatNode, ok := node.(*EntryFormatter); ok && formatNode != nil { if formatNode, ok := node.(*entryFormatter); ok && formatNode != nil {
e, err = newTemporaryEntryFormatter(formatNode).Process(ctx, e) e, err = newTemporaryEntryFormatter(formatNode).Process(ctx, e)
} }
default: default:

View File

@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// TestProcessManual_NilData tests ProcessManual when nil data is supplied. // TestProcessManual_NilData tests processManual when nil data is supplied.
func TestProcessManual_NilData(t *testing.T) { func TestProcessManual_NilData(t *testing.T) {
t.Parallel() t.Parallel()
@@ -32,12 +32,12 @@ func TestProcessManual_NilData(t *testing.T) {
ids = append(ids, sinkId) ids = append(ids, sinkId)
nodes[sinkId] = sinkNode nodes[sinkId] = sinkNode
err := ProcessManual(namespace.RootContext(context.Background()), nil, ids, nodes) err := processManual(namespace.RootContext(context.Background()), nil, ids, nodes)
require.Error(t, err) require.Error(t, err)
require.EqualError(t, err, "data cannot be nil") require.EqualError(t, err, "data cannot be nil")
} }
// TestProcessManual_BadIDs tests ProcessManual when different bad values are // TestProcessManual_BadIDs tests processManual when different bad values are
// supplied for the ID parameter. // supplied for the ID parameter.
func TestProcessManual_BadIDs(t *testing.T) { func TestProcessManual_BadIDs(t *testing.T) {
t.Parallel() t.Parallel()
@@ -76,14 +76,14 @@ func TestProcessManual_BadIDs(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
data := newData(requestId) data := newData(requestId)
err = ProcessManual(namespace.RootContext(context.Background()), data, tc.IDs, nodes) err = processManual(namespace.RootContext(context.Background()), data, tc.IDs, nodes)
require.Error(t, err) require.Error(t, err)
require.EqualError(t, err, tc.ExpectedErrorMessage) require.EqualError(t, err, tc.ExpectedErrorMessage)
}) })
} }
} }
// TestProcessManual_NoNodes tests ProcessManual when no nodes are supplied. // TestProcessManual_NoNodes tests processManual when no nodes are supplied.
func TestProcessManual_NoNodes(t *testing.T) { func TestProcessManual_NoNodes(t *testing.T) {
t.Parallel() t.Parallel()
@@ -103,12 +103,12 @@ func TestProcessManual_NoNodes(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
data := newData(requestId) data := newData(requestId)
err = ProcessManual(namespace.RootContext(context.Background()), data, ids, nodes) err = processManual(namespace.RootContext(context.Background()), data, ids, nodes)
require.Error(t, err) require.Error(t, err)
require.EqualError(t, err, "nodes are required") require.EqualError(t, err, "nodes are required")
} }
// TestProcessManual_IdNodeMismatch tests ProcessManual when IDs don't match with // TestProcessManual_IdNodeMismatch tests processManual when IDs don't match with
// the nodes in the supplied map. // the nodes in the supplied map.
func TestProcessManual_IdNodeMismatch(t *testing.T) { func TestProcessManual_IdNodeMismatch(t *testing.T) {
t.Parallel() t.Parallel()
@@ -130,12 +130,12 @@ func TestProcessManual_IdNodeMismatch(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
data := newData(requestId) data := newData(requestId)
err = ProcessManual(namespace.RootContext(context.Background()), data, ids, nodes) err = processManual(namespace.RootContext(context.Background()), data, ids, nodes)
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "node not found: ") require.ErrorContains(t, err, "node not found: ")
} }
// TestProcessManual_NotEnoughNodes tests ProcessManual when there is only one // TestProcessManual_NotEnoughNodes tests processManual when there is only one
// node provided. // node provided.
func TestProcessManual_NotEnoughNodes(t *testing.T) { func TestProcessManual_NotEnoughNodes(t *testing.T) {
t.Parallel() t.Parallel()
@@ -153,12 +153,12 @@ func TestProcessManual_NotEnoughNodes(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
data := newData(requestId) data := newData(requestId)
err = ProcessManual(namespace.RootContext(context.Background()), data, ids, nodes) err = processManual(namespace.RootContext(context.Background()), data, ids, nodes)
require.Error(t, err) require.Error(t, err)
require.EqualError(t, err, "minimum of 2 ids are required") require.EqualError(t, err, "minimum of 2 ids are required")
} }
// TestProcessManual_LastNodeNotSink tests ProcessManual when the last node is // TestProcessManual_LastNodeNotSink tests processManual when the last node is
// not a Sink node. // not a Sink node.
func TestProcessManual_LastNodeNotSink(t *testing.T) { func TestProcessManual_LastNodeNotSink(t *testing.T) {
t.Parallel() t.Parallel()
@@ -181,7 +181,7 @@ func TestProcessManual_LastNodeNotSink(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
data := newData(requestId) data := newData(requestId)
err = ProcessManual(namespace.RootContext(context.Background()), data, ids, nodes) err = processManual(namespace.RootContext(context.Background()), data, ids, nodes)
require.Error(t, err) require.Error(t, err)
require.EqualError(t, err, "last node must be a filter or sink") require.EqualError(t, err, "last node must be a filter or sink")
} }
@@ -210,7 +210,7 @@ func TestProcessManualEndWithSink(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
data := newData(requestId) data := newData(requestId)
err = ProcessManual(namespace.RootContext(context.Background()), data, ids, nodes) err = processManual(namespace.RootContext(context.Background()), data, ids, nodes)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -243,7 +243,7 @@ func TestProcessManual_EndWithFilter(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
data := newData(requestId) data := newData(requestId)
err = ProcessManual(namespace.RootContext(context.Background()), data, ids, nodes) err = processManual(namespace.RootContext(context.Background()), data, ids, nodes)
require.NoError(t, err) require.NoError(t, err)
} }

View File

@@ -50,6 +50,13 @@ func getOpts(opt ...Option) (options, error) {
return opts, nil return opts, nil
} }
// ValidateOptions can be used to validate options before they are required.
func ValidateOptions(opt ...Option) error {
_, err := getOpts(opt...)
return err
}
// WithID provides an optional ID. // WithID provides an optional ID.
func WithID(id string) Option { func WithID(id string) Option {
return func(o *options) error { return func(o *options) error {

View File

@@ -9,8 +9,8 @@ import (
) )
var ( var (
_ event.Labeler = (*MetricLabelerAuditSink)(nil) _ event.Labeler = (*metricLabelerAuditSink)(nil)
_ event.Labeler = (*MetricLabelerAuditFallback)(nil) _ event.Labeler = (*metricLabelerAuditFallback)(nil)
) )
var ( var (
@@ -20,18 +20,18 @@ var (
metricLabelAuditFallbackMiss = []string{"audit", "fallback", "miss"} metricLabelAuditFallbackMiss = []string{"audit", "fallback", "miss"}
) )
// MetricLabelerAuditSink can be used to provide labels for the success or failure // metricLabelerAuditSink can be used to provide labels for the success or failure
// of a sink node used for a normal audit device. // of a sink node used for a normal audit device.
type MetricLabelerAuditSink struct{} type metricLabelerAuditSink struct{}
// MetricLabelerAuditFallback can be used to provide labels for the success or failure // metricLabelerAuditFallback can be used to provide labels for the success or failure
// of a sink node used for an audit fallback device. // of a sink node used for an audit fallback device.
type MetricLabelerAuditFallback struct{} type metricLabelerAuditFallback struct{}
// Labels provides the success and failure labels for an audit sink, based on the error supplied. // Labels provides the success and failure labels for an audit sink, based on the error supplied.
// Success: 'vault.audit.sink.success' // Success: 'vault.audit.sink.success'
// Failure: 'vault.audit.sink.failure' // Failure: 'vault.audit.sink.failure'
func (m MetricLabelerAuditSink) Labels(_ *eventlogger.Event, err error) []string { func (m metricLabelerAuditSink) Labels(_ *eventlogger.Event, err error) []string {
if err != nil { if err != nil {
return metricLabelAuditSinkFailure return metricLabelAuditSinkFailure
} }
@@ -42,7 +42,7 @@ func (m MetricLabelerAuditSink) Labels(_ *eventlogger.Event, err error) []string
// Labels provides the success and failures labels for an audit fallback sink, based on the error supplied. // Labels provides the success and failures labels for an audit fallback sink, based on the error supplied.
// Success: 'vault.audit.fallback.success' // Success: 'vault.audit.fallback.success'
// Failure: 'vault.audit.sink.failure' // Failure: 'vault.audit.sink.failure'
func (m MetricLabelerAuditFallback) Labels(_ *eventlogger.Event, err error) []string { func (m metricLabelerAuditFallback) Labels(_ *eventlogger.Event, err error) []string {
if err != nil { if err != nil {
return metricLabelAuditSinkFailure return metricLabelAuditSinkFailure
} }

View File

@@ -35,7 +35,7 @@ func TestMetricLabelerAuditSink_Label(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
m := &MetricLabelerAuditSink{} m := &metricLabelerAuditSink{}
result := m.Labels(nil, tc.err) result := m.Labels(nil, tc.err)
assert.Equal(t, tc.expected, result) assert.Equal(t, tc.expected, result)
}) })
@@ -67,7 +67,7 @@ func TestMetricLabelerAuditFallback_Label(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
m := &MetricLabelerAuditFallback{} m := &metricLabelerAuditFallback{}
result := m.Labels(nil, tc.err) result := m.Labels(nil, tc.err)
assert.Equal(t, tc.expected, result) assert.Equal(t, tc.expected, result)
}) })

View File

@@ -13,21 +13,21 @@ import (
"github.com/hashicorp/eventlogger" "github.com/hashicorp/eventlogger"
) )
var _ eventlogger.Node = (*SinkMetricTimer)(nil) var _ eventlogger.Node = (*sinkMetricTimer)(nil)
// SinkMetricTimer is a wrapper for any kind of eventlogger.NodeTypeSink node that // sinkMetricTimer is a wrapper for any kind of eventlogger.NodeTypeSink node that
// processes events containing an AuditEvent payload. // processes events containing an AuditEvent payload.
// It decorates the implemented eventlogger.Node Process method in order to emit // It decorates the implemented eventlogger.Node Process method in order to emit
// timing metrics for the duration between the creation time of the event and the // timing metrics for the duration between the creation time of the event and the
// time the node completes processing. // time the node completes processing.
type SinkMetricTimer struct { type sinkMetricTimer struct {
Name string name string
Sink eventlogger.Node sink eventlogger.Node
} }
// NewSinkMetricTimer should be used to create the SinkMetricTimer. // newSinkMetricTimer should be used to create the sinkMetricTimer.
// It expects that an eventlogger.NodeTypeSink should be supplied as the sink. // It expects that an eventlogger.NodeTypeSink should be supplied as the sink.
func NewSinkMetricTimer(name string, sink eventlogger.Node) (*SinkMetricTimer, error) { func newSinkMetricTimer(name string, sink eventlogger.Node) (*sinkMetricTimer, error) {
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
if name == "" { if name == "" {
return nil, fmt.Errorf("name is required: %w", ErrInvalidParameter) return nil, fmt.Errorf("name is required: %w", ErrInvalidParameter)
@@ -41,9 +41,9 @@ func NewSinkMetricTimer(name string, sink eventlogger.Node) (*SinkMetricTimer, e
return nil, fmt.Errorf("sink node must be of type 'sink': %w", ErrInvalidParameter) return nil, fmt.Errorf("sink node must be of type 'sink': %w", ErrInvalidParameter)
} }
return &SinkMetricTimer{ return &sinkMetricTimer{
Name: name, name: name,
Sink: sink, sink: sink,
}, nil }, nil
} }
@@ -54,23 +54,23 @@ func NewSinkMetricTimer(name string, sink eventlogger.Node) (*SinkMetricTimer, e
// Examples: // Examples:
// 'vault.audit.{DEVICE}.log_request' // 'vault.audit.{DEVICE}.log_request'
// 'vault.audit.{DEVICE}.log_response' // 'vault.audit.{DEVICE}.log_response'
func (s *SinkMetricTimer) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) { func (s *sinkMetricTimer) Process(ctx context.Context, e *eventlogger.Event) (*eventlogger.Event, error) {
defer func() { defer func() {
auditEvent, ok := e.Payload.(*AuditEvent) auditEvent, ok := e.Payload.(*AuditEvent)
if ok { if ok {
metrics.MeasureSince([]string{"audit", s.Name, auditEvent.Subtype.MetricTag()}, e.CreatedAt) metrics.MeasureSince([]string{"audit", s.name, auditEvent.Subtype.MetricTag()}, e.CreatedAt)
} }
}() }()
return s.Sink.Process(ctx, e) return s.sink.Process(ctx, e)
} }
// Reopen wraps the Reopen method of this underlying sink (eventlogger.Node). // Reopen wraps the Reopen method of this underlying sink (eventlogger.Node).
func (s *SinkMetricTimer) Reopen() error { func (s *sinkMetricTimer) Reopen() error {
return s.Sink.Reopen() return s.sink.Reopen()
} }
// Type wraps the Type method of this underlying sink (eventlogger.Node). // Type wraps the Type method of this underlying sink (eventlogger.Node).
func (s *SinkMetricTimer) Type() eventlogger.NodeType { func (s *sinkMetricTimer) Type() eventlogger.NodeType {
return s.Sink.Type() return s.sink.Type()
} }

View File

@@ -12,7 +12,7 @@ import (
) )
// TestNewSinkMetricTimer ensures that parameters are checked correctly and errors // TestNewSinkMetricTimer ensures that parameters are checked correctly and errors
// reported as expected when attempting to create a SinkMetricTimer. // reported as expected when attempting to create a sinkMetricTimer.
func TestNewSinkMetricTimer(t *testing.T) { func TestNewSinkMetricTimer(t *testing.T) {
t.Parallel() t.Parallel()
@@ -40,7 +40,7 @@ func TestNewSinkMetricTimer(t *testing.T) {
}, },
"bad-node": { "bad-node": {
name: "foo", name: "foo",
node: &EntryFormatter{}, node: &entryFormatter{},
isErrorExpected: true, isErrorExpected: true,
expectedErrorMessage: "sink node must be of type 'sink': invalid internal parameter", expectedErrorMessage: "sink node must be of type 'sink': invalid internal parameter",
}, },
@@ -52,7 +52,7 @@ func TestNewSinkMetricTimer(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
m, err := NewSinkMetricTimer(tc.name, tc.node) m, err := newSinkMetricTimer(tc.name, tc.node)
switch { switch {
case tc.isErrorExpected: case tc.isErrorExpected:

View File

@@ -4,68 +4,9 @@
package audit package audit
import ( import (
"context"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
) )
// Backend interface must be implemented for an audit
// mechanism to be made available. Audit backends can be enabled to
// sink information to different backends such as logs, file, databases,
// or other external services.
type Backend interface {
// Salter interface must be implemented by anything implementing Backend.
Salter
// The PipelineReader interface allows backends to surface information about their
// nodes for node and pipeline registration.
event.PipelineReader
// IsFallback can be used to determine if this audit backend device is intended to
// be used as a fallback to catch all events that are not written when only using
// filtered pipelines.
IsFallback() bool
// LogTestMessage is used to check an audit backend before adding it
// permanently. It should attempt to synchronously log the given test
// message, WITHOUT using the normal Salt (which would require a storage
// operation on creation, which is currently disallowed.)
LogTestMessage(context.Context, *logical.LogInput) error
// Reload is called on SIGHUP for supporting backends.
Reload() error
// Invalidate is called for path invalidation
Invalidate(context.Context)
}
// Salter is an interface that provides a way to obtain a Salt for hashing.
type Salter interface {
// Salt returns a non-nil salt or an error.
Salt(context.Context) (*salt.Salt, error)
}
// Formatter is an interface that is responsible for formatting a request/response into some format.
// It is recommended that you pass data through Hash prior to formatting it.
type Formatter interface {
// FormatRequest formats the logical.LogInput into an RequestEntry.
FormatRequest(context.Context, *logical.LogInput, timeProvider) (*RequestEntry, error)
// FormatResponse formats the logical.LogInput into an ResponseEntry.
FormatResponse(context.Context, *logical.LogInput, timeProvider) (*ResponseEntry, error)
}
// HeaderFormatter is an interface defining the methods of the
// vault.AuditedHeadersConfig structure needed in this package.
type HeaderFormatter interface {
// ApplyConfig returns a map of header values that consists of the
// intersection of the provided set of header values with a configured
// set of headers and will hash headers that have been configured as such.
ApplyConfig(context.Context, map[string][]string, Salter) (map[string][]string, error)
}
// RequestEntry is the structure of a request audit log entry. // RequestEntry is the structure of a request audit log entry.
type RequestEntry struct { type RequestEntry struct {
Auth *Auth `json:"auth,omitempty"` Auth *Auth `json:"auth,omitempty"`
@@ -179,28 +120,3 @@ type Namespace struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Path string `json:"path,omitempty"` Path string `json:"path,omitempty"`
} }
// nonPersistentSalt is used for obtaining a salt that is not persisted.
type nonPersistentSalt struct{}
// BackendConfig contains configuration parameters used in the factory func to
// instantiate audit backends
type BackendConfig struct {
// The view to store the salt
SaltView logical.Storage
// The salt config that should be used for any secret obfuscation
SaltConfig *salt.Config
// Config is the opaque user configuration provided when mounting
Config map[string]string
// MountPath is the path where this Backend is mounted
MountPath string
// Logger is used to emit log messages usually captured in the server logs.
Logger hclog.Logger
}
// Factory is the factory function to create an audit backend.
type Factory func(*BackendConfig, HeaderFormatter) (Backend, error)

View File

@@ -1,356 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package file
import (
"context"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
const (
stdout = "stdout"
discard = "discard"
)
var _ audit.Backend = (*Backend)(nil)
// Backend is the audit backend for the file-based audit store.
//
// NOTE: This audit backend is currently very simple: it appends to a file.
// It doesn't do anything more at the moment to assist with rotation
// or reset the write cursor, this should be done in the future.
type Backend struct {
fallback bool
name string
nodeIDList []eventlogger.NodeID
nodeMap map[eventlogger.NodeID]eventlogger.Node
salt *atomic.Value
saltConfig *salt.Config
saltMutex sync.RWMutex
saltView logical.Storage
}
func Factory(conf *audit.BackendConfig, headersConfig audit.HeaderFormatter) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config: %w", audit.ErrInvalidParameter)
}
if conf.SaltView == nil {
return nil, fmt.Errorf("nil salt view: %w", audit.ErrInvalidParameter)
}
if conf.Logger == nil || reflect.ValueOf(conf.Logger).IsNil() {
return nil, fmt.Errorf("nil logger: %w", audit.ErrInvalidParameter)
}
if conf.MountPath == "" {
return nil, fmt.Errorf("mount path cannot be empty: %w", audit.ErrInvalidParameter)
}
// The config options 'fallback' and 'filter' are mutually exclusive, a fallback
// device catches everything, so it cannot be allowed to filter.
var fallback bool
var err error
if fallbackRaw, ok := conf.Config["fallback"]; ok {
fallback, err = parseutil.ParseBool(fallbackRaw)
if err != nil {
return nil, fmt.Errorf("unable to parse 'fallback': %w", audit.ErrExternalOptions)
}
}
if _, ok := conf.Config["filter"]; ok && fallback {
return nil, fmt.Errorf("cannot configure a fallback device with a filter: %w", audit.ErrExternalOptions)
}
// Get file path from config or fall back to the old option name ('path') for compatibility
// (see commit bac4fe0799a372ba1245db642f3f6cd1f1d02669).
var filePath string
if p, ok := conf.Config["file_path"]; ok {
filePath = p
} else if p, ok = conf.Config["path"]; ok {
filePath = p
} else {
return nil, fmt.Errorf("file_path is required: %w", audit.ErrExternalOptions)
}
// normalize file path if configured for stdout
if strings.EqualFold(filePath, stdout) {
filePath = stdout
}
if strings.EqualFold(filePath, discard) {
filePath = discard
}
cfg, err := newFormatterConfig(headersConfig, conf.Config)
if err != nil {
return nil, err
}
b := &Backend{
fallback: fallback,
name: conf.MountPath,
saltConfig: conf.SaltConfig,
saltView: conf.SaltView,
salt: new(atomic.Value),
nodeIDList: []eventlogger.NodeID{},
nodeMap: make(map[eventlogger.NodeID]eventlogger.Node),
}
// Ensure we are working with the right type by explicitly storing a nil of
// the right type
b.salt.Store((*salt.Salt)(nil))
err = b.configureFilterNode(conf.Config["filter"])
if err != nil {
return nil, err
}
err = b.configureFormatterNode(conf.MountPath, cfg, conf.Logger)
if err != nil {
return nil, err
}
err = b.configureSinkNode(conf.MountPath, filePath, conf.Config["mode"], cfg.RequiredFormat.String())
if err != nil {
return nil, fmt.Errorf("error configuring sink node: %w", err)
}
return b, nil
}
func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) {
s := b.salt.Load().(*salt.Salt)
if s != nil {
return s, nil
}
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
s = b.salt.Load().(*salt.Salt)
if s != nil {
return s, nil
}
newSalt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig)
if err != nil {
b.salt.Store((*salt.Salt)(nil))
return nil, err
}
b.salt.Store(newSalt)
return newSalt, nil
}
func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput) error {
if len(b.nodeIDList) > 0 {
return audit.ProcessManual(ctx, in, b.nodeIDList, b.nodeMap)
}
return nil
}
func (b *Backend) Reload() error {
for _, n := range b.nodeMap {
if n.Type() == eventlogger.NodeTypeSink {
return n.Reopen()
}
}
return nil
}
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt.Store((*salt.Salt)(nil))
}
// newFormatterConfig creates the configuration required by a formatter node using
// the config map supplied to the factory.
func newFormatterConfig(headerFormatter audit.HeaderFormatter, config map[string]string) (audit.FormatterConfig, error) {
var opts []audit.Option
if format, ok := config["format"]; ok {
if !audit.IsValidFormat(format) {
return audit.FormatterConfig{}, fmt.Errorf("unsupported 'format': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithFormat(format))
}
// Check if hashing of accessor is disabled
if hmacAccessorRaw, ok := config["hmac_accessor"]; ok {
v, err := strconv.ParseBool(hmacAccessorRaw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'hmac_accessor': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithHMACAccessor(v))
}
// Check if raw logging is enabled
if raw, ok := config["log_raw"]; ok {
v, err := strconv.ParseBool(raw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'log_raw: %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithRaw(v))
}
if elideListResponsesRaw, ok := config["elide_list_responses"]; ok {
v, err := strconv.ParseBool(elideListResponsesRaw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'elide_list_responses': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithElision(v))
}
if prefix, ok := config["prefix"]; ok {
opts = append(opts, audit.WithPrefix(prefix))
}
return audit.NewFormatterConfig(headerFormatter, opts...)
}
// configureFormatterNode is used to configure a formatter node and associated ID on the Backend.
func (b *Backend) configureFormatterNode(name string, formatConfig audit.FormatterConfig, logger hclog.Logger) error {
formatterNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for formatter node: %w: %w", audit.ErrInternal, err)
}
formatterNode, err := audit.NewEntryFormatter(name, formatConfig, b, logger)
if err != nil {
return fmt.Errorf("error creating formatter: %w", err)
}
b.nodeIDList = append(b.nodeIDList, formatterNodeID)
b.nodeMap[formatterNodeID] = formatterNode
return nil
}
// configureSinkNode is used to configure a sink node and associated ID on the Backend.
func (b *Backend) configureSinkNode(name string, filePath string, mode string, format string) error {
name = strings.TrimSpace(name)
if name == "" {
return fmt.Errorf("name is required: %w", audit.ErrExternalOptions)
}
filePath = strings.TrimSpace(filePath)
if filePath == "" {
return fmt.Errorf("file path is required: %w", audit.ErrExternalOptions)
}
format = strings.TrimSpace(format)
if format == "" {
return fmt.Errorf("format is required: %w", audit.ErrInvalidParameter)
}
sinkNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for sink node: %w: %w", audit.ErrInternal, err)
}
// normalize file path if configured for stdout or discard
if strings.EqualFold(filePath, stdout) {
filePath = stdout
} else if strings.EqualFold(filePath, discard) {
filePath = discard
}
var sinkNode eventlogger.Node
var sinkName string
switch filePath {
case stdout:
sinkName = stdout
sinkNode, err = event.NewStdoutSinkNode(format)
case discard:
sinkName = discard
sinkNode = event.NewNoopSink()
default:
// The NewFileSink function attempts to open the file and will return an error if it can't.
sinkName = name
sinkNode, err = event.NewFileSink(filePath, format, []event.Option{event.WithFileMode(mode)}...)
}
if err != nil {
return fmt.Errorf("file sink creation failed for path %q: %w", filePath, err)
}
// Wrap the sink node with metrics middleware
sinkMetricTimer, err := audit.NewSinkMetricTimer(sinkName, sinkNode)
if err != nil {
return fmt.Errorf("unable to add timing metrics to sink for path %q: %w", filePath, err)
}
// Decide what kind of labels we want and wrap the sink node inside a metrics counter.
var metricLabeler event.Labeler
switch {
case b.fallback:
metricLabeler = &audit.MetricLabelerAuditFallback{}
default:
metricLabeler = &audit.MetricLabelerAuditSink{}
}
sinkMetricCounter, err := event.NewMetricsCounter(sinkName, sinkMetricTimer, metricLabeler)
if err != nil {
return fmt.Errorf("unable to add counting metrics to sink for path %q: %w", filePath, err)
}
b.nodeIDList = append(b.nodeIDList, sinkNodeID)
b.nodeMap[sinkNodeID] = sinkMetricCounter
return nil
}
// Name for this backend, this would ideally correspond to the mount path for the audit device.
func (b *Backend) Name() string {
return b.name
}
// Nodes returns the nodes which should be used by the event framework to process audit entries.
func (b *Backend) Nodes() map[eventlogger.NodeID]eventlogger.Node {
return b.nodeMap
}
// NodeIDs returns the IDs of the nodes, in the order they are required.
func (b *Backend) NodeIDs() []eventlogger.NodeID {
return b.nodeIDList
}
// EventType returns the event type for the backend.
func (b *Backend) EventType() eventlogger.EventType {
return event.AuditType.AsEventType()
}
// HasFiltering determines if the first node for the pipeline is an eventlogger.NodeTypeFilter.
func (b *Backend) HasFiltering() bool {
if b.nodeMap == nil {
return false
}
return len(b.nodeIDList) > 0 && b.nodeMap[b.nodeIDList[0]].Type() == eventlogger.NodeTypeFilter
}
// IsFallback can be used to determine if this audit backend device is intended to
// be used as a fallback to catch all events that are not written when only using
// filtered pipelines.
func (b *Backend) IsFallback() bool {
return b.fallback
}

View File

@@ -1,11 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package file
// configureFilterNode is used to configure a filter node and associated ID on the Backend.
func (b *Backend) configureFilterNode(_ string) error {
return nil
}

View File

@@ -1,99 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package file
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/audit"
"github.com/stretchr/testify/require"
)
// TestBackend_configureFilterNode ensures that configureFilterNode handles various
// filter values as expected. Empty (including whitespace) strings should return
// no error but skip configuration of the node.
// NOTE: Audit filtering is an Enterprise feature and behaves differently in the
// community edition of Vault.
func TestBackend_configureFilterNode(t *testing.T) {
t.Parallel()
tests := map[string]struct {
filter string
}{
"happy": {
filter: "operation == update",
},
"empty": {
filter: "",
},
"spacey": {
filter: " ",
},
"bad": {
filter: "___qwerty",
},
"unsupported-field": {
filter: "foo == bar",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
err := b.configureFilterNode(tc.filter)
require.NoError(t, err)
require.Len(t, b.nodeIDList, 0)
require.Len(t, b.nodeMap, 0)
})
}
}
// TestBackend_configureFilterFormatterSink ensures that configuring all three
// types of nodes on a Backend works as expected, i.e. we have only formatter and sink
// nodes at the end and nothing gets overwritten. The order of calls influences the
// slice of IDs on the Backend.
// NOTE: Audit filtering is an Enterprise feature and behaves differently in the
// community edition of Vault.
func TestBackend_configureFilterFormatterSink(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
formatConfig, err := audit.NewFormatterConfig(&audit.NoopHeaderFormatter{})
require.NoError(t, err)
err = b.configureFilterNode("path == bar")
require.NoError(t, err)
err = b.configureFormatterNode("juan", formatConfig, hclog.NewNullLogger())
require.NoError(t, err)
err = b.configureSinkNode("foo", "/tmp/foo", "0777", "json")
require.NoError(t, err)
require.Len(t, b.nodeIDList, 2)
require.Len(t, b.nodeMap, 2)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeFormatter, node.Type())
id = b.nodeIDList[1]
node = b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
}

View File

@@ -1,555 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package file
import (
"os"
"path/filepath"
"strconv"
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestAuditFile_fileModeNew verifies that the backend Factory correctly sets
// the file mode when the mode argument is set.
func TestAuditFile_fileModeNew(t *testing.T) {
t.Parallel()
modeStr := "0777"
mode, err := strconv.ParseUint(modeStr, 8, 32)
require.NoError(t, err)
file := filepath.Join(t.TempDir(), "auditTest.txt")
backendConfig := &audit.BackendConfig{
Config: map[string]string{
"path": file,
"mode": modeStr,
},
MountPath: "foo/bar",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
}
_, err = Factory(backendConfig, &audit.NoopHeaderFormatter{})
require.NoError(t, err)
info, err := os.Stat(file)
require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`")
require.Equalf(t, os.FileMode(mode), info.Mode(), "File mode does not match.")
}
// TestAuditFile_fileModeExisting verifies that the backend Factory correctly sets
// the mode on an existing file.
func TestAuditFile_fileModeExisting(t *testing.T) {
t.Parallel()
dir := t.TempDir()
f, err := os.CreateTemp(dir, "auditTest.log")
require.NoErrorf(t, err, "Failure to create test file.")
err = os.Chmod(f.Name(), 0o777)
require.NoErrorf(t, err, "Failure to chmod temp file for testing.")
err = f.Close()
require.NoErrorf(t, err, "Failure to close temp file for test.")
backendConfig := &audit.BackendConfig{
Config: map[string]string{
"path": f.Name(),
},
MountPath: "foo/bar",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
}
_, err = Factory(backendConfig, &audit.NoopHeaderFormatter{})
require.NoError(t, err)
info, err := os.Stat(f.Name())
require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`")
require.Equalf(t, os.FileMode(0o600), info.Mode(), "File mode does not match.")
}
// TestAuditFile_fileMode0000 verifies that setting the audit file mode to
// "0000" prevents Vault from modifying the permissions of the file.
func TestAuditFile_fileMode0000(t *testing.T) {
t.Parallel()
dir := t.TempDir()
f, err := os.CreateTemp(dir, "auditTest.log")
require.NoErrorf(t, err, "Failure to create test file.")
err = os.Chmod(f.Name(), 0o777)
require.NoErrorf(t, err, "Failure to chmod temp file for testing.")
err = f.Close()
require.NoErrorf(t, err, "Failure to close temp file for test.")
backendConfig := &audit.BackendConfig{
Config: map[string]string{
"path": f.Name(),
"mode": "0000",
},
MountPath: "foo/bar",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
}
_, err = Factory(backendConfig, &audit.NoopHeaderFormatter{})
require.NoError(t, err)
info, err := os.Stat(f.Name())
require.NoErrorf(t, err, "cannot retrieve file mode from `Stat`. The error is %v", err)
require.Equalf(t, os.FileMode(0o777), info.Mode(), "File mode does not match.")
}
// TestAuditFile_EventLogger_fileModeNew verifies that the Factory function
// correctly sets the file mode when the useEventLogger argument is set to
// true.
func TestAuditFile_EventLogger_fileModeNew(t *testing.T) {
modeStr := "0777"
mode, err := strconv.ParseUint(modeStr, 8, 32)
require.NoError(t, err)
file := filepath.Join(t.TempDir(), "auditTest.txt")
backendConfig := &audit.BackendConfig{
Config: map[string]string{
"path": file,
"mode": modeStr,
},
MountPath: "foo/bar",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
}
_, err = Factory(backendConfig, &audit.NoopHeaderFormatter{})
require.NoError(t, err)
info, err := os.Stat(file)
require.NoErrorf(t, err, "Cannot retrieve file mode from `Stat`")
require.Equalf(t, os.FileMode(mode), info.Mode(), "File mode does not match.")
}
// TestBackend_newFormatterConfig ensures that all the configuration values are parsed correctly.
func TestBackend_newFormatterConfig(t *testing.T) {
t.Parallel()
tests := map[string]struct {
config map[string]string
want audit.FormatterConfig
wantErr bool
expectedMessage string
}{
"happy-path-json": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{
Raw: true,
HMACAccessor: true,
ElideListResponses: true,
RequiredFormat: "json",
}, wantErr: false,
},
"happy-path-jsonx": {
config: map[string]string{
"format": audit.JSONxFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{
Raw: true,
HMACAccessor: true,
ElideListResponses: true,
RequiredFormat: "jsonx",
},
wantErr: false,
},
"invalid-format": {
config: map[string]string{
"format": " squiggly ",
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedMessage: "unsupported 'format': invalid configuration",
},
"invalid-hmac-accessor": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedMessage: "unable to parse 'hmac_accessor': invalid configuration",
},
"invalid-log-raw": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedMessage: "unable to parse 'log_raw: invalid configuration",
},
"invalid-elide-bool": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedMessage: "unable to parse 'elide_list_responses': invalid configuration",
},
"prefix": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"prefix": "foo",
},
want: audit.FormatterConfig{
RequiredFormat: audit.JSONFormat,
Prefix: "foo",
HMACAccessor: true,
},
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
got, err := newFormatterConfig(&audit.NoopHeaderFormatter{}, tc.config)
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedMessage)
} else {
require.NoError(t, err)
}
require.Equal(t, tc.want.RequiredFormat, got.RequiredFormat)
require.Equal(t, tc.want.Raw, got.Raw)
require.Equal(t, tc.want.ElideListResponses, got.ElideListResponses)
require.Equal(t, tc.want.HMACAccessor, got.HMACAccessor)
require.Equal(t, tc.want.OmitTime, got.OmitTime)
require.Equal(t, tc.want.Prefix, got.Prefix)
})
}
}
// TestBackend_configureFormatterNode ensures that configureFormatterNode
// populates the nodeIDList and nodeMap on Backend when given valid formatConfig.
func TestBackend_configureFormatterNode(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
formatConfig, err := audit.NewFormatterConfig(&audit.NoopHeaderFormatter{})
require.NoError(t, err)
err = b.configureFormatterNode("juan", formatConfig, hclog.NewNullLogger())
require.NoError(t, err)
require.Len(t, b.nodeIDList, 1)
require.Len(t, b.nodeMap, 1)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeFormatter, node.Type())
}
// TestBackend_configureSinkNode ensures that we can correctly configure the sink
// node on the Backend, and any incorrect parameters result in the relevant errors.
func TestBackend_configureSinkNode(t *testing.T) {
t.Parallel()
tests := map[string]struct {
name string
filePath string
mode string
format string
wantErr bool
expectedErrMsg string
expectedName string
}{
"name-empty": {
name: "",
wantErr: true,
expectedErrMsg: "name is required: invalid configuration",
},
"name-whitespace": {
name: " ",
wantErr: true,
expectedErrMsg: "name is required: invalid configuration",
},
"filePath-empty": {
name: "foo",
filePath: "",
wantErr: true,
expectedErrMsg: "file path is required: invalid configuration",
},
"filePath-whitespace": {
name: "foo",
filePath: " ",
wantErr: true,
expectedErrMsg: "file path is required: invalid configuration",
},
"filePath-stdout-lower": {
name: "foo",
expectedName: "stdout",
filePath: "stdout",
format: "json",
},
"filePath-stdout-upper": {
name: "foo",
expectedName: "stdout",
filePath: "STDOUT",
format: "json",
},
"filePath-stdout-mixed": {
name: "foo",
expectedName: "stdout",
filePath: "StdOut",
format: "json",
},
"filePath-discard-lower": {
name: "foo",
expectedName: "discard",
filePath: "discard",
format: "json",
},
"filePath-discard-upper": {
name: "foo",
expectedName: "discard",
filePath: "DISCARD",
format: "json",
},
"filePath-discard-mixed": {
name: "foo",
expectedName: "discard",
filePath: "DisCArd",
format: "json",
},
"format-empty": {
name: "foo",
filePath: "/tmp/",
format: "",
wantErr: true,
expectedErrMsg: "format is required: invalid internal parameter",
},
"format-whitespace": {
name: "foo",
filePath: "/tmp/",
format: " ",
wantErr: true,
expectedErrMsg: "format is required: invalid internal parameter",
},
"filePath-weird-with-mode-zero": {
name: "foo",
filePath: "/tmp/qwerty",
format: "json",
mode: "0",
wantErr: true,
expectedErrMsg: "file sink creation failed for path \"/tmp/qwerty\": unable to determine existing file mode: stat /tmp/qwerty: no such file or directory",
},
"happy": {
name: "foo",
filePath: "/tmp/audit.log",
mode: "",
format: "json",
wantErr: false,
expectedName: "foo",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
err := b.configureSinkNode(tc.name, tc.filePath, tc.mode, tc.format)
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrMsg)
require.Len(t, b.nodeIDList, 0)
require.Len(t, b.nodeMap, 0)
} else {
require.NoError(t, err)
require.Len(t, b.nodeIDList, 1)
require.Len(t, b.nodeMap, 1)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
mc, ok := node.(*event.MetricsCounter)
require.True(t, ok)
require.Equal(t, tc.expectedName, mc.Name)
}
})
}
}
// TestBackend_Factory_Conf is used to ensure that any configuration which is
// supplied, is validated and tested.
func TestBackend_Factory_Conf(t *testing.T) {
t.Parallel()
tests := map[string]struct {
backendConfig *audit.BackendConfig
isErrorExpected bool
expectedErrorMessage string
}{
"nil-salt-config": {
backendConfig: &audit.BackendConfig{
SaltConfig: nil,
},
isErrorExpected: true,
expectedErrorMessage: "nil salt config: invalid internal parameter",
},
"nil-salt-view": {
backendConfig: &audit.BackendConfig{
SaltConfig: &salt.Config{},
},
isErrorExpected: true,
expectedErrorMessage: "nil salt view: invalid internal parameter",
},
"nil-logger": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: nil,
},
isErrorExpected: true,
expectedErrorMessage: "nil logger: invalid internal parameter",
},
"fallback-device-with-filter": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "true",
"file_path": discard,
"filter": "mount_type == kv",
},
},
isErrorExpected: true,
expectedErrorMessage: "cannot configure a fallback device with a filter: invalid configuration",
},
"non-fallback-device-with-filter": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "false",
"file_path": discard,
"filter": "mount_type == kv",
},
},
isErrorExpected: false,
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
be, err := Factory(tc.backendConfig, &audit.NoopHeaderFormatter{})
switch {
case tc.isErrorExpected:
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrorMessage)
default:
require.NoError(t, err)
require.NotNil(t, be)
}
})
}
}
// TestBackend_IsFallback ensures that the 'fallback' config setting is parsed
// and set correctly, then exposed via the interface method IsFallback().
func TestBackend_IsFallback(t *testing.T) {
t.Parallel()
tests := map[string]struct {
backendConfig *audit.BackendConfig
isFallbackExpected bool
}{
"fallback": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "true",
"file_path": discard,
},
},
isFallbackExpected: true,
},
"no-fallback": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "false",
"file_path": discard,
},
},
isFallbackExpected: false,
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
be, err := Factory(tc.backendConfig, &audit.NoopHeaderFormatter{})
require.NoError(t, err)
require.NotNil(t, be)
require.Equal(t, tc.isFallbackExpected, be.IsFallback())
})
}
}

View File

@@ -1,319 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package socket
import (
"context"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
var _ audit.Backend = (*Backend)(nil)
// Backend is the audit backend for the socket audit transport.
type Backend struct {
fallback bool
name string
nodeIDList []eventlogger.NodeID
nodeMap map[eventlogger.NodeID]eventlogger.Node
salt *salt.Salt
saltConfig *salt.Config
saltMutex sync.RWMutex
saltView logical.Storage
}
func Factory(conf *audit.BackendConfig, headersConfig audit.HeaderFormatter) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config: %w", audit.ErrInvalidParameter)
}
if conf.SaltView == nil {
return nil, fmt.Errorf("nil salt view: %w", audit.ErrInvalidParameter)
}
if conf.Logger == nil || reflect.ValueOf(conf.Logger).IsNil() {
return nil, fmt.Errorf("nil logger: %w", audit.ErrInvalidParameter)
}
if conf.MountPath == "" {
return nil, fmt.Errorf("mount path cannot be empty: %w", audit.ErrInvalidParameter)
}
address, ok := conf.Config["address"]
if !ok {
return nil, fmt.Errorf("address is required: %w", audit.ErrExternalOptions)
}
socketType, ok := conf.Config["socket_type"]
if !ok {
socketType = "tcp"
}
writeDeadline, ok := conf.Config["write_timeout"]
if !ok {
writeDeadline = "2s"
}
sinkOpts := []event.Option{
event.WithSocketType(socketType),
event.WithMaxDuration(writeDeadline),
}
err := event.ValidateOptions(sinkOpts...)
if err != nil {
return nil, err
}
// The config options 'fallback' and 'filter' are mutually exclusive, a fallback
// device catches everything, so it cannot be allowed to filter.
var fallback bool
if fallbackRaw, ok := conf.Config["fallback"]; ok {
fallback, err = parseutil.ParseBool(fallbackRaw)
if err != nil {
return nil, fmt.Errorf("unable to parse 'fallback': %w", audit.ErrExternalOptions)
}
}
if _, ok := conf.Config["filter"]; ok && fallback {
return nil, fmt.Errorf("cannot configure a fallback device with a filter: %w", audit.ErrExternalOptions)
}
cfg, err := newFormatterConfig(headersConfig, conf.Config)
if err != nil {
return nil, err
}
b := &Backend{
fallback: fallback,
name: conf.MountPath,
saltConfig: conf.SaltConfig,
saltView: conf.SaltView,
nodeIDList: []eventlogger.NodeID{},
nodeMap: make(map[eventlogger.NodeID]eventlogger.Node),
}
err = b.configureFilterNode(conf.Config["filter"])
if err != nil {
return nil, err
}
err = b.configureFormatterNode(conf.MountPath, cfg, conf.Logger)
if err != nil {
return nil, err
}
err = b.configureSinkNode(conf.MountPath, address, cfg.RequiredFormat.String(), sinkOpts...)
if err != nil {
return nil, err
}
return b, nil
}
func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput) error {
if len(b.nodeIDList) > 0 {
return audit.ProcessManual(ctx, in, b.nodeIDList, b.nodeMap)
}
return nil
}
func (b *Backend) Reload() error {
for _, n := range b.nodeMap {
if n.Type() == eventlogger.NodeTypeSink {
return n.Reopen()
}
}
return nil
}
func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) {
b.saltMutex.RLock()
if b.salt != nil {
defer b.saltMutex.RUnlock()
return b.salt, nil
}
b.saltMutex.RUnlock()
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
if b.salt != nil {
return b.salt, nil
}
s, err := salt.NewSalt(ctx, b.saltView, b.saltConfig)
if err != nil {
return nil, err
}
b.salt = s
return s, nil
}
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil
}
// newFormatterConfig creates the configuration required by a formatter node using
// the config map supplied to the factory.
func newFormatterConfig(headerFormatter audit.HeaderFormatter, config map[string]string) (audit.FormatterConfig, error) {
var opts []audit.Option
if format, ok := config["format"]; ok {
if !audit.IsValidFormat(format) {
return audit.FormatterConfig{}, fmt.Errorf("unsupported 'format': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithFormat(format))
}
// Check if hashing of accessor is disabled
if hmacAccessorRaw, ok := config["hmac_accessor"]; ok {
v, err := strconv.ParseBool(hmacAccessorRaw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'hmac_accessor': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithHMACAccessor(v))
}
// Check if raw logging is enabled
if raw, ok := config["log_raw"]; ok {
v, err := strconv.ParseBool(raw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'log_raw: %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithRaw(v))
}
if elideListResponsesRaw, ok := config["elide_list_responses"]; ok {
v, err := strconv.ParseBool(elideListResponsesRaw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'elide_list_responses': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithElision(v))
}
if prefix, ok := config["prefix"]; ok {
opts = append(opts, audit.WithPrefix(prefix))
}
return audit.NewFormatterConfig(headerFormatter, opts...)
}
// configureFormatterNode is used to configure a formatter node and associated ID on the Backend.
func (b *Backend) configureFormatterNode(name string, formatConfig audit.FormatterConfig, logger hclog.Logger) error {
formatterNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for formatter node: %w: %w", audit.ErrInternal, err)
}
formatterNode, err := audit.NewEntryFormatter(name, formatConfig, b, logger)
if err != nil {
return fmt.Errorf("error creating formatter: %w", err)
}
b.nodeIDList = append(b.nodeIDList, formatterNodeID)
b.nodeMap[formatterNodeID] = formatterNode
return nil
}
// configureSinkNode is used to configure a sink node and associated ID on the Backend.
func (b *Backend) configureSinkNode(name string, address string, format string, opts ...event.Option) error {
name = strings.TrimSpace(name)
if name == "" {
return fmt.Errorf("name is required: %w", audit.ErrInvalidParameter)
}
address = strings.TrimSpace(address)
if address == "" {
return fmt.Errorf("address is required: %w", audit.ErrInvalidParameter)
}
format = strings.TrimSpace(format)
if format == "" {
return fmt.Errorf("format is required: %w", audit.ErrInvalidParameter)
}
sinkNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for sink node: %w", err)
}
n, err := event.NewSocketSink(address, format, opts...)
if err != nil {
return err
}
// Wrap the sink node with metrics middleware
sinkMetricTimer, err := audit.NewSinkMetricTimer(name, n)
if err != nil {
return fmt.Errorf("unable to add timing metrics to sink for path %q: %w", name, err)
}
// Decide what kind of labels we want and wrap the sink node inside a metrics counter.
var metricLabeler event.Labeler
switch {
case b.fallback:
metricLabeler = &audit.MetricLabelerAuditFallback{}
default:
metricLabeler = &audit.MetricLabelerAuditSink{}
}
sinkMetricCounter, err := event.NewMetricsCounter(name, sinkMetricTimer, metricLabeler)
if err != nil {
return fmt.Errorf("unable to add counting metrics to sink for path %q: %w", name, err)
}
b.nodeIDList = append(b.nodeIDList, sinkNodeID)
b.nodeMap[sinkNodeID] = sinkMetricCounter
return nil
}
// Name for this backend, this would ideally correspond to the mount path for the audit device.
func (b *Backend) Name() string {
return b.name
}
// Nodes returns the nodes which should be used by the event framework to process audit entries.
func (b *Backend) Nodes() map[eventlogger.NodeID]eventlogger.Node {
return b.nodeMap
}
// NodeIDs returns the IDs of the nodes, in the order they are required.
func (b *Backend) NodeIDs() []eventlogger.NodeID {
return b.nodeIDList
}
// EventType returns the event type for the backend.
func (b *Backend) EventType() eventlogger.EventType {
return event.AuditType.AsEventType()
}
// HasFiltering determines if the first node for the pipeline is an eventlogger.NodeTypeFilter.
func (b *Backend) HasFiltering() bool {
if b.nodeMap == nil {
return false
}
return len(b.nodeIDList) > 0 && b.nodeMap[b.nodeIDList[0]].Type() == eventlogger.NodeTypeFilter
}
// IsFallback can be used to determine if this audit backend device is intended to
// be used as a fallback to catch all events that are not written when only using
// filtered pipelines.
func (b *Backend) IsFallback() bool {
return b.fallback
}

View File

@@ -1,11 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package socket
// configureFilterNode is used to configure a filter node and associated ID on the Backend.
func (b *Backend) configureFilterNode(_ string) error {
return nil
}

View File

@@ -1,99 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package socket
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/audit"
"github.com/stretchr/testify/require"
)
// TestBackend_configureFilterNode ensures that configureFilterNode handles various
// filter values as expected. Empty (including whitespace) strings should return
// no error but skip configuration of the node.
// NOTE: Audit filtering is an Enterprise feature and behaves differently in the
// community edition of Vault.
func TestBackend_configureFilterNode(t *testing.T) {
t.Parallel()
tests := map[string]struct {
filter string
}{
"happy": {
filter: "operation == update",
},
"empty": {
filter: "",
},
"spacey": {
filter: " ",
},
"bad": {
filter: "___qwerty",
},
"unsupported-field": {
filter: "foo == bar",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
err := b.configureFilterNode(tc.filter)
require.NoError(t, err)
require.Len(t, b.nodeIDList, 0)
require.Len(t, b.nodeMap, 0)
})
}
}
// TestBackend_configureFilterFormatterSink ensures that configuring all three
// types of nodes on a Backend works as expected, i.e. we have only formatter and sink
// nodes at the end and nothing gets overwritten. The order of calls influences the
// slice of IDs on the Backend.
// NOTE: Audit filtering is an Enterprise feature and behaves differently in the
// community edition of Vault.
func TestBackend_configureFilterFormatterSink(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
formatConfig, err := audit.NewFormatterConfig(&audit.NoopHeaderFormatter{})
require.NoError(t, err)
err = b.configureFilterNode("path == bar")
require.NoError(t, err)
err = b.configureFormatterNode("juan", formatConfig, hclog.NewNullLogger())
require.NoError(t, err)
err = b.configureSinkNode("foo", "https://hashicorp.com", "json")
require.NoError(t, err)
require.Len(t, b.nodeIDList, 2)
require.Len(t, b.nodeMap, 2)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeFormatter, node.Type())
id = b.nodeIDList[1]
node = b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
}

View File

@@ -1,451 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package socket
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestBackend_newFormatterConfig ensures that all the configuration values are parsed correctly.
func TestBackend_newFormatterConfig(t *testing.T) {
t.Parallel()
tests := map[string]struct {
config map[string]string
want audit.FormatterConfig
wantErr bool
expectedErrMsg string
}{
"happy-path-json": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{
Raw: true,
HMACAccessor: true,
ElideListResponses: true,
RequiredFormat: "json",
}, wantErr: false,
},
"happy-path-jsonx": {
config: map[string]string{
"format": audit.JSONxFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{
Raw: true,
HMACAccessor: true,
ElideListResponses: true,
RequiredFormat: "jsonx",
},
wantErr: false,
},
"invalid-format": {
config: map[string]string{
"format": " squiggly ",
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedErrMsg: "unsupported 'format': invalid configuration",
},
"invalid-hmac-accessor": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedErrMsg: "unable to parse 'hmac_accessor': invalid configuration",
},
"invalid-log-raw": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedErrMsg: "unable to parse 'log_raw: invalid configuration",
},
"invalid-elide-bool": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedErrMsg: "unable to parse 'elide_list_responses': invalid configuration",
},
"prefix": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"prefix": "foo",
},
want: audit.FormatterConfig{
RequiredFormat: audit.JSONFormat,
Prefix: "foo",
HMACAccessor: true,
},
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
got, err := newFormatterConfig(&audit.NoopHeaderFormatter{}, tc.config)
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrMsg)
} else {
require.NoError(t, err)
}
require.Equal(t, tc.want.RequiredFormat, got.RequiredFormat)
require.Equal(t, tc.want.Raw, got.Raw)
require.Equal(t, tc.want.ElideListResponses, got.ElideListResponses)
require.Equal(t, tc.want.HMACAccessor, got.HMACAccessor)
require.Equal(t, tc.want.OmitTime, got.OmitTime)
require.Equal(t, tc.want.Prefix, got.Prefix)
})
}
}
// TestBackend_configureFormatterNode ensures that configureFormatterNode
// populates the nodeIDList and nodeMap on Backend when given valid formatConfig.
func TestBackend_configureFormatterNode(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
formatConfig, err := audit.NewFormatterConfig(&audit.NoopHeaderFormatter{})
require.NoError(t, err)
err = b.configureFormatterNode("juan", formatConfig, hclog.NewNullLogger())
require.NoError(t, err)
require.Len(t, b.nodeIDList, 1)
require.Len(t, b.nodeMap, 1)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeFormatter, node.Type())
}
// TestBackend_configureSinkNode ensures that we can correctly configure the sink
// node on the Backend, and any incorrect parameters result in the relevant errors.
func TestBackend_configureSinkNode(t *testing.T) {
t.Parallel()
tests := map[string]struct {
name string
address string
format string
wantErr bool
expectedErrMsg string
expectedName string
}{
"name-empty": {
name: "",
address: "wss://foo",
wantErr: true,
expectedErrMsg: "name is required: invalid internal parameter",
},
"name-whitespace": {
name: " ",
address: "wss://foo",
wantErr: true,
expectedErrMsg: "name is required: invalid internal parameter",
},
"address-empty": {
name: "foo",
address: "",
wantErr: true,
expectedErrMsg: "address is required: invalid internal parameter",
},
"address-whitespace": {
name: "foo",
address: " ",
wantErr: true,
expectedErrMsg: "address is required: invalid internal parameter",
},
"format-empty": {
name: "foo",
address: "wss://foo",
format: "",
wantErr: true,
expectedErrMsg: "format is required: invalid internal parameter",
},
"format-whitespace": {
name: "foo",
address: "wss://foo",
format: " ",
wantErr: true,
expectedErrMsg: "format is required: invalid internal parameter",
},
"happy": {
name: "foo",
address: "wss://foo",
format: "json",
wantErr: false,
expectedName: "foo",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
err := b.configureSinkNode(tc.name, tc.address, tc.format)
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrMsg)
require.Len(t, b.nodeIDList, 0)
require.Len(t, b.nodeMap, 0)
} else {
require.NoError(t, err)
require.Len(t, b.nodeIDList, 1)
require.Len(t, b.nodeMap, 1)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
mc, ok := node.(*event.MetricsCounter)
require.True(t, ok)
require.Equal(t, tc.expectedName, mc.Name)
}
})
}
}
// TestBackend_Factory_Conf is used to ensure that any configuration which is
// supplied, is validated and tested.
func TestBackend_Factory_Conf(t *testing.T) {
t.Parallel()
tests := map[string]struct {
backendConfig *audit.BackendConfig
isErrorExpected bool
expectedErrorMessage string
}{
"nil-salt-config": {
backendConfig: &audit.BackendConfig{
SaltConfig: nil,
},
isErrorExpected: true,
expectedErrorMessage: "nil salt config: invalid internal parameter",
},
"nil-salt-view": {
backendConfig: &audit.BackendConfig{
SaltConfig: &salt.Config{},
},
isErrorExpected: true,
expectedErrorMessage: "nil salt view: invalid internal parameter",
},
"nil-logger": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: nil,
},
isErrorExpected: true,
expectedErrorMessage: "nil logger: invalid internal parameter",
},
"no-address": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{},
},
isErrorExpected: true,
expectedErrorMessage: "address is required: invalid configuration",
},
"empty-address": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"address": "",
},
},
isErrorExpected: true,
expectedErrorMessage: "address is required: invalid internal parameter",
},
"whitespace-address": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"address": " ",
},
},
isErrorExpected: true,
expectedErrorMessage: "address is required: invalid internal parameter",
},
"write-duration-valid": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"address": "hashicorp.com",
"write_timeout": "5s",
},
},
isErrorExpected: false,
},
"write-duration-not-valid": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"address": "hashicorp.com",
"write_timeout": "qwerty",
},
},
isErrorExpected: true,
expectedErrorMessage: "unable to parse max duration: invalid parameter: time: invalid duration \"qwerty\"",
},
"non-fallback-device-with-filter": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"address": "hashicorp.com",
"write_timeout": "5s",
"fallback": "false",
"filter": "mount_type == kv",
},
},
isErrorExpected: false,
},
"fallback-device-with-filter": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"address": "hashicorp.com",
"write_timeout": "2s",
"fallback": "true",
"filter": "mount_type == kv",
},
},
isErrorExpected: true,
expectedErrorMessage: "cannot configure a fallback device with a filter: invalid configuration",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
be, err := Factory(tc.backendConfig, &audit.NoopHeaderFormatter{})
switch {
case tc.isErrorExpected:
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrorMessage)
default:
require.NoError(t, err)
require.NotNil(t, be)
}
})
}
}
// TestBackend_IsFallback ensures that the 'fallback' config setting is parsed
// and set correctly, then exposed via the interface method IsFallback().
func TestBackend_IsFallback(t *testing.T) {
t.Parallel()
tests := map[string]struct {
backendConfig *audit.BackendConfig
isFallbackExpected bool
}{
"fallback": {
backendConfig: &audit.BackendConfig{
MountPath: "qwerty",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "true",
"address": "hashicorp.com",
"write_timeout": "5s",
},
},
isFallbackExpected: true,
},
"no-fallback": {
backendConfig: &audit.BackendConfig{
MountPath: "qwerty",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "false",
"address": "hashicorp.com",
"write_timeout": "5s",
},
},
isFallbackExpected: false,
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
be, err := Factory(tc.backendConfig, &audit.NoopHeaderFormatter{})
require.NoError(t, err)
require.NotNil(t, be)
require.Equal(t, tc.isFallbackExpected, be.IsFallback())
})
}
}

View File

@@ -1,306 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package syslog
import (
"context"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
var _ audit.Backend = (*Backend)(nil)
// Backend is the audit backend for the syslog-based audit store.
type Backend struct {
fallback bool
name string
nodeIDList []eventlogger.NodeID
nodeMap map[eventlogger.NodeID]eventlogger.Node
salt *salt.Salt
saltConfig *salt.Config
saltMutex sync.RWMutex
saltView logical.Storage
}
func Factory(conf *audit.BackendConfig, headersConfig audit.HeaderFormatter) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config: %w", audit.ErrInvalidParameter)
}
if conf.SaltView == nil {
return nil, fmt.Errorf("nil salt view: %w", audit.ErrInvalidParameter)
}
if conf.Logger == nil || reflect.ValueOf(conf.Logger).IsNil() {
return nil, fmt.Errorf("nil logger: %w", audit.ErrInvalidParameter)
}
if conf.MountPath == "" {
return nil, fmt.Errorf("mount path cannot be empty: %w", audit.ErrInvalidParameter)
}
// Get facility or default to AUTH
facility, ok := conf.Config["facility"]
if !ok {
facility = "AUTH"
}
// Get tag or default to 'vault'
tag, ok := conf.Config["tag"]
if !ok {
tag = "vault"
}
sinkOpts := []event.Option{
event.WithFacility(facility),
event.WithTag(tag),
}
err := event.ValidateOptions(sinkOpts...)
if err != nil {
return nil, err
}
// The config options 'fallback' and 'filter' are mutually exclusive, a fallback
// device catches everything, so it cannot be allowed to filter.
var fallback bool
if fallbackRaw, ok := conf.Config["fallback"]; ok {
fallback, err = parseutil.ParseBool(fallbackRaw)
if err != nil {
return nil, fmt.Errorf("unable to parse 'fallback': %w", audit.ErrExternalOptions)
}
}
if _, ok := conf.Config["filter"]; ok && fallback {
return nil, fmt.Errorf("cannot configure a fallback device with a filter: %w", audit.ErrExternalOptions)
}
cfg, err := newFormatterConfig(headersConfig, conf.Config)
if err != nil {
return nil, err
}
b := &Backend{
fallback: fallback,
name: conf.MountPath,
saltConfig: conf.SaltConfig,
saltView: conf.SaltView,
nodeIDList: []eventlogger.NodeID{},
nodeMap: make(map[eventlogger.NodeID]eventlogger.Node),
}
err = b.configureFilterNode(conf.Config["filter"])
if err != nil {
return nil, err
}
err = b.configureFormatterNode(conf.MountPath, cfg, conf.Logger)
if err != nil {
return nil, err
}
err = b.configureSinkNode(conf.MountPath, cfg.RequiredFormat.String(), sinkOpts...)
if err != nil {
return nil, err
}
return b, nil
}
func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput) error {
if len(b.nodeIDList) > 0 {
return audit.ProcessManual(ctx, in, b.nodeIDList, b.nodeMap)
}
return nil
}
func (b *Backend) Reload() error {
return nil
}
func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) {
b.saltMutex.RLock()
if b.salt != nil {
defer b.saltMutex.RUnlock()
return b.salt, nil
}
b.saltMutex.RUnlock()
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
if b.salt != nil {
return b.salt, nil
}
s, err := salt.NewSalt(ctx, b.saltView, b.saltConfig)
if err != nil {
return nil, err
}
b.salt = s
return s, nil
}
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil
}
// newFormatterConfig creates the configuration required by a formatter node using
// the config map supplied to the factory.
func newFormatterConfig(headerFormatter audit.HeaderFormatter, config map[string]string) (audit.FormatterConfig, error) {
var opts []audit.Option
if format, ok := config["format"]; ok {
if !audit.IsValidFormat(format) {
return audit.FormatterConfig{}, fmt.Errorf("unsupported 'format': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithFormat(format))
}
// Check if hashing of accessor is disabled
if hmacAccessorRaw, ok := config["hmac_accessor"]; ok {
v, err := strconv.ParseBool(hmacAccessorRaw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'hmac_accessor': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithHMACAccessor(v))
}
// Check if raw logging is enabled
if raw, ok := config["log_raw"]; ok {
v, err := strconv.ParseBool(raw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'log_raw: %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithRaw(v))
}
if elideListResponsesRaw, ok := config["elide_list_responses"]; ok {
v, err := strconv.ParseBool(elideListResponsesRaw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("unable to parse 'elide_list_responses': %w", audit.ErrExternalOptions)
}
opts = append(opts, audit.WithElision(v))
}
if prefix, ok := config["prefix"]; ok {
opts = append(opts, audit.WithPrefix(prefix))
}
return audit.NewFormatterConfig(headerFormatter, opts...)
}
// configureFormatterNode is used to configure a formatter node and associated ID on the Backend.
func (b *Backend) configureFormatterNode(name string, formatConfig audit.FormatterConfig, logger hclog.Logger) error {
formatterNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for formatter node: %w: %w", audit.ErrInternal, err)
}
formatterNode, err := audit.NewEntryFormatter(name, formatConfig, b, logger)
if err != nil {
return fmt.Errorf("error creating formatter: %w", err)
}
b.nodeIDList = append(b.nodeIDList, formatterNodeID)
b.nodeMap[formatterNodeID] = formatterNode
return nil
}
// configureSinkNode is used to configure a sink node and associated ID on the Backend.
func (b *Backend) configureSinkNode(name string, format string, opts ...event.Option) error {
name = strings.TrimSpace(name)
if name == "" {
return fmt.Errorf("name is required: %w", audit.ErrInvalidParameter)
}
format = strings.TrimSpace(format)
if format == "" {
return fmt.Errorf("format is required: %w", audit.ErrInvalidParameter)
}
sinkNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("error generating random NodeID for sink node: %w: %w", audit.ErrInternal, err)
}
n, err := event.NewSyslogSink(format, opts...)
if err != nil {
return fmt.Errorf("error creating syslog sink node: %w", err)
}
// Wrap the sink node with metrics middleware
sinkMetricTimer, err := audit.NewSinkMetricTimer(name, n)
if err != nil {
return fmt.Errorf("unable to add timing metrics to sink for path %q: %w", name, err)
}
// Decide what kind of labels we want and wrap the sink node inside a metrics counter.
var metricLabeler event.Labeler
switch {
case b.fallback:
metricLabeler = &audit.MetricLabelerAuditFallback{}
default:
metricLabeler = &audit.MetricLabelerAuditSink{}
}
sinkMetricCounter, err := event.NewMetricsCounter(name, sinkMetricTimer, metricLabeler)
if err != nil {
return fmt.Errorf("unable to add counting metrics to sink for path %q: %w", name, err)
}
b.nodeIDList = append(b.nodeIDList, sinkNodeID)
b.nodeMap[sinkNodeID] = sinkMetricCounter
return nil
}
// Name for this backend, this would ideally correspond to the mount path for the audit device.
func (b *Backend) Name() string {
return b.name
}
// Nodes returns the nodes which should be used by the event framework to process audit entries.
func (b *Backend) Nodes() map[eventlogger.NodeID]eventlogger.Node {
return b.nodeMap
}
// NodeIDs returns the IDs of the nodes, in the order they are required.
func (b *Backend) NodeIDs() []eventlogger.NodeID {
return b.nodeIDList
}
// EventType returns the event type for the backend.
func (b *Backend) EventType() eventlogger.EventType {
return event.AuditType.AsEventType()
}
// HasFiltering determines if the first node for the pipeline is an eventlogger.NodeTypeFilter.
func (b *Backend) HasFiltering() bool {
if b.nodeMap == nil {
return false
}
return len(b.nodeIDList) > 0 && b.nodeMap[b.nodeIDList[0]].Type() == eventlogger.NodeTypeFilter
}
// IsFallback can be used to determine if this audit backend device is intended to
// be used as a fallback to catch all events that are not written when only using
// filtered pipelines.
func (b *Backend) IsFallback() bool {
return b.fallback
}

View File

@@ -1,11 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package syslog
// configureFilterNode is used to configure a filter node and associated ID on the Backend.
func (b *Backend) configureFilterNode(_ string) error {
return nil
}

View File

@@ -1,99 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package syslog
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/audit"
"github.com/stretchr/testify/require"
)
// TestBackend_configureFilterNode ensures that configureFilterNode handles various
// filter values as expected. Empty (including whitespace) strings should return
// no error but skip configuration of the node.
// NOTE: Audit filtering is an Enterprise feature and behaves differently in the
// community edition of Vault.
func TestBackend_configureFilterNode(t *testing.T) {
t.Parallel()
tests := map[string]struct {
filter string
}{
"happy": {
filter: "operation == update",
},
"empty": {
filter: "",
},
"spacey": {
filter: " ",
},
"bad": {
filter: "___qwerty",
},
"unsupported-field": {
filter: "foo == bar",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
err := b.configureFilterNode(tc.filter)
require.NoError(t, err)
require.Len(t, b.nodeIDList, 0)
require.Len(t, b.nodeMap, 0)
})
}
}
// TestBackend_configureFilterFormatterSink ensures that configuring all three
// types of nodes on a Backend works as expected, i.e. we have only formatter and sink
// nodes at the end and nothing gets overwritten. The order of calls influences the
// slice of IDs on the Backend.
// NOTE: Audit filtering is an Enterprise feature and behaves differently in the
// community edition of Vault.
func TestBackend_configureFilterFormatterSink(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
formatConfig, err := audit.NewFormatterConfig(&audit.NoopHeaderFormatter{})
require.NoError(t, err)
err = b.configureFilterNode("path == bar")
require.NoError(t, err)
err = b.configureFormatterNode("juan", formatConfig, hclog.NewNullLogger())
require.NoError(t, err)
err = b.configureSinkNode("foo", "json")
require.NoError(t, err)
require.Len(t, b.nodeIDList, 2)
require.Len(t, b.nodeMap, 2)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeFormatter, node.Type())
id = b.nodeIDList[1]
node = b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
}

View File

@@ -1,351 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package syslog
import (
"testing"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// TestBackend_newFormatterConfig ensures that all the configuration values are parsed correctly.
func TestBackend_newFormatterConfig(t *testing.T) {
t.Parallel()
tests := map[string]struct {
config map[string]string
want audit.FormatterConfig
wantErr bool
expectedErrMsg string
}{
"happy-path-json": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{
Raw: true,
HMACAccessor: true,
ElideListResponses: true,
RequiredFormat: "json",
}, wantErr: false,
},
"happy-path-jsonx": {
config: map[string]string{
"format": audit.JSONxFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{
Raw: true,
HMACAccessor: true,
ElideListResponses: true,
RequiredFormat: "jsonx",
},
wantErr: false,
},
"invalid-format": {
config: map[string]string{
"format": " squiggly ",
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "true",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedErrMsg: "unsupported 'format': invalid configuration",
},
"invalid-hmac-accessor": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedErrMsg: "unable to parse 'hmac_accessor': invalid configuration",
},
"invalid-log-raw": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedErrMsg: "unable to parse 'log_raw: invalid configuration",
},
"invalid-elide-bool": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"hmac_accessor": "true",
"log_raw": "true",
"elide_list_responses": "maybe",
},
want: audit.FormatterConfig{},
wantErr: true,
expectedErrMsg: "unable to parse 'elide_list_responses': invalid configuration",
},
"prefix": {
config: map[string]string{
"format": audit.JSONFormat.String(),
"prefix": "foo",
},
want: audit.FormatterConfig{
RequiredFormat: audit.JSONFormat,
Prefix: "foo",
HMACAccessor: true,
},
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
got, err := newFormatterConfig(&audit.NoopHeaderFormatter{}, tc.config)
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrMsg)
} else {
require.NoError(t, err)
}
require.Equal(t, tc.want.RequiredFormat, got.RequiredFormat)
require.Equal(t, tc.want.Raw, got.Raw)
require.Equal(t, tc.want.ElideListResponses, got.ElideListResponses)
require.Equal(t, tc.want.HMACAccessor, got.HMACAccessor)
require.Equal(t, tc.want.OmitTime, got.OmitTime)
require.Equal(t, tc.want.Prefix, got.Prefix)
})
}
}
// TestBackend_configureFormatterNode ensures that configureFormatterNode
// populates the nodeIDList and nodeMap on Backend when given valid formatConfig.
func TestBackend_configureFormatterNode(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
formatConfig, err := audit.NewFormatterConfig(&audit.NoopHeaderFormatter{})
require.NoError(t, err)
err = b.configureFormatterNode("juan", formatConfig, hclog.NewNullLogger())
require.NoError(t, err)
require.Len(t, b.nodeIDList, 1)
require.Len(t, b.nodeMap, 1)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeFormatter, node.Type())
}
// TestBackend_configureSinkNode ensures that we can correctly configure the sink
// node on the Backend, and any incorrect parameters result in the relevant errors.
func TestBackend_configureSinkNode(t *testing.T) {
t.Parallel()
tests := map[string]struct {
name string
format string
wantErr bool
expectedErrMsg string
expectedName string
}{
"name-empty": {
name: "",
wantErr: true,
expectedErrMsg: "name is required: invalid internal parameter",
},
"name-whitespace": {
name: " ",
wantErr: true,
expectedErrMsg: "name is required: invalid internal parameter",
},
"format-empty": {
name: "foo",
format: "",
wantErr: true,
expectedErrMsg: "format is required: invalid internal parameter",
},
"format-whitespace": {
name: "foo",
format: " ",
wantErr: true,
expectedErrMsg: "format is required: invalid internal parameter",
},
"happy": {
name: "foo",
format: "json",
wantErr: false,
expectedName: "foo",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
b := &Backend{
nodeIDList: []eventlogger.NodeID{},
nodeMap: map[eventlogger.NodeID]eventlogger.Node{},
}
err := b.configureSinkNode(tc.name, tc.format)
if tc.wantErr {
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrMsg)
require.Len(t, b.nodeIDList, 0)
require.Len(t, b.nodeMap, 0)
} else {
require.NoError(t, err)
require.Len(t, b.nodeIDList, 1)
require.Len(t, b.nodeMap, 1)
id := b.nodeIDList[0]
node := b.nodeMap[id]
require.Equal(t, eventlogger.NodeTypeSink, node.Type())
mc, ok := node.(*event.MetricsCounter)
require.True(t, ok)
require.Equal(t, tc.expectedName, mc.Name)
}
})
}
}
// TestBackend_Factory_Conf is used to ensure that any configuration which is
// supplied, is validated and tested.
func TestBackend_Factory_Conf(t *testing.T) {
t.Parallel()
tests := map[string]struct {
backendConfig *audit.BackendConfig
isErrorExpected bool
expectedErrorMessage string
}{
"nil-salt-config": {
backendConfig: &audit.BackendConfig{
SaltConfig: nil,
},
isErrorExpected: true,
expectedErrorMessage: "nil salt config: invalid internal parameter",
},
"nil-salt-view": {
backendConfig: &audit.BackendConfig{
SaltConfig: &salt.Config{},
},
isErrorExpected: true,
expectedErrorMessage: "nil salt view: invalid internal parameter",
},
"non-fallback-device-with-filter": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "false",
"filter": "mount_type == kv",
},
},
isErrorExpected: false,
},
"fallback-device-with-filter": {
backendConfig: &audit.BackendConfig{
MountPath: "discard",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "true",
"filter": "mount_type == kv",
},
},
isErrorExpected: true,
expectedErrorMessage: "cannot configure a fallback device with a filter: invalid configuration",
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
be, err := Factory(tc.backendConfig, &audit.NoopHeaderFormatter{})
switch {
case tc.isErrorExpected:
require.Error(t, err)
require.EqualError(t, err, tc.expectedErrorMessage)
default:
require.NoError(t, err)
require.NotNil(t, be)
}
})
}
}
// TestBackend_IsFallback ensures that the 'fallback' config setting is parsed
// and set correctly, then exposed via the interface method IsFallback().
func TestBackend_IsFallback(t *testing.T) {
t.Parallel()
tests := map[string]struct {
backendConfig *audit.BackendConfig
isFallbackExpected bool
}{
"fallback": {
backendConfig: &audit.BackendConfig{
MountPath: "qwerty",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "true",
},
},
isFallbackExpected: true,
},
"no-fallback": {
backendConfig: &audit.BackendConfig{
MountPath: "qwerty",
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Logger: hclog.NewNullLogger(),
Config: map[string]string{
"fallback": "false",
},
},
isFallbackExpected: false,
},
}
for name, tc := range tests {
name := name
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
be, err := Factory(tc.backendConfig, &audit.NoopHeaderFormatter{})
require.NoError(t, err)
require.NotNil(t, be)
require.Equal(t, tc.isFallbackExpected, be.IsFallback())
})
}
}

View File

@@ -13,7 +13,6 @@ import (
uuid "github.com/hashicorp/go-uuid" uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/builtin/logical/transit"
vaulthttp "github.com/hashicorp/vault/http" vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
@@ -26,7 +25,7 @@ func TestTransit_Issue_2958(t *testing.T) {
"transit": transit.Factory, "transit": transit.Factory,
}, },
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": file.Factory, "file": audit.NewFileBackend,
}, },
} }

View File

@@ -4,7 +4,6 @@
package command package command
import ( import (
"io/ioutil"
"os" "os"
"strings" "strings"
"testing" "testing"
@@ -169,26 +168,12 @@ func TestAuditEnableCommand_Run(t *testing.T) {
client, closer := testVaultServerAllBackends(t) client, closer := testVaultServerAllBackends(t)
defer closer() defer closer()
files, err := ioutil.ReadDir("../builtin/audit") for name := range auditBackends {
if err != nil {
t.Fatal(err)
}
var backends []string
for _, f := range files {
if f.IsDir() {
backends = append(backends, f.Name())
}
}
for _, b := range backends {
ui, cmd := testAuditEnableCommand(t) ui, cmd := testAuditEnableCommand(t)
cmd.client = client cmd.client = client
args := []string{ args := []string{name}
b, switch name {
}
switch b {
case "file": case "file":
args = append(args, "file_path=discard") args = append(args, "file_path=discard")
case "socket": case "socket":
@@ -199,15 +184,10 @@ func TestAuditEnableCommand_Run(t *testing.T) {
t.Log("skipping syslog test on WSL") t.Log("skipping syslog test on WSL")
continue continue
} }
if os.Getenv("CIRCLECI") == "true" {
// TODO install syslog in docker image we run our tests in
t.Log("skipping syslog test on CircleCI")
continue
}
} }
code := cmd.Run(args) code := cmd.Run(args)
if exp := 0; code != exp { if exp := 0; code != exp {
t.Errorf("type %s, expected %d to be %d - %s", b, code, exp, ui.OutputWriter.String()+ui.ErrorWriter.String()) t.Errorf("type %s, expected %d to be %d - %s", name, code, exp, ui.OutputWriter.String()+ui.ErrorWriter.String())
} }
} }
}) })

View File

@@ -17,7 +17,6 @@ import (
kv "github.com/hashicorp/vault-plugin-secrets-kv" kv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/builtin/logical/pki" "github.com/hashicorp/vault/builtin/logical/pki"
"github.com/hashicorp/vault/builtin/logical/ssh" "github.com/hashicorp/vault/builtin/logical/ssh"
@@ -39,7 +38,7 @@ var (
} }
defaultVaultAuditBackends = map[string]audit.Factory{ defaultVaultAuditBackends = map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
} }
defaultVaultLogicalBackends = map[string]logical.Factory{ defaultVaultLogicalBackends = map[string]logical.Factory{

View File

@@ -19,9 +19,6 @@ import (
credOCI "github.com/hashicorp/vault-plugin-auth-oci" credOCI "github.com/hashicorp/vault-plugin-auth-oci"
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
auditSocket "github.com/hashicorp/vault/builtin/audit/socket"
auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog"
credAws "github.com/hashicorp/vault/builtin/credential/aws" credAws "github.com/hashicorp/vault/builtin/credential/aws"
credCert "github.com/hashicorp/vault/builtin/credential/cert" credCert "github.com/hashicorp/vault/builtin/credential/cert"
credGitHub "github.com/hashicorp/vault/builtin/credential/github" credGitHub "github.com/hashicorp/vault/builtin/credential/github"
@@ -166,9 +163,9 @@ const (
var ( var (
auditBackends = map[string]audit.Factory{ auditBackends = map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
"socket": auditSocket.Factory, "socket": audit.NewSocketBackend,
"syslog": auditSyslog.Factory, "syslog": audit.NewSyslogBackend,
} }
credentialBackends = map[string]logical.Factory{ credentialBackends = map[string]logical.Factory{

View File

@@ -11,7 +11,6 @@ import (
"testing" "testing"
"github.com/hashicorp/cap/ldap" "github.com/hashicorp/cap/ldap"
"github.com/hashicorp/vault/sdk/helper/docker" "github.com/hashicorp/vault/sdk/helper/docker"
"github.com/hashicorp/vault/sdk/helper/ldaputil" "github.com/hashicorp/vault/sdk/helper/ldaputil"
) )

View File

@@ -6,9 +6,6 @@ package minimal
import ( import (
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
auditSocket "github.com/hashicorp/vault/builtin/audit/socket"
auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog"
logicalDb "github.com/hashicorp/vault/builtin/logical/database" logicalDb "github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/builtinplugins"
@@ -64,9 +61,9 @@ func NewTestSoloCluster(t testing.T, config *vault.CoreConfig) *vault.TestCluste
} }
if mycfg.AuditBackends == nil { if mycfg.AuditBackends == nil {
mycfg.AuditBackends = map[string]audit.Factory{ mycfg.AuditBackends = map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
"socket": auditSocket.Factory, "socket": audit.NewSocketBackend,
"syslog": auditSyslog.Factory, "syslog": audit.NewSyslogBackend,
} }
} }
if mycfg.BuiltinRegistry == nil { if mycfg.BuiltinRegistry == nil {

View File

@@ -14,9 +14,6 @@ import (
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
auditSocket "github.com/hashicorp/vault/builtin/audit/socket"
auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog"
logicalDb "github.com/hashicorp/vault/builtin/logical/database" logicalDb "github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
@@ -320,9 +317,9 @@ func ClusterSetup(conf *vault.CoreConfig, opts *vault.TestClusterOptions, setup
} }
if localConf.AuditBackends == nil { if localConf.AuditBackends == nil {
localConf.AuditBackends = map[string]audit.Factory{ localConf.AuditBackends = map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
"socket": auditSocket.Factory, "socket": audit.NewSocketBackend,
"syslog": auditSyslog.Factory, "syslog": audit.NewSyslogBackend,
"noop": audit.NoopAuditFactory(nil), "noop": audit.NoopAuditFactory(nil),
} }
} }

View File

@@ -23,7 +23,6 @@ import (
kv "github.com/hashicorp/vault-plugin-secrets-kv" kv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/helper/testhelpers/corehelpers"
@@ -681,7 +680,7 @@ func TestLogical_AuditPort(t *testing.T) {
"kv": kv.VersionedKVFactory, "kv": kv.VersionedKVFactory,
}, },
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
}, },
} }
@@ -876,7 +875,7 @@ func testBuiltinPluginMetadataAuditLog(t *testing.T, log map[string]interface{},
func TestLogical_AuditEnabled_ShouldLogPluginMetadata_Auth(t *testing.T) { func TestLogical_AuditEnabled_ShouldLogPluginMetadata_Auth(t *testing.T) {
coreConfig := &vault.CoreConfig{ coreConfig := &vault.CoreConfig{
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
}, },
} }
@@ -949,7 +948,7 @@ func TestLogical_AuditEnabled_ShouldLogPluginMetadata_Secret(t *testing.T) {
"kv": kv.VersionedKVFactory, "kv": kv.VersionedKVFactory,
}, },
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
}, },
} }

View File

@@ -17,7 +17,6 @@ import (
"github.com/go-test/deep" "github.com/go-test/deep"
"github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/internalshared/configutil" "github.com/hashicorp/vault/internalshared/configutil"
@@ -573,7 +572,7 @@ func TestSysSealStatusRedaction(t *testing.T) {
EnableRaw: true, EnableRaw: true,
BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(), BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(),
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
}, },
} }
core, _, token := vault.TestCoreUnsealedWithConfig(t, conf) core, _, token := vault.TestCoreUnsealedWithConfig(t, conf)

View File

@@ -152,7 +152,7 @@ func (c *Core) enableAudit(ctx context.Context, entry *MountEntry, updateStorage
defer view.setReadOnlyErr(origViewReadOnlyErr) defer view.setReadOnlyErr(origViewReadOnlyErr)
// Lookup the new backend // Lookup the new backend
backend, err := c.newAuditBackend(ctx, entry, view, entry.Options) backend, err := c.newAuditBackend(entry, view, entry.Options)
if err != nil { if err != nil {
return err return err
} }
@@ -193,7 +193,7 @@ func (c *Core) enableAudit(ctx context.Context, entry *MountEntry, updateStorage
c.audit = newTable c.audit = newTable
// Register the backend // Register the backend
err = c.auditBroker.Register(entry.Path, backend, entry.Local) err = c.auditBroker.Register(backend, entry.Local)
if err != nil { if err != nil {
return fmt.Errorf("failed to register %q audit backend %q: %w", entry.Type, entry.Path, err) return fmt.Errorf("failed to register %q audit backend %q: %w", entry.Type, entry.Path, err)
} }
@@ -432,7 +432,7 @@ func (c *Core) setupAudits(ctx context.Context) error {
brokerLogger := c.baseLogger.Named("audit") brokerLogger := c.baseLogger.Named("audit")
broker, err := NewAuditBroker(brokerLogger) broker, err := audit.NewBroker(brokerLogger)
if err != nil { if err != nil {
return err return err
} }
@@ -456,7 +456,7 @@ func (c *Core) setupAudits(ctx context.Context) error {
}) })
// Initialize the backend // Initialize the backend
backend, err := c.newAuditBackend(ctx, entry, view, entry.Options) backend, err := c.newAuditBackend(entry, view, entry.Options)
if err != nil { if err != nil {
c.logger.Error("failed to create audit entry", "path", entry.Path, "error", err) c.logger.Error("failed to create audit entry", "path", entry.Path, "error", err)
continue continue
@@ -467,7 +467,7 @@ func (c *Core) setupAudits(ctx context.Context) error {
} }
// Mount the backend // Mount the backend
err = broker.Register(entry.Path, backend, entry.Local) err = broker.Register(backend, entry.Local)
if err != nil { if err != nil {
c.logger.Error("failed to setup audit backed", "path", entry.Path, "type", entry.Type, "error", err) c.logger.Error("failed to setup audit backed", "path", entry.Path, "type", entry.Type, "error", err)
continue continue
@@ -528,7 +528,7 @@ func (c *Core) removeAuditReloadFunc(entry *MountEntry) {
} }
// newAuditBackend is used to create and configure a new audit backend by name // newAuditBackend is used to create and configure a new audit backend by name
func (c *Core) newAuditBackend(ctx context.Context, entry *MountEntry, view logical.Storage, conf map[string]string) (audit.Backend, error) { func (c *Core) newAuditBackend(entry *MountEntry, view logical.Storage, conf map[string]string) (audit.Backend, error) {
// Ensure that non-Enterprise versions aren't trying to supply Enterprise only options. // Ensure that non-Enterprise versions aren't trying to supply Enterprise only options.
if hasInvalidAuditOptions(entry.Options) { if hasInvalidAuditOptions(entry.Options) {
return nil, fmt.Errorf("enterprise-only options supplied: %w", audit.ErrInvalidParameter) return nil, fmt.Errorf("enterprise-only options supplied: %w", audit.ErrInvalidParameter)

View File

@@ -340,16 +340,16 @@ func verifyDefaultAuditTable(t *testing.T, table *MountTable) {
func TestAuditBroker_LogRequest(t *testing.T) { func TestAuditBroker_LogRequest(t *testing.T) {
l := logging.NewVaultLogger(log.Trace) l := logging.NewVaultLogger(log.Trace)
b, err := NewAuditBroker(l) b, err := audit.NewBroker(l)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a1 := audit.TestNoopAudit(t, "foo", nil) a1 := audit.TestNoopAudit(t, "foo", nil)
a2 := audit.TestNoopAudit(t, "bar", nil) a2 := audit.TestNoopAudit(t, "bar", nil)
err = b.Register("foo", a1, false) err = b.Register(a1, false)
require.NoError(t, err) require.NoError(t, err)
err = b.Register("bar", a2, false) err = b.Register(a2, false)
require.NoError(t, err) require.NoError(t, err)
auth := &logical.Auth{ auth := &logical.Auth{
@@ -429,16 +429,16 @@ func TestAuditBroker_LogRequest(t *testing.T) {
func TestAuditBroker_LogResponse(t *testing.T) { func TestAuditBroker_LogResponse(t *testing.T) {
l := logging.NewVaultLogger(log.Trace) l := logging.NewVaultLogger(log.Trace)
b, err := NewAuditBroker(l) b, err := audit.NewBroker(l)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a1 := audit.TestNoopAudit(t, "foo", nil) a1 := audit.TestNoopAudit(t, "foo", nil)
a2 := audit.TestNoopAudit(t, "bar", nil) a2 := audit.TestNoopAudit(t, "bar", nil)
err = b.Register("foo", a1, false) err = b.Register(a1, false)
require.NoError(t, err) require.NoError(t, err)
err = b.Register("bar", a2, false) err = b.Register(a2, false)
require.NoError(t, err) require.NoError(t, err)
auth := &logical.Auth{ auth := &logical.Auth{
@@ -534,7 +534,7 @@ func TestAuditBroker_LogResponse(t *testing.T) {
func TestAuditBroker_AuditHeaders(t *testing.T) { func TestAuditBroker_AuditHeaders(t *testing.T) {
logger := logging.NewVaultLogger(log.Trace) logger := logging.NewVaultLogger(log.Trace)
b, err := NewAuditBroker(logger) b, err := audit.NewBroker(logger)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -542,9 +542,9 @@ func TestAuditBroker_AuditHeaders(t *testing.T) {
a1 := audit.TestNoopAudit(t, "foo", nil) a1 := audit.TestNoopAudit(t, "foo", nil)
a2 := audit.TestNoopAudit(t, "bar", nil) a2 := audit.TestNoopAudit(t, "bar", nil)
err = b.Register("foo", a1, false) err = b.Register(a1, false)
require.NoError(t, err) require.NoError(t, err)
err = b.Register("bar", a2, false) err = b.Register(a2, false)
require.NoError(t, err) require.NoError(t, err)
auth := &logical.Auth{ auth := &logical.Auth{
@@ -741,10 +741,8 @@ func TestAudit_newAuditBackend(t *testing.T) {
Type: "noop", Type: "noop",
Options: map[string]string{"fallback": "true"}, Options: map[string]string{"fallback": "true"},
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := c.newAuditBackend(ctx, me, &logical.InmemStorage{}, me.Options) _, err := c.newAuditBackend(me, &logical.InmemStorage{}, me.Options)
if constants.IsEnterprise { if constants.IsEnterprise {
require.NoError(t, err) require.NoError(t, err)

View File

@@ -389,11 +389,11 @@ type Core struct {
// auditBroker is used to ingest the audit events and fan // auditBroker is used to ingest the audit events and fan
// out into the configured audit backends // out into the configured audit backends
auditBroker *AuditBroker auditBroker *audit.Broker
// auditedHeaders is used to configure which http headers // auditedHeaders is used to configure which http headers
// can be output in the audit logs // can be output in the audit logs
auditedHeaders *AuditedHeadersConfig auditedHeaders *audit.HeadersConfig
// systemBackend is the backend which is used to manage internal operations // systemBackend is the backend which is used to manage internal operations
systemBackend *SystemBackend systemBackend *SystemBackend
@@ -2477,7 +2477,7 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c
} }
} else { } else {
broker, err := NewAuditBroker(logger) broker, err := audit.NewBroker(logger)
if err != nil { if err != nil {
return err return err
} }
@@ -2935,7 +2935,7 @@ func (c *Core) BarrierKeyLength() (min, max int) {
return return
} }
func (c *Core) AuditedHeadersConfig() *AuditedHeadersConfig { func (c *Core) AuditedHeadersConfig() *audit.HeadersConfig {
return c.auditedHeaders return c.auditedHeaders
} }
@@ -4551,3 +4551,26 @@ func (c *Core) DetectStateLockDeadlocks() bool {
} }
return false return false
} }
// setupAuditedHeadersConfig will initialize new audited headers configuration on
// the Core by loading data from the barrier view.
func (c *Core) setupAuditedHeadersConfig(ctx context.Context) error {
// Create a sub-view, e.g. sys/audited-headers-config/
view := c.systemBarrierView.SubView(audit.AuditedHeadersSubPath)
headers, err := audit.NewHeadersConfig(view)
if err != nil {
return err
}
// Invalidate the headers now in order to load them for the first time.
err = headers.Invalidate(ctx)
if err != nil {
return err
}
// Update the Core.
c.auditedHeaders = headers
return nil
}

View File

@@ -21,9 +21,6 @@ import (
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
logicalKv "github.com/hashicorp/vault-plugin-secrets-kv" logicalKv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/builtin/audit/socket"
"github.com/hashicorp/vault/builtin/audit/syslog"
logicalDb "github.com/hashicorp/vault/builtin/logical/database" logicalDb "github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/builtin/plugin" "github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/command/server"
@@ -59,24 +56,24 @@ func TestNewCore_configureAuditBackends(t *testing.T) {
}, },
"file": { "file": {
backends: map[string]audit.Factory{ backends: map[string]audit.Factory{
"file": file.Factory, "file": audit.NewFileBackend,
}, },
}, },
"socket": { "socket": {
backends: map[string]audit.Factory{ backends: map[string]audit.Factory{
"socket": socket.Factory, "socket": audit.NewSocketBackend,
}, },
}, },
"syslog": { "syslog": {
backends: map[string]audit.Factory{ backends: map[string]audit.Factory{
"syslog": syslog.Factory, "syslog": audit.NewSyslogBackend,
}, },
}, },
"all": { "all": {
backends: map[string]audit.Factory{ backends: map[string]audit.Factory{
"file": file.Factory, "file": audit.NewFileBackend,
"socket": socket.Factory, "socket": audit.NewSocketBackend,
"syslog": syslog.Factory, "syslog": audit.NewSyslogBackend,
}, },
}, },
} }

View File

@@ -9,7 +9,6 @@ import (
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
"github.com/hashicorp/vault/builtin/logical/database" "github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/builtin/logical/pki" "github.com/hashicorp/vault/builtin/logical/pki"
@@ -40,7 +39,7 @@ func testVaultServerUnseal(t testing.TB) (*api.Client, []string, func()) {
"userpass": credUserpass.Factory, "userpass": credUserpass.Factory,
}, },
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
}, },
LogicalBackends: map[string]logical.Factory{ LogicalBackends: map[string]logical.Factory{
"database": database.Factory, "database": database.Factory,

View File

@@ -15,7 +15,6 @@ import (
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/api/auth/approle" "github.com/hashicorp/vault/api/auth/approle"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/builtin/logical/database" "github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/testhelpers/consul" "github.com/hashicorp/vault/helper/testhelpers/consul"
@@ -38,7 +37,7 @@ func getCluster(t *testing.T, numCores int, types ...consts.PluginType) *vault.T
"database": database.Factory, "database": database.Factory,
}, },
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
}, },
} }

View File

@@ -1133,14 +1133,14 @@ func (b *SystemBackend) handlePluginRuntimeCatalogList(ctx context.Context, _ *l
} }
// handleAuditedHeaderUpdate creates or overwrites a header entry // handleAuditedHeaderUpdate creates or overwrites a header entry
func (b *SystemBackend) handleAuditedHeaderUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { func (b *SystemBackend) handleAuditedHeaderUpdate(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) {
header := d.Get("header").(string) header := d.Get("header").(string)
hmac := d.Get("hmac").(bool) hmac := d.Get("hmac").(bool)
if header == "" { if header == "" {
return logical.ErrorResponse("missing header name"), nil return logical.ErrorResponse("missing header name"), nil
} }
err := b.Core.AuditedHeadersConfig().add(ctx, header, hmac) err := b.Core.AuditedHeadersConfig().Add(ctx, header, hmac)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1149,13 +1149,13 @@ func (b *SystemBackend) handleAuditedHeaderUpdate(ctx context.Context, req *logi
} }
// handleAuditedHeaderDelete deletes the header with the given name // handleAuditedHeaderDelete deletes the header with the given name
func (b *SystemBackend) handleAuditedHeaderDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { func (b *SystemBackend) handleAuditedHeaderDelete(ctx context.Context, _ *logical.Request, d *framework.FieldData) (*logical.Response, error) {
header := d.Get("header").(string) header := d.Get("header").(string)
if header == "" { if header == "" {
return logical.ErrorResponse("missing header name"), nil return logical.ErrorResponse("missing header name"), nil
} }
err := b.Core.AuditedHeadersConfig().remove(ctx, header) err := b.Core.AuditedHeadersConfig().Remove(ctx, header)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1170,7 +1170,7 @@ func (b *SystemBackend) handleAuditedHeaderRead(_ context.Context, _ *logical.Re
return logical.ErrorResponse("missing header name"), nil return logical.ErrorResponse("missing header name"), nil
} }
settings, ok := b.Core.AuditedHeadersConfig().header(header) settings, ok := b.Core.AuditedHeadersConfig().Header(header)
if !ok { if !ok {
return logical.ErrorResponse("Could not find header in config"), nil return logical.ErrorResponse("Could not find header in config"), nil
} }
@@ -1184,7 +1184,7 @@ func (b *SystemBackend) handleAuditedHeaderRead(_ context.Context, _ *logical.Re
// handleAuditedHeadersRead returns the whole audited headers config // handleAuditedHeadersRead returns the whole audited headers config
func (b *SystemBackend) handleAuditedHeadersRead(_ context.Context, _ *logical.Request, _ *framework.FieldData) (*logical.Response, error) { func (b *SystemBackend) handleAuditedHeadersRead(_ context.Context, _ *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
headerSettings := b.Core.AuditedHeadersConfig().headers() headerSettings := b.Core.AuditedHeadersConfig().Headers()
return &logical.Response{ return &logical.Response{
Data: map[string]interface{}{ Data: map[string]interface{}{

View File

@@ -36,7 +36,6 @@ import (
kv "github.com/hashicorp/vault-plugin-secrets-kv" kv "github.com/hashicorp/vault-plugin-secrets-kv"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
"github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/command/server"
"github.com/hashicorp/vault/helper/constants" "github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/metricsutil"
@@ -131,7 +130,7 @@ func TestCoreWithSeal(t testing.T, testSeal Seal, enableRaw bool) *Core {
EnableRaw: enableRaw, EnableRaw: enableRaw,
BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(), BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(),
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
}, },
} }
return TestCoreWithSealAndUI(t, conf) return TestCoreWithSealAndUI(t, conf)
@@ -144,7 +143,7 @@ func TestCoreWithDeadlockDetection(t testing.T, testSeal Seal, enableRaw bool) *
EnableRaw: enableRaw, EnableRaw: enableRaw,
BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(), BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(),
AuditBackends: map[string]audit.Factory{ AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory, "file": audit.NewFileBackend,
}, },
DetectDeadlocks: "expiration,quotas,statelock,barrier", DetectDeadlocks: "expiration,quotas,statelock,barrier",
} }