mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-10-31 02:08:13 +00:00 
			
		
		
		
	Merge pull request #126359 from jpbetz/quantity-estimated-cost
Fix estimated cost for Kubernetes defined CEL types for equals
This commit is contained in:
		| @@ -2010,14 +2010,14 @@ func TestCelEstimatedCostStability(t *testing.T) { | |||||||
| 				`isQuantity(self.val2)`: 314575, | 				`isQuantity(self.val2)`: 314575, | ||||||
| 				`isQuantity("200M")`:    1, | 				`isQuantity("200M")`:    1, | ||||||
| 				`isQuantity("20Mi")`:    1, | 				`isQuantity("20Mi")`:    1, | ||||||
| 				`quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`:                                           uint64(3689348814741910532), | 				`quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`:                                           uint64(6), | ||||||
| 				`quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`: uint64(5534023222112865798), | 				`quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`: uint64(9), | ||||||
| 				`quantity(self.val1).isLessThan(quantity(self.val2))`:                                                                    629151, | 				`quantity(self.val1).isLessThan(quantity(self.val2))`:                                                                    629151, | ||||||
| 				`quantity("50M").isLessThan(quantity("100M"))`:                                                                           3, | 				`quantity("50M").isLessThan(quantity("100M"))`:                                                                           3, | ||||||
| 				`quantity("50Mi").isGreaterThan(quantity("50M"))`:                                                                        3, | 				`quantity("50Mi").isGreaterThan(quantity("50M"))`:                                                                        3, | ||||||
| 				`quantity("200M").compareTo(quantity("0.2G")) == 0`:                                                                      4, | 				`quantity("200M").compareTo(quantity("0.2G")) == 0`:                                                                      4, | ||||||
| 				`quantity("50k").add(quantity("20")) == quantity("50.02k")`:                                                              uint64(1844674407370955268), | 				`quantity("50k").add(quantity("20")) == quantity("50.02k")`:                                                              uint64(5), | ||||||
| 				`quantity("50k").sub(20) == quantity("49980")`:                                                                           uint64(1844674407370955267), | 				`quantity("50k").sub(20) == quantity("49980")`:                                                                           uint64(4), | ||||||
| 				`quantity("50").isInteger()`:                                                                                             2, | 				`quantity("50").isInteger()`:                                                                                             2, | ||||||
| 				`quantity(self.val1).isInteger()`:                                                                                        314576, | 				`quantity(self.val1).isInteger()`:                                                                                        314576, | ||||||
| 			}, | 			}, | ||||||
|   | |||||||
| @@ -19,6 +19,7 @@ package library | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math" | 	"math" | ||||||
|  | 	"reflect" | ||||||
|  |  | ||||||
| 	"github.com/google/cel-go/checker" | 	"github.com/google/cel-go/checker" | ||||||
| 	"github.com/google/cel-go/common" | 	"github.com/google/cel-go/common" | ||||||
| @@ -27,6 +28,7 @@ import ( | |||||||
| 	"github.com/google/cel-go/common/types/ref" | 	"github.com/google/cel-go/common/types/ref" | ||||||
| 	"github.com/google/cel-go/common/types/traits" | 	"github.com/google/cel-go/common/types/traits" | ||||||
|  |  | ||||||
|  | 	"k8s.io/apimachinery/pkg/util/sets" | ||||||
| 	"k8s.io/apiserver/pkg/cel" | 	"k8s.io/apiserver/pkg/cel" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -48,6 +50,22 @@ var knownUnhandledFunctions = map[string]bool{ | |||||||
| 	"strings.quote": true, | 	"strings.quote": true, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // TODO: Replace this with a utility that extracts types from libraries. | ||||||
|  | var knownKubernetesRuntimeTypes = sets.New[reflect.Type]( | ||||||
|  | 	reflect.ValueOf(cel.URL{}).Type(), | ||||||
|  | 	reflect.ValueOf(cel.IP{}).Type(), | ||||||
|  | 	reflect.ValueOf(cel.CIDR{}).Type(), | ||||||
|  | 	reflect.ValueOf(&cel.Format{}).Type(), | ||||||
|  | 	reflect.ValueOf(cel.Quantity{}).Type(), | ||||||
|  | ) | ||||||
|  | var knownKubernetesCompilerTypes = sets.New[ref.Type]( | ||||||
|  | 	cel.CIDRType, | ||||||
|  | 	cel.IPType, | ||||||
|  | 	cel.FormatType, | ||||||
|  | 	cel.QuantityType, | ||||||
|  | 	cel.URLType, | ||||||
|  | ) | ||||||
|  |  | ||||||
| // CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. | // CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. | ||||||
| type CostEstimator struct { | type CostEstimator struct { | ||||||
| 	// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation | 	// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation | ||||||
| @@ -235,6 +253,27 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re | |||||||
| 		// url accessors | 		// url accessors | ||||||
| 		cost := uint64(1) | 		cost := uint64(1) | ||||||
| 		return &cost | 		return &cost | ||||||
|  | 	case "_==_": | ||||||
|  | 		if len(args) == 2 { | ||||||
|  | 			unitCost := uint64(1) | ||||||
|  | 			lhs := args[0] | ||||||
|  | 			switch lhs.(type) { | ||||||
|  | 			case cel.Quantity: | ||||||
|  | 				return &unitCost | ||||||
|  | 			case cel.IP: | ||||||
|  | 				return &unitCost | ||||||
|  | 			case cel.CIDR: | ||||||
|  | 				return &unitCost | ||||||
|  | 			case *cel.Format: // Formats have a small max size. | ||||||
|  | 				return &unitCost | ||||||
|  | 			case cel.URL: // TODO: Computing the actual cost is expensive, and changing this would be a breaking change | ||||||
|  | 				return &unitCost | ||||||
|  | 			default: | ||||||
|  | 				if panicOnUnknown && knownKubernetesRuntimeTypes.Has(reflect.ValueOf(lhs).Type()) { | ||||||
|  | 					panic(fmt.Errorf("CallCost: unhandled equality for Kubernetes type %T", lhs)) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	if panicOnUnknown && !knownUnhandledFunctions[function] { | 	if panicOnUnknown && !knownUnhandledFunctions[function] { | ||||||
| 		panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args)) | 		panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args)) | ||||||
| @@ -278,7 +317,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch | |||||||
| 	case "url": | 	case "url": | ||||||
| 		if len(args) == 1 { | 		if len(args) == 1 { | ||||||
| 			sz := l.sizeEstimate(args[0]) | 			sz := l.sizeEstimate(args[0]) | ||||||
| 			return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} | 			return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz} | ||||||
| 		} | 		} | ||||||
| 	case "lowerAscii", "upperAscii", "substring", "trim": | 	case "lowerAscii", "upperAscii", "substring", "trim": | ||||||
| 		if target != nil { | 		if target != nil { | ||||||
| @@ -475,6 +514,39 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch | |||||||
| 	case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery": | 	case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery": | ||||||
| 		// url accessors | 		// url accessors | ||||||
| 		return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} | 		return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} | ||||||
|  | 	case "_==_": | ||||||
|  | 		if len(args) == 2 { | ||||||
|  | 			lhs := args[0] | ||||||
|  | 			rhs := args[1] | ||||||
|  | 			if lhs.Type().Equal(rhs.Type()) == types.True { | ||||||
|  | 				t := lhs.Type() | ||||||
|  | 				if t.Kind() == types.OpaqueKind { | ||||||
|  | 					switch t.TypeName() { | ||||||
|  | 					case cel.IPType.TypeName(), cel.CIDRType.TypeName(): | ||||||
|  | 						return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				if t.Kind() == types.StructKind { | ||||||
|  | 					switch t { | ||||||
|  | 					case cel.QuantityType: // O(1) cost equality checks | ||||||
|  | 						return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} | ||||||
|  | 					case cel.FormatType: | ||||||
|  | 						return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: cel.MaxFormatSize}.MultiplyByCostFactor(common.StringTraversalCostFactor)} | ||||||
|  | 					case cel.URLType: | ||||||
|  | 						size := checker.SizeEstimate{Min: 1, Max: 1} | ||||||
|  | 						rhSize := rhs.ComputedSize() | ||||||
|  | 						lhSize := rhs.ComputedSize() | ||||||
|  | 						if rhSize != nil && lhSize != nil { | ||||||
|  | 							size = rhSize.Union(*lhSize) | ||||||
|  | 						} | ||||||
|  | 						return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: size.Max}.MultiplyByCostFactor(common.StringTraversalCostFactor)} | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				if panicOnUnknown && knownKubernetesCompilerTypes.Has(t) { | ||||||
|  | 					panic(fmt.Errorf("EstimateCallCost: unhandled equality for Kubernetes type %v", t)) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| 	if panicOnUnknown && !knownUnhandledFunctions[function] { | 	if panicOnUnknown && !knownUnhandledFunctions[function] { | ||||||
| 		panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args)) | 		panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args)) | ||||||
|   | |||||||
| @@ -206,6 +206,16 @@ func TestURLsCost(t *testing.T) { | |||||||
| 			expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4}, | 			expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4}, | ||||||
| 			expectRuntimeCost:  4, | 			expectRuntimeCost:  4, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ops:                []string{" == url('https:://kubernetes.io/')"}, | ||||||
|  | 			expectEsimatedCost: checker.CostEstimate{Min: 7, Max: 9}, | ||||||
|  | 			expectRuntimeCost:  7, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ops:                []string{" == url('http://x.b')"}, | ||||||
|  | 			expectEsimatedCost: checker.CostEstimate{Min: 5, Max: 5}, | ||||||
|  | 			expectRuntimeCost:  5, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, tc := range cases { | 	for _, tc := range cases { | ||||||
| @@ -245,6 +255,14 @@ func TestIPCost(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 			expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, | 			expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ops: []string{" == ip('192.168.0.1')"}, | ||||||
|  | 			// For most other operations, the cost is expected to be the base + 1. | ||||||
|  | 			expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { | ||||||
|  | 				return c.Add(ipv4BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1}) | ||||||
|  | 			}, | ||||||
|  | 			expectRuntimeCost: func(c uint64) uint64 { return c + ipv4BaseRuntimeCost + 1 }, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, tc := range testCases { | 	for _, tc := range testCases { | ||||||
| @@ -320,6 +338,14 @@ func TestCIDRCost(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 			expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, | 			expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ops: []string{" == cidr('2001:db8::/32')"}, | ||||||
|  | 			// For most other operations, the cost is expected to be the base + 1. | ||||||
|  | 			expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { | ||||||
|  | 				return c.Add(ipv6BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1}) | ||||||
|  | 			}, | ||||||
|  | 			expectRuntimeCost: func(c uint64) uint64 { return c + ipv6BaseRuntimeCost + 1 }, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	//nolint:gocritic | 	//nolint:gocritic | ||||||
| @@ -708,19 +734,19 @@ func TestQuantityCost(t *testing.T) { | |||||||
| 		{ | 		{ | ||||||
| 			name:                "equality_reflexivity", | 			name:                "equality_reflexivity", | ||||||
| 			expr:                `quantity("200M") == quantity("200M")`, | 			expr:                `quantity("200M") == quantity("200M")`, | ||||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 1844674407370955266}, | 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3}, | ||||||
| 			expectRuntimeCost:   3, | 			expectRuntimeCost:   3, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:                "equality_symmetry", | 			name:                "equality_symmetry", | ||||||
| 			expr:                `quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`, | 			expr:                `quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`, | ||||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3689348814741910532}, | 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 6}, | ||||||
| 			expectRuntimeCost:   6, | 			expectRuntimeCost:   6, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:                "equality_transitivity", | 			name:                "equality_transitivity", | ||||||
| 			expr:                `quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`, | 			expr:                `quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`, | ||||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 5534023222112865798}, | 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 9}, | ||||||
| 			expectRuntimeCost:   9, | 			expectRuntimeCost:   9, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| @@ -744,19 +770,19 @@ func TestQuantityCost(t *testing.T) { | |||||||
| 		{ | 		{ | ||||||
| 			name:                "add_quantity", | 			name:                "add_quantity", | ||||||
| 			expr:                `quantity("50k").add(quantity("20")) == quantity("50.02k")`, | 			expr:                `quantity("50k").add(quantity("20")) == quantity("50.02k")`, | ||||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268}, | 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, | ||||||
| 			expectRuntimeCost:   5, | 			expectRuntimeCost:   5, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:                "sub_quantity", | 			name:                "sub_quantity", | ||||||
| 			expr:                `quantity("50k").sub(quantity("20")) == quantity("49.98k")`, | 			expr:                `quantity("50k").sub(quantity("20")) == quantity("49.98k")`, | ||||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268}, | 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, | ||||||
| 			expectRuntimeCost:   5, | 			expectRuntimeCost:   5, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:                "sub_int", | 			name:                "sub_int", | ||||||
| 			expr:                `quantity("50k").sub(20) == quantity("49980")`, | 			expr:                `quantity("50k").sub(20) == quantity("49980")`, | ||||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 1844674407370955267}, | 			expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 4}, | ||||||
| 			expectRuntimeCost:   4, | 			expectRuntimeCost:   4, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| @@ -825,6 +851,18 @@ func TestNameFormatCost(t *testing.T) { | |||||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34}, | 			expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34}, | ||||||
| 			expectRuntimeCost:   10, | 			expectRuntimeCost:   10, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:                "format.dns1123label.validate", | ||||||
|  | 			expr:                `format.named("dns1123Label").value().validate("my-name")`, | ||||||
|  | 			expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34}, | ||||||
|  | 			expectRuntimeCost:   10, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:                "format.dns1123label.validate", | ||||||
|  | 			expr:                `format.named("dns1123Label").value() == format.named("dns1123Label").value()`, | ||||||
|  | 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 11}, | ||||||
|  | 			expectRuntimeCost:   5, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, tc := range cases { | 	for _, tc := range cases { | ||||||
|   | |||||||
| @@ -22,6 +22,8 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/google/cel-go/common/types" | 	"github.com/google/cel-go/common/types" | ||||||
| 	"github.com/google/cel-go/common/types/ref" | 	"github.com/google/cel-go/common/types/ref" | ||||||
|  |  | ||||||
|  | 	"k8s.io/apiserver/pkg/cel" | ||||||
| 	"k8s.io/apiserver/pkg/cel/library" | 	"k8s.io/apiserver/pkg/cel/library" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -228,3 +230,11 @@ func TestFormat(t *testing.T) { | |||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestSizeLimit(t *testing.T) { | ||||||
|  | 	for name := range library.ConstantFormats { | ||||||
|  | 		if len(name) > cel.MaxFormatSize { | ||||||
|  | 			t.Fatalf("All formats must be <= %d chars in length", cel.MaxFormatSize) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -48,5 +48,7 @@ const ( | |||||||
| 	// MinNumberSize is the length of literal 0 | 	// MinNumberSize is the length of literal 0 | ||||||
| 	MinNumberSize = 1 | 	MinNumberSize = 1 | ||||||
|  |  | ||||||
|  | 	// MaxFormatSize is the maximum size we allow for format strings | ||||||
|  | 	MaxFormatSize          = 64 | ||||||
| 	MaxNameFormatRegexSize = 128 | 	MaxNameFormatRegexSize = 128 | ||||||
| ) | ) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Kubernetes Prow Robot
					Kubernetes Prow Robot