mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	Delay salt initialization for audit backends
This commit is contained in:
		| @@ -25,15 +25,21 @@ type Backend interface { | |||||||
| 	// GetHash is used to return the given data with the backend's hash, | 	// GetHash is used to return the given data with the backend's hash, | ||||||
| 	// so that a caller can determine if a value in the audit log matches | 	// so that a caller can determine if a value in the audit log matches | ||||||
| 	// an expected plaintext value | 	// an expected plaintext value | ||||||
| 	GetHash(string) string | 	GetHash(string) (string, error) | ||||||
|  |  | ||||||
| 	// Reload is called on SIGHUP for supporting backends. | 	// Reload is called on SIGHUP for supporting backends. | ||||||
| 	Reload() error | 	Reload() error | ||||||
|  |  | ||||||
|  | 	// Invalidate is called for path invalidation | ||||||
|  | 	Invalidate() | ||||||
| } | } | ||||||
|  |  | ||||||
| type BackendConfig struct { | type BackendConfig struct { | ||||||
| 	// The salt that should be used for any secret obfuscation | 	// The view to store the salt | ||||||
| 	Salt *salt.Salt | 	SaltView logical.Storage | ||||||
|  |  | ||||||
|  | 	// The salt config that should be used for any secret obfuscation | ||||||
|  | 	SaltConfig *salt.Config | ||||||
|  |  | ||||||
| 	// Config is the opaque user configuration provided when mounting | 	// Config is the opaque user configuration provided when mounting | ||||||
| 	Config map[string]string | 	Config map[string]string | ||||||
|   | |||||||
| @@ -7,6 +7,8 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/SermoDigital/jose/jws" | 	"github.com/SermoDigital/jose/jws" | ||||||
|  | 	"github.com/hashicorp/errwrap" | ||||||
|  | 	"github.com/hashicorp/vault/helper/salt" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	"github.com/mitchellh/copystructure" | 	"github.com/mitchellh/copystructure" | ||||||
| ) | ) | ||||||
| @@ -14,6 +16,7 @@ import ( | |||||||
| type AuditFormatWriter interface { | type AuditFormatWriter interface { | ||||||
| 	WriteRequest(io.Writer, *AuditRequestEntry) error | 	WriteRequest(io.Writer, *AuditRequestEntry) error | ||||||
| 	WriteResponse(io.Writer, *AuditResponseEntry) error | 	WriteResponse(io.Writer, *AuditResponseEntry) error | ||||||
|  | 	Salt() (*salt.Salt, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| // AuditFormatter implements the Formatter interface, and allows the underlying | // AuditFormatter implements the Formatter interface, and allows the underlying | ||||||
| @@ -41,6 +44,11 @@ func (f *AuditFormatter) FormatRequest( | |||||||
| 		return fmt.Errorf("no format writer specified") | 		return fmt.Errorf("no format writer specified") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	salt, err := f.Salt() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errwrap.Wrapf("error fetching salt: {{err}}", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if !config.Raw { | 	if !config.Raw { | ||||||
| 		// Before we copy the structure we must nil out some data | 		// Before we copy the structure we must nil out some data | ||||||
| 		// otherwise we will cause reflection to panic and die | 		// otherwise we will cause reflection to panic and die | ||||||
| @@ -70,7 +78,7 @@ func (f *AuditFormatter) FormatRequest( | |||||||
|  |  | ||||||
| 		// Hash any sensitive information | 		// Hash any sensitive information | ||||||
| 		if auth != nil { | 		if auth != nil { | ||||||
| 			if err := Hash(config.Salt, auth); err != nil { | 			if err := Hash(salt, auth); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @@ -80,7 +88,7 @@ func (f *AuditFormatter) FormatRequest( | |||||||
| 		if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { | 		if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { | ||||||
| 			clientTokenAccessor = req.ClientTokenAccessor | 			clientTokenAccessor = req.ClientTokenAccessor | ||||||
| 		} | 		} | ||||||
| 		if err := Hash(config.Salt, req); err != nil { | 		if err := Hash(salt, req); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		if clientTokenAccessor != "" { | 		if clientTokenAccessor != "" { | ||||||
| @@ -152,6 +160,11 @@ func (f *AuditFormatter) FormatResponse( | |||||||
| 		return fmt.Errorf("no format writer specified") | 		return fmt.Errorf("no format writer specified") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	salt, err := f.Salt() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errwrap.Wrapf("error fetching salt: {{err}}", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if !config.Raw { | 	if !config.Raw { | ||||||
| 		// Before we copy the structure we must nil out some data | 		// Before we copy the structure we must nil out some data | ||||||
| 		// otherwise we will cause reflection to panic and die | 		// otherwise we will cause reflection to panic and die | ||||||
| @@ -195,7 +208,7 @@ func (f *AuditFormatter) FormatResponse( | |||||||
| 			if !config.HMACAccessor && auth.Accessor != "" { | 			if !config.HMACAccessor && auth.Accessor != "" { | ||||||
| 				accessor = auth.Accessor | 				accessor = auth.Accessor | ||||||
| 			} | 			} | ||||||
| 			if err := Hash(config.Salt, auth); err != nil { | 			if err := Hash(salt, auth); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			if accessor != "" { | 			if accessor != "" { | ||||||
| @@ -208,7 +221,7 @@ func (f *AuditFormatter) FormatResponse( | |||||||
| 		if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { | 		if !config.HMACAccessor && req != nil && req.ClientTokenAccessor != "" { | ||||||
| 			clientTokenAccessor = req.ClientTokenAccessor | 			clientTokenAccessor = req.ClientTokenAccessor | ||||||
| 		} | 		} | ||||||
| 		if err := Hash(config.Salt, req); err != nil { | 		if err := Hash(salt, req); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		if clientTokenAccessor != "" { | 		if clientTokenAccessor != "" { | ||||||
| @@ -224,7 +237,7 @@ func (f *AuditFormatter) FormatResponse( | |||||||
| 			if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" { | 			if !config.HMACAccessor && resp != nil && resp.WrapInfo != nil && resp.WrapInfo.WrappedAccessor != "" { | ||||||
| 				wrappedAccessor = resp.WrapInfo.WrappedAccessor | 				wrappedAccessor = resp.WrapInfo.WrappedAccessor | ||||||
| 			} | 			} | ||||||
| 			if err := Hash(config.Salt, resp); err != nil { | 			if err := Hash(salt, resp); err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 			if accessor != "" { | 			if accessor != "" { | ||||||
|   | |||||||
| @@ -4,12 +4,15 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/helper/salt" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // JSONFormatWriter is an AuditFormatWriter implementation that structures data into | // JSONFormatWriter is an AuditFormatWriter implementation that structures data into | ||||||
| // a JSON format. | // a JSON format. | ||||||
| type JSONFormatWriter struct { | type JSONFormatWriter struct { | ||||||
| 	Prefix string | 	Prefix   string | ||||||
|  | 	SaltFunc func() (*salt.Salt, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (f *JSONFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { | func (f *JSONFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { | ||||||
| @@ -43,3 +46,7 @@ func (f *JSONFormatWriter) WriteResponse(w io.Writer, resp *AuditResponseEntry) | |||||||
| 	enc := json.NewEncoder(w) | 	enc := json.NewEncoder(w) | ||||||
| 	return enc.Encode(resp) | 	return enc.Encode(resp) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (f *JSONFormatWriter) Salt() (*salt.Salt, error) { | ||||||
|  | 	return f.SaltFunc() | ||||||
|  | } | ||||||
|   | |||||||
| @@ -15,6 +15,13 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestFormatJSON_formatRequest(t *testing.T) { | func TestFormatJSON_formatRequest(t *testing.T) { | ||||||
|  | 	salter, err := salt.NewSalt(nil, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	saltFunc := func() (*salt.Salt, error) { | ||||||
|  | 		return salter, nil | ||||||
|  | 	} | ||||||
| 	cases := map[string]struct { | 	cases := map[string]struct { | ||||||
| 		Auth   *logical.Auth | 		Auth   *logical.Auth | ||||||
| 		Req    *logical.Request | 		Req    *logical.Request | ||||||
| @@ -66,13 +73,11 @@ func TestFormatJSON_formatRequest(t *testing.T) { | |||||||
| 		var buf bytes.Buffer | 		var buf bytes.Buffer | ||||||
| 		formatter := AuditFormatter{ | 		formatter := AuditFormatter{ | ||||||
| 			AuditFormatWriter: &JSONFormatWriter{ | 			AuditFormatWriter: &JSONFormatWriter{ | ||||||
| 				Prefix: tc.Prefix, | 				Prefix:   tc.Prefix, | ||||||
|  | 				SaltFunc: saltFunc, | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 		salter, _ := salt.NewSalt(nil, nil) | 		config := FormatterConfig{} | ||||||
| 		config := FormatterConfig{ |  | ||||||
| 			Salt: salter, |  | ||||||
| 		} |  | ||||||
| 		if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil { | 		if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil { | ||||||
| 			t.Fatalf("bad: %s\nerr: %s", name, err) | 			t.Fatalf("bad: %s\nerr: %s", name, err) | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -5,13 +5,15 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
|  |  | ||||||
|  | 	"github.com/hashicorp/vault/helper/salt" | ||||||
| 	"github.com/jefferai/jsonx" | 	"github.com/jefferai/jsonx" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // JSONxFormatWriter is an AuditFormatWriter implementation that structures data into | // JSONxFormatWriter is an AuditFormatWriter implementation that structures data into | ||||||
| // a XML format. | // a XML format. | ||||||
| type JSONxFormatWriter struct { | type JSONxFormatWriter struct { | ||||||
| 	Prefix string | 	Prefix   string | ||||||
|  | 	SaltFunc func() (*salt.Salt, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (f *JSONxFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { | func (f *JSONxFormatWriter) WriteRequest(w io.Writer, req *AuditRequestEntry) error { | ||||||
| @@ -65,3 +67,7 @@ func (f *JSONxFormatWriter) WriteResponse(w io.Writer, resp *AuditResponseEntry) | |||||||
| 	_, err = w.Write(xmlBytes) | 	_, err = w.Write(xmlBytes) | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (f *JSONxFormatWriter) Salt() (*salt.Salt, error) { | ||||||
|  | 	return f.SaltFunc() | ||||||
|  | } | ||||||
|   | |||||||
| @@ -13,6 +13,13 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestFormatJSONx_formatRequest(t *testing.T) { | func TestFormatJSONx_formatRequest(t *testing.T) { | ||||||
|  | 	salter, err := salt.NewSalt(nil, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	saltFunc := func() (*salt.Salt, error) { | ||||||
|  | 		return salter, nil | ||||||
|  | 	} | ||||||
| 	cases := map[string]struct { | 	cases := map[string]struct { | ||||||
| 		Auth     *logical.Auth | 		Auth     *logical.Auth | ||||||
| 		Req      *logical.Request | 		Req      *logical.Request | ||||||
| @@ -67,12 +74,11 @@ func TestFormatJSONx_formatRequest(t *testing.T) { | |||||||
| 		var buf bytes.Buffer | 		var buf bytes.Buffer | ||||||
| 		formatter := AuditFormatter{ | 		formatter := AuditFormatter{ | ||||||
| 			AuditFormatWriter: &JSONxFormatWriter{ | 			AuditFormatWriter: &JSONxFormatWriter{ | ||||||
| 				Prefix: tc.Prefix, | 				Prefix:   tc.Prefix, | ||||||
|  | 				SaltFunc: saltFunc, | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 		salter, _ := salt.NewSalt(nil, nil) |  | ||||||
| 		config := FormatterConfig{ | 		config := FormatterConfig{ | ||||||
| 			Salt:     salter, |  | ||||||
| 			OmitTime: true, | 			OmitTime: true, | ||||||
| 		} | 		} | ||||||
| 		if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil { | 		if err := formatter.FormatRequest(&buf, config, tc.Auth, tc.Req, tc.Err); err != nil { | ||||||
|   | |||||||
| @@ -10,6 +10,8 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| type noopFormatWriter struct { | type noopFormatWriter struct { | ||||||
|  | 	salt     *salt.Salt | ||||||
|  | 	SaltFunc func() (*salt.Salt, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error { | func (n *noopFormatWriter) WriteRequest(_ io.Writer, _ *AuditRequestEntry) error { | ||||||
| @@ -20,11 +22,20 @@ func (n *noopFormatWriter) WriteResponse(_ io.Writer, _ *AuditResponseEntry) err | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestFormatRequestErrors(t *testing.T) { | func (n *noopFormatWriter) Salt() (*salt.Salt, error) { | ||||||
| 	salter, _ := salt.NewSalt(nil, nil) | 	if n.salt != nil { | ||||||
| 	config := FormatterConfig{ | 		return n.salt, nil | ||||||
| 		Salt: salter, |  | ||||||
| 	} | 	} | ||||||
|  | 	var err error | ||||||
|  | 	n.salt, err = salt.NewSalt(nil, nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return n.salt, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestFormatRequestErrors(t *testing.T) { | ||||||
|  | 	config := FormatterConfig{} | ||||||
| 	formatter := AuditFormatter{ | 	formatter := AuditFormatter{ | ||||||
| 		AuditFormatWriter: &noopFormatWriter{}, | 		AuditFormatWriter: &noopFormatWriter{}, | ||||||
| 	} | 	} | ||||||
| @@ -38,10 +49,7 @@ func TestFormatRequestErrors(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestFormatResponseErrors(t *testing.T) { | func TestFormatResponseErrors(t *testing.T) { | ||||||
| 	salter, _ := salt.NewSalt(nil, nil) | 	config := FormatterConfig{} | ||||||
| 	config := FormatterConfig{ |  | ||||||
| 		Salt: salter, |  | ||||||
| 	} |  | ||||||
| 	formatter := AuditFormatter{ | 	formatter := AuditFormatter{ | ||||||
| 		AuditFormatWriter: &noopFormatWriter{}, | 		AuditFormatWriter: &noopFormatWriter{}, | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,7 +3,6 @@ package audit | |||||||
| import ( | import ( | ||||||
| 	"io" | 	"io" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/helper/salt" |  | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -19,7 +18,6 @@ type Formatter interface { | |||||||
|  |  | ||||||
| type FormatterConfig struct { | type FormatterConfig struct { | ||||||
| 	Raw          bool | 	Raw          bool | ||||||
| 	Salt         *salt.Salt |  | ||||||
| 	HMACAccessor bool | 	HMACAccessor bool | ||||||
|  |  | ||||||
| 	// This should only ever be used in a testing context | 	// This should only ever be used in a testing context | ||||||
|   | |||||||
| @@ -8,12 +8,16 @@ import ( | |||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/audit" | 	"github.com/hashicorp/vault/audit" | ||||||
|  | 	"github.com/hashicorp/vault/helper/salt" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | ||||||
| 	if conf.Salt == nil { | 	if conf.SaltConfig == nil { | ||||||
| 		return nil, fmt.Errorf("nil salt") | 		return nil, fmt.Errorf("nil salt config") | ||||||
|  | 	} | ||||||
|  | 	if conf.SaltView == nil { | ||||||
|  | 		return nil, fmt.Errorf("nil salt view") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	path, ok := conf.Config["file_path"] | 	path, ok := conf.Config["file_path"] | ||||||
| @@ -65,11 +69,12 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	b := &Backend{ | 	b := &Backend{ | ||||||
| 		path: path, | 		path:       path, | ||||||
| 		mode: mode, | 		mode:       mode, | ||||||
|  | 		saltConfig: conf.SaltConfig, | ||||||
|  | 		saltView:   conf.SaltView, | ||||||
| 		formatConfig: audit.FormatterConfig{ | 		formatConfig: audit.FormatterConfig{ | ||||||
| 			Raw:          logRaw, | 			Raw:          logRaw, | ||||||
| 			Salt:         conf.Salt, |  | ||||||
| 			HMACAccessor: hmacAccessor, | 			HMACAccessor: hmacAccessor, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| @@ -77,11 +82,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | |||||||
| 	switch format { | 	switch format { | ||||||
| 	case "json": | 	case "json": | ||||||
| 		b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ | 		b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ | ||||||
| 			Prefix: conf.Config["prefix"], | 			Prefix:   conf.Config["prefix"], | ||||||
|  | 			SaltFunc: b.Salt, | ||||||
| 		} | 		} | ||||||
| 	case "jsonx": | 	case "jsonx": | ||||||
| 		b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ | 		b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ | ||||||
| 			Prefix: conf.Config["prefix"], | 			Prefix:   conf.Config["prefix"], | ||||||
|  | 			SaltFunc: b.Salt, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -109,10 +116,39 @@ type Backend struct { | |||||||
| 	fileLock sync.RWMutex | 	fileLock sync.RWMutex | ||||||
| 	f        *os.File | 	f        *os.File | ||||||
| 	mode     os.FileMode | 	mode     os.FileMode | ||||||
|  |  | ||||||
|  | 	saltMutex  sync.RWMutex | ||||||
|  | 	salt       *salt.Salt | ||||||
|  | 	saltConfig *salt.Config | ||||||
|  | 	saltView   logical.Storage | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *Backend) GetHash(data string) string { | func (b *Backend) Salt() (*salt.Salt, error) { | ||||||
| 	return audit.HashString(b.formatConfig.Salt, data) | 	b.saltMutex.RLock() | ||||||
|  | 	if b.salt != nil { | ||||||
|  | 		defer b.saltMutex.RUnlock() | ||||||
|  | 		return b.salt, nil | ||||||
|  | 	} | ||||||
|  | 	b.saltMutex.RUnlock() | ||||||
|  | 	b.saltMutex.Lock() | ||||||
|  | 	defer b.saltMutex.Unlock() | ||||||
|  | 	if b.salt != nil { | ||||||
|  | 		return b.salt, nil | ||||||
|  | 	} | ||||||
|  | 	salt, err := salt.NewSalt(b.saltView, b.saltConfig) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	b.salt = salt | ||||||
|  | 	return salt, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *Backend) GetHash(data string) (string, error) { | ||||||
|  | 	salt, err := b.Salt() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	return audit.HashString(salt, data), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { | func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { | ||||||
| @@ -189,3 +225,9 @@ func (b *Backend) Reload() error { | |||||||
|  |  | ||||||
| 	return b.open() | 	return b.open() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (b *Backend) Invalidate() { | ||||||
|  | 	b.saltMutex.Lock() | ||||||
|  | 	defer b.saltMutex.Unlock() | ||||||
|  | 	b.salt = nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -9,11 +9,10 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/hashicorp/vault/audit" | 	"github.com/hashicorp/vault/audit" | ||||||
| 	"github.com/hashicorp/vault/helper/salt" | 	"github.com/hashicorp/vault/helper/salt" | ||||||
|  | 	"github.com/hashicorp/vault/logical" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestAuditFile_fileModeNew(t *testing.T) { | func TestAuditFile_fileModeNew(t *testing.T) { | ||||||
| 	salter, _ := salt.NewSalt(nil, nil) |  | ||||||
|  |  | ||||||
| 	modeStr := "0777" | 	modeStr := "0777" | ||||||
| 	mode, err := strconv.ParseUint(modeStr, 8, 32) | 	mode, err := strconv.ParseUint(modeStr, 8, 32) | ||||||
|  |  | ||||||
| @@ -28,8 +27,9 @@ func TestAuditFile_fileModeNew(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	_, err = Factory(&audit.BackendConfig{ | 	_, err = Factory(&audit.BackendConfig{ | ||||||
| 		Salt:   salter, | 		SaltConfig: &salt.Config{}, | ||||||
| 		Config: config, | 		SaltView:   &logical.InmemStorage{}, | ||||||
|  | 		Config:     config, | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| @@ -45,8 +45,6 @@ func TestAuditFile_fileModeNew(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestAuditFile_fileModeExisting(t *testing.T) { | func TestAuditFile_fileModeExisting(t *testing.T) { | ||||||
| 	salter, _ := salt.NewSalt(nil, nil) |  | ||||||
|  |  | ||||||
| 	f, err := ioutil.TempFile("", "test") | 	f, err := ioutil.TempFile("", "test") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failure to create test file.") | 		t.Fatalf("Failure to create test file.") | ||||||
| @@ -68,8 +66,9 @@ func TestAuditFile_fileModeExisting(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	_, err = Factory(&audit.BackendConfig{ | 	_, err = Factory(&audit.BackendConfig{ | ||||||
| 		Salt:   salter, | 		Config:     config, | ||||||
| 		Config: config, | 		SaltConfig: &salt.Config{}, | ||||||
|  | 		SaltView:   &logical.InmemStorage{}, | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
|   | |||||||
| @@ -11,12 +11,16 @@ import ( | |||||||
| 	multierror "github.com/hashicorp/go-multierror" | 	multierror "github.com/hashicorp/go-multierror" | ||||||
| 	"github.com/hashicorp/vault/audit" | 	"github.com/hashicorp/vault/audit" | ||||||
| 	"github.com/hashicorp/vault/helper/parseutil" | 	"github.com/hashicorp/vault/helper/parseutil" | ||||||
|  | 	"github.com/hashicorp/vault/helper/salt" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | ||||||
| 	if conf.Salt == nil { | 	if conf.SaltConfig == nil { | ||||||
| 		return nil, fmt.Errorf("nil salt passed in") | 		return nil, fmt.Errorf("nil salt config") | ||||||
|  | 	} | ||||||
|  | 	if conf.SaltView == nil { | ||||||
|  | 		return nil, fmt.Errorf("nil salt view") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	address, ok := conf.Config["address"] | 	address, ok := conf.Config["address"] | ||||||
| @@ -75,11 +79,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | |||||||
|  |  | ||||||
| 	b := &Backend{ | 	b := &Backend{ | ||||||
| 		connection: conn, | 		connection: conn, | ||||||
|  | 		saltConfig: conf.SaltConfig, | ||||||
|  | 		saltView:   conf.SaltView, | ||||||
| 		formatConfig: audit.FormatterConfig{ | 		formatConfig: audit.FormatterConfig{ | ||||||
| 			Raw:          logRaw, | 			Raw:          logRaw, | ||||||
| 			Salt:         conf.Salt, |  | ||||||
| 			HMACAccessor: hmacAccessor, | 			HMACAccessor: hmacAccessor, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
| 		writeDuration: writeDuration, | 		writeDuration: writeDuration, | ||||||
| 		address:       address, | 		address:       address, | ||||||
| 		socketType:    socketType, | 		socketType:    socketType, | ||||||
| @@ -88,11 +94,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | |||||||
| 	switch format { | 	switch format { | ||||||
| 	case "json": | 	case "json": | ||||||
| 		b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ | 		b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ | ||||||
| 			Prefix: conf.Config["prefix"], | 			Prefix:   conf.Config["prefix"], | ||||||
|  | 			SaltFunc: b.Salt, | ||||||
| 		} | 		} | ||||||
| 	case "jsonx": | 	case "jsonx": | ||||||
| 		b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ | 		b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ | ||||||
| 			Prefix: conf.Config["prefix"], | 			Prefix:   conf.Config["prefix"], | ||||||
|  | 			SaltFunc: b.Salt, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -111,10 +119,19 @@ type Backend struct { | |||||||
| 	socketType    string | 	socketType    string | ||||||
|  |  | ||||||
| 	sync.Mutex | 	sync.Mutex | ||||||
|  |  | ||||||
|  | 	saltMutex  sync.RWMutex | ||||||
|  | 	salt       *salt.Salt | ||||||
|  | 	saltConfig *salt.Config | ||||||
|  | 	saltView   logical.Storage | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *Backend) GetHash(data string) string { | func (b *Backend) GetHash(data string) (string, error) { | ||||||
| 	return audit.HashString(b.formatConfig.Salt, data) | 	salt, err := b.Salt() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	return audit.HashString(salt, data), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { | func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { | ||||||
| @@ -198,3 +215,29 @@ func (b *Backend) Reload() error { | |||||||
|  |  | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (b *Backend) Salt() (*salt.Salt, error) { | ||||||
|  | 	b.saltMutex.RLock() | ||||||
|  | 	if b.salt != nil { | ||||||
|  | 		defer b.saltMutex.RUnlock() | ||||||
|  | 		return b.salt, nil | ||||||
|  | 	} | ||||||
|  | 	b.saltMutex.RUnlock() | ||||||
|  | 	b.saltMutex.Lock() | ||||||
|  | 	defer b.saltMutex.Unlock() | ||||||
|  | 	if b.salt != nil { | ||||||
|  | 		return b.salt, nil | ||||||
|  | 	} | ||||||
|  | 	salt, err := salt.NewSalt(b.saltView, b.saltConfig) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	b.salt = salt | ||||||
|  | 	return salt, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *Backend) Invalidate() { | ||||||
|  | 	b.saltMutex.Lock() | ||||||
|  | 	defer b.saltMutex.Unlock() | ||||||
|  | 	b.salt = nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -4,15 +4,20 @@ import ( | |||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/hashicorp/go-syslog" | 	"github.com/hashicorp/go-syslog" | ||||||
| 	"github.com/hashicorp/vault/audit" | 	"github.com/hashicorp/vault/audit" | ||||||
|  | 	"github.com/hashicorp/vault/helper/salt" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | ||||||
| 	if conf.Salt == nil { | 	if conf.SaltConfig == nil { | ||||||
| 		return nil, fmt.Errorf("Nil salt passed in") | 		return nil, fmt.Errorf("nil salt config") | ||||||
|  | 	} | ||||||
|  | 	if conf.SaltView == nil { | ||||||
|  | 		return nil, fmt.Errorf("nil salt view") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Get facility or default to AUTH | 	// Get facility or default to AUTH | ||||||
| @@ -64,10 +69,11 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	b := &Backend{ | 	b := &Backend{ | ||||||
| 		logger: logger, | 		logger:     logger, | ||||||
|  | 		saltConfig: conf.SaltConfig, | ||||||
|  | 		saltView:   conf.SaltView, | ||||||
| 		formatConfig: audit.FormatterConfig{ | 		formatConfig: audit.FormatterConfig{ | ||||||
| 			Raw:          logRaw, | 			Raw:          logRaw, | ||||||
| 			Salt:         conf.Salt, |  | ||||||
| 			HMACAccessor: hmacAccessor, | 			HMACAccessor: hmacAccessor, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| @@ -75,11 +81,13 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) { | |||||||
| 	switch format { | 	switch format { | ||||||
| 	case "json": | 	case "json": | ||||||
| 		b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ | 		b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ | ||||||
| 			Prefix: conf.Config["prefix"], | 			Prefix:   conf.Config["prefix"], | ||||||
|  | 			SaltFunc: b.Salt, | ||||||
| 		} | 		} | ||||||
| 	case "jsonx": | 	case "jsonx": | ||||||
| 		b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ | 		b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ | ||||||
| 			Prefix: conf.Config["prefix"], | 			Prefix:   conf.Config["prefix"], | ||||||
|  | 			SaltFunc: b.Salt, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -92,10 +100,19 @@ type Backend struct { | |||||||
|  |  | ||||||
| 	formatter    audit.AuditFormatter | 	formatter    audit.AuditFormatter | ||||||
| 	formatConfig audit.FormatterConfig | 	formatConfig audit.FormatterConfig | ||||||
|  |  | ||||||
|  | 	saltMutex  sync.RWMutex | ||||||
|  | 	salt       *salt.Salt | ||||||
|  | 	saltConfig *salt.Config | ||||||
|  | 	saltView   logical.Storage | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *Backend) GetHash(data string) string { | func (b *Backend) GetHash(data string) (string, error) { | ||||||
| 	return audit.HashString(b.formatConfig.Salt, data) | 	salt, err := b.Salt() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	return audit.HashString(salt, data), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { | func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error { | ||||||
| @@ -123,3 +140,29 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *lo | |||||||
| func (b *Backend) Reload() error { | func (b *Backend) Reload() error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (b *Backend) Salt() (*salt.Salt, error) { | ||||||
|  | 	b.saltMutex.RLock() | ||||||
|  | 	if b.salt != nil { | ||||||
|  | 		defer b.saltMutex.RUnlock() | ||||||
|  | 		return b.salt, nil | ||||||
|  | 	} | ||||||
|  | 	b.saltMutex.RUnlock() | ||||||
|  | 	b.saltMutex.Lock() | ||||||
|  | 	defer b.saltMutex.Unlock() | ||||||
|  | 	if b.salt != nil { | ||||||
|  | 		return b.salt, nil | ||||||
|  | 	} | ||||||
|  | 	salt, err := salt.NewSalt(b.saltView, b.saltConfig) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	b.salt = salt | ||||||
|  | 	return salt, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (b *Backend) Invalidate() { | ||||||
|  | 	b.saltMutex.Lock() | ||||||
|  | 	defer b.saltMutex.Unlock() | ||||||
|  | 	b.salt = nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -368,17 +368,16 @@ func (c *Core) newAuditBackend(entry *MountEntry, view logical.Storage, conf map | |||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, fmt.Errorf("unknown backend type: %s", entry.Type) | 		return nil, fmt.Errorf("unknown backend type: %s", entry.Type) | ||||||
| 	} | 	} | ||||||
| 	salter, err := salt.NewSalt(view, &salt.Config{ | 	saltConfig := &salt.Config{ | ||||||
| 		HMAC:     sha256.New, | 		HMAC:     sha256.New, | ||||||
| 		HMACType: "hmac-sha256", | 		HMACType: "hmac-sha256", | ||||||
| 	}) | 		Location: salt.DefaultLocation, | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, fmt.Errorf("core: unable to generate salt: %v", err) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	be, err := f(&audit.BackendConfig{ | 	be, err := f(&audit.BackendConfig{ | ||||||
| 		Salt:   salter, | 		SaltView:   view, | ||||||
| 		Config: conf, | 		SaltConfig: saltConfig, | ||||||
|  | 		Config:     conf, | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -474,20 +473,25 @@ func (a *AuditBroker) GetHash(name string, input string) (string, error) { | |||||||
| 		return "", fmt.Errorf("unknown audit backend %s", name) | 		return "", fmt.Errorf("unknown audit backend %s", name) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return be.backend.GetHash(input), nil | 	return be.backend.GetHash(input) | ||||||
| } | } | ||||||
|  |  | ||||||
| // LogRequest is used to ensure all the audit backends have an opportunity to | // LogRequest is used to ensure all the audit backends have an opportunity to | ||||||
| // log the given request and that *at least one* succeeds. | // log the given request and that *at least one* succeeds. | ||||||
| func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, headersConfig *AuditedHeadersConfig, outerErr error) (retErr error) { | func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, headersConfig *AuditedHeadersConfig, outerErr error) (ret error) { | ||||||
| 	defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now()) | 	defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now()) | ||||||
| 	a.RLock() | 	a.RLock() | ||||||
| 	defer a.RUnlock() | 	defer a.RUnlock() | ||||||
|  |  | ||||||
|  | 	var retErr *multierror.Error | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		if r := recover(); r != nil { | 		if r := recover(); r != nil { | ||||||
| 			a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) | 			a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) | ||||||
| 			retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log")) | 			retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log")) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		ret = retErr.ErrorOrNil() | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	// All logged requests must have an identifier | 	// All logged requests must have an identifier | ||||||
| @@ -506,36 +510,46 @@ func (a *AuditBroker) LogRequest(auth *logical.Auth, req *logical.Request, heade | |||||||
| 	anyLogged := false | 	anyLogged := false | ||||||
| 	for name, be := range a.backends { | 	for name, be := range a.backends { | ||||||
| 		req.Headers = nil | 		req.Headers = nil | ||||||
| 		req.Headers = headersConfig.ApplyConfig(headers, be.backend.GetHash) | 		transHeaders, thErr := headersConfig.ApplyConfig(headers, be.backend.GetHash) | ||||||
|  | 		if thErr != nil { | ||||||
|  | 			a.logger.Error("audit: backend failed to include headers", "backend", name, "error", thErr) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		req.Headers = transHeaders | ||||||
|  |  | ||||||
| 		start := time.Now() | 		start := time.Now() | ||||||
| 		err := be.backend.LogRequest(auth, req, outerErr) | 		lrErr := be.backend.LogRequest(auth, req, outerErr) | ||||||
| 		metrics.MeasureSince([]string{"audit", name, "log_request"}, start) | 		metrics.MeasureSince([]string{"audit", name, "log_request"}, start) | ||||||
| 		if err != nil { | 		if lrErr != nil { | ||||||
| 			a.logger.Error("audit: backend failed to log request", "backend", name, "error", err) | 			a.logger.Error("audit: backend failed to log request", "backend", name, "error", lrErr) | ||||||
| 		} else { | 		} else { | ||||||
| 			anyLogged = true | 			anyLogged = true | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if !anyLogged && len(a.backends) > 0 { | 	if !anyLogged && len(a.backends) > 0 { | ||||||
| 		retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the request")) | 		retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the request")) | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 	return nil |  | ||||||
|  | 	return retErr.ErrorOrNil() | ||||||
| } | } | ||||||
|  |  | ||||||
| // LogResponse is used to ensure all the audit backends have an opportunity to | // LogResponse is used to ensure all the audit backends have an opportunity to | ||||||
| // log the given response and that *at least one* succeeds. | // log the given response and that *at least one* succeeds. | ||||||
| func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request, | func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request, | ||||||
| 	resp *logical.Response, headersConfig *AuditedHeadersConfig, err error) (reterr error) { | 	resp *logical.Response, headersConfig *AuditedHeadersConfig, err error) (ret error) { | ||||||
| 	defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now()) | 	defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now()) | ||||||
| 	a.RLock() | 	a.RLock() | ||||||
| 	defer a.RUnlock() | 	defer a.RUnlock() | ||||||
|  |  | ||||||
|  | 	var retErr *multierror.Error | ||||||
|  |  | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		if r := recover(); r != nil { | 		if r := recover(); r != nil { | ||||||
| 			a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) | 			a.logger.Error("audit: panic during logging", "request_path", req.Path, "error", r) | ||||||
| 			reterr = fmt.Errorf("panic generating audit log") | 			retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log")) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		ret = retErr.ErrorOrNil() | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	headers := req.Headers | 	headers := req.Headers | ||||||
| @@ -547,19 +561,35 @@ func (a *AuditBroker) LogResponse(auth *logical.Auth, req *logical.Request, | |||||||
| 	anyLogged := false | 	anyLogged := false | ||||||
| 	for name, be := range a.backends { | 	for name, be := range a.backends { | ||||||
| 		req.Headers = nil | 		req.Headers = nil | ||||||
| 		req.Headers = headersConfig.ApplyConfig(headers, be.backend.GetHash) | 		transHeaders, thErr := headersConfig.ApplyConfig(headers, be.backend.GetHash) | ||||||
|  | 		if thErr != nil { | ||||||
|  | 			a.logger.Error("audit: backend failed to include headers", "backend", name, "error", thErr) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		req.Headers = transHeaders | ||||||
|  |  | ||||||
| 		start := time.Now() | 		start := time.Now() | ||||||
| 		err := be.backend.LogResponse(auth, req, resp, err) | 		lrErr := be.backend.LogResponse(auth, req, resp, err) | ||||||
| 		metrics.MeasureSince([]string{"audit", name, "log_response"}, start) | 		metrics.MeasureSince([]string{"audit", name, "log_response"}, start) | ||||||
| 		if err != nil { | 		if lrErr != nil { | ||||||
| 			a.logger.Error("audit: backend failed to log response", "backend", name, "error", err) | 			a.logger.Error("audit: backend failed to log response", "backend", name, "error", lrErr) | ||||||
| 		} else { | 		} else { | ||||||
| 			anyLogged = true | 			anyLogged = true | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if !anyLogged && len(a.backends) > 0 { | 	if !anyLogged && len(a.backends) > 0 { | ||||||
| 		return fmt.Errorf("no audit backend succeeded in logging the response") | 		retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the response")) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return retErr.ErrorOrNil() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *AuditBroker) Invalidate(key string) { | ||||||
|  | 	// For now we ignore the key as this would only apply to salts. We just | ||||||
|  | 	// sort of brute force it on each one. | ||||||
|  | 	a.Lock() | ||||||
|  | 	defer a.Unlock() | ||||||
|  | 	for _, be := range a.backends { | ||||||
|  | 		be.backend.Invalidate() | ||||||
| 	} | 	} | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -3,6 +3,8 @@ package vault | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -13,6 +15,7 @@ import ( | |||||||
| 	"github.com/hashicorp/vault/audit" | 	"github.com/hashicorp/vault/audit" | ||||||
| 	"github.com/hashicorp/vault/helper/jsonutil" | 	"github.com/hashicorp/vault/helper/jsonutil" | ||||||
| 	"github.com/hashicorp/vault/helper/logformat" | 	"github.com/hashicorp/vault/helper/logformat" | ||||||
|  | 	"github.com/hashicorp/vault/helper/salt" | ||||||
| 	"github.com/hashicorp/vault/logical" | 	"github.com/hashicorp/vault/logical" | ||||||
| 	log "github.com/mgutz/logxi/v1" | 	log "github.com/mgutz/logxi/v1" | ||||||
| 	"github.com/mitchellh/copystructure" | 	"github.com/mitchellh/copystructure" | ||||||
| @@ -31,6 +34,9 @@ type NoopAudit struct { | |||||||
| 	RespReq  []*logical.Request | 	RespReq  []*logical.Request | ||||||
| 	Resp     []*logical.Response | 	Resp     []*logical.Response | ||||||
| 	RespErrs []error | 	RespErrs []error | ||||||
|  |  | ||||||
|  | 	salt      *salt.Salt | ||||||
|  | 	saltMutex sync.RWMutex | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *NoopAudit) LogRequest(a *logical.Auth, r *logical.Request, err error) error { | func (n *NoopAudit) LogRequest(a *logical.Auth, r *logical.Request, err error) error { | ||||||
| @@ -49,14 +55,44 @@ func (n *NoopAudit) LogResponse(a *logical.Auth, r *logical.Request, re *logical | |||||||
| 	return n.RespErr | 	return n.RespErr | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *NoopAudit) GetHash(data string) string { | func (n *NoopAudit) Salt() (*salt.Salt, error) { | ||||||
| 	return n.Config.Salt.GetIdentifiedHMAC(data) | 	n.saltMutex.RLock() | ||||||
|  | 	if n.salt != nil { | ||||||
|  | 		defer n.saltMutex.RUnlock() | ||||||
|  | 		return n.salt, nil | ||||||
|  | 	} | ||||||
|  | 	n.saltMutex.RUnlock() | ||||||
|  | 	n.saltMutex.Lock() | ||||||
|  | 	defer n.saltMutex.Unlock() | ||||||
|  | 	if n.salt != nil { | ||||||
|  | 		return n.salt, nil | ||||||
|  | 	} | ||||||
|  | 	salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	n.salt = salt | ||||||
|  | 	return salt, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *NoopAudit) GetHash(data string) (string, error) { | ||||||
|  | 	salt, err := n.Salt() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	return salt.GetIdentifiedHMAC(data), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *NoopAudit) Reload() error { | func (n *NoopAudit) Reload() error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (n *NoopAudit) Invalidate() { | ||||||
|  | 	n.saltMutex.Lock() | ||||||
|  | 	defer n.saltMutex.Unlock() | ||||||
|  | 	n.salt = nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestCore_EnableAudit(t *testing.T) { | func TestCore_EnableAudit(t *testing.T) { | ||||||
| 	c, keys, _ := TestCoreUnsealed(t) | 	c, keys, _ := TestCoreUnsealed(t) | ||||||
| 	c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { | 	c.auditBackends["noop"] = func(config *audit.BackendConfig) (audit.Backend, error) { | ||||||
| @@ -508,7 +544,7 @@ func TestAuditBroker_LogResponse(t *testing.T) { | |||||||
| 			t.Fatalf("Bad: %#v", a.Resp[0]) | 			t.Fatalf("Bad: %#v", a.Resp[0]) | ||||||
| 		} | 		} | ||||||
| 		if !reflect.DeepEqual(a.RespErrs[0], respErr) { | 		if !reflect.DeepEqual(a.RespErrs[0], respErr) { | ||||||
| 			t.Fatalf("Bad: %#v", a.RespErrs[0]) | 			t.Fatalf("Expected\n%v\nGot\n%#v", respErr, a.RespErrs[0]) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -522,7 +558,7 @@ func TestAuditBroker_LogResponse(t *testing.T) { | |||||||
| 	// Should FAIL work with both failing backends | 	// Should FAIL work with both failing backends | ||||||
| 	a2.RespErr = fmt.Errorf("failed") | 	a2.RespErr = fmt.Errorf("failed") | ||||||
| 	err = b.LogResponse(auth, req, resp, headersConf, respErr) | 	err = b.LogResponse(auth, req, resp, headersConf, respErr) | ||||||
| 	if err.Error() != "no audit backend succeeded in logging the response" { | 	if !strings.Contains(err.Error(), "no audit backend succeeded in logging the response") { | ||||||
| 		t.Fatalf("err: %v", err) | 		t.Fatalf("err: %v", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -88,7 +88,7 @@ func (a *AuditedHeadersConfig) remove(header string) error { | |||||||
|  |  | ||||||
| // ApplyConfig returns a map of approved headers and their values, either | // ApplyConfig returns a map of approved headers and their values, either | ||||||
| // hmac'ed or plaintext | // hmac'ed or plaintext | ||||||
| func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc func(string) string) (result map[string][]string) { | func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc func(string) (string, error)) (result map[string][]string, retErr error) { | ||||||
| 	// Grab a read lock | 	// Grab a read lock | ||||||
| 	a.RLock() | 	a.RLock() | ||||||
| 	defer a.RUnlock() | 	defer a.RUnlock() | ||||||
| @@ -110,7 +110,11 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc | |||||||
| 			// Optionally hmac the values | 			// Optionally hmac the values | ||||||
| 			if settings.HMAC { | 			if settings.HMAC { | ||||||
| 				for i, el := range hVals { | 				for i, el := range hVals { | ||||||
| 					hVals[i] = hashFunc(el) | 					hVal, err := hashFunc(el) | ||||||
|  | 					if err != nil { | ||||||
|  | 						return nil, err | ||||||
|  | 					} | ||||||
|  | 					hVals[i] = hVal | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| @@ -118,7 +122,7 @@ func (a *AuditedHeadersConfig) ApplyConfig(headers map[string][]string, hashFunc | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return | 	return result, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Initalize the headers config by loading from the barrier view | // Initalize the headers config by loading from the barrier view | ||||||
|   | |||||||
| @@ -166,9 +166,12 @@ func TestAuditedHeadersConfig_ApplyConfig(t *testing.T) { | |||||||
| 		"Content-Type":   []string{"json"}, | 		"Content-Type":   []string{"json"}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	hashFunc := func(s string) string { return "hashed" } | 	hashFunc := func(s string) (string, error) { return "hashed", nil } | ||||||
|  |  | ||||||
| 	result := conf.ApplyConfig(reqHeaders, hashFunc) | 	result, err := conf.ApplyConfig(reqHeaders, hashFunc) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	expected := map[string][]string{ | 	expected := map[string][]string{ | ||||||
| 		"x-test-header":  []string{"foo"}, | 		"x-test-header":  []string{"foo"}, | ||||||
| @@ -214,7 +217,7 @@ func BenchmarkAuditedHeaderConfig_ApplyConfig(b *testing.B) { | |||||||
| 		b.Fatal(err) | 		b.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	hashFunc := func(s string) string { return salter.GetIdentifiedHMAC(s) } | 	hashFunc := func(s string) (string, error) { return salter.GetIdentifiedHMAC(s), nil } | ||||||
|  |  | ||||||
| 	// Reset the timer since we did a lot above | 	// Reset the timer since we did a lot above | ||||||
| 	b.ResetTimer() | 	b.ResetTimer() | ||||||
|   | |||||||
| @@ -1254,13 +1254,11 @@ func TestSystemBackend_auditHash(t *testing.T) { | |||||||
| 			Key:   "salt", | 			Key:   "salt", | ||||||
| 			Value: []byte("foo"), | 			Value: []byte("foo"), | ||||||
| 		}) | 		}) | ||||||
| 		var err error | 		config.SaltView = view | ||||||
| 		config.Salt, err = salt.NewSalt(view, &salt.Config{ | 		config.SaltConfig = &salt.Config{ | ||||||
| 			HMAC:     sha256.New, | 			HMAC:     sha256.New, | ||||||
| 			HMACType: "hmac-sha256", | 			HMACType: "hmac-sha256", | ||||||
| 		}) | 			Location: salt.DefaultLocation, | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatalf("error getting new salt: %v", err) |  | ||||||
| 		} | 		} | ||||||
| 		return &NoopAudit{ | 		return &NoopAudit{ | ||||||
| 			Config: config, | 			Config: config, | ||||||
|   | |||||||
| @@ -14,6 +14,7 @@ import ( | |||||||
| 	"os" | 	"os" | ||||||
| 	"os/exec" | 	"os/exec" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -111,14 +112,11 @@ func testCoreConfig(t testing.TB, physicalBackend physical.Backend, logger log.L | |||||||
| 				Key:   "salt", | 				Key:   "salt", | ||||||
| 				Value: []byte("foo"), | 				Value: []byte("foo"), | ||||||
| 			}) | 			}) | ||||||
| 			var err error | 			config.SaltConfig = &salt.Config{ | ||||||
| 			config.Salt, err = salt.NewSalt(view, &salt.Config{ |  | ||||||
| 				HMAC:     sha256.New, | 				HMAC:     sha256.New, | ||||||
| 				HMACType: "hmac-sha256", | 				HMACType: "hmac-sha256", | ||||||
| 			}) |  | ||||||
| 			if err != nil { |  | ||||||
| 				t.Fatalf("error getting new salt: %v", err) |  | ||||||
| 			} | 			} | ||||||
|  | 			config.SaltView = view | ||||||
| 			return &noopAudit{ | 			return &noopAudit{ | ||||||
| 				Config: config, | 				Config: config, | ||||||
| 			}, nil | 			}, nil | ||||||
| @@ -442,11 +440,17 @@ func AddTestLogicalBackend(name string, factory logical.Factory) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| type noopAudit struct { | type noopAudit struct { | ||||||
| 	Config *audit.BackendConfig | 	Config    *audit.BackendConfig | ||||||
|  | 	salt      *salt.Salt | ||||||
|  | 	saltMutex sync.RWMutex | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *noopAudit) GetHash(data string) string { | func (n *noopAudit) GetHash(data string) (string, error) { | ||||||
| 	return n.Config.Salt.GetIdentifiedHMAC(data) | 	salt, err := n.Salt() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  | 	return salt.GetIdentifiedHMAC(data), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error { | func (n *noopAudit) LogRequest(a *logical.Auth, r *logical.Request, e error) error { | ||||||
| @@ -461,6 +465,32 @@ func (n *noopAudit) Reload() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (n *noopAudit) Invalidate() { | ||||||
|  | 	n.saltMutex.Lock() | ||||||
|  | 	defer n.saltMutex.Unlock() | ||||||
|  | 	n.salt = nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *noopAudit) Salt() (*salt.Salt, error) { | ||||||
|  | 	n.saltMutex.RLock() | ||||||
|  | 	if n.salt != nil { | ||||||
|  | 		defer n.saltMutex.RUnlock() | ||||||
|  | 		return n.salt, nil | ||||||
|  | 	} | ||||||
|  | 	n.saltMutex.RUnlock() | ||||||
|  | 	n.saltMutex.Lock() | ||||||
|  | 	defer n.saltMutex.Unlock() | ||||||
|  | 	if n.salt != nil { | ||||||
|  | 		return n.salt, nil | ||||||
|  | 	} | ||||||
|  | 	salt, err := salt.NewSalt(n.Config.SaltView, n.Config.SaltConfig) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	n.salt = salt | ||||||
|  | 	return salt, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| type rawHTTP struct{} | type rawHTTP struct{} | ||||||
|  |  | ||||||
| func (n *rawHTTP) HandleRequest(req *logical.Request) (*logical.Response, error) { | func (n *rawHTTP) HandleRequest(req *logical.Request) (*logical.Response, error) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jeff Mitchell
					Jeff Mitchell