diff --git a/command/operator_diagnose.go b/command/operator_diagnose.go index 95ca896f63..b66dea73eb 100644 --- a/command/operator_diagnose.go +++ b/command/operator_diagnose.go @@ -37,6 +37,7 @@ type OperatorDiagnoseCommand struct { reloadFuncs *map[string][]reloadutil.ReloadFunc startedCh chan struct{} // for tests reloadedCh chan struct{} // for tests + skipEndEnd bool // for tests } func (c *OperatorDiagnoseCommand) Synopsis() string { @@ -224,11 +225,8 @@ func (c *OperatorDiagnoseCommand) offlineDiagnostics(ctx context.Context) error // Errors in these items could stop Vault from starting but are not yet covered: // TODO: logging configuration // TODO: SetupTelemetry - // TODO: check for storage backend - if err := diagnose.Test(ctx, "storage", func(ctx context.Context) error { - - _, err = server.setupStorage(config) + b, err := server.setupStorage(config) if err != nil { return err } @@ -246,6 +244,15 @@ func (c *OperatorDiagnoseCommand) offlineDiagnostics(ctx context.Context) error return err } } + + // Attempt to use storage backend + if !c.skipEndEnd { + err = diagnose.StorageEndToEndLatencyCheck(ctx, b) + if err != nil { + return err + } + } + return nil }); err != nil { return err diff --git a/command/operator_diagnose_test.go b/command/operator_diagnose_test.go index 7c33c66f11..35c2e9d395 100644 --- a/command/operator_diagnose_test.go +++ b/command/operator_diagnose_test.go @@ -21,6 +21,7 @@ func testOperatorDiagnoseCommand(tb testing.TB) *OperatorDiagnoseCommand { BaseCommand: &BaseCommand{ UI: ui, }, + skipEndEnd: true, } } @@ -215,6 +216,10 @@ func compareResult(t *testing.T, exp *diagnose.Result, act *diagnose.Result) err return fmt.Errorf("names mismatch: %s vs %s", exp.Name, act.Name) } if exp.Status != act.Status { + if act.Status != diagnose.OkStatus { + return fmt.Errorf("section %s, status mismatch: %s vs %s, got error %s", exp.Name, exp.Status, act.Status, act.Message) + + } return fmt.Errorf("section %s, status mismatch: %s vs %s", exp.Name, exp.Status, act.Status) } if exp.Message != "" && exp.Message != act.Message && !strings.Contains(act.Message, exp.Message) { diff --git a/vault/diagnose/mock_storage_backend.go b/vault/diagnose/mock_storage_backend.go new file mode 100644 index 0000000000..dbbc9fb693 --- /dev/null +++ b/vault/diagnose/mock_storage_backend.go @@ -0,0 +1,75 @@ +package diagnose + +import ( + "context" + "fmt" + "time" + + "github.com/hashicorp/vault/sdk/physical" +) + +const ( + timeoutCallRead string = "lag Read" + timeoutCallWrite string = "lag Write" + timeoutCallDelete string = "lag Delete" + errCallWrite string = "err Write" + errCallDelete string = "err Delete" + errCallRead string = "err Read" + badReadCall string = "bad Read" + storageErrStringWrite string = "storage error on write" + storageErrStringRead string = "storage error on read" + storageErrStringDelete string = "storage error on delete" + readOp string = "read" + writeOp string = "write" + deleteOp string = "delete" +) + +var goodEntry physical.Entry = physical.Entry{Key: secretKey, Value: []byte(secretVal)} +var badEntry physical.Entry = physical.Entry{} + +type mockStorageBackend struct { + callType string +} + +func (m mockStorageBackend) storageLogicGeneralInternal(op string) error { + if (m.callType == timeoutCallRead && op == readOp) || (m.callType == timeoutCallWrite && op == writeOp) || + (m.callType == timeoutCallDelete && op == deleteOp) { + time.Sleep(25 * time.Second) + } else if m.callType == errCallWrite && op == writeOp { + return fmt.Errorf(storageErrStringWrite) + } else if m.callType == errCallDelete && op == deleteOp { + return fmt.Errorf(storageErrStringDelete) + } else if m.callType == errCallRead && op == readOp { + return fmt.Errorf(storageErrStringRead) + } + + return nil +} + +// Put is used to insert or update an entry +func (m mockStorageBackend) Put(ctx context.Context, entry *physical.Entry) error { + return m.storageLogicGeneralInternal(writeOp) +} + +// Get is used to fetch an entry +func (m mockStorageBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { + if m.callType == errCallRead || m.callType == timeoutCallRead { + return nil, m.storageLogicGeneralInternal(readOp) + } + if m.callType == badReadCall { + return &badEntry, nil + } + return &goodEntry, nil + +} + +// Delete is used to permanently delete an entry +func (m mockStorageBackend) Delete(ctx context.Context, key string) error { + + return m.storageLogicGeneralInternal(deleteOp) +} + +// List is not used in a mock. +func (m mockStorageBackend) List(ctx context.Context, prefix string) ([]string, error) { + return nil, fmt.Errorf("method not implemented") +} diff --git a/vault/diagnose/storage_checks.go b/vault/diagnose/storage_checks.go new file mode 100644 index 0000000000..5fdf0de2ee --- /dev/null +++ b/vault/diagnose/storage_checks.go @@ -0,0 +1,76 @@ +package diagnose + +import ( + "context" + "fmt" + "time" + + "github.com/hashicorp/vault/sdk/physical" +) + +const ( + success string = "success" + secretKey string = "diagnose" + secretVal string = "diagnoseSecret" + + timeOutErr string = "storage call timed out after 20 seconds: " + wrongRWValsPrefix string = "Storage get and put gave wrong values: " +) + +// StorageEndToEndLatencyCheck calls Write, Read, and Delete on a secret in the root +// directory of the backend. +// Note: Just checking read, write, and delete for root. It's a very basic check, +// but I don't think we can necessarily do any more than that. We could check list, +// but I don't think List is ever going to break in isolation. +func StorageEndToEndLatencyCheck(ctx context.Context, b physical.Backend) error { + + c2 := make(chan error) + go func() { + err := b.Put(context.Background(), &physical.Entry{Key: secretKey, Value: []byte(secretVal)}) + c2 <- err + }() + select { + case errOut := <-c2: + if errOut != nil { + return errOut + } + case <-time.After(20 * time.Second): + return fmt.Errorf(timeOutErr + "operation: Put") + } + + c3 := make(chan *physical.Entry) + c4 := make(chan error) + go func() { + val, err := b.Get(context.Background(), "diagnose") + if err != nil { + c4 <- err + } else { + c3 <- val + } + }() + select { + case err := <-c4: + return err + case val := <-c3: + if val.Key != "diagnose" && string(val.Value) != "diagnose" { + return fmt.Errorf(wrongRWValsPrefix+"expecting diagnose, but got %s, %s", val.Key, val.Value) + } + case <-time.After(20 * time.Second): + return fmt.Errorf(timeOutErr + "operation: Get") + } + + c5 := make(chan error) + go func() { + err := b.Delete(context.Background(), "diagnose") + c5 <- err + }() + select { + case errOut := <-c5: + if errOut != nil { + return errOut + } + case <-time.After(20 * time.Second): + return fmt.Errorf(timeOutErr + "operation: Delete") + } + return nil +} diff --git a/vault/diagnose/storage_checks_test.go b/vault/diagnose/storage_checks_test.go new file mode 100644 index 0000000000..91caec376c --- /dev/null +++ b/vault/diagnose/storage_checks_test.go @@ -0,0 +1,61 @@ +package diagnose + +import ( + "context" + "strings" + "testing" + + "github.com/hashicorp/vault/sdk/physical" +) + +func TestStorageTimeout(t *testing.T) { + + testCases := []struct { + errSubString string + mb physical.Backend + }{ + { + errSubString: timeOutErr + "operation: Put", + mb: mockStorageBackend{callType: timeoutCallWrite}, + }, + { + errSubString: timeOutErr + "operation: Get", + mb: mockStorageBackend{callType: timeoutCallRead}, + }, + { + errSubString: timeOutErr + "operation: Delete", + mb: mockStorageBackend{callType: timeoutCallDelete}, + }, + { + errSubString: storageErrStringWrite, + mb: mockStorageBackend{callType: errCallWrite}, + }, + { + errSubString: storageErrStringDelete, + mb: mockStorageBackend{callType: errCallDelete}, + }, + { + errSubString: storageErrStringRead, + mb: mockStorageBackend{callType: errCallRead}, + }, + { + errSubString: wrongRWValsPrefix, + mb: mockStorageBackend{callType: badReadCall}, + }, + { + errSubString: "", + mb: mockStorageBackend{callType: ""}, + }, + } + + for _, tc := range testCases { + outErr := StorageEndToEndLatencyCheck(context.Background(), tc.mb) + if tc.errSubString == "" && outErr == nil { + // this is the success case where the Storage Latency check passes + continue + } + if !strings.Contains(outErr.Error(), tc.errSubString) { + t.Errorf("wrong error: expected %s to be contained in %s", tc.errSubString, outErr) + } + } +}