hack/pin-dependency.sh github.com/google/cel-go v0.22.0

This commit is contained in:
Joe Betz
2024-11-05 19:21:09 -05:00
parent 2caf4eddd8
commit b0180a9a37
104 changed files with 8161 additions and 639 deletions

View File

@@ -39,6 +39,7 @@ go_library(
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"//parser:go_default_library",
"@dev_cel_expr//:expr",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protodesc:go_default_library",
@@ -81,7 +82,6 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",

View File

@@ -23,6 +23,7 @@ import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
celpb "cel.dev/expr"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@@ -312,20 +313,34 @@ func ExprTypeToType(t *exprpb.Type) (*Type, error) {
// ExprDeclToDeclaration converts a protobuf CEL declaration to a CEL-native declaration, either a Variable or Function.
func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
return AlphaProtoAsDeclaration(d)
}
// AlphaProtoAsDeclaration converts a v1alpha1.Decl value describing a variable or function into an EnvOption.
func AlphaProtoAsDeclaration(d *exprpb.Decl) (EnvOption, error) {
canonical := &celpb.Decl{}
if err := convertProto(d, canonical); err != nil {
return nil, err
}
return ProtoAsDeclaration(canonical)
}
// ProtoAsDeclaration converts a canonical celpb.Decl value describing a variable or function into an EnvOption.
func ProtoAsDeclaration(d *celpb.Decl) (EnvOption, error) {
switch d.GetDeclKind().(type) {
case *exprpb.Decl_Function:
case *celpb.Decl_Function:
overloads := d.GetFunction().GetOverloads()
opts := make([]FunctionOpt, len(overloads))
for i, o := range overloads {
args := make([]*Type, len(o.GetParams()))
for j, p := range o.GetParams() {
a, err := types.ExprTypeToType(p)
a, err := types.ProtoAsType(p)
if err != nil {
return nil, err
}
args[j] = a
}
res, err := types.ExprTypeToType(o.GetResultType())
res, err := types.ProtoAsType(o.GetResultType())
if err != nil {
return nil, err
}
@@ -336,15 +351,15 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
}
}
return Function(d.GetName(), opts...), nil
case *exprpb.Decl_Ident:
t, err := types.ExprTypeToType(d.GetIdent().GetType())
case *celpb.Decl_Ident:
t, err := types.ProtoAsType(d.GetIdent().GetType())
if err != nil {
return nil, err
}
if d.GetIdent().GetValue() == nil {
return Variable(d.GetName(), t), nil
}
val, err := ast.ConstantToVal(d.GetIdent().GetValue())
val, err := ast.ProtoConstantAsVal(d.GetIdent().GetValue())
if err != nil {
return nil, err
}

View File

@@ -459,6 +459,12 @@ func (e *Env) ParseSource(src Source) (*Ast, *Issues) {
// Program generates an evaluable instance of the Ast within the environment (Env).
func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) {
return e.PlanProgram(ast.NativeRep(), opts...)
}
// PlanProgram generates an evaluable instance of the AST in the go-native representation within
// the environment (Env).
func (e *Env) PlanProgram(a *celast.AST, opts ...ProgramOption) (Program, error) {
optSet := e.progOpts
if len(opts) != 0 {
mergedOpts := []ProgramOption{}
@@ -466,7 +472,7 @@ func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) {
mergedOpts = append(mergedOpts, opts...)
optSet = mergedOpts
}
return newProgram(e, ast, optSet)
return newProgram(e, a, optSet)
}
// CELTypeAdapter returns the `types.Adapter` configured for the environment.

View File

@@ -28,6 +28,7 @@ import (
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
celpb "cel.dev/expr"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
anypb "google.golang.org/protobuf/types/known/anypb"
)
@@ -104,72 +105,86 @@ func AstToString(a *Ast) (string, error) {
// RefValueToValue converts between ref.Val and api.expr.Value.
// The result Value is the serialized proto form. The ref.Val must not be error or unknown.
func RefValueToValue(res ref.Val) (*exprpb.Value, error) {
return ValueAsAlphaProto(res)
}
func ValueAsAlphaProto(res ref.Val) (*exprpb.Value, error) {
canonical, err := ValueAsProto(res)
if err != nil {
return nil, err
}
alpha := &exprpb.Value{}
err = convertProto(canonical, alpha)
return alpha, err
}
func ValueAsProto(res ref.Val) (*celpb.Value, error) {
switch res.Type() {
case types.BoolType:
return &exprpb.Value{
Kind: &exprpb.Value_BoolValue{BoolValue: res.Value().(bool)}}, nil
return &celpb.Value{
Kind: &celpb.Value_BoolValue{BoolValue: res.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Value{
Kind: &exprpb.Value_BytesValue{BytesValue: res.Value().([]byte)}}, nil
return &celpb.Value{
Kind: &celpb.Value_BytesValue{BytesValue: res.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Value{
Kind: &exprpb.Value_DoubleValue{DoubleValue: res.Value().(float64)}}, nil
return &celpb.Value{
Kind: &celpb.Value_DoubleValue{DoubleValue: res.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Value{
Kind: &exprpb.Value_Int64Value{Int64Value: res.Value().(int64)}}, nil
return &celpb.Value{
Kind: &celpb.Value_Int64Value{Int64Value: res.Value().(int64)}}, nil
case types.ListType:
l := res.(traits.Lister)
sz := l.Size().(types.Int)
elts := make([]*exprpb.Value, 0, int64(sz))
elts := make([]*celpb.Value, 0, int64(sz))
for i := types.Int(0); i < sz; i++ {
v, err := RefValueToValue(l.Get(i))
v, err := ValueAsProto(l.Get(i))
if err != nil {
return nil, err
}
elts = append(elts, v)
}
return &exprpb.Value{
Kind: &exprpb.Value_ListValue{
ListValue: &exprpb.ListValue{Values: elts}}}, nil
return &celpb.Value{
Kind: &celpb.Value_ListValue{
ListValue: &celpb.ListValue{Values: elts}}}, nil
case types.MapType:
mapper := res.(traits.Mapper)
sz := mapper.Size().(types.Int)
entries := make([]*exprpb.MapValue_Entry, 0, int64(sz))
entries := make([]*celpb.MapValue_Entry, 0, int64(sz))
for it := mapper.Iterator(); it.HasNext().(types.Bool); {
k := it.Next()
v := mapper.Get(k)
kv, err := RefValueToValue(k)
kv, err := ValueAsProto(k)
if err != nil {
return nil, err
}
vv, err := RefValueToValue(v)
vv, err := ValueAsProto(v)
if err != nil {
return nil, err
}
entries = append(entries, &exprpb.MapValue_Entry{Key: kv, Value: vv})
entries = append(entries, &celpb.MapValue_Entry{Key: kv, Value: vv})
}
return &exprpb.Value{
Kind: &exprpb.Value_MapValue{
MapValue: &exprpb.MapValue{Entries: entries}}}, nil
return &celpb.Value{
Kind: &celpb.Value_MapValue{
MapValue: &celpb.MapValue{Entries: entries}}}, nil
case types.NullType:
return &exprpb.Value{
Kind: &exprpb.Value_NullValue{}}, nil
return &celpb.Value{
Kind: &celpb.Value_NullValue{}}, nil
case types.StringType:
return &exprpb.Value{
Kind: &exprpb.Value_StringValue{StringValue: res.Value().(string)}}, nil
return &celpb.Value{
Kind: &celpb.Value_StringValue{StringValue: res.Value().(string)}}, nil
case types.TypeType:
typeName := res.(ref.Type).TypeName()
return &exprpb.Value{Kind: &exprpb.Value_TypeValue{TypeValue: typeName}}, nil
return &celpb.Value{Kind: &celpb.Value_TypeValue{TypeValue: typeName}}, nil
case types.UintType:
return &exprpb.Value{
Kind: &exprpb.Value_Uint64Value{Uint64Value: res.Value().(uint64)}}, nil
return &celpb.Value{
Kind: &celpb.Value_Uint64Value{Uint64Value: res.Value().(uint64)}}, nil
default:
any, err := res.ConvertToNative(anyPbType)
if err != nil {
return nil, err
}
return &exprpb.Value{
Kind: &exprpb.Value_ObjectValue{ObjectValue: any.(*anypb.Any)}}, nil
return &celpb.Value{
Kind: &celpb.Value_ObjectValue{ObjectValue: any.(*anypb.Any)}}, nil
}
}
@@ -192,55 +207,67 @@ var (
// ValueToRefValue converts between exprpb.Value and ref.Val.
func ValueToRefValue(adapter types.Adapter, v *exprpb.Value) (ref.Val, error) {
return AlphaProtoAsValue(adapter, v)
}
func AlphaProtoAsValue(adapter types.Adapter, v *exprpb.Value) (ref.Val, error) {
canonical := &celpb.Value{}
if err := convertProto(v, canonical); err != nil {
return nil, err
}
return ProtoAsValue(adapter, canonical)
}
func ProtoAsValue(adapter types.Adapter, v *celpb.Value) (ref.Val, error) {
switch v.Kind.(type) {
case *exprpb.Value_NullValue:
case *celpb.Value_NullValue:
return types.NullValue, nil
case *exprpb.Value_BoolValue:
case *celpb.Value_BoolValue:
return types.Bool(v.GetBoolValue()), nil
case *exprpb.Value_Int64Value:
case *celpb.Value_Int64Value:
return types.Int(v.GetInt64Value()), nil
case *exprpb.Value_Uint64Value:
case *celpb.Value_Uint64Value:
return types.Uint(v.GetUint64Value()), nil
case *exprpb.Value_DoubleValue:
case *celpb.Value_DoubleValue:
return types.Double(v.GetDoubleValue()), nil
case *exprpb.Value_StringValue:
case *celpb.Value_StringValue:
return types.String(v.GetStringValue()), nil
case *exprpb.Value_BytesValue:
case *celpb.Value_BytesValue:
return types.Bytes(v.GetBytesValue()), nil
case *exprpb.Value_ObjectValue:
case *celpb.Value_ObjectValue:
any := v.GetObjectValue()
msg, err := anypb.UnmarshalNew(any, proto.UnmarshalOptions{DiscardUnknown: true})
if err != nil {
return nil, err
}
return adapter.NativeToValue(msg), nil
case *exprpb.Value_MapValue:
case *celpb.Value_MapValue:
m := v.GetMapValue()
entries := make(map[ref.Val]ref.Val)
for _, entry := range m.Entries {
key, err := ValueToRefValue(adapter, entry.Key)
key, err := ProtoAsValue(adapter, entry.Key)
if err != nil {
return nil, err
}
pb, err := ValueToRefValue(adapter, entry.Value)
pb, err := ProtoAsValue(adapter, entry.Value)
if err != nil {
return nil, err
}
entries[key] = pb
}
return adapter.NativeToValue(entries), nil
case *exprpb.Value_ListValue:
case *celpb.Value_ListValue:
l := v.GetListValue()
elts := make([]ref.Val, len(l.Values))
for i, e := range l.Values {
rv, err := ValueToRefValue(adapter, e)
rv, err := ProtoAsValue(adapter, e)
if err != nil {
return nil, err
}
elts[i] = rv
}
return adapter.NativeToValue(elts), nil
case *exprpb.Value_TypeValue:
case *celpb.Value_TypeValue:
typeName := v.GetTypeValue()
tv, ok := typeNameToTypeValue[typeName]
if ok {
@@ -250,3 +277,12 @@ func ValueToRefValue(adapter types.Adapter, v *exprpb.Value) (ref.Val, error) {
}
return nil, errors.New("unknown value")
}
func convertProto(src, dst proto.Message) error {
pb, err := proto.Marshal(src)
if err != nil {
return err
}
err = proto.Unmarshal(pb, dst)
return err
}

View File

@@ -211,6 +211,16 @@ type OptimizerContext struct {
*Issues
}
// ExtendEnv auguments the context's environment with the additional options.
func (opt *OptimizerContext) ExtendEnv(opts ...EnvOption) error {
e, err := opt.Env.Extend(opts...)
if err != nil {
return err
}
opt.Env = e
return nil
}
// ASTOptimizer applies an optimization over an AST and returns the optimized result.
type ASTOptimizer interface {
// Optimize optimizes a type-checked AST within an Environment and accumulates any issues.

View File

@@ -19,6 +19,7 @@ import (
"fmt"
"sync"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
@@ -151,7 +152,7 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()
@@ -255,9 +256,9 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
return p.initInterpretable(a, decorators)
}
func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
func (p *prog) initInterpretable(a *ast.AST, decs []interpreter.InterpretableDecorator) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...)
interpretable, err := p.interpreter.NewInterpretable(a, decs...)
if err != nil {
return nil, err
}

View File

@@ -496,16 +496,32 @@ func (c *checker) checkComprehension(e ast.Expr) {
comp := e.AsComprehension()
c.check(comp.IterRange())
c.check(comp.AccuInit())
accuType := c.getType(comp.AccuInit())
rangeType := substitute(c.mappings, c.getType(comp.IterRange()), false)
var varType *types.Type
// Create a scope for the comprehension since it has a local accumulation variable.
// This scope will contain the accumulation variable used to compute the result.
accuType := c.getType(comp.AccuInit())
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType))
var varType, var2Type *types.Type
switch rangeType.Kind() {
case types.ListKind:
// varType represents the list element type for one-variable comprehensions.
varType = rangeType.Parameters()[0]
if comp.HasIterVar2() {
// varType represents the list index (int) for two-variable comprehensions,
// and var2Type represents the list element type.
var2Type = varType
varType = types.IntType
}
case types.MapKind:
// Ranges over the keys.
// varType represents the map entry key for all comprehension types.
varType = rangeType.Parameters()[0]
if comp.HasIterVar2() {
// var2Type represents the map entry value for two-variable comprehensions.
var2Type = rangeType.Parameters()[1]
}
case types.DynKind, types.ErrorKind, types.TypeParamKind:
// Set the range type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
@@ -518,13 +534,12 @@ func (c *checker) checkComprehension(e ast.Expr) {
varType = types.ErrorType
}
// Create a scope for the comprehension since it has a local accumulation variable.
// This scope will contain the accumulation variable used to compute the result.
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType))
// Create a block scope for the loop.
c.env = c.env.enterScope()
c.env.AddIdents(decls.NewVariable(comp.IterVar(), varType))
if comp.HasIterVar2() {
c.env.AddIdents(decls.NewVariable(comp.IterVar2(), var2Type))
}
// Check the variable references in the condition and step.
c.check(comp.LoopCondition())
c.assertType(comp.LoopCondition(), types.BoolType)

View File

@@ -15,11 +15,13 @@ go_library(
"navigable.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
deps = [
"//common:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"@dev_cel_expr//:expr",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],
)
@@ -35,12 +37,13 @@ go_test(
embed = [
":go_default_library",
],
deps = [
deps = [
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",

View File

@@ -17,12 +17,14 @@ package ast
import (
"fmt"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
structpb "google.golang.org/protobuf/types/known/structpb"
celpb "cel.dev/expr"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
structpb "google.golang.org/protobuf/types/known/structpb"
)
// ToProto converts an AST to a CheckedExpr protobouf.
@@ -173,9 +175,10 @@ func exprComprehension(factory ExprFactory, id int64, comp *exprpb.Expr_Comprehe
if err != nil {
return nil, err
}
return factory.NewComprehension(id,
return factory.NewComprehensionTwoVar(id,
iterRange,
comp.GetIterVar(),
comp.GetIterVar2(),
comp.GetAccuVar(),
accuInit,
loopCond,
@@ -363,6 +366,7 @@ func protoComprehension(id int64, comp ComprehensionExpr) (*exprpb.Expr, error)
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: comp.IterVar(),
IterVar2: comp.IterVar2(),
IterRange: iterRange,
AccuVar: comp.AccuVar(),
AccuInit: accuInit,
@@ -609,24 +613,47 @@ func ValToConstant(v ref.Val) (*exprpb.Constant, error) {
// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val.
func ConstantToVal(c *exprpb.Constant) (ref.Val, error) {
return AlphaProtoConstantAsVal(c)
}
// AlphaProtoConstantAsVal converts a v1alpha1.Constant protobuf to a CEL-native ref.Val.
func AlphaProtoConstantAsVal(c *exprpb.Constant) (ref.Val, error) {
if c == nil {
return nil, nil
}
canonical := &celpb.Constant{}
if err := convertProto(c, canonical); err != nil {
return nil, err
}
return ProtoConstantAsVal(canonical)
}
// ProtoConstantAsVal converts a canonical celpb.Constant protobuf to a CEL-native ref.Val.
func ProtoConstantAsVal(c *celpb.Constant) (ref.Val, error) {
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
case *celpb.Constant_BoolValue:
return types.Bool(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
case *celpb.Constant_BytesValue:
return types.Bytes(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
case *celpb.Constant_DoubleValue:
return types.Double(c.GetDoubleValue()), nil
case *exprpb.Constant_Int64Value:
case *celpb.Constant_Int64Value:
return types.Int(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
case *celpb.Constant_NullValue:
return types.NullValue, nil
case *exprpb.Constant_StringValue:
case *celpb.Constant_StringValue:
return types.String(c.GetStringValue()), nil
case *exprpb.Constant_Uint64Value:
case *celpb.Constant_Uint64Value:
return types.Uint(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind())
}
func convertProto(src, dst proto.Message) error {
pb, err := proto.Marshal(src)
if err != nil {
return err
}
err = proto.Unmarshal(pb, dst)
return err
}

View File

@@ -269,8 +269,22 @@ type ComprehensionExpr interface {
IterRange() Expr
// IterVar returns the iteration variable name.
//
// For one-variable comprehensions, the iter var refers to the element value
// when iterating over a list, or the map key when iterating over a map.
//
// For two-variable comprehneions, the iter var refers to the list index or the
// map key.
IterVar() string
// IterVar2 returns the second iteration variable name.
//
// When the value is non-empty, the comprehension is a two-variable comprehension.
IterVar2() string
// HasIterVar2 returns true if the second iteration variable is non-empty.
HasIterVar2() bool
// AccuVar returns the accumulation variable name.
AccuVar() string
@@ -397,6 +411,7 @@ func (e *expr) SetKindCase(other Expr) {
e.exprKindCase = &baseComprehensionExpr{
iterRange: c.IterRange(),
iterVar: c.IterVar(),
iterVar2: c.IterVar2(),
accuVar: c.AccuVar(),
accuInit: c.AccuInit(),
loopCond: c.LoopCondition(),
@@ -505,6 +520,7 @@ var _ ComprehensionExpr = &baseComprehensionExpr{}
type baseComprehensionExpr struct {
iterRange Expr
iterVar string
iterVar2 string
accuVar string
accuInit Expr
loopCond Expr
@@ -527,6 +543,14 @@ func (e *baseComprehensionExpr) IterVar() string {
return e.iterVar
}
func (e *baseComprehensionExpr) IterVar2() string {
return e.iterVar2
}
func (e *baseComprehensionExpr) HasIterVar2() bool {
return e.iterVar2 != ""
}
func (e *baseComprehensionExpr) AccuVar() string {
return e.accuVar
}

View File

@@ -27,9 +27,12 @@ type ExprFactory interface {
// NewCall creates an Expr value representing a global function call.
NewCall(id int64, function string, args ...Expr) Expr
// NewComprehension creates an Expr value representing a comprehension over a value range.
// NewComprehension creates an Expr value representing a one-variable comprehension over a value range.
NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr
// NewComprehensionTwoVar creates an Expr value representing a two-variable comprehension over a value range.
NewComprehensionTwoVar(id int64, iterRange Expr, iterVar, iterVar2, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr
// NewMemberCall creates an Expr value representing a member function call.
NewMemberCall(id int64, function string, receiver Expr, args ...Expr) Expr
@@ -111,11 +114,17 @@ func (fac *baseExprFactory) NewMemberCall(id int64, function string, target Expr
}
func (fac *baseExprFactory) NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
// Set the iter_var2 to empty string to indicate the second variable is omitted
return fac.NewComprehensionTwoVar(id, iterRange, iterVar, "", accuVar, accuInit, loopCond, loopStep, result)
}
func (fac *baseExprFactory) NewComprehensionTwoVar(id int64, iterRange Expr, iterVar, iterVar2, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
return fac.newExpr(
id,
&baseComprehensionExpr{
iterRange: iterRange,
iterVar: iterVar,
iterVar2: iterVar2,
accuVar: accuVar,
accuInit: accuInit,
loopCond: loopCond,
@@ -223,9 +232,10 @@ func (fac *baseExprFactory) CopyExpr(e Expr) Expr {
return fac.NewMemberCall(e.ID(), c.FunctionName(), fac.CopyExpr(c.Target()), argsCopy...)
case ComprehensionKind:
compre := e.AsComprehension()
return fac.NewComprehension(e.ID(),
return fac.NewComprehensionTwoVar(e.ID(),
fac.CopyExpr(compre.IterRange()),
compre.IterVar(),
compre.IterVar2(),
compre.AccuVar(),
fac.CopyExpr(compre.AccuInit()),
fac.CopyExpr(compre.LoopCondition()),

View File

@@ -390,6 +390,14 @@ func (comp navigableComprehensionImpl) IterVar() string {
return comp.Expr.AsComprehension().IterVar()
}
func (comp navigableComprehensionImpl) IterVar2() string {
return comp.Expr.AsComprehension().IterVar2()
}
func (comp navigableComprehensionImpl) HasIterVar2() bool {
return comp.Expr.AsComprehension().HasIterVar2()
}
func (comp navigableComprehensionImpl) AccuVar() string {
return comp.Expr.AsComprehension().AccuVar()
}

View File

@@ -19,6 +19,7 @@ package containers
import (
"fmt"
"strings"
"unicode"
"github.com/google/cel-go/common/ast"
)
@@ -212,6 +213,13 @@ type ContainerOption func(*Container) (*Container, error)
func Abbrevs(qualifiedNames ...string) ContainerOption {
return func(c *Container) (*Container, error) {
for _, qn := range qualifiedNames {
qn = strings.TrimSpace(qn)
for _, r := range qn {
if !isIdentifierChar(r) {
return nil, fmt.Errorf(
"invalid qualified name: %s, wanted name of the form 'qualified.name'", qn)
}
}
ind := strings.LastIndex(qn, ".")
if ind <= 0 || ind >= len(qn)-1 {
return nil, fmt.Errorf(
@@ -278,6 +286,10 @@ func aliasAs(kind, qualifiedName, alias string) ContainerOption {
}
}
func isIdentifierChar(r rune) bool {
return r <= unicode.MaxASCII && (r == '.' || r == '_' || unicode.IsLetter(r) || unicode.IsNumber(r))
}
// Name sets the fully-qualified name of the Container.
func Name(name string) ContainerOption {
return func(c *Container) (*Container, error) {

View File

@@ -215,6 +215,11 @@ func (w *debugWriter) appendComprehension(comprehension ast.ComprehensionExpr) {
w.append(comprehension.IterVar())
w.append(",")
w.appendLine()
if comprehension.HasIterVar2() {
w.append(comprehension.IterVar2())
w.append(",")
w.appendLine()
}
w.append("// Target")
w.appendLine()
w.Buffer(comprehension.IterRange())

View File

@@ -251,15 +251,15 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
// are preserved in order to assist with the function resolution step.
switch len(args) {
case 1:
if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
if o.unaryOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) {
return o.unaryOp(args[0])
}
case 2:
if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
if o.binaryOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) {
return o.binaryOp(args[0], args[1])
}
}
if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
if o.functionOp != nil && o.matchesRuntimeSignature(f.disableTypeGuards, args...) {
return o.functionOp(args...)
}
// eventually this will fall through to the noSuchOverload below.
@@ -777,8 +777,13 @@ func (v *VariableDecl) DeclarationIsEquivalent(other *VariableDecl) bool {
return v.Name() == other.Name() && v.Type().IsEquivalentType(other.Type())
}
// VariableDeclToExprDecl converts a go-native variable declaration into a protobuf-type variable declaration.
func VariableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) {
// TypeVariable creates a new type identifier for use within a types.Provider
func TypeVariable(t *types.Type) *VariableDecl {
return NewVariable(t.TypeName(), types.NewTypeTypeWithParam(t))
}
// variableDeclToExprDecl converts a go-native variable declaration into a protobuf-type variable declaration.
func variableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) {
varType, err := types.TypeToExprType(v.Type())
if err != nil {
return nil, err
@@ -786,13 +791,8 @@ func VariableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) {
return chkdecls.NewVar(v.Name(), varType), nil
}
// TypeVariable creates a new type identifier for use within a types.Provider
func TypeVariable(t *types.Type) *VariableDecl {
return NewVariable(t.TypeName(), types.NewTypeTypeWithParam(t))
}
// FunctionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration.
func FunctionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) {
// functionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration.
func functionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) {
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(f.overloads))
for i, oID := range f.overloadOrdinals {
o := f.overloads[oID]

View File

@@ -40,10 +40,12 @@ go_library(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@com_github_stoewer_go_strcase//:go_default_library",
"@dev_cel_expr//:expr",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//types/dynamicpb:go_default_library",
"@org_golang_google_protobuf//types/known/anypb:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",

View File

@@ -256,6 +256,15 @@ func (l *baseList) IsZeroValue() bool {
return l.size == 0
}
// Fold calls the FoldEntry method for each (index, value) pair in the list.
func (l *baseList) Fold(f traits.Folder) {
for i := 0; i < l.size; i++ {
if !f.FoldEntry(i, l.get(i)) {
break
}
}
}
// Iterator implements the traits.Iterable interface method.
func (l *baseList) Iterator() traits.Iterator {
return newListIterator(l)
@@ -433,6 +442,15 @@ func (l *concatList) IsZeroValue() bool {
return l.Size().(Int) == 0
}
// Fold calls the FoldEntry method for each (index, value) pair in the list.
func (l *concatList) Fold(f traits.Folder) {
for i := Int(0); i < l.Size().(Int); i++ {
if !f.FoldEntry(i, l.Get(i)) {
break
}
}
}
// Iterator implements the traits.Iterable interface method.
func (l *concatList) Iterator() traits.Iterator {
return newListIterator(l)
@@ -527,3 +545,30 @@ func IndexOrError(index ref.Val) (int, error) {
return -1, fmt.Errorf("unsupported index type '%s' in list", index.Type())
}
}
// ToFoldableList will create a Foldable version of a list suitable for key-value pair iteration.
//
// For values which are already Foldable, this call is a no-op. For all other values, the fold is
// driven via the Size() and Get() calls which means that the folding will function, but take a
// performance hit.
func ToFoldableList(l traits.Lister) traits.Foldable {
if f, ok := l.(traits.Foldable); ok {
return f
}
return interopFoldableList{Lister: l}
}
type interopFoldableList struct {
traits.Lister
}
// Fold implements the traits.Foldable interface method and performs an iteration over the
// range of elements of the list.
func (l interopFoldableList) Fold(f traits.Folder) {
sz := l.Size().(Int)
for i := Int(0); i < sz; i++ {
if !f.FoldEntry(i, l.Get(i)) {
break
}
}
}

View File

@@ -94,6 +94,24 @@ func NewProtoMap(adapter Adapter, value *pb.Map) traits.Mapper {
}
}
// NewMutableMap constructs a mutable map from an adapter and a set of map values.
func NewMutableMap(adapter Adapter, mutableValues map[ref.Val]ref.Val) traits.MutableMapper {
mutableCopy := make(map[ref.Val]ref.Val, len(mutableValues))
for k, v := range mutableValues {
mutableCopy[k] = v
}
m := &mutableMap{
baseMap: &baseMap{
Adapter: adapter,
mapAccessor: newRefValMapAccessor(mutableCopy),
value: mutableCopy,
size: len(mutableCopy),
},
mutableValues: mutableCopy,
}
return m
}
// mapAccessor is a private interface for finding values within a map and iterating over the keys.
// This interface implements portions of the API surface area required by the traits.Mapper
// interface.
@@ -105,6 +123,9 @@ type mapAccessor interface {
// Iterator returns an Iterator over the map key set.
Iterator() traits.Iterator
// Fold calls the FoldEntry method for each (key, value) pair in the map.
Fold(traits.Folder)
}
// baseMap is a reflection based map implementation designed to handle a variety of map-like types.
@@ -307,6 +328,28 @@ func (m *baseMap) Value() any {
return m.value
}
// mutableMap holds onto a set of mutable values which are used for intermediate computations.
type mutableMap struct {
*baseMap
mutableValues map[ref.Val]ref.Val
}
// Insert implements the traits.MutableMapper interface method, returning true if the key insertion
// succeeds.
func (m *mutableMap) Insert(k, v ref.Val) ref.Val {
if _, found := m.Find(k); found {
return NewErr("insert failed: key %v already exists", k)
}
m.mutableValues[k] = v
return m
}
// ToImmutableMap implements the traits.MutableMapper interface method, converting a mutable map
// an immutable map implementation.
func (m *mutableMap) ToImmutableMap() traits.Mapper {
return NewRefValMap(m.Adapter, m.mutableValues)
}
func newJSONStructAccessor(adapter Adapter, st map[string]*structpb.Value) mapAccessor {
return &jsonStructAccessor{
Adapter: adapter,
@@ -350,6 +393,15 @@ func (a *jsonStructAccessor) Iterator() traits.Iterator {
}
}
// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (a *jsonStructAccessor) Fold(f traits.Folder) {
for k, v := range a.st {
if !f.FoldEntry(k, v) {
break
}
}
}
func newReflectMapAccessor(adapter Adapter, value reflect.Value) mapAccessor {
keyType := value.Type().Key()
return &reflectMapAccessor{
@@ -424,6 +476,16 @@ func (m *reflectMapAccessor) Iterator() traits.Iterator {
}
}
// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (m *reflectMapAccessor) Fold(f traits.Folder) {
mapRange := m.refValue.MapRange()
for mapRange.Next() {
if !f.FoldEntry(mapRange.Key().Interface(), mapRange.Value().Interface()) {
break
}
}
}
func newRefValMapAccessor(mapVal map[ref.Val]ref.Val) mapAccessor {
return &refValMapAccessor{mapVal: mapVal}
}
@@ -477,6 +539,15 @@ func (a *refValMapAccessor) Iterator() traits.Iterator {
}
}
// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (a *refValMapAccessor) Fold(f traits.Folder) {
for k, v := range a.mapVal {
if !f.FoldEntry(k, v) {
break
}
}
}
func newStringMapAccessor(strMap map[string]string) mapAccessor {
return &stringMapAccessor{mapVal: strMap}
}
@@ -515,6 +586,15 @@ func (a *stringMapAccessor) Iterator() traits.Iterator {
}
}
// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (a *stringMapAccessor) Fold(f traits.Folder) {
for k, v := range a.mapVal {
if !f.FoldEntry(k, v) {
break
}
}
}
func newStringIfaceMapAccessor(adapter Adapter, mapVal map[string]any) mapAccessor {
return &stringIfaceMapAccessor{
Adapter: adapter,
@@ -557,6 +637,15 @@ func (a *stringIfaceMapAccessor) Iterator() traits.Iterator {
}
}
// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (a *stringIfaceMapAccessor) Fold(f traits.Folder) {
for k, v := range a.mapVal {
if !f.FoldEntry(k, v) {
break
}
}
}
// protoMap is a specialized, separate implementation of the traits.Mapper interfaces tailored to
// accessing protoreflect.Map values.
type protoMap struct {
@@ -769,6 +858,13 @@ func (m *protoMap) Iterator() traits.Iterator {
}
}
// Fold calls the FoldEntry method for each (key, value) pair in the map.
func (m *protoMap) Fold(f traits.Folder) {
m.value.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
return f.FoldEntry(k.Interface(), v.Interface())
})
}
// Size returns the number of entries in the protoreflect.Map.
func (m *protoMap) Size() ref.Val {
return Int(m.value.Len())
@@ -852,3 +948,55 @@ func (it *stringKeyIterator) Next() ref.Val {
}
return nil
}
// ToFoldableMap will create a Foldable version of a map suitable for key-value pair iteration.
//
// For values which are already Foldable, this call is a no-op. For all other values, the fold
// is driven via the Iterator HasNext() and Next() calls as well as the map's Get() method
// which means that the folding will function, but take a performance hit.
func ToFoldableMap(m traits.Mapper) traits.Foldable {
if f, ok := m.(traits.Foldable); ok {
return f
}
return interopFoldableMap{Mapper: m}
}
type interopFoldableMap struct {
traits.Mapper
}
func (m interopFoldableMap) Fold(f traits.Folder) {
it := m.Iterator()
for it.HasNext() == True {
k := it.Next()
if !f.FoldEntry(k, m.Get(k)) {
break
}
}
}
// InsertMapKeyValue inserts a key, value pair into the target map if the target map does not
// already contain the given key.
//
// If the map is mutable, it is modified in-place per the MutableMapper contract.
// If the map is not mutable, a copy containing the new key, value pair is made.
func InsertMapKeyValue(m traits.Mapper, k, v ref.Val) ref.Val {
if mutable, ok := m.(traits.MutableMapper); ok {
return mutable.Insert(k, v)
}
// Otherwise perform the slow version of the insertion which makes a copy of the incoming map.
if _, found := m.Find(k); !found {
size := m.Size().(Int)
copy := make(map[ref.Val]ref.Val, size+1)
copy[k] = v
it := m.Iterator()
for it.HasNext() == True {
nextK := it.Next()
nextV := m.Get(nextK)
copy[nextK] = nextV
}
return DefaultTypeAdapter.NativeToValue(copy)
}
return NewErr("insert failed: key %v already exists", k)
}

View File

@@ -35,6 +35,8 @@ var (
// golang reflect type for Null values.
nullReflectType = reflect.TypeOf(NullValue)
protoIfaceType = reflect.TypeOf((*proto.Message)(nil)).Elem()
)
// ConvertToNative implements ref.Val.ConvertToNative.
@@ -61,8 +63,14 @@ func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) {
return structpb.NewNullValue(), nil
case boolWrapperType, byteWrapperType, doubleWrapperType, floatWrapperType,
int32WrapperType, int64WrapperType, stringWrapperType, uint32WrapperType,
uint64WrapperType:
uint64WrapperType, durationValueType, timestampValueType, protoIfaceType:
return nil, nil
case jsonListValueType, jsonStructType:
// skip handling
default:
if typeDesc.Implements(protoIfaceType) {
return nil, nil
}
}
case reflect.Interface:
nv := n.Value()

View File

@@ -34,3 +34,16 @@ type Iterator interface {
// Next returns the next element.
Next() ref.Val
}
// Foldable aggregate types support iteration over (key, value) or (index, value) pairs.
type Foldable interface {
// Fold invokes the Folder.FoldEntry for all entries in the type
Fold(Folder)
}
// Folder performs a fold on a given entry and indicates whether to continue folding.
type Folder interface {
// FoldEntry indicates the key, value pair associated with the entry.
// If the output is true, continue folding. Otherwise, terminate the fold.
FoldEntry(key, val any) bool
}

View File

@@ -27,6 +27,9 @@ type Lister interface {
}
// MutableLister interface which emits an immutable result after an intermediate computation.
//
// Note, this interface is intended only to be used within Comprehensions where the mutable
// value is not directly observable within the user-authored CEL expression.
type MutableLister interface {
Lister
ToImmutableList() Lister

View File

@@ -31,3 +31,18 @@ type Mapper interface {
// (Unknown|Err, false).
Find(key ref.Val) (ref.Val, bool)
}
// MutableMapper interface which emits an immutable result after an intermediate computation.
//
// Note, this interface is intended only to be used within Comprehensions where the mutable
// value is not directly observable within the user-authored CEL expression.
type MutableMapper interface {
Mapper
// Insert a key, value pair into the map, returning the map if the insert is successful
// and an error if key already exists in the mutable map.
Insert(k, v ref.Val) ref.Val
// ToImmutableMap converts a mutable map into an immutable map.
ToImmutableMap() Mapper
}

View File

@@ -59,6 +59,21 @@ const (
// SizerType types support the size() method.
SizerType
// SubtractorType type support '-' operations.
// SubtractorType types support '-' operations.
SubtractorType
// FoldableType types support comprehensions v2 macros which iterate over (key, value) pairs.
FoldableType
)
const (
// ListerType supports a set of traits necessary for list operations.
//
// The ListerType is syntactic sugar and not intended to be a perfect reflection of all List operators.
ListerType = AdderType | ContainerType | IndexerType | IterableType | SizerType
// MapperType supports a set of traits necessary for map operations.
//
// The MapperType is syntactic sugar and not intended to be a perfect reflection of all Map operators.
MapperType = ContainerType | IndexerType | IterableType | SizerType
)

View File

@@ -19,10 +19,13 @@ import (
"reflect"
"strings"
"google.golang.org/protobuf/proto"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
celpb "cel.dev/expr"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@@ -666,85 +669,99 @@ func TypeToExprType(t *Type) (*exprpb.Type, error) {
// ExprTypeToType converts a protobuf CEL type representation to a CEL-native type representation.
func ExprTypeToType(t *exprpb.Type) (*Type, error) {
return AlphaProtoAsType(t)
}
// AlphaProtoAsType converts a CEL v1alpha1.Type protobuf type to a CEL-native type representation.
func AlphaProtoAsType(t *exprpb.Type) (*Type, error) {
canonical := &celpb.Type{}
if err := convertProto(t, canonical); err != nil {
return nil, err
}
return ProtoAsType(canonical)
}
// ProtoAsType converts a canonical CEL celpb.Type protobuf type to a CEL-native type representation.
func ProtoAsType(t *celpb.Type) (*Type, error) {
switch t.GetTypeKind().(type) {
case *exprpb.Type_Dyn:
case *celpb.Type_Dyn:
return DynType, nil
case *exprpb.Type_AbstractType_:
case *celpb.Type_AbstractType_:
paramTypes := make([]*Type, len(t.GetAbstractType().GetParameterTypes()))
for i, p := range t.GetAbstractType().GetParameterTypes() {
pt, err := ExprTypeToType(p)
pt, err := ProtoAsType(p)
if err != nil {
return nil, err
}
paramTypes[i] = pt
}
return NewOpaqueType(t.GetAbstractType().GetName(), paramTypes...), nil
case *exprpb.Type_ListType_:
et, err := ExprTypeToType(t.GetListType().GetElemType())
case *celpb.Type_ListType_:
et, err := ProtoAsType(t.GetListType().GetElemType())
if err != nil {
return nil, err
}
return NewListType(et), nil
case *exprpb.Type_MapType_:
kt, err := ExprTypeToType(t.GetMapType().GetKeyType())
case *celpb.Type_MapType_:
kt, err := ProtoAsType(t.GetMapType().GetKeyType())
if err != nil {
return nil, err
}
vt, err := ExprTypeToType(t.GetMapType().GetValueType())
vt, err := ProtoAsType(t.GetMapType().GetValueType())
if err != nil {
return nil, err
}
return NewMapType(kt, vt), nil
case *exprpb.Type_MessageType:
case *celpb.Type_MessageType:
return NewObjectType(t.GetMessageType()), nil
case *exprpb.Type_Null:
case *celpb.Type_Null:
return NullType, nil
case *exprpb.Type_Primitive:
case *celpb.Type_Primitive:
switch t.GetPrimitive() {
case exprpb.Type_BOOL:
case celpb.Type_BOOL:
return BoolType, nil
case exprpb.Type_BYTES:
case celpb.Type_BYTES:
return BytesType, nil
case exprpb.Type_DOUBLE:
case celpb.Type_DOUBLE:
return DoubleType, nil
case exprpb.Type_INT64:
case celpb.Type_INT64:
return IntType, nil
case exprpb.Type_STRING:
case celpb.Type_STRING:
return StringType, nil
case exprpb.Type_UINT64:
case celpb.Type_UINT64:
return UintType, nil
default:
return nil, fmt.Errorf("unsupported primitive type: %v", t)
}
case *exprpb.Type_TypeParam:
case *celpb.Type_TypeParam:
return NewTypeParamType(t.GetTypeParam()), nil
case *exprpb.Type_Type:
case *celpb.Type_Type:
if t.GetType().GetTypeKind() != nil {
p, err := ExprTypeToType(t.GetType())
p, err := ProtoAsType(t.GetType())
if err != nil {
return nil, err
}
return NewTypeTypeWithParam(p), nil
}
return TypeType, nil
case *exprpb.Type_WellKnown:
case *celpb.Type_WellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
case celpb.Type_ANY:
return AnyType, nil
case exprpb.Type_DURATION:
case celpb.Type_DURATION:
return DurationType, nil
case exprpb.Type_TIMESTAMP:
case celpb.Type_TIMESTAMP:
return TimestampType, nil
default:
return nil, fmt.Errorf("unsupported well-known type: %v", t)
}
case *exprpb.Type_Wrapper:
t, err := ExprTypeToType(&exprpb.Type{TypeKind: &exprpb.Type_Primitive{Primitive: t.GetWrapper()}})
case *celpb.Type_Wrapper:
t, err := ProtoAsType(&celpb.Type{TypeKind: &celpb.Type_Primitive{Primitive: t.GetWrapper()}})
if err != nil {
return nil, err
}
return NewNullableType(t), nil
case *exprpb.Type_Error:
case *celpb.Type_Error:
return ErrorType, nil
default:
return nil, fmt.Errorf("unsupported type: %v", t)
@@ -776,6 +793,23 @@ func maybeForeignType(t ref.Type) *Type {
return NewObjectType(t.TypeName(), traitMask)
}
func convertProto(src, dst proto.Message) error {
pb, err := proto.Marshal(src)
if err != nil {
return err
}
err = proto.Unmarshal(pb, dst)
return err
}
func primitiveType(primitive celpb.Type_PrimitiveType) *celpb.Type {
return &celpb.Type{
TypeKind: &celpb.Type_Primitive{
Primitive: primitive,
},
}
}
var (
checkedWellKnowns = map[string]*Type{
// Wrapper types.
@@ -820,4 +854,11 @@ var (
}
structTypeTraitMask = traits.FieldTesterType | traits.IndexerType
boolType = primitiveType(celpb.Type_BOOL)
bytesType = primitiveType(celpb.Type_BYTES)
doubleType = primitiveType(celpb.Type_DOUBLE)
intType = primitiveType(celpb.Type_INT64)
stringType = primitiveType(celpb.Type_STRING)
uintType = primitiveType(celpb.Type_UINT64)
)

View File

@@ -24,6 +24,7 @@ go_library(
"//cel:go_default_library",
"//checker:go_default_library",
"//common/ast:go_default_library",
"//common/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
@@ -31,6 +32,7 @@ go_library(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"//parser:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//types/known/structpb",
@@ -61,8 +63,8 @@ go_test(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",

View File

@@ -3,12 +3,12 @@
CEL extensions are a related set of constants, functions, macros, or other
features which may not be covered by the core CEL spec.
## Bindings
## Bindings
Returns a cel.EnvOption to configure support for local variable bindings
in expressions.
# Cel.Bind
### Cel.Bind
Binds a simple identifier to an initialization expression which may be used
in a subsequenct result expression. Bindings may also be nested within each
@@ -19,11 +19,11 @@ other.
Examples:
cel.bind(a, 'hello',
cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello"
cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello"
// Avoid a list allocation within the exists comprehension.
cel.bind(valid_values, [a, b, c],
[d, e, f].exists(elem, elem in valid_values))
[d, e, f].exists(elem, elem in valid_values))
Local bindings are not guaranteed to be evaluated before use.
@@ -393,6 +393,65 @@ Example:
Extended functions for list manipulation. As a general note, all indices are
zero-based.
### Distinct
**Introduced in version 2**
Returns the distinct elements of a list.
<list(T)>.distinct() -> <list(T)>
Examples:
[1, 2, 2, 3, 3, 3].distinct() // return [1, 2, 3]
["b", "b", "c", "a", "c"].distinct() // return ["b", "c", "a"]
[1, "b", 2, "b"].distinct() // return [1, "b", 2]
### Flatten
**Introduced in version 1**
Flattens a list recursively.
If an optional depth is provided, the list is flattened to a the specificied level.
A negative depth value will result in an error.
<list>.flatten(<list>) -> <list>
<list>.flatten(<list>, <int>) -> <list>
Examples:
[1,[2,3],[4]].flatten() // return [1, 2, 3, 4]
[1,[2,[3,4]]].flatten() // return [1, 2, [3, 4]]
[1,2,[],[],[3,4]].flatten() // return [1, 2, 3, 4]
[1,[2,[3,[4]]]].flatten(2) // return [1, 2, 3, [4]]
[1,[2,[3,[4]]]].flatten(-1) // error
### Range
**Introduced in version 2**
Returns a list of integers from 0 to n-1.
lists.range(<int>) -> <list(int)>
Examples:
lists.range(5) -> [0, 1, 2, 3, 4]
### Reverse
**Introduced in version 2**
Returns the elements of a list in reverse order.
<list(T)>.reverse() -> <list(T)>
Examples:
[5, 3, 1, 2].reverse() // return [2, 1, 3, 5]
### Slice
@@ -403,7 +462,43 @@ Returns a new sub-list using the indexes provided.
Examples:
[1,2,3,4].slice(1, 3) // return [2, 3]
[1,2,3,4].slice(2, 4) // return [3 ,4]
[1,2,3,4].slice(2, 4) // return [3, 4]
### Sort
**Introduced in version 2**
Sorts a list with comparable elements. If the element type is not comparable
or the element types are not the same, the function will produce an error.
<list(T)>.sort() -> <list(T)>
T in {int, uint, double, bool, duration, timestamp, string, bytes}
Examples:
[3, 2, 1].sort() // return [1, 2, 3]
["b", "c", "a"].sort() // return ["a", "b", "c"]
[1, "b"].sort() // error
[[1, 2, 3]].sort() // error
### SortBy
**Introduced in version 2**
Sorts a list by a key value, i.e., the order is determined by the result of
an expression applied to each element of the list.
<list(T)>.sortBy(<bindingName>, <keyExpr>) -> <list(T)>
keyExpr returns a value in {int, uint, double, bool, duration, timestamp, string, bytes}
Examples:
[
Player { name: "foo", score: 0 },
Player { name: "bar", score: -10 },
Player { name: "baz", score: 1000 },
].sortBy(e, e.score).map(e, e.name)
== ["bar", "foo", "baz"]
## Sets
@@ -498,7 +593,8 @@ Examples:
'hello mellow'.indexOf('jello') // returns -1
'hello mellow'.indexOf('', 2) // returns 2
'hello mellow'.indexOf('ello', 2) // returns 7
'hello mellow'.indexOf('ello', 20) // error
'hello mellow'.indexOf('ello', 20) // returns -1
'hello mellow'.indexOf('ello', -1) // error
### Join
@@ -536,6 +632,7 @@ Examples:
'hello mellow'.lastIndexOf('ello') // returns 7
'hello mellow'.lastIndexOf('jello') // returns -1
'hello mellow'.lastIndexOf('ello', 6) // returns 1
'hello mellow'.lastIndexOf('ello', 20) // returns -1
'hello mellow'.lastIndexOf('ello', -1) // error
### LowerAscii
@@ -666,4 +763,137 @@ It can be located in Version 3 of strings.
Examples:
'gums'.reverse() // returns 'smug'
'John Smith'.reverse() // returns 'htimS nhoJ'
'John Smith'.reverse() // returns 'htimS nhoJ'
## TwoVarComprehensions
TwoVarComprehensions introduces support for two-variable comprehensions.
The two-variable form of comprehensions looks similar to the one-variable
counterparts. Where possible, the same macro names were used and additional
macro signatures added. The notable distinction for two-variable comprehensions
is the introduction of `transformList`, `transformMap`, and `transformMapEntry`
support for list and map types rather than the more traditional `map` and
`filter` macros.
### All
Comprehension which tests whether all elements in the list or map satisfy a
given predicate. The `all` macro evaluates in a manner consistent with logical
AND and will short-circuit when encountering a `false` value.
<list>.all(indexVar, valueVar, <predicate>) -> bool
<map>.all(keyVar, valueVar, <predicate>) -> bool
Examples:
[1, 2, 3].all(i, j, i < j) // returns true
{'hello': 'world', 'taco': 'taco'}.all(k, v, k != v) // returns false
// Combines two-variable comprehension with single variable
{'h': ['hello', 'hi'], 'j': ['joke', 'jog']}
.all(k, vals, vals.all(v, v.startsWith(k))) // returns true
### Exists
Comprehension which tests whether any element in a list or map exists which
satisfies a given predicate. The `exists` macro evaluates in a manner consistent
with logical OR and will short-circuit when encountering a `true` value.
<list>.exists(indexVar, valueVar, <predicate>) -> bool
<map>.exists(keyVar, valueVar, <predicate>) -> bool
Examples:
{'greeting': 'hello', 'farewell': 'goodbye'}
.exists(k, v, k.startsWith('good') || v.endsWith('bye')) // returns true
[1, 2, 4, 8, 16].exists(i, v, v == 1024 && i == 10) // returns false
### ExistsOne
Comprehension which tests whether exactly one element in a list or map exists
which satisfies a given predicate expression. This comprehension does not
short-circuit in keeping with the one-variable exists one macro semantics.
<list>.existsOne(indexVar, valueVar, <predicate>)
<map>.existsOne(keyVar, valueVar, <predicate>)
This macro may also be used with the `exists_one` function name, for
compatibility with the one-variable macro of the same name.
Examples:
[1, 2, 1, 3, 1, 4].existsOne(i, v, i == 1 || v == 1) // returns false
[1, 1, 2, 2, 3, 3].existsOne(i, v, i == 2 && v == 2) // returns true
{'i': 0, 'j': 1, 'k': 2}.existsOne(i, v, i == 'l' || v == 1) // returns true
### TransformList
Comprehension which converts a map or a list into a list value. The output
expression of the comprehension determines the contents of the output list.
Elements in the list may optionally be filtered according to a predicate
expression, where elements that satisfy the predicate are transformed.
<list>.transformList(indexVar, valueVar, <transform>)
<list>.transformList(indexVar, valueVar, <filter>, <transform>)
<map>.transformList(keyVar, valueVar, <transform>)
<map>.transformList(keyVar, valueVar, <filter>, <transform>)
Examples:
[1, 2, 3].transformList(indexVar, valueVar,
(indexVar * valueVar) + valueVar) // returns [1, 4, 9]
[1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0
(indexVar * valueVar) + valueVar) // returns [1, 9]
{'greeting': 'hello', 'farewell': 'goodbye'}
.transformList(k, _, k) // returns ['greeting', 'farewell']
{'greeting': 'hello', 'farewell': 'goodbye'}
.transformList(_, v, v) // returns ['hello', 'goodbye']
### TransformMap
Comprehension which converts a map or a list into a map value. The output
expression of the comprehension determines the value of the output map entry;
however, the key remains fixed. Elements in the map may optionally be filtered
according to a predicate expression, where elements that satisfy the predicate
are transformed.
<list>.transformMap(indexVar, valueVar, <transform>)
<list>.transformMap(indexVar, valueVar, <filter>, <transform>)
<map>.transformMap(keyVar, valueVar, <transform>)
<map>.transformMap(keyVar, valueVar, <filter>, <transform>)
Examples:
[1, 2, 3].transformMap(indexVar, valueVar,
(indexVar * valueVar) + valueVar) // returns {0: 1, 1: 4, 2: 9}
[1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0
(indexVar * valueVar) + valueVar) // returns {0: 1, 2: 9}
{'greeting': 'hello'}.transformMap(k, v, v + '!') // returns {'greeting': 'hello!'}
### TransformMapEntry
Comprehension which converts a map or a list into a map value; however, this
transform expects the entry expression be a map literal. If the transform
produces an entry which duplicates a key in the target map, the comprehension
will error. Note, that key equality is determined using CEL equality which
asserts that numeric values which are equal, even if they don't have the same
type will cause a key collision.
Elements in the map may optionally be filtered according to a predicate
expression, where elements that satisfy the predicate are transformed.
<list>.transformMap(indexVar, valueVar, <transform>)
<list>.transformMap(indexVar, valueVar, <filter>, <transform>)
<map>.transformMap(keyVar, valueVar, <transform>)
<map>.transformMap(keyVar, valueVar, <filter>, <transform>)
Examples:
// returns {'hello': 'greeting'}
{'greeting': 'hello'}.transformMapEntry(keyVar, valueVar, {valueVar: keyVar})
// reverse lookup, require all values in list be unique
[1, 2, 3].transformMapEntry(indexVar, valueVar, {valueVar: indexVar})
{'greeting': 'aloha', 'farewell': 'aloha'}
.transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key

View File

@@ -15,9 +15,19 @@
package ext
import (
"errors"
"fmt"
"math"
"strconv"
"strings"
"sync"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)
// Bindings returns a cel.EnvOption to configure support for local variable
@@ -41,35 +51,120 @@ import (
// [d, e, f].exists(elem, elem in valid_values))
//
// Local bindings are not guaranteed to be evaluated before use.
func Bindings() cel.EnvOption {
return cel.Lib(celBindings{})
func Bindings(options ...BindingsOption) cel.EnvOption {
b := &celBindings{version: math.MaxUint32}
for _, o := range options {
b = o(b)
}
return cel.Lib(b)
}
const (
celNamespace = "cel"
bindMacro = "bind"
blockFunc = "@block"
unusedIterVar = "#unused"
)
type celBindings struct{}
// BindingsOption declares a functional operator for configuring the Bindings library behavior.
type BindingsOption func(*celBindings) *celBindings
func (celBindings) LibraryName() string {
// BindingsVersion sets the version of the bindings library to an explicit version.
func BindingsVersion(version uint32) BindingsOption {
return func(lib *celBindings) *celBindings {
lib.version = version
return lib
}
}
type celBindings struct {
version uint32
}
func (*celBindings) LibraryName() string {
return "cel.lib.ext.cel.bindings"
}
func (celBindings) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
func (lib *celBindings) CompileOptions() []cel.EnvOption {
opts := []cel.EnvOption{
cel.Macros(
// cel.bind(var, <init>, <expr>)
cel.ReceiverMacro(bindMacro, 3, celBind),
),
}
if lib.version >= 1 {
// The cel.@block signature takes a list of subexpressions and a typed expression which is
// used as the output type.
paramType := cel.TypeParamType("T")
opts = append(opts,
cel.Function("cel.@block",
cel.Overload("cel_block_list",
[]*cel.Type{cel.ListType(cel.DynType), paramType}, paramType)),
)
opts = append(opts, cel.ASTValidators(blockValidationExemption{}))
}
return opts
}
func (celBindings) ProgramOptions() []cel.ProgramOption {
func (lib *celBindings) ProgramOptions() []cel.ProgramOption {
if lib.version >= 1 {
celBlockPlan := func(i interpreter.Interpretable) (interpreter.Interpretable, error) {
call, ok := i.(interpreter.InterpretableCall)
if !ok {
return i, nil
}
switch call.Function() {
case "cel.@block":
args := call.Args()
if len(args) != 2 {
return nil, fmt.Errorf("cel.@block expects two arguments, but got %d", len(args))
}
expr := args[1]
// Non-empty block
if block, ok := args[0].(interpreter.InterpretableConstructor); ok {
slotExprs := block.InitVals()
return newDynamicBlock(slotExprs, expr), nil
}
// Constant valued block which can happen during runtime optimization.
if cons, ok := args[0].(interpreter.InterpretableConst); ok {
if cons.Value().Type() == types.ListType {
l := cons.Value().(traits.Lister)
if l.Size().Equal(types.IntZero) == types.True {
return args[1], nil
}
return newConstantBlock(l, expr), nil
}
}
return nil, errors.New("cel.@block expects a list constructor as the first argument")
default:
return i, nil
}
}
return []cel.ProgramOption{cel.CustomDecorator(celBlockPlan)}
}
return []cel.ProgramOption{}
}
type blockValidationExemption struct{}
// Name returns the name of the validator.
func (blockValidationExemption) Name() string {
return "cel.lib.ext.validate.functions.cel.block"
}
// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip
// during homogeneous aggregate literal type-checks.
func (blockValidationExemption) Configure(config cel.MutableValidatorConfig) error {
functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string)
functions = append(functions, "cel.@block")
return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions)
}
// Validate is a no-op as the intent is to simply disable strong type-checks for list literals during
// when they occur within cel.@block calls as the arg types have already been validated.
func (blockValidationExemption) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) {
}
func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(celNamespace, target) {
return nil, nil
@@ -94,3 +189,148 @@ func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Ex
resultExpr,
), nil
}
func newDynamicBlock(slotExprs []interpreter.Interpretable, expr interpreter.Interpretable) interpreter.Interpretable {
bs := &dynamicBlock{
slotExprs: slotExprs,
expr: expr,
}
bs.slotActivationPool = &sync.Pool{
New: func() any {
slotCount := len(slotExprs)
sa := &dynamicSlotActivation{
slotExprs: slotExprs,
slotCount: slotCount,
slotVals: make([]*slotVal, slotCount),
}
for i := 0; i < slotCount; i++ {
sa.slotVals[i] = &slotVal{}
}
return sa
},
}
return bs
}
type dynamicBlock struct {
slotExprs []interpreter.Interpretable
expr interpreter.Interpretable
slotActivationPool *sync.Pool
}
// ID implements the Interpretable interface method.
func (b *dynamicBlock) ID() int64 {
return b.expr.ID()
}
// Eval implements the Interpretable interface method.
func (b *dynamicBlock) Eval(activation interpreter.Activation) ref.Val {
sa := b.slotActivationPool.Get().(*dynamicSlotActivation)
sa.Activation = activation
defer b.clearSlots(sa)
return b.expr.Eval(sa)
}
func (b *dynamicBlock) clearSlots(sa *dynamicSlotActivation) {
sa.reset()
b.slotActivationPool.Put(sa)
}
type slotVal struct {
value *ref.Val
visited bool
}
type dynamicSlotActivation struct {
interpreter.Activation
slotExprs []interpreter.Interpretable
slotCount int
slotVals []*slotVal
}
// ResolveName implements the Activation interface method but handles variables prefixed with `@index`
// as special variables which exist within the slot-based memory of the cel.@block() where each slot
// refers to an expression which must be computed only once.
func (sa *dynamicSlotActivation) ResolveName(name string) (any, bool) {
if idx, found := matchSlot(name, sa.slotCount); found {
v := sa.slotVals[idx]
if v.visited {
// Return not found if the index expression refers to itself
if v.value == nil {
return nil, false
}
return *v.value, true
}
v.visited = true
val := sa.slotExprs[idx].Eval(sa)
v.value = &val
return val, true
}
return sa.Activation.ResolveName(name)
}
func (sa *dynamicSlotActivation) reset() {
sa.Activation = nil
for _, sv := range sa.slotVals {
sv.visited = false
sv.value = nil
}
}
func newConstantBlock(slots traits.Lister, expr interpreter.Interpretable) interpreter.Interpretable {
count := slots.Size().(types.Int)
return &constantBlock{slots: slots, slotCount: int(count), expr: expr}
}
type constantBlock struct {
slots traits.Lister
slotCount int
expr interpreter.Interpretable
}
// ID implements the interpreter.Interpretable interface method.
func (b *constantBlock) ID() int64 {
return b.expr.ID()
}
// Eval implements the interpreter.Interpretable interface method, and will proxy @index prefixed variable
// lookups into a set of constant slots determined from the plan step.
func (b *constantBlock) Eval(activation interpreter.Activation) ref.Val {
vars := constantSlotActivation{Activation: activation, slots: b.slots, slotCount: b.slotCount}
return b.expr.Eval(vars)
}
type constantSlotActivation struct {
interpreter.Activation
slots traits.Lister
slotCount int
}
// ResolveName implements Activation interface method and proxies @index prefixed lookups into the slot
// activation associated with the block scope.
func (sa constantSlotActivation) ResolveName(name string) (any, bool) {
if idx, found := matchSlot(name, sa.slotCount); found {
return sa.slots.Get(types.Int(idx)), true
}
return sa.Activation.ResolveName(name)
}
func matchSlot(name string, slotCount int) (int, bool) {
if idx, found := strings.CutPrefix(name, indexPrefix); found {
idx, err := strconv.Atoi(idx)
// Return not found if the index is not numeric
if err != nil {
return -1, false
}
// Return not found if the index is not a valid slot
if idx < 0 || idx >= slotCount {
return -1, false
}
return idx, true
}
return -1, false
}
var (
indexPrefix = "@index"
)

410
vendor/github.com/google/cel-go/ext/comprehensions.go generated vendored Normal file
View File

@@ -0,0 +1,410 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
)
const (
mapInsert = "cel.@mapInsert"
mapInsertOverloadMap = "@mapInsert_map_map"
mapInsertOverloadKeyValue = "@mapInsert_map_key_value"
)
// TwoVarComprehensions introduces support for two-variable comprehensions.
//
// The two-variable form of comprehensions looks similar to the one-variable counterparts.
// Where possible, the same macro names were used and additional macro signatures added.
// The notable distinction for two-variable comprehensions is the introduction of
// `transformList`, `transformMap`, and `transformMapEntry` support for list and map types
// rather than the more traditional `map` and `filter` macros.
//
// # All
//
// Comprehension which tests whether all elements in the list or map satisfy a given
// predicate. The `all` macro evaluates in a manner consistent with logical AND and will
// short-circuit when encountering a `false` value.
//
// <list>.all(indexVar, valueVar, <predicate>) -> bool
// <map>.all(keyVar, valueVar, <predicate>) -> bool
//
// Examples:
//
// [1, 2, 3].all(i, j, i < j) // returns true
// {'hello': 'world', 'taco': 'taco'}.all(k, v, k != v) // returns false
//
// // Combines two-variable comprehension with single variable
// {'h': ['hello', 'hi'], 'j': ['joke', 'jog']}
// .all(k, vals, vals.all(v, v.startsWith(k))) // returns true
//
// # Exists
//
// Comprehension which tests whether any element in a list or map exists which satisfies
// a given predicate. The `exists` macro evaluates in a manner consistent with logical OR
// and will short-circuit when encountering a `true` value.
//
// <list>.exists(indexVar, valueVar, <predicate>) -> bool
// <map>.exists(keyVar, valueVar, <predicate>) -> bool
//
// Examples:
//
// {'greeting': 'hello', 'farewell': 'goodbye'}
// .exists(k, v, k.startsWith('good') || v.endsWith('bye')) // returns true
// [1, 2, 4, 8, 16].exists(i, v, v == 1024 && i == 10) // returns false
//
// # ExistsOne
//
// Comprehension which tests whether exactly one element in a list or map exists which
// satisfies a given predicate expression. This comprehension does not short-circuit in
// keeping with the one-variable exists one macro semantics.
//
// <list>.existsOne(indexVar, valueVar, <predicate>)
// <map>.existsOne(keyVar, valueVar, <predicate>)
//
// This macro may also be used with the `exists_one` function name, for compatibility
// with the one-variable macro of the same name.
//
// Examples:
//
// [1, 2, 1, 3, 1, 4].existsOne(i, v, i == 1 || v == 1) // returns false
// [1, 1, 2, 2, 3, 3].existsOne(i, v, i == 2 && v == 2) // returns true
// {'i': 0, 'j': 1, 'k': 2}.existsOne(i, v, i == 'l' || v == 1) // returns true
//
// # TransformList
//
// Comprehension which converts a map or a list into a list value. The output expression
// of the comprehension determines the contents of the output list. Elements in the list
// may optionally be filtered according to a predicate expression, where elements that
// satisfy the predicate are transformed.
//
// <list>.transformList(indexVar, valueVar, <transform>)
// <list>.transformList(indexVar, valueVar, <filter>, <transform>)
// <map>.transformList(keyVar, valueVar, <transform>)
// <map>.transformList(keyVar, valueVar, <filter>, <transform>)
//
// Examples:
//
// [1, 2, 3].transformList(indexVar, valueVar,
// (indexVar * valueVar) + valueVar) // returns [1, 4, 9]
// [1, 2, 3].transformList(indexVar, valueVar, indexVar % 2 == 0
// (indexVar * valueVar) + valueVar) // returns [1, 9]
// {'greeting': 'hello', 'farewell': 'goodbye'}
// .transformList(k, _, k) // returns ['greeting', 'farewell']
// {'greeting': 'hello', 'farewell': 'goodbye'}
// .transformList(_, v, v) // returns ['hello', 'goodbye']
//
// # TransformMap
//
// Comprehension which converts a map or a list into a map value. The output expression
// of the comprehension determines the value of the output map entry; however, the key
// remains fixed. Elements in the map may optionally be filtered according to a predicate
// expression, where elements that satisfy the predicate are transformed.
//
// <list>.transformMap(indexVar, valueVar, <transform>)
// <list>.transformMap(indexVar, valueVar, <filter>, <transform>)
// <map>.transformMap(keyVar, valueVar, <transform>)
// <map>.transformMap(keyVar, valueVar, <filter>, <transform>)
//
// Examples:
//
// [1, 2, 3].transformMap(indexVar, valueVar,
// (indexVar * valueVar) + valueVar) // returns {0: 1, 1: 4, 2: 9}
// [1, 2, 3].transformMap(indexVar, valueVar, indexVar % 2 == 0
// (indexVar * valueVar) + valueVar) // returns {0: 1, 2: 9}
// {'greeting': 'hello'}.transformMap(k, v, v + '!') // returns {'greeting': 'hello!'}
//
// # TransformMapEntry
//
// Comprehension which converts a map or a list into a map value; however, this transform
// expects the entry expression be a map literal. If the tranform produces an entry which
// duplicates a key in the target map, the comprehension will error. Note, that key
// equality is determined using CEL equality which asserts that numeric values which are
// equal, even if they don't have the same type will cause a key collision.
//
// Elements in the map may optionally be filtered according to a predicate expression, where
// elements that satisfy the predicate are transformed.
//
// <list>.transformMap(indexVar, valueVar, <transform>)
// <list>.transformMap(indexVar, valueVar, <filter>, <transform>)
// <map>.transformMap(keyVar, valueVar, <transform>)
// <map>.transformMap(keyVar, valueVar, <filter>, <transform>)
//
// Examples:
//
// // returns {'hello': 'greeting'}
// {'greeting': 'hello'}.transformMapEntry(keyVar, valueVar, {valueVar: keyVar})
// // reverse lookup, require all values in list be unique
// [1, 2, 3].transformMapEntry(indexVar, valueVar, {valueVar: indexVar})
//
// {'greeting': 'aloha', 'farewell': 'aloha'}
// .transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key
func TwoVarComprehensions() cel.EnvOption {
return cel.Lib(compreV2Lib{})
}
type compreV2Lib struct{}
// LibraryName implements that SingletonLibrary interface method.
func (compreV2Lib) LibraryName() string {
return "cel.lib.ext.comprev2"
}
// CompileOptions implements the cel.Library interface method.
func (compreV2Lib) CompileOptions() []cel.EnvOption {
kType := cel.TypeParamType("K")
vType := cel.TypeParamType("V")
mapKVType := cel.MapType(kType, vType)
opts := []cel.EnvOption{
cel.Macros(
cel.ReceiverMacro("all", 3, quantifierAll),
cel.ReceiverMacro("exists", 3, quantifierExists),
cel.ReceiverMacro("existsOne", 3, quantifierExistsOne),
cel.ReceiverMacro("exists_one", 3, quantifierExistsOne),
cel.ReceiverMacro("transformList", 3, transformList),
cel.ReceiverMacro("transformList", 4, transformList),
cel.ReceiverMacro("transformMap", 3, transformMap),
cel.ReceiverMacro("transformMap", 4, transformMap),
cel.ReceiverMacro("transformMapEntry", 3, transformMapEntry),
cel.ReceiverMacro("transformMapEntry", 4, transformMapEntry),
),
cel.Function(mapInsert,
cel.Overload(mapInsertOverloadKeyValue, []*cel.Type{mapKVType, kType, vType}, mapKVType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
m := args[0].(traits.Mapper)
k := args[1]
v := args[2]
return types.InsertMapKeyValue(m, k, v)
})),
cel.Overload(mapInsertOverloadMap, []*cel.Type{mapKVType, mapKVType}, mapKVType,
cel.BinaryBinding(func(targetMap, updateMap ref.Val) ref.Val {
tm := targetMap.(traits.Mapper)
um := updateMap.(traits.Mapper)
umIt := um.Iterator()
for umIt.HasNext() == types.True {
k := umIt.Next()
updateOrErr := types.InsertMapKeyValue(tm, k, um.Get(k))
if types.IsError(updateOrErr) {
return updateOrErr
}
tm = updateOrErr.(traits.Mapper)
}
return tm
})),
),
}
return opts
}
// ProgramOptions implements the cel.Library interface method
func (compreV2Lib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
/*accuInit=*/ mef.NewLiteral(types.True),
/*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewAccuIdent()),
/*step=*/ mef.NewCall(operators.LogicalAnd, mef.NewAccuIdent(), args[2]),
/*result=*/ mef.NewAccuIdent(),
), nil
}
func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
/*accuInit=*/ mef.NewLiteral(types.False),
/*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewCall(operators.LogicalNot, mef.NewAccuIdent())),
/*step=*/ mef.NewCall(operators.LogicalOr, mef.NewAccuIdent(), args[2]),
/*result=*/ mef.NewAccuIdent(),
), nil
}
func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
/*accuInit=*/ mef.NewLiteral(types.Int(0)),
/*condition=*/ mef.NewLiteral(types.True),
/*step=*/ mef.NewCall(operators.Conditional, args[2],
mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewLiteral(types.Int(1))),
mef.NewAccuIdent()),
/*result=*/ mef.NewCall(operators.Equals, mef.NewAccuIdent(), mef.NewLiteral(types.Int(1))),
), nil
}
func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
var transform ast.Expr
var filter ast.Expr
if len(args) == 4 {
filter = args[2]
transform = args[3]
} else {
filter = nil
transform = args[2]
}
// __result__ = __result__ + [transform]
step := mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewList(transform))
if filter != nil {
// __result__ = (filter) ? __result__ + [transform] : __result__
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
/*accuInit=*/ mef.NewList(),
/*condition=*/ mef.NewLiteral(types.True),
step,
/*result=*/ mef.NewAccuIdent(),
), nil
}
func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
var transform ast.Expr
var filter ast.Expr
if len(args) == 4 {
filter = args[2]
transform = args[3]
} else {
filter = nil
transform = args[2]
}
// __result__ = cel.@mapInsert(__result__, iterVar1, transform)
step := mef.NewCall(mapInsert, mef.NewAccuIdent(), mef.NewIdent(iterVar1), transform)
if filter != nil {
// __result__ = (filter) ? cel.@mapInsert(__result__, iterVar1, transform) : __result__
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
/*accuInit=*/ mef.NewMap(),
/*condition=*/ mef.NewLiteral(types.True),
step,
/*result=*/ mef.NewAccuIdent(),
), nil
}
func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
var transform ast.Expr
var filter ast.Expr
if len(args) == 4 {
filter = args[2]
transform = args[3]
} else {
filter = nil
transform = args[2]
}
// __result__ = cel.@mapInsert(__result__, transform)
step := mef.NewCall(mapInsert, mef.NewAccuIdent(), transform)
if filter != nil {
// __result__ = (filter) ? cel.@mapInsert(__result__, transform) : __result__
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
}
return mef.NewComprehensionTwoVar(
target,
iterVar1,
iterVar2,
parser.AccumulatorName,
/*accuInit=*/ mef.NewMap(),
/*condition=*/ mef.NewLiteral(types.True),
step,
/*result=*/ mef.NewAccuIdent(),
), nil
}
func extractIterVars(mef cel.MacroExprFactory, arg0, arg1 ast.Expr) (string, string, *cel.Error) {
iterVar1, err := extractIterVar(mef, arg0)
if err != nil {
return "", "", err
}
iterVar2, err := extractIterVar(mef, arg1)
if err != nil {
return "", "", err
}
if iterVar1 == iterVar2 {
return "", "", mef.NewError(arg1.ID(), fmt.Sprintf("duplicate variable name: %s", iterVar1))
}
if iterVar1 == parser.AccumulatorName {
return "", "", mef.NewError(arg0.ID(), "iteration variable overwrites accumulator variable")
}
if iterVar2 == parser.AccumulatorName {
return "", "", mef.NewError(arg1.ID(), "iteration variable overwrites accumulator variable")
}
return iterVar1, iterVar2, nil
}
func extractIterVar(mef cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) {
iterVar, found := extractIdent(target)
if !found {
return "", mef.NewError(target.ID(), "argument must be a simple name")
}
return iterVar, nil
}

View File

@@ -36,7 +36,7 @@ import (
// Examples:
//
// base64.decode('aGVsbG8=') // return b'hello'
// base64.decode('aGVsbG8') // error
// base64.decode('aGVsbG8') // return b'hello'
//
// # Base64.Encode
//
@@ -79,7 +79,14 @@ func (encoderLib) ProgramOptions() []cel.ProgramOption {
}
func base64DecodeString(str string) ([]byte, error) {
return base64.StdEncoding.DecodeString(str)
b, err := base64.StdEncoding.DecodeString(str)
if err == nil {
return b, nil
}
if _, tryAltEncoding := err.(base64.CorruptInputError); tryAltEncoding {
return base64.RawStdEncoding.DecodeString(str)
}
return nil, err
}
func base64EncodeBytes(bytes []byte) (string, error) {

View File

@@ -50,14 +50,18 @@ func listStringOrError(strs []string, err error) ref.Val {
return types.DefaultTypeAdapter.NativeToValue(strs)
}
func macroTargetMatchesNamespace(ns string, target ast.Expr) bool {
func extractIdent(target ast.Expr) (string, bool) {
switch target.Kind() {
case ast.IdentKind:
if target.AsIdent() != ns {
return false
}
return true
return target.AsIdent(), true
default:
return false
return "", false
}
}
func macroTargetMatchesNamespace(ns string, target ast.Expr) bool {
if id, found := extractIdent(target); found {
return id == ns
}
return false
}

View File

@@ -16,15 +16,70 @@ package ext
import (
"fmt"
"math"
"sort"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
)
var comparableTypes = []*cel.Type{
cel.IntType,
cel.UintType,
cel.DoubleType,
cel.BoolType,
cel.DurationType,
cel.TimestampType,
cel.StringType,
cel.BytesType,
}
// Lists returns a cel.EnvOption to configure extended functions for list manipulation.
// As a general note, all indices are zero-based.
//
// # Distinct
//
// Introduced in version: 2
//
// Returns the distinct elements of a list.
//
// <list(T)>.distinct() -> <list(T)>
//
// Examples:
//
// [1, 2, 2, 3, 3, 3].distinct() // return [1, 2, 3]
// ["b", "b", "c", "a", "c"].distinct() // return ["b", "c", "a"]
// [1, "b", 2, "b"].distinct() // return [1, "b", 2]
//
// # Range
//
// Introduced in version: 2
//
// Returns a list of integers from 0 to n-1.
//
// lists.range(<int>) -> <list(int)>
//
// Examples:
//
// lists.range(5) -> [0, 1, 2, 3, 4]
//
// # Reverse
//
// Introduced in version: 2
//
// Returns the elements of a list in reverse order.
//
// <list(T)>.reverse() -> <list(T)>
//
// Examples:
//
// [5, 3, 1, 2].reverse() // return [2, 1, 3, 5]
//
// # Slice
//
// Returns a new sub-list using the indexes provided.
@@ -35,21 +90,105 @@ import (
//
// [1,2,3,4].slice(1, 3) // return [2, 3]
// [1,2,3,4].slice(2, 4) // return [3 ,4]
func Lists() cel.EnvOption {
return cel.Lib(listsLib{})
//
// # Flatten
//
// Flattens a list recursively.
// If an optional depth is provided, the list is flattened to a the specificied level.
// A negative depth value will result in an error.
//
// <list>.flatten(<list>) -> <list>
// <list>.flatten(<list>, <int>) -> <list>
//
// Examples:
//
// [1,[2,3],[4]].flatten() // return [1, 2, 3, 4]
// [1,[2,[3,4]]].flatten() // return [1, 2, [3, 4]]
// [1,2,[],[],[3,4]].flatten() // return [1, 2, 3, 4]
// [1,[2,[3,[4]]]].flatten(2) // return [1, 2, 3, [4]]
// [1,[2,[3,[4]]]].flatten(-1) // error
//
// # Sort
//
// Introduced in version: 2
//
// Sorts a list with comparable elements. If the element type is not comparable
// or the element types are not the same, the function will produce an error.
//
// <list(T)>.sort() -> <list(T)>
// T in {int, uint, double, bool, duration, timestamp, string, bytes}
//
// Examples:
//
// [3, 2, 1].sort() // return [1, 2, 3]
// ["b", "c", "a"].sort() // return ["a", "b", "c"]
// [1, "b"].sort() // error
// [[1, 2, 3]].sort() // error
//
// # SortBy
//
// Sorts a list by a key value, i.e., the order is determined by the result of
// an expression applied to each element of the list.
// The output of the key expression must be a comparable type, otherwise the
// function will return an error.
//
// <list(T)>.sortBy(<bindingName>, <keyExpr>) -> <list(T)>
// keyExpr returns a value in {int, uint, double, bool, duration, timestamp, string, bytes}
// Examples:
//
// [
// Player { name: "foo", score: 0 },
// Player { name: "bar", score: -10 },
// Player { name: "baz", score: 1000 },
// ].sortBy(e, e.score).map(e, e.name)
// == ["bar", "foo", "baz"]
func Lists(options ...ListsOption) cel.EnvOption {
l := &listsLib{
version: math.MaxUint32,
}
for _, o := range options {
l = o(l)
}
return cel.Lib(l)
}
type listsLib struct{}
type listsLib struct {
version uint32
}
// LibraryName implements the SingletonLibrary interface method.
func (listsLib) LibraryName() string {
return "cel.lib.ext.lists"
}
// ListsOption is a functional interface for configuring the strings library.
type ListsOption func(*listsLib) *listsLib
// ListsVersion configures the version of the string library.
//
// The version limits which functions are available. Only functions introduced
// below or equal to the given version included in the library. If this option
// is not set, all functions are available.
//
// See the library documentation to determine which version a function was introduced.
// If the documentation does not state which version a function was introduced, it can
// be assumed to be introduced at version 0, when the library was first created.
func ListsVersion(version uint32) ListsOption {
return func(lib *listsLib) *listsLib {
lib.version = version
return lib
}
}
// CompileOptions implements the Library interface method.
func (listsLib) CompileOptions() []cel.EnvOption {
func (lib listsLib) CompileOptions() []cel.EnvOption {
listType := cel.ListType(cel.TypeParamType("T"))
return []cel.EnvOption{
listListType := cel.ListType(listType)
listDyn := cel.ListType(cel.DynType)
opts := []cel.EnvOption{
cel.Function("slice",
cel.MemberOverload("list_slice",
[]*cel.Type{listType, cel.IntType, cel.IntType}, listType,
@@ -66,6 +205,151 @@ func (listsLib) CompileOptions() []cel.EnvOption {
),
),
}
if lib.version >= 1 {
opts = append(opts,
cel.Function("flatten",
cel.MemberOverload("list_flatten",
[]*cel.Type{listListType}, listType,
cel.UnaryBinding(func(arg ref.Val) ref.Val {
list, ok := arg.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
flatList, err := flatten(list, 1)
if err != nil {
return types.WrapErr(err)
}
return types.DefaultTypeAdapter.NativeToValue(flatList)
}),
),
cel.MemberOverload("list_flatten_int",
[]*cel.Type{listDyn, types.IntType}, listDyn,
cel.BinaryBinding(func(arg1, arg2 ref.Val) ref.Val {
list, ok := arg1.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
depth, ok := arg2.(types.Int)
if !ok {
return types.MaybeNoSuchOverloadErr(arg2)
}
flatList, err := flatten(list, int64(depth))
if err != nil {
return types.WrapErr(err)
}
return types.DefaultTypeAdapter.NativeToValue(flatList)
}),
),
// To handle the case where a variable of just `list(T)` is provided at runtime
// with a graceful failure more, disable the type guards since the implementation
// can handle lists which are already flat.
decls.DisableTypeGuards(true),
),
)
}
if lib.version >= 2 {
sortDecl := cel.Function("sort",
append(
templatedOverloads(comparableTypes, func(t *cel.Type) cel.FunctionOpt {
return cel.MemberOverload(
fmt.Sprintf("list_%s_sort", t.TypeName()),
[]*cel.Type{cel.ListType(t)}, cel.ListType(t),
)
}),
cel.SingletonUnaryBinding(
func(arg ref.Val) ref.Val {
list, ok := arg.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
sorted, err := sortList(list)
if err != nil {
return types.WrapErr(err)
}
return sorted
},
// List traits
traits.ListerType,
),
)...,
)
opts = append(opts, sortDecl)
opts = append(opts, cel.Macros(cel.ReceiverMacro("sortBy", 2, sortByMacro)))
opts = append(opts, cel.Function("@sortByAssociatedKeys",
append(
templatedOverloads(comparableTypes, func(u *cel.Type) cel.FunctionOpt {
return cel.MemberOverload(
fmt.Sprintf("list_%s_sortByAssociatedKeys", u.TypeName()),
[]*cel.Type{listType, cel.ListType(u)}, listType,
)
}),
cel.SingletonBinaryBinding(
func(arg1 ref.Val, arg2 ref.Val) ref.Val {
list, ok := arg1.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
keys, ok := arg2.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg2)
}
sorted, err := sortListByAssociatedKeys(list, keys)
if err != nil {
return types.WrapErr(err)
}
return sorted
},
// List traits
traits.ListerType,
),
)...,
))
opts = append(opts, cel.Function("lists.range",
cel.Overload("lists_range",
[]*cel.Type{cel.IntType}, cel.ListType(cel.IntType),
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
n := args[0].(types.Int)
result, err := genRange(n)
if err != nil {
return types.WrapErr(err)
}
return result
}),
),
))
opts = append(opts, cel.Function("reverse",
cel.MemberOverload("list_reverse",
[]*cel.Type{listType}, listType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
list := args[0].(traits.Lister)
result, err := reverseList(list)
if err != nil {
return types.WrapErr(err)
}
return result
}),
),
))
opts = append(opts, cel.Function("distinct",
cel.MemberOverload("list_distinct",
[]*cel.Type{listType}, listType,
cel.UnaryBinding(func(list ref.Val) ref.Val {
result, err := distinctList(list.(traits.Lister))
if err != nil {
return types.WrapErr(err)
}
return result
}),
),
))
}
return opts
}
// ProgramOptions implements the Library interface method.
@@ -73,6 +357,24 @@ func (listsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func genRange(n types.Int) (ref.Val, error) {
var newList []ref.Val
for i := types.Int(0); i < n; i++ {
newList = append(newList, i)
}
return types.DefaultTypeAdapter.NativeToValue(newList), nil
}
func reverseList(list traits.Lister) (ref.Val, error) {
var newList []ref.Val
listLength := list.Size().(types.Int)
for i := types.Int(0); i < listLength; i++ {
val := list.Get(listLength - i - 1)
newList = append(newList, val)
}
return types.DefaultTypeAdapter.NativeToValue(newList), nil
}
func slice(list traits.Lister, start, end types.Int) (ref.Val, error) {
listLength := list.Size().(types.Int)
if start < 0 || end < 0 {
@@ -92,3 +394,167 @@ func slice(list traits.Lister, start, end types.Int) (ref.Val, error) {
}
return types.DefaultTypeAdapter.NativeToValue(newList), nil
}
func flatten(list traits.Lister, depth int64) ([]ref.Val, error) {
if depth < 0 {
return nil, fmt.Errorf("level must be non-negative")
}
var newList []ref.Val
iter := list.Iterator()
for iter.HasNext() == types.True {
val := iter.Next()
nestedList, isList := val.(traits.Lister)
if !isList || depth == 0 {
newList = append(newList, val)
continue
} else {
flattenedList, err := flatten(nestedList, depth-1)
if err != nil {
return nil, err
}
newList = append(newList, flattenedList...)
}
}
return newList, nil
}
func sortList(list traits.Lister) (ref.Val, error) {
return sortListByAssociatedKeys(list, list)
}
// Internal function used for the implementation of sort() and sortBy().
//
// Sorts a list of arbitrary elements, according to the order produced by sorting
// another list of comparable elements. If the element type of the keys is not
// comparable or the element types are not the same, the function will produce an error.
//
// <list(T)>.@sortByAssociatedKeys(<list(U)>) -> <list(T)>
// U in {int, uint, double, bool, duration, timestamp, string, bytes}
//
// Example:
//
// ["foo", "bar", "baz"].@sortByAssociatedKeys([3, 1, 2]) // return ["bar", "baz", "foo"]
func sortListByAssociatedKeys(list, keys traits.Lister) (ref.Val, error) {
listLength := list.Size().(types.Int)
keysLength := keys.Size().(types.Int)
if listLength != keysLength {
return nil, fmt.Errorf(
"@sortByAssociatedKeys() expected a list of the same size as the associated keys list, but got %d and %d elements respectively",
listLength,
keysLength,
)
}
if listLength == 0 {
return list, nil
}
elem := keys.Get(types.IntZero)
if _, ok := elem.(traits.Comparer); !ok {
return nil, fmt.Errorf("list elements must be comparable")
}
sortedIndices := make([]ref.Val, 0, listLength)
for i := types.IntZero; i < listLength; i++ {
if keys.Get(i).Type() != elem.Type() {
return nil, fmt.Errorf("list elements must have the same type")
}
sortedIndices = append(sortedIndices, i)
}
sort.Slice(sortedIndices, func(i, j int) bool {
iKey := keys.Get(sortedIndices[i])
jKey := keys.Get(sortedIndices[j])
return iKey.(traits.Comparer).Compare(jKey) == types.IntNegOne
})
sorted := make([]ref.Val, 0, listLength)
for _, sortedIdx := range sortedIndices {
sorted = append(sorted, list.Get(sortedIdx))
}
return types.DefaultTypeAdapter.NativeToValue(sorted), nil
}
// sortByMacro transforms an expression like:
//
// mylistExpr.sortBy(e, -math.abs(e))
//
// into something equivalent to:
//
// cel.bind(
// __sortBy_input__,
// myListExpr,
// __sortBy_input__.@sortByAssociatedKeys(__sortBy_input__.map(e, -math.abs(e))
// )
func sortByMacro(meh cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
varIdent := meh.NewIdent("@__sortBy_input__")
varName := varIdent.AsIdent()
targetKind := target.Kind()
if targetKind != ast.ListKind &&
targetKind != ast.SelectKind &&
targetKind != ast.IdentKind &&
targetKind != ast.ComprehensionKind && targetKind != ast.CallKind {
return nil, meh.NewError(target.ID(), fmt.Sprintf("sortBy can only be applied to a list, identifier, comprehension, call or select expression"))
}
mapCompr, err := parser.MakeMap(meh, meh.Copy(varIdent), args)
if err != nil {
return nil, err
}
callExpr := meh.NewMemberCall("@sortByAssociatedKeys",
meh.Copy(varIdent),
mapCompr,
)
bindExpr := meh.NewComprehension(
meh.NewList(),
"#unused",
varName,
target,
meh.NewLiteral(types.False),
varIdent,
callExpr,
)
return bindExpr, nil
}
func distinctList(list traits.Lister) (ref.Val, error) {
listLength := list.Size().(types.Int)
if listLength == 0 {
return list, nil
}
uniqueList := make([]ref.Val, 0, listLength)
for i := types.IntZero; i < listLength; i++ {
val := list.Get(i)
seen := false
for j := types.IntZero; j < types.Int(len(uniqueList)); j++ {
if i == j {
continue
}
other := uniqueList[j]
if val.Equal(other) == types.True {
seen = true
break
}
}
if !seen {
uniqueList = append(uniqueList, val)
}
}
return types.DefaultTypeAdapter.NativeToValue(uniqueList), nil
}
func templatedOverloads(types []*cel.Type, template func(t *cel.Type) cel.FunctionOpt) []cel.FunctionOpt {
overloads := make([]cel.FunctionOpt, len(types))
for i, t := range types {
overloads[i] = template(t)
}
return overloads
}

View File

@@ -325,8 +325,12 @@ import (
//
// math.isFinite(0.0/0.0) // returns false
// math.isFinite(1.2) // returns true
func Math() cel.EnvOption {
return cel.Lib(&mathLib{version: math.MaxUint32})
func Math(options ...MathOption) cel.EnvOption {
m := &mathLib{version: math.MaxUint32}
for _, o := range options {
m = o(m)
}
return cel.Lib(m)
}
const (
@@ -366,8 +370,10 @@ var (
errIntOverflow = types.NewErr("integer overflow")
)
// MathOption declares a functional operator for configuring math extensions.
type MathOption func(*mathLib) *mathLib
// MathVersion sets the library version for math extensions.
func MathVersion(version uint32) MathOption {
return func(lib *mathLib) *mathLib {
lib.version = version

View File

@@ -128,16 +128,66 @@ func NativeTypes(args ...any) cel.EnvOption {
// NativeTypesOption is a functional interface for configuring handling of native types.
type NativeTypesOption func(*nativeTypeOptions) error
// NativeTypesFieldNameHandler is a handler for mapping a reflect.StructField to a CEL field name.
// This can be used to override the default Go struct field to CEL field name mapping.
type NativeTypesFieldNameHandler = func(field reflect.StructField) string
func fieldNameByTag(structTagToParse string) func(field reflect.StructField) string {
return func(field reflect.StructField) string {
tag, found := field.Tag.Lookup(structTagToParse)
if found {
splits := strings.Split(tag, ",")
if len(splits) > 0 {
// We make the assumption that the leftmost entry in the tag is the name.
// This seems to be true for most tags that have the concept of a name/key, such as:
// https://pkg.go.dev/encoding/xml#Marshal
// https://pkg.go.dev/encoding/json#Marshal
// https://pkg.go.dev/go.mongodb.org/mongo-driver/bson#hdr-Structs
// https://pkg.go.dev/gopkg.in/yaml.v2#Marshal
name := splits[0]
return name
}
}
return field.Name
}
}
type nativeTypeOptions struct {
// parseStructTags controls if CEL should support struct field renames, by parsing
// struct field tags.
parseStructTags bool
// fieldNameHandler controls how CEL should perform struct field renames.
// This is most commonly used for switching to parsing based off the struct field tag,
// such as "cel" or "json".
fieldNameHandler NativeTypesFieldNameHandler
}
// ParseStructTags configures if native types field names should be overridable by CEL struct tags.
// This is equivalent to ParseStructTag("cel")
func ParseStructTags(enabled bool) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.parseStructTags = true
if enabled {
ntp.fieldNameHandler = fieldNameByTag("cel")
} else {
ntp.fieldNameHandler = nil
}
return nil
}
}
// ParseStructTag configures the struct tag to parse. The 0th item in the tag is used as the name of the CEL field.
// For example:
// If the tag to parse is "cel" and the struct field has tag cel:"foo", the CEL struct field will be "foo".
// If the tag to parse is "json" and the struct field has tag json:"foo,omitempty", the CEL struct field will be "foo".
func ParseStructTag(tag string) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.fieldNameHandler = fieldNameByTag(tag)
return nil
}
}
// ParseStructField configures how to parse Go struct fields. It can be used to customize struct field parsing.
func ParseStructField(handler NativeTypesFieldNameHandler) NativeTypesOption {
return func(ntp *nativeTypeOptions) error {
ntp.fieldNameHandler = handler
return nil
}
}
@@ -147,7 +197,7 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
result, err := newNativeTypes(tpOptions.parseStructTags, rt)
result, err := newNativeTypes(tpOptions.fieldNameHandler, rt)
if err != nil {
return nil, err
}
@@ -155,7 +205,7 @@ func newNativeTypeProvider(tpOptions nativeTypeOptions, adapter types.Adapter, p
nativeTypes[result[idx].TypeName()] = result[idx]
}
case reflect.Value:
result, err := newNativeTypes(tpOptions.parseStructTags, rt.Type())
result, err := newNativeTypes(tpOptions.fieldNameHandler, rt.Type())
if err != nil {
return nil, err
}
@@ -208,16 +258,12 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool
return tp.baseProvider.FindStructType(typeName)
}
func toFieldName(parseStructTag bool, f reflect.StructField) string {
if !parseStructTag {
func toFieldName(fieldNameHandler NativeTypesFieldNameHandler, f reflect.StructField) string {
if fieldNameHandler == nil {
return f.Name
}
if name, found := f.Tag.Lookup("cel"); found {
return name
}
return f.Name
return fieldNameHandler(f)
}
// FindStructFieldNames looks up the type definition first from the native types, then from
@@ -228,7 +274,7 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
fieldCount := t.refType.NumField()
fields := make([]string, fieldCount)
for i := 0; i < fieldCount; i++ {
fields[i] = toFieldName(tp.options.parseStructTags, t.refType.Field(i))
fields[i] = toFieldName(tp.options.fieldNameHandler, t.refType.Field(i))
}
return fields, true
}
@@ -238,22 +284,6 @@ func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, b
return tp.baseProvider.FindStructFieldNames(typeName)
}
// valueFieldByName retrieves the corresponding reflect.Value field for the given field name, by
// searching for a matching field tag value or field name.
func valueFieldByName(parseStructTags bool, target reflect.Value, fieldName string) reflect.Value {
if !parseStructTags {
return target.FieldByName(fieldName)
}
for i := 0; i < target.Type().NumField(); i++ {
f := target.Type().Field(i)
if toFieldName(parseStructTags, f) == fieldName {
return target.FieldByIndex(f.Index)
}
}
return reflect.Value{}
}
// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
@@ -273,12 +303,12 @@ func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*
Type: celType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
refField := refVal.FieldByName(refField.Name)
return !refField.IsZero()
},
GetFrom: func(obj any) (any, error) {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := valueFieldByName(tp.options.parseStructTags, refVal, fieldName)
refField := refVal.FieldByName(refField.Name)
return getFieldValue(refField), nil
},
}, true
@@ -404,7 +434,7 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
}
func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(tp.options.parseStructTags, refValue.Type())
valType, err := newNativeType(tp.options.fieldNameHandler, refValue.Type())
if err != nil {
return types.NewErr(err.Error())
}
@@ -456,7 +486,7 @@ func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldName := toFieldName(o.valType.parseStructTags, fieldType)
fieldName := toFieldName(o.valType.fieldNameHandler, fieldType)
fieldCELVal := o.NativeToValue(fieldValue.Interface())
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
if err != nil {
@@ -554,8 +584,8 @@ func (o *nativeObj) Value() any {
return o.val
}
func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(parseStructTags, rawType)
func newNativeTypes(fieldNameHandler NativeTypesFieldNameHandler, rawType reflect.Type) ([]*nativeType, error) {
nt, err := newNativeType(fieldNameHandler, rawType)
if err != nil {
return nil, err
}
@@ -574,7 +604,7 @@ func newNativeTypes(parseStructTags bool, rawType reflect.Type) ([]*nativeType,
return
}
alreadySeen[t.String()] = struct{}{}
nt, ntErr := newNativeType(parseStructTags, t)
nt, ntErr := newNativeType(fieldNameHandler, t)
if ntErr != nil {
err = ntErr
return
@@ -594,7 +624,7 @@ var (
errDuplicatedFieldName = errors.New("field name already exists in struct")
)
func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, error) {
func newNativeType(fieldNameHandler NativeTypesFieldNameHandler, rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
refType = refType.Elem()
@@ -604,12 +634,12 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
}
// Since naming collisions can only happen with struct tag parsing, we only check for them if it is enabled.
if parseStructTags {
if fieldNameHandler != nil {
fieldNames := make(map[string]struct{})
for idx := 0; idx < refType.NumField(); idx++ {
field := refType.Field(idx)
fieldName := toFieldName(parseStructTags, field)
fieldName := toFieldName(fieldNameHandler, field)
if _, found := fieldNames[fieldName]; found {
return nil, fmt.Errorf("invalid field name `%s` in struct `%s`: %w", fieldName, refType.Name(), errDuplicatedFieldName)
@@ -620,16 +650,16 @@ func newNativeType(parseStructTags bool, rawType reflect.Type) (*nativeType, err
}
return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
parseStructTags: parseStructTags,
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
fieldNameHandler: fieldNameHandler,
}, nil
}
type nativeType struct {
typeName string
refType reflect.Type
parseStructTags bool
typeName string
refType reflect.Type
fieldNameHandler NativeTypesFieldNameHandler
}
// ConvertToNative implements ref.Val.ConvertToNative.
@@ -680,13 +710,13 @@ func (t *nativeType) Value() any {
// fieldByName returns the corresponding reflect.StructField for the give name either by matching
// field tag or field name.
func (t *nativeType) fieldByName(fieldName string) (reflect.StructField, bool) {
if !t.parseStructTags {
if t.fieldNameHandler == nil {
return t.refType.FieldByName(fieldName)
}
for i := 0; i < t.refType.NumField(); i++ {
f := t.refType.Field(i)
if toFieldName(t.parseStructTags, f) == fieldName {
if toFieldName(t.fieldNameHandler, f) == fieldName {
return f, true
}
}

View File

@@ -119,7 +119,8 @@ const (
// 'hello mellow'.indexOf('jello') // returns -1
// 'hello mellow'.indexOf('', 2) // returns 2
// 'hello mellow'.indexOf('ello', 2) // returns 7
// 'hello mellow'.indexOf('ello', 20) // error
// 'hello mellow'.indexOf('ello', 20) // returns -1
// 'hello mellow'.indexOf('ello', -1) // error
//
// # Join
//
@@ -155,6 +156,7 @@ const (
// 'hello mellow'.lastIndexOf('ello') // returns 7
// 'hello mellow'.lastIndexOf('jello') // returns -1
// 'hello mellow'.lastIndexOf('ello', 6) // returns 1
// 'hello mellow'.lastIndexOf('ello', 20) // returns -1
// 'hello mellow'.lastIndexOf('ello', -1) // error
//
// # LowerAscii
@@ -520,7 +522,7 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption {
if lib.version >= 3 {
opts = append(opts,
cel.Function("reverse",
cel.MemberOverload("reverse", []*cel.Type{cel.StringType}, cel.StringType,
cel.MemberOverload("string_reverse", []*cel.Type{cel.StringType}, cel.StringType,
cel.UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return stringOrError(reverse(string(s)))
@@ -561,9 +563,13 @@ func indexOfOffset(str, substr string, offset int64) (int64, error) {
off := int(offset)
runes := []rune(str)
subrunes := []rune(substr)
if off < 0 || off >= len(runes) {
if off < 0 {
return -1, fmt.Errorf("index out of range: %d", off)
}
// If the offset exceeds the length, return -1 rather than error.
if off >= len(runes) {
return -1, nil
}
for i := off; i < len(runes)-(len(subrunes)-1); i++ {
found := true
for j := 0; j < len(subrunes); j++ {
@@ -594,9 +600,13 @@ func lastIndexOfOffset(str, substr string, offset int64) (int64, error) {
off := int(offset)
runes := []rune(str)
subrunes := []rune(substr)
if off < 0 || off >= len(runes) {
if off < 0 {
return -1, fmt.Errorf("index out of range: %d", off)
}
// If the offset is far greater than the length return -1
if off >= len(runes) {
return -1, nil
}
if off > len(runes)-len(subrunes) {
off = len(runes) - len(subrunes)
}

View File

@@ -17,7 +17,6 @@ package interpreter
import (
"errors"
"fmt"
"sync"
"github.com/google/cel-go/common/types/ref"
)
@@ -167,35 +166,3 @@ type partActivation struct {
func (a *partActivation) UnknownAttributePatterns() []*AttributePattern {
return a.unknowns
}
// varActivation represents a single mutable variable binding.
//
// This activation type should only be used within folds as the fold loop controls the object
// life-cycle.
type varActivation struct {
parent Activation
name string
val ref.Val
}
// Parent implements the Activation interface method.
func (v *varActivation) Parent() Activation {
return v.parent
}
// ResolveName implements the Activation interface method.
func (v *varActivation) ResolveName(name string) (any, bool) {
if name == v.name {
return v.val, true
}
return v.parent.ResolveName(name)
}
var (
// pool of var activations to reduce allocations during folds.
varActivationPool = &sync.Pool{
New: func() any {
return &varActivation{}
},
}
)

View File

@@ -16,6 +16,7 @@ package interpreter
import (
"fmt"
"sync"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
@@ -96,7 +97,7 @@ type InterpretableCall interface {
Args() []Interpretable
}
// InterpretableConstructor interface for inspecting Interpretable instructions that initialize a list, map
// InterpretableConstructor interface for inspecting Interpretable instructions that initialize a list, map
// or struct.
type InterpretableConstructor interface {
Interpretable
@@ -720,24 +721,31 @@ func (o *evalObj) Eval(ctx Activation) ref.Val {
return types.LabelErrNode(o.id, o.provider.NewValue(o.typeName, fieldVals))
}
// InitVals implements the InterpretableConstructor interface method.
func (o *evalObj) InitVals() []Interpretable {
return o.vals
}
// Type implements the InterpretableConstructor interface method.
func (o *evalObj) Type() ref.Type {
return types.NewObjectTypeValue(o.typeName)
return types.NewObjectType(o.typeName)
}
type evalFold struct {
id int64
accuVar string
iterVar string
iterRange Interpretable
accu Interpretable
cond Interpretable
step Interpretable
result Interpretable
adapter types.Adapter
id int64
accuVar string
iterVar string
iterVar2 string
iterRange Interpretable
accu Interpretable
cond Interpretable
step Interpretable
result Interpretable
adapter types.Adapter
// note an exhaustive fold will ensure that all branches are evaluated
// when using mutable values, these branches will mutate the final result
// rather than make a throw-away computation.
exhaustive bool
interruptable bool
}
@@ -749,64 +757,30 @@ func (fold *evalFold) ID() int64 {
// Eval implements the Interpretable interface method.
func (fold *evalFold) Eval(ctx Activation) ref.Val {
// Initialize the folder interface
f := newFolder(fold, ctx)
defer releaseFolder(f)
foldRange := fold.iterRange.Eval(ctx)
if fold.iterVar2 != "" {
var foldable traits.Foldable
switch r := foldRange.(type) {
case traits.Mapper:
foldable = types.ToFoldableMap(r)
case traits.Lister:
foldable = types.ToFoldableList(r)
default:
return types.NewErrWithNodeID(fold.ID(), "unsupported comprehension range type: %T", foldRange)
}
foldable.Fold(f)
return f.evalResult()
}
if !foldRange.Type().HasTrait(traits.IterableType) {
return types.ValOrErr(foldRange, "got '%T', expected iterable type", foldRange)
}
// Configure the fold activation with the accumulator initial value.
accuCtx := varActivationPool.Get().(*varActivation)
accuCtx.parent = ctx
accuCtx.name = fold.accuVar
accuCtx.val = fold.accu.Eval(ctx)
// If the accumulator starts as an empty list, then the comprehension will build a list
// so create a mutable list to optimize the cost of the inner loop.
l, ok := accuCtx.val.(traits.Lister)
buildingList := false
if !fold.exhaustive && ok && l.Size() == types.IntZero {
buildingList = true
accuCtx.val = types.NewMutableList(fold.adapter)
}
iterCtx := varActivationPool.Get().(*varActivation)
iterCtx.parent = accuCtx
iterCtx.name = fold.iterVar
interrupted := false
it := foldRange.(traits.Iterable).Iterator()
for it.HasNext() == types.True {
// Modify the iter var in the fold activation.
iterCtx.val = it.Next()
// Evaluate the condition, terminate the loop if false.
cond := fold.cond.Eval(iterCtx)
condBool, ok := cond.(types.Bool)
if !fold.exhaustive && ok && condBool != types.True {
break
}
// Evaluate the evaluation step into accu var.
accuCtx.val = fold.step.Eval(iterCtx)
if fold.interruptable {
if stop, found := ctx.ResolveName("#interrupted"); found && stop == true {
interrupted = true
break
}
}
}
varActivationPool.Put(iterCtx)
if interrupted {
varActivationPool.Put(accuCtx)
return types.NewErr("operation interrupted")
}
// Compute the result.
res := fold.result.Eval(accuCtx)
varActivationPool.Put(accuCtx)
// Convert a mutable list to an immutable one, if the comprehension has generated a list as a result.
if !types.IsUnknownOrError(res) && buildingList {
if _, ok := res.(traits.MutableLister); ok {
res = res.(traits.MutableLister).ToImmutableList()
}
}
return res
iterable := foldRange.(traits.Iterable)
return f.foldIterable(iterable)
}
// Optional Interpretable implementations that specialize, subsume, or extend the core evaluation
@@ -1262,3 +1236,172 @@ func invalidOptionalEntryInit(field any, value ref.Val) ref.Val {
func invalidOptionalElementInit(value ref.Val) ref.Val {
return types.NewErr("cannot initialize optional list element from non-optional value %v", value)
}
// newFolder creates or initializes a pooled folder instance.
func newFolder(eval *evalFold, ctx Activation) *folder {
f := folderPool.Get().(*folder)
f.evalFold = eval
f.Activation = ctx
return f
}
// releaseFolder resets and releases a pooled folder instance.
func releaseFolder(f *folder) {
f.reset()
folderPool.Put(f)
}
// folder tracks the state associated with folding a list or map with a comprehension v2 style macro.
//
// The folder embeds an interpreter.Activation and Interpretable evalFold value as well as implements
// the traits.Folder interface methods.
//
// Instances of a folder are intended to be pooled to minimize allocation overhead with this temporary
// bookkeeping object which supports lazy evaluation of the accumulator init expression which is useful
// in preserving evaluation order semantics which might otherwise be disrupted through the use of
// cel.bind or cel.@block.
type folder struct {
*evalFold
Activation
// fold state objects.
accuVal ref.Val
iterVar1Val any
iterVar2Val any
// bookkeeping flags to modify Activation and fold behaviors.
initialized bool
mutableValue bool
interrupted bool
computeResult bool
}
func (f *folder) foldIterable(iterable traits.Iterable) ref.Val {
it := iterable.Iterator()
for it.HasNext() == types.True {
f.iterVar1Val = it.Next()
cond := f.cond.Eval(f)
condBool, ok := cond.(types.Bool)
if f.interrupted || (!f.exhaustive && ok && condBool != types.True) {
return f.evalResult()
}
// Update the accumulation value and check for eval interuption.
f.accuVal = f.step.Eval(f)
f.initialized = true
if f.interruptable && checkInterrupt(f.Activation) {
f.interrupted = true
return f.evalResult()
}
}
return f.evalResult()
}
// FoldEntry will either fold comprehension v1 style macros if iterVar2 is unset, or comprehension v2 style
// macros if both the iterVar and iterVar2 are set to non-empty strings.
func (f *folder) FoldEntry(key, val any) bool {
// Default to referencing both values.
f.iterVar1Val = key
f.iterVar2Val = val
// Terminate evaluation if evaluation is interrupted or the condition is not true and exhaustive
// eval is not enabled.
cond := f.cond.Eval(f)
condBool, ok := cond.(types.Bool)
if f.interrupted || (!f.exhaustive && ok && condBool != types.True) {
return false
}
// Update the accumulation value and check for eval interuption.
f.accuVal = f.step.Eval(f)
f.initialized = true
if f.interruptable && checkInterrupt(f.Activation) {
f.interrupted = true
return false
}
return true
}
// ResolveName overrides the default Activation lookup to perform lazy initialization of the accumulator
// and specialized lookups of iteration values with consideration for whether the final result is being
// computed and the iteration variables should be ignored.
func (f *folder) ResolveName(name string) (any, bool) {
if name == f.accuVar {
if !f.initialized {
f.initialized = true
initVal := f.accu.Eval(f.Activation)
if !f.exhaustive {
if l, isList := initVal.(traits.Lister); isList && l.Size() == types.IntZero {
initVal = types.NewMutableList(f.adapter)
f.mutableValue = true
}
if m, isMap := initVal.(traits.Mapper); isMap && m.Size() == types.IntZero {
initVal = types.NewMutableMap(f.adapter, map[ref.Val]ref.Val{})
f.mutableValue = true
}
}
f.accuVal = initVal
}
return f.accuVal, true
}
if !f.computeResult {
if name == f.iterVar {
f.iterVar1Val = f.adapter.NativeToValue(f.iterVar1Val)
return f.iterVar1Val, true
}
if name == f.iterVar2 {
f.iterVar2Val = f.adapter.NativeToValue(f.iterVar2Val)
return f.iterVar2Val, true
}
}
return f.Activation.ResolveName(name)
}
// evalResult computes the final result of the fold after all entries have been folded and accumulated.
func (f *folder) evalResult() ref.Val {
f.computeResult = true
if f.interrupted {
return types.NewErr("operation interrupted")
}
res := f.result.Eval(f)
// Convert a mutable list or map to an immutable one if the comprehension has generated a list or
// map as a result.
if !types.IsUnknownOrError(res) && f.mutableValue {
if _, ok := res.(traits.MutableLister); ok {
res = res.(traits.MutableLister).ToImmutableList()
}
if _, ok := res.(traits.MutableMapper); ok {
res = res.(traits.MutableMapper).ToImmutableMap()
}
}
return res
}
// reset clears any state associated with folder evaluation.
func (f *folder) reset() {
f.evalFold = nil
f.Activation = nil
f.accuVal = nil
f.iterVar1Val = nil
f.iterVar2Val = nil
f.initialized = false
f.mutableValue = false
f.interrupted = false
f.computeResult = false
}
func checkInterrupt(a Activation) bool {
stop, found := a.ResolveName("#interrupted")
return found && stop == true
}
var (
// pool of var folders to reduce allocations during folds.
folderPool = &sync.Pool{
New: func() any {
return &folder{}
},
}
)

View File

@@ -603,6 +603,7 @@ func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) {
accuVar: fold.AccuVar(),
accu: accu,
iterVar: fold.IterVar(),
iterVar2: fold.IterVar2(),
iterRange: iterRange,
cond: cond,
step: step,

View File

@@ -1,7 +1,7 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
package(
default_visibility = ["//parser:__subpackages__"],
default_visibility = ["//:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)

View File

@@ -115,7 +115,7 @@ func (p *parserHelper) newObjectField(fieldID int64, field string, value ast.Exp
func (p *parserHelper) newComprehension(ctx any,
iterRange ast.Expr,
iterVar string,
iterVar,
accuVar string,
accuInit ast.Expr,
condition ast.Expr,
@@ -125,6 +125,18 @@ func (p *parserHelper) newComprehension(ctx any,
p.newID(ctx), iterRange, iterVar, accuVar, accuInit, condition, step, result)
}
func (p *parserHelper) newComprehensionTwoVar(ctx any,
iterRange ast.Expr,
iterVar, iterVar2,
accuVar string,
accuInit ast.Expr,
condition ast.Expr,
step ast.Expr,
result ast.Expr) ast.Expr {
return p.exprFactory.NewComprehensionTwoVar(
p.newID(ctx), iterRange, iterVar, iterVar2, accuVar, accuInit, condition, step, result)
}
func (p *parserHelper) newID(ctx any) int64 {
if id, isID := ctx.(int64); isID {
return id
@@ -383,8 +395,10 @@ func (e *exprHelper) Copy(expr ast.Expr) ast.Expr {
cond := e.Copy(compre.LoopCondition())
step := e.Copy(compre.LoopStep())
result := e.Copy(compre.Result())
return e.exprFactory.NewComprehension(copyID,
iterRange, compre.IterVar(), compre.AccuVar(), accuInit, cond, step, result)
// All comprehensions can be represented by the two-variable comprehension since the
// differentiation between one and two-variable is whether the iterVar2 value is non-empty.
return e.exprFactory.NewComprehensionTwoVar(copyID,
iterRange, compre.IterVar(), compre.IterVar2(), compre.AccuVar(), accuInit, cond, step, result)
}
return e.exprFactory.NewUnspecifiedExpr(copyID)
}
@@ -432,6 +446,20 @@ func (e *exprHelper) NewComprehension(
e.nextMacroID(), iterRange, iterVar, accuVar, accuInit, condition, step, result)
}
// NewComprehensionTwoVar implements the ExprHelper interface method.
func (e *exprHelper) NewComprehensionTwoVar(
iterRange ast.Expr,
iterVar,
iterVar2,
accuVar string,
accuInit,
condition,
step,
result ast.Expr) ast.Expr {
return e.exprFactory.NewComprehensionTwoVar(
e.nextMacroID(), iterRange, iterVar, iterVar2, accuVar, accuInit, condition, step, result)
}
// NewIdent implements the ExprHelper interface method.
func (e *exprHelper) NewIdent(name string) ast.Expr {
return e.exprFactory.NewIdent(e.nextMacroID(), name)

View File

@@ -170,11 +170,12 @@ type ExprHelper interface {
// NewStructField creates a new struct field initializer from the field name and value.
NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr
// NewComprehension creates a new comprehension instruction.
// NewComprehension creates a new one-variable comprehension instruction.
//
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - iterVar is the iteration variable name.
// - iterVar is the variable name for the list element value, or the map key, depending on the
// range type.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
@@ -186,11 +187,36 @@ type ExprHelper interface {
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
NewComprehension(iterRange ast.Expr,
iterVar string,
iterVar,
accuVar string,
accuInit ast.Expr,
condition ast.Expr,
step ast.Expr,
accuInit,
condition,
step,
result ast.Expr) ast.Expr
// NewComprehensionTwoVar creates a new two-variable comprehension instruction.
//
// - iterRange represents the expression that resolves to a list or map where the elements or
// keys (respectively) will be iterated over.
// - iterVar is the iteration variable assigned to the list index or the map key.
// - iterVar2 is the iteration variable assigned to the list element value or the map key value.
// - accuVar is the accumulation variable name, typically parser.AccumulatorName.
// - accuInit is the initial expression whose value will be set for the accuVar prior to
// folding.
// - condition is the expression to test to determine whether to continue folding.
// - step is the expression to evaluation at the conclusion of a single fold iteration.
// - result is the computation to evaluate at the conclusion of the fold.
//
// The accuVar should not shadow variable names that you would like to reference within the
// environment in the step and condition expressions. Presently, the name __result__ is commonly
// used by built-in macros but this may change in the future.
NewComprehensionTwoVar(iterRange ast.Expr,
iterVar,
iterVar2,
accuVar string,
accuInit,
condition,
step,
result ast.Expr) ast.Expr
// NewIdent creates an identifier Expr value.