mirror of
				https://github.com/optim-enterprises-bv/kubernetes.git
				synced 2025-10-30 17:58:14 +00:00 
			
		
		
		
	Merge pull request #126359 from jpbetz/quantity-estimated-cost
Fix estimated cost for Kubernetes defined CEL types for equals
This commit is contained in:
		| @@ -2010,14 +2010,14 @@ func TestCelEstimatedCostStability(t *testing.T) { | ||||
| 				`isQuantity(self.val2)`: 314575, | ||||
| 				`isQuantity("200M")`:    1, | ||||
| 				`isQuantity("20Mi")`:    1, | ||||
| 				`quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`:                                           uint64(3689348814741910532), | ||||
| 				`quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`: uint64(5534023222112865798), | ||||
| 				`quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`:                                           uint64(6), | ||||
| 				`quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`: uint64(9), | ||||
| 				`quantity(self.val1).isLessThan(quantity(self.val2))`:                                                                    629151, | ||||
| 				`quantity("50M").isLessThan(quantity("100M"))`:                                                                           3, | ||||
| 				`quantity("50Mi").isGreaterThan(quantity("50M"))`:                                                                        3, | ||||
| 				`quantity("200M").compareTo(quantity("0.2G")) == 0`:                                                                      4, | ||||
| 				`quantity("50k").add(quantity("20")) == quantity("50.02k")`:                                                              uint64(1844674407370955268), | ||||
| 				`quantity("50k").sub(20) == quantity("49980")`:                                                                           uint64(1844674407370955267), | ||||
| 				`quantity("50k").add(quantity("20")) == quantity("50.02k")`:                                                              uint64(5), | ||||
| 				`quantity("50k").sub(20) == quantity("49980")`:                                                                           uint64(4), | ||||
| 				`quantity("50").isInteger()`:                                                                                             2, | ||||
| 				`quantity(self.val1).isInteger()`:                                                                                        314576, | ||||
| 			}, | ||||
|   | ||||
| @@ -19,6 +19,7 @@ package library | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"reflect" | ||||
|  | ||||
| 	"github.com/google/cel-go/checker" | ||||
| 	"github.com/google/cel-go/common" | ||||
| @@ -27,6 +28,7 @@ import ( | ||||
| 	"github.com/google/cel-go/common/types/ref" | ||||
| 	"github.com/google/cel-go/common/types/traits" | ||||
|  | ||||
| 	"k8s.io/apimachinery/pkg/util/sets" | ||||
| 	"k8s.io/apiserver/pkg/cel" | ||||
| ) | ||||
|  | ||||
| @@ -48,6 +50,22 @@ var knownUnhandledFunctions = map[string]bool{ | ||||
| 	"strings.quote": true, | ||||
| } | ||||
|  | ||||
| // TODO: Replace this with a utility that extracts types from libraries. | ||||
| var knownKubernetesRuntimeTypes = sets.New[reflect.Type]( | ||||
| 	reflect.ValueOf(cel.URL{}).Type(), | ||||
| 	reflect.ValueOf(cel.IP{}).Type(), | ||||
| 	reflect.ValueOf(cel.CIDR{}).Type(), | ||||
| 	reflect.ValueOf(&cel.Format{}).Type(), | ||||
| 	reflect.ValueOf(cel.Quantity{}).Type(), | ||||
| ) | ||||
| var knownKubernetesCompilerTypes = sets.New[ref.Type]( | ||||
| 	cel.CIDRType, | ||||
| 	cel.IPType, | ||||
| 	cel.FormatType, | ||||
| 	cel.QuantityType, | ||||
| 	cel.URLType, | ||||
| ) | ||||
|  | ||||
| // CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. | ||||
| type CostEstimator struct { | ||||
| 	// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation | ||||
| @@ -235,6 +253,27 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re | ||||
| 		// url accessors | ||||
| 		cost := uint64(1) | ||||
| 		return &cost | ||||
| 	case "_==_": | ||||
| 		if len(args) == 2 { | ||||
| 			unitCost := uint64(1) | ||||
| 			lhs := args[0] | ||||
| 			switch lhs.(type) { | ||||
| 			case cel.Quantity: | ||||
| 				return &unitCost | ||||
| 			case cel.IP: | ||||
| 				return &unitCost | ||||
| 			case cel.CIDR: | ||||
| 				return &unitCost | ||||
| 			case *cel.Format: // Formats have a small max size. | ||||
| 				return &unitCost | ||||
| 			case cel.URL: // TODO: Computing the actual cost is expensive, and changing this would be a breaking change | ||||
| 				return &unitCost | ||||
| 			default: | ||||
| 				if panicOnUnknown && knownKubernetesRuntimeTypes.Has(reflect.ValueOf(lhs).Type()) { | ||||
| 					panic(fmt.Errorf("CallCost: unhandled equality for Kubernetes type %T", lhs)) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	if panicOnUnknown && !knownUnhandledFunctions[function] { | ||||
| 		panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args)) | ||||
| @@ -278,7 +317,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch | ||||
| 	case "url": | ||||
| 		if len(args) == 1 { | ||||
| 			sz := l.sizeEstimate(args[0]) | ||||
| 			return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} | ||||
| 			return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz} | ||||
| 		} | ||||
| 	case "lowerAscii", "upperAscii", "substring", "trim": | ||||
| 		if target != nil { | ||||
| @@ -475,6 +514,39 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch | ||||
| 	case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery": | ||||
| 		// url accessors | ||||
| 		return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} | ||||
| 	case "_==_": | ||||
| 		if len(args) == 2 { | ||||
| 			lhs := args[0] | ||||
| 			rhs := args[1] | ||||
| 			if lhs.Type().Equal(rhs.Type()) == types.True { | ||||
| 				t := lhs.Type() | ||||
| 				if t.Kind() == types.OpaqueKind { | ||||
| 					switch t.TypeName() { | ||||
| 					case cel.IPType.TypeName(), cel.CIDRType.TypeName(): | ||||
| 						return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} | ||||
| 					} | ||||
| 				} | ||||
| 				if t.Kind() == types.StructKind { | ||||
| 					switch t { | ||||
| 					case cel.QuantityType: // O(1) cost equality checks | ||||
| 						return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} | ||||
| 					case cel.FormatType: | ||||
| 						return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: cel.MaxFormatSize}.MultiplyByCostFactor(common.StringTraversalCostFactor)} | ||||
| 					case cel.URLType: | ||||
| 						size := checker.SizeEstimate{Min: 1, Max: 1} | ||||
| 						rhSize := rhs.ComputedSize() | ||||
| 						lhSize := rhs.ComputedSize() | ||||
| 						if rhSize != nil && lhSize != nil { | ||||
| 							size = rhSize.Union(*lhSize) | ||||
| 						} | ||||
| 						return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: size.Max}.MultiplyByCostFactor(common.StringTraversalCostFactor)} | ||||
| 					} | ||||
| 				} | ||||
| 				if panicOnUnknown && knownKubernetesCompilerTypes.Has(t) { | ||||
| 					panic(fmt.Errorf("EstimateCallCost: unhandled equality for Kubernetes type %v", t)) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	if panicOnUnknown && !knownUnhandledFunctions[function] { | ||||
| 		panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args)) | ||||
|   | ||||
| @@ -206,6 +206,16 @@ func TestURLsCost(t *testing.T) { | ||||
| 			expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4}, | ||||
| 			expectRuntimeCost:  4, | ||||
| 		}, | ||||
| 		{ | ||||
| 			ops:                []string{" == url('https:://kubernetes.io/')"}, | ||||
| 			expectEsimatedCost: checker.CostEstimate{Min: 7, Max: 9}, | ||||
| 			expectRuntimeCost:  7, | ||||
| 		}, | ||||
| 		{ | ||||
| 			ops:                []string{" == url('http://x.b')"}, | ||||
| 			expectEsimatedCost: checker.CostEstimate{Min: 5, Max: 5}, | ||||
| 			expectRuntimeCost:  5, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| @@ -245,6 +255,14 @@ func TestIPCost(t *testing.T) { | ||||
| 			}, | ||||
| 			expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, | ||||
| 		}, | ||||
| 		{ | ||||
| 			ops: []string{" == ip('192.168.0.1')"}, | ||||
| 			// For most other operations, the cost is expected to be the base + 1. | ||||
| 			expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { | ||||
| 				return c.Add(ipv4BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1}) | ||||
| 			}, | ||||
| 			expectRuntimeCost: func(c uint64) uint64 { return c + ipv4BaseRuntimeCost + 1 }, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| @@ -320,6 +338,14 @@ func TestCIDRCost(t *testing.T) { | ||||
| 			}, | ||||
| 			expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, | ||||
| 		}, | ||||
| 		{ | ||||
| 			ops: []string{" == cidr('2001:db8::/32')"}, | ||||
| 			// For most other operations, the cost is expected to be the base + 1. | ||||
| 			expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { | ||||
| 				return c.Add(ipv6BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1}) | ||||
| 			}, | ||||
| 			expectRuntimeCost: func(c uint64) uint64 { return c + ipv6BaseRuntimeCost + 1 }, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	//nolint:gocritic | ||||
| @@ -708,19 +734,19 @@ func TestQuantityCost(t *testing.T) { | ||||
| 		{ | ||||
| 			name:                "equality_reflexivity", | ||||
| 			expr:                `quantity("200M") == quantity("200M")`, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 1844674407370955266}, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3}, | ||||
| 			expectRuntimeCost:   3, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:                "equality_symmetry", | ||||
| 			expr:                `quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3689348814741910532}, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 6}, | ||||
| 			expectRuntimeCost:   6, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:                "equality_transitivity", | ||||
| 			expr:                `quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 5534023222112865798}, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 9}, | ||||
| 			expectRuntimeCost:   9, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -744,19 +770,19 @@ func TestQuantityCost(t *testing.T) { | ||||
| 		{ | ||||
| 			name:                "add_quantity", | ||||
| 			expr:                `quantity("50k").add(quantity("20")) == quantity("50.02k")`, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268}, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, | ||||
| 			expectRuntimeCost:   5, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:                "sub_quantity", | ||||
| 			expr:                `quantity("50k").sub(quantity("20")) == quantity("49.98k")`, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268}, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, | ||||
| 			expectRuntimeCost:   5, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:                "sub_int", | ||||
| 			expr:                `quantity("50k").sub(20) == quantity("49980")`, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 1844674407370955267}, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 4}, | ||||
| 			expectRuntimeCost:   4, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -825,6 +851,18 @@ func TestNameFormatCost(t *testing.T) { | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34}, | ||||
| 			expectRuntimeCost:   10, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:                "format.dns1123label.validate", | ||||
| 			expr:                `format.named("dns1123Label").value().validate("my-name")`, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34}, | ||||
| 			expectRuntimeCost:   10, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:                "format.dns1123label.validate", | ||||
| 			expr:                `format.named("dns1123Label").value() == format.named("dns1123Label").value()`, | ||||
| 			expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 11}, | ||||
| 			expectRuntimeCost:   5, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
|   | ||||
| @@ -22,6 +22,8 @@ import ( | ||||
|  | ||||
| 	"github.com/google/cel-go/common/types" | ||||
| 	"github.com/google/cel-go/common/types/ref" | ||||
|  | ||||
| 	"k8s.io/apiserver/pkg/cel" | ||||
| 	"k8s.io/apiserver/pkg/cel/library" | ||||
| ) | ||||
|  | ||||
| @@ -228,3 +230,11 @@ func TestFormat(t *testing.T) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestSizeLimit(t *testing.T) { | ||||
| 	for name := range library.ConstantFormats { | ||||
| 		if len(name) > cel.MaxFormatSize { | ||||
| 			t.Fatalf("All formats must be <= %d chars in length", cel.MaxFormatSize) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -48,5 +48,7 @@ const ( | ||||
| 	// MinNumberSize is the length of literal 0 | ||||
| 	MinNumberSize = 1 | ||||
|  | ||||
| 	// MaxFormatSize is the maximum size we allow for format strings | ||||
| 	MaxFormatSize          = 64 | ||||
| 	MaxNameFormatRegexSize = 128 | ||||
| ) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Kubernetes Prow Robot
					Kubernetes Prow Robot