mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	VAULT-17079: Adding Hash Function and HeaderAdjuster to EntryFormatter (#22042)
* add hashfunc field to EntryFormatter struct and adjust NewEntryFormatter function and tests * add HeaderAdjuster interface and require it in EntryFormatter dquote> adjust all references to NewEntryFormatter to include a HeaderAdjuster parameter * replace use of hash function in AuditedHeadersConfig's ApplyConfig method with Salter interface instance * fixup! replace use of hash function in AuditedHeadersConfig's ApplyConfig method with Salter interface instance * review feedback * Go doc typo * add another test function --------- Co-authored-by: Peter Wilson <peter.wilson@hashicorp.com>
This commit is contained in:
		| @@ -30,7 +30,7 @@ var ( | ||||
|  | ||||
| // NewEntryFormatter should be used to create an EntryFormatter. | ||||
| // Accepted options: WithPrefix. | ||||
| func NewEntryFormatter(config FormatterConfig, salter Salter, opt ...Option) (*EntryFormatter, error) { | ||||
| func NewEntryFormatter(config FormatterConfig, salter Salter, headersConfig HeaderFormatter, opt ...Option) (*EntryFormatter, error) { | ||||
| 	const op = "audit.NewEntryFormatter" | ||||
|  | ||||
| 	if salter == nil { | ||||
| @@ -50,17 +50,18 @@ func NewEntryFormatter(config FormatterConfig, salter Salter, opt ...Option) (*E | ||||
| 	return &EntryFormatter{ | ||||
| 		salter:        salter, | ||||
| 		config:        config, | ||||
| 		headersConfig: headersConfig, | ||||
| 		prefix:        opts.withPrefix, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // Reopen is a no-op for the formatter node. | ||||
| func (_ *EntryFormatter) Reopen() error { | ||||
| func (*EntryFormatter) Reopen() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Type describes the type of this node (formatter). | ||||
| func (_ *EntryFormatter) Type() eventlogger.NodeType { | ||||
| func (*EntryFormatter) Type() eventlogger.NodeType { | ||||
| 	return eventlogger.NodeTypeFormatter | ||||
| } | ||||
|  | ||||
| @@ -145,11 +146,6 @@ func (f *EntryFormatter) FormatRequest(ctx context.Context, in *logical.LogInput | ||||
| 		return nil, errors.New("salt func not configured") | ||||
| 	} | ||||
|  | ||||
| 	s, err := f.salter.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error fetching salt: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Set these to the input values at first | ||||
| 	auth := in.Auth | ||||
| 	req := in.Request | ||||
| @@ -163,12 +159,13 @@ func (f *EntryFormatter) FormatRequest(ctx context.Context, in *logical.LogInput | ||||
| 	} | ||||
|  | ||||
| 	if !f.config.Raw { | ||||
| 		auth, err = HashAuth(s, auth, f.config.HMACAccessor) | ||||
| 		var err error | ||||
| 		auth, err = HashAuth(ctx, f.salter, auth, f.config.HMACAccessor) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		req, err = HashRequest(s, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) | ||||
| 		req, err = HashRequest(ctx, f.salter, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -277,11 +274,6 @@ func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInpu | ||||
| 		return nil, errors.New("salt func not configured") | ||||
| 	} | ||||
|  | ||||
| 	s, err := f.salter.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error fetching salt: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// Set these to the input values at first | ||||
| 	auth, req, resp := in.Auth, in.Request, in.Response | ||||
| 	if auth == nil { | ||||
| @@ -314,17 +306,18 @@ func (f *EntryFormatter) FormatResponse(ctx context.Context, in *logical.LogInpu | ||||
| 			respData = resp.Data | ||||
| 		} | ||||
| 	} else { | ||||
| 		auth, err = HashAuth(s, auth, f.config.HMACAccessor) | ||||
| 		var err error | ||||
| 		auth, err = HashAuth(ctx, f.salter, auth, f.config.HMACAccessor) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		req, err = HashRequest(s, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) | ||||
| 		req, err = HashRequest(ctx, f.salter, req, f.config.HMACAccessor, in.NonHMACReqDataKeys) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		resp, err = HashResponse(s, resp, f.config.HMACAccessor, in.NonHMACRespDataKeys, elideListResponseData) | ||||
| 		resp, err = HashResponse(ctx, f.salter, resp, f.config.HMACAccessor, in.NonHMACRespDataKeys, elideListResponseData) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|   | ||||
| @@ -127,7 +127,7 @@ func TestNewEntryFormatter(t *testing.T) { | ||||
|  | ||||
| 			cfg, err := NewFormatterConfig(tc.Options...) | ||||
| 			require.NoError(t, err) | ||||
| 			f, err := NewEntryFormatter(cfg, ss, tc.Options...) | ||||
| 			f, err := NewEntryFormatter(cfg, ss, nil, tc.Options...) | ||||
|  | ||||
| 			switch { | ||||
| 			case tc.IsErrorExpected: | ||||
| @@ -150,7 +150,7 @@ func TestEntryFormatter_Reopen(t *testing.T) { | ||||
| 	cfg, err := NewFormatterConfig() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	f, err := NewEntryFormatter(cfg, ss) | ||||
| 	f, err := NewEntryFormatter(cfg, ss, nil) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, f) | ||||
| 	require.NoError(t, f.Reopen()) | ||||
| @@ -162,7 +162,7 @@ func TestEntryFormatter_Type(t *testing.T) { | ||||
| 	cfg, err := NewFormatterConfig() | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	f, err := NewEntryFormatter(cfg, ss) | ||||
| 	f, err := NewEntryFormatter(cfg, ss, nil) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, f) | ||||
| 	require.Equal(t, eventlogger.NodeTypeFormatter, f.Type()) | ||||
| @@ -305,7 +305,7 @@ func TestEntryFormatter_Process(t *testing.T) { | ||||
| 			cfg, err := NewFormatterConfig(WithFormat(tc.RequiredFormat.String())) | ||||
| 			require.NoError(t, err) | ||||
|  | ||||
| 			f, err := NewEntryFormatter(cfg, ss) | ||||
| 			f, err := NewEntryFormatter(cfg, ss, nil) | ||||
| 			require.NoError(t, err) | ||||
| 			require.NotNil(t, f) | ||||
|  | ||||
| @@ -366,13 +366,13 @@ func BenchmarkAuditFileSink_Process(b *testing.B) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	ctx := namespace.RootContext(nil) | ||||
| 	ctx := namespace.RootContext(context.Background()) | ||||
|  | ||||
| 	// Create the formatter node. | ||||
| 	cfg, err := NewFormatterConfig() | ||||
| 	require.NoError(b, err) | ||||
| 	ss := newStaticSalt(b) | ||||
| 	formatter, err := NewEntryFormatter(cfg, ss) | ||||
| 	formatter, err := NewEntryFormatter(cfg, ss, nil) | ||||
| 	require.NoError(b, err) | ||||
| 	require.NotNil(b, formatter) | ||||
|  | ||||
|   | ||||
| @@ -96,7 +96,7 @@ func NewTemporaryFormatter(requiredFormat, prefix string) (*EntryFormatterWriter | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	eventFormatter, err := NewEntryFormatter(cfg, &nonPersistentSalt{}, WithPrefix(prefix)) | ||||
| 	eventFormatter, err := NewEntryFormatter(cfg, &nonPersistentSalt{}, nil, WithPrefix(prefix)) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|   | ||||
| @@ -127,7 +127,7 @@ func TestNewEntryFormatterWriter(t *testing.T) { | ||||
|  | ||||
| 			var f Formatter | ||||
| 			if !tc.UseNilFormatter { | ||||
| 				tempFormatter, err := NewEntryFormatter(cfg, s) | ||||
| 				tempFormatter, err := NewEntryFormatter(cfg, s, nil) | ||||
| 				require.NoError(t, err) | ||||
| 				require.NotNil(t, tempFormatter) | ||||
| 				f = tempFormatter | ||||
| @@ -189,9 +189,10 @@ func TestEntryFormatter_FormatRequest(t *testing.T) { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			ss := newStaticSalt(t) | ||||
| 			cfg, err := NewFormatterConfig() | ||||
| 			require.NoError(t, err) | ||||
| 			f, err := NewEntryFormatter(cfg, newStaticSalt(t)) | ||||
| 			f, err := NewEntryFormatter(cfg, ss, nil) | ||||
| 			require.NoError(t, err) | ||||
|  | ||||
| 			var ctx context.Context | ||||
| @@ -255,9 +256,10 @@ func TestEntryFormatter_FormatResponse(t *testing.T) { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			ss := newStaticSalt(t) | ||||
| 			cfg, err := NewFormatterConfig() | ||||
| 			require.NoError(t, err) | ||||
| 			f, err := NewEntryFormatter(cfg, newStaticSalt(t)) | ||||
| 			f, err := NewEntryFormatter(cfg, ss, nil) | ||||
| 			require.NoError(t, err) | ||||
|  | ||||
| 			var ctx context.Context | ||||
| @@ -359,7 +361,7 @@ func TestElideListResponses(t *testing.T) { | ||||
|  | ||||
| 	formatResponse := func(t *testing.T, config FormatterConfig, operation logical.Operation, inputData map[string]interface{}, | ||||
| 	) { | ||||
| 		f, err := NewEntryFormatter(config, &tfw) | ||||
| 		f, err := NewEntryFormatter(config, &tfw, nil) | ||||
| 		require.NoError(t, err) | ||||
| 		formatter, err := NewEntryFormatterWriter(config, f, &tfw) | ||||
| 		require.NoError(t, err) | ||||
|   | ||||
| @@ -4,13 +4,13 @@ | ||||
| package audit | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/hashicorp/go-secure-stdlib/strutil" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/salt" | ||||
| 	"github.com/hashicorp/vault/sdk/helper/wrapping" | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| 	"github.com/mitchellh/copystructure" | ||||
| @@ -18,17 +18,27 @@ import ( | ||||
| ) | ||||
|  | ||||
| // HashString hashes the given opaque string and returns it | ||||
| func HashString(salter *salt.Salt, data string) string { | ||||
| 	return salter.GetIdentifiedHMAC(data) | ||||
| func HashString(ctx context.Context, salter Salter, data string) (string, error) { | ||||
| 	salt, err := salter.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	return salt.GetIdentifiedHMAC(data), nil | ||||
| } | ||||
|  | ||||
| // HashAuth returns a hashed copy of the logical.Auth input. | ||||
| func HashAuth(salter *salt.Salt, in *logical.Auth, HMACAccessor bool) (*logical.Auth, error) { | ||||
| func HashAuth(ctx context.Context, salter Salter, in *logical.Auth, HMACAccessor bool) (*logical.Auth, error) { | ||||
| 	if in == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | ||||
| 	fn := salter.GetIdentifiedHMAC | ||||
| 	salt, err := salter.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	fn := salt.GetIdentifiedHMAC | ||||
| 	auth := *in | ||||
|  | ||||
| 	if auth.ClientToken != "" { | ||||
| @@ -41,12 +51,17 @@ func HashAuth(salter *salt.Salt, in *logical.Auth, HMACAccessor bool) (*logical. | ||||
| } | ||||
|  | ||||
| // HashRequest returns a hashed copy of the logical.Request input. | ||||
| func HashRequest(salter *salt.Salt, in *logical.Request, HMACAccessor bool, nonHMACDataKeys []string) (*logical.Request, error) { | ||||
| func HashRequest(ctx context.Context, salter Salter, in *logical.Request, HMACAccessor bool, nonHMACDataKeys []string) (*logical.Request, error) { | ||||
| 	if in == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | ||||
| 	fn := salter.GetIdentifiedHMAC | ||||
| 	salt, err := salter.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	fn := salt.GetIdentifiedHMAC | ||||
| 	req := *in | ||||
|  | ||||
| 	if req.Auth != nil { | ||||
| @@ -55,7 +70,7 @@ func HashRequest(salter *salt.Salt, in *logical.Request, HMACAccessor bool, nonH | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		req.Auth, err = HashAuth(salter, cp.(*logical.Auth), HMACAccessor) | ||||
| 		req.Auth, err = HashAuth(ctx, salter, cp.(*logical.Auth), HMACAccessor) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -84,11 +99,11 @@ func HashRequest(salter *salt.Salt, in *logical.Request, HMACAccessor bool, nonH | ||||
| 	return &req, nil | ||||
| } | ||||
|  | ||||
| func hashMap(fn func(string) string, data map[string]interface{}, nonHMACDataKeys []string) error { | ||||
| func hashMap(hashFunc HashCallback, data map[string]interface{}, nonHMACDataKeys []string) error { | ||||
| 	for k, v := range data { | ||||
| 		if o, ok := v.(logical.OptMarshaler); ok { | ||||
| 			marshaled, err := o.MarshalJSONWithOptions(&logical.MarshalOptions{ | ||||
| 				ValueHasher: fn, | ||||
| 				ValueHasher: hashFunc, | ||||
| 			}) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| @@ -97,22 +112,21 @@ func hashMap(fn func(string) string, data map[string]interface{}, nonHMACDataKey | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return HashStructure(data, fn, nonHMACDataKeys) | ||||
| 	return HashStructure(data, hashFunc, nonHMACDataKeys) | ||||
| } | ||||
|  | ||||
| // HashResponse returns a hashed copy of the logical.Request input. | ||||
| func HashResponse( | ||||
| 	salter *salt.Salt, | ||||
| 	in *logical.Response, | ||||
| 	HMACAccessor bool, | ||||
| 	nonHMACDataKeys []string, | ||||
| 	elideListResponseData bool, | ||||
| ) (*logical.Response, error) { | ||||
| func HashResponse(ctx context.Context, salter Salter, in *logical.Response, HMACAccessor bool, nonHMACDataKeys []string, elideListResponseData bool) (*logical.Response, error) { | ||||
| 	if in == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | ||||
| 	fn := salter.GetIdentifiedHMAC | ||||
| 	salt, err := salter.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	fn := salt.GetIdentifiedHMAC | ||||
| 	resp := *in | ||||
|  | ||||
| 	if resp.Auth != nil { | ||||
| @@ -121,7 +135,7 @@ func HashResponse( | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		resp.Auth, err = HashAuth(salter, cp.(*logical.Auth), HMACAccessor) | ||||
| 		resp.Auth, err = HashAuth(ctx, salter, cp.(*logical.Auth), HMACAccessor) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -154,7 +168,7 @@ func HashResponse( | ||||
|  | ||||
| 	if resp.WrapInfo != nil { | ||||
| 		var err error | ||||
| 		resp.WrapInfo, err = HashWrapInfo(salter, resp.WrapInfo, HMACAccessor) | ||||
| 		resp.WrapInfo, err = hashWrapInfo(fn, resp.WrapInfo, HMACAccessor) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| @@ -163,22 +177,21 @@ func HashResponse( | ||||
| 	return &resp, nil | ||||
| } | ||||
|  | ||||
| // HashWrapInfo returns a hashed copy of the wrapping.ResponseWrapInfo input. | ||||
| func HashWrapInfo(salter *salt.Salt, in *wrapping.ResponseWrapInfo, HMACAccessor bool) (*wrapping.ResponseWrapInfo, error) { | ||||
| // hashWrapInfo returns a hashed copy of the wrapping.ResponseWrapInfo input. | ||||
| func hashWrapInfo(hashFunc HashCallback, in *wrapping.ResponseWrapInfo, HMACAccessor bool) (*wrapping.ResponseWrapInfo, error) { | ||||
| 	if in == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | ||||
| 	fn := salter.GetIdentifiedHMAC | ||||
| 	wrapinfo := *in | ||||
|  | ||||
| 	wrapinfo.Token = fn(wrapinfo.Token) | ||||
| 	wrapinfo.Token = hashFunc(wrapinfo.Token) | ||||
|  | ||||
| 	if HMACAccessor { | ||||
| 		wrapinfo.Accessor = fn(wrapinfo.Accessor) | ||||
| 		wrapinfo.Accessor = hashFunc(wrapinfo.Accessor) | ||||
|  | ||||
| 		if wrapinfo.WrappedAccessor != "" { | ||||
| 			wrapinfo.WrappedAccessor = fn(wrapinfo.WrappedAccessor) | ||||
| 			wrapinfo.WrappedAccessor = hashFunc(wrapinfo.WrappedAccessor) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -98,20 +98,32 @@ func TestCopy_response(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestHashString(t *testing.T) { | ||||
| // TestSalter is a structure that implements the Salter interface in a trivial | ||||
| // manner. | ||||
| type TestSalter struct{} | ||||
|  | ||||
| // Salt returns a salt.Salt pointer based on dummy data stored in an in-memory | ||||
| // storage instance. | ||||
| func (*TestSalter) Salt(ctx context.Context) (*salt.Salt, error) { | ||||
| 	inmemStorage := &logical.InmemStorage{} | ||||
| 	inmemStorage.Put(context.Background(), &logical.StorageEntry{ | ||||
| 		Key:   "salt", | ||||
| 		Value: []byte("foo"), | ||||
| 	}) | ||||
| 	localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ | ||||
|  | ||||
| 	return salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ | ||||
| 		HMAC:     sha256.New, | ||||
| 		HMACType: "hmac-sha256", | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestHashString(t *testing.T) { | ||||
| 	salter := &TestSalter{} | ||||
|  | ||||
| 	out, err := HashString(context.Background(), salter, "foo") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error instantiating salt: %s", err) | ||||
| 	} | ||||
| 	out := HashString(localSalt, "foo") | ||||
| 	if out != "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a" { | ||||
| 		t.Fatalf("err: HashString output did not match expected") | ||||
| 	} | ||||
| @@ -152,16 +164,10 @@ func TestHashAuth(t *testing.T) { | ||||
| 		Key:   "salt", | ||||
| 		Value: []byte("foo"), | ||||
| 	}) | ||||
| 	localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ | ||||
| 		HMAC:     sha256.New, | ||||
| 		HMACType: "hmac-sha256", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error instantiating salt: %s", err) | ||||
| 	} | ||||
| 	salter := &TestSalter{} | ||||
| 	for _, tc := range cases { | ||||
| 		input := fmt.Sprintf("%#v", tc.Input) | ||||
| 		out, err := HashAuth(localSalt, tc.Input, tc.HMACAccessor) | ||||
| 		out, err := HashAuth(context.Background(), salter, tc.Input, tc.HMACAccessor) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("err: %s\n\n%s", err, input) | ||||
| 		} | ||||
| @@ -216,16 +222,10 @@ func TestHashRequest(t *testing.T) { | ||||
| 		Key:   "salt", | ||||
| 		Value: []byte("foo"), | ||||
| 	}) | ||||
| 	localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ | ||||
| 		HMAC:     sha256.New, | ||||
| 		HMACType: "hmac-sha256", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error instantiating salt: %s", err) | ||||
| 	} | ||||
| 	salter := &TestSalter{} | ||||
| 	for _, tc := range cases { | ||||
| 		input := fmt.Sprintf("%#v", tc.Input) | ||||
| 		out, err := HashRequest(localSalt, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys) | ||||
| 		out, err := HashRequest(context.Background(), salter, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("err: %s\n\n%s", err, input) | ||||
| 		} | ||||
| @@ -287,16 +287,10 @@ func TestHashResponse(t *testing.T) { | ||||
| 		Key:   "salt", | ||||
| 		Value: []byte("foo"), | ||||
| 	}) | ||||
| 	localSalt, err := salt.NewSalt(context.Background(), inmemStorage, &salt.Config{ | ||||
| 		HMAC:     sha256.New, | ||||
| 		HMACType: "hmac-sha256", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Error instantiating salt: %s", err) | ||||
| 	} | ||||
| 	salter := &TestSalter{} | ||||
| 	for _, tc := range cases { | ||||
| 		input := fmt.Sprintf("%#v", tc.Input) | ||||
| 		out, err := HashResponse(localSalt, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys, false) | ||||
| 		out, err := HashResponse(context.Background(), salter, tc.Input, tc.HMACAccessor, tc.NonHMACDataKeys, false) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("err: %s\n\n%s", err, input) | ||||
| 		} | ||||
|   | ||||
| @@ -85,9 +85,19 @@ type Writer interface { | ||||
| 	WriteResponse(io.Writer, *ResponseEntry) error | ||||
| } | ||||
|  | ||||
| // HeaderFormatter is an interface defining the methods of the | ||||
| // vault.AuditedHeadersConfig structure needed in this package. | ||||
| type HeaderFormatter interface { | ||||
| 	// ApplyConfig returns a map of header values that consists of the | ||||
| 	// intersection of the provided set of header values with a configured | ||||
| 	// set of headers and will hash headers that have been configured as such. | ||||
| 	ApplyConfig(context.Context, map[string][]string, Salter) (map[string][]string, error) | ||||
| } | ||||
|  | ||||
| // EntryFormatter should be used to format audit requests and responses. | ||||
| type EntryFormatter struct { | ||||
| 	salter        Salter | ||||
| 	headersConfig HeaderFormatter | ||||
| 	config        FormatterConfig | ||||
| 	prefix        string | ||||
| } | ||||
| @@ -255,6 +265,9 @@ type nonPersistentSalt struct{} | ||||
| // sink information to different backends such as logs, file, databases, | ||||
| // or other external services. | ||||
| type Backend interface { | ||||
| 	// Salter interface must be implemented by anything implementing Backend. | ||||
| 	Salter | ||||
|  | ||||
| 	// LogRequest is used to synchronously log a request. This is done after the | ||||
| 	// request is authorized but before the request is executed. The arguments | ||||
| 	// MUST not be modified in any way. They should be deep copied if this is | ||||
| @@ -273,11 +286,6 @@ type Backend interface { | ||||
| 	// operation on creation, which is currently disallowed.) | ||||
| 	LogTestMessage(context.Context, *logical.LogInput, map[string]string) error | ||||
|  | ||||
| 	// GetHash is used to return the given data with the backend's hash, | ||||
| 	// so that a caller can determine if a value in the audit log matches | ||||
| 	// an expected plaintext value | ||||
| 	GetHash(context.Context, string) (string, error) | ||||
|  | ||||
| 	// Reload is called on SIGHUP for supporting backends. | ||||
| 	Reload(context.Context) error | ||||
|  | ||||
| @@ -305,4 +313,4 @@ type BackendConfig struct { | ||||
| } | ||||
|  | ||||
| // Factory is the factory function to create an audit backend. | ||||
| type Factory func(context.Context, *BackendConfig, bool) (Backend, error) | ||||
| type Factory func(context.Context, *BackendConfig, bool, HeaderFormatter) (Backend, error) | ||||
|   | ||||
| @@ -100,7 +100,7 @@ func TestFormatJSON_formatRequest(t *testing.T) { | ||||
| 		var buf bytes.Buffer | ||||
| 		cfg, err := NewFormatterConfig() | ||||
| 		require.NoError(t, err) | ||||
| 		f, err := NewEntryFormatter(cfg, ss) | ||||
| 		f, err := NewEntryFormatter(cfg, ss, nil) | ||||
| 		require.NoError(t, err) | ||||
| 		formatter := EntryFormatterWriter{ | ||||
| 			Formatter: f, | ||||
|   | ||||
| @@ -119,7 +119,7 @@ func TestFormatJSONx_formatRequest(t *testing.T) { | ||||
| 			WithFormat(JSONxFormat.String()), | ||||
| 		) | ||||
| 		require.NoError(t, err) | ||||
| 		f, err := NewEntryFormatter(cfg, tempStaticSalt) | ||||
| 		f, err := NewEntryFormatter(cfg, tempStaticSalt, nil) | ||||
| 		require.NoError(t, err) | ||||
| 		writer := &JSONxWriter{Prefix: tc.Prefix} | ||||
| 		formatter, err := NewEntryFormatterWriter(cfg, f, writer) | ||||
|   | ||||
| @@ -22,7 +22,7 @@ import ( | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
|  | ||||
| func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool) (audit.Backend, error) { | ||||
| func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 	if conf.SaltConfig == nil { | ||||
| 		return nil, fmt.Errorf("nil salt config") | ||||
| 	} | ||||
| @@ -131,7 +131,7 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool | ||||
| 	b.salt.Store((*salt.Salt)(nil)) | ||||
|  | ||||
| 	// Configure the formatter for either case. | ||||
| 	f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithPrefix(conf.Config["prefix"])) | ||||
| 	f, err := audit.NewEntryFormatter(b.formatConfig, b, headersConfig, audit.WithPrefix(conf.Config["prefix"])) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error creating formatter: %w", err) | ||||
| 	} | ||||
| @@ -253,15 +253,6 @@ func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { | ||||
| 	return newSalt, nil | ||||
| } | ||||
|  | ||||
| func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { | ||||
| 	salt, err := b.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	return audit.HashString(salt, data), nil | ||||
| } | ||||
|  | ||||
| func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { | ||||
| 	var writer io.Writer | ||||
| 	switch b.path { | ||||
|   | ||||
| @@ -35,7 +35,7 @@ func TestAuditFile_fileModeNew(t *testing.T) { | ||||
| 		SaltConfig: &salt.Config{}, | ||||
| 		SaltView:   &logical.InmemStorage{}, | ||||
| 		Config:     config, | ||||
| 	}, false) | ||||
| 	}, false, nil) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -74,7 +74,7 @@ func TestAuditFile_fileModeExisting(t *testing.T) { | ||||
| 		Config:     config, | ||||
| 		SaltConfig: &salt.Config{}, | ||||
| 		SaltView:   &logical.InmemStorage{}, | ||||
| 	}, false) | ||||
| 	}, false, nil) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -114,7 +114,7 @@ func TestAuditFile_fileMode0000(t *testing.T) { | ||||
| 		Config:     config, | ||||
| 		SaltConfig: &salt.Config{}, | ||||
| 		SaltView:   &logical.InmemStorage{}, | ||||
| 	}, false) | ||||
| 	}, false, nil) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -148,7 +148,7 @@ func TestAuditFile_EventLogger_fileModeNew(t *testing.T) { | ||||
| 		SaltConfig: &salt.Config{}, | ||||
| 		SaltView:   &logical.InmemStorage{}, | ||||
| 		Config:     config, | ||||
| 	}, true) | ||||
| 	}, true, nil) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -170,7 +170,7 @@ func BenchmarkAuditFile_request(b *testing.B) { | ||||
| 		Config:     config, | ||||
| 		SaltConfig: &salt.Config{}, | ||||
| 		SaltView:   &logical.InmemStorage{}, | ||||
| 	}, false) | ||||
| 	}, false, nil) | ||||
| 	if err != nil { | ||||
| 		b.Fatal(err) | ||||
| 	} | ||||
|   | ||||
| @@ -21,7 +21,7 @@ import ( | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
|  | ||||
| func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool) (audit.Backend, error) { | ||||
| func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 	if conf.SaltConfig == nil { | ||||
| 		return nil, fmt.Errorf("nil salt config") | ||||
| 	} | ||||
| @@ -108,7 +108,7 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool | ||||
| 	} | ||||
|  | ||||
| 	// Configure the formatter for either case. | ||||
| 	f, err := audit.NewEntryFormatter(b.formatConfig, b) | ||||
| 	f, err := audit.NewEntryFormatter(b.formatConfig, b, headersConfig) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error creating formatter: %w", err) | ||||
| 	} | ||||
| @@ -177,14 +177,6 @@ type Backend struct { | ||||
|  | ||||
| var _ audit.Backend = (*Backend)(nil) | ||||
|  | ||||
| func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { | ||||
| 	salt, err := b.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return audit.HashString(salt, data), nil | ||||
| } | ||||
|  | ||||
| func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { | ||||
| 	var buf bytes.Buffer | ||||
| 	if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { | ||||
|   | ||||
| @@ -18,7 +18,7 @@ import ( | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
|  | ||||
| func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool) (audit.Backend, error) { | ||||
| func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 	if conf.SaltConfig == nil { | ||||
| 		return nil, fmt.Errorf("nil salt config") | ||||
| 	} | ||||
| @@ -102,7 +102,7 @@ func Factory(ctx context.Context, conf *audit.BackendConfig, useEventLogger bool | ||||
| 	} | ||||
|  | ||||
| 	// Configure the formatter for either case. | ||||
| 	f, err := audit.NewEntryFormatter(b.formatConfig, b, audit.WithPrefix(conf.Config["prefix"])) | ||||
| 	f, err := audit.NewEntryFormatter(b.formatConfig, b, headersConfig, audit.WithPrefix(conf.Config["prefix"])) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error creating formatter: %w", err) | ||||
| 	} | ||||
| @@ -166,14 +166,6 @@ type Backend struct { | ||||
|  | ||||
| var _ audit.Backend = (*Backend)(nil) | ||||
|  | ||||
| func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { | ||||
| 	salt, err := b.Salt(ctx) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return audit.HashString(salt, data), nil | ||||
| } | ||||
|  | ||||
| func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { | ||||
| 	var buf bytes.Buffer | ||||
| 	if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil { | ||||
|   | ||||
| @@ -252,7 +252,7 @@ func NewNoopAudit(config map[string]string) (*NoopAudit, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	f, err := audit.NewEntryFormatter(cfg, n) | ||||
| 	f, err := audit.NewEntryFormatter(cfg, n, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error creating formatter: %w", err) | ||||
| 	} | ||||
| @@ -268,7 +268,7 @@ func NewNoopAudit(config map[string]string) (*NoopAudit, error) { | ||||
| } | ||||
|  | ||||
| func NoopAuditFactory(records **[][]byte) audit.Factory { | ||||
| 	return func(_ context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 	return func(_ context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 		n, err := NewNoopAudit(config.Config) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
|   | ||||
| @@ -482,7 +482,7 @@ func TestLogical_Audit_invalidWrappingToken(t *testing.T) { | ||||
| 	noop := corehelpers.TestNoopAudit(t, nil) | ||||
| 	c, _, root := vault.TestCoreUnsealedWithConfig(t, &vault.CoreConfig{ | ||||
| 		AuditBackends: map[string]audit.Factory{ | ||||
| 			"noop": func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 			"noop": func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 				return noop, nil | ||||
| 			}, | ||||
| 		}, | ||||
|   | ||||
| @@ -486,7 +486,7 @@ func (c *Core) newAuditBackend(ctx context.Context, entry *MountEntry, view logi | ||||
| 		SaltView:   view, | ||||
| 		SaltConfig: saltConfig, | ||||
| 		Config:     conf, | ||||
| 	}, c.IsExperimentEnabled(experiments.VaultExperimentCoreAuditEventsAlpha1)) | ||||
| 	}, c.IsExperimentEnabled(experiments.VaultExperimentCoreAuditEventsAlpha1), c.auditedHeaders) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|   | ||||
| @@ -129,7 +129,7 @@ func (a *AuditBroker) GetHash(ctx context.Context, name string, input string) (s | ||||
| 		return "", fmt.Errorf("unknown audit backend %q", name) | ||||
| 	} | ||||
|  | ||||
| 	return be.backend.GetHash(ctx, input) | ||||
| 	return audit.HashString(ctx, be.backend, input) | ||||
| } | ||||
|  | ||||
| // LogRequest is used to ensure all the audit backends have an opportunity to | ||||
| @@ -182,7 +182,7 @@ func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput, head | ||||
| 	anyLogged := false | ||||
| 	for name, be := range a.backends { | ||||
| 		in.Request.Headers = nil | ||||
| 		transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend.GetHash) | ||||
| 		transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend) | ||||
| 		if thErr != nil { | ||||
| 			a.logger.Error("backend failed to include headers", "backend", name, "error", thErr) | ||||
| 			continue | ||||
| @@ -247,7 +247,7 @@ func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput, hea | ||||
| 	anyLogged := false | ||||
| 	for name, be := range a.backends { | ||||
| 		in.Request.Headers = nil | ||||
| 		transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend.GetHash) | ||||
| 		transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend) | ||||
| 		if thErr != nil { | ||||
| 			a.logger.Error("backend failed to include headers", "backend", name, "error", thErr) | ||||
| 			continue | ||||
|   | ||||
| @@ -27,7 +27,7 @@ import ( | ||||
|  | ||||
| func TestAudit_ReadOnlyViewDuringMount(t *testing.T) { | ||||
| 	c, _, _ := TestCoreUnsealed(t) | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 		err := config.SaltView.Put(ctx, &logical.StorageEntry{ | ||||
| 			Key:   "bar", | ||||
| 			Value: []byte("baz"), | ||||
| @@ -36,7 +36,7 @@ func TestAudit_ReadOnlyViewDuringMount(t *testing.T) { | ||||
| 			t.Fatalf("expected a read-only error") | ||||
| 		} | ||||
| 		factory := corehelpers.NoopAuditFactory(nil) | ||||
| 		return factory(ctx, config, false) | ||||
| 		return factory(ctx, config, false, nil) | ||||
| 	} | ||||
|  | ||||
| 	me := &MountEntry{ | ||||
| @@ -103,7 +103,7 @@ func TestCore_EnableAudit(t *testing.T) { | ||||
| func TestCore_EnableAudit_MixedFailures(t *testing.T) { | ||||
| 	c, _, _ := TestCoreUnsealed(t) | ||||
| 	c.auditBackends["noop"] = corehelpers.NoopAuditFactory(nil) | ||||
| 	c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 	c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 		return nil, fmt.Errorf("failing enabling") | ||||
| 	} | ||||
|  | ||||
| @@ -152,7 +152,7 @@ func TestCore_EnableAudit_MixedFailures(t *testing.T) { | ||||
| func TestCore_EnableAudit_Local(t *testing.T) { | ||||
| 	c, _, _ := TestCoreUnsealed(t) | ||||
| 	c.auditBackends["noop"] = corehelpers.NoopAuditFactory(nil) | ||||
| 	c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 	c.auditBackends["fail"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 		return nil, fmt.Errorf("failing enabling") | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -9,6 +9,7 @@ import ( | ||||
| 	"strings" | ||||
| 	"sync" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/audit" | ||||
| 	"github.com/hashicorp/vault/sdk/logical" | ||||
| ) | ||||
|  | ||||
| @@ -92,7 +93,7 @@ func (a *AuditedHeadersConfig) remove(ctx context.Context, header string) error | ||||
|  | ||||
| // ApplyConfig returns a map of approved headers and their values, either | ||||
| // hmac'ed or plaintext | ||||
| func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, hashFunc func(context.Context, string) (string, error)) (result map[string][]string, retErr error) { | ||||
| func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, salter audit.Salter) (result map[string][]string, retErr error) { | ||||
| 	// Grab a read lock | ||||
| 	a.RLock() | ||||
| 	defer a.RUnlock() | ||||
| @@ -114,7 +115,7 @@ func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[stri | ||||
| 			// Optionally hmac the values | ||||
| 			if settings.HMAC { | ||||
| 				for i, el := range hVals { | ||||
| 					hVal, err := hashFunc(ctx, el) | ||||
| 					hVal, err := audit.HashString(ctx, salter, el) | ||||
| 					if err != nil { | ||||
| 						return nil, err | ||||
| 					} | ||||
|   | ||||
| @@ -5,7 +5,9 @@ package vault | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/hashicorp/vault/sdk/helper/salt" | ||||
| @@ -169,6 +171,12 @@ func testAuditedHeadersConfig_Remove(t *testing.T, conf *AuditedHeadersConfig) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type TestSalter struct{} | ||||
|  | ||||
| func (*TestSalter) Salt(ctx context.Context) (*salt.Salt, error) { | ||||
| 	return salt.NewSalt(ctx, nil, nil) | ||||
| } | ||||
|  | ||||
| func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { | ||||
| 	conf := mockAuditedHeadersConfig(t) | ||||
|  | ||||
| @@ -181,20 +189,40 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { | ||||
| 		"Content-Type":   {"json"}, | ||||
| 	} | ||||
|  | ||||
| 	hashFunc := func(ctx context.Context, s string) (string, error) { return "hashed", nil } | ||||
| 	salter := &TestSalter{} | ||||
|  | ||||
| 	result, err := conf.ApplyConfig(context.Background(), reqHeaders, hashFunc) | ||||
| 	result, err := conf.ApplyConfig(context.Background(), reqHeaders, salter) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	expected := map[string][]string{ | ||||
| 		"x-test-header":  {"foo"}, | ||||
| 		"x-vault-header": {"hashed", "hashed"}, | ||||
| 		"x-vault-header": {"hmac-sha256:", "hmac-sha256:"}, | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(result, expected) { | ||||
| 		t.Fatalf("Expected headers did not match actual: Expected %#v\n Got %#v\n", expected, result) | ||||
| 	if len(expected) != len(result) { | ||||
| 		t.Fatalf("Expected headers count did not match actual count: Expected count %d\n Got %d\n", len(expected), len(result)) | ||||
| 	} | ||||
|  | ||||
| 	for resultKey, resultValues := range result { | ||||
| 		expectedValues := expected[resultKey] | ||||
|  | ||||
| 		if len(expectedValues) != len(resultValues) { | ||||
| 			t.Fatalf("Expected header values count did not match actual values count: Expected count: %d\n Got %d\n", len(expectedValues), len(resultValues)) | ||||
| 		} | ||||
|  | ||||
| 		for i, e := range expectedValues { | ||||
| 			if e == "hmac-sha256:" { | ||||
| 				if !strings.HasPrefix(resultValues[i], e) { | ||||
| 					t.Fatalf("Expected headers did not match actual: Expected %#v...\n Got %#v\n", e, resultValues[i]) | ||||
| 				} | ||||
| 			} else { | ||||
| 				if e != resultValues[i] { | ||||
| 					t.Fatalf("Expected headers did not match actual: Expected %#v\n Got %#v\n", e, resultValues[i]) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Make sure we didn't edit the reqHeaders map | ||||
| @@ -209,6 +237,91 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestAuditedHeadersConfig_ApplyConfig_NoHeaders tests the case where there are | ||||
| // no headers in the request. | ||||
| func TestAuditedHeadersConfig_ApplyConfig_NoRequestHeaders(t *testing.T) { | ||||
| 	conf := mockAuditedHeadersConfig(t) | ||||
|  | ||||
| 	conf.add(context.Background(), "X-TesT-Header", false) | ||||
| 	conf.add(context.Background(), "X-Vault-HeAdEr", true) | ||||
|  | ||||
| 	reqHeaders := map[string][]string{} | ||||
|  | ||||
| 	salter := &TestSalter{} | ||||
|  | ||||
| 	result, err := conf.ApplyConfig(context.Background(), reqHeaders, salter) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if len(result) != 0 { | ||||
| 		t.Fatalf("Expected no headers but actually got: %d\n", len(result)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAuditedHeadersConfig_ApplyConfig_NoConfiguredHeaders(t *testing.T) { | ||||
| 	conf := mockAuditedHeadersConfig(t) | ||||
|  | ||||
| 	reqHeaders := map[string][]string{ | ||||
| 		"X-Test-Header":  {"foo"}, | ||||
| 		"X-Vault-Header": {"bar", "bar"}, | ||||
| 		"Content-Type":   {"json"}, | ||||
| 	} | ||||
|  | ||||
| 	salter := &TestSalter{} | ||||
|  | ||||
| 	result, err := conf.ApplyConfig(context.Background(), reqHeaders, salter) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if len(result) != 0 { | ||||
| 		t.Fatalf("Expected no headers but actually got: %d\n", len(result)) | ||||
| 	} | ||||
|  | ||||
| 	// Make sure we didn't edit the reqHeaders map | ||||
| 	reqHeadersCopy := map[string][]string{ | ||||
| 		"X-Test-Header":  {"foo"}, | ||||
| 		"X-Vault-Header": {"bar", "bar"}, | ||||
| 		"Content-Type":   {"json"}, | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(reqHeaders, reqHeadersCopy) { | ||||
| 		t.Fatalf("Req headers were changed, expected %#v\n got %#v", reqHeadersCopy, reqHeaders) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // FailingSalter is an implementation of the Salter interface where the Salt | ||||
| // method always returns an error. | ||||
| type FailingSalter struct{} | ||||
|  | ||||
| // Salt always returns an error. | ||||
| func (s *FailingSalter) Salt(context.Context) (*salt.Salt, error) { | ||||
| 	return nil, errors.New("testing error") | ||||
| } | ||||
|  | ||||
| // TestAuditedHeadersConfig_ApplyConfig_HashStringError tests the case where | ||||
| // an error is returned from HashString instead of a map of headers. | ||||
| func TestAuditedHeadersConfig_ApplyConfig_HashStringError(t *testing.T) { | ||||
| 	conf := mockAuditedHeadersConfig(t) | ||||
|  | ||||
| 	conf.add(context.Background(), "X-TesT-Header", false) | ||||
| 	conf.add(context.Background(), "X-Vault-HeAdEr", true) | ||||
|  | ||||
| 	reqHeaders := map[string][]string{ | ||||
| 		"X-Test-Header":  {"foo"}, | ||||
| 		"X-Vault-Header": {"bar", "bar"}, | ||||
| 		"Content-Type":   {"json"}, | ||||
| 	} | ||||
|  | ||||
| 	salter := &FailingSalter{} | ||||
|  | ||||
| 	_, err := conf.ApplyConfig(context.Background(), reqHeaders, salter) | ||||
| 	if err == nil { | ||||
| 		t.Fatal("expected error from ApplyConfig") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) { | ||||
| 	conf := &AuditedHeadersConfig{ | ||||
| 		Headers: make(map[string]*auditedHeaderSettings), | ||||
| @@ -226,16 +339,11 @@ func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) { | ||||
| 		"Content-Type":   {"json"}, | ||||
| 	} | ||||
|  | ||||
| 	salter, err := salt.NewSalt(context.Background(), nil, nil) | ||||
| 	if err != nil { | ||||
| 		b.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	hashFunc := func(ctx context.Context, s string) (string, error) { return salter.GetIdentifiedHMAC(s), nil } | ||||
| 	salter := &TestSalter{} | ||||
|  | ||||
| 	// Reset the timer since we did a lot above | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		conf.ApplyConfig(context.Background(), reqHeaders, hashFunc) | ||||
| 		conf.ApplyConfig(context.Background(), reqHeaders, salter) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -1137,7 +1137,7 @@ func TestCore_HandleRequest_AuditTrail(t *testing.T) { | ||||
| 	// Create a noop audit backend | ||||
| 	noop := &corehelpers.NoopAudit{} | ||||
| 	c, _, root := TestCoreUnsealed(t) | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 		noop = &corehelpers.NoopAudit{ | ||||
| 			Config: config, | ||||
| 		} | ||||
| @@ -1201,7 +1201,7 @@ func TestCore_HandleRequest_AuditTrail_noHMACKeys(t *testing.T) { | ||||
| 	// Create a noop audit backend | ||||
| 	var noop *corehelpers.NoopAudit | ||||
| 	c, _, root := TestCoreUnsealed(t) | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 		noop = &corehelpers.NoopAudit{ | ||||
| 			Config: config, | ||||
| 		} | ||||
| @@ -1323,7 +1323,7 @@ func TestCore_HandleLogin_AuditTrail(t *testing.T) { | ||||
| 	c.credentialBackends["noop"] = func(context.Context, *logical.BackendConfig) (logical.Backend, error) { | ||||
| 		return noopBack, nil | ||||
| 	} | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 		noop = &corehelpers.NoopAudit{ | ||||
| 			Config: config, | ||||
| 		} | ||||
|   | ||||
| @@ -61,7 +61,7 @@ func TestLoginMfaGenerateTOTPTestAuditIncluded(t *testing.T) { | ||||
| 			"totp": totp.Factory, | ||||
| 		}, | ||||
| 		AuditBackends: map[string]audit.Factory{ | ||||
| 			"noop": func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 			"noop": func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 				return noop, nil | ||||
| 			}, | ||||
| 		}, | ||||
|   | ||||
| @@ -724,7 +724,7 @@ func TestDefaultMountTable(t *testing.T) { | ||||
| func TestCore_MountTable_UpgradeToTyped(t *testing.T) { | ||||
| 	c, _, _ := TestCoreUnsealed(t) | ||||
|  | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool) (audit.Backend, error) { | ||||
| 	c.auditBackends["noop"] = func(ctx context.Context, config *audit.BackendConfig, _ bool, _ audit.HeaderFormatter) (audit.Backend, error) { | ||||
| 		return &corehelpers.NoopAudit{ | ||||
| 			Config: config, | ||||
| 		}, nil | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Marc Boudreau
					Marc Boudreau