Bump cel-go to v0.10.0

This commit is contained in:
Joe Betz
2022-03-07 20:47:04 -05:00
parent f93be6584e
commit 2a6b85c395
66 changed files with 3332 additions and 817 deletions

View File

@@ -17,8 +17,10 @@ go_library(
"evalstate.go",
"interpretable.go",
"interpreter.go",
"optimizations.go",
"planner.go",
"prune.go",
"runtimecost.go",
],
importpath = "github.com/google/cel-go/interpreter",
deps = [

View File

@@ -157,14 +157,23 @@ func (q *stringQualifier) QualifierValueEquals(value interface{}) bool {
// QualifierValueEquals implementation for int qualifiers.
func (q *intQualifier) QualifierValueEquals(value interface{}) bool {
ival, ok := value.(int64)
return ok && q.value == ival
return numericValueEquals(value, q.celValue)
}
// QualifierValueEquals implementation for uint qualifiers.
func (q *uintQualifier) QualifierValueEquals(value interface{}) bool {
uval, ok := value.(uint64)
return ok && q.value == uval
return numericValueEquals(value, q.celValue)
}
// QualifierValueEquals implementation for double qualifiers.
func (q *doubleQualifier) QualifierValueEquals(value interface{}) bool {
return numericValueEquals(value, q.celValue)
}
// numericValueEquals uses CEL equality to determine whether two number values are
func numericValueEquals(value interface{}, celValue ref.Val) bool {
val := types.DefaultTypeAdapter.NativeToValue(value)
return celValue.Equal(val) == types.True
}
// NewPartialAttributeFactory returns an AttributeFactory implementation capable of performing
@@ -348,7 +357,8 @@ func (m *attributeMatcher) Resolve(vars Activation) (interface{}, error) {
// the standard Resolve logic applies.
func (m *attributeMatcher) TryResolve(vars Activation) (interface{}, bool, error) {
id := m.NamespacedAttribute.ID()
partial, isPartial := vars.(PartialActivation)
// Bug in how partial activation is resolved, should search parents as well.
partial, isPartial := toPartialActivation(vars)
if isPartial {
unk, err := m.fac.matchesUnknownPatterns(
partial,
@@ -381,3 +391,14 @@ func (m *attributeMatcher) Qualify(vars Activation, obj interface{}) (interface{
}
return qual.Qualify(vars, obj)
}
func toPartialActivation(vars Activation) (PartialActivation, bool) {
pv, ok := vars.(PartialActivation)
if ok {
return pv, true
}
if vars.Parent() != nil {
return toPartialActivation(vars.Parent())
}
return nil, false
}

View File

@@ -15,7 +15,6 @@
package interpreter
import (
"errors"
"fmt"
"math"
@@ -487,9 +486,7 @@ func (a *maybeAttribute) AddQualifier(qual Qualifier) (Attribute, error) {
}
}
// Next, ensure the most specific variable / type reference is searched first.
a.attrs = append([]NamespacedAttribute{
a.fac.AbsoluteAttribute(qual.ID(), augmentedNames...),
}, a.attrs...)
a.attrs = append([]NamespacedAttribute{a.fac.AbsoluteAttribute(qual.ID(), augmentedNames...)}, a.attrs...)
return a, nil
}
@@ -628,6 +625,10 @@ func newQualifier(adapter ref.TypeAdapter, id int64, v interface{}) (Qualifier,
qual = &uintQualifier{id: id, value: val, celValue: types.Uint(val), adapter: adapter}
case bool:
qual = &boolQualifier{id: id, value: val, celValue: types.Bool(val), adapter: adapter}
case float32:
qual = &doubleQualifier{id: id, value: float64(val), celValue: types.Double(val), adapter: adapter}
case float64:
qual = &doubleQualifier{id: id, value: val, celValue: types.Double(val), adapter: adapter}
case types.String:
qual = &stringQualifier{id: id, value: string(val), celValue: val, adapter: adapter}
case types.Int:
@@ -714,9 +715,6 @@ func (q *stringQualifier) Qualify(vars Activation, obj interface{}) (interface{}
if err != nil {
return nil, err
}
if types.IsUnknown(elem) {
return elem, nil
}
return elem, nil
}
if isMap && !isKey {
@@ -829,9 +827,6 @@ func (q *intQualifier) Qualify(vars Activation, obj interface{}) (interface{}, e
if err != nil {
return nil, err
}
if types.IsUnknown(elem) {
return elem, nil
}
return elem, nil
}
if isMap && !isKey {
@@ -891,9 +886,6 @@ func (q *uintQualifier) Qualify(vars Activation, obj interface{}) (interface{},
if err != nil {
return nil, err
}
if types.IsUnknown(elem) {
return elem, nil
}
return elem, nil
}
if isMap && !isKey {
@@ -942,9 +934,6 @@ func (q *boolQualifier) Qualify(vars Activation, obj interface{}) (interface{},
if err != nil {
return nil, err
}
if types.IsUnknown(elem) {
return elem, nil
}
return elem, nil
}
if !isKey {
@@ -996,6 +985,37 @@ func (q *fieldQualifier) Cost() (min, max int64) {
return 0, 0
}
// doubleQualifier qualifies a CEL object, map, or list using a double value.
//
// This qualifier is used for working with dynamic data like JSON or protobuf.Any where the value
// type may not be known ahead of time and may not conform to the standard types supported as valid
// protobuf map key types.
type doubleQualifier struct {
id int64
value float64
celValue ref.Val
adapter ref.TypeAdapter
}
// ID is an implementation of the Qualifier interface method.
func (q *doubleQualifier) ID() int64 {
return q.id
}
// Qualify implements the Qualifier interface method.
func (q *doubleQualifier) Qualify(vars Activation, obj interface{}) (interface{}, error) {
switch o := obj.(type) {
case types.Unknown:
return o, nil
default:
elem, err := refResolve(q.adapter, q.celValue, obj)
if err != nil {
return nil, err
}
return elem, nil
}
}
// refResolve attempts to convert the value to a CEL value and then uses reflection methods
// to try and resolve the qualifier.
func refResolve(adapter ref.TypeAdapter, idx ref.Val, obj interface{}) (ref.Val, error) {
@@ -1006,9 +1026,6 @@ func refResolve(adapter ref.TypeAdapter, idx ref.Val, obj interface{}) (ref.Val,
if !found {
return nil, fmt.Errorf("no such key: %v", idx)
}
if types.IsError(elem) {
return nil, elem.(*types.Err)
}
return elem, nil
}
indexer, isIndexer := celVal.(traits.Indexer)
@@ -1028,5 +1045,5 @@ func refResolve(adapter ref.TypeAdapter, idx ref.Val, obj interface{}) (ref.Val,
if types.IsError(celVal) {
return nil, celVal.(*types.Err)
}
return nil, errors.New("no such overload")
return nil, fmt.Errorf("no such key: %v", idx)
}

View File

@@ -16,7 +16,11 @@ package interpreter
import "math"
// TODO: remove Coster.
// Coster calculates the heuristic cost incurred during evaluation.
// Deprecated: Please migrate cel.EstimateCost, it supports length estimates for input data and cost estimates for
// extension functions.
type Coster interface {
Cost() (min, max int64)
}

View File

@@ -25,11 +25,8 @@ import (
// Interpretable expression nodes at construction time.
type InterpretableDecorator func(Interpretable) (Interpretable, error)
// evalObserver is a functional interface that accepts an expression id and an observed value.
type evalObserver func(int64, ref.Val)
// decObserveEval records evaluation state into an EvalState object.
func decObserveEval(observer evalObserver) InterpretableDecorator {
func decObserveEval(observer EvalObserver) InterpretableDecorator {
return func(i Interpretable) (Interpretable, error) {
switch inst := i.(type) {
case *evalWatch, *evalWatchAttr, *evalWatchConst:
@@ -54,6 +51,19 @@ func decObserveEval(observer evalObserver) InterpretableDecorator {
}
}
// decInterruptFolds creates an intepretable decorator which marks comprehensions as interruptable
// where the interrupt state is communicated via a hidden variable on the Activation.
func decInterruptFolds() InterpretableDecorator {
return func(i Interpretable) (Interpretable, error) {
fold, ok := i.(*evalFold)
if !ok {
return i, nil
}
fold.interruptable = true
return fold, nil
}
}
// decDisableShortcircuits ensures that all branches of an expression will be evaluated, no short-circuiting.
func decDisableShortcircuits() InterpretableDecorator {
return func(i Interpretable) (Interpretable, error) {
@@ -71,16 +81,8 @@ func decDisableShortcircuits() InterpretableDecorator {
rhs: expr.rhs,
}, nil
case *evalFold:
return &evalExhaustiveFold{
id: expr.id,
accu: expr.accu,
accuVar: expr.accuVar,
iterRange: expr.iterRange,
iterVar: expr.iterVar,
cond: expr.cond,
step: expr.step,
result: expr.result,
}, nil
expr.exhaustive = true
return expr, nil
case InterpretableAttribute:
cond, isCond := expr.Attr().(*conditionalAttribute)
if isCond {
@@ -118,6 +120,48 @@ func decOptimize() InterpretableDecorator {
}
}
// decRegexOptimizer compiles regex pattern string constants.
func decRegexOptimizer(regexOptimizations ...*RegexOptimization) InterpretableDecorator {
functionMatchMap := make(map[string]*RegexOptimization)
overloadMatchMap := make(map[string]*RegexOptimization)
for _, m := range regexOptimizations {
functionMatchMap[m.Function] = m
if m.OverloadID != "" {
overloadMatchMap[m.OverloadID] = m
}
}
return func(i Interpretable) (Interpretable, error) {
call, ok := i.(InterpretableCall)
if !ok {
return i, nil
}
var matcher *RegexOptimization
var found bool
if call.OverloadID() != "" {
matcher, found = overloadMatchMap[call.OverloadID()]
}
if !found {
matcher, found = functionMatchMap[call.Function()]
}
if !found || matcher.RegexIndex >= len(call.Args()) {
return i, nil
}
args := call.Args()
regexArg := args[matcher.RegexIndex]
regexStr, isConst := regexArg.(InterpretableConst)
if !isConst {
return i, nil
}
pattern, ok := regexStr.Value().(types.String)
if !ok {
return i, nil
}
return matcher.Factory(call, string(pattern))
}
}
func maybeOptimizeConstUnary(i Interpretable, call InterpretableCall) (Interpretable, error) {
args := call.Args()
if len(args) != 1 {
@@ -177,7 +221,6 @@ func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Inte
return NewConstValue(inlist.ID(), types.False), nil
}
it := list.Iterator()
var typ ref.Type
valueSet := make(map[ref.Val]ref.Val)
for it.HasNext() == types.True {
elem := it.Next()
@@ -185,17 +228,44 @@ func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Inte
// Note, non-primitive type are not yet supported.
return i, nil
}
if typ == nil {
typ = elem.Type()
} else if typ.TypeName() != elem.Type().TypeName() {
return i, nil
}
valueSet[elem] = types.True
switch ev := elem.(type) {
case types.Double:
iv := ev.ConvertToType(types.IntType)
// Ensure that only lossless conversions are added to the set
if !types.IsError(iv) && iv.Equal(ev) == types.True {
valueSet[iv] = types.True
}
// Ensure that only lossless conversions are added to the set
uv := ev.ConvertToType(types.UintType)
if !types.IsError(uv) && uv.Equal(ev) == types.True {
valueSet[uv] = types.True
}
case types.Int:
dv := ev.ConvertToType(types.DoubleType)
if !types.IsError(dv) {
valueSet[dv] = types.True
}
uv := ev.ConvertToType(types.UintType)
if !types.IsError(uv) {
valueSet[uv] = types.True
}
case types.Uint:
dv := ev.ConvertToType(types.DoubleType)
if !types.IsError(dv) {
valueSet[dv] = types.True
}
iv := ev.ConvertToType(types.IntType)
if !types.IsError(iv) {
valueSet[iv] = types.True
}
default:
break
}
}
return &evalSetMembership{
inst: inlist,
arg: lhs,
argTypeName: typ.TypeName(),
valueSet: valueSet,
inst: inlist,
arg: lhs,
valueSet: valueSet,
}, nil
}

View File

@@ -100,8 +100,6 @@ func StandardOverloads() []*Overload {
return cmp
}},
// TODO: Verify overflow, NaN, underflow cases for numeric values.
// Add operator
{Operator: operators.Add,
OperandTrait: traits.AdderType,

View File

@@ -88,6 +88,18 @@ type InterpretableCall interface {
Args() []Interpretable
}
// InterpretableConstructor interface for inspecting Interpretable instructions that initialize a list, map
// or struct.
type InterpretableConstructor interface {
Interpretable
// InitVals returns all the list elements, map key and values or struct field values.
InitVals() []Interpretable
// Type returns the type constructed.
Type() ref.Type
}
// Core Interpretable implementations used during the program planning phase.
type evalTestOnly struct {
@@ -298,7 +310,13 @@ func (eq *evalEq) ID() int64 {
func (eq *evalEq) Eval(ctx Activation) ref.Val {
lVal := eq.lhs.Eval(ctx)
rVal := eq.rhs.Eval(ctx)
return lVal.Equal(rVal)
if types.IsUnknownOrError(lVal) {
return lVal
}
if types.IsUnknownOrError(rVal) {
return rVal
}
return types.Equal(lVal, rVal)
}
// Cost implements the Coster interface method.
@@ -336,12 +354,13 @@ func (ne *evalNe) ID() int64 {
func (ne *evalNe) Eval(ctx Activation) ref.Val {
lVal := ne.lhs.Eval(ctx)
rVal := ne.rhs.Eval(ctx)
eqVal := lVal.Equal(rVal)
eqBool, ok := eqVal.(types.Bool)
if !ok {
return types.ValOrErr(eqVal, "no such overload: _!=_")
if types.IsUnknownOrError(lVal) {
return lVal
}
return !eqBool
if types.IsUnknownOrError(rVal) {
return rVal
}
return types.Bool(types.Equal(lVal, rVal) != types.True)
}
// Cost implements the Coster interface method.
@@ -526,6 +545,17 @@ type evalVarArgs struct {
impl functions.FunctionOp
}
// NewCall creates a new call Interpretable.
func NewCall(id int64, function, overload string, args []Interpretable, impl functions.FunctionOp) InterpretableCall {
return &evalVarArgs{
id: id,
function: function,
overload: overload,
args: args,
impl: impl,
}
}
// ID implements the Interpretable interface method.
func (fn *evalVarArgs) ID() int64 {
return fn.id
@@ -603,6 +633,14 @@ func (l *evalList) Eval(ctx Activation) ref.Val {
return l.adapter.NativeToValue(elemVals)
}
func (l *evalList) InitVals() []Interpretable {
return l.elems
}
func (l *evalList) Type() ref.Type {
return types.ListType
}
// Cost implements the Coster interface method.
func (l *evalList) Cost() (min, max int64) {
return sumOfCost(l.elems)
@@ -638,6 +676,14 @@ func (m *evalMap) Eval(ctx Activation) ref.Val {
return m.adapter.NativeToValue(entries)
}
func (m *evalMap) InitVals() []Interpretable {
return append(m.keys, m.vals...)
}
func (m *evalMap) Type() ref.Type {
return types.MapType
}
// Cost implements the Coster interface method.
func (m *evalMap) Cost() (min, max int64) {
kMin, kMax := sumOfCost(m.keys)
@@ -672,6 +718,14 @@ func (o *evalObj) Eval(ctx Activation) ref.Val {
return o.provider.NewValue(o.typeName, fieldVals)
}
func (o *evalObj) InitVals() []Interpretable {
return o.vals
}
func (o *evalObj) Type() ref.Type {
return types.NewObjectTypeValue(o.typeName)
}
// Cost implements the Coster interface method.
func (o *evalObj) Cost() (min, max int64) {
return sumOfCost(o.vals)
@@ -688,14 +742,17 @@ func sumOfCost(interps []Interpretable) (min, max int64) {
}
type evalFold struct {
id int64
accuVar string
iterVar string
iterRange Interpretable
accu Interpretable
cond Interpretable
step Interpretable
result Interpretable
id int64
accuVar string
iterVar string
iterRange Interpretable
accu Interpretable
cond Interpretable
step Interpretable
result Interpretable
adapter ref.TypeAdapter
exhaustive bool
interruptable bool
}
// ID implements the Interpretable interface method.
@@ -714,9 +771,19 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val {
accuCtx.parent = ctx
accuCtx.name = fold.accuVar
accuCtx.val = fold.accu.Eval(ctx)
// If the accumulator starts as an empty list, then the comprehension will build a list
// so create a mutable list to optimize the cost of the inner loop.
l, ok := accuCtx.val.(traits.Lister)
buildingList := false
if !fold.exhaustive && ok && l.Size() == types.IntZero {
buildingList = true
accuCtx.val = types.NewMutableList(fold.adapter)
}
iterCtx := varActivationPool.Get().(*varActivation)
iterCtx.parent = accuCtx
iterCtx.name = fold.iterVar
interrupted := false
it := foldRange.(traits.Iterable).Iterator()
for it.HasNext() == types.True {
// Modify the iter var in the fold activation.
@@ -725,17 +792,31 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val {
// Evaluate the condition, terminate the loop if false.
cond := fold.cond.Eval(iterCtx)
condBool, ok := cond.(types.Bool)
if !types.IsUnknown(cond) && ok && condBool != types.True {
if !fold.exhaustive && ok && condBool != types.True {
break
}
// Evalute the evaluation step into accu var.
// Evaluate the evaluation step into accu var.
accuCtx.val = fold.step.Eval(iterCtx)
if fold.interruptable {
if stop, found := ctx.ResolveName("#interrupted"); found && stop == true {
interrupted = true
break
}
}
}
varActivationPool.Put(iterCtx)
if interrupted {
varActivationPool.Put(accuCtx)
return types.NewErr("operation interrupted")
}
// Compute the result.
res := fold.result.Eval(accuCtx)
varActivationPool.Put(iterCtx)
varActivationPool.Put(accuCtx)
// Convert a mutable list to an immutable one, if the comprehension has generated a list as a result.
if !types.IsUnknownOrError(res) && buildingList {
res = res.(traits.MutableLister).ToImmutableList()
}
return res
}
@@ -760,6 +841,10 @@ func (fold *evalFold) Cost() (min, max int64) {
cMin, cMax := estimateCost(fold.cond)
sMin, sMax := estimateCost(fold.step)
rMin, rMax := estimateCost(fold.result)
if fold.exhaustive {
cMin = cMin * rangeCnt
sMin = sMin * rangeCnt
}
// The cond and step costs are multiplied by size(iterRange). The minimum possible cost incurs
// when the evaluation result can be determined by the first iteration.
@@ -773,10 +858,9 @@ func (fold *evalFold) Cost() (min, max int64) {
// evalSetMembership is an Interpretable implementation which tests whether an input value
// exists within the set of map keys used to model a set.
type evalSetMembership struct {
inst Interpretable
arg Interpretable
argTypeName string
valueSet map[ref.Val]ref.Val
inst Interpretable
arg Interpretable
valueSet map[ref.Val]ref.Val
}
// ID implements the Interpretable interface method.
@@ -787,9 +871,6 @@ func (e *evalSetMembership) ID() int64 {
// Eval implements the Interpretable interface method.
func (e *evalSetMembership) Eval(ctx Activation) ref.Val {
val := e.arg.Eval(ctx)
if val.Type().TypeName() != e.argTypeName {
return types.ValOrErr(val, "no such overload")
}
if ret, found := e.valueSet[val]; found {
return ret
}
@@ -805,13 +886,13 @@ func (e *evalSetMembership) Cost() (min, max int64) {
// expression so that it may observe the computed value and send it to an observer.
type evalWatch struct {
Interpretable
observer evalObserver
observer EvalObserver
}
// Eval implements the Interpretable interface method.
func (e *evalWatch) Eval(ctx Activation) ref.Val {
val := e.Interpretable.Eval(ctx)
e.observer(e.ID(), val)
e.observer(e.ID(), e.Interpretable, val)
return val
}
@@ -826,7 +907,7 @@ func (e *evalWatch) Cost() (min, max int64) {
// must implement the instAttr interface by proxy.
type evalWatchAttr struct {
InterpretableAttribute
observer evalObserver
observer EvalObserver
}
// AddQualifier creates a wrapper over the incoming qualifier which observes the qualification
@@ -850,11 +931,23 @@ func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) {
return e, err
}
// Cost implements the Coster interface method.
func (e *evalWatchAttr) Cost() (min, max int64) {
return estimateCost(e.InterpretableAttribute)
}
// Eval implements the Interpretable interface method.
func (e *evalWatchAttr) Eval(vars Activation) ref.Val {
val := e.InterpretableAttribute.Eval(vars)
e.observer(e.ID(), e.InterpretableAttribute, val)
return val
}
// evalWatchConstQual observes the qualification of an object using a constant boolean, int,
// string, or uint.
type evalWatchConstQual struct {
ConstantQualifier
observer evalObserver
observer EvalObserver
adapter ref.TypeAdapter
}
@@ -872,7 +965,7 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj interface{}) (interfac
} else {
val = e.adapter.NativeToValue(out)
}
e.observer(e.ID(), val)
e.observer(e.ID(), e.ConstantQualifier, val)
return out, err
}
@@ -885,7 +978,7 @@ func (e *evalWatchConstQual) QualifierValueEquals(value interface{}) bool {
// evalWatchQual observes the qualification of an object by a value computed at runtime.
type evalWatchQual struct {
Qualifier
observer evalObserver
observer EvalObserver
adapter ref.TypeAdapter
}
@@ -903,32 +996,20 @@ func (e *evalWatchQual) Qualify(vars Activation, obj interface{}) (interface{},
} else {
val = e.adapter.NativeToValue(out)
}
e.observer(e.ID(), val)
e.observer(e.ID(), e.Qualifier, val)
return out, err
}
// Cost implements the Coster interface method.
func (e *evalWatchAttr) Cost() (min, max int64) {
return estimateCost(e.InterpretableAttribute)
}
// Eval implements the Interpretable interface method.
func (e *evalWatchAttr) Eval(vars Activation) ref.Val {
val := e.InterpretableAttribute.Eval(vars)
e.observer(e.ID(), val)
return val
}
// evalWatchConst describes a watcher of an instConst Interpretable.
type evalWatchConst struct {
InterpretableConst
observer evalObserver
observer EvalObserver
}
// Eval implements the Interpretable interface method.
func (e *evalWatchConst) Eval(vars Activation) ref.Val {
val := e.Value()
e.observer(e.ID(), val)
e.observer(e.ID(), e.InterpretableConst, val)
return val
}
@@ -1074,83 +1155,6 @@ func (cond *evalExhaustiveConditional) Cost() (min, max int64) {
return cond.attr.Cost()
}
// evalExhaustiveFold is like evalFold, but does not short-circuit argument evaluation.
type evalExhaustiveFold struct {
id int64
accuVar string
iterVar string
iterRange Interpretable
accu Interpretable
cond Interpretable
step Interpretable
result Interpretable
}
// ID implements the Interpretable interface method.
func (fold *evalExhaustiveFold) ID() int64 {
return fold.id
}
// Eval implements the Interpretable interface method.
func (fold *evalExhaustiveFold) Eval(ctx Activation) ref.Val {
foldRange := fold.iterRange.Eval(ctx)
if !foldRange.Type().HasTrait(traits.IterableType) {
return types.ValOrErr(foldRange, "got '%T', expected iterable type", foldRange)
}
// Configure the fold activation with the accumulator initial value.
accuCtx := varActivationPool.Get().(*varActivation)
accuCtx.parent = ctx
accuCtx.name = fold.accuVar
accuCtx.val = fold.accu.Eval(ctx)
iterCtx := varActivationPool.Get().(*varActivation)
iterCtx.parent = accuCtx
iterCtx.name = fold.iterVar
it := foldRange.(traits.Iterable).Iterator()
for it.HasNext() == types.True {
// Modify the iter var in the fold activation.
iterCtx.val = it.Next()
// Evaluate the condition, but don't terminate the loop as this is exhaustive eval!
fold.cond.Eval(iterCtx)
// Evalute the evaluation step into accu var.
accuCtx.val = fold.step.Eval(iterCtx)
}
// Compute the result.
res := fold.result.Eval(accuCtx)
varActivationPool.Put(iterCtx)
varActivationPool.Put(accuCtx)
return res
}
// Cost implements the Coster interface method.
func (fold *evalExhaustiveFold) Cost() (min, max int64) {
// Compute the cost for evaluating iterRange.
iMin, iMax := estimateCost(fold.iterRange)
// Compute the size of iterRange. If the size depends on the input, return the maximum possible
// cost range.
foldRange := fold.iterRange.Eval(EmptyActivation())
if !foldRange.Type().HasTrait(traits.IterableType) {
return 0, math.MaxInt64
}
var rangeCnt int64
it := foldRange.(traits.Iterable).Iterator()
for it.HasNext() == types.True {
it.Next()
rangeCnt++
}
aMin, aMax := estimateCost(fold.accu)
cMin, cMax := estimateCost(fold.cond)
sMin, sMax := estimateCost(fold.step)
rMin, rMax := estimateCost(fold.result)
// The cond and step costs are multiplied by size(iterRange).
return iMin + aMin + cMin*rangeCnt + sMin*rangeCnt + rMin,
iMax + aMax + cMax*rangeCnt + sMax*rangeCnt + rMax
}
// evalAttr evaluates an Attribute value.
type evalAttr struct {
adapter ref.TypeAdapter

View File

@@ -38,41 +38,118 @@ type Interpreter interface {
decorators ...InterpretableDecorator) (Interpretable, error)
}
// EvalObserver is a functional interface that accepts an expression id and an observed value.
// The id identifies the expression that was evaluated, the programStep is the Interpretable or Qualifier that
// was evaluated and value is the result of the evaluation.
type EvalObserver func(id int64, programStep interface{}, value ref.Val)
// Observe constructs a decorator that calls all the provided observers in order after evaluating each Interpretable
// or Qualifier during program evaluation.
func Observe(observers ...EvalObserver) InterpretableDecorator {
if len(observers) == 1 {
return decObserveEval(observers[0])
}
observeFn := func(id int64, programStep interface{}, val ref.Val) {
for _, observer := range observers {
observer(id, programStep, val)
}
}
return decObserveEval(observeFn)
}
// EvalCancelledError represents a cancelled program evaluation operation.
type EvalCancelledError struct {
Message string
// Type identifies the cause of the cancellation.
Cause CancellationCause
}
func (e EvalCancelledError) Error() string {
return e.Message
}
// CancellationCause enumerates the ways a program evaluation operation can be cancelled.
type CancellationCause int
const (
// ContextCancelled indicates that the operation was cancelled in response to a Golang context cancellation.
ContextCancelled CancellationCause = iota
// CostLimitExceeded indicates that the operation was cancelled in response to the actual cost limit being
// exceeded.
CostLimitExceeded
)
// TODO: Replace all usages of TrackState with EvalStateObserver
// TrackState decorates each expression node with an observer which records the value
// associated with the given expression id. EvalState must be provided to the decorator.
// This decorator is not thread-safe, and the EvalState must be reset between Eval()
// calls.
// DEPRECATED: Please use EvalStateObserver instead. It composes gracefully with additional observers.
func TrackState(state EvalState) InterpretableDecorator {
observer := func(id int64, val ref.Val) {
return Observe(EvalStateObserver(state))
}
// EvalStateObserver provides an observer which records the value
// associated with the given expression id. EvalState must be provided to the observer.
// This decorator is not thread-safe, and the EvalState must be reset between Eval()
// calls.
func EvalStateObserver(state EvalState) EvalObserver {
return func(id int64, programStep interface{}, val ref.Val) {
state.SetValue(id, val)
}
return decObserveEval(observer)
}
// TODO: Replace all usages of ExhaustiveEval with ExhaustiveEvalWrapper
// ExhaustiveEval replaces operations that short-circuit with versions that evaluate
// expressions and couples this behavior with the TrackState() decorator to provide
// insight into the evaluation state of the entire expression. EvalState must be
// provided to the decorator. This decorator is not thread-safe, and the EvalState
// must be reset between Eval() calls.
func ExhaustiveEval(state EvalState) InterpretableDecorator {
func ExhaustiveEval() InterpretableDecorator {
ex := decDisableShortcircuits()
obs := TrackState(state)
return func(i Interpretable) (Interpretable, error) {
var err error
i, err = ex(i)
if err != nil {
return nil, err
}
return obs(i)
return ex(i)
}
}
func InterruptableEval() InterpretableDecorator {
return decInterruptFolds()
}
// Optimize will pre-compute operations such as list and map construction and optimize
// call arguments to set membership tests. The set of optimizations will increase over time.
func Optimize() InterpretableDecorator {
return decOptimize()
}
// RegexOptimization provides a way to replace an InterpretableCall for a regex function when the
// RegexIndex argument is a string constant. Typically, the Factory would compile the regex pattern at
// RegexIndex and report any errors (at program creation time) and then use the compiled regex for
// all regex function invocations.
type RegexOptimization struct {
// Function is the name of the function to optimize.
Function string
// OverloadID is the ID of the overload to optimize.
OverloadID string
// RegexIndex is the index position of the regex pattern argument. Only calls to the function where this argument is
// a string constant will be delegated to this optimizer.
RegexIndex int
// Factory constructs a replacement InterpretableCall node that optimizes the regex function call. Factory is
// provided with the unoptimized regex call and the string constant at the RegexIndex argument.
// The Factory may compile the regex for use across all invocations of the call, return any errors and
// return an interpreter.NewCall with the desired regex optimized function impl.
Factory func(call InterpretableCall, regexPattern string) (InterpretableCall, error)
}
// CompileRegexConstants compiles regex pattern string constants at program creation time and reports any regex pattern
// compile errors.
func CompileRegexConstants(regexOptimizations ...*RegexOptimization) InterpretableDecorator {
return decRegexOptimizer(regexOptimizations...)
}
type exprInterpreter struct {
dispatcher Dispatcher
container *containers.Container

View File

@@ -0,0 +1,46 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
import (
"regexp"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// MatchesRegexOptimization optimizes the 'matches' standard library function by compiling the regex pattern and
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
// call invocations.
var MatchesRegexOptimization = &RegexOptimization{
Function: "matches",
RegexIndex: 1,
Factory: func(call InterpretableCall, regexPattern string) (InterpretableCall, error) {
compiledRegex, err := regexp.Compile(regexPattern)
if err != nil {
return nil, err
}
return NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(values ...ref.Val) ref.Val {
if len(values) != 2 {
return types.NoSuchOverloadErr()
}
in, ok := values[0].Value().(string)
if !ok {
return types.NoSuchOverloadErr()
}
return types.Bool(compiledRegex.MatchString(in))
}), nil
},
}

View File

@@ -617,6 +617,7 @@ func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) {
cond: cond,
step: step,
result: result,
adapter: p.adapter,
}, nil
}

View File

@@ -0,0 +1,192 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package interpreter provides functions to evaluate parsed expressions with
// the option to augment the evaluation with inputs and functions supplied at
// evaluation time.
package interpreter
import (
"math"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// WARNING: Any changes to cost calculations in this file require a corresponding change in checker/cost.go
// ActualCostEstimator provides function call cost estimations at runtime
// CallCost returns an estimated cost for the function overload invocation with the given args, or nil if it has no
// estimate to provide. CEL attempts to provide reasonable estimates for its standard function library, so CallCost
// should typically not need to provide an estimate for CELs standard function.
type ActualCostEstimator interface {
CallCost(overloadId string, args []ref.Val) *uint64
}
// CostObserver provides an observer that tracks runtime cost.
func CostObserver(tracker *CostTracker) EvalObserver {
observer := func(id int64, programStep interface{}, val ref.Val) {
switch t := programStep.(type) {
case ConstantQualifier:
// TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them
// and enable the below pop. Once enabled this can case can be collapsed into the Qualifier case.
//tracker.stack.pop(1)
tracker.cost += 1
case InterpretableConst:
// zero cost
case InterpretableAttribute:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
_, isConditional := t.Attr().(*conditionalAttribute)
if !isConditional {
tracker.cost += common.SelectAndIdentCost
}
case *evalExhaustiveConditional, *evalOr, *evalAnd, *evalExhaustiveOr, *evalExhaustiveAnd:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
case Qualifier:
tracker.stack.pop(1)
tracker.cost += 1
case InterpretableCall:
if argVals, ok := tracker.stack.pop(len(t.Args())); ok {
tracker.cost += tracker.costCall(t, argVals)
}
case InterpretableConstructor:
switch t.Type() {
case types.ListType:
tracker.cost += common.ListCreateBaseCost
case types.MapType:
tracker.cost += common.MapCreateBaseCost
default:
tracker.cost += common.StructCreateBaseCost
}
}
tracker.stack.push(val)
if tracker.Limit != nil && tracker.cost > *tracker.Limit {
panic(EvalCancelledError{Cause: CostLimitExceeded, Message: "operation cancelled: actual cost limit exceeded"})
}
}
return observer
}
// CostTracker represents the information needed for tacking runtime cost
type CostTracker struct {
Estimator ActualCostEstimator
Limit *uint64
cost uint64
stack refValStack
}
// ActualCost returns the runtime cost
func (c CostTracker) ActualCost() uint64 {
return c.cost
}
func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val) uint64 {
var cost uint64
if c.Estimator != nil {
callCost := c.Estimator.CallCost(call.OverloadID(), argValues)
if callCost != nil {
cost += *callCost
return cost
}
}
// if user didn't specify, the default way of calculating runtime cost would be used.
// if user has their own implementation of ActualCostEstimator, make sure to cover the mapping between overloadId and cost calculation
switch call.OverloadID() {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
// We just assume all list containment checks are O(n).
cost += c.actualSize(argValues[1])
// O(min(m, n)) functions
case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString,
overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes,
overloads.Equals, overloads.NotEquals:
// When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.),
// the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost
// of 1.
lhsSize := c.actualSize(argValues[0])
rhsSize := c.actualSize(argValues[1])
minSize := lhsSize
if rhsSize < minSize {
minSize = rhsSize
}
cost += uint64(math.Ceil(float64(minSize) * common.StringTraversalCostFactor))
// O(m+n) functions
case overloads.AddString, overloads.AddBytes:
// In the worst case scenario, we would need to reallocate a new backing store and copy both operands over.
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
// O(nm) functions
case overloads.MatchesString:
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor))
cost += strCost * regexCost
case overloads.ContainsString:
strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
cost += strCost * substrCost
default:
// The following operations are assumed to have O(1) complexity.
// - AddList due to the implementation. Index lookup can be O(c) the
// number of concatenated lists, but we don't track that is cost calculations.
// - Conversions, since none perform a traversal of a type of unbound length.
// - Computing the size of strings, byte sequences, lists and maps.
// - Logical operations and all operators on fixed width scalars (comparisons, equality)
// - Any functions that don't have a declared cost either here or in provided ActualCostEstimator.
cost += 1
}
return cost
}
// actualSize returns the size of value
func (c CostTracker) actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}
// refValStack keeps track of values of the stack for cost calculation purposes
type refValStack []ref.Val
func (s *refValStack) push(value ref.Val) {
*s = append(*s, value)
}
func (s *refValStack) pop(count int) ([]ref.Val, bool) {
if len(*s) < count {
return nil, false
}
idx := len(*s) - count
el := (*s)[idx:]
*s = (*s)[:idx]
return el, true
}