Add write deadline and a Reload function

This commit is contained in:
Brian Kassouf
2017-02-02 15:44:56 -08:00
parent 6da4806582
commit b32cb4bedf

View File

@@ -5,8 +5,11 @@ import (
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"sync"
"time"
"github.com/hashicorp/vault/audit" "github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/helper/duration"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
@@ -20,9 +23,18 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
return nil, fmt.Errorf("address is required") return nil, fmt.Errorf("address is required")
} }
socket_type, ok := conf.Config["socket_type"] socketType, ok := conf.Config["socket_type"]
if !ok { if !ok {
socket_type = "tcp" socketType = "tcp"
}
writeDeadline, ok := conf.Config["write_deadline"]
if !ok {
writeDeadline = "2s"
}
writeDuration, err := duration.ParseDurationSecond(writeDeadline)
if err != nil {
return nil, err
} }
format, ok := conf.Config["format"] format, ok := conf.Config["format"]
@@ -55,7 +67,7 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
logRaw = b logRaw = b
} }
conn, err := net.Dial(socket_type, address) conn, err := net.Dial(socketType, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -67,6 +79,9 @@ func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
Salt: conf.Salt, Salt: conf.Salt,
HMACAccessor: hmacAccessor, HMACAccessor: hmacAccessor,
}, },
writeDuration: writeDuration,
address: address,
socketType: socketType,
} }
switch format { switch format {
@@ -85,6 +100,12 @@ type Backend struct {
formatter audit.AuditFormatter formatter audit.AuditFormatter
formatConfig audit.FormatterConfig formatConfig audit.FormatterConfig
writeDuration time.Duration
address string
socketType string
sync.Mutex
} }
func (b *Backend) GetHash(data string) string { func (b *Backend) GetHash(data string) string {
@@ -97,20 +118,50 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
return err return err
} }
b.connection.Write(buf.Bytes()) b.Lock()
return nil
b.connection.SetDeadline(time.Now().Add(b.writeDuration))
_, err := b.connection.Write(buf.Bytes())
b.Unlock()
if err != nil {
b.Reload()
}
return err
} }
func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request,
resp *logical.Response, err error) error { resp *logical.Response, outerErr error) error {
var buf bytes.Buffer var buf bytes.Buffer
if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, err); err != nil { if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, outerErr); err != nil {
return err return err
} }
b.connection.Write(buf.Bytes())
return nil b.Lock()
b.connection.SetDeadline(time.Now().Add(b.writeDuration))
_, err := b.connection.Write(buf.Bytes())
b.Unlock()
if err != nil {
b.Reload()
}
return err
} }
func (b *Backend) Reload() error { func (b *Backend) Reload() error {
b.Lock()
defer b.Unlock()
conn, err := net.Dial(b.socketType, b.address)
if err != nil {
return err
}
b.connection.Close()
b.connection = conn
return nil return nil
} }