diff options
Diffstat (limited to 'vendor/github.com/authzed/cel-go/cel/folding.go')
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/folding.go | 559 |
1 files changed, 559 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/cel-go/cel/folding.go b/vendor/github.com/authzed/cel-go/cel/folding.go new file mode 100644 index 0000000..9a2d904 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/folding.go @@ -0,0 +1,559 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" +) + +// ConstantFoldingOption defines a functional option for configuring constant folding. +type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) + +// MaxConstantFoldIterations limits the number of times literals may be folding during optimization. +// +// Defaults to 100 if not set. +func MaxConstantFoldIterations(limit int) ConstantFoldingOption { + return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) { + opt.maxFoldIterations = limit + return opt, nil + } +} + +// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate +// literal values within function calls and select statements with their evaluated result. +func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) { + folder := &constantFoldingOptimizer{ + maxFoldIterations: defaultMaxConstantFoldIterations, + } + var err error + for _, o := range opts { + folder, err = o(folder) + if err != nil { + return nil, err + } + } + return folder, nil +} + +type constantFoldingOptimizer struct { + maxFoldIterations int +} + +// Optimize queries the expression graph for scalar and aggregate literal expressions within call and +// select statements and then evaluates them and replaces the call site with the literal result. +// +// Note: only values which can be represented as literals in CEL syntax are supported. +func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + + // Walk the list of foldable expression and continue to fold until there are no more folds left. + // All of the fold candidates returned by the constantExprMatcher should succeed unless there's + // a logic bug with the selection of expressions. + foldableExprs := ast.MatchDescendants(root, constantExprMatcher) + foldCount := 0 + for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations { + for _, fold := range foldableExprs { + // If the expression could be folded because it's a non-strict call, and the + // branches are pruned, continue to the next fold. + if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) { + continue + } + // Otherwise, assume all context is needed to evaluate the expression. + err := tryFold(ctx, a, fold) + if err != nil { + ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error()) + return a + } + } + foldCount++ + foldableExprs = ast.MatchDescendants(root, constantExprMatcher) + } + // Once all of the constants have been folded, try to run through the remaining comprehensions + // one last time. In this case, there's no guarantee they'll run, so we only update the + // target comprehension node with the literal value if the evaluation succeeds. + for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) { + tryFold(ctx, a, compre) + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + pruneOptionalElements(ctx, root) + + // Ensure that all intermediate values in the folded expression can be represented as valid + // CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents` + // to avoid extra allocations during this final pass through the AST. + ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() != ast.LiteralKind { + return + } + val := e.AsLiteral() + adapted, err := adaptLiteral(ctx, val) + if err != nil { + ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error()) + return + } + ctx.UpdateExpr(e, adapted) + })) + + return a +} + +// tryFold attempts to evaluate a sub-expression to a literal. +// +// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise +// the method will return an error. +func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { + // Assume all context is needed to evaluate the expression. + subAST := &Ast{ + impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()), + } + prg, err := ctx.Program(subAST) + if err != nil { + return err + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + return err + } + // Update the fold expression to be a literal. + ctx.UpdateExpr(expr, ctx.NewLiteral(out)) + return nil +} + +// maybePruneBranches inspects the non-strict call expression to determine whether +// a branch can be removed. Evaluation will naturally prune logical and / or calls, +// but conditional will not be pruned cleanly, so this is one small area where the +// constant folding step reimplements a portion of the evaluator. +func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool { + call := expr.AsCall() + args := call.Args() + switch call.FunctionName() { + case operators.LogicalAnd, operators.LogicalOr: + return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr) + case operators.Conditional: + cond := args[0] + truthy := args[1] + falsy := args[2] + if cond.Kind() != ast.LiteralKind { + return false + } + if cond.AsLiteral() == types.True { + ctx.UpdateExpr(expr, truthy) + } else { + ctx.UpdateExpr(expr, falsy) + } + return true + case operators.In: + haystack := args[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + ctx.UpdateExpr(expr, ctx.NewLiteral(types.False)) + return true + } + needle := args[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + ctx.UpdateExpr(expr, ctx.NewLiteral(types.True)) + return true + } + } + } + } + return false +} + +func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool { + shortcircuit := types.False + skip := types.True + if function == operators.LogicalOr { + shortcircuit = types.True + skip = types.False + } + newArgs := []ast.Expr{} + for _, arg := range args { + if arg.Kind() != ast.LiteralKind { + newArgs = append(newArgs, arg) + continue + } + if arg.AsLiteral() == skip { + continue + } + if arg.AsLiteral() == shortcircuit { + ctx.UpdateExpr(expr, arg) + return true + } + } + if len(newArgs) == 0 { + newArgs = append(newArgs, args[0]) + ctx.UpdateExpr(expr, newArgs[0]) + return true + } + if len(newArgs) == 1 { + ctx.UpdateExpr(expr, newArgs[0]) + return true + } + ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...)) + return true +} + +// pruneOptionalElements works from the bottom up to resolve optional elements within +// aggregate literals. +// +// Note, many aggregate literals will be resolved as arguments to functions or select +// statements, so this method exists to handle the case where the literal could not be +// fully resolved or exists outside of a call, select, or comprehension context. +func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) { + aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher) + for _, lit := range aggregateLiterals { + switch lit.Kind() { + case ast.ListKind: + pruneOptionalListElements(ctx, lit) + case ast.MapKind: + pruneOptionalMapEntries(ctx, lit) + case ast.StructKind: + pruneOptionalStructFields(ctx, lit) + } + } +} + +func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) { + l := e.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + if len(optIndices) == 0 { + return + } + updatedElems := []ast.Expr{} + updatedIndices := []int32{} + newOptIndex := -1 + for _, e := range elems { + newOptIndex++ + if !l.IsOptional(int32(newOptIndex)) { + updatedElems = append(updatedElems, e) + continue + } + if e.Kind() != ast.LiteralKind { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(newOptIndex)) + continue + } + optElemVal, ok := e.AsLiteral().(*types.Optional) + if !ok { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(newOptIndex)) + continue + } + if !optElemVal.HasValue() { + newOptIndex-- // Skipping causes the list to get smaller. + continue + } + ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue())) + updatedElems = append(updatedElems, e) + } + ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices)) +} + +func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { + m := e.AsMap() + entries := m.Entries() + updatedEntries := []ast.EntryExpr{} + modified := false + for _, e := range entries { + entry := e.AsMapEntry() + key := entry.Key() + val := entry.Value() + // If the entry is not optional, or the value-side of the optional hasn't + // been resolved to a literal, then preserve the entry as-is. + if !entry.IsOptional() || val.Kind() != ast.LiteralKind { + updatedEntries = append(updatedEntries, e) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedEntries = append(updatedEntries, e) + continue + } + // When the key is not a literal, but the value is, then it needs to be + // restored to an optional value. + if key.Kind() != ast.LiteralKind { + undoOptVal, err := adaptLiteral(ctx, optElemVal) + if err != nil { + ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err) + } + ctx.UpdateExpr(val, undoOptVal) + updatedEntries = append(updatedEntries, e) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue())) + updatedEntry := ctx.NewMapEntry(key, val, false) + updatedEntries = append(updatedEntries, updatedEntry) + } + if modified { + ctx.UpdateExpr(e, ctx.NewMap(updatedEntries)) + } +} + +func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) { + s := e.AsStruct() + fields := s.Fields() + updatedFields := []ast.EntryExpr{} + modified := false + for _, f := range fields { + field := f.AsStructField() + val := field.Value() + if !field.IsOptional() || val.Kind() != ast.LiteralKind { + updatedFields = append(updatedFields, f) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedFields = append(updatedFields, f) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue())) + updatedField := ctx.NewStructField(field.Name(), val, false) + updatedFields = append(updatedFields, updatedField) + } + if modified { + ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields)) + } +} + +// adaptLiteral converts a runtime CEL value to its equivalent literal expression. +// +// For strongly typed values, the type-provider will be used to reconstruct the fields +// which are present in the literal and their equivalent initialization values. +func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) { + switch t := val.Type().(type) { + case *types.Type: + switch t { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + return ctx.NewLiteral(val), nil + case types.DurationType: + return ctx.NewCall( + overloads.TypeConvertDuration, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.TimestampType: + return ctx.NewCall( + overloads.TypeConvertTimestamp, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.OptionalType: + opt := val.(*types.Optional) + if !opt.HasValue() { + return ctx.NewCall("optional.none"), nil + } + target, err := adaptLiteral(ctx, opt.GetValue()) + if err != nil { + return nil, err + } + return ctx.NewCall("optional.of", target), nil + case types.TypeType: + return ctx.NewIdent(val.(*types.Type).TypeName()), nil + case types.ListType: + l, ok := val.(traits.Lister) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + elems := make([]ast.Expr, l.Size().(types.Int)) + idx := 0 + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + elemExpr, err := adaptLiteral(ctx, elemVal) + if err != nil { + return nil, err + } + elems[idx] = elemExpr + idx++ + } + return ctx.NewList(elems, []int32{}), nil + case types.MapType: + m, ok := val.(traits.Mapper) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + entries := make([]ast.EntryExpr, m.Size().(types.Int)) + idx := 0 + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + keyExpr, err := adaptLiteral(ctx, keyVal) + if err != nil { + return nil, err + } + valVal := m.Get(keyVal) + valExpr, err := adaptLiteral(ctx, valVal) + if err != nil { + return nil, err + } + entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false) + idx++ + } + return ctx.NewMap(entries), nil + default: + provider := ctx.CELTypeProvider() + fields, found := provider.FindStructFieldNames(t.TypeName()) + if !found { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + tester := val.(traits.FieldTester) + indexer := val.(traits.Indexer) + fieldInits := []ast.EntryExpr{} + for _, f := range fields { + field := types.String(f) + if tester.IsSet(field) != types.True { + continue + } + fieldVal := indexer.Get(field) + fieldExpr, err := adaptLiteral(ctx, fieldVal) + if err != nil { + return nil, err + } + fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false)) + } + return ctx.NewStruct(t.TypeName(), fieldInits), nil + } + } + return nil, fmt.Errorf("failed to adapt %v to literal", val) +} + +// constantExprMatcher matches calls, select statements, and comprehensions whose arguments +// are all constant scalar or aggregate literal values. +// +// Only comprehensions which are not nested are included as possible constant folds, and only +// if all variables referenced in the comprehension stack exist are only iteration or +// accumulation variables. +func constantExprMatcher(e ast.NavigableExpr) bool { + switch e.Kind() { + case ast.CallKind: + return constantCallMatcher(e) + case ast.SelectKind: + sel := e.AsSelect() // guaranteed to be a navigable value + return constantMatcher(sel.Operand().(ast.NavigableExpr)) + case ast.ComprehensionKind: + if isNestedComprehension(e) { + return false + } + vars := map[string]bool{} + constantExprs := true + visitor := ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() == ast.ComprehensionKind { + nested := e.AsComprehension() + vars[nested.AccuVar()] = true + vars[nested.IterVar()] = true + } + if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { + constantExprs = false + } + }) + ast.PreOrderVisit(e, visitor) + return constantExprs + default: + return false + } +} + +// constantCallMatcher identifies strict and non-strict calls which can be folded. +func constantCallMatcher(e ast.NavigableExpr) bool { + call := e.AsCall() + children := e.Children() + fnName := call.FunctionName() + if fnName == operators.LogicalAnd { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.LogicalOr { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.Conditional { + cond := children[0] + if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType { + return true + } + } + if fnName == operators.In { + haystack := children[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + return true + } + needle := children[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + return true + } + } + } + } + // convert all other calls with constant arguments + for _, child := range children { + if !constantMatcher(child) { + return false + } + } + return true +} + +func isNestedComprehension(e ast.NavigableExpr) bool { + parent, found := e.Parent() + for found { + if parent.Kind() == ast.ComprehensionKind { + return true + } + parent, found = parent.Parent() + } + return false +} + +func aggregateLiteralMatcher(e ast.NavigableExpr) bool { + return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind +} + +var ( + constantMatcher = ast.ConstantValueMatcher() +) + +const ( + defaultMaxConstantFoldIterations = 100 +) |
