mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-04 04:28:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			467 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			467 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) HashiCorp, Inc.
 | 
						|
// SPDX-License-Identifier: MPL-2.0
 | 
						|
 | 
						|
package framework
 | 
						|
 | 
						|
import (
 | 
						|
	"encoding/base64"
 | 
						|
	"encoding/json"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"regexp"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"github.com/hashicorp/errwrap"
 | 
						|
	"github.com/hashicorp/go-secure-stdlib/parseutil"
 | 
						|
	"github.com/hashicorp/go-secure-stdlib/strutil"
 | 
						|
	"github.com/hashicorp/vault/sdk/helper/jsonutil"
 | 
						|
	"github.com/mitchellh/mapstructure"
 | 
						|
)
 | 
						|
 | 
						|
// FieldData is the structure passed to the callback to handle a path
 | 
						|
// containing the populated parameters for fields. This should be used
 | 
						|
// instead of the raw (*vault.Request).Data to access data in a type-safe
 | 
						|
// way.
 | 
						|
type FieldData struct {
 | 
						|
	Raw    map[string]interface{}
 | 
						|
	Schema map[string]*FieldSchema
 | 
						|
}
 | 
						|
 | 
						|
// Validate cycles through raw data and validates conversions in
 | 
						|
// the schema, so we don't get an error/panic later when
 | 
						|
// trying to get data out.  Data not in the schema is not
 | 
						|
// an error at this point, so we don't worry about it.
 | 
						|
