From de0c724d72b141cbf046ec0868f5ef2ee8e9d519 Mon Sep 17 00:00:00 2001 From: Chris Capurso <1036769+ccapurso@users.noreply.github.com> Date: Wed, 28 Aug 2024 09:49:03 -0400 Subject: [PATCH] handle mount fields for non-entity clients; prevent null values (#28202) --- vault/activity_log.go | 47 ++++++++++++++++--- .../activity_testonly_test.go | 31 ++++++------ 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/vault/activity_log.go b/vault/activity_log.go index 4623dd0dd2..71df6654a1 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -3067,6 +3067,13 @@ func (a *ActivityLog) writeExport(ctx context.Context, rw http.ResponseWriter, f NamespacePath: nsDisplayPath, Timestamp: ts.UTC().Format(time.RFC3339), MountAccessor: e.MountAccessor, + + // Default following to empty versus nil, will be overwritten if necessary + Policies: []string{}, + EntityMetadata: map[string]string{}, + EntityAliasMetadata: map[string]string{}, + EntityAliasCustomMetadata: map[string]string{}, + EntityGroupIDs: []string{}, } if e.MountAccessor != "" { @@ -3106,24 +3113,34 @@ func (a *ActivityLog) writeExport(ctx context.Context, rw http.ResponseWriter, f return fmt.Errorf("failed to process entity name") } - record.Policies, ok = entityResp.Data["policies"].([]string) + policies, ok := entityResp.Data["policies"].([]string) if !ok { return fmt.Errorf("failed to process policies") } - slices.Sort(record.Policies) + if policies != nil { + record.Policies = policies + slices.Sort(record.Policies) + } - record.EntityMetadata, ok = entityResp.Data["metadata"].(map[string]string) + entityMetadata, ok := entityResp.Data["metadata"].(map[string]string) if !ok { return fmt.Errorf("failed to process entity metadata") } - record.EntityGroupIDs, ok = entityResp.Data["group_ids"].([]string) + if entityMetadata != nil { + record.EntityMetadata = entityMetadata + } + + entityGroupIDs, ok := entityResp.Data["group_ids"].([]string) if !ok { return fmt.Errorf("failed to process entity group IDs") } - slices.Sort(record.EntityGroupIDs) + if entityGroupIDs != nil { + record.EntityGroupIDs = entityGroupIDs + slices.Sort(record.EntityGroupIDs) + } aliases, ok := entityResp.Data["aliases"].([]interface{}) if !ok { @@ -3165,15 +3182,31 @@ func (a *ActivityLog) writeExport(ctx context.Context, rw http.ResponseWriter, f return fmt.Errorf("failed to process mount path") } - record.EntityAliasMetadata, ok = alias["metadata"].(map[string]string) + entityAliasMetadata, ok := alias["metadata"].(map[string]string) if !ok { return fmt.Errorf("failed to process entity alias metadata") } - record.EntityAliasCustomMetadata, ok = alias["custom_metadata"].(map[string]string) + if entityAliasMetadata != nil { + record.EntityAliasMetadata = entityAliasMetadata + } + + entityAliasCustomMetadata, ok := alias["custom_metadata"].(map[string]string) if !ok { return fmt.Errorf("failed to process entity alias custom metadata") } + + if entityAliasCustomMetadata != nil { + record.EntityAliasCustomMetadata = entityAliasCustomMetadata + } + } + } else { + // fetch mount directly to ensure mount type and path are populated + // this will be necessary for non-entity client types (e.g. non-entity-token) + validateResp := a.core.router.ValidateMountByAccessor(e.MountAccessor) + if validateResp != nil { + record.MountPath = validateResp.MountPath + record.MountType = validateResp.MountType } } } diff --git a/vault/external_tests/activity_testonly/activity_testonly_test.go b/vault/external_tests/activity_testonly/activity_testonly_test.go index e1c7e97858..8141357efe 100644 --- a/vault/external_tests/activity_testonly/activity_testonly_test.go +++ b/vault/external_tests/activity_testonly/activity_testonly_test.go @@ -585,8 +585,22 @@ func getCSVExport(t *testing.T, client *api.Client, monthsPreviousTo int, now ti // skip initial row as it is header for rowIdx := 1; rowIdx < len(csvRecords); rowIdx++ { + baseRecord := vault.ActivityLogExportRecord{ + Policies: []string{}, + EntityMetadata: map[string]string{}, + EntityAliasMetadata: map[string]string{}, + EntityAliasCustomMetadata: map[string]string{}, + EntityGroupIDs: []string{}, + } + recordMap := make(map[string]interface{}) + // create base map + err = mapstructure.Decode(baseRecord, &recordMap) + if err != nil { + return nil, err + } + for columnIdx, columnName := range csvHeader { val := csvRecords[rowIdx][columnIdx] @@ -597,13 +611,7 @@ func getCSVExport(t *testing.T, client *api.Client, monthsPreviousTo int, now ti prefix := columnNameParts[0] if _, ok := mapFields[prefix]; ok { - m, mOK := recordMap[prefix] - - // ensure output contains non-nil map - if !mOK { - m = make(map[string]string) - recordMap[prefix] = m - } + m := recordMap[prefix] // ignore empty CSV column value if val != "" { @@ -611,12 +619,7 @@ func getCSVExport(t *testing.T, client *api.Client, monthsPreviousTo int, now ti recordMap[prefix] = m } } else if _, ok := sliceFields[prefix]; ok { - // ensure output contains non-nil slice - s, sOK := recordMap[prefix] - if !sOK { - s = make([]string, 0) - recordMap[prefix] = s - } + s := recordMap[prefix] // ignore empty CSV column value if val != "" { @@ -624,7 +627,7 @@ func getCSVExport(t *testing.T, client *api.Client, monthsPreviousTo int, now ti recordMap[prefix] = s } } else { - t.Fatalf("unexpected CSV field: %s", columnName) + t.Fatalf("unexpected CSV field: %q", columnName) } } else if _, ok := boolFields[columnName]; ok { recordMap[columnName], err = strconv.ParseBool(val)