diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/cel-go/cel | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/cel-go/cel')
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/BUILD.bazel | 91 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/cel.go | 19 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/decls.go | 355 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/env.go | 881 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/folding.go | 559 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/inlining.go | 228 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/io.go | 252 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/library.go | 790 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/macro.go | 576 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/optimizer.go | 509 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/options.go | 667 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/program.go | 545 | ||||
| -rw-r--r-- | vendor/github.com/authzed/cel-go/cel/validator.go | 375 |
13 files changed, 5847 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/cel-go/cel/BUILD.bazel b/vendor/github.com/authzed/cel-go/cel/BUILD.bazel new file mode 100644 index 0000000..74a619b --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/BUILD.bazel @@ -0,0 +1,91 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +go_library( + name = "go_default_library", + srcs = [ + "cel.go", + "decls.go", + "env.go", + "folding.go", + "io.go", + "inlining.go", + "library.go", + "macro.go", + "optimizer.go", + "options.go", + "program.go", + "validator.go", + ], + importpath = "github.com/authzed/cel-go/cel", + visibility = ["//visibility:public"], + deps = [ + "//checker:go_default_library", + "//checker/decls:go_default_library", + "//common:go_default_library", + "//common/ast:go_default_library", + "//common/containers:go_default_library", + "//common/decls:go_default_library", + "//common/functions:go_default_library", + "//common/operators:go_default_library", + "//common/overloads:go_default_library", + "//common/stdlib:go_default_library", + "//common/types:go_default_library", + "//common/types/pb:go_default_library", + "//common/types/ref:go_default_library", + "//common/types/traits:go_default_library", + "//interpreter:go_default_library", + "//parser:go_default_library", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//reflect/protodesc:go_default_library", + "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", + "@org_golang_google_protobuf//reflect/protoregistry:go_default_library", + "@org_golang_google_protobuf//types/descriptorpb:go_default_library", + "@org_golang_google_protobuf//types/dynamicpb:go_default_library", + "@org_golang_google_protobuf//types/known/anypb:go_default_library", + "@org_golang_google_protobuf//types/known/durationpb:go_default_library", + "@org_golang_google_protobuf//types/known/timestamppb:go_default_library", + ], +) + +go_test( + name = "go_default_test", + srcs = [ + "cel_example_test.go", + "cel_test.go", + "decls_test.go", + "env_test.go", + "folding_test.go", + "io_test.go", + "inlining_test.go", + "optimizer_test.go", + "validator_test.go", + ], + data = [ + "//cel/testdata:gen_test_fds", + ], + embed = [ + ":go_default_library", + ], + deps = [ + "//common/operators:go_default_library", + "//common/overloads:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//common/types/traits:go_default_library", + "//ext:go_default_library", + "//test:go_default_library", + "//test/proto2pb:go_default_library", + "//test/proto3pb:go_default_library", + "@io_bazel_rules_go//proto/wkt:descriptor_go_proto", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", + "@org_golang_google_protobuf//types/known/structpb:go_default_library", + "@org_golang_google_protobuf//types/known/wrapperspb:go_default_library", + ], +) diff --git a/vendor/github.com/authzed/cel-go/cel/cel.go b/vendor/github.com/authzed/cel-go/cel/cel.go new file mode 100644 index 0000000..eb5a9f4 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/cel.go @@ -0,0 +1,19 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cel defines the top-level interface for the Common Expression Language (CEL). +// +// CEL is a non-Turing complete expression language designed to parse, check, and evaluate +// expressions against user-defined environments. +package cel diff --git a/vendor/github.com/authzed/cel-go/cel/decls.go b/vendor/github.com/authzed/cel-go/cel/decls.go new file mode 100644 index 0000000..7b704c0 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/decls.go @@ -0,0 +1,355 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/decls" + "github.com/authzed/cel-go/common/functions" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// Kind indicates a CEL type's kind which is used to differentiate quickly between simple and complex types. +type Kind = types.Kind + +const ( + // DynKind represents a dynamic type. This kind only exists at type-check time. + DynKind Kind = types.DynKind + + // AnyKind represents a google.protobuf.Any type. This kind only exists at type-check time. + AnyKind = types.AnyKind + + // BoolKind represents a boolean type. + BoolKind = types.BoolKind + + // BytesKind represents a bytes type. + BytesKind = types.BytesKind + + // DoubleKind represents a double type. + DoubleKind = types.DoubleKind + + // DurationKind represents a CEL duration type. + DurationKind = types.DurationKind + + // IntKind represents an integer type. + IntKind = types.IntKind + + // ListKind represents a list type. + ListKind = types.ListKind + + // MapKind represents a map type. + MapKind = types.MapKind + + // NullTypeKind represents a null type. + NullTypeKind = types.NullTypeKind + + // OpaqueKind represents an abstract type which has no accessible fields. + OpaqueKind = types.OpaqueKind + + // StringKind represents a string type. + StringKind = types.StringKind + + // StructKind represents a structured object with typed fields. + StructKind = types.StructKind + + // TimestampKind represents a a CEL time type. + TimestampKind = types.TimestampKind + + // TypeKind represents the CEL type. + TypeKind = types.TypeKind + + // TypeParamKind represents a parameterized type whose type name will be resolved at type-check time, if possible. + TypeParamKind = types.TypeParamKind + + // UintKind represents a uint type. + UintKind = types.UintKind +) + +var ( + // AnyType represents the google.protobuf.Any type. + AnyType = types.AnyType + // BoolType represents the bool type. + BoolType = types.BoolType + // BytesType represents the bytes type. + BytesType = types.BytesType + // DoubleType represents the double type. + DoubleType = types.DoubleType + // DurationType represents the CEL duration type. + DurationType = types.DurationType + // DynType represents a dynamic CEL type whose type will be determined at runtime from context. + DynType = types.DynType + // IntType represents the int type. + IntType = types.IntType + // NullType represents the type of a null value. + NullType = types.NullType + // StringType represents the string type. + StringType = types.StringType + // TimestampType represents the time type. + TimestampType = types.TimestampType + // TypeType represents a CEL type + TypeType = types.TypeType + // UintType represents a uint type. + UintType = types.UintType + + // function references for instantiating new types. + + // ListType creates an instances of a list type value with the provided element type. + ListType = types.NewListType + // MapType creates an instance of a map type value with the provided key and value types. + MapType = types.NewMapType + // NullableType creates an instance of a nullable type with the provided wrapped type. + // + // Note: only primitive types are supported as wrapped types. + NullableType = types.NewNullableType + // OptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional. + OptionalType = types.NewOptionalType + // OpaqueType creates an abstract parameterized type with a given name. + OpaqueType = types.NewOpaqueType + // ObjectType creates a type references to an externally defined type, e.g. a protobuf message type. + ObjectType = types.NewObjectType + // TypeParamType creates a parameterized type instance. + TypeParamType = types.NewTypeParamType +) + +// Type holds a reference to a runtime type with an optional type-checked set of type parameters. +type Type = types.Type + +// Constant creates an instances of an identifier declaration with a variable name, type, and value. +func Constant(name string, t *Type, v ref.Val) EnvOption { + return func(e *Env) (*Env, error) { + e.variables = append(e.variables, decls.NewConstant(name, t, v)) + return e, nil + } +} + +// Variable creates an instance of a variable declaration with a variable name and type. +func Variable(name string, t *Type) EnvOption { + return func(e *Env) (*Env, error) { + e.variables = append(e.variables, decls.NewVariable(name, t)) + return e, nil + } +} + +// Function defines a function and overloads with optional singleton or per-overload bindings. +// +// Using Function is roughly equivalent to calling Declarations() to declare the function signatures +// and Functions() to define the function bindings, if they have been defined. Specifying the +// same function name more than once will result in the aggregation of the function overloads. If any +// signatures conflict between the existing and new function definition an error will be raised. +// However, if the signatures are identical and the overload ids are the same, the redefinition will +// be considered a no-op. +// +// One key difference with using Function() is that each FunctionDecl provided will handle dynamic +// dispatch based on the type-signatures of the overloads provided which means overload resolution at +// runtime is handled out of the box rather than via a custom binding for overload resolution via +// Functions(): +// +// - Overloads are searched in the order they are declared +// - Dynamic dispatch for lists and maps is limited by inspection of the list and map contents +// +// at runtime. Empty lists and maps will result in a 'default dispatch' +// +// - In the event that a default dispatch occurs, the first overload provided is the one invoked +// +// If you intend to use overloads which differentiate based on the key or element type of a list or +// map, consider using a generic function instead: e.g. func(list(T)) or func(map(K, V)) as this +// will allow your implementation to determine how best to handle dispatch and the default behavior +// for empty lists and maps whose contents cannot be inspected. +// +// For functions which use parameterized opaque types (abstract types), consider using a singleton +// function which is capable of inspecting the contents of the type and resolving the appropriate +// overload as CEL can only make inferences by type-name regarding such types. +func Function(name string, opts ...FunctionOpt) EnvOption { + return func(e *Env) (*Env, error) { + fn, err := decls.NewFunction(name, opts...) + if err != nil { + return nil, err + } + if existing, found := e.functions[fn.Name()]; found { + fn, err = existing.Merge(fn) + if err != nil { + return nil, err + } + } + e.functions[fn.Name()] = fn + return e, nil + } +} + +// FunctionOpt defines a functional option for configuring a function declaration. +type FunctionOpt = decls.FunctionOpt + +// SingletonUnaryBinding creates a singleton function definition to be used for all function overloads. +// +// Note, this approach works well if operand is expected to have a specific trait which it implements, +// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings. +func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt { + return decls.SingletonUnaryBinding(fn, traits...) +} + +// SingletonBinaryImpl creates a singleton function definition to be used with all function overloads. +// +// Note, this approach works well if operand is expected to have a specific trait which it implements, +// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings. +// +// Deprecated: use SingletonBinaryBinding +func SingletonBinaryImpl(fn functions.BinaryOp, traits ...int) FunctionOpt { + return decls.SingletonBinaryBinding(fn, traits...) +} + +// SingletonBinaryBinding creates a singleton function definition to be used with all function overloads. +// +// Note, this approach works well if operand is expected to have a specific trait which it implements, +// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings. +func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt { + return decls.SingletonBinaryBinding(fn, traits...) +} + +// SingletonFunctionImpl creates a singleton function definition to be used with all function overloads. +// +// Note, this approach works well if operand is expected to have a specific trait which it implements, +// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings. +// +// Deprecated: use SingletonFunctionBinding +func SingletonFunctionImpl(fn functions.FunctionOp, traits ...int) FunctionOpt { + return decls.SingletonFunctionBinding(fn, traits...) +} + +// SingletonFunctionBinding creates a singleton function definition to be used with all function overloads. +// +// Note, this approach works well if operand is expected to have a specific trait which it implements, +// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings. +func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOpt { + return decls.SingletonFunctionBinding(fn, traits...) +} + +// DisableDeclaration disables the function signatures, effectively removing them from the type-check +// environment while preserving the runtime bindings. +func DisableDeclaration(value bool) FunctionOpt { + return decls.DisableDeclaration(value) +} + +// Overload defines a new global overload with an overload id, argument types, and result type. Through the +// use of OverloadOpt options, the overload may also be configured with a binding, an operand trait, and to +// be non-strict. +// +// Note: function bindings should be commonly configured with Overload instances whereas operand traits and +// strict-ness should be rare occurrences. +func Overload(overloadID string, args []*Type, resultType *Type, opts ...OverloadOpt) FunctionOpt { + return decls.Overload(overloadID, args, resultType, opts...) +} + +// MemberOverload defines a new receiver-style overload (or member function) with an overload id, argument types, +// and result type. Through the use of OverloadOpt options, the overload may also be configured with a binding, +// an operand trait, and to be non-strict. +// +// Note: function bindings should be commonly configured with Overload instances whereas operand traits and +// strict-ness should be rare occurrences. +func MemberOverload(overloadID string, args []*Type, resultType *Type, opts ...OverloadOpt) FunctionOpt { + return decls.MemberOverload(overloadID, args, resultType, opts...) +} + +// OverloadOpt is a functional option for configuring a function overload. +type OverloadOpt = decls.OverloadOpt + +// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func UnaryBinding(binding functions.UnaryOp) OverloadOpt { + return decls.UnaryBinding(binding) +} + +// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func BinaryBinding(binding functions.BinaryOp) OverloadOpt { + return decls.BinaryBinding(binding) +} + +// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime +// type-guard which ensures runtime type agreement between the overload signature and runtime argument types. +func FunctionBinding(binding functions.FunctionOp) OverloadOpt { + return decls.FunctionBinding(binding) +} + +// OverloadIsNonStrict enables the function to be called with error and unknown argument values. +// +// Note: do not use this option unless absoluately necessary as it should be an uncommon feature. +func OverloadIsNonStrict() OverloadOpt { + return decls.OverloadIsNonStrict() +} + +// OverloadOperandTrait configures a set of traits which the first argument to the overload must implement in order to be +// successfully invoked. +func OverloadOperandTrait(trait int) OverloadOpt { + return decls.OverloadOperandTrait(trait) +} + +// TypeToExprType converts a CEL-native type representation to a protobuf CEL Type representation. +func TypeToExprType(t *Type) (*exprpb.Type, error) { + return types.TypeToExprType(t) +} + +// ExprTypeToType converts a protobuf CEL type representation to a CEL-native type representation. +func ExprTypeToType(t *exprpb.Type) (*Type, error) { + return types.ExprTypeToType(t) +} + +// ExprDeclToDeclaration converts a protobuf CEL declaration to a CEL-native declaration, either a Variable or Function. +func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) { + switch d.GetDeclKind().(type) { + case *exprpb.Decl_Function: + overloads := d.GetFunction().GetOverloads() + opts := make([]FunctionOpt, len(overloads)) + for i, o := range overloads { + args := make([]*Type, len(o.GetParams())) + for j, p := range o.GetParams() { + a, err := types.ExprTypeToType(p) + if err != nil { + return nil, err + } + args[j] = a + } + res, err := types.ExprTypeToType(o.GetResultType()) + if err != nil { + return nil, err + } + if o.IsInstanceFunction { + opts[i] = decls.MemberOverload(o.GetOverloadId(), args, res) + } else { + opts[i] = decls.Overload(o.GetOverloadId(), args, res) + } + } + return Function(d.GetName(), opts...), nil + case *exprpb.Decl_Ident: + t, err := types.ExprTypeToType(d.GetIdent().GetType()) + if err != nil { + return nil, err + } + if d.GetIdent().GetValue() == nil { + return Variable(d.GetName(), t), nil + } + val, err := ast.ConstantToVal(d.GetIdent().GetValue()) + if err != nil { + return nil, err + } + return Constant(d.GetName(), t, val), nil + default: + return nil, fmt.Errorf("unsupported decl: %v", d) + } +} diff --git a/vendor/github.com/authzed/cel-go/cel/env.go b/vendor/github.com/authzed/cel-go/cel/env.go new file mode 100644 index 0000000..5cbb1a2 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/env.go @@ -0,0 +1,881 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "errors" + "sync" + + "github.com/authzed/cel-go/checker" + chkdecls "github.com/authzed/cel-go/checker/decls" + "github.com/authzed/cel-go/common" + celast "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/containers" + "github.com/authzed/cel-go/common/decls" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/interpreter" + "github.com/authzed/cel-go/parser" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// Source interface representing a user-provided expression. +type Source = common.Source + +// Ast representing the checked or unchecked expression, its source, and related metadata such as +// source position information. +type Ast struct { + source Source + impl *celast.AST +} + +// NativeRep converts the AST to a Go-native representation. +func (ast *Ast) NativeRep() *celast.AST { + return ast.impl +} + +// Expr returns the proto serializable instance of the parsed/checked expression. +// +// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr() +// the result instead. +func (ast *Ast) Expr() *exprpb.Expr { + if ast == nil { + return nil + } + pbExpr, _ := celast.ExprToProto(ast.impl.Expr()) + return pbExpr +} + +// IsChecked returns whether the Ast value has been successfully type-checked. +func (ast *Ast) IsChecked() bool { + if ast == nil { + return false + } + return ast.impl.IsChecked() +} + +// SourceInfo returns character offset and newline position information about expression elements. +func (ast *Ast) SourceInfo() *exprpb.SourceInfo { + if ast == nil { + return nil + } + pbInfo, _ := celast.SourceInfoToProto(ast.impl.SourceInfo()) + return pbInfo +} + +// ResultType returns the output type of the expression if the Ast has been type-checked, else +// returns chkdecls.Dyn as the parse step cannot infer the type. +// +// Deprecated: use OutputType +func (ast *Ast) ResultType() *exprpb.Type { + out := ast.OutputType() + t, err := TypeToExprType(out) + if err != nil { + return chkdecls.Dyn + } + return t +} + +// OutputType returns the output type of the expression if the Ast has been type-checked, else +// returns cel.DynType as the parse step cannot infer types. +func (ast *Ast) OutputType() *Type { + if ast == nil { + return types.ErrorType + } + return ast.impl.GetType(ast.impl.Expr().ID()) +} + +// Source returns a view of the input used to create the Ast. This source may be complete or +// constructed from the SourceInfo. +func (ast *Ast) Source() Source { + if ast == nil { + return nil + } + return ast.source +} + +// FormatType converts a type message into a string representation. +// +// Deprecated: prefer FormatCELType +func FormatType(t *exprpb.Type) string { + return checker.FormatCheckedType(t) +} + +// FormatCELType formats a cel.Type value to a string representation. +// +// The type formatting is identical to FormatType. +func FormatCELType(t *Type) string { + return checker.FormatCELType(t) +} + +// Env encapsulates the context necessary to perform parsing, type checking, or generation of +// evaluable programs for different expressions. +type Env struct { + Container *containers.Container + variables []*decls.VariableDecl + functions map[string]*decls.FunctionDecl + macros []parser.Macro + adapter types.Adapter + provider types.Provider + features map[int]bool + appliedFeatures map[int]bool + libraries map[string]bool + validators []ASTValidator + costOptions []checker.CostOption + + // Internal parser representation + prsr *parser.Parser + prsrOpts []parser.Option + + // Internal checker representation + chkMutex sync.Mutex + chk *checker.Env + chkErr error + chkOnce sync.Once + chkOpts []checker.Option + + // Program options tied to the environment + progOpts []ProgramOption +} + +// NewEnv creates a program environment configured with the standard library of CEL functions and +// macros. The Env value returned can parse and check any CEL program which builds upon the core +// features documented in the CEL specification. +// +// See the EnvOption helper functions for the options that can be used to configure the +// environment. +func NewEnv(opts ...EnvOption) (*Env, error) { + // Extend the statically configured standard environment, disabling eager validation to ensure + // the cost of setup for the environment is still just as cheap as it is in v0.11.x and earlier + // releases. The user provided options can easily re-enable the eager validation as they are + // processed after this default option. + stdOpts := append([]EnvOption{EagerlyValidateDeclarations(false)}, opts...) + env, err := getStdEnv() + if err != nil { + return nil, err + } + return env.Extend(stdOpts...) +} + +// NewCustomEnv creates a custom program environment which is not automatically configured with the +// standard library of functions and macros documented in the CEL spec. +// +// The purpose for using a custom environment might be for subsetting the standard library produced +// by the cel.StdLib() function. Subsetting CEL is a core aspect of its design that allows users to +// limit the compute and memory impact of a CEL program by controlling the functions and macros +// that may appear in a given expression. +// +// See the EnvOption helper functions for the options that can be used to configure the +// environment. +func NewCustomEnv(opts ...EnvOption) (*Env, error) { + registry, err := types.NewRegistry() + if err != nil { + return nil, err + } + return (&Env{ + variables: []*decls.VariableDecl{}, + functions: map[string]*decls.FunctionDecl{}, + macros: []parser.Macro{}, + Container: containers.DefaultContainer, + adapter: registry, + provider: registry, + features: map[int]bool{}, + appliedFeatures: map[int]bool{}, + libraries: map[string]bool{}, + validators: []ASTValidator{}, + progOpts: []ProgramOption{}, + costOptions: []checker.CostOption{}, + }).configure(opts) +} + +// Check performs type-checking on the input Ast and yields a checked Ast and/or set of Issues. +// If any `ASTValidators` are configured on the environment, they will be applied after a valid +// type-check result. If any issues are detected, the validators will provide them on the +// output Issues object. +// +// Either checking or validation has failed if the returned Issues value and its Issues.Err() +// value are non-nil. Issues should be inspected if they are non-nil, but may not represent a +// fatal error. +// +// It is possible to have both non-nil Ast and Issues values returned from this call: however, +// the mere presence of an Ast does not imply that it is valid for use. +func (e *Env) Check(ast *Ast) (*Ast, *Issues) { + // Construct the internal checker env, erroring if there is an issue adding the declarations. + chk, err := e.initChecker() + if err != nil { + errs := common.NewErrors(ast.Source()) + errs.ReportError(common.NoLocation, err.Error()) + return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) + } + + checked, errs := checker.Check(ast.impl, ast.Source(), chk) + if len(errs.GetErrors()) > 0 { + return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) + } + // Manually create the Ast to ensure that the Ast source information (which may be more + // detailed than the information provided by Check), is returned to the caller. + ast = &Ast{ + source: ast.Source(), + impl: checked} + + // Avoid creating a validator config if it's not needed. + if len(e.validators) == 0 { + return ast, nil + } + + // Generate a validator configuration from the set of configured validators. + vConfig := newValidatorConfig() + for _, v := range e.validators { + if cv, ok := v.(ASTValidatorConfigurer); ok { + cv.Configure(vConfig) + } + } + // Apply additional validators on the type-checked result. + iss := NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) + for _, v := range e.validators { + v.Validate(e, vConfig, checked, iss) + } + if iss.Err() != nil { + return nil, iss + } + return ast, nil +} + +// Compile combines the Parse and Check phases CEL program compilation to produce an Ast and +// associated issues. +// +// If an error is encountered during parsing the Compile step will not continue with the Check +// phase. If non-error issues are encountered during Parse, they may be combined with any issues +// discovered during Check. +// +// Note, for parse-only uses of CEL use Parse. +func (e *Env) Compile(txt string) (*Ast, *Issues) { + return e.CompileSource(common.NewTextSource(txt)) +} + +// CompileSource combines the Parse and Check phases CEL program compilation to produce an Ast and +// associated issues. +// +// If an error is encountered during parsing the CompileSource step will not continue with the +// Check phase. If non-error issues are encountered during Parse, they may be combined with any +// issues discovered during Check. +// +// Note, for parse-only uses of CEL use Parse. +func (e *Env) CompileSource(src Source) (*Ast, *Issues) { + ast, iss := e.ParseSource(src) + if iss.Err() != nil { + return nil, iss + } + checked, iss2 := e.Check(ast) + if iss2.Err() != nil { + return nil, iss2 + } + return checked, iss2 +} + +// Extend the current environment with additional options to produce a new Env. +// +// Note, the extended Env value should not share memory with the original. It is possible, however, +// that a CustomTypeAdapter or CustomTypeProvider options could provide values which are mutable. +// To ensure separation of state between extended environments either make sure the TypeAdapter and +// TypeProvider are immutable, or that their underlying implementations are based on the +// ref.TypeRegistry which provides a Copy method which will be invoked by this method. +func (e *Env) Extend(opts ...EnvOption) (*Env, error) { + chk, chkErr := e.getCheckerOrError() + if chkErr != nil { + return nil, chkErr + } + + prsrOptsCopy := make([]parser.Option, len(e.prsrOpts)) + copy(prsrOptsCopy, e.prsrOpts) + + // The type-checker is configured with Declarations. The declarations may either be provided + // as options which have not yet been validated, or may come from a previous checker instance + // whose types have already been validated. + chkOptsCopy := make([]checker.Option, len(e.chkOpts)) + copy(chkOptsCopy, e.chkOpts) + + // Copy the declarations if needed. + varsCopy := []*decls.VariableDecl{} + if chk != nil { + // If the type-checker has already been instantiated, then the e.declarations have been + // validated within the chk instance. + chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(chk)) + } else { + // If the type-checker has not been instantiated, ensure the unvalidated declarations are + // provided to the extended Env instance. + varsCopy = make([]*decls.VariableDecl, len(e.variables)) + copy(varsCopy, e.variables) + } + + // Copy macros and program options + macsCopy := make([]parser.Macro, len(e.macros)) + progOptsCopy := make([]ProgramOption, len(e.progOpts)) + copy(macsCopy, e.macros) + copy(progOptsCopy, e.progOpts) + + // Copy the adapter / provider if they appear to be mutable. + adapter := e.adapter + provider := e.provider + adapterReg, isAdapterReg := e.adapter.(*types.Registry) + providerReg, isProviderReg := e.provider.(*types.Registry) + // In most cases the provider and adapter will be a ref.TypeRegistry; + // however, in the rare cases where they are not, they are assumed to + // be immutable. Since it is possible to set the TypeProvider separately + // from the TypeAdapter, the possible configurations which could use a + // TypeRegistry as the base implementation are captured below. + if isAdapterReg && isProviderReg { + reg := providerReg.Copy() + provider = reg + // If the adapter and provider are the same object, set the adapter + // to the same ref.TypeRegistry as the provider. + if adapterReg == providerReg { + adapter = reg + } else { + // Otherwise, make a copy of the adapter. + adapter = adapterReg.Copy() + } + } else if isProviderReg { + provider = providerReg.Copy() + } else if isAdapterReg { + adapter = adapterReg.Copy() + } + + featuresCopy := make(map[int]bool, len(e.features)) + for k, v := range e.features { + featuresCopy[k] = v + } + appliedFeaturesCopy := make(map[int]bool, len(e.appliedFeatures)) + for k, v := range e.appliedFeatures { + appliedFeaturesCopy[k] = v + } + funcsCopy := make(map[string]*decls.FunctionDecl, len(e.functions)) + for k, v := range e.functions { + funcsCopy[k] = v + } + libsCopy := make(map[string]bool, len(e.libraries)) + for k, v := range e.libraries { + libsCopy[k] = v + } + validatorsCopy := make([]ASTValidator, len(e.validators)) + copy(validatorsCopy, e.validators) + costOptsCopy := make([]checker.CostOption, len(e.costOptions)) + copy(costOptsCopy, e.costOptions) + + ext := &Env{ + Container: e.Container, + variables: varsCopy, + functions: funcsCopy, + macros: macsCopy, + progOpts: progOptsCopy, + adapter: adapter, + features: featuresCopy, + appliedFeatures: appliedFeaturesCopy, + libraries: libsCopy, + validators: validatorsCopy, + provider: provider, + chkOpts: chkOptsCopy, + prsrOpts: prsrOptsCopy, + costOptions: costOptsCopy, + } + return ext.configure(opts) +} + +// HasFeature checks whether the environment enables the given feature +// flag, as enumerated in options.go. +func (e *Env) HasFeature(flag int) bool { + enabled, has := e.features[flag] + return has && enabled +} + +// HasLibrary returns whether a specific SingletonLibrary has been configured in the environment. +func (e *Env) HasLibrary(libName string) bool { + configured, exists := e.libraries[libName] + return exists && configured +} + +// Libraries returns a list of SingletonLibrary that have been configured in the environment. +func (e *Env) Libraries() []string { + libraries := make([]string, 0, len(e.libraries)) + for libName := range e.libraries { + libraries = append(libraries, libName) + } + return libraries +} + +// HasValidator returns whether a specific ASTValidator has been configured in the environment. +func (e *Env) HasValidator(name string) bool { + for _, v := range e.validators { + if v.Name() == name { + return true + } + } + return false +} + +// Parse parses the input expression value `txt` to a Ast and/or a set of Issues. +// +// This form of Parse creates a Source value for the input `txt` and forwards to the +// ParseSource method. +func (e *Env) Parse(txt string) (*Ast, *Issues) { + src := common.NewTextSource(txt) + return e.ParseSource(src) +} + +// ParseSource parses the input source to an Ast and/or set of Issues. +// +// Parsing has failed if the returned Issues value and its Issues.Err() value is non-nil. +// Issues should be inspected if they are non-nil, but may not represent a fatal error. +// +// It is possible to have both non-nil Ast and Issues values returned from this call; however, +// the mere presence of an Ast does not imply that it is valid for use. +func (e *Env) ParseSource(src Source) (*Ast, *Issues) { + parsed, errs := e.prsr.Parse(src) + if len(errs.GetErrors()) > 0 { + return nil, &Issues{errs: errs} + } + return &Ast{source: src, impl: parsed}, nil +} + +// Program generates an evaluable instance of the Ast within the environment (Env). +func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) { + optSet := e.progOpts + if len(opts) != 0 { + mergedOpts := []ProgramOption{} + mergedOpts = append(mergedOpts, e.progOpts...) + mergedOpts = append(mergedOpts, opts...) + optSet = mergedOpts + } + return newProgram(e, ast, optSet) +} + +// CELTypeAdapter returns the `types.Adapter` configured for the environment. +func (e *Env) CELTypeAdapter() types.Adapter { + return e.adapter +} + +// CELTypeProvider returns the `types.Provider` configured for the environment. +func (e *Env) CELTypeProvider() types.Provider { + return e.provider +} + +// TypeAdapter returns the `ref.TypeAdapter` configured for the environment. +// +// Deprecated: use CELTypeAdapter() +func (e *Env) TypeAdapter() ref.TypeAdapter { + return e.adapter +} + +// TypeProvider returns the `ref.TypeProvider` configured for the environment. +// +// Deprecated: use CELTypeProvider() +func (e *Env) TypeProvider() ref.TypeProvider { + if legacyProvider, ok := e.provider.(ref.TypeProvider); ok { + return legacyProvider + } + return &interopLegacyTypeProvider{Provider: e.provider} +} + +// UnknownVars returns an interpreter.PartialActivation which marks all variables declared in the +// Env as unknown AttributePattern values. +// +// Note, the UnknownVars will behave the same as an interpreter.EmptyActivation unless the +// PartialAttributes option is provided as a ProgramOption. +func (e *Env) UnknownVars() interpreter.PartialActivation { + act := interpreter.EmptyActivation() + part, _ := PartialVars(act, e.computeUnknownVars(act)...) + return part +} + +// PartialVars returns an interpreter.PartialActivation where all variables not in the input variable +// set, but which have been configured in the environment, are marked as unknown. +// +// The `vars` value may either be an interpreter.Activation or any valid input to the +// interpreter.NewActivation call. +// +// Note, this is equivalent to calling cel.PartialVars and manually configuring the set of unknown +// variables. For more advanced use cases of partial state where portions of an object graph, rather +// than top-level variables, are missing the PartialVars() method may be a more suitable choice. +// +// Note, the PartialVars will behave the same as an interpreter.EmptyActivation unless the +// PartialAttributes option is provided as a ProgramOption. +func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) { + act, err := interpreter.NewActivation(vars) + if err != nil { + return nil, err + } + return PartialVars(act, e.computeUnknownVars(act)...) +} + +// ResidualAst takes an Ast and its EvalDetails to produce a new Ast which only contains the +// attribute references which are unknown. +// +// Residual expressions are beneficial in a few scenarios: +// +// - Optimizing constant expression evaluations away. +// - Indexing and pruning expressions based on known input arguments. +// - Surfacing additional requirements that are needed in order to complete an evaluation. +// - Sharing the evaluation of an expression across multiple machines/nodes. +// +// For example, if an expression targets a 'resource' and 'request' attribute and the possible +// values for the resource are known, a PartialActivation could mark the 'request' as an unknown +// interpreter.AttributePattern and the resulting ResidualAst would be reduced to only the parts +// of the expression that reference the 'request'. +// +// Note, the expression ids within the residual AST generated through this method have no +// correlation to the expression ids of the original AST. +// +// See the PartialVars helper for how to construct a PartialActivation. +// +// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an +// Ast format and then Program again. +func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { + pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State()) + newAST := &Ast{source: a.Source(), impl: pruned} + expr, err := AstToString(newAST) + if err != nil { + return nil, err + } + parsed, iss := e.Parse(expr) + if iss != nil && iss.Err() != nil { + return nil, iss.Err() + } + if !a.IsChecked() { + return parsed, nil + } + checked, iss := e.Check(parsed) + if iss != nil && iss.Err() != nil { + return nil, iss.Err() + } + return checked, nil +} + +// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and +// extension functions provided by estimator. +func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) { + extendedOpts := make([]checker.CostOption, 0, len(e.costOptions)) + extendedOpts = append(extendedOpts, opts...) + extendedOpts = append(extendedOpts, e.costOptions...) + return checker.Cost(ast.impl, estimator, extendedOpts...) +} + +// configure applies a series of EnvOptions to the current environment. +func (e *Env) configure(opts []EnvOption) (*Env, error) { + // Customized the environment using the provided EnvOption values. If an error is + // generated at any step this, will be returned as a nil Env with a non-nil error. + var err error + for _, opt := range opts { + e, err = opt(e) + if err != nil { + return nil, err + } + } + + // If the default UTC timezone fix has been enabled, make sure the library is configured + e, err = e.maybeApplyFeature(featureDefaultUTCTimeZone, Lib(timeUTCLibrary{})) + if err != nil { + return nil, err + } + + // Configure the parser. + prsrOpts := []parser.Option{} + prsrOpts = append(prsrOpts, e.prsrOpts...) + prsrOpts = append(prsrOpts, parser.Macros(e.macros...)) + + if e.HasFeature(featureEnableMacroCallTracking) { + prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true)) + } + if e.HasFeature(featureVariadicLogicalASTs) { + prsrOpts = append(prsrOpts, parser.EnableVariadicOperatorASTs(true)) + } + e.prsr, err = parser.NewParser(prsrOpts...) + if err != nil { + return nil, err + } + + // Ensure that the checker init happens eagerly rather than lazily. + if e.HasFeature(featureEagerlyValidateDeclarations) { + _, err := e.initChecker() + if err != nil { + return nil, err + } + } + + return e, nil +} + +func (e *Env) initChecker() (*checker.Env, error) { + e.chkOnce.Do(func() { + chkOpts := []checker.Option{} + chkOpts = append(chkOpts, e.chkOpts...) + chkOpts = append(chkOpts, + checker.CrossTypeNumericComparisons( + e.HasFeature(featureCrossTypeNumericComparisons))) + + ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...) + if err != nil { + e.setCheckerOrError(nil, err) + return + } + // Add the statically configured declarations. + err = ce.AddIdents(e.variables...) + if err != nil { + e.setCheckerOrError(nil, err) + return + } + // Add the function declarations which are derived from the FunctionDecl instances. + for _, fn := range e.functions { + if fn.IsDeclarationDisabled() { + continue + } + err = ce.AddFunctions(fn) + if err != nil { + e.setCheckerOrError(nil, err) + return + } + } + // Add function declarations here separately. + e.setCheckerOrError(ce, nil) + }) + return e.getCheckerOrError() +} + +// setCheckerOrError sets the checker.Env or error state in a concurrency-safe manner +func (e *Env) setCheckerOrError(chk *checker.Env, chkErr error) { + e.chkMutex.Lock() + e.chk = chk + e.chkErr = chkErr + e.chkMutex.Unlock() +} + +// getCheckerOrError gets the checker.Env or error state in a concurrency-safe manner +func (e *Env) getCheckerOrError() (*checker.Env, error) { + e.chkMutex.Lock() + defer e.chkMutex.Unlock() + return e.chk, e.chkErr +} + +// maybeApplyFeature determines whether the feature-guarded option is enabled, and if so applies +// the feature if it has not already been enabled. +func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) { + if !e.HasFeature(feature) { + return e, nil + } + _, applied := e.appliedFeatures[feature] + if applied { + return e, nil + } + e, err := option(e) + if err != nil { + return nil, err + } + // record that the feature has been applied since it will generate declarations + // and functions which will be propagated on Extend() calls and which should only + // be registered once. + e.appliedFeatures[feature] = true + return e, nil +} + +// computeUnknownVars determines a set of missing variables based on the input activation and the +// environment's configured declaration set. +func (e *Env) computeUnknownVars(vars interpreter.Activation) []*interpreter.AttributePattern { + var unknownPatterns []*interpreter.AttributePattern + for _, v := range e.variables { + varName := v.Name() + if _, found := vars.ResolveName(varName); found { + continue + } + unknownPatterns = append(unknownPatterns, interpreter.NewAttributePattern(varName)) + } + return unknownPatterns +} + +// Error type which references an expression id, a location within source, and a message. +type Error = common.Error + +// Issues defines methods for inspecting the error details of parse and check calls. +// +// Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct. +type Issues struct { + errs *common.Errors + info *celast.SourceInfo +} + +// NewIssues returns an Issues struct from a common.Errors object. +func NewIssues(errs *common.Errors) *Issues { + return NewIssuesWithSourceInfo(errs, nil) +} + +// NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata +// which can be used with the `ReportErrorAtID` method for additional error reports within the context +// information that's inferred from an expression id. +func NewIssuesWithSourceInfo(errs *common.Errors, info *celast.SourceInfo) *Issues { + return &Issues{ + errs: errs, + info: info, + } +} + +// Err returns an error value if the issues list contains one or more errors. +func (i *Issues) Err() error { + if i == nil { + return nil + } + if len(i.Errors()) > 0 { + return errors.New(i.String()) + } + return nil +} + +// Errors returns the collection of errors encountered in more granular detail. +func (i *Issues) Errors() []*Error { + if i == nil { + return []*Error{} + } + return i.errs.GetErrors() +} + +// Append collects the issues from another Issues struct into a new Issues object. +func (i *Issues) Append(other *Issues) *Issues { + if i == nil { + return other + } + if other == nil { + return i + } + return NewIssues(i.errs.Append(other.errs.GetErrors())) +} + +// String converts the issues to a suitable display string. +func (i *Issues) String() string { + if i == nil { + return "" + } + return i.errs.ToDisplayString() +} + +// ReportErrorAtID reports an error message with an optional set of formatting arguments. +// +// The source metadata for the expression at `id`, if present, is attached to the error report. +// To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo. +func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) { + i.errs.ReportErrorAtID(id, i.info.GetStartLocation(id), message, args...) +} + +// getStdEnv lazy initializes the CEL standard environment. +func getStdEnv() (*Env, error) { + stdEnvInit.Do(func() { + stdEnv, stdEnvErr = NewCustomEnv(StdLib(), EagerlyValidateDeclarations(true)) + }) + return stdEnv, stdEnvErr +} + +// interopCELTypeProvider layers support for the types.Provider interface on top of a ref.TypeProvider. +type interopCELTypeProvider struct { + ref.TypeProvider +} + +// FindStructType returns a types.Type instance for the given fully-qualified typeName if one exists. +// +// This method proxies to the underlying ref.TypeProvider's FindType method and converts protobuf type +// into a native type representation. If the conversion fails, the type is listed as not found. +func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, bool) { + if et, found := p.FindType(typeName); found { + t, err := types.ExprTypeToType(et) + if err != nil { + return nil, false + } + return t, true + } + return nil, false +} + +// FindStructFieldNames returns an empty set of field for the interop provider. +// +// To inspect the field names, migrate to a `types.Provider` implementation. +func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) { + return []string{}, false +} + +// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field +// name, if one exists. +// +// This method proxies to the underlying ref.TypeProvider's FindFieldType method and converts protobuf type +// into a native type representation. If the conversion fails, the type is listed as not found. +func (p *interopCELTypeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + if ft, found := p.FindFieldType(structType, fieldName); found { + t, err := types.ExprTypeToType(ft.Type) + if err != nil { + return nil, false + } + return &types.FieldType{ + Type: t, + IsSet: ft.IsSet, + GetFrom: ft.GetFrom, + }, true + } + return nil, false +} + +// interopLegacyTypeProvider layers support for the ref.TypeProvider interface on top of a types.Provider. +type interopLegacyTypeProvider struct { + types.Provider +} + +// FindType retruns the protobuf Type representation for the input type name if one exists. +// +// This method proxies to the underlying types.Provider FindStructType method and converts the types.Type +// value to a protobuf Type representation. +// +// Failure to convert the type will result in the type not being found. +func (p *interopLegacyTypeProvider) FindType(typeName string) (*exprpb.Type, bool) { + if t, found := p.FindStructType(typeName); found { + et, err := types.TypeToExprType(t) + if err != nil { + return nil, false + } + return et, true + } + return nil, false +} + +// FindFieldType returns the protobuf-based FieldType representation for the input type name and field, +// if one exists. +// +// This call proxies to the types.Provider FindStructFieldType method and converts the types.FIeldType +// value to a protobuf-based ref.FieldType representation if found. +// +// Failure to convert the FieldType will result in the field not being found. +func (p *interopLegacyTypeProvider) FindFieldType(structType, fieldName string) (*ref.FieldType, bool) { + if cft, found := p.FindStructFieldType(structType, fieldName); found { + et, err := types.TypeToExprType(cft.Type) + if err != nil { + return nil, false + } + return &ref.FieldType{ + Type: et, + IsSet: cft.IsSet, + GetFrom: cft.GetFrom, + }, true + } + return nil, false +} + +var ( + stdEnvInit sync.Once + stdEnv *Env + stdEnvErr error +) diff --git a/vendor/github.com/authzed/cel-go/cel/folding.go b/vendor/github.com/authzed/cel-go/cel/folding.go new file mode 100644 index 0000000..9a2d904 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/folding.go @@ -0,0 +1,559 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" +) + +// ConstantFoldingOption defines a functional option for configuring constant folding. +type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) + +// MaxConstantFoldIterations limits the number of times literals may be folding during optimization. +// +// Defaults to 100 if not set. +func MaxConstantFoldIterations(limit int) ConstantFoldingOption { + return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) { + opt.maxFoldIterations = limit + return opt, nil + } +} + +// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate +// literal values within function calls and select statements with their evaluated result. +func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) { + folder := &constantFoldingOptimizer{ + maxFoldIterations: defaultMaxConstantFoldIterations, + } + var err error + for _, o := range opts { + folder, err = o(folder) + if err != nil { + return nil, err + } + } + return folder, nil +} + +type constantFoldingOptimizer struct { + maxFoldIterations int +} + +// Optimize queries the expression graph for scalar and aggregate literal expressions within call and +// select statements and then evaluates them and replaces the call site with the literal result. +// +// Note: only values which can be represented as literals in CEL syntax are supported. +func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + + // Walk the list of foldable expression and continue to fold until there are no more folds left. + // All of the fold candidates returned by the constantExprMatcher should succeed unless there's + // a logic bug with the selection of expressions. + foldableExprs := ast.MatchDescendants(root, constantExprMatcher) + foldCount := 0 + for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations { + for _, fold := range foldableExprs { + // If the expression could be folded because it's a non-strict call, and the + // branches are pruned, continue to the next fold. + if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) { + continue + } + // Otherwise, assume all context is needed to evaluate the expression. + err := tryFold(ctx, a, fold) + if err != nil { + ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error()) + return a + } + } + foldCount++ + foldableExprs = ast.MatchDescendants(root, constantExprMatcher) + } + // Once all of the constants have been folded, try to run through the remaining comprehensions + // one last time. In this case, there's no guarantee they'll run, so we only update the + // target comprehension node with the literal value if the evaluation succeeds. + for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) { + tryFold(ctx, a, compre) + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + pruneOptionalElements(ctx, root) + + // Ensure that all intermediate values in the folded expression can be represented as valid + // CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents` + // to avoid extra allocations during this final pass through the AST. + ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() != ast.LiteralKind { + return + } + val := e.AsLiteral() + adapted, err := adaptLiteral(ctx, val) + if err != nil { + ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error()) + return + } + ctx.UpdateExpr(e, adapted) + })) + + return a +} + +// tryFold attempts to evaluate a sub-expression to a literal. +// +// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise +// the method will return an error. +func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { + // Assume all context is needed to evaluate the expression. + subAST := &Ast{ + impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()), + } + prg, err := ctx.Program(subAST) + if err != nil { + return err + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + return err + } + // Update the fold expression to be a literal. + ctx.UpdateExpr(expr, ctx.NewLiteral(out)) + return nil +} + +// maybePruneBranches inspects the non-strict call expression to determine whether +// a branch can be removed. Evaluation will naturally prune logical and / or calls, +// but conditional will not be pruned cleanly, so this is one small area where the +// constant folding step reimplements a portion of the evaluator. +func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool { + call := expr.AsCall() + args := call.Args() + switch call.FunctionName() { + case operators.LogicalAnd, operators.LogicalOr: + return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr) + case operators.Conditional: + cond := args[0] + truthy := args[1] + falsy := args[2] + if cond.Kind() != ast.LiteralKind { + return false + } + if cond.AsLiteral() == types.True { + ctx.UpdateExpr(expr, truthy) + } else { + ctx.UpdateExpr(expr, falsy) + } + return true + case operators.In: + haystack := args[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + ctx.UpdateExpr(expr, ctx.NewLiteral(types.False)) + return true + } + needle := args[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + ctx.UpdateExpr(expr, ctx.NewLiteral(types.True)) + return true + } + } + } + } + return false +} + +func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool { + shortcircuit := types.False + skip := types.True + if function == operators.LogicalOr { + shortcircuit = types.True + skip = types.False + } + newArgs := []ast.Expr{} + for _, arg := range args { + if arg.Kind() != ast.LiteralKind { + newArgs = append(newArgs, arg) + continue + } + if arg.AsLiteral() == skip { + continue + } + if arg.AsLiteral() == shortcircuit { + ctx.UpdateExpr(expr, arg) + return true + } + } + if len(newArgs) == 0 { + newArgs = append(newArgs, args[0]) + ctx.UpdateExpr(expr, newArgs[0]) + return true + } + if len(newArgs) == 1 { + ctx.UpdateExpr(expr, newArgs[0]) + return true + } + ctx.UpdateExpr(expr, ctx.NewCall(function, newArgs...)) + return true +} + +// pruneOptionalElements works from the bottom up to resolve optional elements within +// aggregate literals. +// +// Note, many aggregate literals will be resolved as arguments to functions or select +// statements, so this method exists to handle the case where the literal could not be +// fully resolved or exists outside of a call, select, or comprehension context. +func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) { + aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher) + for _, lit := range aggregateLiterals { + switch lit.Kind() { + case ast.ListKind: + pruneOptionalListElements(ctx, lit) + case ast.MapKind: + pruneOptionalMapEntries(ctx, lit) + case ast.StructKind: + pruneOptionalStructFields(ctx, lit) + } + } +} + +func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) { + l := e.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + if len(optIndices) == 0 { + return + } + updatedElems := []ast.Expr{} + updatedIndices := []int32{} + newOptIndex := -1 + for _, e := range elems { + newOptIndex++ + if !l.IsOptional(int32(newOptIndex)) { + updatedElems = append(updatedElems, e) + continue + } + if e.Kind() != ast.LiteralKind { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(newOptIndex)) + continue + } + optElemVal, ok := e.AsLiteral().(*types.Optional) + if !ok { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(newOptIndex)) + continue + } + if !optElemVal.HasValue() { + newOptIndex-- // Skipping causes the list to get smaller. + continue + } + ctx.UpdateExpr(e, ctx.NewLiteral(optElemVal.GetValue())) + updatedElems = append(updatedElems, e) + } + ctx.UpdateExpr(e, ctx.NewList(updatedElems, updatedIndices)) +} + +func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { + m := e.AsMap() + entries := m.Entries() + updatedEntries := []ast.EntryExpr{} + modified := false + for _, e := range entries { + entry := e.AsMapEntry() + key := entry.Key() + val := entry.Value() + // If the entry is not optional, or the value-side of the optional hasn't + // been resolved to a literal, then preserve the entry as-is. + if !entry.IsOptional() || val.Kind() != ast.LiteralKind { + updatedEntries = append(updatedEntries, e) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedEntries = append(updatedEntries, e) + continue + } + // When the key is not a literal, but the value is, then it needs to be + // restored to an optional value. + if key.Kind() != ast.LiteralKind { + undoOptVal, err := adaptLiteral(ctx, optElemVal) + if err != nil { + ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err) + } + ctx.UpdateExpr(val, undoOptVal) + updatedEntries = append(updatedEntries, e) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue())) + updatedEntry := ctx.NewMapEntry(key, val, false) + updatedEntries = append(updatedEntries, updatedEntry) + } + if modified { + ctx.UpdateExpr(e, ctx.NewMap(updatedEntries)) + } +} + +func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) { + s := e.AsStruct() + fields := s.Fields() + updatedFields := []ast.EntryExpr{} + modified := false + for _, f := range fields { + field := f.AsStructField() + val := field.Value() + if !field.IsOptional() || val.Kind() != ast.LiteralKind { + updatedFields = append(updatedFields, f) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedFields = append(updatedFields, f) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + ctx.UpdateExpr(val, ctx.NewLiteral(optElemVal.GetValue())) + updatedField := ctx.NewStructField(field.Name(), val, false) + updatedFields = append(updatedFields, updatedField) + } + if modified { + ctx.UpdateExpr(e, ctx.NewStruct(s.TypeName(), updatedFields)) + } +} + +// adaptLiteral converts a runtime CEL value to its equivalent literal expression. +// +// For strongly typed values, the type-provider will be used to reconstruct the fields +// which are present in the literal and their equivalent initialization values. +func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) { + switch t := val.Type().(type) { + case *types.Type: + switch t { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + return ctx.NewLiteral(val), nil + case types.DurationType: + return ctx.NewCall( + overloads.TypeConvertDuration, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.TimestampType: + return ctx.NewCall( + overloads.TypeConvertTimestamp, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.OptionalType: + opt := val.(*types.Optional) + if !opt.HasValue() { + return ctx.NewCall("optional.none"), nil + } + target, err := adaptLiteral(ctx, opt.GetValue()) + if err != nil { + return nil, err + } + return ctx.NewCall("optional.of", target), nil + case types.TypeType: + return ctx.NewIdent(val.(*types.Type).TypeName()), nil + case types.ListType: + l, ok := val.(traits.Lister) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + elems := make([]ast.Expr, l.Size().(types.Int)) + idx := 0 + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + elemExpr, err := adaptLiteral(ctx, elemVal) + if err != nil { + return nil, err + } + elems[idx] = elemExpr + idx++ + } + return ctx.NewList(elems, []int32{}), nil + case types.MapType: + m, ok := val.(traits.Mapper) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + entries := make([]ast.EntryExpr, m.Size().(types.Int)) + idx := 0 + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + keyExpr, err := adaptLiteral(ctx, keyVal) + if err != nil { + return nil, err + } + valVal := m.Get(keyVal) + valExpr, err := adaptLiteral(ctx, valVal) + if err != nil { + return nil, err + } + entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false) + idx++ + } + return ctx.NewMap(entries), nil + default: + provider := ctx.CELTypeProvider() + fields, found := provider.FindStructFieldNames(t.TypeName()) + if !found { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + tester := val.(traits.FieldTester) + indexer := val.(traits.Indexer) + fieldInits := []ast.EntryExpr{} + for _, f := range fields { + field := types.String(f) + if tester.IsSet(field) != types.True { + continue + } + fieldVal := indexer.Get(field) + fieldExpr, err := adaptLiteral(ctx, fieldVal) + if err != nil { + return nil, err + } + fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false)) + } + return ctx.NewStruct(t.TypeName(), fieldInits), nil + } + } + return nil, fmt.Errorf("failed to adapt %v to literal", val) +} + +// constantExprMatcher matches calls, select statements, and comprehensions whose arguments +// are all constant scalar or aggregate literal values. +// +// Only comprehensions which are not nested are included as possible constant folds, and only +// if all variables referenced in the comprehension stack exist are only iteration or +// accumulation variables. +func constantExprMatcher(e ast.NavigableExpr) bool { + switch e.Kind() { + case ast.CallKind: + return constantCallMatcher(e) + case ast.SelectKind: + sel := e.AsSelect() // guaranteed to be a navigable value + return constantMatcher(sel.Operand().(ast.NavigableExpr)) + case ast.ComprehensionKind: + if isNestedComprehension(e) { + return false + } + vars := map[string]bool{} + constantExprs := true + visitor := ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() == ast.ComprehensionKind { + nested := e.AsComprehension() + vars[nested.AccuVar()] = true + vars[nested.IterVar()] = true + } + if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { + constantExprs = false + } + }) + ast.PreOrderVisit(e, visitor) + return constantExprs + default: + return false + } +} + +// constantCallMatcher identifies strict and non-strict calls which can be folded. +func constantCallMatcher(e ast.NavigableExpr) bool { + call := e.AsCall() + children := e.Children() + fnName := call.FunctionName() + if fnName == operators.LogicalAnd { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.LogicalOr { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.Conditional { + cond := children[0] + if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType { + return true + } + } + if fnName == operators.In { + haystack := children[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + return true + } + needle := children[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + return true + } + } + } + } + // convert all other calls with constant arguments + for _, child := range children { + if !constantMatcher(child) { + return false + } + } + return true +} + +func isNestedComprehension(e ast.NavigableExpr) bool { + parent, found := e.Parent() + for found { + if parent.Kind() == ast.ComprehensionKind { + return true + } + parent, found = parent.Parent() + } + return false +} + +func aggregateLiteralMatcher(e ast.NavigableExpr) bool { + return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind +} + +var ( + constantMatcher = ast.ConstantValueMatcher() +) + +const ( + defaultMaxConstantFoldIterations = 100 +) diff --git a/vendor/github.com/authzed/cel-go/cel/inlining.go b/vendor/github.com/authzed/cel-go/cel/inlining.go new file mode 100644 index 0000000..2ffe82f --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/inlining.go @@ -0,0 +1,228 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/containers" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/traits" +) + +// InlineVariable holds a variable name to be matched and an AST representing +// the expression graph which should be used to replace it. +type InlineVariable struct { + name string + alias string + def *ast.AST +} + +// Name returns the qualified variable or field selection to replace. +func (v *InlineVariable) Name() string { + return v.name +} + +// Alias returns the alias to use when performing cel.bind() calls during inlining. +func (v *InlineVariable) Alias() string { + return v.alias +} + +// Expr returns the inlined expression value. +func (v *InlineVariable) Expr() ast.Expr { + return v.def.Expr() +} + +// Type indicates the inlined expression type. +func (v *InlineVariable) Type() *Type { + return v.def.GetType(v.def.Expr().ID()) +} + +// NewInlineVariable declares a variable name to be replaced by a checked expression. +func NewInlineVariable(name string, definition *Ast) *InlineVariable { + return NewInlineVariableWithAlias(name, name, definition) +} + +// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression. +// If the variable occurs more than once, the provided alias will be used to replace the expressions +// where the variable name occurs. +func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable { + return &InlineVariable{name: name, alias: alias, def: definition.impl} +} + +// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions. +// +// If a variable occurs one time, the variable is replaced by the inline definition. If the +// variable occurs more than once, the variable occurences are replaced by a cel.bind() call. +func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer { + return &inliningOptimizer{variables: inlineVars} +} + +type inliningOptimizer struct { + variables []*InlineVariable +} + +func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + for _, inlineVar := range opt.variables { + matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name())) + // Skip cases where the variable isn't in the expression graph + if len(matches) == 0 { + continue + } + + // For a single match, do a direct replacement of the expression sub-graph. + if len(matches) == 1 || !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) { + for _, match := range matches { + // Copy the inlined AST expr and source info. + copyExpr := ctx.CopyASTAndMetadata(inlineVar.def) + opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type()) + } + continue + } + + // For multiple matches, find the least common ancestor (lca) and insert the + // variable as a cel.bind() macro. + var lca ast.NavigableExpr = root + lcaAncestorCount := 0 + ancestors := map[int64]int{} + for _, match := range matches { + // Update the identifier matches with the provided alias. + parent, found := match, true + for found { + ancestorCount, hasAncestor := ancestors[parent.ID()] + if !hasAncestor { + ancestors[parent.ID()] = 1 + parent, found = parent.Parent() + continue + } + if lcaAncestorCount < ancestorCount || (lcaAncestorCount == ancestorCount && lca.Depth() < parent.Depth()) { + lca = parent + lcaAncestorCount = ancestorCount + } + ancestors[parent.ID()] = ancestorCount + 1 + parent, found = parent.Parent() + } + aliasExpr := ctx.NewIdent(inlineVar.Alias()) + opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type()) + } + + // Copy the inlined AST expr and source info. + copyExpr := ctx.CopyASTAndMetadata(inlineVar.def) + // Update the least common ancestor by inserting a cel.bind() call to the alias. + inlined, bindMacro := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), copyExpr, lca) + opt.inlineExpr(ctx, lca, inlined, inlineVar.Type()) + ctx.SetMacroCall(lca.ID(), bindMacro) + } + return a +} + +// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining +// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is +// made to determine whether the inlined value can be presence or existence tested. +func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) { + switch prev.Kind() { + case ast.SelectKind: + sel := prev.AsSelect() + if !sel.IsTestOnly() { + ctx.UpdateExpr(prev, inlined) + return + } + opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType) + default: + ctx.UpdateExpr(prev, inlined) + } +} + +// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe +// expression appropriate for the inlined type, if possible. +// +// If the rewrite is not possible an error is reported at the inline expression site. +func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) { + // If the input inlined expression is not a select expression it won't work with the has() + // macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error. + if inlined.Kind() == ast.SelectKind { + presenceTest, hasMacro := ctx.NewHasMacro(prev.ID(), inlined) + ctx.UpdateExpr(prev, presenceTest) + ctx.SetMacroCall(prev.ID(), hasMacro) + return + } + + ctx.ClearMacroCall(prev.ID()) + if inlinedType.IsAssignableType(NullType) { + ctx.UpdateExpr(prev, + ctx.NewCall(operators.NotEquals, + inlined, + ctx.NewLiteral(types.NullValue), + )) + return + } + if inlinedType.HasTrait(traits.SizerType) { + ctx.UpdateExpr(prev, + ctx.NewCall(operators.NotEquals, + ctx.NewMemberCall(overloads.Size, inlined), + ctx.NewLiteral(types.IntZero), + )) + return + } + ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType) +} + +// isBindable indicates whether the inlined type can be used within a cel.bind() if the expression +// being replaced occurs within a presence test. Value types with a size() method or field selection +// support can be bound. +// +// In future iterations, support may also be added for indexer types which can be rewritten as an `in` +// expression; however, this would imply a rewrite of the inlined expression that may not be necessary +// in most cases. +func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) bool { + if inlinedType.IsAssignableType(NullType) || + inlinedType.HasTrait(traits.SizerType) { + return true + } + for _, m := range matches { + if m.Kind() != ast.SelectKind { + continue + } + sel := m.AsSelect() + if sel.IsTestOnly() { + return false + } + } + return true +} + +// matchVariable matches simple identifiers, select expressions, and presence test expressions +// which match the (potentially) qualified variable name provided as input. +// +// Note, this function does not support inlining against select expressions which includes optional +// field selection. This may be a future refinement. +func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher { + return func(e ast.NavigableExpr) bool { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + return true + } + if e.Kind() == ast.SelectKind { + sel := e.AsSelect() + // While the `ToQualifiedName` call could take the select directly, this + // would skip presence tests from possible matches, which we would like + // to include. + qualName, found := containers.ToQualifiedName(sel.Operand()) + return found && qualName+"."+sel.FieldName() == varName + } + return false + } +} diff --git a/vendor/github.com/authzed/cel-go/cel/io.go b/vendor/github.com/authzed/cel-go/cel/io.go new file mode 100644 index 0000000..3a9b2b1 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/io.go @@ -0,0 +1,252 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "errors" + "fmt" + "reflect" + + "google.golang.org/protobuf/proto" + + "github.com/authzed/cel-go/common" + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" + "github.com/authzed/cel-go/parser" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + anypb "google.golang.org/protobuf/types/known/anypb" +) + +// CheckedExprToAst converts a checked expression proto message to an Ast. +func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast { + checked, _ := CheckedExprToAstWithSource(checkedExpr, nil) + return checked +} + +// CheckedExprToAstWithSource converts a checked expression proto message to an Ast, +// using the provided Source as the textual contents. +// +// In general the source is not necessary unless the AST has been modified between the +// `Parse` and `Check` calls as an `Ast` created from the `Parse` step will carry the source +// through future calls. +// +// Prefer CheckedExprToAst if loading expressions from storage. +func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) { + checked, err := ast.ToAST(checkedExpr) + if err != nil { + return nil, err + } + return &Ast{source: src, impl: checked}, nil +} + +// AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value. +// +// If the Ast.IsChecked() returns false, this conversion method will return an error. +func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) { + if !a.IsChecked() { + return nil, fmt.Errorf("cannot convert unchecked ast") + } + return ast.ToProto(a.impl) +} + +// ParsedExprToAst converts a parsed expression proto message to an Ast. +func ParsedExprToAst(parsedExpr *exprpb.ParsedExpr) *Ast { + return ParsedExprToAstWithSource(parsedExpr, nil) +} + +// ParsedExprToAstWithSource converts a parsed expression proto message to an Ast, +// using the provided Source as the textual contents. +// +// In general you only need this if you need to recheck a previously checked +// expression, or if you need to separately check a subset of an expression. +// +// Prefer ParsedExprToAst if loading expressions from storage. +func ParsedExprToAstWithSource(parsedExpr *exprpb.ParsedExpr, src Source) *Ast { + info, _ := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo()) + if src == nil { + src = common.NewInfoSource(parsedExpr.GetSourceInfo()) + } + e, _ := ast.ProtoToExpr(parsedExpr.GetExpr()) + return &Ast{source: src, impl: ast.NewAST(e, info)} +} + +// AstToParsedExpr converts an Ast to an protobuf ParsedExpr value. +func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) { + return &exprpb.ParsedExpr{ + Expr: a.Expr(), + SourceInfo: a.SourceInfo(), + }, nil +} + +// AstToString converts an Ast back to a string if possible. +// +// Note, the conversion may not be an exact replica of the original expression, but will produce +// a string that is semantically equivalent and whose textual representation is stable. +func AstToString(a *Ast) (string, error) { + return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo()) +} + +// RefValueToValue converts between ref.Val and api.expr.Value. +// The result Value is the serialized proto form. The ref.Val must not be error or unknown. +func RefValueToValue(res ref.Val) (*exprpb.Value, error) { + switch res.Type() { + case types.BoolType: + return &exprpb.Value{ + Kind: &exprpb.Value_BoolValue{BoolValue: res.Value().(bool)}}, nil + case types.BytesType: + return &exprpb.Value{ + Kind: &exprpb.Value_BytesValue{BytesValue: res.Value().([]byte)}}, nil + case types.DoubleType: + return &exprpb.Value{ + Kind: &exprpb.Value_DoubleValue{DoubleValue: res.Value().(float64)}}, nil + case types.IntType: + return &exprpb.Value{ + Kind: &exprpb.Value_Int64Value{Int64Value: res.Value().(int64)}}, nil + case types.ListType: + l := res.(traits.Lister) + sz := l.Size().(types.Int) + elts := make([]*exprpb.Value, 0, int64(sz)) + for i := types.Int(0); i < sz; i++ { + v, err := RefValueToValue(l.Get(i)) + if err != nil { + return nil, err + } + elts = append(elts, v) + } + return &exprpb.Value{ + Kind: &exprpb.Value_ListValue{ + ListValue: &exprpb.ListValue{Values: elts}}}, nil + case types.MapType: + mapper := res.(traits.Mapper) + sz := mapper.Size().(types.Int) + entries := make([]*exprpb.MapValue_Entry, 0, int64(sz)) + for it := mapper.Iterator(); it.HasNext().(types.Bool); { + k := it.Next() + v := mapper.Get(k) + kv, err := RefValueToValue(k) + if err != nil { + return nil, err + } + vv, err := RefValueToValue(v) + if err != nil { + return nil, err + } + entries = append(entries, &exprpb.MapValue_Entry{Key: kv, Value: vv}) + } + return &exprpb.Value{ + Kind: &exprpb.Value_MapValue{ + MapValue: &exprpb.MapValue{Entries: entries}}}, nil + case types.NullType: + return &exprpb.Value{ + Kind: &exprpb.Value_NullValue{}}, nil + case types.StringType: + return &exprpb.Value{ + Kind: &exprpb.Value_StringValue{StringValue: res.Value().(string)}}, nil + case types.TypeType: + typeName := res.(ref.Type).TypeName() + return &exprpb.Value{Kind: &exprpb.Value_TypeValue{TypeValue: typeName}}, nil + case types.UintType: + return &exprpb.Value{ + Kind: &exprpb.Value_Uint64Value{Uint64Value: res.Value().(uint64)}}, nil + default: + any, err := res.ConvertToNative(anyPbType) + if err != nil { + return nil, err + } + return &exprpb.Value{ + Kind: &exprpb.Value_ObjectValue{ObjectValue: any.(*anypb.Any)}}, nil + } +} + +var ( + typeNameToTypeValue = map[string]ref.Val{ + "bool": types.BoolType, + "bytes": types.BytesType, + "double": types.DoubleType, + "null_type": types.NullType, + "int": types.IntType, + "list": types.ListType, + "map": types.MapType, + "string": types.StringType, + "type": types.TypeType, + "uint": types.UintType, + } + + anyPbType = reflect.TypeOf(&anypb.Any{}) +) + +// ValueToRefValue converts between exprpb.Value and ref.Val. +func ValueToRefValue(adapter types.Adapter, v *exprpb.Value) (ref.Val, error) { + switch v.Kind.(type) { + case *exprpb.Value_NullValue: + return types.NullValue, nil + case *exprpb.Value_BoolValue: + return types.Bool(v.GetBoolValue()), nil + case *exprpb.Value_Int64Value: + return types.Int(v.GetInt64Value()), nil + case *exprpb.Value_Uint64Value: + return types.Uint(v.GetUint64Value()), nil + case *exprpb.Value_DoubleValue: + return types.Double(v.GetDoubleValue()), nil + case *exprpb.Value_StringValue: + return types.String(v.GetStringValue()), nil + case *exprpb.Value_BytesValue: + return types.Bytes(v.GetBytesValue()), nil + case *exprpb.Value_ObjectValue: + any := v.GetObjectValue() + msg, err := anypb.UnmarshalNew(any, proto.UnmarshalOptions{DiscardUnknown: true}) + if err != nil { + return nil, err + } + return adapter.NativeToValue(msg), nil + case *exprpb.Value_MapValue: + m := v.GetMapValue() + entries := make(map[ref.Val]ref.Val) + for _, entry := range m.Entries { + key, err := ValueToRefValue(adapter, entry.Key) + if err != nil { + return nil, err + } + pb, err := ValueToRefValue(adapter, entry.Value) + if err != nil { + return nil, err + } + entries[key] = pb + } + return adapter.NativeToValue(entries), nil + case *exprpb.Value_ListValue: + l := v.GetListValue() + elts := make([]ref.Val, len(l.Values)) + for i, e := range l.Values { + rv, err := ValueToRefValue(adapter, e) + if err != nil { + return nil, err + } + elts[i] = rv + } + return adapter.NativeToValue(elts), nil + case *exprpb.Value_TypeValue: + typeName := v.GetTypeValue() + tv, ok := typeNameToTypeValue[typeName] + if ok { + return tv, nil + } + return types.NewObjectTypeValue(typeName), nil + } + return nil, errors.New("unknown value") +} diff --git a/vendor/github.com/authzed/cel-go/cel/library.go b/vendor/github.com/authzed/cel-go/cel/library.go new file mode 100644 index 0000000..5d6dde4 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/library.go @@ -0,0 +1,790 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "math" + "strconv" + "strings" + "time" + + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/stdlib" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" + "github.com/authzed/cel-go/interpreter" + "github.com/authzed/cel-go/parser" +) + +const ( + optMapMacro = "optMap" + optFlatMapMacro = "optFlatMap" + hasValueFunc = "hasValue" + optionalNoneFunc = "optional.none" + optionalOfFunc = "optional.of" + optionalOfNonZeroValueFunc = "optional.ofNonZeroValue" + valueFunc = "value" + unusedIterVar = "#unused" +) + +// Library provides a collection of EnvOption and ProgramOption values used to configure a CEL +// environment for a particular use case or with a related set of functionality. +// +// Note, the ProgramOption values provided by a library are expected to be static and not vary +// between calls to Env.Program(). If there is a need for such dynamic configuration, prefer to +// configure these options outside the Library and within the Env.Program() call directly. +type Library interface { + // CompileOptions returns a collection of functional options for configuring the Parse / Check + // environment. + CompileOptions() []EnvOption + + // ProgramOptions returns a collection of functional options which should be included in every + // Program generated from the Env.Program() call. + ProgramOptions() []ProgramOption +} + +// SingletonLibrary refines the Library interface to ensure that libraries in this format are only +// configured once within the environment. +type SingletonLibrary interface { + Library + + // LibraryName provides a namespaced name which is used to check whether the library has already + // been configured in the environment. + LibraryName() string +} + +// Lib creates an EnvOption out of a Library, allowing libraries to be provided as functional args, +// and to be linked to each other. +func Lib(l Library) EnvOption { + singleton, isSingleton := l.(SingletonLibrary) + return func(e *Env) (*Env, error) { + if isSingleton { + if e.HasLibrary(singleton.LibraryName()) { + return e, nil + } + e.libraries[singleton.LibraryName()] = true + } + var err error + for _, opt := range l.CompileOptions() { + e, err = opt(e) + if err != nil { + return nil, err + } + } + e.progOpts = append(e.progOpts, l.ProgramOptions()...) + return e, nil + } +} + +// StdLib returns an EnvOption for the standard library of CEL functions and macros. +func StdLib() EnvOption { + return Lib(stdLibrary{}) +} + +// stdLibrary implements the Library interface and provides functional options for the core CEL +// features documented in the specification. +type stdLibrary struct{} + +// LibraryName implements the SingletonLibrary interface method. +func (stdLibrary) LibraryName() string { + return "cel.lib.std" +} + +// CompileOptions returns options for the standard CEL function declarations and macros. +func (stdLibrary) CompileOptions() []EnvOption { + return []EnvOption{ + func(e *Env) (*Env, error) { + var err error + for _, fn := range stdlib.Functions() { + existing, found := e.functions[fn.Name()] + if found { + fn, err = existing.Merge(fn) + if err != nil { + return nil, err + } + } + e.functions[fn.Name()] = fn + } + return e, nil + }, + func(e *Env) (*Env, error) { + e.variables = append(e.variables, stdlib.Types()...) + return e, nil + }, + Macros(StandardMacros...), + } +} + +// ProgramOptions returns function implementations for the standard CEL functions. +func (stdLibrary) ProgramOptions() []ProgramOption { + return []ProgramOption{} +} + +// OptionalTypes enable support for optional syntax and types in CEL. +// +// The optional value type makes it possible to express whether variables have +// been provided, whether a result has been computed, and in the future whether +// an object field path, map key value, or list index has a value. +// +// # Syntax Changes +// +// OptionalTypes are unlike other CEL extensions because they modify the CEL +// syntax itself, notably through the use of a `?` preceding a field name or +// index value. +// +// ## Field Selection +// +// The optional syntax in field selection is denoted as `obj.?field`. In other +// words, if a field is set, return `optional.of(obj.field)“, else +// `optional.none()`. The optional field selection is viral in the sense that +// after the first optional selection all subsequent selections or indices +// are treated as optional, i.e. the following expressions are equivalent: +// +// obj.?field.subfield +// obj.?field.?subfield +// +// ## Indexing +// +// Similar to field selection, the optional syntax can be used in index +// expressions on maps and lists: +// +// list[?0] +// map[?key] +// +// ## Optional Field Setting +// +// When creating map or message literals, if a field may be optionally set +// based on its presence, then placing a `?` before the field name or key +// will ensure the type on the right-hand side must be optional(T) where T +// is the type of the field or key-value. +// +// The following returns a map with the key expression set only if the +// subfield is present, otherwise an empty map is created: +// +// {?key: obj.?field.subfield} +// +// ## Optional Element Setting +// +// When creating list literals, an element in the list may be optionally added +// when the element expression is preceded by a `?`: +// +// [a, ?b, ?c] // return a list with either [a], [a, b], [a, b, c], or [a, c] +// +// # Optional.Of +// +// Create an optional(T) value of a given value with type T. +// +// optional.of(10) +// +// # Optional.OfNonZeroValue +// +// Create an optional(T) value of a given value with type T if it is not a +// zero-value. A zero-value the default empty value for any given CEL type, +// including empty protobuf message types. If the value is empty, the result +// of this call will be optional.none(). +// +// optional.ofNonZeroValue([1, 2, 3]) // optional(list(int)) +// optional.ofNonZeroValue([]) // optional.none() +// optional.ofNonZeroValue(0) // optional.none() +// optional.ofNonZeroValue("") // optional.none() +// +// # Optional.None +// +// Create an empty optional value. +// +// # HasValue +// +// Determine whether the optional contains a value. +// +// optional.of(b'hello').hasValue() // true +// optional.ofNonZeroValue({}).hasValue() // false +// +// # Value +// +// Get the value contained by the optional. If the optional does not have a +// value, the result will be a CEL error. +// +// optional.of(b'hello').value() // b'hello' +// optional.ofNonZeroValue({}).value() // error +// +// # Or +// +// If the value on the left-hand side is optional.none(), the optional value +// on the right hand side is returned. If the value on the left-hand set is +// valued, then it is returned. This operation is short-circuiting and will +// only evaluate as many links in the `or` chain as are needed to return a +// non-empty optional value. +// +// obj.?field.or(m[?key]) +// l[?index].or(obj.?field.subfield).or(obj.?other) +// +// # OrValue +// +// Either return the value contained within the optional on the left-hand side +// or return the alternative value on the right hand side. +// +// m[?key].orValue("none") +// +// # OptMap +// +// Apply a transformation to the optional's underlying value if it is not empty +// and return an optional typed result based on the transformation. The +// transformation expression type must return a type T which is wrapped into +// an optional. +// +// msg.?elements.optMap(e, e.size()).orValue(0) +// +// # OptFlatMap +// +// Introduced in version: 1 +// +// Apply a transformation to the optional's underlying value if it is not empty +// and return the result. The transform expression must return an optional(T) +// rather than type T. This can be useful when dealing with zero values and +// conditionally generating an empty or non-empty result in ways which cannot +// be expressed with `optMap`. +// +// msg.?elements.optFlatMap(e, e[?0]) // return the first element if present. +func OptionalTypes(opts ...OptionalTypesOption) EnvOption { + lib := &optionalLib{version: math.MaxUint32} + for _, opt := range opts { + lib = opt(lib) + } + return Lib(lib) +} + +type optionalLib struct { + version uint32 +} + +// OptionalTypesOption is a functional interface for configuring the strings library. +type OptionalTypesOption func(*optionalLib) *optionalLib + +// OptionalTypesVersion configures the version of the optional type library. +// +// The version limits which functions are available. Only functions introduced +// below or equal to the given version included in the library. If this option +// is not set, all functions are available. +// +// See the library documentation to determine which version a function was introduced. +// If the documentation does not state which version a function was introduced, it can +// be assumed to be introduced at version 0, when the library was first created. +func OptionalTypesVersion(version uint32) OptionalTypesOption { + return func(lib *optionalLib) *optionalLib { + lib.version = version + return lib + } +} + +// LibraryName implements the SingletonLibrary interface method. +func (lib *optionalLib) LibraryName() string { + return "cel.lib.optional" +} + +// CompileOptions implements the Library interface method. +func (lib *optionalLib) CompileOptions() []EnvOption { + paramTypeK := TypeParamType("K") + paramTypeV := TypeParamType("V") + optionalTypeV := OptionalType(paramTypeV) + listTypeV := ListType(paramTypeV) + mapTypeKV := MapType(paramTypeK, paramTypeV) + + opts := []EnvOption{ + // Enable the optional syntax in the parser. + enableOptionalSyntax(), + + // Introduce the optional type. + Types(types.OptionalType), + + // Configure the optMap and optFlatMap macros. + Macros(ReceiverMacro(optMapMacro, 2, optMap)), + + // Global and member functions for working with optional values. + Function(optionalOfFunc, + Overload("optional_of", []*Type{paramTypeV}, optionalTypeV, + UnaryBinding(func(value ref.Val) ref.Val { + return types.OptionalOf(value) + }))), + Function(optionalOfNonZeroValueFunc, + Overload("optional_ofNonZeroValue", []*Type{paramTypeV}, optionalTypeV, + UnaryBinding(func(value ref.Val) ref.Val { + v, isZeroer := value.(traits.Zeroer) + if !isZeroer || !v.IsZeroValue() { + return types.OptionalOf(value) + } + return types.OptionalNone + }))), + Function(optionalNoneFunc, + Overload("optional_none", []*Type{}, optionalTypeV, + FunctionBinding(func(values ...ref.Val) ref.Val { + return types.OptionalNone + }))), + Function(valueFunc, + MemberOverload("optional_value", []*Type{optionalTypeV}, paramTypeV, + UnaryBinding(func(value ref.Val) ref.Val { + opt := value.(*types.Optional) + return opt.GetValue() + }))), + Function(hasValueFunc, + MemberOverload("optional_hasValue", []*Type{optionalTypeV}, BoolType, + UnaryBinding(func(value ref.Val) ref.Val { + opt := value.(*types.Optional) + return types.Bool(opt.HasValue()) + }))), + + // Implementation of 'or' and 'orValue' are special-cased to support short-circuiting in the + // evaluation chain. + Function("or", + MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV)), + Function("orValue", + MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV)), + + // OptSelect is handled specially by the type-checker, so the receiver's field type is used to determine the + // optput type. + Function(operators.OptSelect, + Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV)), + + // OptIndex is handled mostly like any other indexing operation on a list or map, so the type-checker can use + // these signatures to determine type-agreement without any special handling. + Function(operators.OptIndex, + Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV), + Overload("optional_list_optindex_optional_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV), + Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV), + Overload("optional_map_optindex_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)), + + // Index overloads to accommodate using an optional value as the operand. + Function(operators.Index, + Overload("optional_list_index_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV), + Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)), + } + if lib.version >= 1 { + opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap))) + } + return opts +} + +// ProgramOptions implements the Library interface method. +func (lib *optionalLib) ProgramOptions() []ProgramOption { + return []ProgramOption{ + CustomDecorator(decorateOptionalOr), + } +} + +func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) { + varIdent := args[0] + varName := "" + switch varIdent.Kind() { + case ast.IdentKind: + varName = varIdent.AsIdent() + default: + return nil, meh.NewError(varIdent.ID(), "optMap() variable name must be a simple identifier") + } + mapExpr := args[1] + return meh.NewCall( + operators.Conditional, + meh.NewMemberCall(hasValueFunc, target), + meh.NewCall(optionalOfFunc, + meh.NewComprehension( + meh.NewList(), + unusedIterVar, + varName, + meh.NewMemberCall(valueFunc, meh.Copy(target)), + meh.NewLiteral(types.False), + meh.NewIdent(varName), + mapExpr, + ), + ), + meh.NewCall(optionalNoneFunc), + ), nil +} + +func optFlatMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) { + varIdent := args[0] + varName := "" + switch varIdent.Kind() { + case ast.IdentKind: + varName = varIdent.AsIdent() + default: + return nil, meh.NewError(varIdent.ID(), "optFlatMap() variable name must be a simple identifier") + } + mapExpr := args[1] + return meh.NewCall( + operators.Conditional, + meh.NewMemberCall(hasValueFunc, target), + meh.NewComprehension( + meh.NewList(), + unusedIterVar, + varName, + meh.NewMemberCall(valueFunc, meh.Copy(target)), + meh.NewLiteral(types.False), + meh.NewIdent(varName), + mapExpr, + ), + meh.NewCall(optionalNoneFunc), + ), nil +} + +func enableOptionalSyntax() EnvOption { + return func(e *Env) (*Env, error) { + e.prsrOpts = append(e.prsrOpts, parser.EnableOptionalSyntax(true)) + return e, nil + } +} + +// EnableErrorOnBadPresenceTest enables error generation when a presence test or optional field +// selection is performed on a primitive type. +func EnableErrorOnBadPresenceTest(value bool) EnvOption { + return features(featureEnableErrorOnBadPresenceTest, value) +} + +func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable, error) { + call, ok := i.(interpreter.InterpretableCall) + if !ok { + return i, nil + } + args := call.Args() + if len(args) != 2 { + return i, nil + } + switch call.Function() { + case "or": + if call.OverloadID() != "" && call.OverloadID() != "optional_or_optional" { + return i, nil + } + return &evalOptionalOr{ + id: call.ID(), + lhs: args[0], + rhs: args[1], + }, nil + case "orValue": + if call.OverloadID() != "" && call.OverloadID() != "optional_orValue_value" { + return i, nil + } + return &evalOptionalOrValue{ + id: call.ID(), + lhs: args[0], + rhs: args[1], + }, nil + default: + return i, nil + } +} + +// evalOptionalOr selects between two optional values, either the first if it has a value, or +// the second optional expression is evaluated and returned. +type evalOptionalOr struct { + id int64 + lhs interpreter.Interpretable + rhs interpreter.Interpretable +} + +// ID implements the Interpretable interface method. +func (opt *evalOptionalOr) ID() int64 { + return opt.id +} + +// Eval evaluates the left-hand side optional to determine whether it contains a value, else +// proceeds with the right-hand side evaluation. +func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val { + // short-circuit lhs. + optLHS := opt.lhs.Eval(ctx) + optVal, ok := optLHS.(*types.Optional) + if !ok { + return optLHS + } + if optVal.HasValue() { + return optVal + } + return opt.rhs.Eval(ctx) +} + +// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value, +// its value is returned, otherwise the alternative value expression is evaluated and returned. +type evalOptionalOrValue struct { + id int64 + lhs interpreter.Interpretable + rhs interpreter.Interpretable +} + +// ID implements the Interpretable interface method. +func (opt *evalOptionalOrValue) ID() int64 { + return opt.id +} + +// Eval evaluates the left-hand side optional to determine whether it contains a value, else +// proceeds with the right-hand side evaluation. +func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val { + // short-circuit lhs. + optLHS := opt.lhs.Eval(ctx) + optVal, ok := optLHS.(*types.Optional) + if !ok { + return optLHS + } + if optVal.HasValue() { + return optVal.GetValue() + } + return opt.rhs.Eval(ctx) +} + +type timeUTCLibrary struct{} + +func (timeUTCLibrary) CompileOptions() []EnvOption { + return timeOverloadDeclarations +} + +func (timeUTCLibrary) ProgramOptions() []ProgramOption { + return []ProgramOption{} +} + +// Declarations and functions which enable using UTC on time.Time inputs when the timezone is unspecified +// in the CEL expression. +var ( + utcTZ = types.String("UTC") + + timeOverloadDeclarations = []EnvOption{ + Function(overloads.TimeGetHours, + MemberOverload(overloads.DurationToHours, []*Type{DurationType}, IntType, + UnaryBinding(types.DurationGetHours))), + Function(overloads.TimeGetMinutes, + MemberOverload(overloads.DurationToMinutes, []*Type{DurationType}, IntType, + UnaryBinding(types.DurationGetMinutes))), + Function(overloads.TimeGetSeconds, + MemberOverload(overloads.DurationToSeconds, []*Type{DurationType}, IntType, + UnaryBinding(types.DurationGetSeconds))), + Function(overloads.TimeGetMilliseconds, + MemberOverload(overloads.DurationToMilliseconds, []*Type{DurationType}, IntType, + UnaryBinding(types.DurationGetMilliseconds))), + Function(overloads.TimeGetFullYear, + MemberOverload(overloads.TimestampToYear, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetFullYear(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToYearWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetFullYear), + ), + ), + Function(overloads.TimeGetMonth, + MemberOverload(overloads.TimestampToMonth, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetMonth(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToMonthWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetMonth), + ), + ), + Function(overloads.TimeGetDayOfYear, + MemberOverload(overloads.TimestampToDayOfYear, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetDayOfYear(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToDayOfYearWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(func(ts, tz ref.Val) ref.Val { + return timestampGetDayOfYear(ts, tz) + }), + ), + ), + Function(overloads.TimeGetDayOfMonth, + MemberOverload(overloads.TimestampToDayOfMonthZeroBased, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetDayOfMonthZeroBased(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetDayOfMonthZeroBased), + ), + ), + Function(overloads.TimeGetDate, + MemberOverload(overloads.TimestampToDayOfMonthOneBased, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetDayOfMonthOneBased(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToDayOfMonthOneBasedWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetDayOfMonthOneBased), + ), + ), + Function(overloads.TimeGetDayOfWeek, + MemberOverload(overloads.TimestampToDayOfWeek, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetDayOfWeek(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToDayOfWeekWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetDayOfWeek), + ), + ), + Function(overloads.TimeGetHours, + MemberOverload(overloads.TimestampToHours, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetHours(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToHoursWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetHours), + ), + ), + Function(overloads.TimeGetMinutes, + MemberOverload(overloads.TimestampToMinutes, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetMinutes(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToMinutesWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetMinutes), + ), + ), + Function(overloads.TimeGetSeconds, + MemberOverload(overloads.TimestampToSeconds, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetSeconds(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToSecondsWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetSeconds), + ), + ), + Function(overloads.TimeGetMilliseconds, + MemberOverload(overloads.TimestampToMilliseconds, []*Type{TimestampType}, IntType, + UnaryBinding(func(ts ref.Val) ref.Val { + return timestampGetMilliseconds(ts, utcTZ) + }), + ), + MemberOverload(overloads.TimestampToMillisecondsWithTz, []*Type{TimestampType, StringType}, IntType, + BinaryBinding(timestampGetMilliseconds), + ), + ), + } +) + +func timestampGetFullYear(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.Year()) +} + +func timestampGetMonth(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + // CEL spec indicates that the month should be 0-based, but the Time value + // for Month() is 1-based. + return types.Int(t.Month() - 1) +} + +func timestampGetDayOfYear(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.YearDay() - 1) +} + +func timestampGetDayOfMonthZeroBased(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.Day() - 1) +} + +func timestampGetDayOfMonthOneBased(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.Day()) +} + +func timestampGetDayOfWeek(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.Weekday()) +} + +func timestampGetHours(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.Hour()) +} + +func timestampGetMinutes(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.Minute()) +} + +func timestampGetSeconds(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.Second()) +} + +func timestampGetMilliseconds(ts, tz ref.Val) ref.Val { + t, err := inTimeZone(ts, tz) + if err != nil { + return types.NewErr(err.Error()) + } + return types.Int(t.Nanosecond() / 1000000) +} + +func inTimeZone(ts, tz ref.Val) (time.Time, error) { + t := ts.(types.Timestamp) + val := string(tz.(types.String)) + ind := strings.Index(val, ":") + if ind == -1 { + loc, err := time.LoadLocation(val) + if err != nil { + return time.Time{}, err + } + return t.In(loc), nil + } + + // If the input is not the name of a timezone (for example, 'US/Central'), it should be a numerical offset from UTC + // in the format ^(+|-)(0[0-9]|1[0-4]):[0-5][0-9]$. The numerical input is parsed in terms of hours and minutes. + hr, err := strconv.Atoi(string(val[0:ind])) + if err != nil { + return time.Time{}, err + } + min, err := strconv.Atoi(string(val[ind+1:])) + if err != nil { + return time.Time{}, err + } + var offset int + if string(val[0]) == "-" { + offset = hr*60 - min + } else { + offset = hr*60 + min + } + secondsEastOfUTC := int((time.Duration(offset) * time.Minute).Seconds()) + timezone := time.FixedZone("", secondsEastOfUTC) + return t.In(timezone), nil +} diff --git a/vendor/github.com/authzed/cel-go/cel/macro.go b/vendor/github.com/authzed/cel-go/cel/macro.go new file mode 100644 index 0000000..26c7031 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/macro.go @@ -0,0 +1,576 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/authzed/cel-go/common" + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/parser" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// Macro describes a function signature to match and the MacroExpander to apply. +// +// Note: when a Macro should apply to multiple overloads (based on arg count) of a given function, +// a Macro should be created per arg-count or as a var arg macro. +type Macro = parser.Macro + +// MacroFactory defines an expansion function which converts a call and its arguments to a cel.Expr value. +type MacroFactory = parser.MacroExpander + +// MacroExprFactory assists with the creation of Expr values in a manner which is consistent +// the internal semantics and id generation behaviors of the parser and checker libraries. +type MacroExprFactory = parser.ExprHelper + +// MacroExpander converts a call and its associated arguments into a protobuf Expr representation. +// +// If the MacroExpander determines within the implementation that an expansion is not needed it may return +// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments +// are not well-formed, the result of the expansion will be an error. +// +// The MacroExpander accepts as arguments a MacroExprHelper as well as the arguments used in the function call +// and produces as output an Expr ast node. +// +// Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil. +type MacroExpander func(eh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) + +// MacroExprHelper exposes helper methods for creating new expressions within a CEL abstract syntax tree. +// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is +// consistent with the source position and expression id generation code leveraged by both +// the parser and type-checker. +type MacroExprHelper interface { + // Copy the input expression with a brand new set of identifiers. + Copy(*exprpb.Expr) *exprpb.Expr + + // LiteralBool creates an Expr value for a bool literal. + LiteralBool(value bool) *exprpb.Expr + + // LiteralBytes creates an Expr value for a byte literal. + LiteralBytes(value []byte) *exprpb.Expr + + // LiteralDouble creates an Expr value for double literal. + LiteralDouble(value float64) *exprpb.Expr + + // LiteralInt creates an Expr value for an int literal. + LiteralInt(value int64) *exprpb.Expr + + // LiteralString creates am Expr value for a string literal. + LiteralString(value string) *exprpb.Expr + + // LiteralUint creates an Expr value for a uint literal. + LiteralUint(value uint64) *exprpb.Expr + + // NewList creates a CreateList instruction where the list is comprised of the optional set + // of elements provided as arguments. + NewList(elems ...*exprpb.Expr) *exprpb.Expr + + // NewMap creates a CreateStruct instruction for a map where the map is comprised of the + // optional set of key, value entries. + NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + + // NewMapEntry creates a Map Entry for the key, value pair. + NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + + // NewObject creates a CreateStruct instruction for an object with a given type name and + // optional set of field initializers. + NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + + // NewObjectFieldInit creates a new Object field initializer from the field name and value. + NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + + // Fold creates a fold comprehension instruction. + // + // - iterVar is the iteration variable name. + // - iterRange represents the expression that resolves to a list or map where the elements or + // keys (respectively) will be iterated over. + // - accuVar is the accumulation variable name, typically parser.AccumulatorName. + // - accuInit is the initial expression whose value will be set for the accuVar prior to + // folding. + // - condition is the expression to test to determine whether to continue folding. + // - step is the expression to evaluation at the conclusion of a single fold iteration. + // - result is the computation to evaluate at the conclusion of the fold. + // + // The accuVar should not shadow variable names that you would like to reference within the + // environment in the step and condition expressions. Presently, the name __result__ is commonly + // used by built-in macros but this may change in the future. + Fold(iterVar string, + iterRange *exprpb.Expr, + accuVar string, + accuInit *exprpb.Expr, + condition *exprpb.Expr, + step *exprpb.Expr, + result *exprpb.Expr) *exprpb.Expr + + // Ident creates an identifier Expr value. + Ident(name string) *exprpb.Expr + + // AccuIdent returns an accumulator identifier for use with comprehension results. + AccuIdent() *exprpb.Expr + + // GlobalCall creates a function call Expr value for a global (free) function. + GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr + + // ReceiverCall creates a function call Expr value for a receiver-style function. + ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr + + // PresenceTest creates a Select TestOnly Expr value for modelling has() semantics. + PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr + + // Select create a field traversal Expr value. + Select(operand *exprpb.Expr, field string) *exprpb.Expr + + // OffsetLocation returns the Location of the expression identifier. + OffsetLocation(exprID int64) common.Location + + // NewError associates an error message with a given expression id. + NewError(exprID int64, message string) *Error +} + +// GlobalMacro creates a Macro for a global function with the specified arg count. +func GlobalMacro(function string, argCount int, factory MacroFactory) Macro { + return parser.NewGlobalMacro(function, argCount, factory) +} + +// ReceiverMacro creates a Macro for a receiver function matching the specified arg count. +func ReceiverMacro(function string, argCount int, factory MacroFactory) Macro { + return parser.NewReceiverMacro(function, argCount, factory) +} + +// GlobalVarArgMacro creates a Macro for a global function with a variable arg count. +func GlobalVarArgMacro(function string, factory MacroFactory) Macro { + return parser.NewGlobalVarArgMacro(function, factory) +} + +// ReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count. +func ReceiverVarArgMacro(function string, factory MacroFactory) Macro { + return parser.NewReceiverVarArgMacro(function, factory) +} + +// NewGlobalMacro creates a Macro for a global function with the specified arg count. +// +// Deprecated: use GlobalMacro +func NewGlobalMacro(function string, argCount int, expander MacroExpander) Macro { + expand := adaptingExpander{expander} + return parser.NewGlobalMacro(function, argCount, expand.Expander) +} + +// NewReceiverMacro creates a Macro for a receiver function matching the specified arg count. +// +// Deprecated: use ReceiverMacro +func NewReceiverMacro(function string, argCount int, expander MacroExpander) Macro { + expand := adaptingExpander{expander} + return parser.NewReceiverMacro(function, argCount, expand.Expander) +} + +// NewGlobalVarArgMacro creates a Macro for a global function with a variable arg count. +// +// Deprecated: use GlobalVarArgMacro +func NewGlobalVarArgMacro(function string, expander MacroExpander) Macro { + expand := adaptingExpander{expander} + return parser.NewGlobalVarArgMacro(function, expand.Expander) +} + +// NewReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count. +// +// Deprecated: use ReceiverVarArgMacro +func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro { + expand := adaptingExpander{expander} + return parser.NewReceiverVarArgMacro(function, expand.Expander) +} + +// HasMacroExpander expands the input call arguments into a presence test, e.g. has(<operand>.field) +func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + arg, err := adaptToExpr(args[0]) + if err != nil { + return nil, err + } + if arg.Kind() == ast.SelectKind { + s := arg.AsSelect() + return adaptToProto(ph.NewPresenceTest(s.Operand(), s.FieldName())) + } + return nil, ph.NewError(arg.ID(), "invalid argument to has() macro") +} + +// ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the +// elements in the range match the predicate expressions: +// <iterRange>.exists(<iterVar>, <predicate>) +func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeExists(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) +} + +// ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly +// one of the elements in the range match the predicate expressions: +// <iterRange>.exists_one(<iterVar>, <predicate>) +func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeExistsOne(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) +} + +// MapMacroExpander expands the input call arguments into a comprehension that transforms each element in the +// input to produce an output list. +// +// There are two call patterns supported by map: +// +// <iterRange>.map(<iterVar>, <transform>) +// <iterRange>.map(<iterVar>, <predicate>, <transform>) +// +// In the second form only iterVar values which return true when provided to the predicate expression +// are transformed. +func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeMap(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) +} + +// FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains +// only elements which match the provided predicate expression: +// <iterRange>.filter(<iterVar>, <predicate>) +func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeFilter(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) +} + +var ( + // Aliases to each macro in the CEL standard environment. + // Note: reassigning these macro variables may result in undefined behavior. + + // HasMacro expands "has(m.f)" which tests the presence of a field, avoiding the need to + // specify the field as a string. + HasMacro = parser.HasMacro + + // AllMacro expands "range.all(var, predicate)" into a comprehension which ensures that all + // elements in the range satisfy the predicate. + AllMacro = parser.AllMacro + + // ExistsMacro expands "range.exists(var, predicate)" into a comprehension which ensures that + // some element in the range satisfies the predicate. + ExistsMacro = parser.ExistsMacro + + // ExistsOneMacro expands "range.exists_one(var, predicate)", which is true if for exactly one + // element in range the predicate holds. + ExistsOneMacro = parser.ExistsOneMacro + + // MapMacro expands "range.map(var, function)" into a comprehension which applies the function + // to each element in the range to produce a new list. + MapMacro = parser.MapMacro + + // MapFilterMacro expands "range.map(var, predicate, function)" into a comprehension which + // first filters the elements in the range by the predicate, then applies the transform function + // to produce a new list. + MapFilterMacro = parser.MapFilterMacro + + // FilterMacro expands "range.filter(var, predicate)" into a comprehension which filters + // elements in the range, producing a new list from the elements that satisfy the predicate. + FilterMacro = parser.FilterMacro + + // StandardMacros provides an alias to all the CEL macros defined in the standard environment. + StandardMacros = []Macro{ + HasMacro, AllMacro, ExistsMacro, ExistsOneMacro, MapMacro, MapFilterMacro, FilterMacro, + } + + // NoMacros provides an alias to an empty list of macros + NoMacros = []Macro{} +) + +type adaptingExpander struct { + legacyExpander MacroExpander +} + +func (adapt *adaptingExpander) Expander(eh parser.ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + var legacyTarget *exprpb.Expr = nil + var err *Error = nil + if target != nil { + legacyTarget, err = adaptToProto(target) + if err != nil { + return nil, err + } + } + legacyArgs := make([]*exprpb.Expr, len(args)) + for i, arg := range args { + legacyArgs[i], err = adaptToProto(arg) + if err != nil { + return nil, err + } + } + ah := &adaptingHelper{modernHelper: eh} + legacyExpr, err := adapt.legacyExpander(ah, legacyTarget, legacyArgs) + if err != nil { + return nil, err + } + ex, err := adaptToExpr(legacyExpr) + if err != nil { + return nil, err + } + return ex, nil +} + +func wrapErr(id int64, message string, err error) *common.Error { + return &common.Error{ + Location: common.NoLocation, + Message: fmt.Sprintf("%s: %v", message, err), + ExprID: id, + } +} + +type adaptingHelper struct { + modernHelper parser.ExprHelper +} + +// Copy the input expression with a brand new set of identifiers. +func (ah *adaptingHelper) Copy(e *exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.Copy(mustAdaptToExpr(e))) +} + +// LiteralBool creates an Expr value for a bool literal. +func (ah *adaptingHelper) LiteralBool(value bool) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bool(value))) +} + +// LiteralBytes creates an Expr value for a byte literal. +func (ah *adaptingHelper) LiteralBytes(value []byte) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bytes(value))) +} + +// LiteralDouble creates an Expr value for double literal. +func (ah *adaptingHelper) LiteralDouble(value float64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Double(value))) +} + +// LiteralInt creates an Expr value for an int literal. +func (ah *adaptingHelper) LiteralInt(value int64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Int(value))) +} + +// LiteralString creates am Expr value for a string literal. +func (ah *adaptingHelper) LiteralString(value string) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.String(value))) +} + +// LiteralUint creates an Expr value for a uint literal. +func (ah *adaptingHelper) LiteralUint(value uint64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Uint(value))) +} + +// NewList creates a CreateList instruction where the list is comprised of the optional set +// of elements provided as arguments. +func (ah *adaptingHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewList(mustAdaptToExprs(elems)...)) +} + +// NewMap creates a CreateStruct instruction for a map where the map is comprised of the +// optional set of key, value entries. +func (ah *adaptingHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { + adaptedEntries := make([]ast.EntryExpr, len(entries)) + for i, e := range entries { + adaptedEntries[i] = mustAdaptToEntryExpr(e) + } + return mustAdaptToProto(ah.modernHelper.NewMap(adaptedEntries...)) +} + +// NewMapEntry creates a Map Entry for the key, value pair. +func (ah *adaptingHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { + return mustAdaptToProtoEntry( + ah.modernHelper.NewMapEntry(mustAdaptToExpr(key), mustAdaptToExpr(val), optional)) +} + +// NewObject creates a CreateStruct instruction for an object with a given type name and +// optional set of field initializers. +func (ah *adaptingHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { + adaptedEntries := make([]ast.EntryExpr, len(fieldInits)) + for i, e := range fieldInits { + adaptedEntries[i] = mustAdaptToEntryExpr(e) + } + return mustAdaptToProto(ah.modernHelper.NewStruct(typeName, adaptedEntries...)) +} + +// NewObjectFieldInit creates a new Object field initializer from the field name and value. +func (ah *adaptingHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { + return mustAdaptToProtoEntry( + ah.modernHelper.NewStructField(field, mustAdaptToExpr(init), optional)) +} + +// Fold creates a fold comprehension instruction. +// +// - iterVar is the iteration variable name. +// - iterRange represents the expression that resolves to a list or map where the elements or +// keys (respectively) will be iterated over. +// - accuVar is the accumulation variable name, typically parser.AccumulatorName. +// - accuInit is the initial expression whose value will be set for the accuVar prior to +// folding. +// - condition is the expression to test to determine whether to continue folding. +// - step is the expression to evaluation at the conclusion of a single fold iteration. +// - result is the computation to evaluate at the conclusion of the fold. +// +// The accuVar should not shadow variable names that you would like to reference within the +// environment in the step and condition expressions. Presently, the name __result__ is commonly +// used by built-in macros but this may change in the future. +func (ah *adaptingHelper) Fold(iterVar string, + iterRange *exprpb.Expr, + accuVar string, + accuInit *exprpb.Expr, + condition *exprpb.Expr, + step *exprpb.Expr, + result *exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto( + ah.modernHelper.NewComprehension( + mustAdaptToExpr(iterRange), + iterVar, + accuVar, + mustAdaptToExpr(accuInit), + mustAdaptToExpr(condition), + mustAdaptToExpr(step), + mustAdaptToExpr(result), + ), + ) +} + +// Ident creates an identifier Expr value. +func (ah *adaptingHelper) Ident(name string) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewIdent(name)) +} + +// AccuIdent returns an accumulator identifier for use with comprehension results. +func (ah *adaptingHelper) AccuIdent() *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewAccuIdent()) +} + +// GlobalCall creates a function call Expr value for a global (free) function. +func (ah *adaptingHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewCall(function, mustAdaptToExprs(args)...)) +} + +// ReceiverCall creates a function call Expr value for a receiver-style function. +func (ah *adaptingHelper) ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto( + ah.modernHelper.NewMemberCall(function, mustAdaptToExpr(target), mustAdaptToExprs(args)...)) +} + +// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics. +func (ah *adaptingHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr { + op := mustAdaptToExpr(operand) + return mustAdaptToProto(ah.modernHelper.NewPresenceTest(op, field)) +} + +// Select create a field traversal Expr value. +func (ah *adaptingHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr { + op := mustAdaptToExpr(operand) + return mustAdaptToProto(ah.modernHelper.NewSelect(op, field)) +} + +// OffsetLocation returns the Location of the expression identifier. +func (ah *adaptingHelper) OffsetLocation(exprID int64) common.Location { + return ah.modernHelper.OffsetLocation(exprID) +} + +// NewError associates an error message with a given expression id. +func (ah *adaptingHelper) NewError(exprID int64, message string) *Error { + return ah.modernHelper.NewError(exprID, message) +} + +func mustAdaptToExprs(exprs []*exprpb.Expr) []ast.Expr { + adapted := make([]ast.Expr, len(exprs)) + for i, e := range exprs { + adapted[i] = mustAdaptToExpr(e) + } + return adapted +} + +func mustAdaptToExpr(e *exprpb.Expr) ast.Expr { + out, _ := adaptToExpr(e) + return out +} + +func adaptToExpr(e *exprpb.Expr) (ast.Expr, *Error) { + if e == nil { + return nil, nil + } + out, err := ast.ProtoToExpr(e) + if err != nil { + return nil, wrapErr(e.GetId(), "proto conversion failure", err) + } + return out, nil +} + +func mustAdaptToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) ast.EntryExpr { + out, _ := ast.ProtoToEntryExpr(e) + return out +} + +func mustAdaptToProto(e ast.Expr) *exprpb.Expr { + out, _ := adaptToProto(e) + return out +} + +func adaptToProto(e ast.Expr) (*exprpb.Expr, *Error) { + if e == nil { + return nil, nil + } + out, err := ast.ExprToProto(e) + if err != nil { + return nil, wrapErr(e.ID(), "expr conversion failure", err) + } + return out, nil +} + +func mustAdaptToProtoEntry(e ast.EntryExpr) *exprpb.Expr_CreateStruct_Entry { + out, _ := ast.EntryExprToProto(e) + return out +} + +func toParserHelper(meh MacroExprHelper) (parser.ExprHelper, *Error) { + ah, ok := meh.(*adaptingHelper) + if !ok { + return nil, common.NewError(0, + fmt.Sprintf("unsupported macro helper: %v (%T)", meh, meh), + common.NoLocation) + } + return ah.modernHelper, nil +} diff --git a/vendor/github.com/authzed/cel-go/cel/optimizer.go b/vendor/github.com/authzed/cel-go/cel/optimizer.go new file mode 100644 index 0000000..27a6f78 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/optimizer.go @@ -0,0 +1,509 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/authzed/cel-go/common" + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" +) + +// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order. +// +// The static optimizer normalizes expression ids and type-checking run between optimization +// passes to ensure that the final optimized output is a valid expression with metadata consistent +// with what would have been generated from a parsed and checked expression. +// +// Note: source position information is best-effort and likely wrong, but optimized expressions +// should be suitable for calls to parser.Unparse. +type StaticOptimizer struct { + optimizers []ASTOptimizer +} + +// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied +// to a checked expression. +func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { + return &StaticOptimizer{ + optimizers: optimizers, + } +} + +// Optimize applies a sequence of optimizations to an Ast within a given environment. +// +// If issues are encountered, the Issues.Err() return value will be non-nil. +func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { + // Make a copy of the AST to be optimized. + optimized := ast.Copy(a.impl) + ids := newIDGenerator(ast.MaxID(a.impl)) + + // Create the optimizer context, could be pooled in the future. + issues := NewIssues(common.NewErrors(a.Source())) + baseFac := ast.NewExprFactory() + exprFac := &optimizerExprFactory{ + idGenerator: ids, + fac: baseFac, + sourceInfo: optimized.SourceInfo(), + } + ctx := &OptimizerContext{ + optimizerExprFactory: exprFac, + Env: env, + Issues: issues, + } + + // Apply the optimizations sequentially. + for _, o := range opt.optimizers { + optimized = o.Optimize(ctx, optimized) + if issues.Err() != nil { + return nil, issues + } + // Normalize expression id metadata including coordination with macro call metadata. + freshIDGen := newIDGenerator(0) + info := optimized.SourceInfo() + expr := optimized.Expr() + normalizeIDs(freshIDGen.renumberStable, expr, info) + cleanupMacroRefs(expr, info) + + // Recheck the updated expression for any possible type-agreement or validation errors. + parsed := &Ast{ + source: a.Source(), + impl: ast.NewAST(expr, info)} + checked, iss := ctx.Check(parsed) + if iss.Err() != nil { + return nil, iss + } + optimized = checked.impl + } + + // Return the optimized result. + return &Ast{ + source: a.Source(), + impl: optimized, + }, nil +} + +// normalizeIDs ensures that the metadata present with an AST is reset in a manner such +// that the ids within the expression correspond to the ids within macros. +func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) { + optimized.RenumberIDs(idGen) + + if len(info.MacroCalls()) == 0 { + return + } + + // First, update the macro call ids themselves. + callIDMap := map[int64]int64{} + for id := range info.MacroCalls() { + callIDMap[id] = idGen(id) + } + // Then update the macro call definitions which refer to these ids, but + // ensure that the updates don't collide and remove macro entries which haven't + // been visited / updated yet. + type macroUpdate struct { + id int64 + call ast.Expr + } + macroUpdates := []macroUpdate{} + for oldID, newID := range callIDMap { + call, found := info.GetMacroCall(oldID) + if !found { + continue + } + call.RenumberIDs(idGen) + macroUpdates = append(macroUpdates, macroUpdate{id: newID, call: call}) + info.ClearMacroCall(oldID) + } + for _, u := range macroUpdates { + info.SetMacroCall(u.id, u.call) + } +} + +func cleanupMacroRefs(expr ast.Expr, info *ast.SourceInfo) { + if len(info.MacroCalls()) == 0 { + return + } + // Sanitize the macro call references once the optimized expression has been computed + // and the ids normalized between the expression and the macros. + exprRefMap := make(map[int64]struct{}) + ast.PostOrderVisit(expr, ast.NewExprVisitor(func(e ast.Expr) { + if e.ID() == 0 { + return + } + exprRefMap[e.ID()] = struct{}{} + })) + // Update the macro call id references to ensure that macro pointers are + // updated consistently across macros. + for _, call := range info.MacroCalls() { + ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) { + if e.ID() == 0 { + return + } + exprRefMap[e.ID()] = struct{}{} + })) + } + for id := range info.MacroCalls() { + if _, found := exprRefMap[id]; !found { + info.ClearMacroCall(id) + } + } +} + +// newIDGenerator ensures that new ids are only created the first time they are encountered. +func newIDGenerator(seed int64) *idGenerator { + return &idGenerator{ + idMap: make(map[int64]int64), + seed: seed, + } +} + +type idGenerator struct { + idMap map[int64]int64 + seed int64 +} + +func (gen *idGenerator) nextID() int64 { + gen.seed++ + return gen.seed +} + +func (gen *idGenerator) renumberStable(id int64) int64 { + if id == 0 { + return 0 + } + if newID, found := gen.idMap[id]; found { + return newID + } + nextID := gen.nextID() + gen.idMap[id] = nextID + return nextID +} + +// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate +// subexpressions and report any errors encountered along the way. The context also embeds the +// optimizerExprFactory which can be used to generate new sub-expressions with expression ids +// consistent with the expectations of a parsed expression. +type OptimizerContext struct { + *Env + *optimizerExprFactory + *Issues +} + +// ASTOptimizer applies an optimization over an AST and returns the optimized result. +type ASTOptimizer interface { + // Optimize optimizes a type-checked AST within an Environment and accumulates any issues. + Optimize(*OptimizerContext, *ast.AST) *ast.AST +} + +type optimizerExprFactory struct { + *idGenerator + fac ast.ExprFactory + sourceInfo *ast.SourceInfo +} + +// NewAST creates an AST from the current expression using the tracked source info which +// is modified and managed by the OptimizerContext. +func (opt *optimizerExprFactory) NewAST(expr ast.Expr) *ast.AST { + return ast.NewAST(expr, opt.sourceInfo) +} + +// CopyAST creates a renumbered copy of `Expr` and `SourceInfo` values of the input AST, where the +// renumbering uses the same scheme as the core optimizer logic ensuring there are no collisions +// between copies. +// +// Use this method before attempting to merge the expression from AST into another. +func (opt *optimizerExprFactory) CopyAST(a *ast.AST) (ast.Expr, *ast.SourceInfo) { + idGen := newIDGenerator(opt.nextID()) + defer func() { opt.seed = idGen.nextID() }() + copyExpr := opt.fac.CopyExpr(a.Expr()) + copyInfo := ast.CopySourceInfo(a.SourceInfo()) + normalizeIDs(idGen.renumberStable, copyExpr, copyInfo) + return copyExpr, copyInfo +} + +// CopyASTAndMetadata copies the input AST and propagates the macro metadata into the AST being +// optimized. +func (opt *optimizerExprFactory) CopyASTAndMetadata(a *ast.AST) ast.Expr { + copyExpr, copyInfo := opt.CopyAST(a) + for macroID, call := range copyInfo.MacroCalls() { + opt.SetMacroCall(macroID, call) + } + return copyExpr +} + +// ClearMacroCall clears the macro at the given expression id. +func (opt *optimizerExprFactory) ClearMacroCall(id int64) { + opt.sourceInfo.ClearMacroCall(id) +} + +// SetMacroCall sets the macro call metadata for the given macro id within the tracked source info +// metadata. +func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) { + opt.sourceInfo.SetMacroCall(id, expr) +} + +// NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression +// representing the unexpanded call signature to be inserted into the source info macro call metadata. +func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) { + varID := opt.nextID() + remainingID := opt.nextID() + remaining = opt.fac.CopyExpr(remaining) + remaining.RenumberIDs(func(id int64) int64 { + if id == macroID { + return remainingID + } + return id + }) + if call, exists := opt.sourceInfo.GetMacroCall(macroID); exists { + opt.SetMacroCall(remainingID, opt.fac.CopyExpr(call)) + } + + astExpr = opt.fac.NewComprehension(macroID, + opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}), + "#unused", + varName, + opt.fac.CopyExpr(varInit), + opt.fac.NewLiteral(opt.nextID(), types.False), + opt.fac.NewIdent(varID, varName), + remaining) + + macroExpr = opt.fac.NewMemberCall(0, "bind", + opt.fac.NewIdent(opt.nextID(), "cel"), + opt.fac.NewIdent(varID, varName), + opt.fac.CopyExpr(varInit), + opt.fac.CopyExpr(remaining)) + opt.sanitizeMacro(macroID, macroExpr) + return +} + +// NewCall creates a global function call invocation expression. +// +// Example: +// +// countByField(list, fieldName) +// - function: countByField +// - args: [list, fieldName] +func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr { + return opt.fac.NewCall(opt.nextID(), function, args...) +} + +// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call. +// +// Example: +// +// list.countByField(fieldName) +// - function: countByField +// - target: list +// - args: [fieldName] +func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return opt.fac.NewMemberCall(opt.nextID(), function, target, args...) +} + +// NewIdent creates a new identifier expression. +// +// Examples: +// +// - simple_var_name +// - qualified.subpackage.var_name +func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr { + return opt.fac.NewIdent(opt.nextID(), name) +} + +// NewLiteral creates a new literal expression value. +// +// The range of valid values for a literal generated during optimization is different than for expressions +// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can +// be converted back to a literal-like form. +func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr { + return opt.fac.NewLiteral(opt.nextID(), value) +} + +// NewList creates a list expression with a set of optional indices. +// +// Examples: +// +// [a, b] +// - elems: [a, b] +// - optIndices: [] +// +// [a, ?b, ?c] +// - elems: [a, b, c] +// - optIndices: [1, 2] +func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr { + return opt.fac.NewList(opt.nextID(), elems, optIndices) +} + +// NewMap creates a map from a set of entry expressions which contain a key and value expression. +func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr { + return opt.fac.NewMap(opt.nextID(), entries) +} + +// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the +// entry is optional. +// +// Examples: +// +// {a: b} +// - key: a +// - value: b +// - optional: false +// +// {?a: ?b} +// - key: a +// - value: b +// - optional: true +func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) +} + +// NewHasMacro generates a test-only select expression to be included within an AST and an unexpanded +// has() macro call signature to be inserted into the source info macro call metadata. +func (opt *optimizerExprFactory) NewHasMacro(macroID int64, s ast.Expr) (astExpr, macroExpr ast.Expr) { + sel := s.AsSelect() + astExpr = opt.fac.NewPresenceTest(macroID, sel.Operand(), sel.FieldName()) + macroExpr = opt.fac.NewCall(0, "has", + opt.NewSelect(opt.fac.CopyExpr(sel.Operand()), sel.FieldName())) + opt.sanitizeMacro(macroID, macroExpr) + return +} + +// NewSelect creates a select expression where a field value is selected from an operand. +// +// Example: +// +// msg.field_name +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr { + return opt.fac.NewSelect(opt.nextID(), operand, field) +} + +// NewStruct creates a new typed struct value with an set of field initializations. +// +// Example: +// +// pkg.TypeName{field: value} +// - typeName: pkg.TypeName +// - fields: [{field: value}] +func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr { + return opt.fac.NewStruct(opt.nextID(), typeName, fields) +} + +// NewStructField creates a struct field initialization. +// +// Examples: +// +// {count: 3u} +// - field: count +// - value: 3u +// - optional: false +// +// {?count: x} +// - field: count +// - value: x +// - optional: true +func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewStructField(opt.nextID(), field, value, isOptional) +} + +// UpdateExpr updates the target expression with the updated content while preserving macro metadata. +// +// There are four scenarios during the update to consider: +// 1. target is not macro, updated is not macro +// 2. target is macro, updated is not macro +// 3. target is macro, updated is macro +// 4. target is not macro, updated is macro +// +// When the target is a macro already, it may either be updated to a new macro function +// body if the update is also a macro, or it may be removed altogether if the update is +// a macro. +// +// When the update is a macro, then the target references within other macros must be +// updated to point to the new updated macro. Otherwise, other macros which pointed to +// the target body must be replaced with copies of the updated expression body. +func (opt *optimizerExprFactory) UpdateExpr(target, updated ast.Expr) { + // Update the expression + target.SetKindCase(updated) + + // Early return if there's no macros present sa the source info reflects the + // macro set from the target and updated expressions. + if len(opt.sourceInfo.MacroCalls()) == 0 { + return + } + // Determine whether the target expression was a macro. + _, targetIsMacro := opt.sourceInfo.GetMacroCall(target.ID()) + + // Determine whether the updated expression was a macro. + updatedMacro, updatedIsMacro := opt.sourceInfo.GetMacroCall(updated.ID()) + + if updatedIsMacro { + // If the updated call was a macro, then updated id maps to target id, + // and the updated macro moves into the target id slot. + opt.sourceInfo.ClearMacroCall(updated.ID()) + opt.sourceInfo.SetMacroCall(target.ID(), updatedMacro) + } else if targetIsMacro { + // Otherwise if the target expr was a macro, but is no longer, clear + // the macro reference. + opt.sourceInfo.ClearMacroCall(target.ID()) + } + + // Punch holes in the updated value where macros references exist. + macroExpr := opt.fac.CopyExpr(target) + macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) { + if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists { + e.SetKindCase(nil) + } + }) + ast.PostOrderVisit(macroExpr, macroRefVisitor) + + // Update any references to the expression within a macro + macroVisitor := ast.NewExprVisitor(func(call ast.Expr) { + // Update the target expression to point to the macro expression which + // will be empty if the updated expression was a macro. + if call.ID() == target.ID() { + call.SetKindCase(opt.fac.CopyExpr(macroExpr)) + } + // Update the macro call expression if it refers to the updated expression + // id which has since been remapped to the target id. + if call.ID() == updated.ID() { + // Either ensure the expression is a macro reference or a populated with + // the relevant sub-expression if the updated expr was not a macro. + if updatedIsMacro { + call.SetKindCase(nil) + } else { + call.SetKindCase(opt.fac.CopyExpr(macroExpr)) + } + // Since SetKindCase does not renumber the id, ensure the references to + // the old 'updated' id are mapped to the target id. + call.RenumberIDs(func(id int64) int64 { + if id == updated.ID() { + return target.ID() + } + return id + }) + } + }) + for _, call := range opt.sourceInfo.MacroCalls() { + ast.PostOrderVisit(call, macroVisitor) + } +} + +func (opt *optimizerExprFactory) sanitizeMacro(macroID int64, macroExpr ast.Expr) { + macroRefVisitor := ast.NewExprVisitor(func(e ast.Expr) { + if _, exists := opt.sourceInfo.GetMacroCall(e.ID()); exists && e.ID() != macroID { + e.SetKindCase(nil) + } + }) + ast.PostOrderVisit(macroExpr, macroRefVisitor) +} diff --git a/vendor/github.com/authzed/cel-go/cel/options.go b/vendor/github.com/authzed/cel-go/cel/options.go new file mode 100644 index 0000000..6da048a --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/options.go @@ -0,0 +1,667 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/dynamicpb" + + "github.com/authzed/cel-go/checker" + "github.com/authzed/cel-go/common/containers" + "github.com/authzed/cel-go/common/functions" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/pb" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/interpreter" + "github.com/authzed/cel-go/parser" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + descpb "google.golang.org/protobuf/types/descriptorpb" +) + +// These constants beginning with "Feature" enable optional behavior in +// the library. See the documentation for each constant to see its +// effects, compatibility restrictions, and standard conformance. +const ( + _ = iota + + // Enable the tracking of function call expressions replaced by macros. + featureEnableMacroCallTracking + + // Enable the use of cross-type numeric comparisons at the type-checker. + featureCrossTypeNumericComparisons + + // Enable eager validation of declarations to ensure that Env values created + // with `Extend` inherit a validated list of declarations from the parent Env. + featureEagerlyValidateDeclarations + + // Enable the use of the default UTC timezone when a timezone is not specified + // on a CEL timestamp operation. This fixes the scenario where the input time + // is not already in UTC. + featureDefaultUTCTimeZone + + // Enable the serialization of logical operator ASTs as variadic calls, thus + // compressing the logic graph to a single call when multiple like-operator + // expressions occur: e.g. a && b && c && d -> call(_&&_, [a, b, c, d]) + featureVariadicLogicalASTs + + // Enable error generation when a presence test or optional field selection is + // performed on a primitive type. + featureEnableErrorOnBadPresenceTest +) + +// EnvOption is a functional interface for configuring the environment. +type EnvOption func(e *Env) (*Env, error) + +// ClearMacros options clears all parser macros. +// +// Clearing macros will ensure CEL expressions can only contain linear evaluation paths, as +// comprehensions such as `all` and `exists` are enabled only via macros. +func ClearMacros() EnvOption { + return func(e *Env) (*Env, error) { + e.macros = NoMacros + return e, nil + } +} + +// CustomTypeAdapter swaps the default types.Adapter implementation with a custom one. +// +// Note: This option must be specified before the Types and TypeDescs options when used together. +func CustomTypeAdapter(adapter types.Adapter) EnvOption { + return func(e *Env) (*Env, error) { + e.adapter = adapter + return e, nil + } +} + +// CustomTypeProvider replaces the types.Provider implementation with a custom one. +// +// The `provider` variable type may either be types.Provider or ref.TypeProvider (deprecated) +// +// Note: This option must be specified before the Types and TypeDescs options when used together. +func CustomTypeProvider(provider any) EnvOption { + return func(e *Env) (*Env, error) { + var err error + e.provider, err = maybeInteropProvider(provider) + return e, err + } +} + +// Declarations option extends the declaration set configured in the environment. +// +// Note: Declarations will by default be appended to the pre-existing declaration set configured +// for the environment. The NewEnv call builds on top of the standard CEL declarations. For a +// purely custom set of declarations use NewCustomEnv. +func Declarations(decls ...*exprpb.Decl) EnvOption { + declOpts := []EnvOption{} + var err error + var opt EnvOption + // Convert the declarations to `EnvOption` values ahead of time. + // Surface any errors in conversion when the options are applied. + for _, d := range decls { + opt, err = ExprDeclToDeclaration(d) + if err != nil { + break + } + declOpts = append(declOpts, opt) + } + return func(e *Env) (*Env, error) { + if err != nil { + return nil, err + } + for _, o := range declOpts { + e, err = o(e) + if err != nil { + return nil, err + } + } + return e, nil + } +} + +// EagerlyValidateDeclarations ensures that any collisions between configured declarations are caught +// at the time of the `NewEnv` call. +// +// Eagerly validating declarations is also useful for bootstrapping a base `cel.Env` value. +// Calls to base `Env.Extend()` will be significantly faster when declarations are eagerly validated +// as declarations will be collision-checked at most once and only incrementally by way of `Extend` +// +// Disabled by default as not all environments are used for type-checking. +func EagerlyValidateDeclarations(enabled bool) EnvOption { + return features(featureEagerlyValidateDeclarations, enabled) +} + +// HomogeneousAggregateLiterals disables mixed type list and map literal values. +// +// Note, it is still possible to have heterogeneous aggregates when provided as variables to the +// expression, as well as via conversion of well-known dynamic types, or with unchecked +// expressions. +func HomogeneousAggregateLiterals() EnvOption { + return ASTValidators(ValidateHomogeneousAggregateLiterals()) +} + +// variadicLogicalOperatorASTs flatten like-operator chained logical expressions into a single +// variadic call with N-terms. This behavior is useful when serializing to a protocol buffer as +// it will reduce the number of recursive calls needed to deserialize the AST later. +// +// For example, given the following expression the call graph will be rendered accordingly: +// +// expression: a && b && c && (d || e) +// ast: call(_&&_, [a, b, c, call(_||_, [d, e])]) +func variadicLogicalOperatorASTs() EnvOption { + return features(featureVariadicLogicalASTs, true) +} + +// Macros option extends the macro set configured in the environment. +// +// Note: This option must be specified after ClearMacros if used together. +func Macros(macros ...Macro) EnvOption { + return func(e *Env) (*Env, error) { + e.macros = append(e.macros, macros...) + return e, nil + } +} + +// Container sets the container for resolving variable names. Defaults to an empty container. +// +// If all references within an expression are relative to a protocol buffer package, then +// specifying a container of `google.type` would make it possible to write expressions such as +// `Expr{expression: 'a < b'}` instead of having to write `google.type.Expr{...}`. +func Container(name string) EnvOption { + return func(e *Env) (*Env, error) { + cont, err := e.Container.Extend(containers.Name(name)) + if err != nil { + return nil, err + } + e.Container = cont + return e, nil + } +} + +// Abbrevs configures a set of simple names as abbreviations for fully-qualified names. +// +// An abbreviation (abbrev for short) is a simple name that expands to a fully-qualified name. +// Abbreviations can be useful when working with variables, functions, and especially types from +// multiple namespaces: +// +// // CEL object construction +// qual.pkg.version.ObjTypeName{ +// field: alt.container.ver.FieldTypeName{value: ...} +// } +// +// Only one the qualified names above may be used as the CEL container, so at least one of these +// references must be a long qualified name within an otherwise short CEL program. Using the +// following abbreviations, the program becomes much simpler: +// +// // CEL Go option +// Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName") +// // Simplified Object construction +// ObjTypeName{field: FieldTypeName{value: ...}} +// +// There are a few rules for the qualified names and the simple abbreviations generated from them: +// - Qualified names must be dot-delimited, e.g. `package.subpkg.name`. +// - The last element in the qualified name is the abbreviation. +// - Abbreviations must not collide with each other. +// - The abbreviation must not collide with unqualified names in use. +// +// Abbreviations are distinct from container-based references in the following important ways: +// - Abbreviations must expand to a fully-qualified name. +// - Expanded abbreviations do not participate in namespace resolution. +// - Abbreviation expansion is done instead of the container search for a matching identifier. +// - Containers follow C++ namespace resolution rules with searches from the most qualified name +// +// to the least qualified name. +// +// - Container references within the CEL program may be relative, and are resolved to fully +// +// qualified names at either type-check time or program plan time, whichever comes first. +// +// If there is ever a case where an identifier could be in both the container and as an +// abbreviation, the abbreviation wins as this will ensure that the meaning of a program is +// preserved between compilations even as the container evolves. +func Abbrevs(qualifiedNames ...string) EnvOption { + return func(e *Env) (*Env, error) { + cont, err := e.Container.Extend(containers.Abbrevs(qualifiedNames...)) + if err != nil { + return nil, err + } + e.Container = cont + return e, nil + } +} + +// customTypeRegistry is an internal-only interface containing the minimum methods required to support +// custom types. It is a subset of methods from ref.TypeRegistry. +type customTypeRegistry interface { + RegisterDescriptor(protoreflect.FileDescriptor) error + RegisterType(...ref.Type) error +} + +// Types adds one or more type declarations to the environment, allowing for construction of +// type-literals whose definitions are included in the common expression built-in set. +// +// The input types may either be instances of `proto.Message` or `ref.Type`. Any other type +// provided to this option will result in an error. +// +// Well-known protobuf types within the `google.protobuf.*` package are included in the standard +// environment by default. +// +// Note: This option must be specified after the CustomTypeProvider option when used together. +func Types(addTypes ...any) EnvOption { + return func(e *Env) (*Env, error) { + reg, isReg := e.provider.(customTypeRegistry) + if !isReg { + return nil, fmt.Errorf("custom types not supported by provider: %T", e.provider) + } + for _, t := range addTypes { + switch v := t.(type) { + case proto.Message: + fdMap := pb.CollectFileDescriptorSet(v) + for _, fd := range fdMap { + err := reg.RegisterDescriptor(fd) + if err != nil { + return nil, err + } + } + case ref.Type: + err := reg.RegisterType(v) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported type: %T", t) + } + } + return e, nil + } +} + +// TypeDescs adds type declarations from any protoreflect.FileDescriptor, protoregistry.Files, +// google.protobuf.FileDescriptorProto or google.protobuf.FileDescriptorSet provided. +// +// Note that messages instantiated from these descriptors will be *dynamicpb.Message values +// rather than the concrete message type. +// +// TypeDescs are hermetic to a single Env object, but may be copied to other Env values via +// extension or by re-using the same EnvOption with another NewEnv() call. +func TypeDescs(descs ...any) EnvOption { + return func(e *Env) (*Env, error) { + reg, isReg := e.provider.(customTypeRegistry) + if !isReg { + return nil, fmt.Errorf("custom types not supported by provider: %T", e.provider) + } + // Scan the input descriptors for FileDescriptorProto messages and accumulate them into a + // synthetic FileDescriptorSet as the FileDescriptorProto messages may refer to each other + // and will not resolve properly unless they are part of the same set. + var fds *descpb.FileDescriptorSet + for _, d := range descs { + switch f := d.(type) { + case *descpb.FileDescriptorProto: + if fds == nil { + fds = &descpb.FileDescriptorSet{ + File: []*descpb.FileDescriptorProto{}, + } + } + fds.File = append(fds.File, f) + } + } + if fds != nil { + if err := registerFileSet(reg, fds); err != nil { + return nil, err + } + } + for _, d := range descs { + switch f := d.(type) { + case *protoregistry.Files: + if err := registerFiles(reg, f); err != nil { + return nil, err + } + case protoreflect.FileDescriptor: + if err := reg.RegisterDescriptor(f); err != nil { + return nil, err + } + case *descpb.FileDescriptorSet: + if err := registerFileSet(reg, f); err != nil { + return nil, err + } + case *descpb.FileDescriptorProto: + // skip, handled as a synthetic file descriptor set. + default: + return nil, fmt.Errorf("unsupported type descriptor: %T", d) + } + } + return e, nil + } +} + +func registerFileSet(reg customTypeRegistry, fileSet *descpb.FileDescriptorSet) error { + files, err := protodesc.NewFiles(fileSet) + if err != nil { + return fmt.Errorf("protodesc.NewFiles(%v) failed: %v", fileSet, err) + } + return registerFiles(reg, files) +} + +func registerFiles(reg customTypeRegistry, files *protoregistry.Files) error { + var err error + files.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + err = reg.RegisterDescriptor(fd) + return err == nil + }) + return err +} + +// ProgramOption is a functional interface for configuring evaluation bindings and behaviors. +type ProgramOption func(p *prog) (*prog, error) + +// CustomDecorator appends an InterpreterDecorator to the program. +// +// InterpretableDecorators can be used to inspect, alter, or replace the Program plan. +func CustomDecorator(dec interpreter.InterpretableDecorator) ProgramOption { + return func(p *prog) (*prog, error) { + p.decorators = append(p.decorators, dec) + return p, nil + } +} + +// Functions adds function overloads that extend or override the set of CEL built-ins. +// +// Deprecated: use Function() instead to declare the function, its overload signatures, +// and the overload implementations. +func Functions(funcs ...*functions.Overload) ProgramOption { + return func(p *prog) (*prog, error) { + if err := p.dispatcher.Add(funcs...); err != nil { + return nil, err + } + return p, nil + } +} + +// Globals sets the global variable values for a given program. These values may be shadowed by +// variables with the same name provided to the Eval() call. If Globals is used in a Library with +// a Lib EnvOption, vars may shadow variables provided by previously added libraries. +// +// The vars value may either be an `interpreter.Activation` instance or a `map[string]any`. +func Globals(vars any) ProgramOption { + return func(p *prog) (*prog, error) { + defaultVars, err := interpreter.NewActivation(vars) + if err != nil { + return nil, err + } + if p.defaultVars != nil { + defaultVars = interpreter.NewHierarchicalActivation(p.defaultVars, defaultVars) + } + p.defaultVars = defaultVars + return p, nil + } +} + +// OptimizeRegex provides a way to replace the InterpretableCall for regex functions. This can be used +// to compile regex string constants at program creation time and report any errors and then use the +// compiled regex for all regex function invocations. +func OptimizeRegex(regexOptimizations ...*interpreter.RegexOptimization) ProgramOption { + return func(p *prog) (*prog, error) { + p.regexOptimizations = append(p.regexOptimizations, regexOptimizations...) + return p, nil + } +} + +// EvalOption indicates an evaluation option that may affect the evaluation behavior or information +// in the output result. +type EvalOption int + +const ( + // OptTrackState will cause the runtime to return an immutable EvalState value in the Result. + OptTrackState EvalOption = 1 << iota + + // OptExhaustiveEval causes the runtime to disable short-circuits and track state. + OptExhaustiveEval EvalOption = 1<<iota | OptTrackState + + // OptOptimize precomputes functions and operators with constants as arguments at program + // creation time. It also pre-compiles regex pattern constants passed to 'matches', reports any compilation errors + // at program creation and uses the compiled regex pattern for all 'matches' function invocations. + // This flag is useful when the expression will be evaluated repeatedly against + // a series of different inputs. + OptOptimize EvalOption = 1 << iota + + // OptPartialEval enables the evaluation of a partial state where the input data that may be + // known to be missing, either as top-level variables, or somewhere within a variable's object + // member graph. + // + // By itself, OptPartialEval does not change evaluation behavior unless the input to the + // Program Eval() call is created via PartialVars(). + OptPartialEval EvalOption = 1 << iota + + // OptTrackCost enables the runtime cost calculation while validation and return cost within evalDetails + // cost calculation is available via func ActualCost() + OptTrackCost EvalOption = 1 << iota + + // OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality. + // + // Deprecated: use ext.StringsValidateFormatCalls() as this option is now a no-op. + OptCheckStringFormat EvalOption = 1 << iota +) + +// EvalOptions sets one or more evaluation options which may affect the evaluation or Result. +func EvalOptions(opts ...EvalOption) ProgramOption { + return func(p *prog) (*prog, error) { + for _, opt := range opts { + p.evalOpts |= opt + } + return p, nil + } +} + +// InterruptCheckFrequency configures the number of iterations within a comprehension to evaluate +// before checking whether the function evaluation has been interrupted. +func InterruptCheckFrequency(checkFrequency uint) ProgramOption { + return func(p *prog) (*prog, error) { + p.interruptCheckFrequency = checkFrequency + return p, nil + } +} + +// CostEstimatorOptions configure type-check time options for estimating expression cost. +func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption { + return func(e *Env) (*Env, error) { + e.costOptions = append(e.costOptions, costOpts...) + return e, nil + } +} + +// CostTrackerOptions configures a set of options for cost-tracking. +// +// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled. +func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption { + return func(p *prog) (*prog, error) { + p.costOptions = append(p.costOptions, costOpts...) + return p, nil + } +} + +// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls. +func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption { + return func(p *prog) (*prog, error) { + p.callCostEstimator = costEstimator + p.evalOpts |= OptTrackCost + return p, nil + } +} + +// CostLimit enables cost tracking and sets configures program evaluation to exit early with a +// "runtime cost limit exceeded" error if the runtime cost exceeds the costLimit. +// The CostLimit is a metric that corresponds to the number and estimated expense of operations +// performed while evaluating an expression. It is indicative of CPU usage, not memory usage. +func CostLimit(costLimit uint64) ProgramOption { + return func(p *prog) (*prog, error) { + p.costLimit = &costLimit + p.evalOpts |= OptTrackCost + return p, nil + } +} + +func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) { + if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind { + msgName := (string)(field.Message().FullName()) + return ObjectType(msgName), nil + } + if primitiveType, found := types.ProtoCELPrimitives[field.Kind()]; found { + return primitiveType, nil + } + if field.Kind() == protoreflect.EnumKind { + return IntType, nil + } + return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String()) +} + +func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) { + name := string(field.Name()) + if field.IsMap() { + mapKey := field.MapKey() + mapValue := field.MapValue() + keyType, err := fieldToCELType(mapKey) + if err != nil { + return nil, err + } + valueType, err := fieldToCELType(mapValue) + if err != nil { + return nil, err + } + return Variable(name, MapType(keyType, valueType)), nil + } + if field.IsList() { + elemType, err := fieldToCELType(field) + if err != nil { + return nil, err + } + return Variable(name, ListType(elemType)), nil + } + celType, err := fieldToCELType(field) + if err != nil { + return nil, err + } + return Variable(name, celType), nil +} + +// DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto. +// Each field of the proto defines a variable of the same name in the environment. +// https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment +func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption { + return func(e *Env) (*Env, error) { + fields := descriptor.Fields() + for i := 0; i < fields.Len(); i++ { + field := fields.Get(i) + variable, err := fieldToVariable(field) + if err != nil { + return nil, err + } + e, err = variable(e) + if err != nil { + return nil, err + } + } + return Types(dynamicpb.NewMessage(descriptor))(e) + } +} + +// ContextProtoVars uses the fields of the input proto.Messages as top-level variables within an Activation. +// +// Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using +// protocol buffers. +func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) { + if ctx == nil || !ctx.ProtoReflect().IsValid() { + return interpreter.EmptyActivation(), nil + } + reg, err := types.NewRegistry(ctx) + if err != nil { + return nil, err + } + pbRef := ctx.ProtoReflect() + typeName := string(pbRef.Descriptor().FullName()) + fields := pbRef.Descriptor().Fields() + vars := make(map[string]any, fields.Len()) + for i := 0; i < fields.Len(); i++ { + field := fields.Get(i) + sft, found := reg.FindStructFieldType(typeName, field.TextName()) + if !found { + return nil, fmt.Errorf("no such field: %s", field.TextName()) + } + fieldVal, err := sft.GetFrom(ctx) + if err != nil { + return nil, err + } + vars[field.TextName()] = fieldVal + } + return interpreter.NewActivation(vars) +} + +// EnableMacroCallTracking ensures that call expressions which are replaced by macros +// are tracked in the `SourceInfo` of parsed and checked expressions. +func EnableMacroCallTracking() EnvOption { + return features(featureEnableMacroCallTracking, true) +} + +// CrossTypeNumericComparisons makes it possible to compare across numeric types, e.g. double < int +func CrossTypeNumericComparisons(enabled bool) EnvOption { + return features(featureCrossTypeNumericComparisons, enabled) +} + +// DefaultUTCTimeZone ensures that time-based operations use the UTC timezone rather than the +// input time's local timezone. +func DefaultUTCTimeZone(enabled bool) EnvOption { + return features(featureDefaultUTCTimeZone, enabled) +} + +// features sets the given feature flags. See list of Feature constants above. +func features(flag int, enabled bool) EnvOption { + return func(e *Env) (*Env, error) { + e.features[flag] = enabled + return e, nil + } +} + +// ParserRecursionLimit adjusts the AST depth the parser will tolerate. +// Defaults defined in the parser package. +func ParserRecursionLimit(limit int) EnvOption { + return func(e *Env) (*Env, error) { + e.prsrOpts = append(e.prsrOpts, parser.MaxRecursionDepth(limit)) + return e, nil + } +} + +// ParserExpressionSizeLimit adjusts the number of code points the expression parser is allowed to parse. +// Defaults defined in the parser package. +func ParserExpressionSizeLimit(limit int) EnvOption { + return func(e *Env) (*Env, error) { + e.prsrOpts = append(e.prsrOpts, parser.ExpressionSizeCodePointLimit(limit)) + return e, nil + } +} + +func maybeInteropProvider(provider any) (types.Provider, error) { + switch p := provider.(type) { + case types.Provider: + return p, nil + case ref.TypeProvider: + return &interopCELTypeProvider{TypeProvider: p}, nil + default: + return nil, fmt.Errorf("unsupported type provider: %T", provider) + } +} diff --git a/vendor/github.com/authzed/cel-go/cel/program.go b/vendor/github.com/authzed/cel-go/cel/program.go new file mode 100644 index 0000000..79ba9ed --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/program.go @@ -0,0 +1,545 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "context" + "fmt" + "sync" + + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/interpreter" +) + +// Program is an evaluable view of an Ast. +type Program interface { + // Eval returns the result of an evaluation of the Ast and environment against the input vars. + // + // The vars value may either be an `interpreter.Activation` or a `map[string]any`. + // + // If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will + // be non-nil. Given this caveat on `details`, the return state from evaluation will be: + // + // * `val`, `details`, `nil` - Successful evaluation of a non-error result. + // * `val`, `details`, `err` - Successful evaluation to an error result. + // * `nil`, `details`, `err` - Unsuccessful evaluation. + // + // An unsuccessful evaluation is typically the result of a series of incompatible `EnvOption` + // or `ProgramOption` values used in the creation of the evaluation environment or executable + // program. + Eval(any) (ref.Val, *EvalDetails, error) + + // ContextEval evaluates the program with a set of input variables and a context object in order + // to support cancellation and timeouts. This method must be used in conjunction with the + // InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation. + // + // The vars value may either be an `interpreter.Activation` or `map[string]any`. + // + // The output contract for `ContextEval` is otherwise identical to the `Eval` method. + ContextEval(context.Context, any) (ref.Val, *EvalDetails, error) +} + +// NoVars returns an empty Activation. +func NoVars() interpreter.Activation { + return interpreter.EmptyActivation() +} + +// PartialVars returns a PartialActivation which contains variables and a set of AttributePattern +// values that indicate variables or parts of variables whose value are not yet known. +// +// This method relies on manually configured sets of missing attribute patterns. For a method which +// infers the missing variables from the input and the configured environment, use Env.PartialVars(). +// +// The `vars` value may either be an interpreter.Activation or any valid input to the +// interpreter.NewActivation call. +func PartialVars(vars any, + unknowns ...*interpreter.AttributePattern) (interpreter.PartialActivation, error) { + return interpreter.NewPartialActivation(vars, unknowns...) +} + +// AttributePattern returns an AttributePattern that matches a top-level variable. The pattern is +// mutable, and its methods support the specification of one or more qualifier patterns. +// +// For example, the AttributePattern(`a`).QualString(`b`) represents a variable access `a` with a +// string field or index qualification `b`. This pattern will match Attributes `a`, and `a.b`, +// but not `a.c`. +// +// When using a CEL expression within a container, e.g. a package or namespace, the variable name +// in the pattern must match the qualified name produced during the variable namespace resolution. +// For example, when variable `a` is declared within an expression whose container is `ns.app`, the +// fully qualified variable name may be `ns.app.a`, `ns.a`, or `a` per the CEL namespace resolution +// rules. Pick the fully qualified variable name that makes sense within the container as the +// AttributePattern `varName` argument. +// +// See the interpreter.AttributePattern and interpreter.AttributeQualifierPattern for more info +// about how to create and manipulate AttributePattern values. +func AttributePattern(varName string) *interpreter.AttributePattern { + return interpreter.NewAttributePattern(varName) +} + +// EvalDetails holds additional information observed during the Eval() call. +type EvalDetails struct { + state interpreter.EvalState + costTracker *interpreter.CostTracker +} + +// State of the evaluation, non-nil if the OptTrackState or OptExhaustiveEval is specified +// within EvalOptions. +func (ed *EvalDetails) State() interpreter.EvalState { + return ed.state +} + +// ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled. +// Otherwise, returns nil if the cost was not enabled. +func (ed *EvalDetails) ActualCost() *uint64 { + if ed == nil || ed.costTracker == nil { + return nil + } + cost := ed.costTracker.ActualCost() + return &cost +} + +// prog is the internal implementation of the Program interface. +type prog struct { + *Env + evalOpts EvalOption + defaultVars interpreter.Activation + dispatcher interpreter.Dispatcher + interpreter interpreter.Interpreter + interruptCheckFrequency uint + + // Intermediate state used to configure the InterpretableDecorator set provided + // to the initInterpretable call. + decorators []interpreter.InterpretableDecorator + regexOptimizations []*interpreter.RegexOptimization + + // Interpretable configured from an Ast and aggregate decorator set based on program options. + interpretable interpreter.Interpretable + callCostEstimator interpreter.ActualCostEstimator + costOptions []interpreter.CostTrackerOption + costLimit *uint64 +} + +func (p *prog) clone() *prog { + costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions)) + copy(costOptsCopy, p.costOptions) + + return &prog{ + Env: p.Env, + evalOpts: p.evalOpts, + defaultVars: p.defaultVars, + dispatcher: p.dispatcher, + interpreter: p.interpreter, + interruptCheckFrequency: p.interruptCheckFrequency, + } +} + +// newProgram creates a program instance with an environment, an ast, and an optional list of +// ProgramOption values. +// +// If the program cannot be configured the prog will be nil, with a non-nil error response. +func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { + // Build the dispatcher, interpreter, and default program value. + disp := interpreter.NewDispatcher() + + // Ensure the default attribute factory is set after the adapter and provider are + // configured. + p := &prog{ + Env: e, + decorators: []interpreter.InterpretableDecorator{}, + dispatcher: disp, + costOptions: []interpreter.CostTrackerOption{}, + } + + // Configure the program via the ProgramOption values. + var err error + for _, opt := range opts { + p, err = opt(p) + if err != nil { + return nil, err + } + } + + // Add the function bindings created via Function() options. + for _, fn := range e.functions { + bindings, err := fn.Bindings() + if err != nil { + return nil, err + } + err = disp.Add(bindings...) + if err != nil { + return nil, err + } + } + + // Set the attribute factory after the options have been set. + var attrFactory interpreter.AttributeFactory + attrFactorOpts := []interpreter.AttrFactoryOption{ + interpreter.EnableErrorOnBadPresenceTest(p.HasFeature(featureEnableErrorOnBadPresenceTest)), + } + if p.evalOpts&OptPartialEval == OptPartialEval { + attrFactory = interpreter.NewPartialAttributeFactory(e.Container, e.adapter, e.provider, attrFactorOpts...) + } else { + attrFactory = interpreter.NewAttributeFactory(e.Container, e.adapter, e.provider, attrFactorOpts...) + } + interp := interpreter.NewInterpreter(disp, e.Container, e.provider, e.adapter, attrFactory) + p.interpreter = interp + + // Translate the EvalOption flags into InterpretableDecorator instances. + decorators := make([]interpreter.InterpretableDecorator, len(p.decorators)) + copy(decorators, p.decorators) + + // Enable interrupt checking if there's a non-zero check frequency + if p.interruptCheckFrequency > 0 { + decorators = append(decorators, interpreter.InterruptableEval()) + } + // Enable constant folding first. + if p.evalOpts&OptOptimize == OptOptimize { + decorators = append(decorators, interpreter.Optimize()) + p.regexOptimizations = append(p.regexOptimizations, interpreter.MatchesRegexOptimization) + } + // Enable regex compilation of constants immediately after folding constants. + if len(p.regexOptimizations) > 0 { + decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...)) + } + + // Enable exhaustive eval, state tracking and cost tracking last since they require a factory. + if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 { + factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) { + costTracker.Estimator = p.callCostEstimator + costTracker.Limit = p.costLimit + for _, costOpt := range p.costOptions { + err := costOpt(costTracker) + if err != nil { + return nil, err + } + } + // Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This + // prevents the underlying memory from being shared between factory function calls causing + // undesired mutations. + decs := decorators[:len(decorators):len(decorators)] + var observers []interpreter.EvalObserver + + if p.evalOpts&(OptExhaustiveEval|OptTrackState) != 0 { + // EvalStateObserver is required for OptExhaustiveEval. + observers = append(observers, interpreter.EvalStateObserver(state)) + } + if p.evalOpts&OptTrackCost == OptTrackCost { + observers = append(observers, interpreter.CostObserver(costTracker)) + } + + // Enable exhaustive eval over a basic observer since it offers a superset of features. + if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval { + decs = append(decs, interpreter.ExhaustiveEval(), interpreter.Observe(observers...)) + } else if len(observers) > 0 { + decs = append(decs, interpreter.Observe(observers...)) + } + + return p.clone().initInterpretable(a, decs) + } + return newProgGen(factory) + } + return p.initInterpretable(a, decorators) +} + +func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) { + // When the AST has been exprAST it contains metadata that can be used to speed up program execution. + interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...) + if err != nil { + return nil, err + } + p.interpretable = interpretable + return p, nil +} + +// Eval implements the Program interface method. +func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) { + // Configure error recovery for unexpected panics during evaluation. Note, the use of named + // return values makes it possible to modify the error response during the recovery + // function. + defer func() { + if r := recover(); r != nil { + switch t := r.(type) { + case interpreter.EvalCancelledError: + err = t + default: + err = fmt.Errorf("internal error: %v", r) + } + } + }() + // Build a hierarchical activation if there are default vars set. + var vars interpreter.Activation + switch v := input.(type) { + case interpreter.Activation: + vars = v + case map[string]any: + vars = activationPool.Setup(v) + defer activationPool.Put(vars) + default: + return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input) + } + if p.defaultVars != nil { + vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars) + } + v = p.interpretable.Eval(vars) + // The output of an internal Eval may have a value (`v`) that is a types.Err. This step + // translates the CEL value to a Go error response. This interface does not quite match the + // RPC signature which allows for multiple errors to be returned, but should be sufficient. + if types.IsError(v) { + err = v.(*types.Err) + } + return +} + +// ContextEval implements the Program interface. +func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) { + if ctx == nil { + return nil, nil, fmt.Errorf("context can not be nil") + } + // Configure the input, making sure to wrap Activation inputs in the special ctxActivation which + // exposes the #interrupted variable and manages rate-limited checks of the ctx.Done() state. + var vars interpreter.Activation + switch v := input.(type) { + case interpreter.Activation: + vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency) + defer ctxActivationPool.Put(vars) + case map[string]any: + rawVars := activationPool.Setup(v) + defer activationPool.Put(rawVars) + vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency) + defer ctxActivationPool.Put(vars) + default: + return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input) + } + return p.Eval(vars) +} + +// progFactory is a helper alias for marking a program creation factory function. +type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error) + +// progGen holds a reference to a progFactory instance and implements the Program interface. +type progGen struct { + factory progFactory +} + +// newProgGen tests the factory object by calling it once and returns a factory-based Program if +// the test is successful. +func newProgGen(factory progFactory) (Program, error) { + // Test the factory to make sure that configuration errors are spotted at config + tracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, err + } + _, err = factory(interpreter.NewEvalState(), tracker) + if err != nil { + return nil, err + } + return &progGen{factory: factory}, nil +} + +// Eval implements the Program interface method. +func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) { + // The factory based Eval() differs from the standard evaluation model in that it generates a + // new EvalState instance for each call to ensure that unique evaluations yield unique stateful + // results. + state := interpreter.NewEvalState() + costTracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, nil, err + } + det := &EvalDetails{state: state, costTracker: costTracker} + + // Generate a new instance of the interpretable using the factory configured during the call to + // newProgram(). It is incredibly unlikely that the factory call will generate an error given + // the factory test performed within the Program() call. + p, err := gen.factory(state, costTracker) + if err != nil { + return nil, det, err + } + + // Evaluate the input, returning the result and the 'state' within EvalDetails. + v, _, err := p.Eval(input) + if err != nil { + return v, det, err + } + return v, det, nil +} + +// ContextEval implements the Program interface method. +func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) { + if ctx == nil { + return nil, nil, fmt.Errorf("context can not be nil") + } + // The factory based Eval() differs from the standard evaluation model in that it generates a + // new EvalState instance for each call to ensure that unique evaluations yield unique stateful + // results. + state := interpreter.NewEvalState() + costTracker, err := interpreter.NewCostTracker(nil) + if err != nil { + return nil, nil, err + } + det := &EvalDetails{state: state, costTracker: costTracker} + + // Generate a new instance of the interpretable using the factory configured during the call to + // newProgram(). It is incredibly unlikely that the factory call will generate an error given + // the factory test performed within the Program() call. + p, err := gen.factory(state, costTracker) + if err != nil { + return nil, det, err + } + + // Evaluate the input, returning the result and the 'state' within EvalDetails. + v, _, err := p.ContextEval(ctx, input) + if err != nil { + return v, det, err + } + return v, det, nil +} + +type ctxEvalActivation struct { + parent interpreter.Activation + interrupt <-chan struct{} + interruptCheckCount uint + interruptCheckFrequency uint +} + +// ResolveName implements the Activation interface method, but adds a special #interrupted variable +// which is capable of testing whether a 'done' signal is provided from a context.Context channel. +func (a *ctxEvalActivation) ResolveName(name string) (any, bool) { + if name == "#interrupted" { + a.interruptCheckCount++ + if a.interruptCheckCount%a.interruptCheckFrequency == 0 { + select { + case <-a.interrupt: + return true, true + default: + return nil, false + } + } + return nil, false + } + return a.parent.ResolveName(name) +} + +func (a *ctxEvalActivation) Parent() interpreter.Activation { + return a.parent +} + +func newCtxEvalActivationPool() *ctxEvalActivationPool { + return &ctxEvalActivationPool{ + Pool: sync.Pool{ + New: func() any { + return &ctxEvalActivation{} + }, + }, + } +} + +type ctxEvalActivationPool struct { + sync.Pool +} + +// Setup initializes a pooled Activation with the ability check for context.Context cancellation +func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation { + a := p.Pool.Get().(*ctxEvalActivation) + a.parent = vars + a.interrupt = done + a.interruptCheckCount = 0 + a.interruptCheckFrequency = interruptCheckRate + return a +} + +type evalActivation struct { + vars map[string]any + lazyVars map[string]any +} + +// ResolveName looks up the value of the input variable name, if found. +// +// Lazy bindings may be supplied within the map-based input in either of the following forms: +// - func() any +// - func() ref.Val +// +// The lazy binding will only be invoked once per evaluation. +// +// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using +// the types.Adapter configured in the environment. +func (a *evalActivation) ResolveName(name string) (any, bool) { + v, found := a.vars[name] + if !found { + return nil, false + } + switch obj := v.(type) { + case func() ref.Val: + if resolved, found := a.lazyVars[name]; found { + return resolved, true + } + lazy := obj() + a.lazyVars[name] = lazy + return lazy, true + case func() any: + if resolved, found := a.lazyVars[name]; found { + return resolved, true + } + lazy := obj() + a.lazyVars[name] = lazy + return lazy, true + default: + return obj, true + } +} + +// Parent implements the interpreter.Activation interface +func (a *evalActivation) Parent() interpreter.Activation { + return nil +} + +func newEvalActivationPool() *evalActivationPool { + return &evalActivationPool{ + Pool: sync.Pool{ + New: func() any { + return &evalActivation{lazyVars: make(map[string]any)} + }, + }, + } +} + +type evalActivationPool struct { + sync.Pool +} + +// Setup initializes a pooled Activation object with the map input. +func (p *evalActivationPool) Setup(vars map[string]any) *evalActivation { + a := p.Pool.Get().(*evalActivation) + a.vars = vars + return a +} + +func (p *evalActivationPool) Put(value any) { + a := value.(*evalActivation) + for k := range a.lazyVars { + delete(a.lazyVars, k) + } + p.Pool.Put(a) +} + +var ( + // activationPool is an internally managed pool of Activation values that wrap map[string]any inputs + activationPool = newEvalActivationPool() + + // ctxActivationPool is an internally managed pool of Activation values that expose a special #interrupted variable + ctxActivationPool = newCtxEvalActivationPool() +) diff --git a/vendor/github.com/authzed/cel-go/cel/validator.go b/vendor/github.com/authzed/cel-go/cel/validator.go new file mode 100644 index 0000000..1664c13 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/cel/validator.go @@ -0,0 +1,375 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + "reflect" + "regexp" + + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/overloads" +) + +const ( + homogeneousValidatorName = "cel.lib.std.validate.types.homogeneous" + + // HomogeneousAggregateLiteralExemptFunctions is the ValidatorConfig key used to configure + // the set of function names which are exempt from homogeneous type checks. The expected type + // is a string list of function names. + // + // As an example, the `<string>.format([args])` call expects the input arguments list to be + // comprised of a variety of types which correspond to the types expected by the format control + // clauses; however, all other uses of a mixed element type list, would be unexpected. + HomogeneousAggregateLiteralExemptFunctions = homogeneousValidatorName + ".exempt" +) + +// ASTValidators configures a set of ASTValidator instances into the target environment. +// +// Validators are applied in the order in which the are specified and are treated as singletons. +// The same ASTValidator with a given name will not be applied more than once. +func ASTValidators(validators ...ASTValidator) EnvOption { + return func(e *Env) (*Env, error) { + for _, v := range validators { + if !e.HasValidator(v.Name()) { + e.validators = append(e.validators, v) + } + } + return e, nil + } +} + +// ASTValidator defines a singleton interface for validating a type-checked Ast against an environment. +// +// Note: the Issues argument is mutable in the sense that it is intended to collect errors which will be +// reported to the caller. +type ASTValidator interface { + // Name returns the name of the validator. Names must be unique. + Name() string + + // Validate validates a given Ast within an Environment and collects a set of potential issues. + // + // The ValidatorConfig is generated from the set of ASTValidatorConfigurer instances prior to + // the invocation of the Validate call. The expectation is that the validator configuration + // is created in sequence and immutable once provided to the Validate call. + // + // See individual validators for more information on their configuration keys and configuration + // properties. + Validate(*Env, ValidatorConfig, *ast.AST, *Issues) +} + +// ValidatorConfig provides an accessor method for querying validator configuration state. +type ValidatorConfig interface { + GetOrDefault(name string, value any) any +} + +// MutableValidatorConfig provides mutation methods for querying and updating validator configuration +// settings. +type MutableValidatorConfig interface { + ValidatorConfig + Set(name string, value any) error +} + +// ASTValidatorConfigurer indicates that this object, currently expected to be an ASTValidator, +// participates in validator configuration settings. +// +// This interface may be split from the expectation of being an ASTValidator instance in the future. +type ASTValidatorConfigurer interface { + Configure(MutableValidatorConfig) error +} + +// validatorConfig implements the ValidatorConfig and MutableValidatorConfig interfaces. +type validatorConfig struct { + data map[string]any +} + +// newValidatorConfig initializes the validator config with default values for core CEL validators. +func newValidatorConfig() *validatorConfig { + return &validatorConfig{ + data: map[string]any{ + HomogeneousAggregateLiteralExemptFunctions: []string{}, + }, + } +} + +// GetOrDefault returns the configured value for the name, if present, else the input default value. +// +// Note, the type-agreement between the input default and configured value is not checked on read. +func (config *validatorConfig) GetOrDefault(name string, value any) any { + v, found := config.data[name] + if !found { + return value + } + return v +} + +// Set configures a validator option with the given name and value. +// +// If the value had previously been set, the new value must have the same reflection type as the old one, +// or the call will error. +func (config *validatorConfig) Set(name string, value any) error { + v, found := config.data[name] + if found && reflect.TypeOf(v) != reflect.TypeOf(value) { + return fmt.Errorf("incompatible configuration type for %s, got %T, wanted %T", name, value, v) + } + config.data[name] = value + return nil +} + +// ExtendedValidations collects a set of common AST validations which reduce the likelihood of runtime errors. +// +// - Validate duration and timestamp literals +// - Ensure regex strings are valid +// - Disable mixed type list and map literals +func ExtendedValidations() EnvOption { + return ASTValidators( + ValidateDurationLiterals(), + ValidateTimestampLiterals(), + ValidateRegexLiterals(), + ValidateHomogeneousAggregateLiterals(), + ) +} + +// ValidateDurationLiterals ensures that duration literal arguments are valid immediately after type-check. +func ValidateDurationLiterals() ASTValidator { + return newFormatValidator(overloads.TypeConvertDuration, 0, evalCall) +} + +// ValidateTimestampLiterals ensures that timestamp literal arguments are valid immediately after type-check. +func ValidateTimestampLiterals() ASTValidator { + return newFormatValidator(overloads.TypeConvertTimestamp, 0, evalCall) +} + +// ValidateRegexLiterals ensures that regex patterns are validated after type-check. +func ValidateRegexLiterals() ASTValidator { + return newFormatValidator(overloads.Matches, 0, compileRegex) +} + +// ValidateHomogeneousAggregateLiterals checks that all list and map literals entries have the same types, i.e. +// no mixed list element types or mixed map key or map value types. +// +// Note: the string format call relies on a mixed element type list for ease of use, so this check skips all +// literals which occur within string format calls. +func ValidateHomogeneousAggregateLiterals() ASTValidator { + return homogeneousAggregateLiteralValidator{} +} + +// ValidateComprehensionNestingLimit ensures that comprehension nesting does not exceed the specified limit. +// +// This validator can be useful for preventing arbitrarily nested comprehensions which can take high polynomial +// time to complete. +// +// Note, this limit does not apply to comprehensions with an empty iteration range, as these comprehensions have +// no actual looping cost. The cel.bind() utilizes the comprehension structure to perform local variable +// assignments and supplies an empty iteration range, so they won't count against the nesting limit either. +func ValidateComprehensionNestingLimit(limit int) ASTValidator { + return nestingLimitValidator{limit: limit} +} + +type argChecker func(env *Env, call, arg ast.Expr) error + +func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator { + return formatValidator{ + funcName: funcName, + check: check, + argNum: argNum, + } +} + +type formatValidator struct { + funcName string + argNum int + check argChecker +} + +// Name returns the unique name of this function format validator. +func (v formatValidator) Name() string { + return fmt.Sprintf("cel.lib.std.validate.functions.%s", v.funcName) +} + +// Validate searches the AST for uses of a given function name with a constant argument and performs a check +// on whether the argument is a valid literal value. +func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) { + root := ast.NavigateAST(a) + funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName)) + for _, call := range funcCalls { + callArgs := call.AsCall().Args() + if len(callArgs) <= v.argNum { + continue + } + litArg := callArgs[v.argNum] + if litArg.Kind() != ast.LiteralKind { + continue + } + if err := v.check(e, call, litArg); err != nil { + iss.ReportErrorAtID(litArg.ID(), "invalid %s argument", v.funcName) + } + } +} + +func evalCall(env *Env, call, arg ast.Expr) error { + ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))} + prg, err := env.Program(ast) + if err != nil { + return err + } + _, _, err = prg.Eval(NoVars()) + return err +} + +func compileRegex(_ *Env, _, arg ast.Expr) error { + pattern := arg.AsLiteral().Value().(string) + _, err := regexp.Compile(pattern) + return err +} + +type homogeneousAggregateLiteralValidator struct{} + +// Name returns the unique name of the homogeneous type validator. +func (homogeneousAggregateLiteralValidator) Name() string { + return homogeneousValidatorName +} + +// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types. +// +// This validator makes an exception for list and map literals which occur at any level of nesting within +// string format calls. +func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) { + var exemptedFunctions []string + exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string) + root := ast.NavigateAST(a) + listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind)) + for _, listExpr := range listExprs { + if inExemptFunction(listExpr, exemptedFunctions) { + continue + } + l := listExpr.AsList() + elements := l.Elements() + optIndices := l.OptionalIndices() + var elemType *Type + for i, e := range elements { + et := a.GetType(e.ID()) + if isOptionalIndex(i, optIndices) { + et = et.Parameters()[0] + } + if elemType == nil { + elemType = et + continue + } + if !elemType.IsEquivalentType(et) { + v.typeMismatch(iss, e.ID(), elemType, et) + break + } + } + } + mapExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.MapKind)) + for _, mapExpr := range mapExprs { + if inExemptFunction(mapExpr, exemptedFunctions) { + continue + } + m := mapExpr.AsMap() + entries := m.Entries() + var keyType, valType *Type + for _, e := range entries { + mapEntry := e.AsMapEntry() + key, val := mapEntry.Key(), mapEntry.Value() + kt, vt := a.GetType(key.ID()), a.GetType(val.ID()) + if mapEntry.IsOptional() { + vt = vt.Parameters()[0] + } + if keyType == nil && valType == nil { + keyType, valType = kt, vt + continue + } + if !keyType.IsEquivalentType(kt) { + v.typeMismatch(iss, key.ID(), keyType, kt) + } + if !valType.IsEquivalentType(vt) { + v.typeMismatch(iss, val.ID(), valType, vt) + } + } + } +} + +func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool { + parent, found := e.Parent() + for found { + if parent.Kind() == ast.CallKind { + fnName := parent.AsCall().FunctionName() + for _, exempt := range exemptFunctions { + if exempt == fnName { + return true + } + } + } + parent, found = parent.Parent() + } + return false +} + +func isOptionalIndex(i int, optIndices []int32) bool { + for _, optInd := range optIndices { + if i == int(optInd) { + return true + } + } + return false +} + +func (homogeneousAggregateLiteralValidator) typeMismatch(iss *Issues, id int64, expected, actual *Type) { + iss.ReportErrorAtID(id, "expected type '%s' but found '%s'", FormatCELType(expected), FormatCELType(actual)) +} + +type nestingLimitValidator struct { + limit int +} + +func (v nestingLimitValidator) Name() string { + return "cel.lib.std.validate.comprehension_nesting_limit" +} + +func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) { + root := ast.NavigateAST(a) + comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) + if len(comprehensions) <= v.limit { + return + } + for _, comp := range comprehensions { + count := 0 + e := comp + hasParent := true + for hasParent { + // When the expression is not a comprehension, continue to the next ancestor. + if e.Kind() != ast.ComprehensionKind { + e, hasParent = e.Parent() + continue + } + // When the comprehension has an empty range, continue to the next ancestor + // as this comprehension does not have any associated cost. + iterRange := e.AsComprehension().IterRange() + if iterRange.Kind() == ast.ListKind && iterRange.AsList().Size() == 0 { + e, hasParent = e.Parent() + continue + } + // Otherwise check the nesting limit. + count++ + if count > v.limit { + iss.ReportErrorAtID(comp.ID(), "comprehension exceeds nesting limit") + break + } + e, hasParent = e.Parent() + } + } +} |