func (d *FieldData) Validate() error {
 | 
						|
	for field, value := range d.Raw {
 | 
						|
 | 
						|
		schema, ok := d.Schema[field]
 | 
						|
		if !ok {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		switch schema.Type {
 | 
						|
		case TypeBool, TypeInt, TypeInt64, TypeMap, TypeDurationSecond, TypeSignedDurationSecond, TypeString,
 | 
						|
			TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice,
 | 
						|
			TypeKVPairs, TypeCommaIntSlice, TypeHeader, TypeFloat, TypeTime:
 | 
						|
			_, _, err := d.getPrimitive(field, schema)
 | 
						|
			if err != nil {
 | 
						|
				return errwrap.Wrapf(fmt.Sprintf("error converting input %v for field %q: {{err}}", value, field), err)
 | 
						|
			}
 | 
						|
		default:
 | 
						|
			return fmt.Errorf("unknown field type %q for field %q", schema.Type, field)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// ValidateStrict cycles through raw data and validates conversions in the
 | 
						|
// schema. In addition to the checks done by Validate, this function ensures
 | 
						|
// that the raw data has all of the schema's required fields and does not
 | 
						|
// have any fields outside of the schema. It will return a non-nil error if:
 | 
						|
//
 | 
						|
//  1. a conversion (parsing of the field's value) fails
 | 
						|
//  2. a raw field does not exist in the schema (unless the schema is nil)
 | 
						|
//  3. a required schema field is missing from the raw data
 | 
						|
//
 | 
						|
// This function is currently used for validating response schemas in tests.
 | 
						|
func (d *FieldData) ValidateStrict() error {
 | 
						|
	// the schema is nil, nothing to validate
 | 
						|
	if d.Schema == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	for field := range d.Raw {
 | 
						|
		if _, _, err := d.GetOkErr(field); err != nil {
 | 
						|
			return fmt.Errorf("field %q: %w", field, err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	for field, schema := range d.Schema {
 | 
						|
		if !schema.Required {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		if _, ok := d.Raw[field]; !ok {
 | 
						|
			return fmt.Errorf("missing required field %q", field)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Get gets the value for the given field. If the key is an invalid field,
 | 
						|
// FieldData will panic. If you want a safer version of this method, use
 | 
						|
// GetOk. If the field k is not set, the default value (if set) will be
 | 
						|
// returned, otherwise the zero value will be returned.
 | 
						|
func (d *FieldData) Get(k string) interface{} {
 | 
						|
	schema, ok := d.Schema[k]
 | 
						|
	if !ok {
 | 
						|
		panic(fmt.Sprintf("field %s not in the schema", k))
 | 
						|
	}
 | 
						|
 | 
						|
	// If the value can't be decoded, use the zero or default value for the field
 | 
						|
	// type
 | 
						|
	value, ok := d.GetOk(k)
 | 
						|
	if !ok || value == nil {
 | 
						|
		value = schema.DefaultOrZero()
 | 
						|
	}
 | 
						|
 | 
						|
	return value
 | 
						|
}
 | 
						|
 | 
						|
// GetDefaultOrZero gets the default value set on the schema for the given
 | 
						|
// field. If there is no default value set, the zero value of the type
 | 
						|
// will be returned.
 | 
						|
func (d *FieldData) GetDefaultOrZero(k string) interface{} {
 | 
						|
	schema, ok := d.Schema[k]
 | 
						|
	if !ok {
 | 
						|
		panic(fmt.Sprintf("field %s not in the schema", k))
 | 
						|
	}
 | 
						|
 | 
						|
	return schema.DefaultOrZero()
 | 
						|
}
 | 
						|
 | 
						|
// GetFirst gets the value for the given field names, in order from first
 | 
						|
// to last. This can be useful for fields with a current name, and one or
 | 
						|
// more deprecated names. The second return value will be false if the keys
 | 
						|
// are invalid or the keys are not set at all.
 | 
						|
func (d *FieldData) GetFirst(k ...string) (interface{}, bool) {
 | 
						|
	for _, v := range k {
 | 
						|
		if result, ok := d.GetOk(v); ok {
 | 
						|
			return result, ok
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil, false
 | 
						|
}
 | 
						|
 | 
						|
// GetOk gets the value for the given field. The second return value will be
 | 
						|
// false if the key is invalid or the key is not set at all. If the field k is
 | 
						|
// set and the decoded value is nil, the default or zero value
 | 
						|
// will be returned instead.
 | 
						|
func (d *FieldData) GetOk(k string) (interface{}, bool) {
 | 
						|
	schema, ok := d.Schema[k]
 | 
						|
	if !ok {
 | 
						|
		return nil, false
 | 
						|
	}
 | 
						|
 | 
						|
	result, ok, err := d.GetOkErr(k)
 | 
						|
	if err != nil {
 | 
						|
		panic(fmt.Sprintf("error reading %s: %s", k, err))
 | 
						|
	}
 | 
						|
 | 
						|
	if ok && result == nil {
 | 
						|
		result = schema.DefaultOrZero()
 | 
						|
	}
 | 
						|
 | 
						|
	return result, ok
 | 
						|
}
 | 
						|
 | 
						|
// GetOkErr is the most conservative of all the Get methods. It returns
 | 
						|
// whether key is set or not, but also an error value. The error value is
 | 
						|
// non-nil if the field doesn't exist or there was an error parsing the
 | 
						|
// field value.
 | 
						|
func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) {
 | 
						|
	schema, ok := d.Schema[k]
 | 
						|
	if !ok {
 | 
						|
		return nil, false, fmt.Errorf("unknown field: %q", k)
 | 
						|
	}
 | 
						|
 | 
						|
	switch schema.Type {
 | 
						|
	case TypeBool, TypeInt, TypeInt64, TypeMap, TypeDurationSecond, TypeSignedDurationSecond, TypeString,
 | 
						|
		TypeLowerCaseString, TypeNameString, TypeSlice, TypeStringSlice, TypeCommaStringSlice,
 | 
						|
		TypeKVPairs, TypeCommaIntSlice, TypeHeader, TypeFloat, TypeTime:
 | 
						|
		return d.getPrimitive(k, schema)
 | 
						|
	default:
 | 
						|
		return nil, false,
 | 
						|
			fmt.Errorf("unknown field type %q for field %q", schema.Type, k)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (d *FieldData) getPrimitive(k string, schema *FieldSchema) (interface{}, bool, error) {
 | 
						|
	raw, ok := d.Raw[k]
 | 
						|
	if !ok {
 | 
						|
		return nil, false, nil
 | 
						|
	}
 | 
						|
 | 
						|
	switch t := schema.Type; t {
 | 
						|
	case TypeBool:
 | 
						|
		var result bool
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeInt:
 | 
						|
		var result int
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeInt64:
 | 
						|
		var result int64
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeFloat:
 | 
						|
		var result float64
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeString:
 | 
						|
		var result string
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeLowerCaseString:
 | 
						|
		var result string
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		return strings.ToLower(result), true, nil
 | 
						|
 | 
						|
	case TypeNameString:
 | 
						|
		var result string
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		matched, err := regexp.MatchString("^\\w(([\\w-.]+)?\\w)?$", result)
 | 
						|
		if err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		if !matched {
 | 
						|
			return nil, false, errors.New("field does not match the formatting rules")
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeMap:
 | 
						|
		var result map[string]interface{}
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeDurationSecond, TypeSignedDurationSecond:
 | 
						|
		var result int
 | 
						|
		switch inp := raw.(type) {
 | 
						|
		case nil:
 | 
						|
			return nil, false, nil
 | 
						|
		default:
 | 
						|
			dur, err := parseutil.ParseDurationSecond(inp)
 | 
						|
			if err != nil {
 | 
						|
				return nil, false, err
 | 
						|
			}
 | 
						|
			result = int(dur.Seconds())
 | 
						|
		}
 | 
						|
		if t == TypeDurationSecond && result < 0 {
 | 
						|
			return nil, false, fmt.Errorf("cannot provide negative value '%d'", result)
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeTime:
 | 
						|
		switch inp := raw.(type) {
 | 
						|
		case nil:
 | 
						|
			// Handle nil interface{} as a non-error case
 | 
						|
			return nil, false, nil
 | 
						|
		default:
 | 
						|
			time, err := parseutil.ParseAbsoluteTime(inp)
 | 
						|
			if err != nil {
 | 
						|
				return nil, false, err
 | 
						|
			}
 | 
						|
			return time.UTC(), true, nil
 | 
						|
		}
 | 
						|
 | 
						|
	case TypeCommaIntSlice:
 | 
						|
		var result []int
 | 
						|
 | 
						|
		jsonIn, ok := raw.(json.Number)
 | 
						|
		if ok {
 | 
						|
			raw = jsonIn.String()
 | 
						|
		}
 | 
						|
 | 
						|
		config := &mapstructure.DecoderConfig{
 | 
						|
			Result:           &result,
 | 
						|
			WeaklyTypedInput: true,
 | 
						|
			DecodeHook:       mapstructure.StringToSliceHookFunc(","),
 | 
						|
		}
 | 
						|
		decoder, err := mapstructure.NewDecoder(config)
 | 
						|
		if err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		if err := decoder.Decode(raw); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		if len(result) == 0 {
 | 
						|
			return make([]int, 0), true, nil
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeSlice:
 | 
						|
		var result []interface{}
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		if len(result) == 0 {
 | 
						|
			return make([]interface{}, 0), true, nil
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeStringSlice:
 | 
						|
		rawString, ok := raw.(string)
 | 
						|
		if ok && rawString == "" {
 | 
						|
			return []string{}, true, nil
 | 
						|
		}
 | 
						|
 | 
						|
		var result []string
 | 
						|
		if err := mapstructure.WeakDecode(raw, &result); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		if len(result) == 0 {
 | 
						|
			return make([]string, 0), true, nil
 | 
						|
		}
 | 
						|
		return strutil.TrimStrings(result), true, nil
 | 
						|
 | 
						|
	case TypeCommaStringSlice:
 | 
						|
		res, err := parseutil.ParseCommaStringSlice(raw)
 | 
						|
		if err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
		return res, true, nil
 | 
						|
 | 
						|
	case TypeKVPairs:
 | 
						|
		// First try to parse this as a map
 | 
						|
		var mapResult map[string]string
 | 
						|
		if err := mapstructure.WeakDecode(raw, &mapResult); err == nil {
 | 
						|
			return mapResult, true, nil
 | 
						|
		}
 | 
						|
 | 
						|
		// If map parse fails, parse as a string list of = delimited pairs
 | 
						|
		var listResult []string
 | 
						|
		if err := mapstructure.WeakDecode(raw, &listResult); err != nil {
 | 
						|
			return nil, false, err
 | 
						|
		}
 | 
						|
 | 
						|
		result := make(map[string]string, len(listResult))
 | 
						|
		for _, keyPair := range listResult {
 | 
						|
			keyPairSlice := strings.SplitN(keyPair, "=", 2)
 | 
						|
			if len(keyPairSlice) != 2 || keyPairSlice[0] == "" {
 | 
						|
				return nil, false, fmt.Errorf("invalid key pair %q", keyPair)
 | 
						|
			}
 | 
						|
			result[keyPairSlice[0]] = keyPairSlice[1]
 | 
						|
		}
 | 
						|
		return result, true, nil
 | 
						|
 | 
						|
	case TypeHeader:
 | 
						|
		/*
 | 
						|
 | 
						|
			There are multiple ways a header could be provided:
 | 
						|
 | 
						|
			1.	As a map[string]interface{} that resolves to a map[string]string or map[string][]string, or a mix of both
 | 
						|
				because that's permitted for headers.
 | 
						|
				This mainly comes from the API.
 | 
						|
 | 
						|
			2.	As a string...
 | 
						|
				a. That contains JSON that originally was JSON, but then was base64 encoded.
 | 
						|
				b. That contains JSON, ex. `{"content-type":"text/json","accept":["encoding/json"]}`.
 | 
						|
				This mainly comes from the API and is used to save space while sending in the header.
 | 
						|
 | 
						|
			3.	As an array of strings that contains comma-delimited key-value pairs associated via a colon,
 | 
						|
				ex: `content-type:text/json`,`accept:encoding/json`.
 | 
						|
				This mainly comes from the CLI.
 | 
						|
 | 
						|
			We go through these sequentially below.
 | 
						|
 | 
						|
		*/
 | 
						|
		result := http.Header{}
 | 
						|
 | 
						|
		toHeader := func(resultMap map[string]interface{}) (http.Header, error) {
 | 
						|
			header := http.Header{}
 | 
						|
			for headerKey, headerValGroup := range resultMap {
 | 
						|
				switch typedHeader := headerValGroup.(type) {
 | 
						|
				case string:
 | 
						|
					header.Add(headerKey, typedHeader)
 | 
						|
				case []string:
 | 
						|
					for _, headerVal := range typedHeader {
 | 
						|
						header.Add(headerKey, headerVal)
 | 
						|
					}
 | 
						|
				case json.Number:
 | 
						|
					header.Add(headerKey, typedHeader.String())
 | 
						|
				case []interface{}:
 | 
						|
					for _, headerVal := range typedHeader {
 | 
						|
						switch typedHeader := headerVal.(type) {
 | 
						|
						case string:
 | 
						|
							header.Add(headerKey, typedHeader)
 | 
						|
						case json.Number:
 | 
						|
							header.Add(headerKey, typedHeader.String())
 | 
						|
						default:
 | 
						|
							// All header values should already be strings when they're being sent in.
 | 
						|
							// Even numbers and booleans will be treated as strings.
 | 
						|
							return nil, fmt.Errorf("received non-string value for header key:%s, val:%s", headerKey, headerValGroup)
 | 
						|
						}
 | 
						|
					}
 | 
						|
				default:
 | 
						|
					return nil, fmt.Errorf("unrecognized type for %s", headerValGroup)
 | 
						|
				}
 | 
						|
			}
 | 
						|
			return header, nil
 | 
						|
		}
 | 
						|
 | 
						|
		resultMap := make(map[string]interface{})
 | 
						|
 | 
						|
		// 1. Are we getting a map from the API?
 | 
						|
		if err := mapstructure.WeakDecode(raw, &resultMap); err == nil {
 | 
						|
			result, err = toHeader(resultMap)
 | 
						|
			if err != nil {
 | 
						|
				return nil, false, err
 | 
						|
			}
 | 
						|
			return result, true, nil
 | 
						|
		}
 | 
						|
 | 
						|
		// 2. Are we getting a JSON string?
 | 
						|
		if headerStr, ok := raw.(string); ok {
 | 
						|
			// a. Is it base64 encoded?
 | 
						|
			headerBytes, err := base64.StdEncoding.DecodeString(headerStr)
 | 
						|
			if err != nil {
 | 
						|
				// b. It's not base64 encoded, it's a straight-out JSON string.
 | 
						|
				headerBytes = []byte(headerStr)
 | 
						|
			}
 | 
						|
			if err := jsonutil.DecodeJSON(headerBytes, &resultMap); err != nil {
 | 
						|
				return nil, false, err
 | 
						|
			}
 | 
						|
			result, err = toHeader(resultMap)
 | 
						|
			if err != nil {
 | 
						|
				return nil, false, err
 | 
						|
			}
 | 
						|
			return result, true, nil
 | 
						|
		}
 | 
						|
 | 
						|
		// 3. Are we getting an array of fields like "content-type:encoding/json" from the CLI?
 | 
						|
		var keyPairs []interface{}
 | 
						|
		if err := mapstructure.WeakDecode(raw, &keyPairs); err == nil {
 | 
						|
			for _, keyPairIfc := range keyPairs {
 | 
						|
				keyPair, ok := keyPairIfc.(string)
 | 
						|
				if !ok {
 | 
						|
					return nil, false, fmt.Errorf("invalid key pair %q", keyPair)
 | 
						|
				}
 | 
						|
				keyPairSlice := strings.SplitN(keyPair, ":", 2)
 | 
						|
				if len(keyPairSlice) != 2 || keyPairSlice[0] == "" {
 | 
						|
					return nil, false, fmt.Errorf("invalid key pair %q", keyPair)
 | 
						|
				}
 | 
						|
				result.Add(keyPairSlice[0], keyPairSlice[1])
 | 
						|
			}
 | 
						|
			return result, true, nil
 | 
						|
		}
 | 
						|
		return nil, false, fmt.Errorf("%s not provided an expected format", raw)
 | 
						|
 | 
						|
	default:
 | 
						|
		panic(fmt.Sprintf("Unknown type: %s", schema.Type))
 | 
						|
	}
 | 
						|
}
 |