OpenAPI: Fix generation of correct fields (#21942)

* OpenAPI: Fix generation of correct fields

Currently, the OpenAPI generator logic is wrong about how it maps from
Vault framework fields to OpenAPI. This manifests most obviously with
endpoints making use of `framework.OptionalParamRegex` or similar
regex-level optional path parameters, and results in various incorrect
fields showing up in the generated request structures.

The fix is a bit complicated, but in essence is just rewriting the
OpenAPI logic to properly parallel the real request processing logic.

With these changes:

* A path parameter in an optional part of the regex, no longer gets
  erroneously treated as a body parameter when creating OpenAPI
  endpoints that do not include the optional parameter.

* A field marked as `Query: true` no longer gets incorrectly skipped
  when creating OpenAPI `POST` operations.

* changelog
This commit is contained in:
Max Bowsher
2023-07-25 04:10:33 +01:00
committed by GitHub
parent b2e110ec5a
commit 8e4409dbf0
3 changed files with 148 additions and 86 deletions

3
changelog/21942.txt Normal file
View File

@@ -0,0 +1,3 @@
```release-note:improvement
openapi: Fix generation of correct fields in some rarer cases
```

View File

@@ -229,7 +229,7 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc *
// Convert optional parameters into distinct patterns to be processed independently. // Convert optional parameters into distinct patterns to be processed independently.
forceUnpublished := false forceUnpublished := false
paths, err := expandPattern(p.Pattern) paths, captures, err := expandPattern(p.Pattern)
if err != nil { if err != nil {
if errors.Is(err, errUnsupportableRegexpOperationForOpenAPI) { if errors.Is(err, errUnsupportableRegexpOperationForOpenAPI) {
// Pattern cannot be transformed into sensible OpenAPI paths. In this case, we override the later // Pattern cannot be transformed into sensible OpenAPI paths. In this case, we override the later
@@ -270,26 +270,14 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc *
// Process path and header parameters, which are common to all operations. // Process path and header parameters, which are common to all operations.
// Body fields will be added to individual operations. // Body fields will be added to individual operations.
pathFields, bodyFields := splitFields(p.Fields, path) pathFields, queryFields, bodyFields := splitFields(p.Fields, path, captures)
for name, field := range pathFields { for name, field := range pathFields {
location := "path"
required := true
if field == nil {
continue
}
if field.Query {
location = "query"
required = false
}
t := convertType(field.Type) t := convertType(field.Type)
p := OASParameter{ p := OASParameter{
Name: name, Name: name,
Description: cleanString(field.Description), Description: cleanString(field.Description),
In: location, In: "path",
Schema: &OASSchema{ Schema: &OASSchema{
Type: t.baseType, Type: t.baseType,
Pattern: t.pattern, Pattern: t.pattern,
@@ -297,7 +285,7 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc *
Default: field.Default, Default: field.Default,
DisplayAttrs: withoutOperationHints(field.DisplayAttrs), DisplayAttrs: withoutOperationHints(field.DisplayAttrs),
}, },
Required: required, Required: true,
Deprecated: field.Deprecated, Deprecated: field.Deprecated,
} }
pi.Parameters = append(pi.Parameters, p) pi.Parameters = append(pi.Parameters, p)
@@ -342,8 +330,12 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc *
op.Deprecated = props.Deprecated op.Deprecated = props.Deprecated
op.OperationID = operationID op.OperationID = operationID
// Add any fields not present in the path as body parameters for POST. switch opType {
if opType == logical.CreateOperation || opType == logical.UpdateOperation { // For the operation types which map to POST/PUT methods, and so allow for request body parameters,
// prepare the request body definition
case logical.CreateOperation:
fallthrough
case logical.UpdateOperation:
s := &OASSchema{ s := &OASSchema{
Type: "object", Type: "object",
Properties: make(map[string]*OASSchema), Properties: make(map[string]*OASSchema),
@@ -357,27 +349,14 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc *
continue continue
} }
openapiField := convertType(field.Type) addFieldToOASSchema(s, name, field)
if field.Required { }
s.Required = append(s.Required, name)
}
p := OASSchema{ // Contrary to what one might guess, fields marked with "Query: true" are only query fields when the
Type: openapiField.baseType, // request method is one which does not allow for a request body - they are still body fields when
Description: cleanString(field.Description), // dealing with a POST/PUT request.
Format: openapiField.format, for name, field := range queryFields {
Pattern: openapiField.pattern, addFieldToOASSchema(s, name, field)
Enum: field.AllowedValues,
Default: field.Default,
Deprecated: field.Deprecated,
DisplayAttrs: withoutOperationHints(field.DisplayAttrs),
}
if openapiField.baseType == "array" {
p.Items = &OASSchema{
Type: openapiField.items,
}
}
s.Properties[name] = &p
} }
// Make the ordering deterministic, so that the generated OpenAPI spec document, observed over several // Make the ordering deterministic, so that the generated OpenAPI spec document, observed over several
@@ -426,12 +405,12 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc *
}, },
} }
} }
}
// LIST is represented as GET with a `list` query parameter. Code later on in this function will assign // For the operation types which map to HTTP methods without a request body, populate query parameters
// list operations to a path with an extra trailing slash, ensuring they do not collide with read case logical.ListOperation:
// operations. // LIST is represented as GET with a `list` query parameter. Code later on in this function will assign
if opType == logical.ListOperation { // list operations to a path with an extra trailing slash, ensuring they do not collide with read
// operations.
op.Parameters = append(op.Parameters, OASParameter{ op.Parameters = append(op.Parameters, OASParameter{
Name: "list", Name: "list",
Description: "Must be set to `true`", Description: "Must be set to `true`",
@@ -439,6 +418,27 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc *
In: "query", In: "query",
Schema: &OASSchema{Type: "string", Enum: []interface{}{"true"}}, Schema: &OASSchema{Type: "string", Enum: []interface{}{"true"}},
}) })
fallthrough
case logical.DeleteOperation:
fallthrough
case logical.ReadOperation:
for name, field := range queryFields {
t := convertType(field.Type)
p := OASParameter{
Name: name,
Description: cleanString(field.Description),
In: "query",
Schema: &OASSchema{
Type: t.baseType,
Pattern: t.pattern,
Enum: field.AllowedValues,
Default: field.Default,
DisplayAttrs: withoutOperationHints(field.DisplayAttrs),
},
Deprecated: field.Deprecated,
}
op.Parameters = append(op.Parameters, p)
}
} }
// Add tags based on backend type // Add tags based on backend type
@@ -612,6 +612,31 @@ func documentPath(p *Path, backend *Backend, requestResponsePrefix string, doc *
return nil return nil
} }
func addFieldToOASSchema(s *OASSchema, name string, field *FieldSchema) {
openapiField := convertType(field.Type)
if field.Required {
s.Required = append(s.Required, name)
}
p := OASSchema{
Type: openapiField.baseType,
Description: cleanString(field.Description),
Format: openapiField.format,
Pattern: openapiField.pattern,
Enum: field.AllowedValues,
Default: field.Default,
Deprecated: field.Deprecated,
DisplayAttrs: withoutOperationHints(field.DisplayAttrs),
}
if openapiField.baseType == "array" {
p.Items = &OASSchema{
Type: openapiField.items,
}
}
s.Properties[name] = &p
}
// specialPathMatch checks whether the given path matches one of the special // specialPathMatch checks whether the given path matches one of the special
// paths, taking into account * and + wildcards (e.g. foo/+/bar/*) // paths, taking into account * and + wildcards (e.g. foo/+/bar/*)
func specialPathMatch(path string, specialPaths []string) bool { func specialPathMatch(path string, specialPaths []string) bool {
@@ -776,8 +801,9 @@ func constructOperationID(
} }
// expandPattern expands a regex pattern by generating permutations of any optional parameters // expandPattern expands a regex pattern by generating permutations of any optional parameters
// and changing named parameters into their {openapi} equivalents. // and changing named parameters into their {openapi} equivalents. It also returns the names of all capturing groups
func expandPattern(pattern string) ([]string, error) { // observed in the pattern.
func expandPattern(pattern string) (paths []string, captures map[string]struct{}, err error) {
// Happily, the Go regexp library exposes its underlying "parse to AST" functionality, so we can rely on that to do // Happily, the Go regexp library exposes its underlying "parse to AST" functionality, so we can rely on that to do
// the hard work of interpreting the regexp syntax. // the hard work of interpreting the regexp syntax.
rx, err := syntax.Parse(pattern, syntax.Perl) rx, err := syntax.Parse(pattern, syntax.Perl)
@@ -787,12 +813,12 @@ func expandPattern(pattern string) ([]string, error) {
panic(err) panic(err)
} }
paths, err := collectPathsFromRegexpAST(rx) paths, captures, err = collectPathsFromRegexpAST(rx)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return paths, nil return paths, captures, nil
} }
type pathCollector struct { type pathCollector struct {
@@ -813,23 +839,28 @@ type pathCollector struct {
// //
// Each named capture group - i.e. (?P<name>something here) - is replaced with an OpenAPI parameter - i.e. {name} - and // Each named capture group - i.e. (?P<name>something here) - is replaced with an OpenAPI parameter - i.e. {name} - and
// the subtree of regexp AST inside the parameter is completely skipped. // the subtree of regexp AST inside the parameter is completely skipped.
func collectPathsFromRegexpAST(rx *syntax.Regexp) ([]string, error) { func collectPathsFromRegexpAST(rx *syntax.Regexp) (paths []string, captures map[string]struct{}, err error) {
pathCollectors, err := collectPathsFromRegexpASTInternal(rx, []*pathCollector{{}}) captures = make(map[string]struct{})
pathCollectors, err := collectPathsFromRegexpASTInternal(rx, []*pathCollector{{}}, captures)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
paths := make([]string, 0, len(pathCollectors)) paths = make([]string, 0, len(pathCollectors))
for _, collector := range pathCollectors { for _, collector := range pathCollectors {
if collector.conditionalSlashAppendedAtLength != collector.Len() { if collector.conditionalSlashAppendedAtLength != collector.Len() {
paths = append(paths, collector.String()) paths = append(paths, collector.String())
} }
} }
return paths, nil return paths, captures, nil
} }
var errUnsupportableRegexpOperationForOpenAPI = errors.New("path regexp uses an operation that cannot be translated to an OpenAPI pattern") var errUnsupportableRegexpOperationForOpenAPI = errors.New("path regexp uses an operation that cannot be translated to an OpenAPI pattern")
func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCollector) ([]*pathCollector, error) { func collectPathsFromRegexpASTInternal(
rx *syntax.Regexp,
appendingTo []*pathCollector,
captures map[string]struct{},
) ([]*pathCollector, error) {
var err error var err error
// Depending on the type of this regexp AST node (its Op, i.e. operation), figure out whether it contributes any // Depending on the type of this regexp AST node (its Op, i.e. operation), figure out whether it contributes any
@@ -856,7 +887,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol
// those pieces. // those pieces.
case syntax.OpConcat: case syntax.OpConcat:
for _, child := range rx.Sub { for _, child := range rx.Sub {
appendingTo, err = collectPathsFromRegexpASTInternal(child, appendingTo) appendingTo, err = collectPathsFromRegexpASTInternal(child, appendingTo, captures)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -887,7 +918,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol
childAppendingTo = append(childAppendingTo, newCollector) childAppendingTo = append(childAppendingTo, newCollector)
} }
} }
childAppendingTo, err = collectPathsFromRegexpASTInternal(child, childAppendingTo) childAppendingTo, err = collectPathsFromRegexpASTInternal(child, childAppendingTo, captures)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -905,7 +936,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol
newCollector.conditionalSlashAppendedAtLength = collector.conditionalSlashAppendedAtLength newCollector.conditionalSlashAppendedAtLength = collector.conditionalSlashAppendedAtLength
childAppendingTo = append(childAppendingTo, newCollector) childAppendingTo = append(childAppendingTo, newCollector)
} }
childAppendingTo, err = collectPathsFromRegexpASTInternal(child, childAppendingTo) childAppendingTo, err = collectPathsFromRegexpASTInternal(child, childAppendingTo, captures)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -927,7 +958,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol
// In Vault, an unnamed capturing group is not actually used for capturing. // In Vault, an unnamed capturing group is not actually used for capturing.
// We treat it exactly the same as OpConcat. // We treat it exactly the same as OpConcat.
for _, child := range rx.Sub { for _, child := range rx.Sub {
appendingTo, err = collectPathsFromRegexpASTInternal(child, appendingTo) appendingTo, err = collectPathsFromRegexpASTInternal(child, appendingTo, captures)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -940,6 +971,7 @@ func collectPathsFromRegexpASTInternal(rx *syntax.Regexp, appendingTo []*pathCol
builder.WriteString(rx.Name) builder.WriteString(rx.Name)
builder.WriteRune('}') builder.WriteRune('}')
} }
captures[rx.Name] = struct{}{}
} }
// Any other kind of operation is a problem, and will trigger an error, resulting in the pattern being left out of // Any other kind of operation is a problem, and will trigger an error, resulting in the pattern being left out of
@@ -1041,29 +1073,37 @@ func cleanString(s string) string {
return s return s
} }
// splitFields partitions fields into path and body groups // splitFields partitions fields into path, query and body groups. It uses information on capturing groups previously
// The input pattern is expected to have been run through expandPattern, // collected by expandPattern, which is necessary to correctly match the treatment in (*Backend).HandleRequest:
// with paths parameters denotes in {braces}. // a field counts as a path field if it appears in any capture in the regex, and if that capture was inside an
func splitFields(allFields map[string]*FieldSchema, pattern string) (pathFields, bodyFields map[string]*FieldSchema) { // alternation or optional part of the regex which does not survive in the OpenAPI path pattern currently being
// processed, that field should NOT be rendered to the OpenAPI spec AT ALL.
func splitFields(
allFields map[string]*FieldSchema,
openAPIPathPattern string,
captures map[string]struct{},
) (pathFields, queryFields, bodyFields map[string]*FieldSchema) {
pathFields = make(map[string]*FieldSchema) pathFields = make(map[string]*FieldSchema)
queryFields = make(map[string]*FieldSchema)
bodyFields = make(map[string]*FieldSchema) bodyFields = make(map[string]*FieldSchema)
for _, match := range pathFieldsRe.FindAllStringSubmatch(pattern, -1) { for _, match := range pathFieldsRe.FindAllStringSubmatch(openAPIPathPattern, -1) {
name := match[1] name := match[1]
pathFields[name] = allFields[name] pathFields[name] = allFields[name]
} }
for name, field := range allFields { for name, field := range allFields {
if _, ok := pathFields[name]; !ok { // Any field which relates to a regex capture was already processed above, if it needed to be.
if _, ok := captures[name]; !ok {
if field.Query { if field.Query {
pathFields[name] = field queryFields[name] = field
} else { } else {
bodyFields[name] = field bodyFields[name] = field
} }
} }
} }
return pathFields, bodyFields return pathFields, queryFields, bodyFields
} }
// withoutOperationHints returns a copy of the given DisplayAttributes without // withoutOperationHints returns a copy of the given DisplayAttributes without

