diff options
Diffstat (limited to 'vendor/github.com/authzed/cel-go/interpreter/prune.go')
| -rw-r--r-- | vendor/github.com/authzed/cel-go/interpreter/prune.go | 543 |
1 files changed, 543 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/cel-go/interpreter/prune.go b/vendor/github.com/authzed/cel-go/interpreter/prune.go new file mode 100644 index 0000000..b093f2a --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/prune.go @@ -0,0 +1,543 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" +) + +type astPruner struct { + ast.ExprFactory + expr ast.Expr + macroCalls map[int64]ast.Expr + state EvalState + nextExprID int64 +} + +// TODO Consider having a separate walk of the AST that finds common +// subexpressions. This can be called before or after constant folding to find +// common subexpressions. + +// PruneAst prunes the given AST based on the given EvalState and generates a new AST. +// Given AST is copied on write and a new AST is returned. +// Couple of typical use cases this interface would be: +// +// A) +// 1) Evaluate expr with some unknowns, +// 2) If result is unknown: +// +// a) PruneAst +// b) Goto 1 +// +// Functional call results which are known would be effectively cached across +// iterations. +// +// B) +// 1) Compile the expression (maybe via a service and maybe after checking a +// +// compiled expression does not exists in local cache) +// +// 2) Prepare the environment and the interpreter. Activation might be empty. +// 3) Eval the expression. This might return unknown or error or a concrete +// +// value. +// +// 4) PruneAst +// 4) Maybe cache the expression +// This is effectively constant folding the expression. How the environment is +// prepared in step 2 is flexible. For example, If the caller caches the +// compiled and constant folded expressions, but is not willing to constant +// fold(and thus cache results of) some external calls, then they can prepare +// the overloads accordingly. +func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST { + pruneState := NewEvalState() + for _, id := range state.IDs() { + v, _ := state.Value(id) + pruneState.SetValue(id, v) + } + pruner := &astPruner{ + ExprFactory: ast.NewExprFactory(), + expr: expr, + macroCalls: macroCalls, + state: pruneState, + nextExprID: getMaxID(expr)} + newExpr, _ := pruner.maybePrune(expr) + newInfo := ast.NewSourceInfo(nil) + for id, call := range pruner.macroCalls { + newInfo.SetMacroCall(id, call) + } + return ast.NewAST(newExpr, newInfo) +} + +func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) { + switch v := val.(type) { + case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint: + p.state.SetValue(id, val) + return p.NewLiteral(id, val), true + case types.Duration: + p.state.SetValue(id, val) + durationString := v.ConvertToType(types.StringType).(types.String) + return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true + case types.Timestamp: + timestampString := v.ConvertToType(types.StringType).(types.String) + return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true + } + + // Attempt to build a list literal. + if list, isList := val.(traits.Lister); isList { + sz := list.Size().(types.Int) + elemExprs := make([]ast.Expr, sz) + for i := types.Int(0); i < sz; i++ { + elem := list.Get(i) + if types.IsUnknownOrError(elem) { + return nil, false + } + elemExpr, ok := p.maybeCreateLiteral(p.nextID(), elem) + if !ok { + return nil, false + } + elemExprs[i] = elemExpr + } + p.state.SetValue(id, val) + return p.NewList(id, elemExprs, []int32{}), true + } + + // Create a map literal if possible. + if mp, isMap := val.(traits.Mapper); isMap { + it := mp.Iterator() + entries := make([]ast.EntryExpr, mp.Size().(types.Int)) + i := 0 + for it.HasNext() != types.False { + key := it.Next() + val := mp.Get(key) + if types.IsUnknownOrError(key) || types.IsUnknownOrError(val) { + return nil, false + } + keyExpr, ok := p.maybeCreateLiteral(p.nextID(), key) + if !ok { + return nil, false + } + valExpr, ok := p.maybeCreateLiteral(p.nextID(), val) + if !ok { + return nil, false + } + entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false) + entries[i] = entry + i++ + } + p.state.SetValue(id, val) + return p.NewMap(id, entries), true + } + + // TODO(issues/377) To construct message literals, the type provider will need to support + // the enumeration the fields for a given message. + return nil, false +} + +func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) { + elemVal, found := p.value(elem.ID()) + if found && elemVal.Type() == types.OptionalType { + opt := elemVal.(*types.Optional) + if !opt.HasValue() { + return nil, true + } + if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned { + return newElem, true + } + } + return elem, false +} + +func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) { + // elem in list + call := node.AsCall() + val, exists := p.maybeValue(call.Args()[1].ID()) + if !exists { + return nil, false + } + if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { + return p.maybeCreateLiteral(node.ID(), types.False) + } + return nil, false +} + +func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + arg := call.Args()[0] + val, exists := p.maybeValue(arg.ID()) + if !exists { + return nil, false + } + if b, ok := val.(types.Bool); ok { + return p.maybeCreateLiteral(node.ID(), !b) + } + return nil, false +} + +func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + // We know result is unknown, so we have at least one unknown arg + // and if one side is a known value, we know we can ignore it. + if v, exists := p.maybeValue(call.Args()[0].ID()); exists { + if v == types.True { + return p.maybeCreateLiteral(node.ID(), types.True) + } + return call.Args()[1], true + } + if v, exists := p.maybeValue(call.Args()[1].ID()); exists { + if v == types.True { + return p.maybeCreateLiteral(node.ID(), types.True) + } + return call.Args()[0], true + } + return nil, false +} + +func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + // We know result is unknown, so we have at least one unknown arg + // and if one side is a known value, we know we can ignore it. + if v, exists := p.maybeValue(call.Args()[0].ID()); exists { + if v == types.False { + return p.maybeCreateLiteral(node.ID(), types.False) + } + return call.Args()[1], true + } + if v, exists := p.maybeValue(call.Args()[1].ID()); exists { + if v == types.False { + return p.maybeCreateLiteral(node.ID(), types.False) + } + return call.Args()[0], true + } + return nil, false +} + +func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + cond, exists := p.maybeValue(call.Args()[0].ID()) + if !exists { + return nil, false + } + if cond.Value().(bool) { + return call.Args()[1], true + } + return call.Args()[2], true +} + +func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) { + if _, exists := p.value(node.ID()); !exists { + return nil, false + } + call := node.AsCall() + if call.FunctionName() == operators.LogicalOr { + return p.maybePruneOr(node) + } + if call.FunctionName() == operators.LogicalAnd { + return p.maybePruneAnd(node) + } + if call.FunctionName() == operators.Conditional { + return p.maybePruneConditional(node) + } + if call.FunctionName() == operators.In { + return p.maybePruneIn(node) + } + if call.FunctionName() == operators.LogicalNot { + return p.maybePruneLogicalNot(node) + } + return nil, false +} + +func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) { + return p.prune(node) +} + +func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) { + if node == nil { + return node, false + } + val, valueExists := p.maybeValue(node.ID()) + if valueExists { + if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok { + delete(p.macroCalls, node.ID()) + return newNode, true + } + } + if macro, found := p.macroCalls[node.ID()]; found { + // Ensure that intermediate values for the comprehension are cleared during pruning + if node.Kind() == ast.ComprehensionKind { + compre := node.AsComprehension() + visit(macro, clearIterVarVisitor(compre.IterVar(), p.state)) + } + // prune the expression in terms of the macro call instead of the expanded form. + if newMacro, pruned := p.prune(macro); pruned { + p.macroCalls[node.ID()] = newMacro + } + } + + // We have either an unknown/error value, or something we don't want to + // transform, or expression was not evaluated. If possible, drill down + // more. + switch node.Kind() { + case ast.SelectKind: + sel := node.AsSelect() + if operand, isPruned := p.maybePrune(sel.Operand()); isPruned { + if sel.IsTestOnly() { + return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true + } + return p.NewSelect(node.ID(), operand, sel.FieldName()), true + } + case ast.CallKind: + argsPruned := false + call := node.AsCall() + args := call.Args() + newArgs := make([]ast.Expr, len(args)) + for i, a := range args { + newArgs[i] = a + if arg, isPruned := p.maybePrune(a); isPruned { + argsPruned = true + newArgs[i] = arg + } + } + if !call.IsMemberFunction() { + newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...) + if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { + return prunedCall, true + } + return newCall, argsPruned + } + newTarget := call.Target() + targetPruned := false + if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned { + targetPruned = true + newTarget = prunedTarget + } + newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...) + if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { + return prunedCall, true + } + return newCall, targetPruned || argsPruned + case ast.ListKind: + l := node.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + optIndexMap := map[int32]bool{} + for _, i := range optIndices { + optIndexMap[i] = true + } + newOptIndexMap := make(map[int32]bool, len(optIndexMap)) + newElems := make([]ast.Expr, 0, len(elems)) + var listPruned bool + prunedIdx := 0 + for i, elem := range elems { + _, isOpt := optIndexMap[int32(i)] + if isOpt { + newElem, pruned := p.maybePruneOptional(elem) + if pruned { + listPruned = true + if newElem != nil { + newElems = append(newElems, newElem) + prunedIdx++ + } + continue + } + newOptIndexMap[int32(prunedIdx)] = true + } + if newElem, prunedElem := p.maybePrune(elem); prunedElem { + newElems = append(newElems, newElem) + listPruned = true + } else { + newElems = append(newElems, elem) + } + prunedIdx++ + } + optIndices = make([]int32, len(newOptIndexMap)) + idx := 0 + for i := range newOptIndexMap { + optIndices[idx] = i + idx++ + } + if listPruned { + return p.NewList(node.ID(), newElems, optIndices), true + } + case ast.MapKind: + var mapPruned bool + m := node.AsMap() + entries := m.Entries() + newEntries := make([]ast.EntryExpr, len(entries)) + for i, entry := range entries { + newEntries[i] = entry + e := entry.AsMapEntry() + newKey, keyPruned := p.maybePrune(e.Key()) + newValue, valuePruned := p.maybePrune(e.Value()) + if !keyPruned && !valuePruned { + continue + } + mapPruned = true + newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional()) + newEntries[i] = newEntry + } + if mapPruned { + return p.NewMap(node.ID(), newEntries), true + } + case ast.StructKind: + var structPruned bool + obj := node.AsStruct() + fields := obj.Fields() + newFields := make([]ast.EntryExpr, len(fields)) + for i, field := range fields { + newFields[i] = field + f := field.AsStructField() + newValue, prunedValue := p.maybePrune(f.Value()) + if !prunedValue { + continue + } + structPruned = true + newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional()) + newFields[i] = newEntry + } + if structPruned { + return p.NewStruct(node.ID(), obj.TypeName(), newFields), true + } + case ast.ComprehensionKind: + compre := node.AsComprehension() + // Only the range of the comprehension is pruned since the state tracking only records + // the last iteration of the comprehension and not each step in the evaluation which + // means that the any residuals computed in between might be inaccurate. + if newRange, pruned := p.maybePrune(compre.IterRange()); pruned { + return p.NewComprehension( + node.ID(), + newRange, + compre.IterVar(), + compre.AccuVar(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result(), + ), true + } + } + return node, false +} + +func (p *astPruner) value(id int64) (ref.Val, bool) { + val, found := p.state.Value(id) + return val, (found && val != nil) +} + +func (p *astPruner) maybeValue(id int64) (ref.Val, bool) { + val, found := p.value(id) + if !found || types.IsUnknownOrError(val) { + return nil, false + } + return val, true +} + +func (p *astPruner) nextID() int64 { + next := p.nextExprID + p.nextExprID++ + return next +} + +type astVisitor struct { + // visitEntry is called on every expr node, including those within a map/struct entry. + visitExpr func(expr ast.Expr) + // visitEntry is called before entering the key, value of a map/struct entry. + visitEntry func(entry ast.EntryExpr) +} + +func getMaxID(expr ast.Expr) int64 { + maxID := int64(1) + visit(expr, maxIDVisitor(&maxID)) + return maxID +} + +func clearIterVarVisitor(varName string, state EvalState) astVisitor { + return astVisitor{ + visitExpr: func(e ast.Expr) { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + state.SetValue(e.ID(), nil) + } + }, + } +} + +func maxIDVisitor(maxID *int64) astVisitor { + return astVisitor{ + visitExpr: func(e ast.Expr) { + if e.ID() >= *maxID { + *maxID = e.ID() + 1 + } + }, + visitEntry: func(e ast.EntryExpr) { + if e.ID() >= *maxID { + *maxID = e.ID() + 1 + } + }, + } +} + +func visit(expr ast.Expr, visitor astVisitor) { + exprs := []ast.Expr{expr} + for len(exprs) != 0 { + e := exprs[0] + if visitor.visitExpr != nil { + visitor.visitExpr(e) + } + exprs = exprs[1:] + switch e.Kind() { + case ast.SelectKind: + exprs = append(exprs, e.AsSelect().Operand()) + case ast.CallKind: + call := e.AsCall() + if call.Target() != nil { + exprs = append(exprs, call.Target()) + } + exprs = append(exprs, call.Args()...) + case ast.ComprehensionKind: + compre := e.AsComprehension() + exprs = append(exprs, + compre.IterRange(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result()) + case ast.ListKind: + list := e.AsList() + exprs = append(exprs, list.Elements()...) + case ast.MapKind: + for _, entry := range e.AsMap().Entries() { + e := entry.AsMapEntry() + if visitor.visitEntry != nil { + visitor.visitEntry(entry) + } + exprs = append(exprs, e.Key()) + exprs = append(exprs, e.Value()) + } + case ast.StructKind: + for _, entry := range e.AsStruct().Fields() { + f := entry.AsStructField() + if visitor.visitEntry != nil { + visitor.visitEntry(entry) + } + exprs = append(exprs, f.Value()) + } + } + } +} |
