diff --git a/logical/framework/backend.go b/logical/framework/backend.go index 70c09f477b..7885a308c2 100644 --- a/logical/framework/backend.go +++ b/logical/framework/backend.go @@ -34,7 +34,7 @@ type Backend struct { // Rollback is called when a WAL entry (see wal.go) has to be rolled // back. It is called with the data from the entry. Boolean true should // be returned on success. Errors should just be logged. - Rollback func(data interface{}) bool + Rollback func(kind string, data interface{}) bool once sync.Once pathsRe []*regexp.Regexp @@ -193,13 +193,13 @@ func (b *Backend) handleRollback( } for _, k := range keys { - data, err := GetWAL(req.Storage, k) + kind, data, err := GetWAL(req.Storage, k) if err != nil { merr = multierror.Append(merr, err) continue } - if b.Rollback(data) { + if b.Rollback(kind, data) { if err := DeleteWAL(req.Storage, k); err != nil { merr = multierror.Append(merr, err) } diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index ef94121e0b..543f2936c9 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -138,7 +138,7 @@ func TestBackendHandleRequest_help(t *testing.T) { func TestBackendHandleRequest_rollback(t *testing.T) { var called uint32 - callback := func(data interface{}) bool { + callback := func(kind string, data interface{}) bool { if data == "foo" { atomic.AddUint32(&called, 1) } @@ -151,7 +151,7 @@ func TestBackendHandleRequest_rollback(t *testing.T) { } storage := new(logical.InmemStorage) - if _, err := PutWAL(storage, "foo"); err != nil { + if _, err := PutWAL(storage, "kind", "foo"); err != nil { t.Fatalf("err: %s", err) } diff --git a/logical/framework/wal.go b/logical/framework/wal.go index 72893a7ba3..e087dfd8e0 100644 --- a/logical/framework/wal.go +++ b/logical/framework/wal.go @@ -12,6 +12,10 @@ const WALPrefix = "wal/" // PutWAL writes some data to the WAL. // +// The kind parameter is used by the framework to allow users to store +// multiple kinds of WAL data and to easily disambiguate what data they're +// expecting. +// // Data within the WAL that is uncommitted (CommitWAL hasn't be called) // will be given to the rollback callback when an rollback operation is // received, allowing the backend to clean up some partial states. @@ -21,8 +25,11 @@ const WALPrefix = "wal/" // This returns a unique ID that can be used to reference this WAL data. // WAL data cannot be modified. You can only add to the WAL and commit existing // WAL entries. -func PutWAL(s logical.Storage, data interface{}) (string, error) { - value, err := json.Marshal(data) +func PutWAL(s logical.Storage, kind string, data interface{}) (string, error) { + value, err := json.Marshal(map[string]interface{}{ + "kind": kind, + "data": data, + }) if err != nil { return "", err } @@ -40,21 +47,23 @@ func PutWAL(s logical.Storage, data interface{}) (string, error) { // GetWAL reads a specific entry from the WAL. If the entry doesn't exist, // then nil value is returned. -func GetWAL(s logical.Storage, id string) (interface{}, error) { +// +// The kind, value, and error are returned. +func GetWAL(s logical.Storage, id string) (string, interface{}, error) { entry, err := s.Get(WALPrefix + id) if err != nil { - return nil, err + return "", nil, err } if entry == nil { - return nil, nil + return "", nil, nil } - var result interface{} + var result map[string]interface{} if err := json.Unmarshal(entry.Value, &result); err != nil { - return nil, err + return "", nil, err } - return result, nil + return result["kind"].(string), result["data"], nil } // DeleteWAL commits the WAL entry with the given ID. Once comitted, diff --git a/logical/framework/wal_test.go b/logical/framework/wal_test.go index 6c93e53f1f..495195e999 100644 --- a/logical/framework/wal_test.go +++ b/logical/framework/wal_test.go @@ -20,7 +20,7 @@ func TestWAL(t *testing.T) { } // Write an entry to the WAL - id, err := PutWAL(s, "bar") + id, err := PutWAL(s, "foo", "bar") if err != nil { t.Fatalf("err: %s", err) } @@ -35,10 +35,13 @@ func TestWAL(t *testing.T) { } // Should be able to get the value - v, err := GetWAL(s, id) + kind, v, err := GetWAL(s, id) if err != nil { t.Fatalf("err: %s", err) } + if kind != "foo" { + t.Fatalf("bad: %#v", kind) + } if v != "bar" { t.Fatalf("bad: %#v", v) } @@ -47,7 +50,7 @@ func TestWAL(t *testing.T) { if err := DeleteWAL(s, id); err != nil { t.Fatalf("err: %s", err) } - v, err = GetWAL(s, id) + _, v, err = GetWAL(s, id) if err != nil { t.Fatalf("err: %s", err) }