diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index fe2eda1555..6c70f82ebc 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -133,32 +133,38 @@ func TestHashString(t *testing.T) { } func TestHashAuth(t *testing.T) { - cases := []struct { + cases := map[string]struct { Input *logical.Auth Output *logical.Auth HMACAccessor bool }{ - { - &logical.Auth{ClientToken: "foo"}, - &logical.Auth{ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a"}, + "no-accessor-hmac": { + &logical.Auth{ + ClientToken: "foo", + Accessor: "very-accessible", + }, + &logical.Auth{ + ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a", + Accessor: "very-accessible", + }, false, }, - { + "accessor-hmac": { &logical.Auth{ LeaseOptions: logical.LeaseOptions{ TTL: 1 * time.Hour, }, - + Accessor: "very-accessible", ClientToken: "foo", }, &logical.Auth{ LeaseOptions: logical.LeaseOptions{ TTL: 1 * time.Hour, }, - + Accessor: "hmac-sha256:5d6d7c8da5b699ace193ea453bbf77082a8aaca42a474436509487d646a7c0af", ClientToken: "hmac-sha256:08ba357e274f528065766c770a639abf6809b39ccfd37c2a3157c7f51954da0a", }, - false, + true, }, } @@ -170,14 +176,9 @@ func TestHashAuth(t *testing.T) { require.NoError(t, err) salter := &TestSalter{} for _, tc := range cases { - input := fmt.Sprintf("%#v", tc.Input) err := hashAuth(context.Background(), salter, tc.Input, tc.HMACAccessor) - if err != nil { - t.Fatalf("err: %s\n\n%s", err, input) - } - if !reflect.DeepEqual(tc.Input, tc.Output) { - t.Fatalf("bad:\nInput:\n%s\nOutput:\n%#v\nExpected output:\n%#v", input, tc.Input, tc.Output) - } + require.NoError(t, err) + require.Equal(t, tc.Output, tc.Input) } } diff --git a/http/logical_test.go b/http/logical_test.go index 682b8f6dea..6550128185 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -34,6 +34,7 @@ import ( "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/sdk/physical/inmem" "github.com/hashicorp/vault/vault" + "github.com/stretchr/testify/require" ) func TestLogical(t *testing.T) { @@ -744,9 +745,12 @@ func TestLogical_AuditPort(t *testing.T) { decoder := json.NewDecoder(auditLogFile) - var auditRecord map[string]interface{} count := 0 - for decoder.Decode(&auditRecord) == nil { + for decoder.More() { + var auditRecord map[string]interface{} + err := decoder.Decode(&auditRecord) + require.NoError(t, err) + count += 1 // Skip the first line @@ -851,14 +855,25 @@ func TestLogical_ErrRelativePath(t *testing.T) { } func testBuiltinPluginMetadataAuditLog(t *testing.T, log map[string]interface{}, expectedMountClass string) { + t.Helper() + if mountClass, ok := log["mount_class"].(string); !ok { t.Fatalf("mount_class should be a string, not %T", log["mount_class"]) } else if mountClass != expectedMountClass { t.Fatalf("bad: mount_class should be %s, not %s", expectedMountClass, mountClass) } - if _, ok := log["mount_running_version"].(string); !ok { - t.Fatalf("mount_running_version should be a string, not %T", log["mount_running_version"]) + // Requests have 'mount_running_version' but Responses have 'mount_running_plugin_version' + runningVersionRaw, runningVersionRawOK := log["mount_running_version"] + runningPluginVersionRaw, runningPluginVersionRawOK := log["mount_running_plugin_version"] + if !runningVersionRawOK && !runningPluginVersionRawOK { + t.Fatalf("mount_running_version/mount_running_plugin_version should be present") + } else if runningVersionRawOK { + if _, ok := runningVersionRaw.(string); !ok { + t.Fatalf("mount_running_version should be string, not %T", runningVersionRaw) + } + } else if _, ok := runningPluginVersionRaw.(string); !ok { + t.Fatalf("mount_running_plugin_version should be string, not %T", runningPluginVersionRaw) } if _, ok := log["mount_running_sha256"].(string); ok { @@ -905,38 +920,45 @@ func TestLogical_AuditEnabled_ShouldLogPluginMetadata_Auth(t *testing.T) { "file_path": auditLogFile.Name(), }, }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) _, err = c.Logical().Write("auth/token/create", map[string]interface{}{ "ttl": "10s", }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + // Disable audit now we're done performing operations + err = c.Sys().DisableAudit("file") + require.NoError(t, err) // Check the audit trail on request and response decoder := json.NewDecoder(auditLogFile) - var auditRecord map[string]interface{} - for decoder.Decode(&auditRecord) == nil { - auditRequest := map[string]interface{}{} - if req, ok := auditRecord["request"]; ok { - auditRequest = req.(map[string]interface{}) - if auditRequest["path"] != "auth/token/create" { - continue - } - } - testBuiltinPluginMetadataAuditLog(t, auditRequest, consts.PluginTypeCredential.String()) + for decoder.More() { + var auditRecord map[string]interface{} + err := decoder.Decode(&auditRecord) + require.NoError(t, err) - auditResponse := map[string]interface{}{} - if req, ok := auditRecord["response"]; ok { - auditRequest = req.(map[string]interface{}) - if auditResponse["path"] != "auth/token/create" { + if req, ok := auditRecord["request"]; ok { + auditRequest, ok := req.(map[string]interface{}) + require.True(t, ok) + + path, ok := auditRequest["path"].(string) + require.True(t, ok) + + if path != "auth/token/create" { continue } + + testBuiltinPluginMetadataAuditLog(t, auditRequest, consts.PluginTypeCredential.String()) + } + + // Should never have a response without a corresponding request. + if resp, ok := auditRecord["response"]; ok { + auditResponse, ok := resp.(map[string]interface{}) + require.True(t, ok) + + testBuiltinPluginMetadataAuditLog(t, auditResponse, consts.PluginTypeCredential.String()) } - testBuiltinPluginMetadataAuditLog(t, auditResponse, consts.PluginTypeCredential.String()) } } @@ -974,9 +996,7 @@ func TestLogical_AuditEnabled_ShouldLogPluginMetadata_Secret(t *testing.T) { // Enable the audit backend tempDir := t.TempDir() auditLogFile, err := os.CreateTemp(tempDir, "") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) err = c.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ Type: "file", @@ -984,9 +1004,7 @@ func TestLogical_AuditEnabled_ShouldLogPluginMetadata_Secret(t *testing.T) { "file_path": auditLogFile.Name(), }, }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) { writeData := map[string]interface{}{ @@ -1003,26 +1021,36 @@ func TestLogical_AuditEnabled_ShouldLogPluginMetadata_Secret(t *testing.T) { }) } + // Disable audit now we're done performing operations + err = c.Sys().DisableAudit("file") + require.NoError(t, err) + // Check the audit trail on request and response decoder := json.NewDecoder(auditLogFile) - var auditRecord map[string]interface{} - for decoder.Decode(&auditRecord) == nil { - auditRequest := map[string]interface{}{} - if req, ok := auditRecord["request"]; ok { - auditRequest = req.(map[string]interface{}) - if auditRequest["path"] != "kv/data/foo" { - continue - } - } - testBuiltinPluginMetadataAuditLog(t, auditRequest, consts.PluginTypeSecrets.String()) + for decoder.More() { + var auditRecord map[string]interface{} + err := decoder.Decode(&auditRecord) + require.NoError(t, err) - auditResponse := map[string]interface{}{} - if req, ok := auditRecord["response"]; ok { - auditRequest = req.(map[string]interface{}) - if auditResponse["path"] != "kv/data/foo" { + if req, ok := auditRecord["request"]; ok { + auditRequest, ok := req.(map[string]interface{}) + require.True(t, ok) + + path, ok := auditRequest["path"] + require.True(t, ok) + + if path != "kv/data/foo" { continue } + + testBuiltinPluginMetadataAuditLog(t, auditRequest, consts.PluginTypeSecrets.String()) + } + + if resp, ok := auditRecord["response"]; ok { + auditResponse, ok := resp.(map[string]interface{}) + require.True(t, ok) + + testBuiltinPluginMetadataAuditLog(t, auditResponse, consts.PluginTypeSecrets.String()) } - testBuiltinPluginMetadataAuditLog(t, auditResponse, consts.PluginTypeSecrets.String()) } } diff --git a/vault/external_tests/kv/kv_patch_test.go b/vault/external_tests/kv/kv_patch_test.go index 1b4802821e..62f65e41c3 100644 --- a/vault/external_tests/kv/kv_patch_test.go +++ b/vault/external_tests/kv/kv_patch_test.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/testhelpers/minimal" + "github.com/stretchr/testify/require" ) func TestKV_Patch_BadContentTypeHeader(t *testing.T) { @@ -158,7 +159,10 @@ func TestKV_Patch_Audit(t *testing.T) { decoder := json.NewDecoder(auditLogFile) var auditRecord map[string]interface{} - for decoder.Decode(&auditRecord) == nil { + for decoder.More() { + err := decoder.Decode(&auditRecord) + require.NoError(t, err) + auditRequest := map[string]interface{}{} if req, ok := auditRecord["request"]; ok { diff --git a/vault/external_tests/plugin/external_plugin_test.go b/vault/external_tests/plugin/external_plugin_test.go index 9f81521901..ef508b44e1 100644 --- a/vault/external_tests/plugin/external_plugin_test.go +++ b/vault/external_tests/plugin/external_plugin_test.go @@ -28,6 +28,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" _ "github.com/jackc/pgx/v4/stdlib" + "github.com/stretchr/testify/require" ) func getCluster(t *testing.T, numCores int, types ...consts.PluginType) *vault.TestCluster { @@ -946,25 +947,32 @@ func TestExternalPlugin_AuditEnabled_ShouldLogPluginMetadata_Auth(t *testing.T) // Check the audit trail on request and response decoder := json.NewDecoder(auditLogFile) - var auditRecord map[string]interface{} - for decoder.Decode(&auditRecord) == nil { - auditRequest := map[string]interface{}{} - if req, ok := auditRecord["request"]; ok { - auditRequest = req.(map[string]interface{}) - if auditRequest["path"] != "auth/"+plugin.Name+"/role/role1" { - continue - } - } - testExternalPluginMetadataAuditLog(t, auditRequest, consts.PluginTypeCredential.String()) + for decoder.More() { + var auditRecord map[string]interface{} + err := decoder.Decode(&auditRecord) + require.NoError(t, err) - auditResponse := map[string]interface{}{} - if req, ok := auditRecord["response"]; ok { - auditRequest = req.(map[string]interface{}) - if auditResponse["path"] != "auth/"+plugin.Name+"/role/role1" { + if req, ok := auditRecord["request"]; ok { + auditRequest, ok := req.(map[string]interface{}) + require.True(t, ok) + + path, ok := auditRequest["path"] + require.True(t, ok) + + if path != "auth/"+plugin.Name+"/role/role1" { continue } + + testExternalPluginMetadataAuditLog(t, auditRequest, consts.PluginTypeCredential.String()) } - testExternalPluginMetadataAuditLog(t, auditResponse, consts.PluginTypeCredential.String()) + + if resp, ok := auditRecord["response"]; ok { + auditResponse, ok := resp.(map[string]interface{}) + require.True(t, ok) + + testExternalPluginMetadataAuditLog(t, auditResponse, consts.PluginTypeCredential.String()) + } + } // Deregister @@ -1014,31 +1022,39 @@ func TestExternalPlugin_AuditEnabled_ShouldLogPluginMetadata_Secret(t *testing.T "address": consulConfig.Address(), "token": consulConfig.Token, }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + // Disable audit now we're done performing operations + err = client.Sys().DisableAudit("file") + require.NoError(t, err) // Check the audit trail on request and response decoder := json.NewDecoder(auditLogFile) - var auditRecord map[string]interface{} - for decoder.Decode(&auditRecord) == nil { - auditRequest := map[string]interface{}{} - if req, ok := auditRecord["request"]; ok { - auditRequest = req.(map[string]interface{}) - if auditRequest["path"] != plugin.Name+"/config/access" { - continue - } - } - testExternalPluginMetadataAuditLog(t, auditRequest, consts.PluginTypeSecrets.String()) + for decoder.More() { + var auditRecord map[string]interface{} + err := decoder.Decode(&auditRecord) + require.NoError(t, err) - auditResponse := map[string]interface{}{} - if req, ok := auditRecord["response"]; ok { - auditRequest = req.(map[string]interface{}) - if auditResponse["path"] != plugin.Name+"/config/access" { + if req, ok := auditRecord["request"]; ok { + auditRequest, ok := req.(map[string]interface{}) + require.True(t, ok) + + path, ok := auditRequest["path"].(string) + require.True(t, ok) + + if path != plugin.Name+"/config/access" { continue } + + testExternalPluginMetadataAuditLog(t, auditRequest, consts.PluginTypeSecrets.String()) + } + + if resp, ok := auditRecord["response"]; ok { + auditResponse, ok := resp.(map[string]interface{}) + require.True(t, ok) + + testExternalPluginMetadataAuditLog(t, auditResponse, consts.PluginTypeSecrets.String()) } - testExternalPluginMetadataAuditLog(t, auditResponse, consts.PluginTypeSecrets.String()) } // Deregister