View File

@@ -160,13 +160,13 @@ func TestOpenAPI_ExpandPattern(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
out, err := expandPattern(test.inPattern) paths, _, err := expandPattern(test.inPattern)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
sort.Strings(out) sort.Strings(paths)
if !reflect.DeepEqual(out, test.outPathlets) { if !reflect.DeepEqual(paths, test.outPathlets) {
t.Fatalf("Test %d: Expected %v got %v", i, test.outPathlets, out) t.Fatalf("Test %d: Expected %v got %v", i, test.outPathlets, paths)
} }
} }
} }
@@ -188,7 +188,7 @@ func TestOpenAPI_ExpandPattern_ReturnsError(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
_, err := expandPattern(test.inPattern) _, _, err := expandPattern(test.inPattern)
if err != test.outError { if err != test.outError {
t.Fatalf("Test %d: Expected %q got %q", i, test.outError, err) t.Fatalf("Test %d: Expected %q got %q", i, test.outError, err)
} }
@@ -196,31 +196,50 @@ func TestOpenAPI_ExpandPattern_ReturnsError(t *testing.T) {
} }
func TestOpenAPI_SplitFields(t *testing.T) { func TestOpenAPI_SplitFields(t *testing.T) {
paths, captures, err := expandPattern("some/" + GenericNameRegex("a") + "/path" + OptionalParamRegex("e"))
if err != nil {
t.Fatal(err)
}
fields := map[string]*FieldSchema{ fields := map[string]*FieldSchema{
"a": {Description: "path"}, "a": {Description: "path"},
"b": {Description: "body"}, "b": {Description: "body"},
"c": {Description: "body"}, "c": {Description: "body"},
"d": {Description: "body"}, "d": {Description: "body"},
"e": {Description: "path"}, "e": {Description: "path"},
"f": {Description: "query", Query: true},
} }
pathFields, bodyFields := splitFields(fields, "some/{a}/path/{e}") for index, path := range paths {
pathFields, queryFields, bodyFields := splitFields(fields, path, captures)
lp := len(pathFields) numPath := len(pathFields)
lb := len(bodyFields) numQuery := len(queryFields)
l := len(fields) numBody := len(bodyFields)
if lp+lb != l { numExpectedDiscarded := 0
t.Fatalf("split length error: %d + %d != %d", lp, lb, l) // The first path generated is expected to be the one omitting the optional parameter field "e"
} if index == 0 {
numExpectedDiscarded = 1
for name, field := range pathFields {
if field.Description != "path" {
t.Fatalf("expected field %s to be in 'path', found in %s", name, field.Description)
} }
} l := len(fields)
for name, field := range bodyFields { if numPath+numQuery+numBody+numExpectedDiscarded != l {
if field.Description != "body" { t.Fatalf("split length error: %d + %d + %d + %d != %d", numPath, numQuery, numBody, numExpectedDiscarded, l)
t.Fatalf("expected field %s to be in 'body', found in %s", name, field.Description) }
for name, field := range pathFields {
if field.Description != "path" {
t.Fatalf("expected field %s to be in 'path', found in %s", name, field.Description)
}
}
for name, field := range queryFields {
if field.Description != "query" {
t.Fatalf("expected field %s to be in 'query', found in %s", name, field.Description)
}
}
for name, field := range bodyFields {
if field.Description != "body" {
t.Fatalf("expected field %s to be in 'body', found in %s", name, field.Description)
}
} }
} }
} }