summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/cel-go/parser/helper.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/authzed/cel-go/parser/helper.go')
-rw-r--r--vendor/github.com/authzed/cel-go/parser/helper.go478
1 files changed, 478 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/cel-go/parser/helper.go b/vendor/github.com/authzed/cel-go/parser/helper.go
new file mode 100644
index 0000000..fe8012c
--- /dev/null
+++ b/vendor/github.com/authzed/cel-go/parser/helper.go
@@ -0,0 +1,478 @@
+// Copyright 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package parser
+
+import (
+ "sync"
+
+ antlr "github.com/antlr4-go/antlr/v4"
+
+ "github.com/authzed/cel-go/common"
+ "github.com/authzed/cel-go/common/ast"
+ "github.com/authzed/cel-go/common/types"
+ "github.com/authzed/cel-go/common/types/ref"
+)
+
+type parserHelper struct {
+ exprFactory ast.ExprFactory
+ source common.Source
+ sourceInfo *ast.SourceInfo
+ nextID int64
+}
+
+func newParserHelper(source common.Source, fac ast.ExprFactory) *parserHelper {
+ return &parserHelper{
+ exprFactory: fac,
+ source: source,
+ sourceInfo: ast.NewSourceInfo(source),
+ nextID: 1,
+ }
+}
+
+func (p *parserHelper) getSourceInfo() *ast.SourceInfo {
+ return p.sourceInfo
+}
+
+func (p *parserHelper) newLiteral(ctx any, value ref.Val) ast.Expr {
+ return p.exprFactory.NewLiteral(p.newID(ctx), value)
+}
+
+func (p *parserHelper) newLiteralBool(ctx any, value bool) ast.Expr {
+ return p.newLiteral(ctx, types.Bool(value))
+}
+
+func (p *parserHelper) newLiteralString(ctx any, value string) ast.Expr {
+ return p.newLiteral(ctx, types.String(value))
+}
+
+func (p *parserHelper) newLiteralBytes(ctx any, value []byte) ast.Expr {
+ return p.newLiteral(ctx, types.Bytes(value))
+}
+
+func (p *parserHelper) newLiteralInt(ctx any, value int64) ast.Expr {
+ return p.newLiteral(ctx, types.Int(value))
+}
+
+func (p *parserHelper) newLiteralUint(ctx any, value uint64) ast.Expr {
+ return p.newLiteral(ctx, types.Uint(value))
+}
+
+func (p *parserHelper) newLiteralDouble(ctx any, value float64) ast.Expr {
+ return p.newLiteral(ctx, types.Double(value))
+}
+
+func (p *parserHelper) newIdent(ctx any, name string) ast.Expr {
+ return p.exprFactory.NewIdent(p.newID(ctx), name)
+}
+
+func (p *parserHelper) newSelect(ctx any, operand ast.Expr, field string) ast.Expr {
+ return p.exprFactory.NewSelect(p.newID(ctx), operand, field)
+}
+
+func (p *parserHelper) newPresenceTest(ctx any, operand ast.Expr, field string) ast.Expr {
+ return p.exprFactory.NewPresenceTest(p.newID(ctx), operand, field)
+}
+
+func (p *parserHelper) newGlobalCall(ctx any, function string, args ...ast.Expr) ast.Expr {
+ return p.exprFactory.NewCall(p.newID(ctx), function, args...)
+}
+
+func (p *parserHelper) newReceiverCall(ctx any, function string, target ast.Expr, args ...ast.Expr) ast.Expr {
+ return p.exprFactory.NewMemberCall(p.newID(ctx), function, target, args...)
+}
+
+func (p *parserHelper) newList(ctx any, elements []ast.Expr, optionals ...int32) ast.Expr {
+ return p.exprFactory.NewList(p.newID(ctx), elements, optionals)
+}
+
+func (p *parserHelper) newMap(ctx any, entries ...ast.EntryExpr) ast.Expr {
+ return p.exprFactory.NewMap(p.newID(ctx), entries)
+}
+
+func (p *parserHelper) newMapEntry(entryID int64, key ast.Expr, value ast.Expr, optional bool) ast.EntryExpr {
+ return p.exprFactory.NewMapEntry(entryID, key, value, optional)
+}
+
+func (p *parserHelper) newObject(ctx any, typeName string, fields ...ast.EntryExpr) ast.Expr {
+ return p.exprFactory.NewStruct(p.newID(ctx), typeName, fields)
+}
+
+func (p *parserHelper) newObjectField(fieldID int64, field string, value ast.Expr, optional bool) ast.EntryExpr {
+ return p.exprFactory.NewStructField(fieldID, field, value, optional)
+}
+
+func (p *parserHelper) newComprehension(ctx any,
+ iterRange ast.Expr,
+ iterVar string,
+ accuVar string,
+ accuInit ast.Expr,
+ condition ast.Expr,
+ step ast.Expr,
+ result ast.Expr) ast.Expr {
+ return p.exprFactory.NewComprehension(
+ p.newID(ctx), iterRange, iterVar, accuVar, accuInit, condition, step, result)
+}
+
+func (p *parserHelper) newID(ctx any) int64 {
+ if id, isID := ctx.(int64); isID {
+ return id
+ }
+ return p.id(ctx)
+}
+
+func (p *parserHelper) newExpr(ctx any) ast.Expr {
+ return p.exprFactory.NewUnspecifiedExpr(p.newID(ctx))
+}
+
+func (p *parserHelper) id(ctx any) int64 {
+ var offset ast.OffsetRange
+ switch c := ctx.(type) {
+ case antlr.ParserRuleContext:
+ start := c.GetStart()
+ offset.Start = p.sourceInfo.ComputeOffset(int32(start.GetLine()), int32(start.GetColumn()))
+ offset.Stop = offset.Start + int32(len(c.GetText()))
+ case antlr.Token:
+ offset.Start = p.sourceInfo.ComputeOffset(int32(c.GetLine()), int32(c.GetColumn()))
+ offset.Stop = offset.Start + int32(len(c.GetText()))
+ case common.Location:
+ offset.Start = p.sourceInfo.ComputeOffset(int32(c.Line()), int32(c.Column()))
+ offset.Stop = offset.Start
+ case ast.OffsetRange:
+ offset = c
+ default:
+ // This should only happen if the ctx is nil
+ return -1
+ }
+ id := p.nextID
+ p.sourceInfo.SetOffsetRange(id, offset)
+ p.nextID++
+ return id
+}
+
+func (p *parserHelper) deleteId(id int64) {
+ p.sourceInfo.ClearOffsetRange(id)
+ if id == p.nextID-1 {
+ p.nextID--
+ }
+}
+
+func (p *parserHelper) getLocation(id int64) common.Location {
+ return p.sourceInfo.GetStartLocation(id)
+}
+
+// buildMacroCallArg iterates the expression and returns a new expression
+// where all macros have been replaced by their IDs in MacroCalls
+func (p *parserHelper) buildMacroCallArg(expr ast.Expr) ast.Expr {
+ if _, found := p.sourceInfo.GetMacroCall(expr.ID()); found {
+ return p.exprFactory.NewUnspecifiedExpr(expr.ID())
+ }
+
+ switch expr.Kind() {
+ case ast.CallKind:
+ // Iterate the AST from `expr` recursively looking for macros. Because we are at most
+ // starting from the top level macro, this recursion is bounded by the size of the AST. This
+ // means that the depth check on the AST during parsing will catch recursion overflows
+ // before we get to here.
+ call := expr.AsCall()
+ macroArgs := make([]ast.Expr, len(call.Args()))
+ for index, arg := range call.Args() {
+ macroArgs[index] = p.buildMacroCallArg(arg)
+ }
+ if !call.IsMemberFunction() {
+ return p.exprFactory.NewCall(expr.ID(), call.FunctionName(), macroArgs...)
+ }
+ macroTarget := p.buildMacroCallArg(call.Target())
+ return p.exprFactory.NewMemberCall(expr.ID(), call.FunctionName(), macroTarget, macroArgs...)
+ case ast.ListKind:
+ list := expr.AsList()
+ macroListArgs := make([]ast.Expr, list.Size())
+ for i, elem := range list.Elements() {
+ macroListArgs[i] = p.buildMacroCallArg(elem)
+ }
+ return p.exprFactory.NewList(expr.ID(), macroListArgs, list.OptionalIndices())
+ }
+ return expr
+}
+
+// addMacroCall adds the macro the the MacroCalls map in source info. If a macro has args/subargs/target
+// that are macros, their ID will be stored instead for later self-lookups.
+func (p *parserHelper) addMacroCall(exprID int64, function string, target ast.Expr, args ...ast.Expr) {
+ macroArgs := make([]ast.Expr, len(args))
+ for index, arg := range args {
+ macroArgs[index] = p.buildMacroCallArg(arg)
+ }
+ if target == nil {
+ p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewCall(0, function, macroArgs...))
+ return
+ }
+ macroTarget := target
+ if _, found := p.sourceInfo.GetMacroCall(target.ID()); found {
+ macroTarget = p.exprFactory.NewUnspecifiedExpr(target.ID())
+ } else {
+ macroTarget = p.buildMacroCallArg(target)
+ }
+ p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewMemberCall(0, function, macroTarget, macroArgs...))
+}
+
+// logicManager compacts logical trees into a more efficient structure which is semantically
+// equivalent with how the logic graph is constructed by the ANTLR parser.
+//
+// The purpose of the logicManager is to ensure a compact serialization format for the logical &&, ||
+// operators which have a tendency to create long DAGs which are skewed in one direction. Since the
+// operators are commutative re-ordering the terms *must not* affect the evaluation result.
+//
+// The logic manager will either render the terms to N-chained && / || operators as a single logical
+// call with N-terms, or will rebalance the tree. Rebalancing the terms is a safe, if somewhat
+// controversial choice as it alters the traditional order of execution assumptions present in most
+// expressions.
+type logicManager struct {
+ exprFactory ast.ExprFactory
+ function string
+ terms []ast.Expr
+ ops []int64
+ variadicASTs bool
+}
+
+// newVariadicLogicManager creates a logic manager instance bound to a specific function and its first term.
+func newVariadicLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager {
+ return &logicManager{
+ exprFactory: fac,
+ function: function,
+ terms: []ast.Expr{term},
+ ops: []int64{},
+ variadicASTs: true,
+ }
+}
+
+// newBalancingLogicManager creates a logic manager instance bound to a specific function and its first term.
+func newBalancingLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager {
+ return &logicManager{
+ exprFactory: fac,
+ function: function,
+ terms: []ast.Expr{term},
+ ops: []int64{},
+ variadicASTs: false,
+ }
+}
+
+// addTerm adds an operation identifier and term to the set of terms to be balanced.
+func (l *logicManager) addTerm(op int64, term ast.Expr) {
+ l.terms = append(l.terms, term)
+ l.ops = append(l.ops, op)
+}
+
+// toExpr renders the logic graph into an Expr value, either balancing a tree of logical
+// operations or creating a variadic representation of the logical operator.
+func (l *logicManager) toExpr() ast.Expr {
+ if len(l.terms) == 1 {
+ return l.terms[0]
+ }
+ if l.variadicASTs {
+ return l.exprFactory.NewCall(l.ops[0], l.function, l.terms...)
+ }
+ return l.balancedTree(0, len(l.ops)-1)
+}
+
+// balancedTree recursively balances the terms provided to a commutative operator.
+func (l *logicManager) balancedTree(lo, hi int) ast.Expr {
+ mid := (lo + hi + 1) / 2
+
+ var left ast.Expr
+ if mid == lo {
+ left = l.terms[mid]
+ } else {
+ left = l.balancedTree(lo, mid-1)
+ }
+
+ var right ast.Expr
+ if mid == hi {
+ right = l.terms[mid+1]
+ } else {
+ right = l.balancedTree(mid+1, hi)
+ }
+ return l.exprFactory.NewCall(l.ops[mid], l.function, left, right)
+}
+
+type exprHelper struct {
+ *parserHelper
+ id int64
+}
+
+func (e *exprHelper) nextMacroID() int64 {
+ return e.parserHelper.id(e.parserHelper.getLocation(e.id))
+}
+
+// Copy implements the ExprHelper interface method by producing a copy of the input Expr value
+// with a fresh set of numeric identifiers the Expr and all its descendants.
+func (e *exprHelper) Copy(expr ast.Expr) ast.Expr {
+ offsetRange, _ := e.parserHelper.sourceInfo.GetOffsetRange(expr.ID())
+ copyID := e.parserHelper.newID(offsetRange)
+ switch expr.Kind() {
+ case ast.LiteralKind:
+ return e.exprFactory.NewLiteral(copyID, expr.AsLiteral())
+ case ast.IdentKind:
+ return e.exprFactory.NewIdent(copyID, expr.AsIdent())
+ case ast.SelectKind:
+ sel := expr.AsSelect()
+ op := e.Copy(sel.Operand())
+ if sel.IsTestOnly() {
+ return e.exprFactory.NewPresenceTest(copyID, op, sel.FieldName())
+ }
+ return e.exprFactory.NewSelect(copyID, op, sel.FieldName())
+ case ast.CallKind:
+ call := expr.AsCall()
+ args := call.Args()
+ argsCopy := make([]ast.Expr, len(args))
+ for i, arg := range args {
+ argsCopy[i] = e.Copy(arg)
+ }
+ if !call.IsMemberFunction() {
+ return e.exprFactory.NewCall(copyID, call.FunctionName(), argsCopy...)
+ }
+ return e.exprFactory.NewMemberCall(copyID, call.FunctionName(), e.Copy(call.Target()), argsCopy...)
+ case ast.ListKind:
+ list := expr.AsList()
+ elems := list.Elements()
+ elemsCopy := make([]ast.Expr, len(elems))
+ for i, elem := range elems {
+ elemsCopy[i] = e.Copy(elem)
+ }
+ return e.exprFactory.NewList(copyID, elemsCopy, list.OptionalIndices())
+ case ast.MapKind:
+ m := expr.AsMap()
+ entries := m.Entries()
+ entriesCopy := make([]ast.EntryExpr, len(entries))
+ for i, en := range entries {
+ entry := en.AsMapEntry()
+ entryID := e.nextMacroID()
+ entriesCopy[i] = e.exprFactory.NewMapEntry(entryID,
+ e.Copy(entry.Key()), e.Copy(entry.Value()), entry.IsOptional())
+ }
+ return e.exprFactory.NewMap(copyID, entriesCopy)
+ case ast.StructKind:
+ s := expr.AsStruct()
+ fields := s.Fields()
+ fieldsCopy := make([]ast.EntryExpr, len(fields))
+ for i, f := range fields {
+ field := f.AsStructField()
+ fieldID := e.nextMacroID()
+ fieldsCopy[i] = e.exprFactory.NewStructField(fieldID,
+ field.Name(), e.Copy(field.Value()), field.IsOptional())
+ }
+ return e.exprFactory.NewStruct(copyID, s.TypeName(), fieldsCopy)
+ case ast.ComprehensionKind:
+ compre := expr.AsComprehension()
+ iterRange := e.Copy(compre.IterRange())
+ accuInit := e.Copy(compre.AccuInit())
+ cond := e.Copy(compre.LoopCondition())
+ step := e.Copy(compre.LoopStep())
+ result := e.Copy(compre.Result())
+ return e.exprFactory.NewComprehension(copyID,
+ iterRange, compre.IterVar(), compre.AccuVar(), accuInit, cond, step, result)
+ }
+ return e.exprFactory.NewUnspecifiedExpr(copyID)
+}
+
+// NewLiteral implements the ExprHelper interface method.
+func (e *exprHelper) NewLiteral(value ref.Val) ast.Expr {
+ return e.exprFactory.NewLiteral(e.nextMacroID(), value)
+}
+
+// NewList implements the ExprHelper interface method.
+func (e *exprHelper) NewList(elems ...ast.Expr) ast.Expr {
+ return e.exprFactory.NewList(e.nextMacroID(), elems, []int32{})
+}
+
+// NewMap implements the ExprHelper interface method.
+func (e *exprHelper) NewMap(entries ...ast.EntryExpr) ast.Expr {
+ return e.exprFactory.NewMap(e.nextMacroID(), entries)
+}
+
+// NewMapEntry implements the ExprHelper interface method.
+func (e *exprHelper) NewMapEntry(key ast.Expr, val ast.Expr, optional bool) ast.EntryExpr {
+ return e.exprFactory.NewMapEntry(e.nextMacroID(), key, val, optional)
+}
+
+// NewStruct implements the ExprHelper interface method.
+func (e *exprHelper) NewStruct(typeName string, fieldInits ...ast.EntryExpr) ast.Expr {
+ return e.exprFactory.NewStruct(e.nextMacroID(), typeName, fieldInits)
+}
+
+// NewStructField implements the ExprHelper interface method.
+func (e *exprHelper) NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr {
+ return e.exprFactory.NewStructField(e.nextMacroID(), field, init, optional)
+}
+
+// NewComprehension implements the ExprHelper interface method.
+func (e *exprHelper) NewComprehension(
+ iterRange ast.Expr,
+ iterVar string,
+ accuVar string,
+ accuInit ast.Expr,
+ condition ast.Expr,
+ step ast.Expr,
+ result ast.Expr) ast.Expr {
+ return e.exprFactory.NewComprehension(
+ e.nextMacroID(), iterRange, iterVar, accuVar, accuInit, condition, step, result)
+}
+
+// NewIdent implements the ExprHelper interface method.
+func (e *exprHelper) NewIdent(name string) ast.Expr {
+ return e.exprFactory.NewIdent(e.nextMacroID(), name)
+}
+
+// NewAccuIdent implements the ExprHelper interface method.
+func (e *exprHelper) NewAccuIdent() ast.Expr {
+ return e.exprFactory.NewAccuIdent(e.nextMacroID())
+}
+
+// NewGlobalCall implements the ExprHelper interface method.
+func (e *exprHelper) NewCall(function string, args ...ast.Expr) ast.Expr {
+ return e.exprFactory.NewCall(e.nextMacroID(), function, args...)
+}
+
+// NewMemberCall implements the ExprHelper interface method.
+func (e *exprHelper) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr {
+ return e.exprFactory.NewMemberCall(e.nextMacroID(), function, target, args...)
+}
+
+// NewPresenceTest implements the ExprHelper interface method.
+func (e *exprHelper) NewPresenceTest(operand ast.Expr, field string) ast.Expr {
+ return e.exprFactory.NewPresenceTest(e.nextMacroID(), operand, field)
+}
+
+// NewSelect implements the ExprHelper interface method.
+func (e *exprHelper) NewSelect(operand ast.Expr, field string) ast.Expr {
+ return e.exprFactory.NewSelect(e.nextMacroID(), operand, field)
+}
+
+// OffsetLocation implements the ExprHelper interface method.
+func (e *exprHelper) OffsetLocation(exprID int64) common.Location {
+ return e.parserHelper.sourceInfo.GetStartLocation(exprID)
+}
+
+// NewError associates an error message with a given expression id, populating the source offset location of the error if possible.
+func (e *exprHelper) NewError(exprID int64, message string) *common.Error {
+ return common.NewError(exprID, message, e.OffsetLocation(exprID))
+}
+
+var (
+ // Thread-safe pool of ExprHelper values to minimize alloc overhead of ExprHelper creations.
+ exprHelperPool = &sync.Pool{
+ New: func() any {
+ return &exprHelper{}
+ },
+ }
+)