summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/cel-go/cel/folding.go
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/cel-go/cel/folding.go
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/cel-go/cel/folding.go')
-rw-r--r--vendor/github.com/authzed/cel-go/cel/folding.go559
1 files changed, 559 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/cel-go/cel/folding.go b/vendor/github.com/authzed/cel-go/cel/folding.go
new file mode 100644
index 0000000..9a2d904
--- /dev/null
+++ b/vendor/github.com/authzed/cel-go/cel/folding.go
@@ -0,0 +1,559 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cel
+
+import (
+ "fmt"
+
+ "github.com/authzed/cel-go/common/ast"
+ "github.com/authzed/cel-go/common/operators"
+ "github.com/authzed/cel-go/common/overloads"
+ "github.com/authzed/cel-go/common/types"
+ "github.com/authzed/cel-go/common/types/ref"
+ "github.com/authzed/cel-go/common/types/traits"
+)
+
+// ConstantFoldingOption defines a functional option for configuring constant folding.
+type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error)
+
+// MaxConstantFoldIterations limits the number of times literals may be folding during optimization.
+//
+// Defaults to 100 if not set.
+func MaxConstantFoldIterations(limit int) ConstantFoldingOption {
+ return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) {
+ opt.maxFoldIterations = limit
+ return opt, nil
+ }
+}
+
+// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate
+// literal values within function calls and select statements with their evaluated result.
+func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) {
+ folder := &constantFoldingOptimizer{
+ maxFoldIterations: defaultMaxConstantFoldIterations,
+ }
+ var err error
+ for _, o := range opts {
+ folder, err = o(folder)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return folder, nil
+}
+
+type constantFoldingOptimizer struct {
+ maxFoldIterations int
+}
+
+// Optimize queries the expression graph for scalar and aggregate literal expressions within call and
+// select statements and then evaluates them and replaces the call site with the literal result.
+//
+// Note: only values which can be represented as literals in CEL syntax are supported.
+func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST {
+ root := ast.NavigateAST(a)
+
+ // Walk the list of foldable expression and continue to fold until there are no more folds left.
+ // All of the fold candidates returned by the constantExprMatcher should succeed unless there's
+ // a logic bug with the selection of expressions.
+ foldableExprs := ast.MatchDescendants(root, constantExprMatcher)
+ foldCount := 0
+ for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations {
+ for _, fold := range foldableExprs {
+ // If the expression could be folded because it's a non-strict call, and the
+ // branches are pruned, continue to the next fold.
+ if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) {
+ continue
+ }
+ // Otherwise, assume all context is needed to evaluate the expression.
+ err := tryFold(ctx, a, fold)
+ if err != nil {
+ ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error())
+ return a
+ }
+ }
+ foldCount++
+ foldableExprs = ast.MatchDescendants(root, constantExprMatcher)
+ }
+ // Once all of the constants have been folded, try to run through the remaining comprehensions
+ // one last time. In this case, there's no guarantee they'll run, so we only update the
+ // target comprehension node with the literal value if the evaluation succeeds.
+ for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) {
+ tryFold(ctx, a, compre)
+ }
+
+ // If the output is a list, map, or struct which contains optional entries, then prune it
+ // to make sure that the optionals, if resolved, do not surface in the output literal.
+ pruneOptionalElements(ctx, root)
+
+ // Ensure that all intermediate values in the folded expression can be represented as valid
+ // CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents`
+ // to avoid extra allocations during this final pass through the AST.
+ ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) {
+ if e.Kind() != ast.LiteralKind {
+ return
+ }
+ val := e.AsLiteral()
+ adapted, err := adaptLiteral(ctx, val)
+ if err != nil {
+ ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error())
+ return
+ }
+ ctx.UpdateExpr(e, adapted)
+ }))
+
+ return a
+}
+
+// tryFold attempts to evaluate a sub-expression to a literal.
+//
+// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise
+// the method will return an error.
+func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error {
+ // Assume all context is needed to evaluate the expression.
+ subAST := &Ast{
+ impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()),
+ }
+ prg, err := ctx.Program(subAST)
+ if err != nil {
+ return err
+ }
+ out, _, err := prg.Eval(NoVars())
+ if err != nil {
+ return err
+ }
+ // Update the fold expression to be a literal.
+ ctx.UpdateExpr(expr, ctx.NewLiteral(out))
+ return nil
+}
+
+// maybePruneBranches inspects the non-strict call expression to determine whether
+// a branch can be removed. Evaluation will naturally prune logical and / or calls,
+// but conditional will not be pruned cleanly, so this is one small area where the
+// constant folding step reimplements a portion of the evaluator.
+func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool {
+ call := expr.AsCall()
+ args := call.Args()
+ switch call.FunctionName() {
+ case operators.LogicalAnd, operators.LogicalOr:
+ return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr)
+ case operators.Conditional:
+ cond := args[0]
+ truthy := args[1]
+ falsy := args[2]
+ if cond.Kind() != ast.LiteralKind {
+ return false
+ }
+ if cond.AsLiteral() == types.True {
+ ctx.UpdateExpr(expr, truthy)
+ } else {
+ ctx.UpdateExpr(expr, falsy)
+ }
+ return true
+ case operators.In:
+ haystack := args[1]
+ if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
+ ctx.UpdateExpr(expr, ctx.NewLiteral(types.False))
+ return true
+ }
+ needle := args[0]
+ if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
+ needleValue := needle.AsLiteral()
+ list := haystack.AsList()
+ for _, e := range list.Elements() {
+ if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
+ ctx.UpdateExpr(expr, ctx.NewLiteral(types.True))
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
+func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool {
+ shortcircuit := types.False
+ skip := types.True
+ if function == operators.LogicalOr {
+ shortcircuit = types.True
+ skip = types.False
+ }
+ newArgs := []ast.Expr{}
+ for _, arg := range args {
+ if arg.Kind() != ast.LiteralKind {
+ newArgs = append(newArgs, arg)
+ continue
+ }
+ if arg.AsLiteral() == skip {
+ continue
+ }
+ if arg.AsLiteral() == shortcircuit {
+ ctx.UpdateExpr(expr, arg)
+ return true
+ }
+ }
+ if len(newArgs) == 0 {
+ newArgs = append(newArgs, args[0])
+ ctx.UpdateExpr(expr, newArgs[0])
+ return true
+ }
+ if len(newArgs) == 1 {
+ ctx.UpdateExpr(expr, newArgs[0])
+ return true
+ }
+ ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...))
+ return true
+}
+
+// pruneOptionalElements works from the bottom up to resolve optional elements within
+// aggregate literals.
+//
+// Note, many aggregate literals will be resolved as arguments to functions or select
+// statements, so this method exists to handle the case where the literal could not be
+// fully resolved or exists outside of a call, select, or comprehension context.
+func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) {
+ aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher)
+ for _, lit := range aggregateLiterals {
+ switch lit.Kind() {
+ case ast.ListKind:
+ pruneOptionalListElements(ctx, lit)
+ case ast.MapKind:
+ pruneOptionalMapEntries(ctx, lit)
+ case ast.StructKind:
+ pruneOptionalStructFields(ctx, lit)
+ }
+ }
+}
+
+func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) {
+ l := e.AsList()
+ elems := l.Elements()
+ optIndices := l.OptionalIndices()
+ if len(optIndices) == 0 {
+ return
+ }
+ updatedElems := []ast.Expr{}
+ updatedIndices := []int32{}
+ newOptIndex := -1
+ for _, e := range elems {
+ newOptIndex++
+ if !l.IsOptional(int32(newOptIndex)) {
+ updatedElems = append(updatedElems, e)
+ continue
+ }
+ if e.Kind() != ast.LiteralKind {
+ updatedElems = append(updatedElems, e)
+ updatedIndices = append(updatedIndices, int32(newOptIndex))
+ continue
+ }
+ optElemVal, ok := e.AsLiteral().(*types.Optional)
+ if !ok {
+ updatedElems = append(updatedElems, e)
+ updatedIndices = append(updatedIndices, int32(newOptIndex))
+ continue
+ }
+ if !optElemVal.HasValue() {
+ newOptIndex-- // Skipping causes the list to get smaller.
+ continue
+ }
+ ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue()))
+ updatedElems = append(updatedElems, e)
+ }
+ ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices))
+}
+
+func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) {
+ m := e.AsMap()
+ entries := m.Entries()
+ updatedEntries := []ast.EntryExpr{}
+ modified := false
+ for _, e := range entries {
+ entry := e.AsMapEntry()
+ key := entry.Key()
+ val := entry.Value()
+ // If the entry is not optional, or the value-side of the optional hasn't
+ // been resolved to a literal, then preserve the entry as-is.
+ if !entry.IsOptional() || val.Kind() != ast.LiteralKind {
+ updatedEntries = append(updatedEntries, e)
+ continue
+ }
+ optElemVal, ok := val.AsLiteral().(*types.Optional)
+ if !ok {
+ updatedEntries = append(updatedEntries, e)
+ continue
+ }
+ // When the key is not a literal, but the value is, then it needs to be
+ // restored to an optional value.
+ if key.Kind() != ast.LiteralKind {
+ undoOptVal, err := adaptLiteral(ctx, optElemVal)
+ if err != nil {
+ ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err)
+ }
+ ctx.UpdateExpr(val, undoOptVal)
+ updatedEntries = append(updatedEntries, e)
+ continue
+ }
+ modified = true
+ if !optElemVal.HasValue() {
+ continue
+ }
+ ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
+ updatedEntry := ctx.NewMapEntry(key, val, false)
+ updatedEntries = append(updatedEntries, updatedEntry)
+ }
+ if modified {
+ ctx.UpdateExpr(e, ctx.NewMap(updatedEntries))
+ }
+}
+
+func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) {
+ s := e.AsStruct()
+ fields := s.Fields()
+ updatedFields := []ast.EntryExpr{}
+ modified := false
+ for _, f := range fields {
+ field := f.AsStructField()
+ val := field.Value()
+ if !field.IsOptional() || val.Kind() != ast.LiteralKind {
+ updatedFields = append(updatedFields, f)
+ continue
+ }
+ optElemVal, ok := val.AsLiteral().(*types.Optional)
+ if !ok {
+ updatedFields = append(updatedFields, f)
+ continue
+ }
+ modified = true
+ if !optElemVal.HasValue() {
+ continue
+ }
+ ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue()))
+ updatedField := ctx.NewStructField(field.Name(), val, false)
+ updatedFields = append(updatedFields, updatedField)
+ }
+ if modified {
+ ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields))
+ }
+}
+
+// adaptLiteral converts a runtime CEL value to its equivalent literal expression.
+//
+// For strongly typed values, the type-provider will be used to reconstruct the fields
+// which are present in the literal and their equivalent initialization values.
+func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) {
+ switch t := val.Type().(type) {
+ case *types.Type:
+ switch t {
+ case types.BoolType, types.BytesType, types.DoubleType, types.IntType,
+ types.NullType, types.StringType, types.UintType:
+ return ctx.NewLiteral(val), nil
+ case types.DurationType:
+ return ctx.NewCall(
+ overloads.TypeConvertDuration,
+ ctx.NewLiteral(val.ConvertToType(types.StringType)),
+ ), nil
+ case types.TimestampType:
+ return ctx.NewCall(
+ overloads.TypeConvertTimestamp,
+ ctx.NewLiteral(val.ConvertToType(types.StringType)),
+ ), nil
+ case types.OptionalType:
+ opt := val.(*types.Optional)
+ if !opt.HasValue() {
+ return ctx.NewCall("optional.none"), nil
+ }
+ target, err := adaptLiteral(ctx, opt.GetValue())
+ if err != nil {
+ return nil, err
+ }
+ return ctx.NewCall("optional.of", target), nil
+ case types.TypeType:
+ return ctx.NewIdent(val.(*types.Type).TypeName()), nil
+ case types.ListType:
+ l, ok := val.(traits.Lister)
+ if !ok {
+ return nil, fmt.Errorf("failed to adapt %v to literal", val)
+ }
+ elems := make([]ast.Expr, l.Size().(types.Int))
+ idx := 0
+ it := l.Iterator()
+ for it.HasNext() == types.True {
+ elemVal := it.Next()
+ elemExpr, err := adaptLiteral(ctx, elemVal)
+ if err != nil {
+ return nil, err
+ }
+ elems[idx] = elemExpr
+ idx++
+ }
+ return ctx.NewList(elems, []int32{}), nil
+ case types.MapType:
+ m, ok := val.(traits.Mapper)
+ if !ok {
+ return nil, fmt.Errorf("failed to adapt %v to literal", val)
+ }
+ entries := make([]ast.EntryExpr, m.Size().(types.Int))
+ idx := 0
+ it := m.Iterator()
+ for it.HasNext() == types.True {
+ keyVal := it.Next()
+ keyExpr, err := adaptLiteral(ctx, keyVal)
+ if err != nil {
+ return nil, err
+ }
+ valVal := m.Get(keyVal)
+ valExpr, err := adaptLiteral(ctx, valVal)
+ if err != nil {
+ return nil, err
+ }
+ entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false)
+ idx++
+ }
+ return ctx.NewMap(entries), nil
+ default:
+ provider := ctx.CELTypeProvider()
+ fields, found := provider.FindStructFieldNames(t.TypeName())
+ if !found {
+ return nil, fmt.Errorf("failed to adapt %v to literal", val)
+ }
+ tester := val.(traits.FieldTester)
+ indexer := val.(traits.Indexer)
+ fieldInits := []ast.EntryExpr{}
+ for _, f := range fields {
+ field := types.String(f)
+ if tester.IsSet(field) != types.True {
+ continue
+ }
+ fieldVal := indexer.Get(field)
+ fieldExpr, err := adaptLiteral(ctx, fieldVal)
+ if err != nil {
+ return nil, err
+ }
+ fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false))
+ }
+ return ctx.NewStruct(t.TypeName(), fieldInits), nil
+ }
+ }
+ return nil, fmt.Errorf("failed to adapt %v to literal", val)
+}
+
+// constantExprMatcher matches calls, select statements, and comprehensions whose arguments
+// are all constant scalar or aggregate literal values.
+//
+// Only comprehensions which are not nested are included as possible constant folds, and only
+// if all variables referenced in the comprehension stack exist are only iteration or
+// accumulation variables.
+func constantExprMatcher(e ast.NavigableExpr) bool {
+ switch e.Kind() {
+ case ast.CallKind:
+ return constantCallMatcher(e)
+ case ast.SelectKind:
+ sel := e.AsSelect() // guaranteed to be a navigable value
+ return constantMatcher(sel.Operand().(ast.NavigableExpr))
+ case ast.ComprehensionKind:
+ if isNestedComprehension(e) {
+ return false
+ }
+ vars := map[string]bool{}
+ constantExprs := true
+ visitor := ast.NewExprVisitor(func(e ast.Expr) {
+ if e.Kind() == ast.ComprehensionKind {
+ nested := e.AsComprehension()
+ vars[nested.AccuVar()] = true
+ vars[nested.IterVar()] = true
+ }
+ if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
+ constantExprs = false
+ }
+ })
+ ast.PreOrderVisit(e, visitor)
+ return constantExprs
+ default:
+ return false
+ }
+}
+
+// constantCallMatcher identifies strict and non-strict calls which can be folded.
+func constantCallMatcher(e ast.NavigableExpr) bool {
+ call := e.AsCall()
+ children := e.Children()
+ fnName := call.FunctionName()
+ if fnName == operators.LogicalAnd {
+ for _, child := range children {
+ if child.Kind() == ast.LiteralKind {
+ return true
+ }
+ }
+ }
+ if fnName == operators.LogicalOr {
+ for _, child := range children {
+ if child.Kind() == ast.LiteralKind {
+ return true
+ }
+ }
+ }
+ if fnName == operators.Conditional {
+ cond := children[0]
+ if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType {
+ return true
+ }
+ }
+ if fnName == operators.In {
+ haystack := children[1]
+ if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 {
+ return true
+ }
+ needle := children[0]
+ if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind {
+ needleValue := needle.AsLiteral()
+ list := haystack.AsList()
+ for _, e := range list.Elements() {
+ if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True {
+ return true
+ }
+ }
+ }
+ }
+ // convert all other calls with constant arguments
+ for _, child := range children {
+ if !constantMatcher(child) {
+ return false
+ }
+ }
+ return true
+}
+
+func isNestedComprehension(e ast.NavigableExpr) bool {
+ parent, found := e.Parent()
+ for found {
+ if parent.Kind() == ast.ComprehensionKind {
+ return true
+ }
+ parent, found = parent.Parent()
+ }
+ return false
+}
+
+func aggregateLiteralMatcher(e ast.NavigableExpr) bool {
+ return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind
+}
+
+var (
+ constantMatcher = ast.ConstantValueMatcher()
+)
+
+const (
+ defaultMaxConstantFoldIterations = 100
+)