diff options
Diffstat (limited to 'vendor/github.com/authzed/cel-go/interpreter')
13 files changed, 5669 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/cel-go/interpreter/BUILD.bazel b/vendor/github.com/authzed/cel-go/interpreter/BUILD.bazel new file mode 100644 index 0000000..9778a62 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/BUILD.bazel @@ -0,0 +1,74 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +go_library( + name = "go_default_library", + srcs = [ + "activation.go", + "attribute_patterns.go", + "attributes.go", + "decorators.go", + "dispatcher.go", + "evalstate.go", + "interpretable.go", + "interpreter.go", + "optimizations.go", + "planner.go", + "prune.go", + "runtimecost.go", + ], + importpath = "github.com/authzed/cel-go/interpreter", + deps = [ + "//common:go_default_library", + "//common/ast:go_default_library", + "//common/containers:go_default_library", + "//common/functions:go_default_library", + "//common/operators:go_default_library", + "//common/overloads:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//common/types/traits:go_default_library", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//types/known/durationpb:go_default_library", + "@org_golang_google_protobuf//types/known/structpb:go_default_library", + "@org_golang_google_protobuf//types/known/timestamppb:go_default_library", + "@org_golang_google_protobuf//types/known/wrapperspb:go_default_library", + ], +) + +go_test( + name = "go_default_test", + srcs = [ + "activation_test.go", + "attribute_patterns_test.go", + "attributes_test.go", + "interpreter_test.go", + "prune_test.go", + "runtimecost_test.go", + ], + embed = [ + ":go_default_library", + ], + deps = [ + "//checker:go_default_library", + "//common/containers:go_default_library", + "//common/debug:go_default_library", + "//common/decls:go_default_library", + "//common/functions:go_default_library", + "//common/operators:go_default_library", + "//common/stdlib:go_default_library", + "//common/types:go_default_library", + "//parser:go_default_library", + "//test:go_default_library", + "//test/proto2pb:go_default_library", + "//test/proto3pb:go_default_library", + "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//types/known/anypb:go_default_library", + ], +) diff --git a/vendor/github.com/authzed/cel-go/interpreter/activation.go b/vendor/github.com/authzed/cel-go/interpreter/activation.go new file mode 100644 index 0000000..1d1bc5a --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/activation.go @@ -0,0 +1,201 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "errors" + "fmt" + "sync" + + "github.com/authzed/cel-go/common/types/ref" +) + +// Activation used to resolve identifiers by name and references by id. +// +// An Activation is the primary mechanism by which a caller supplies input into a CEL program. +type Activation interface { + // ResolveName returns a value from the activation by qualified name, or false if the name + // could not be found. + ResolveName(name string) (any, bool) + + // Parent returns the parent of the current activation, may be nil. + // If non-nil, the parent will be searched during resolve calls. + Parent() Activation +} + +// EmptyActivation returns a variable-free activation. +func EmptyActivation() Activation { + return emptyActivation{} +} + +// emptyActivation is a variable-free activation. +type emptyActivation struct{} + +func (emptyActivation) ResolveName(string) (any, bool) { return nil, false } +func (emptyActivation) Parent() Activation { return nil } + +// NewActivation returns an activation based on a map-based binding where the map keys are +// expected to be qualified names used with ResolveName calls. +// +// The input `bindings` may either be of type `Activation` or `map[string]any`. +// +// Lazy bindings may be supplied within the map-based input in either of the following forms: +// - func() any +// - func() ref.Val +// +// The output of the lazy binding will overwrite the variable reference in the internal map. +// +// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using +// the types.Adapter configured in the environment. +func NewActivation(bindings any) (Activation, error) { + if bindings == nil { + return nil, errors.New("bindings must be non-nil") + } + a, isActivation := bindings.(Activation) + if isActivation { + return a, nil + } + m, isMap := bindings.(map[string]any) + if !isMap { + return nil, fmt.Errorf( + "activation input must be an activation or map[string]interface: got %T", + bindings) + } + return &mapActivation{bindings: m}, nil +} + +// mapActivation which implements Activation and maps of named values. +// +// Named bindings may lazily supply values by providing a function which accepts no arguments and +// produces an interface value. +type mapActivation struct { + bindings map[string]any +} + +// Parent implements the Activation interface method. +func (a *mapActivation) Parent() Activation { + return nil +} + +// ResolveName implements the Activation interface method. +func (a *mapActivation) ResolveName(name string) (any, bool) { + obj, found := a.bindings[name] + if !found { + return nil, false + } + fn, isLazy := obj.(func() ref.Val) + if isLazy { + obj = fn() + a.bindings[name] = obj + } + fnRaw, isLazy := obj.(func() any) + if isLazy { + obj = fnRaw() + a.bindings[name] = obj + } + return obj, found +} + +// hierarchicalActivation which implements Activation and contains a parent and +// child activation. +type hierarchicalActivation struct { + parent Activation + child Activation +} + +// Parent implements the Activation interface method. +func (a *hierarchicalActivation) Parent() Activation { + return a.parent +} + +// ResolveName implements the Activation interface method. +func (a *hierarchicalActivation) ResolveName(name string) (any, bool) { + if object, found := a.child.ResolveName(name); found { + return object, found + } + return a.parent.ResolveName(name) +} + +// NewHierarchicalActivation takes two activations and produces a new one which prioritizes +// resolution in the child first and parent(s) second. +func NewHierarchicalActivation(parent Activation, child Activation) Activation { + return &hierarchicalActivation{parent, child} +} + +// NewPartialActivation returns an Activation which contains a list of AttributePattern values +// representing field and index operations that should result in a 'types.Unknown' result. +// +// The `bindings` value may be any value type supported by the interpreter.NewActivation call, +// but is typically either an existing Activation or map[string]any. +func NewPartialActivation(bindings any, + unknowns ...*AttributePattern) (PartialActivation, error) { + a, err := NewActivation(bindings) + if err != nil { + return nil, err + } + return &partActivation{Activation: a, unknowns: unknowns}, nil +} + +// PartialActivation extends the Activation interface with a set of UnknownAttributePatterns. +type PartialActivation interface { + Activation + + // UnknownAttributePaths returns a set of AttributePattern values which match Attribute + // expressions for data accesses whose values are not yet known. + UnknownAttributePatterns() []*AttributePattern +} + +// partActivation is the default implementations of the PartialActivation interface. +type partActivation struct { + Activation + unknowns []*AttributePattern +} + +// UnknownAttributePatterns implements the PartialActivation interface method. +func (a *partActivation) UnknownAttributePatterns() []*AttributePattern { + return a.unknowns +} + +// varActivation represents a single mutable variable binding. +// +// This activation type should only be used within folds as the fold loop controls the object +// life-cycle. +type varActivation struct { + parent Activation + name string + val ref.Val +} + +// Parent implements the Activation interface method. +func (v *varActivation) Parent() Activation { + return v.parent +} + +// ResolveName implements the Activation interface method. +func (v *varActivation) ResolveName(name string) (any, bool) { + if name == v.name { + return v.val, true + } + return v.parent.ResolveName(name) +} + +var ( + // pool of var activations to reduce allocations during folds. + varActivationPool = &sync.Pool{ + New: func() any { + return &varActivation{} + }, + } +) diff --git a/vendor/github.com/authzed/cel-go/interpreter/attribute_patterns.go b/vendor/github.com/authzed/cel-go/interpreter/attribute_patterns.go new file mode 100644 index 0000000..7b6ec99 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/attribute_patterns.go @@ -0,0 +1,397 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "fmt" + + "github.com/authzed/cel-go/common/containers" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" +) + +// AttributePattern represents a top-level variable with an optional set of qualifier patterns. +// +// When using a CEL expression within a container, e.g. a package or namespace, the variable name +// in the pattern must match the qualified name produced during the variable namespace resolution. +// For example, if variable `c` appears in an expression whose container is `a.b`, the variable +// name supplied to the pattern must be `a.b.c` +// +// The qualifier patterns for attribute matching must be one of the following: +// +// - valid map key type: string, int, uint, bool +// - wildcard (*) +// +// Examples: +// +// 1. ns.myvar["complex-value"] +// 2. ns.myvar["complex-value"][0] +// 3. ns.myvar["complex-value"].*.name +// +// The first example is simple: match an attribute where the variable is 'ns.myvar' with a +// field access on 'complex-value'. The second example expands the match to indicate that only +// a specific index `0` should match. And lastly, the third example matches any indexed access +// that later selects the 'name' field. +type AttributePattern struct { + variable string + qualifierPatterns []*AttributeQualifierPattern +} + +// NewAttributePattern produces a new mutable AttributePattern based on a variable name. +func NewAttributePattern(variable string) *AttributePattern { + return &AttributePattern{ + variable: variable, + qualifierPatterns: []*AttributeQualifierPattern{}, + } +} + +// QualString adds a string qualifier pattern to the AttributePattern. The string may be a valid +// identifier, or string map key including empty string. +func (apat *AttributePattern) QualString(pattern string) *AttributePattern { + apat.qualifierPatterns = append(apat.qualifierPatterns, + &AttributeQualifierPattern{value: pattern}) + return apat +} + +// QualInt adds an int qualifier pattern to the AttributePattern. The index may be either a map or +// list index. +func (apat *AttributePattern) QualInt(pattern int64) *AttributePattern { + apat.qualifierPatterns = append(apat.qualifierPatterns, + &AttributeQualifierPattern{value: pattern}) + return apat +} + +// QualUint adds an uint qualifier pattern for a map index operation to the AttributePattern. +func (apat *AttributePattern) QualUint(pattern uint64) *AttributePattern { + apat.qualifierPatterns = append(apat.qualifierPatterns, + &AttributeQualifierPattern{value: pattern}) + return apat +} + +// QualBool adds a bool qualifier pattern for a map index operation to the AttributePattern. +func (apat *AttributePattern) QualBool(pattern bool) *AttributePattern { + apat.qualifierPatterns = append(apat.qualifierPatterns, + &AttributeQualifierPattern{value: pattern}) + return apat +} + +// Wildcard adds a special sentinel qualifier pattern that will match any single qualifier. +func (apat *AttributePattern) Wildcard() *AttributePattern { + apat.qualifierPatterns = append(apat.qualifierPatterns, + &AttributeQualifierPattern{wildcard: true}) + return apat +} + +// VariableMatches returns true if the fully qualified variable matches the AttributePattern +// fully qualified variable name. +func (apat *AttributePattern) VariableMatches(variable string) bool { + return apat.variable == variable +} + +// QualifierPatterns returns the set of AttributeQualifierPattern values on the AttributePattern. +func (apat *AttributePattern) QualifierPatterns() []*AttributeQualifierPattern { + return apat.qualifierPatterns +} + +// AttributeQualifierPattern holds a wildcard or valued qualifier pattern. +type AttributeQualifierPattern struct { + wildcard bool + value any +} + +// Matches returns true if the qualifier pattern is a wildcard, or the Qualifier implements the +// qualifierValueEquator interface and its IsValueEqualTo returns true for the qualifier pattern. +func (qpat *AttributeQualifierPattern) Matches(q Qualifier) bool { + if qpat.wildcard { + return true + } + qve, ok := q.(qualifierValueEquator) + return ok && qve.QualifierValueEquals(qpat.value) +} + +// qualifierValueEquator defines an interface for determining if an input value, of valid map key +// type, is equal to the value held in the Qualifier. This interface is used by the +// AttributeQualifierPattern to determine pattern matches for non-wildcard qualifier patterns. +// +// Note: Attribute values are also Qualifier values; however, Attributes are resolved before +// qualification happens. This is an implementation detail, but one relevant to why the Attribute +// types do not surface in the list of implementations. +// +// See: partialAttributeFactory.matchesUnknownPatterns for more details on how this interface is +// used. +type qualifierValueEquator interface { + // QualifierValueEquals returns true if the input value is equal to the value held in the + // Qualifier. + QualifierValueEquals(value any) bool +} + +// QualifierValueEquals implementation for boolean qualifiers. +func (q *boolQualifier) QualifierValueEquals(value any) bool { + bval, ok := value.(bool) + return ok && q.value == bval +} + +// QualifierValueEquals implementation for field qualifiers. +func (q *fieldQualifier) QualifierValueEquals(value any) bool { + sval, ok := value.(string) + return ok && q.Name == sval +} + +// QualifierValueEquals implementation for string qualifiers. +func (q *stringQualifier) QualifierValueEquals(value any) bool { + sval, ok := value.(string) + return ok && q.value == sval +} + +// QualifierValueEquals implementation for int qualifiers. +func (q *intQualifier) QualifierValueEquals(value any) bool { + return numericValueEquals(value, q.celValue) +} + +// QualifierValueEquals implementation for uint qualifiers. +func (q *uintQualifier) QualifierValueEquals(value any) bool { + return numericValueEquals(value, q.celValue) +} + +// QualifierValueEquals implementation for double qualifiers. +func (q *doubleQualifier) QualifierValueEquals(value any) bool { + return numericValueEquals(value, q.celValue) +} + +// numericValueEquals uses CEL equality to determine whether two number values are +func numericValueEquals(value any, celValue ref.Val) bool { + val := types.DefaultTypeAdapter.NativeToValue(value) + return celValue.Equal(val) == types.True +} + +// NewPartialAttributeFactory returns an AttributeFactory implementation capable of performing +// AttributePattern matches with PartialActivation inputs. +func NewPartialAttributeFactory(container *containers.Container, adapter types.Adapter, provider types.Provider, opts ...AttrFactoryOption) AttributeFactory { + fac := NewAttributeFactory(container, adapter, provider, opts...) + return &partialAttributeFactory{ + AttributeFactory: fac, + container: container, + adapter: adapter, + provider: provider, + } +} + +type partialAttributeFactory struct { + AttributeFactory + container *containers.Container + adapter types.Adapter + provider types.Provider +} + +// AbsoluteAttribute implementation of the AttributeFactory interface which wraps the +// NamespacedAttribute resolution in an internal attributeMatcher object to dynamically match +// unknown patterns from PartialActivation inputs if given. +func (fac *partialAttributeFactory) AbsoluteAttribute(id int64, names ...string) NamespacedAttribute { + attr := fac.AttributeFactory.AbsoluteAttribute(id, names...) + return &attributeMatcher{fac: fac, NamespacedAttribute: attr} +} + +// MaybeAttribute implementation of the AttributeFactory interface which ensure that the set of +// 'maybe' NamespacedAttribute values are produced using the partialAttributeFactory rather than +// the base AttributeFactory implementation. +func (fac *partialAttributeFactory) MaybeAttribute(id int64, name string) Attribute { + return &maybeAttribute{ + id: id, + attrs: []NamespacedAttribute{ + fac.AbsoluteAttribute(id, fac.container.ResolveCandidateNames(name)...), + }, + adapter: fac.adapter, + provider: fac.provider, + fac: fac, + } +} + +// matchesUnknownPatterns returns true if the variable names and qualifiers for a given +// Attribute value match any of the ActivationPattern objects in the set of unknown activation +// patterns on the given PartialActivation. +// +// For example, in the expression `a.b`, the Attribute is composed of variable `a`, with string +// qualifier `b`. When a PartialActivation is supplied, it indicates that some or all of the data +// provided in the input is unknown by specifying unknown AttributePatterns. An AttributePattern +// that refers to variable `a` with a string qualifier of `c` will not match `a.b`; however, any +// of the following patterns will match Attribute `a.b`: +// +// - `AttributePattern("a")` +// - `AttributePattern("a").Wildcard()` +// - `AttributePattern("a").QualString("b")` +// - `AttributePattern("a").QualString("b").QualInt(0)` +// +// Any AttributePattern which overlaps an Attribute or vice-versa will produce an Unknown result +// for the last pattern matched variable or qualifier in the Attribute. In the first matching +// example, the expression id representing variable `a` would be listed in the Unknown result, +// whereas in the other pattern examples, the qualifier `b` would be returned as the Unknown. +func (fac *partialAttributeFactory) matchesUnknownPatterns( + vars PartialActivation, + attrID int64, + variableNames []string, + qualifiers []Qualifier) (*types.Unknown, error) { + patterns := vars.UnknownAttributePatterns() + candidateIndices := map[int]struct{}{} + for _, variable := range variableNames { + for i, pat := range patterns { + if pat.VariableMatches(variable) { + if len(qualifiers) == 0 { + return types.NewUnknown(attrID, types.NewAttributeTrail(variable)), nil + } + candidateIndices[i] = struct{}{} + } + } + } + // Determine whether to return early if there are no candidate unknown patterns. + if len(candidateIndices) == 0 { + return nil, nil + } + // Resolve the attribute qualifiers into a static set. This prevents more dynamic + // Attribute resolutions than necessary when there are multiple unknown patterns + // that traverse the same Attribute-based qualifier field. + newQuals := make([]Qualifier, len(qualifiers)) + for i, qual := range qualifiers { + attr, isAttr := qual.(Attribute) + if isAttr { + val, err := attr.Resolve(vars) + if err != nil { + return nil, err + } + // If this resolution behavior ever changes, new implementations of the + // qualifierValueEquator may be required to handle proper resolution. + qual, err = fac.NewQualifier(nil, qual.ID(), val, attr.IsOptional()) + if err != nil { + return nil, err + } + } + newQuals[i] = qual + } + // Determine whether any of the unknown patterns match. + for patIdx := range candidateIndices { + pat := patterns[patIdx] + isUnk := true + matchExprID := attrID + qualPats := pat.QualifierPatterns() + for i, qual := range newQuals { + if i >= len(qualPats) { + break + } + matchExprID = qual.ID() + qualPat := qualPats[i] + // Note, the AttributeQualifierPattern relies on the input Qualifier not being an + // Attribute, since there is no way to resolve the Attribute with the information + // provided to the Matches call. + if !qualPat.Matches(qual) { + isUnk = false + break + } + } + if isUnk { + attr := types.NewAttributeTrail(pat.variable) + for i := 0; i < len(qualPats) && i < len(newQuals); i++ { + if qual, ok := newQuals[i].(ConstantQualifier); ok { + switch v := qual.Value().Value().(type) { + case bool: + types.QualifyAttribute[bool](attr, v) + case float64: + types.QualifyAttribute[int64](attr, int64(v)) + case int64: + types.QualifyAttribute[int64](attr, v) + case string: + types.QualifyAttribute[string](attr, v) + case uint64: + types.QualifyAttribute[uint64](attr, v) + default: + types.QualifyAttribute[string](attr, fmt.Sprintf("%v", v)) + } + } else { + types.QualifyAttribute[string](attr, "*") + } + } + return types.NewUnknown(matchExprID, attr), nil + } + } + return nil, nil +} + +// attributeMatcher embeds the NamespacedAttribute interface which allows it to participate in +// AttributePattern matching against Attribute values without having to modify the code paths that +// identify Attributes in expressions. +type attributeMatcher struct { + NamespacedAttribute + qualifiers []Qualifier + fac *partialAttributeFactory +} + +// AddQualifier implements the Attribute interface method. +func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) { + // Add the qualifier to the embedded NamespacedAttribute. If the input to the Resolve + // method is not a PartialActivation, or does not match an unknown attribute pattern, the + // Resolve method is directly invoked on the underlying NamespacedAttribute. + _, err := m.NamespacedAttribute.AddQualifier(qual) + if err != nil { + return nil, err + } + // The attributeMatcher overloads TryResolve and will attempt to match unknown patterns against + // the variable name and qualifier set contained within the Attribute. These values are not + // directly inspectable on the top-level NamespacedAttribute interface and so are tracked within + // the attributeMatcher. + m.qualifiers = append(m.qualifiers, qual) + return m, nil +} + +// Resolve is an implementation of the NamespacedAttribute interface method which tests +// for matching unknown attribute patterns and returns types.Unknown if present. Otherwise, +// the standard Resolve logic applies. +func (m *attributeMatcher) Resolve(vars Activation) (any, error) { + id := m.NamespacedAttribute.ID() + // Bug in how partial activation is resolved, should search parents as well. + partial, isPartial := toPartialActivation(vars) + if isPartial { + unk, err := m.fac.matchesUnknownPatterns( + partial, + id, + m.CandidateVariableNames(), + m.qualifiers) + if err != nil { + return nil, err + } + if unk != nil { + return unk, nil + } + } + return m.NamespacedAttribute.Resolve(vars) +} + +// Qualify is an implementation of the Qualifier interface method. +func (m *attributeMatcher) Qualify(vars Activation, obj any) (any, error) { + return attrQualify(m.fac, vars, obj, m) +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (m *attributeMatcher) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(m.fac, vars, obj, m, presenceOnly) +} + +func toPartialActivation(vars Activation) (PartialActivation, bool) { + pv, ok := vars.(PartialActivation) + if ok { + return pv, true + } + if vars.Parent() != nil { + return toPartialActivation(vars.Parent()) + } + return nil, false +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/attributes.go b/vendor/github.com/authzed/cel-go/interpreter/attributes.go new file mode 100644 index 0000000..9b2ab09 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/attributes.go @@ -0,0 +1,1436 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "fmt" + "strings" + + "github.com/authzed/cel-go/common/containers" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" +) + +// AttributeFactory provides methods creating Attribute and Qualifier values. +type AttributeFactory interface { + // AbsoluteAttribute creates an attribute that refers to a top-level variable name. + // + // Checked expressions generate absolute attribute with a single name. + // Parse-only expressions may have more than one possible absolute identifier when the + // expression is created within a container, e.g. package or namespace. + // + // When there is more than one name supplied to the AbsoluteAttribute call, the names + // must be in CEL's namespace resolution order. The name arguments provided here are + // returned in the same order as they were provided by the NamespacedAttribute + // CandidateVariableNames method. + AbsoluteAttribute(id int64, names ...string) NamespacedAttribute + + // ConditionalAttribute creates an attribute with two Attribute branches, where the Attribute + // that is resolved depends on the boolean evaluation of the input 'expr'. + ConditionalAttribute(id int64, expr Interpretable, t, f Attribute) Attribute + + // MaybeAttribute creates an attribute that refers to either a field selection or a namespaced + // variable name. + // + // Only expressions which have not been type-checked may generate oneof attributes. + MaybeAttribute(id int64, name string) Attribute + + // RelativeAttribute creates an attribute whose value is a qualification of a dynamic + // computation rather than a static variable reference. + RelativeAttribute(id int64, operand Interpretable) Attribute + + // NewQualifier creates a qualifier on the target object with a given value. + // + // The 'val' may be an Attribute or any proto-supported map key type: bool, int, string, uint. + // + // The qualifier may consider the object type being qualified, if present. If absent, the + // qualification should be considered dynamic and the qualification should still work, though + // it may be sub-optimal. + NewQualifier(objType *types.Type, qualID int64, val any, opt bool) (Qualifier, error) +} + +// Qualifier marker interface for designating different qualifier values and where they appear +// within field selections and index call expressions (`_[_]`). +type Qualifier interface { + // ID where the qualifier appears within an expression. + ID() int64 + + // IsOptional specifies whether the qualifier is optional. + // Instead of a direct qualification, an optional qualifier will be resolved via QualifyIfPresent + // rather than Qualify. A non-optional qualifier may also be resolved through QualifyIfPresent if + // the object to qualify is itself optional. + IsOptional() bool + + // Qualify performs a qualification, e.g. field selection, on the input object and returns + // the value of the access and whether the value was set. A non-nil value with a false presence + // test result indicates that the value being returned is the default value. + Qualify(vars Activation, obj any) (any, error) + + // QualifyIfPresent qualifies the object if the qualifier is declared or defined on the object. + // The 'presenceOnly' flag indicates that the value is not necessary, just a boolean status as + // to whether the qualifier is present. + QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) +} + +// ConstantQualifier interface embeds the Qualifier interface and provides an option to inspect the +// qualifier's constant value. +// +// Non-constant qualifiers are of Attribute type. +type ConstantQualifier interface { + Qualifier + + // Value returns the constant value associated with the qualifier. + Value() ref.Val +} + +// Attribute values are a variable or value with an optional set of qualifiers, such as field, key, +// or index accesses. +type Attribute interface { + Qualifier + + // AddQualifier adds a qualifier on the Attribute or error if the qualification is not a valid qualifier type. + AddQualifier(Qualifier) (Attribute, error) + + // Resolve returns the value of the Attribute and whether it was present given an Activation. + // For objects which support safe traversal, the value may be non-nil and the presence flag be false. + // + // If an error is encountered during attribute resolution, it will be returned immediately. + // If the attribute cannot be resolved within the Activation, the result must be: `nil`, `error` + // with the error indicating which variable was missing. + Resolve(Activation) (any, error) +} + +// NamespacedAttribute values are a variable within a namespace, and an optional set of qualifiers +// such as field, key, or index accesses. +type NamespacedAttribute interface { + Attribute + + // CandidateVariableNames returns the possible namespaced variable names for this Attribute in + // the CEL namespace resolution order. + CandidateVariableNames() []string + + // Qualifiers returns the list of qualifiers associated with the Attribute. + Qualifiers() []Qualifier +} + +// AttrFactoryOption specifies a functional option for configuring an attribute factory. +type AttrFactoryOption func(*attrFactory) *attrFactory + +// EnableErrorOnBadPresenceTest error generation when a presence test or optional field selection +// is performed on a primitive type. +func EnableErrorOnBadPresenceTest(value bool) AttrFactoryOption { + return func(fac *attrFactory) *attrFactory { + fac.errorOnBadPresenceTest = value + return fac + } +} + +// NewAttributeFactory returns a default AttributeFactory which is produces Attribute values +// capable of resolving types by simple names and qualify the values using the supported qualifier +// types: bool, int, string, and uint. +func NewAttributeFactory(cont *containers.Container, a types.Adapter, p types.Provider, opts ...AttrFactoryOption) AttributeFactory { + fac := &attrFactory{ + container: cont, + adapter: a, + provider: p, + } + for _, o := range opts { + fac = o(fac) + } + return fac +} + +type attrFactory struct { + container *containers.Container + adapter types.Adapter + provider types.Provider + + errorOnBadPresenceTest bool +} + +// AbsoluteAttribute refers to a variable value and an optional qualifier path. +// +// The namespaceNames represent the names the variable could have based on namespace +// resolution rules. +func (r *attrFactory) AbsoluteAttribute(id int64, names ...string) NamespacedAttribute { + return &absoluteAttribute{ + id: id, + namespaceNames: names, + qualifiers: []Qualifier{}, + adapter: r.adapter, + provider: r.provider, + fac: r, + errorOnBadPresenceTest: r.errorOnBadPresenceTest, + } +} + +// ConditionalAttribute supports the case where an attribute selection may occur on a conditional +// expression, e.g. (cond ? a : b).c +func (r *attrFactory) ConditionalAttribute(id int64, expr Interpretable, t, f Attribute) Attribute { + return &conditionalAttribute{ + id: id, + expr: expr, + truthy: t, + falsy: f, + adapter: r.adapter, + fac: r, + } +} + +// MaybeAttribute collects variants of unchecked AbsoluteAttribute values which could either be +// direct variable accesses or some combination of variable access with qualification. +func (r *attrFactory) MaybeAttribute(id int64, name string) Attribute { + return &maybeAttribute{ + id: id, + attrs: []NamespacedAttribute{ + r.AbsoluteAttribute(id, r.container.ResolveCandidateNames(name)...), + }, + adapter: r.adapter, + provider: r.provider, + fac: r, + } +} + +// RelativeAttribute refers to an expression and an optional qualifier path. +func (r *attrFactory) RelativeAttribute(id int64, operand Interpretable) Attribute { + return &relativeAttribute{ + id: id, + operand: operand, + qualifiers: []Qualifier{}, + adapter: r.adapter, + fac: r, + errorOnBadPresenceTest: r.errorOnBadPresenceTest, + } +} + +// NewQualifier is an implementation of the AttributeFactory interface. +func (r *attrFactory) NewQualifier(objType *types.Type, qualID int64, val any, opt bool) (Qualifier, error) { + // Before creating a new qualifier check to see if this is a protobuf message field access. + // If so, use the precomputed GetFrom qualification method rather than the standard + // stringQualifier. + str, isStr := val.(string) + if isStr && objType != nil && objType.Kind() == types.StructKind { + ft, found := r.provider.FindStructFieldType(objType.TypeName(), str) + if found && ft.IsSet != nil && ft.GetFrom != nil { + return &fieldQualifier{ + id: qualID, + Name: str, + FieldType: ft, + adapter: r.adapter, + optional: opt, + }, nil + } + } + return newQualifier(r.adapter, qualID, val, opt, r.errorOnBadPresenceTest) +} + +type absoluteAttribute struct { + id int64 + // namespaceNames represent the names the variable could have based on declared container + // (package) of the expression. + namespaceNames []string + qualifiers []Qualifier + adapter types.Adapter + provider types.Provider + fac AttributeFactory + + errorOnBadPresenceTest bool +} + +// ID implements the Attribute interface method. +func (a *absoluteAttribute) ID() int64 { + qualCount := len(a.qualifiers) + if qualCount == 0 { + return a.id + } + return a.qualifiers[qualCount-1].ID() +} + +// IsOptional returns trivially false for an attribute as the attribute represents a fully +// qualified variable name. If the attribute is used in an optional manner, then an attrQualifier +// is created and marks the attribute as optional. +func (a *absoluteAttribute) IsOptional() bool { + return false +} + +// AddQualifier implements the Attribute interface method. +func (a *absoluteAttribute) AddQualifier(qual Qualifier) (Attribute, error) { + a.qualifiers = append(a.qualifiers, qual) + return a, nil +} + +// CandidateVariableNames implements the NamespaceAttribute interface method. +func (a *absoluteAttribute) CandidateVariableNames() []string { + return a.namespaceNames +} + +// Qualifiers returns the list of Qualifier instances associated with the namespaced attribute. +func (a *absoluteAttribute) Qualifiers() []Qualifier { + return a.qualifiers +} + +// Qualify is an implementation of the Qualifier interface method. +func (a *absoluteAttribute) Qualify(vars Activation, obj any) (any, error) { + return attrQualify(a.fac, vars, obj, a) +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (a *absoluteAttribute) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(a.fac, vars, obj, a, presenceOnly) +} + +// String implements the Stringer interface method. +func (a *absoluteAttribute) String() string { + return fmt.Sprintf("id: %v, names: %v", a.id, a.namespaceNames) +} + +// Resolve returns the resolved Attribute value given the Activation, or error if the Attribute +// variable is not found, or if its Qualifiers cannot be applied successfully. +// +// If the variable name cannot be found as an Activation variable or in the TypeProvider as +// a type, then the result is `nil`, `error` with the error indicating the name of the first +// variable searched as missing. +func (a *absoluteAttribute) Resolve(vars Activation) (any, error) { + for _, nm := range a.namespaceNames { + // If the variable is found, process it. Otherwise, wait until the checks to + // determine whether the type is unknown before returning. + obj, found := vars.ResolveName(nm) + if found { + if celErr, ok := obj.(*types.Err); ok { + return nil, celErr.Unwrap() + } + obj, isOpt, err := applyQualifiers(vars, obj, a.qualifiers) + if err != nil { + return nil, err + } + if isOpt { + val := a.adapter.NativeToValue(obj) + if types.IsUnknown(val) { + return val, nil + } + return types.OptionalOf(val), nil + } + return obj, nil + } + // Attempt to resolve the qualified type name if the name is not a variable identifier. + typ, found := a.provider.FindIdent(nm) + if found { + if len(a.qualifiers) == 0 { + return typ, nil + } + } + } + var attrNames strings.Builder + for i, nm := range a.namespaceNames { + if i != 0 { + attrNames.WriteString(", ") + } + attrNames.WriteString(nm) + } + return nil, missingAttribute(attrNames.String()) +} + +type conditionalAttribute struct { + id int64 + expr Interpretable + truthy Attribute + falsy Attribute + adapter types.Adapter + fac AttributeFactory +} + +// ID is an implementation of the Attribute interface method. +func (a *conditionalAttribute) ID() int64 { + // There's a field access after the conditional. + if a.truthy.ID() == a.falsy.ID() { + return a.truthy.ID() + } + // Otherwise return the conditional id as the consistent id being tracked. + return a.id +} + +// IsOptional returns trivially false for an attribute as the attribute represents a fully +// qualified variable name. If the attribute is used in an optional manner, then an attrQualifier +// is created and marks the attribute as optional. +func (a *conditionalAttribute) IsOptional() bool { + return false +} + +// AddQualifier appends the same qualifier to both sides of the conditional, in effect managing +// the qualification of alternate attributes. +func (a *conditionalAttribute) AddQualifier(qual Qualifier) (Attribute, error) { + _, err := a.truthy.AddQualifier(qual) + if err != nil { + return nil, err + } + _, err = a.falsy.AddQualifier(qual) + if err != nil { + return nil, err + } + return a, nil +} + +// Qualify is an implementation of the Qualifier interface method. +func (a *conditionalAttribute) Qualify(vars Activation, obj any) (any, error) { + return attrQualify(a.fac, vars, obj, a) +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (a *conditionalAttribute) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(a.fac, vars, obj, a, presenceOnly) +} + +// Resolve evaluates the condition, and then resolves the truthy or falsy branch accordingly. +func (a *conditionalAttribute) Resolve(vars Activation) (any, error) { + val := a.expr.Eval(vars) + if val == types.True { + return a.truthy.Resolve(vars) + } + if val == types.False { + return a.falsy.Resolve(vars) + } + if types.IsUnknown(val) { + return val, nil + } + return nil, types.MaybeNoSuchOverloadErr(val).(*types.Err) +} + +// String is an implementation of the Stringer interface method. +func (a *conditionalAttribute) String() string { + return fmt.Sprintf("id: %v, truthy attribute: %v, falsy attribute: %v", a.id, a.truthy, a.falsy) +} + +type maybeAttribute struct { + id int64 + attrs []NamespacedAttribute + adapter types.Adapter + provider types.Provider + fac AttributeFactory +} + +// ID is an implementation of the Attribute interface method. +func (a *maybeAttribute) ID() int64 { + return a.attrs[0].ID() +} + +// IsOptional returns trivially false for an attribute as the attribute represents a fully +// qualified variable name. If the attribute is used in an optional manner, then an attrQualifier +// is created and marks the attribute as optional. +func (a *maybeAttribute) IsOptional() bool { + return false +} + +// AddQualifier adds a qualifier to each possible attribute variant, and also creates +// a new namespaced variable from the qualified value. +// +// The algorithm for building the maybe attribute is as follows: +// +// 1. Create a maybe attribute from a simple identifier when it occurs in a parsed-only expression +// +// mb = MaybeAttribute(<id>, "a") +// +// Initializing the maybe attribute creates an absolute attribute internally which includes the +// possible namespaced names of the attribute. In this example, let's assume we are in namespace +// 'ns', then the maybe is either one of the following variable names: +// +// possible variables names -- ns.a, a +// +// 2. Adding a qualifier to the maybe means that the variable name could be a longer qualified +// name, or a field selection on one of the possible variable names produced earlier: +// +// mb.AddQualifier("b") +// +// possible variables names -- ns.a.b, a.b +// possible field selection -- ns.a['b'], a['b'] +// +// If none of the attributes within the maybe resolves a value, the result is an error. +func (a *maybeAttribute) AddQualifier(qual Qualifier) (Attribute, error) { + str := "" + isStr := false + cq, isConst := qual.(ConstantQualifier) + if isConst { + str, isStr = cq.Value().Value().(string) + } + var augmentedNames []string + // First add the qualifier to all existing attributes in the oneof. + for _, attr := range a.attrs { + if isStr && len(attr.Qualifiers()) == 0 { + candidateVars := attr.CandidateVariableNames() + augmentedNames = make([]string, len(candidateVars)) + for i, name := range candidateVars { + augmentedNames[i] = fmt.Sprintf("%s.%s", name, str) + } + } + _, err := attr.AddQualifier(qual) + if err != nil { + return nil, err + } + } + // Next, ensure the most specific variable / type reference is searched first. + if len(augmentedNames) != 0 { + a.attrs = append([]NamespacedAttribute{a.fac.AbsoluteAttribute(qual.ID(), augmentedNames...)}, a.attrs...) + } + return a, nil +} + +// Qualify is an implementation of the Qualifier interface method. +func (a *maybeAttribute) Qualify(vars Activation, obj any) (any, error) { + return attrQualify(a.fac, vars, obj, a) +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (a *maybeAttribute) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(a.fac, vars, obj, a, presenceOnly) +} + +// Resolve follows the variable resolution rules to determine whether the attribute is a variable +// or a field selection. +func (a *maybeAttribute) Resolve(vars Activation) (any, error) { + var maybeErr error + for _, attr := range a.attrs { + obj, err := attr.Resolve(vars) + // Return an error if one is encountered. + if err != nil { + resErr, ok := err.(*resolutionError) + if !ok { + return nil, err + } + // If this was not a missing variable error, return it. + if !resErr.isMissingAttribute() { + return nil, err + } + // When the variable is missing in a maybe attribute we defer erroring. + if maybeErr == nil { + maybeErr = resErr + } + // Continue attempting to resolve possible variables. + continue + } + return obj, nil + } + // Else, produce a no such attribute error. + return nil, maybeErr +} + +// String is an implementation of the Stringer interface method. +func (a *maybeAttribute) String() string { + return fmt.Sprintf("id: %v, attributes: %v", a.id, a.attrs) +} + +type relativeAttribute struct { + id int64 + operand Interpretable + qualifiers []Qualifier + adapter types.Adapter + fac AttributeFactory + + errorOnBadPresenceTest bool +} + +// ID is an implementation of the Attribute interface method. +func (a *relativeAttribute) ID() int64 { + qualCount := len(a.qualifiers) + if qualCount == 0 { + return a.id + } + return a.qualifiers[qualCount-1].ID() +} + +// IsOptional returns trivially false for an attribute as the attribute represents a fully +// qualified variable name. If the attribute is used in an optional manner, then an attrQualifier +// is created and marks the attribute as optional. +func (a *relativeAttribute) IsOptional() bool { + return false +} + +// AddQualifier implements the Attribute interface method. +func (a *relativeAttribute) AddQualifier(qual Qualifier) (Attribute, error) { + a.qualifiers = append(a.qualifiers, qual) + return a, nil +} + +// Qualify is an implementation of the Qualifier interface method. +func (a *relativeAttribute) Qualify(vars Activation, obj any) (any, error) { + return attrQualify(a.fac, vars, obj, a) +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (a *relativeAttribute) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return attrQualifyIfPresent(a.fac, vars, obj, a, presenceOnly) +} + +// Resolve expression value and qualifier relative to the expression result. +func (a *relativeAttribute) Resolve(vars Activation) (any, error) { + // First, evaluate the operand. + v := a.operand.Eval(vars) + if types.IsError(v) { + return nil, v.(*types.Err) + } + if types.IsUnknown(v) { + return v, nil + } + obj, isOpt, err := applyQualifiers(vars, v, a.qualifiers) + if err != nil { + return nil, err + } + if isOpt { + val := a.adapter.NativeToValue(obj) + if types.IsUnknown(val) { + return val, nil + } + return types.OptionalOf(val), nil + } + return obj, nil +} + +// String is an implementation of the Stringer interface method. +func (a *relativeAttribute) String() string { + return fmt.Sprintf("id: %v, operand: %v", a.id, a.operand) +} + +func newQualifier(adapter types.Adapter, id int64, v any, opt, errorOnBadPresenceTest bool) (Qualifier, error) { + var qual Qualifier + switch val := v.(type) { + case Attribute: + // Note, attributes are initially identified as non-optional since they represent a top-level + // field access; however, when used as a relative qualifier, e.g. a[?b.c], then an attrQualifier + // is created which intercepts the IsOptional check for the attribute in order to return the + // correct result. + return &attrQualifier{ + id: id, + Attribute: val, + optional: opt, + }, nil + case string: + qual = &stringQualifier{ + id: id, + value: val, + celValue: types.String(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case int: + qual = &intQualifier{ + id: id, + value: int64(val), + celValue: types.Int(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case int32: + qual = &intQualifier{ + id: id, + value: int64(val), + celValue: types.Int(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case int64: + qual = &intQualifier{ + id: id, + value: val, + celValue: types.Int(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case uint: + qual = &uintQualifier{ + id: id, + value: uint64(val), + celValue: types.Uint(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case uint32: + qual = &uintQualifier{ + id: id, + value: uint64(val), + celValue: types.Uint(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case uint64: + qual = &uintQualifier{ + id: id, + value: val, + celValue: types.Uint(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case bool: + qual = &boolQualifier{ + id: id, + value: val, + celValue: types.Bool(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case float32: + qual = &doubleQualifier{ + id: id, + value: float64(val), + celValue: types.Double(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case float64: + qual = &doubleQualifier{ + id: id, + value: val, + celValue: types.Double(val), + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case types.String: + qual = &stringQualifier{ + id: id, + value: string(val), + celValue: val, + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case types.Int: + qual = &intQualifier{ + id: id, + value: int64(val), + celValue: val, + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case types.Uint: + qual = &uintQualifier{ + id: id, + value: uint64(val), + celValue: val, + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case types.Bool: + qual = &boolQualifier{ + id: id, + value: bool(val), + celValue: val, + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case types.Double: + qual = &doubleQualifier{ + id: id, + value: float64(val), + celValue: val, + adapter: adapter, + optional: opt, + errorOnBadPresenceTest: errorOnBadPresenceTest, + } + case *types.Unknown: + qual = &unknownQualifier{id: id, value: val} + default: + if q, ok := v.(Qualifier); ok { + return q, nil + } + return nil, fmt.Errorf("invalid qualifier type: %T", v) + } + return qual, nil +} + +type attrQualifier struct { + id int64 + Attribute + optional bool +} + +// ID implements the Qualifier interface method and returns the qualification instruction id +// rather than the attribute id. +func (q *attrQualifier) ID() int64 { + return q.id +} + +// IsOptional implements the Qualifier interface method. +func (q *attrQualifier) IsOptional() bool { + return q.optional +} + +type stringQualifier struct { + id int64 + value string + celValue ref.Val + adapter types.Adapter + optional bool + errorOnBadPresenceTest bool +} + +// ID is an implementation of the Qualifier interface method. +func (q *stringQualifier) ID() int64 { + return q.id +} + +// IsOptional implements the Qualifier interface method. +func (q *stringQualifier) IsOptional() bool { + return q.optional +} + +// Qualify implements the Qualifier interface method. +func (q *stringQualifier) Qualify(vars Activation, obj any) (any, error) { + val, _, err := q.qualifyInternal(vars, obj, false, false) + return val, err +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (q *stringQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return q.qualifyInternal(vars, obj, true, presenceOnly) +} + +func (q *stringQualifier) qualifyInternal(vars Activation, obj any, presenceTest, presenceOnly bool) (any, bool, error) { + s := q.value + switch o := obj.(type) { + case map[string]any: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]string: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]int: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]int32: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]int64: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]uint: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]uint32: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]uint64: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]float32: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]float64: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + case map[string]bool: + obj, isKey := o[s] + if isKey { + return obj, true, nil + } + default: + return refQualify(q.adapter, obj, q.celValue, presenceTest, presenceOnly, q.errorOnBadPresenceTest) + } + if presenceTest { + return nil, false, nil + } + return nil, false, missingKey(q.celValue) +} + +// Value implements the ConstantQualifier interface +func (q *stringQualifier) Value() ref.Val { + return q.celValue +} + +type intQualifier struct { + id int64 + value int64 + celValue ref.Val + adapter types.Adapter + optional bool + errorOnBadPresenceTest bool +} + +// ID is an implementation of the Qualifier interface method. +func (q *intQualifier) ID() int64 { + return q.id +} + +// IsOptional implements the Qualifier interface method. +func (q *intQualifier) IsOptional() bool { + return q.optional +} + +// Qualify implements the Qualifier interface method. +func (q *intQualifier) Qualify(vars Activation, obj any) (any, error) { + val, _, err := q.qualifyInternal(vars, obj, false, false) + return val, err +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (q *intQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return q.qualifyInternal(vars, obj, true, presenceOnly) +} + +func (q *intQualifier) qualifyInternal(vars Activation, obj any, presenceTest, presenceOnly bool) (any, bool, error) { + i := q.value + var isMap bool + switch o := obj.(type) { + // The specialized map types supported by an int qualifier are considerably fewer than the set + // of specialized map types supported by string qualifiers since they are less frequently used + // than string-based map keys. Additional specializations may be added in the future if + // desired. + case map[int]any: + isMap = true + obj, isKey := o[int(i)] + if isKey { + return obj, true, nil + } + case map[int32]any: + isMap = true + obj, isKey := o[int32(i)] + if isKey { + return obj, true, nil + } + case map[int64]any: + isMap = true + obj, isKey := o[i] + if isKey { + return obj, true, nil + } + case []any: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []string: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []int: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []int32: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []int64: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []uint: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []uint32: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []uint64: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []float32: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []float64: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + case []bool: + isIndex := i >= 0 && i < int64(len(o)) + if isIndex { + return o[i], true, nil + } + default: + return refQualify(q.adapter, obj, q.celValue, presenceTest, presenceOnly, q.errorOnBadPresenceTest) + } + if presenceTest { + return nil, false, nil + } + if isMap { + return nil, false, missingKey(q.celValue) + } + return nil, false, missingIndex(q.celValue) +} + +// Value implements the ConstantQualifier interface +func (q *intQualifier) Value() ref.Val { + return q.celValue +} + +type uintQualifier struct { + id int64 + value uint64 + celValue ref.Val + adapter types.Adapter + optional bool + errorOnBadPresenceTest bool +} + +// ID is an implementation of the Qualifier interface method. +func (q *uintQualifier) ID() int64 { + return q.id +} + +// IsOptional implements the Qualifier interface method. +func (q *uintQualifier) IsOptional() bool { + return q.optional +} + +// Qualify implements the Qualifier interface method. +func (q *uintQualifier) Qualify(vars Activation, obj any) (any, error) { + val, _, err := q.qualifyInternal(vars, obj, false, false) + return val, err +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (q *uintQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return q.qualifyInternal(vars, obj, true, presenceOnly) +} + +func (q *uintQualifier) qualifyInternal(vars Activation, obj any, presenceTest, presenceOnly bool) (any, bool, error) { + u := q.value + switch o := obj.(type) { + // The specialized map types supported by a uint qualifier are considerably fewer than the set + // of specialized map types supported by string qualifiers since they are less frequently used + // than string-based map keys. Additional specializations may be added in the future if + // desired. + case map[uint]any: + obj, isKey := o[uint(u)] + if isKey { + return obj, true, nil + } + case map[uint32]any: + obj, isKey := o[uint32(u)] + if isKey { + return obj, true, nil + } + case map[uint64]any: + obj, isKey := o[u] + if isKey { + return obj, true, nil + } + default: + return refQualify(q.adapter, obj, q.celValue, presenceTest, presenceOnly, q.errorOnBadPresenceTest) + } + if presenceTest { + return nil, false, nil + } + return nil, false, missingKey(q.celValue) +} + +// Value implements the ConstantQualifier interface +func (q *uintQualifier) Value() ref.Val { + return q.celValue +} + +type boolQualifier struct { + id int64 + value bool + celValue ref.Val + adapter types.Adapter + optional bool + errorOnBadPresenceTest bool +} + +// ID is an implementation of the Qualifier interface method. +func (q *boolQualifier) ID() int64 { + return q.id +} + +// IsOptional implements the Qualifier interface method. +func (q *boolQualifier) IsOptional() bool { + return q.optional +} + +// Qualify implements the Qualifier interface method. +func (q *boolQualifier) Qualify(vars Activation, obj any) (any, error) { + val, _, err := q.qualifyInternal(vars, obj, false, false) + return val, err +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (q *boolQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return q.qualifyInternal(vars, obj, true, presenceOnly) +} + +func (q *boolQualifier) qualifyInternal(vars Activation, obj any, presenceTest, presenceOnly bool) (any, bool, error) { + b := q.value + switch o := obj.(type) { + case map[bool]any: + obj, isKey := o[b] + if isKey { + return obj, true, nil + } + default: + return refQualify(q.adapter, obj, q.celValue, presenceTest, presenceOnly, q.errorOnBadPresenceTest) + } + if presenceTest { + return nil, false, nil + } + return nil, false, missingKey(q.celValue) +} + +// Value implements the ConstantQualifier interface +func (q *boolQualifier) Value() ref.Val { + return q.celValue +} + +// fieldQualifier indicates that the qualification is a well-defined field with a known +// field type. When the field type is known this can be used to improve the speed and +// efficiency of field resolution. +type fieldQualifier struct { + id int64 + Name string + FieldType *types.FieldType + adapter types.Adapter + optional bool +} + +// ID is an implementation of the Qualifier interface method. +func (q *fieldQualifier) ID() int64 { + return q.id +} + +// IsOptional implements the Qualifier interface method. +func (q *fieldQualifier) IsOptional() bool { + return q.optional +} + +// Qualify implements the Qualifier interface method. +func (q *fieldQualifier) Qualify(vars Activation, obj any) (any, error) { + if rv, ok := obj.(ref.Val); ok { + obj = rv.Value() + } + val, err := q.FieldType.GetFrom(obj) + if err != nil { + return nil, err + } + return val, nil +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (q *fieldQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + if rv, ok := obj.(ref.Val); ok { + obj = rv.Value() + } + if !q.FieldType.IsSet(obj) { + return nil, false, nil + } + if presenceOnly { + return nil, true, nil + } + val, err := q.FieldType.GetFrom(obj) + if err != nil { + return nil, false, err + } + return val, true, nil +} + +// Value implements the ConstantQualifier interface +func (q *fieldQualifier) Value() ref.Val { + return types.String(q.Name) +} + +// doubleQualifier qualifies a CEL object, map, or list using a double value. +// +// This qualifier is used for working with dynamic data like JSON or protobuf.Any where the value +// type may not be known ahead of time and may not conform to the standard types supported as valid +// protobuf map key types. +type doubleQualifier struct { + id int64 + value float64 + celValue ref.Val + adapter types.Adapter + optional bool + errorOnBadPresenceTest bool +} + +// ID is an implementation of the Qualifier interface method. +func (q *doubleQualifier) ID() int64 { + return q.id +} + +// IsOptional implements the Qualifier interface method. +func (q *doubleQualifier) IsOptional() bool { + return q.optional +} + +// Qualify implements the Qualifier interface method. +func (q *doubleQualifier) Qualify(vars Activation, obj any) (any, error) { + val, _, err := q.qualifyInternal(vars, obj, false, false) + return val, err +} + +func (q *doubleQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return q.qualifyInternal(vars, obj, true, presenceOnly) +} + +func (q *doubleQualifier) qualifyInternal(vars Activation, obj any, presenceTest, presenceOnly bool) (any, bool, error) { + return refQualify(q.adapter, obj, q.celValue, presenceTest, presenceOnly, q.errorOnBadPresenceTest) +} + +// Value implements the ConstantQualifier interface +func (q *doubleQualifier) Value() ref.Val { + return q.celValue +} + +// unknownQualifier is a simple qualifier which always returns a preconfigured set of unknown values +// for any value subject to qualification. This is consistent with CEL's unknown handling elsewhere. +type unknownQualifier struct { + id int64 + value *types.Unknown +} + +// ID is an implementation of the Qualifier interface method. +func (q *unknownQualifier) ID() int64 { + return q.id +} + +// IsOptional returns trivially false as an the unknown value is always returned. +func (q *unknownQualifier) IsOptional() bool { + return false +} + +// Qualify returns the unknown value associated with this qualifier. +func (q *unknownQualifier) Qualify(vars Activation, obj any) (any, error) { + return q.value, nil +} + +// QualifyIfPresent is an implementation of the Qualifier interface method. +func (q *unknownQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + return q.value, true, nil +} + +// Value implements the ConstantQualifier interface +func (q *unknownQualifier) Value() ref.Val { + return q.value +} + +func applyQualifiers(vars Activation, obj any, qualifiers []Qualifier) (any, bool, error) { + optObj, isOpt := obj.(*types.Optional) + if isOpt { + if !optObj.HasValue() { + return optObj, false, nil + } + obj = optObj.GetValue().Value() + } + + var err error + for _, qual := range qualifiers { + var qualObj any + isOpt = isOpt || qual.IsOptional() + if isOpt { + var present bool + qualObj, present, err = qual.QualifyIfPresent(vars, obj, false) + if err != nil { + return nil, false, err + } + if !present { + // We return optional none here with a presence of 'false' as the layers + // above will attempt to call types.OptionalOf() on a present value if any + // of the qualifiers is optional. + return types.OptionalNone, false, nil + } + } else { + qualObj, err = qual.Qualify(vars, obj) + if err != nil { + return nil, false, err + } + } + obj = qualObj + } + return obj, isOpt, nil +} + +// attrQualify performs a qualification using the result of an attribute evaluation. +func attrQualify(fac AttributeFactory, vars Activation, obj any, qualAttr Attribute) (any, error) { + val, err := qualAttr.Resolve(vars) + if err != nil { + return nil, err + } + qual, err := fac.NewQualifier(nil, qualAttr.ID(), val, qualAttr.IsOptional()) + if err != nil { + return nil, err + } + return qual.Qualify(vars, obj) +} + +// attrQualifyIfPresent conditionally performs the qualification of the result of attribute is present +// on the target object. +func attrQualifyIfPresent(fac AttributeFactory, vars Activation, obj any, qualAttr Attribute, + presenceOnly bool) (any, bool, error) { + val, err := qualAttr.Resolve(vars) + if err != nil { + return nil, false, err + } + qual, err := fac.NewQualifier(nil, qualAttr.ID(), val, qualAttr.IsOptional()) + if err != nil { + return nil, false, err + } + return qual.QualifyIfPresent(vars, obj, presenceOnly) +} + +// refQualify attempts to convert the value to a CEL value and then uses reflection methods to try and +// apply the qualifier with the option to presence test field accesses before retrieving field values. +func refQualify(adapter types.Adapter, obj any, idx ref.Val, presenceTest, presenceOnly, errorOnBadPresenceTest bool) (ref.Val, bool, error) { + celVal := adapter.NativeToValue(obj) + switch v := celVal.(type) { + case *types.Unknown: + return v, true, nil + case *types.Err: + return nil, false, v + case traits.Mapper: + val, found := v.Find(idx) + // If the index is of the wrong type for the map, then it is possible + // for the Find call to produce an error. + if types.IsError(val) { + return nil, false, val.(*types.Err) + } + if found { + return val, true, nil + } + if presenceTest { + return nil, false, nil + } + return nil, false, missingKey(idx) + case traits.Lister: + // If the index argument is not a valid numeric type, then it is possible + // for the index operation to produce an error. + i, err := types.IndexOrError(idx) + if err != nil { + return nil, false, err + } + celIndex := types.Int(i) + if i >= 0 && celIndex < v.Size().(types.Int) { + return v.Get(idx), true, nil + } + if presenceTest { + return nil, false, nil + } + return nil, false, missingIndex(idx) + case traits.Indexer: + if presenceTest { + ft, ok := v.(traits.FieldTester) + if ok { + presence := ft.IsSet(idx) + if types.IsError(presence) { + return nil, false, presence.(*types.Err) + } + // If not found or presence only test, then return. + // Otherwise, if found, obtain the value later on. + if presenceOnly || presence == types.False { + return nil, presence == types.True, nil + } + } + } + val := v.Get(idx) + if types.IsError(val) { + return nil, false, val.(*types.Err) + } + return val, true, nil + default: + if presenceTest && !errorOnBadPresenceTest { + return nil, false, nil + } + return nil, false, missingKey(idx) + } +} + +// resolutionError is a custom error type which encodes the different error states which may +// occur during attribute resolution. +type resolutionError struct { + missingAttribute string + missingIndex ref.Val + missingKey ref.Val +} + +func (e *resolutionError) isMissingAttribute() bool { + return e.missingAttribute != "" +} + +func missingIndex(missing ref.Val) *resolutionError { + return &resolutionError{ + missingIndex: missing, + } +} + +func missingKey(missing ref.Val) *resolutionError { + return &resolutionError{ + missingKey: missing, + } +} + +func missingAttribute(attr string) *resolutionError { + return &resolutionError{ + missingAttribute: attr, + } +} + +// Error implements the error interface method. +func (e *resolutionError) Error() string { + if e.missingKey != nil { + return fmt.Sprintf("no such key: %v", e.missingKey) + } + if e.missingIndex != nil { + return fmt.Sprintf("index out of bounds: %v", e.missingIndex) + } + if e.missingAttribute != "" { + return fmt.Sprintf("no such attribute(s): %s", e.missingAttribute) + } + return "invalid attribute" +} + +// Is implements the errors.Is() method used by more recent versions of Go. +func (e *resolutionError) Is(err error) bool { + return err.Error() == e.Error() +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/decorators.go b/vendor/github.com/authzed/cel-go/interpreter/decorators.go new file mode 100644 index 0000000..760aa96 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/decorators.go @@ -0,0 +1,272 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" +) + +// InterpretableDecorator is a functional interface for decorating or replacing +// Interpretable expression nodes at construction time. +type InterpretableDecorator func(Interpretable) (Interpretable, error) + +// decObserveEval records evaluation state into an EvalState object. +func decObserveEval(observer EvalObserver) InterpretableDecorator { + return func(i Interpretable) (Interpretable, error) { + switch inst := i.(type) { + case *evalWatch, *evalWatchAttr, *evalWatchConst, *evalWatchConstructor: + // these instruction are already watching, return straight-away. + return i, nil + case InterpretableAttribute: + return &evalWatchAttr{ + InterpretableAttribute: inst, + observer: observer, + }, nil + case InterpretableConst: + return &evalWatchConst{ + InterpretableConst: inst, + observer: observer, + }, nil + case InterpretableConstructor: + return &evalWatchConstructor{ + constructor: inst, + observer: observer, + }, nil + default: + return &evalWatch{ + Interpretable: i, + observer: observer, + }, nil + } + } +} + +// decInterruptFolds creates an intepretable decorator which marks comprehensions as interruptable +// where the interrupt state is communicated via a hidden variable on the Activation. +func decInterruptFolds() InterpretableDecorator { + return func(i Interpretable) (Interpretable, error) { + fold, ok := i.(*evalFold) + if !ok { + return i, nil + } + fold.interruptable = true + return fold, nil + } +} + +// decDisableShortcircuits ensures that all branches of an expression will be evaluated, no short-circuiting. +func decDisableShortcircuits() InterpretableDecorator { + return func(i Interpretable) (Interpretable, error) { + switch expr := i.(type) { + case *evalOr: + return &evalExhaustiveOr{ + id: expr.id, + terms: expr.terms, + }, nil + case *evalAnd: + return &evalExhaustiveAnd{ + id: expr.id, + terms: expr.terms, + }, nil + case *evalFold: + expr.exhaustive = true + return expr, nil + case InterpretableAttribute: + cond, isCond := expr.Attr().(*conditionalAttribute) + if isCond { + return &evalExhaustiveConditional{ + id: cond.id, + attr: cond, + adapter: expr.Adapter(), + }, nil + } + } + return i, nil + } +} + +// decOptimize optimizes the program plan by looking for common evaluation patterns and +// conditionally precomputing the result. +// - build list and map values with constant elements. +// - convert 'in' operations to set membership tests if possible. +func decOptimize() InterpretableDecorator { + return func(i Interpretable) (Interpretable, error) { + switch inst := i.(type) { + case *evalList: + return maybeBuildListLiteral(i, inst) + case *evalMap: + return maybeBuildMapLiteral(i, inst) + case InterpretableCall: + if inst.OverloadID() == overloads.InList { + return maybeOptimizeSetMembership(i, inst) + } + if overloads.IsTypeConversionFunction(inst.Function()) { + return maybeOptimizeConstUnary(i, inst) + } + } + return i, nil + } +} + +// decRegexOptimizer compiles regex pattern string constants. +func decRegexOptimizer(regexOptimizations ...*RegexOptimization) InterpretableDecorator { + functionMatchMap := make(map[string]*RegexOptimization) + overloadMatchMap := make(map[string]*RegexOptimization) + for _, m := range regexOptimizations { + functionMatchMap[m.Function] = m + if m.OverloadID != "" { + overloadMatchMap[m.OverloadID] = m + } + } + + return func(i Interpretable) (Interpretable, error) { + call, ok := i.(InterpretableCall) + if !ok { + return i, nil + } + + var matcher *RegexOptimization + var found bool + if call.OverloadID() != "" { + matcher, found = overloadMatchMap[call.OverloadID()] + } + if !found { + matcher, found = functionMatchMap[call.Function()] + } + if !found || matcher.RegexIndex >= len(call.Args()) { + return i, nil + } + args := call.Args() + regexArg := args[matcher.RegexIndex] + regexStr, isConst := regexArg.(InterpretableConst) + if !isConst { + return i, nil + } + pattern, ok := regexStr.Value().(types.String) + if !ok { + return i, nil + } + return matcher.Factory(call, string(pattern)) + } +} + +func maybeOptimizeConstUnary(i Interpretable, call InterpretableCall) (Interpretable, error) { + args := call.Args() + if len(args) != 1 { + return i, nil + } + _, isConst := args[0].(InterpretableConst) + if !isConst { + return i, nil + } + val := call.Eval(EmptyActivation()) + if types.IsError(val) { + return nil, val.(*types.Err) + } + return NewConstValue(call.ID(), val), nil +} + +func maybeBuildListLiteral(i Interpretable, l *evalList) (Interpretable, error) { + for _, elem := range l.elems { + _, isConst := elem.(InterpretableConst) + if !isConst { + return i, nil + } + } + return NewConstValue(l.ID(), l.Eval(EmptyActivation())), nil +} + +func maybeBuildMapLiteral(i Interpretable, mp *evalMap) (Interpretable, error) { + for idx, key := range mp.keys { + _, isConst := key.(InterpretableConst) + if !isConst { + return i, nil + } + _, isConst = mp.vals[idx].(InterpretableConst) + if !isConst { + return i, nil + } + } + return NewConstValue(mp.ID(), mp.Eval(EmptyActivation())), nil +} + +// maybeOptimizeSetMembership may convert an 'in' operation against a list to map key membership +// test if the following conditions are true: +// - the list is a constant with homogeneous element types. +// - the elements are all of primitive type. +func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Interpretable, error) { + args := inlist.Args() + lhs := args[0] + rhs := args[1] + l, isConst := rhs.(InterpretableConst) + if !isConst { + return i, nil + } + // When the incoming binary call is flagged with as the InList overload, the value will + // always be convertible to a `traits.Lister` type. + list := l.Value().(traits.Lister) + if list.Size() == types.IntZero { + return NewConstValue(inlist.ID(), types.False), nil + } + it := list.Iterator() + valueSet := make(map[ref.Val]ref.Val) + for it.HasNext() == types.True { + elem := it.Next() + if !types.IsPrimitiveType(elem) || elem.Type() == types.BytesType { + // Note, non-primitive type are not yet supported, and []byte isn't hashable. + return i, nil + } + valueSet[elem] = types.True + switch ev := elem.(type) { + case types.Double: + iv := ev.ConvertToType(types.IntType) + // Ensure that only lossless conversions are added to the set + if !types.IsError(iv) && iv.Equal(ev) == types.True { + valueSet[iv] = types.True + } + // Ensure that only lossless conversions are added to the set + uv := ev.ConvertToType(types.UintType) + if !types.IsError(uv) && uv.Equal(ev) == types.True { + valueSet[uv] = types.True + } + case types.Int: + dv := ev.ConvertToType(types.DoubleType) + if !types.IsError(dv) { + valueSet[dv] = types.True + } + uv := ev.ConvertToType(types.UintType) + if !types.IsError(uv) { + valueSet[uv] = types.True + } + case types.Uint: + dv := ev.ConvertToType(types.DoubleType) + if !types.IsError(dv) { + valueSet[dv] = types.True + } + iv := ev.ConvertToType(types.IntType) + if !types.IsError(iv) { + valueSet[iv] = types.True + } + } + } + return &evalSetMembership{ + inst: inlist, + arg: lhs, + valueSet: valueSet, + }, nil +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/dispatcher.go b/vendor/github.com/authzed/cel-go/interpreter/dispatcher.go new file mode 100644 index 0000000..b2a1eca --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/dispatcher.go @@ -0,0 +1,100 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "fmt" + + "github.com/authzed/cel-go/common/functions" +) + +// Dispatcher resolves function calls to their appropriate overload. +type Dispatcher interface { + // Add one or more overloads, returning an error if any Overload has the same Overload#Name. + Add(overloads ...*functions.Overload) error + + // FindOverload returns an Overload definition matching the provided name. + FindOverload(overload string) (*functions.Overload, bool) + + // OverloadIds returns the set of all overload identifiers configured for dispatch. + OverloadIds() []string +} + +// NewDispatcher returns an empty Dispatcher instance. +func NewDispatcher() Dispatcher { + return &defaultDispatcher{ + overloads: make(map[string]*functions.Overload)} +} + +// ExtendDispatcher returns a Dispatcher which inherits the overloads of its parent, and +// provides an isolation layer between built-ins and extension functions which is useful +// for forward compatibility. +func ExtendDispatcher(parent Dispatcher) Dispatcher { + return &defaultDispatcher{ + parent: parent, + overloads: make(map[string]*functions.Overload)} +} + +// overloadMap helper type for indexing overloads by function name. +type overloadMap map[string]*functions.Overload + +// defaultDispatcher struct which contains an overload map. +type defaultDispatcher struct { + parent Dispatcher + overloads overloadMap +} + +// Add implements the Dispatcher.Add interface method. +func (d *defaultDispatcher) Add(overloads ...*functions.Overload) error { + for _, o := range overloads { + // add the overload unless an overload of the same name has already been provided. + if _, found := d.overloads[o.Operator]; found { + return fmt.Errorf("overload already exists '%s'", o.Operator) + } + // index the overload by function name. + d.overloads[o.Operator] = o + } + return nil +} + +// FindOverload implements the Dispatcher.FindOverload interface method. +func (d *defaultDispatcher) FindOverload(overload string) (*functions.Overload, bool) { + o, found := d.overloads[overload] + // Attempt to dispatch to an overload defined in the parent. + if !found && d.parent != nil { + return d.parent.FindOverload(overload) + } + return o, found +} + +// OverloadIds implements the Dispatcher interface method. +func (d *defaultDispatcher) OverloadIds() []string { + i := 0 + overloads := make([]string, len(d.overloads)) + for name := range d.overloads { + overloads[i] = name + i++ + } + if d.parent == nil { + return overloads + } + parentOverloads := d.parent.OverloadIds() + for _, pName := range parentOverloads { + if _, found := d.overloads[pName]; !found { + overloads = append(overloads, pName) + } + } + return overloads +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/evalstate.go b/vendor/github.com/authzed/cel-go/interpreter/evalstate.go new file mode 100644 index 0000000..d0b8094 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/evalstate.go @@ -0,0 +1,79 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "github.com/authzed/cel-go/common/types/ref" +) + +// EvalState tracks the values associated with expression ids during execution. +type EvalState interface { + // IDs returns the list of ids with recorded values. + IDs() []int64 + + // Value returns the observed value of the given expression id if found, and a nil false + // result if not. + Value(int64) (ref.Val, bool) + + // SetValue sets the observed value of the expression id. + SetValue(int64, ref.Val) + + // Reset clears the previously recorded expression values. + Reset() +} + +// evalState permits the mutation of evaluation state for a given expression id. +type evalState struct { + values map[int64]ref.Val +} + +// NewEvalState returns an EvalState instanced used to observe the intermediate +// evaluations of an expression. +func NewEvalState() EvalState { + return &evalState{ + values: make(map[int64]ref.Val), + } +} + +// IDs implements the EvalState interface method. +func (s *evalState) IDs() []int64 { + var ids []int64 + for k, v := range s.values { + if v != nil { + ids = append(ids, k) + } + } + return ids +} + +// Value is an implementation of the EvalState interface method. +func (s *evalState) Value(exprID int64) (ref.Val, bool) { + val, found := s.values[exprID] + return val, found +} + +// SetValue is an implementation of the EvalState interface method. +func (s *evalState) SetValue(exprID int64, val ref.Val) { + if val == nil { + delete(s.values, exprID) + } else { + s.values[exprID] = val + } +} + +// Reset implements the EvalState interface method. +func (s *evalState) Reset() { + s.values = map[int64]ref.Val{} +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/interpretable.go b/vendor/github.com/authzed/cel-go/interpreter/interpretable.go new file mode 100644 index 0000000..2d583ab --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/interpretable.go @@ -0,0 +1,1264 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "fmt" + + "github.com/authzed/cel-go/common/functions" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" +) + +// Interpretable can accept a given Activation and produce a value along with +// an accompanying EvalState which can be used to inspect whether additional +// data might be necessary to complete the evaluation. +type Interpretable interface { + // ID value corresponding to the expression node. + ID() int64 + + // Eval an Activation to produce an output. + Eval(activation Activation) ref.Val +} + +// InterpretableConst interface for tracking whether the Interpretable is a constant value. +type InterpretableConst interface { + Interpretable + + // Value returns the constant value of the instruction. + Value() ref.Val +} + +// InterpretableAttribute interface for tracking whether the Interpretable is an attribute. +type InterpretableAttribute interface { + Interpretable + + // Attr returns the Attribute value. + Attr() Attribute + + // Adapter returns the type adapter to be used for adapting resolved Attribute values. + Adapter() types.Adapter + + // AddQualifier proxies the Attribute.AddQualifier method. + // + // Note, this method may mutate the current attribute state. If the desire is to clone the + // Attribute, the Attribute should first be copied before adding the qualifier. Attributes + // are not copyable by default, so this is a capable that would need to be added to the + // AttributeFactory or specifically to the underlying Attribute implementation. + AddQualifier(Qualifier) (Attribute, error) + + // Qualify replicates the Attribute.Qualify method to permit extension and interception + // of object qualification. + Qualify(vars Activation, obj any) (any, error) + + // QualifyIfPresent qualifies the object if the qualifier is declared or defined on the object. + // The 'presenceOnly' flag indicates that the value is not necessary, just a boolean status as + // to whether the qualifier is present. + QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) + + // IsOptional indicates whether the resulting value is an optional type. + IsOptional() bool + + // Resolve returns the value of the Attribute given the current Activation. + Resolve(Activation) (any, error) +} + +// InterpretableCall interface for inspecting Interpretable instructions related to function calls. +type InterpretableCall interface { + Interpretable + + // Function returns the function name as it appears in text or mangled operator name as it + // appears in the operators.go file. + Function() string + + // OverloadID returns the overload id associated with the function specialization. + // Overload ids are stable across language boundaries and can be treated as synonymous with a + // unique function signature. + OverloadID() string + + // Args returns the normalized arguments to the function overload. + // For receiver-style functions, the receiver target is arg 0. + Args() []Interpretable +} + +// InterpretableConstructor interface for inspecting Interpretable instructions that initialize a list, map +// or struct. +type InterpretableConstructor interface { + Interpretable + + // InitVals returns all the list elements, map key and values or struct field values. + InitVals() []Interpretable + + // Type returns the type constructed. + Type() ref.Type +} + +// Core Interpretable implementations used during the program planning phase. + +type evalTestOnly struct { + id int64 + InterpretableAttribute +} + +// ID implements the Interpretable interface method. +func (test *evalTestOnly) ID() int64 { + return test.id +} + +// Eval implements the Interpretable interface method. +func (test *evalTestOnly) Eval(ctx Activation) ref.Val { + val, err := test.Resolve(ctx) + // Return an error if the resolve step fails + if err != nil { + return types.LabelErrNode(test.id, types.WrapErr(err)) + } + if optVal, isOpt := val.(*types.Optional); isOpt { + return types.Bool(optVal.HasValue()) + } + return test.Adapter().NativeToValue(val) +} + +// AddQualifier appends a qualifier that will always and only perform a presence test. +func (test *evalTestOnly) AddQualifier(q Qualifier) (Attribute, error) { + cq, ok := q.(ConstantQualifier) + if !ok { + return nil, fmt.Errorf("test only expressions must have constant qualifiers: %v", q) + } + return test.InterpretableAttribute.AddQualifier(&testOnlyQualifier{ConstantQualifier: cq}) +} + +type testOnlyQualifier struct { + ConstantQualifier +} + +// Qualify determines whether the test-only qualifier is present on the input object. +func (q *testOnlyQualifier) Qualify(vars Activation, obj any) (any, error) { + out, present, err := q.ConstantQualifier.QualifyIfPresent(vars, obj, true) + if err != nil { + return nil, err + } + if unk, isUnk := out.(types.Unknown); isUnk { + return unk, nil + } + if opt, isOpt := out.(types.Optional); isOpt { + return opt.HasValue(), nil + } + return present, nil +} + +// QualifyIfPresent returns whether the target field in the test-only expression is present. +func (q *testOnlyQualifier) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + // Only ever test for presence. + return q.ConstantQualifier.QualifyIfPresent(vars, obj, true) +} + +// QualifierValueEquals determines whether the test-only constant qualifier equals the input value. +func (q *testOnlyQualifier) QualifierValueEquals(value any) bool { + // The input qualifier will always be of type string + return q.ConstantQualifier.Value().Value() == value +} + +// NewConstValue creates a new constant valued Interpretable. +func NewConstValue(id int64, val ref.Val) InterpretableConst { + return &evalConst{ + id: id, + val: val, + } +} + +type evalConst struct { + id int64 + val ref.Val +} + +// ID implements the Interpretable interface method. +func (cons *evalConst) ID() int64 { + return cons.id +} + +// Eval implements the Interpretable interface method. +func (cons *evalConst) Eval(ctx Activation) ref.Val { + return cons.val +} + +// Value implements the InterpretableConst interface method. +func (cons *evalConst) Value() ref.Val { + return cons.val +} + +type evalOr struct { + id int64 + terms []Interpretable +} + +// ID implements the Interpretable interface method. +func (or *evalOr) ID() int64 { + return or.id +} + +// Eval implements the Interpretable interface method. +func (or *evalOr) Eval(ctx Activation) ref.Val { + var err ref.Val = nil + var unk *types.Unknown + for _, term := range or.terms { + val := term.Eval(ctx) + boolVal, ok := val.(types.Bool) + // short-circuit on true. + if ok && boolVal == types.True { + return types.True + } + if !ok { + isUnk := false + unk, isUnk = types.MaybeMergeUnknowns(val, unk) + if !isUnk && err == nil { + if types.IsError(val) { + err = val + } else { + err = types.MaybeNoSuchOverloadErr(val) + } + err = types.LabelErrNode(or.id, err) + } + } + } + if unk != nil { + return unk + } + if err != nil { + return err + } + return types.False +} + +type evalAnd struct { + id int64 + terms []Interpretable +} + +// ID implements the Interpretable interface method. +func (and *evalAnd) ID() int64 { + return and.id +} + +// Eval implements the Interpretable interface method. +func (and *evalAnd) Eval(ctx Activation) ref.Val { + var err ref.Val = nil + var unk *types.Unknown + for _, term := range and.terms { + val := term.Eval(ctx) + boolVal, ok := val.(types.Bool) + // short-circuit on false. + if ok && boolVal == types.False { + return types.False + } + if !ok { + isUnk := false + unk, isUnk = types.MaybeMergeUnknowns(val, unk) + if !isUnk && err == nil { + if types.IsError(val) { + err = val + } else { + err = types.MaybeNoSuchOverloadErr(val) + } + err = types.LabelErrNode(and.id, err) + } + } + } + if unk != nil { + return unk + } + if err != nil { + return err + } + return types.True +} + +type evalEq struct { + id int64 + lhs Interpretable + rhs Interpretable +} + +// ID implements the Interpretable interface method. +func (eq *evalEq) ID() int64 { + return eq.id +} + +// Eval implements the Interpretable interface method. +func (eq *evalEq) Eval(ctx Activation) ref.Val { + lVal := eq.lhs.Eval(ctx) + rVal := eq.rhs.Eval(ctx) + if types.IsUnknownOrError(lVal) { + return lVal + } + if types.IsUnknownOrError(rVal) { + return rVal + } + return types.Equal(lVal, rVal) +} + +// Function implements the InterpretableCall interface method. +func (*evalEq) Function() string { + return operators.Equals +} + +// OverloadID implements the InterpretableCall interface method. +func (*evalEq) OverloadID() string { + return overloads.Equals +} + +// Args implements the InterpretableCall interface method. +func (eq *evalEq) Args() []Interpretable { + return []Interpretable{eq.lhs, eq.rhs} +} + +type evalNe struct { + id int64 + lhs Interpretable + rhs Interpretable +} + +// ID implements the Interpretable interface method. +func (ne *evalNe) ID() int64 { + return ne.id +} + +// Eval implements the Interpretable interface method. +func (ne *evalNe) Eval(ctx Activation) ref.Val { + lVal := ne.lhs.Eval(ctx) + rVal := ne.rhs.Eval(ctx) + if types.IsUnknownOrError(lVal) { + return lVal + } + if types.IsUnknownOrError(rVal) { + return rVal + } + return types.Bool(types.Equal(lVal, rVal) != types.True) +} + +// Function implements the InterpretableCall interface method. +func (*evalNe) Function() string { + return operators.NotEquals +} + +// OverloadID implements the InterpretableCall interface method. +func (*evalNe) OverloadID() string { + return overloads.NotEquals +} + +// Args implements the InterpretableCall interface method. +func (ne *evalNe) Args() []Interpretable { + return []Interpretable{ne.lhs, ne.rhs} +} + +type evalZeroArity struct { + id int64 + function string + overload string + impl functions.FunctionOp +} + +// ID implements the Interpretable interface method. +func (zero *evalZeroArity) ID() int64 { + return zero.id +} + +// Eval implements the Interpretable interface method. +func (zero *evalZeroArity) Eval(ctx Activation) ref.Val { + return types.LabelErrNode(zero.id, zero.impl()) +} + +// Function implements the InterpretableCall interface method. +func (zero *evalZeroArity) Function() string { + return zero.function +} + +// OverloadID implements the InterpretableCall interface method. +func (zero *evalZeroArity) OverloadID() string { + return zero.overload +} + +// Args returns the argument to the unary function. +func (zero *evalZeroArity) Args() []Interpretable { + return []Interpretable{} +} + +type evalUnary struct { + id int64 + function string + overload string + arg Interpretable + trait int + impl functions.UnaryOp + nonStrict bool +} + +// ID implements the Interpretable interface method. +func (un *evalUnary) ID() int64 { + return un.id +} + +// Eval implements the Interpretable interface method. +func (un *evalUnary) Eval(ctx Activation) ref.Val { + argVal := un.arg.Eval(ctx) + // Early return if the argument to the function is unknown or error. + strict := !un.nonStrict + if strict && types.IsUnknownOrError(argVal) { + return argVal + } + // If the implementation is bound and the argument value has the right traits required to + // invoke it, then call the implementation. + if un.impl != nil && (un.trait == 0 || (!strict && types.IsUnknownOrError(argVal)) || argVal.Type().HasTrait(un.trait)) { + return types.LabelErrNode(un.id, un.impl(argVal)) + } + // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the + // operand (arg0). + if argVal.Type().HasTrait(traits.ReceiverType) { + return types.LabelErrNode(un.id, argVal.(traits.Receiver).Receive(un.function, un.overload, []ref.Val{})) + } + return types.NewErrWithNodeID(un.id, "no such overload: %s", un.function) +} + +// Function implements the InterpretableCall interface method. +func (un *evalUnary) Function() string { + return un.function +} + +// OverloadID implements the InterpretableCall interface method. +func (un *evalUnary) OverloadID() string { + return un.overload +} + +// Args returns the argument to the unary function. +func (un *evalUnary) Args() []Interpretable { + return []Interpretable{un.arg} +} + +type evalBinary struct { + id int64 + function string + overload string + lhs Interpretable + rhs Interpretable + trait int + impl functions.BinaryOp + nonStrict bool +} + +// ID implements the Interpretable interface method. +func (bin *evalBinary) ID() int64 { + return bin.id +} + +// Eval implements the Interpretable interface method. +func (bin *evalBinary) Eval(ctx Activation) ref.Val { + lVal := bin.lhs.Eval(ctx) + rVal := bin.rhs.Eval(ctx) + // Early return if any argument to the function is unknown or error. + strict := !bin.nonStrict + if strict { + if types.IsUnknownOrError(lVal) { + return lVal + } + if types.IsUnknownOrError(rVal) { + return rVal + } + } + // If the implementation is bound and the argument value has the right traits required to + // invoke it, then call the implementation. + if bin.impl != nil && (bin.trait == 0 || (!strict && types.IsUnknownOrError(lVal)) || lVal.Type().HasTrait(bin.trait)) { + return types.LabelErrNode(bin.id, bin.impl(lVal, rVal)) + } + // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the + // operand (arg0). + if lVal.Type().HasTrait(traits.ReceiverType) { + return types.LabelErrNode(bin.id, lVal.(traits.Receiver).Receive(bin.function, bin.overload, []ref.Val{rVal})) + } + return types.NewErrWithNodeID(bin.id, "no such overload: %s", bin.function) +} + +// Function implements the InterpretableCall interface method. +func (bin *evalBinary) Function() string { + return bin.function +} + +// OverloadID implements the InterpretableCall interface method. +func (bin *evalBinary) OverloadID() string { + return bin.overload +} + +// Args returns the argument to the unary function. +func (bin *evalBinary) Args() []Interpretable { + return []Interpretable{bin.lhs, bin.rhs} +} + +type evalVarArgs struct { + id int64 + function string + overload string + args []Interpretable + trait int + impl functions.FunctionOp + nonStrict bool +} + +// NewCall creates a new call Interpretable. +func NewCall(id int64, function, overload string, args []Interpretable, impl functions.FunctionOp) InterpretableCall { + return &evalVarArgs{ + id: id, + function: function, + overload: overload, + args: args, + impl: impl, + } +} + +// ID implements the Interpretable interface method. +func (fn *evalVarArgs) ID() int64 { + return fn.id +} + +// Eval implements the Interpretable interface method. +func (fn *evalVarArgs) Eval(ctx Activation) ref.Val { + argVals := make([]ref.Val, len(fn.args)) + // Early return if any argument to the function is unknown or error. + strict := !fn.nonStrict + for i, arg := range fn.args { + argVals[i] = arg.Eval(ctx) + if strict && types.IsUnknownOrError(argVals[i]) { + return argVals[i] + } + } + // If the implementation is bound and the argument value has the right traits required to + // invoke it, then call the implementation. + arg0 := argVals[0] + if fn.impl != nil && (fn.trait == 0 || (!strict && types.IsUnknownOrError(arg0)) || arg0.Type().HasTrait(fn.trait)) { + return types.LabelErrNode(fn.id, fn.impl(argVals...)) + } + // Otherwise, if the argument is a ReceiverType attempt to invoke the receiver method on the + // operand (arg0). + if arg0.Type().HasTrait(traits.ReceiverType) { + return types.LabelErrNode(fn.id, arg0.(traits.Receiver).Receive(fn.function, fn.overload, argVals[1:])) + } + return types.NewErrWithNodeID(fn.id, "no such overload: %s %d", fn.function, fn.id) +} + +// Function implements the InterpretableCall interface method. +func (fn *evalVarArgs) Function() string { + return fn.function +} + +// OverloadID implements the InterpretableCall interface method. +func (fn *evalVarArgs) OverloadID() string { + return fn.overload +} + +// Args returns the argument to the unary function. +func (fn *evalVarArgs) Args() []Interpretable { + return fn.args +} + +type evalList struct { + id int64 + elems []Interpretable + optionals []bool + hasOptionals bool + adapter types.Adapter +} + +// ID implements the Interpretable interface method. +func (l *evalList) ID() int64 { + return l.id +} + +// Eval implements the Interpretable interface method. +func (l *evalList) Eval(ctx Activation) ref.Val { + elemVals := make([]ref.Val, 0, len(l.elems)) + // If any argument is unknown or error early terminate. + for i, elem := range l.elems { + elemVal := elem.Eval(ctx) + if types.IsUnknownOrError(elemVal) { + return elemVal + } + if l.hasOptionals && l.optionals[i] { + optVal, ok := elemVal.(*types.Optional) + if !ok { + return types.LabelErrNode(l.id, invalidOptionalElementInit(elemVal)) + } + if !optVal.HasValue() { + continue + } + elemVal = optVal.GetValue() + } + elemVals = append(elemVals, elemVal) + } + return l.adapter.NativeToValue(elemVals) +} + +func (l *evalList) InitVals() []Interpretable { + return l.elems +} + +func (l *evalList) Type() ref.Type { + return types.ListType +} + +type evalMap struct { + id int64 + keys []Interpretable + vals []Interpretable + optionals []bool + hasOptionals bool + adapter types.Adapter +} + +// ID implements the Interpretable interface method. +func (m *evalMap) ID() int64 { + return m.id +} + +// Eval implements the Interpretable interface method. +func (m *evalMap) Eval(ctx Activation) ref.Val { + entries := make(map[ref.Val]ref.Val) + // If any argument is unknown or error early terminate. + for i, key := range m.keys { + keyVal := key.Eval(ctx) + if types.IsUnknownOrError(keyVal) { + return keyVal + } + valVal := m.vals[i].Eval(ctx) + if types.IsUnknownOrError(valVal) { + return valVal + } + if m.hasOptionals && m.optionals[i] { + optVal, ok := valVal.(*types.Optional) + if !ok { + return types.LabelErrNode(m.id, invalidOptionalEntryInit(keyVal, valVal)) + } + if !optVal.HasValue() { + delete(entries, keyVal) + continue + } + valVal = optVal.GetValue() + } + entries[keyVal] = valVal + } + return m.adapter.NativeToValue(entries) +} + +func (m *evalMap) InitVals() []Interpretable { + if len(m.keys) != len(m.vals) { + return nil + } + result := make([]Interpretable, len(m.keys)+len(m.vals)) + idx := 0 + for i, k := range m.keys { + v := m.vals[i] + result[idx] = k + idx++ + result[idx] = v + idx++ + } + return result +} + +func (m *evalMap) Type() ref.Type { + return types.MapType +} + +type evalObj struct { + id int64 + typeName string + fields []string + vals []Interpretable + optionals []bool + hasOptionals bool + provider types.Provider +} + +// ID implements the Interpretable interface method. +func (o *evalObj) ID() int64 { + return o.id +} + +// Eval implements the Interpretable interface method. +func (o *evalObj) Eval(ctx Activation) ref.Val { + fieldVals := make(map[string]ref.Val) + // If any argument is unknown or error early terminate. + for i, field := range o.fields { + val := o.vals[i].Eval(ctx) + if types.IsUnknownOrError(val) { + return val + } + if o.hasOptionals && o.optionals[i] { + optVal, ok := val.(*types.Optional) + if !ok { + return types.LabelErrNode(o.id, invalidOptionalEntryInit(field, val)) + } + if !optVal.HasValue() { + delete(fieldVals, field) + continue + } + val = optVal.GetValue() + } + fieldVals[field] = val + } + return types.LabelErrNode(o.id, o.provider.NewValue(o.typeName, fieldVals)) +} + +func (o *evalObj) InitVals() []Interpretable { + return o.vals +} + +func (o *evalObj) Type() ref.Type { + return types.NewObjectTypeValue(o.typeName) +} + +type evalFold struct { + id int64 + accuVar string + iterVar string + iterRange Interpretable + accu Interpretable + cond Interpretable + step Interpretable + result Interpretable + adapter types.Adapter + exhaustive bool + interruptable bool +} + +// ID implements the Interpretable interface method. +func (fold *evalFold) ID() int64 { + return fold.id +} + +// Eval implements the Interpretable interface method. +func (fold *evalFold) Eval(ctx Activation) ref.Val { + foldRange := fold.iterRange.Eval(ctx) + if !foldRange.Type().HasTrait(traits.IterableType) { + return types.ValOrErr(foldRange, "got '%T', expected iterable type", foldRange) + } + // Configure the fold activation with the accumulator initial value. + accuCtx := varActivationPool.Get().(*varActivation) + accuCtx.parent = ctx + accuCtx.name = fold.accuVar + accuCtx.val = fold.accu.Eval(ctx) + // If the accumulator starts as an empty list, then the comprehension will build a list + // so create a mutable list to optimize the cost of the inner loop. + l, ok := accuCtx.val.(traits.Lister) + buildingList := false + if !fold.exhaustive && ok && l.Size() == types.IntZero { + buildingList = true + accuCtx.val = types.NewMutableList(fold.adapter) + } + iterCtx := varActivationPool.Get().(*varActivation) + iterCtx.parent = accuCtx + iterCtx.name = fold.iterVar + + interrupted := false + it := foldRange.(traits.Iterable).Iterator() + for it.HasNext() == types.True { + // Modify the iter var in the fold activation. + iterCtx.val = it.Next() + + // Evaluate the condition, terminate the loop if false. + cond := fold.cond.Eval(iterCtx) + condBool, ok := cond.(types.Bool) + if !fold.exhaustive && ok && condBool != types.True { + break + } + // Evaluate the evaluation step into accu var. + accuCtx.val = fold.step.Eval(iterCtx) + if fold.interruptable { + if stop, found := ctx.ResolveName("#interrupted"); found && stop == true { + interrupted = true + break + } + } + } + varActivationPool.Put(iterCtx) + if interrupted { + varActivationPool.Put(accuCtx) + return types.NewErr("operation interrupted") + } + + // Compute the result. + res := fold.result.Eval(accuCtx) + varActivationPool.Put(accuCtx) + // Convert a mutable list to an immutable one, if the comprehension has generated a list as a result. + if !types.IsUnknownOrError(res) && buildingList { + if _, ok := res.(traits.MutableLister); ok { + res = res.(traits.MutableLister).ToImmutableList() + } + } + return res +} + +// Optional Interpretable implementations that specialize, subsume, or extend the core evaluation +// plan via decorators. + +// evalSetMembership is an Interpretable implementation which tests whether an input value +// exists within the set of map keys used to model a set. +type evalSetMembership struct { + inst Interpretable + arg Interpretable + valueSet map[ref.Val]ref.Val +} + +// ID implements the Interpretable interface method. +func (e *evalSetMembership) ID() int64 { + return e.inst.ID() +} + +// Eval implements the Interpretable interface method. +func (e *evalSetMembership) Eval(ctx Activation) ref.Val { + val := e.arg.Eval(ctx) + if types.IsUnknownOrError(val) { + return val + } + if ret, found := e.valueSet[val]; found { + return ret + } + return types.False +} + +// evalWatch is an Interpretable implementation that wraps the execution of a given +// expression so that it may observe the computed value and send it to an observer. +type evalWatch struct { + Interpretable + observer EvalObserver +} + +// Eval implements the Interpretable interface method. +func (e *evalWatch) Eval(ctx Activation) ref.Val { + val := e.Interpretable.Eval(ctx) + e.observer(e.ID(), e.Interpretable, val) + return val +} + +// evalWatchAttr describes a watcher of an InterpretableAttribute Interpretable. +// +// Since the watcher may be selected against at a later stage in program planning, the watcher +// must implement the InterpretableAttribute interface by proxy. +type evalWatchAttr struct { + InterpretableAttribute + observer EvalObserver +} + +// AddQualifier creates a wrapper over the incoming qualifier which observes the qualification +// result. +func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) { + switch qual := q.(type) { + // By default, the qualifier is either a constant or an attribute + // There may be some custom cases where the attribute is neither. + case ConstantQualifier: + // Expose a method to test whether the qualifier matches the input pattern. + q = &evalWatchConstQual{ + ConstantQualifier: qual, + observer: e.observer, + adapter: e.Adapter(), + } + case *evalWatchAttr: + // Unwrap the evalWatchAttr since the observation will be applied during Qualify or + // QualifyIfPresent rather than Eval. + q = &evalWatchAttrQual{ + Attribute: qual.InterpretableAttribute, + observer: e.observer, + adapter: e.Adapter(), + } + case Attribute: + // Expose methods which intercept the qualification prior to being applied as a qualifier. + // Using this interface ensures that the qualifier is converted to a constant value one + // time during attribute pattern matching as the method embeds the Attribute interface + // needed to trip the conversion to a constant. + q = &evalWatchAttrQual{ + Attribute: qual, + observer: e.observer, + adapter: e.Adapter(), + } + default: + // This is likely a custom qualifier type. + q = &evalWatchQual{ + Qualifier: qual, + observer: e.observer, + adapter: e.Adapter(), + } + } + _, err := e.InterpretableAttribute.AddQualifier(q) + return e, err +} + +// Eval implements the Interpretable interface method. +func (e *evalWatchAttr) Eval(vars Activation) ref.Val { + val := e.InterpretableAttribute.Eval(vars) + e.observer(e.ID(), e.InterpretableAttribute, val) + return val +} + +// evalWatchConstQual observes the qualification of an object using a constant boolean, int, +// string, or uint. +type evalWatchConstQual struct { + ConstantQualifier + observer EvalObserver + adapter types.Adapter +} + +// Qualify observes the qualification of a object via a constant boolean, int, string, or uint. +func (e *evalWatchConstQual) Qualify(vars Activation, obj any) (any, error) { + out, err := e.ConstantQualifier.Qualify(vars, obj) + var val ref.Val + if err != nil { + val = types.LabelErrNode(e.ID(), types.WrapErr(err)) + } else { + val = e.adapter.NativeToValue(out) + } + e.observer(e.ID(), e.ConstantQualifier, val) + return out, err +} + +// QualifyIfPresent conditionally qualifies the variable and only records a value if one is present. +func (e *evalWatchConstQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + out, present, err := e.ConstantQualifier.QualifyIfPresent(vars, obj, presenceOnly) + var val ref.Val + if err != nil { + val = types.LabelErrNode(e.ID(), types.WrapErr(err)) + } else if out != nil { + val = e.adapter.NativeToValue(out) + } else if presenceOnly { + val = types.Bool(present) + } + if present || presenceOnly { + e.observer(e.ID(), e.ConstantQualifier, val) + } + return out, present, err +} + +// QualifierValueEquals tests whether the incoming value is equal to the qualifying constant. +func (e *evalWatchConstQual) QualifierValueEquals(value any) bool { + qve, ok := e.ConstantQualifier.(qualifierValueEquator) + return ok && qve.QualifierValueEquals(value) +} + +// evalWatchAttrQual observes the qualification of an object by a value computed at runtime. +type evalWatchAttrQual struct { + Attribute + observer EvalObserver + adapter ref.TypeAdapter +} + +// Qualify observes the qualification of a object via a value computed at runtime. +func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) { + out, err := e.Attribute.Qualify(vars, obj) + var val ref.Val + if err != nil { + val = types.LabelErrNode(e.ID(), types.WrapErr(err)) + } else { + val = e.adapter.NativeToValue(out) + } + e.observer(e.ID(), e.Attribute, val) + return out, err +} + +// QualifyIfPresent conditionally qualifies the variable and only records a value if one is present. +func (e *evalWatchAttrQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + out, present, err := e.Attribute.QualifyIfPresent(vars, obj, presenceOnly) + var val ref.Val + if err != nil { + val = types.LabelErrNode(e.ID(), types.WrapErr(err)) + } else if out != nil { + val = e.adapter.NativeToValue(out) + } else if presenceOnly { + val = types.Bool(present) + } + if present || presenceOnly { + e.observer(e.ID(), e.Attribute, val) + } + return out, present, err +} + +// evalWatchQual observes the qualification of an object by a value computed at runtime. +type evalWatchQual struct { + Qualifier + observer EvalObserver + adapter types.Adapter +} + +// Qualify observes the qualification of a object via a value computed at runtime. +func (e *evalWatchQual) Qualify(vars Activation, obj any) (any, error) { + out, err := e.Qualifier.Qualify(vars, obj) + var val ref.Val + if err != nil { + val = types.LabelErrNode(e.ID(), types.WrapErr(err)) + } else { + val = e.adapter.NativeToValue(out) + } + e.observer(e.ID(), e.Qualifier, val) + return out, err +} + +// QualifyIfPresent conditionally qualifies the variable and only records a value if one is present. +func (e *evalWatchQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + out, present, err := e.Qualifier.QualifyIfPresent(vars, obj, presenceOnly) + var val ref.Val + if err != nil { + val = types.LabelErrNode(e.ID(), types.WrapErr(err)) + } else if out != nil { + val = e.adapter.NativeToValue(out) + } else if presenceOnly { + val = types.Bool(present) + } + if present || presenceOnly { + e.observer(e.ID(), e.Qualifier, val) + } + return out, present, err +} + +// evalWatchConst describes a watcher of an instConst Interpretable. +type evalWatchConst struct { + InterpretableConst + observer EvalObserver +} + +// Eval implements the Interpretable interface method. +func (e *evalWatchConst) Eval(vars Activation) ref.Val { + val := e.Value() + e.observer(e.ID(), e.InterpretableConst, val) + return val +} + +// evalExhaustiveOr is just like evalOr, but does not short-circuit argument evaluation. +type evalExhaustiveOr struct { + id int64 + terms []Interpretable +} + +// ID implements the Interpretable interface method. +func (or *evalExhaustiveOr) ID() int64 { + return or.id +} + +// Eval implements the Interpretable interface method. +func (or *evalExhaustiveOr) Eval(ctx Activation) ref.Val { + var err ref.Val = nil + var unk *types.Unknown + isTrue := false + for _, term := range or.terms { + val := term.Eval(ctx) + boolVal, ok := val.(types.Bool) + // flag the result as true + if ok && boolVal == types.True { + isTrue = true + } + if !ok && !isTrue { + isUnk := false + unk, isUnk = types.MaybeMergeUnknowns(val, unk) + if !isUnk && err == nil { + if types.IsError(val) { + err = val + } else { + err = types.MaybeNoSuchOverloadErr(val) + } + } + } + } + if isTrue { + return types.True + } + if unk != nil { + return unk + } + if err != nil { + return err + } + return types.False +} + +// evalExhaustiveAnd is just like evalAnd, but does not short-circuit argument evaluation. +type evalExhaustiveAnd struct { + id int64 + terms []Interpretable +} + +// ID implements the Interpretable interface method. +func (and *evalExhaustiveAnd) ID() int64 { + return and.id +} + +// Eval implements the Interpretable interface method. +func (and *evalExhaustiveAnd) Eval(ctx Activation) ref.Val { + var err ref.Val = nil + var unk *types.Unknown + isFalse := false + for _, term := range and.terms { + val := term.Eval(ctx) + boolVal, ok := val.(types.Bool) + // short-circuit on false. + if ok && boolVal == types.False { + isFalse = true + } + if !ok && !isFalse { + isUnk := false + unk, isUnk = types.MaybeMergeUnknowns(val, unk) + if !isUnk && err == nil { + if types.IsError(val) { + err = val + } else { + err = types.MaybeNoSuchOverloadErr(val) + } + } + } + } + if isFalse { + return types.False + } + if unk != nil { + return unk + } + if err != nil { + return err + } + return types.True +} + +// evalExhaustiveConditional is like evalConditional, but does not short-circuit argument +// evaluation. +type evalExhaustiveConditional struct { + id int64 + adapter types.Adapter + attr *conditionalAttribute +} + +// ID implements the Interpretable interface method. +func (cond *evalExhaustiveConditional) ID() int64 { + return cond.id +} + +// Eval implements the Interpretable interface method. +func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val { + cVal := cond.attr.expr.Eval(ctx) + tVal, tErr := cond.attr.truthy.Resolve(ctx) + fVal, fErr := cond.attr.falsy.Resolve(ctx) + cBool, ok := cVal.(types.Bool) + if !ok { + return types.ValOrErr(cVal, "no such overload") + } + if cBool { + if tErr != nil { + return types.LabelErrNode(cond.id, types.WrapErr(tErr)) + } + return cond.adapter.NativeToValue(tVal) + } + if fErr != nil { + return types.LabelErrNode(cond.id, types.WrapErr(fErr)) + } + return cond.adapter.NativeToValue(fVal) +} + +// evalAttr evaluates an Attribute value. +type evalAttr struct { + adapter types.Adapter + attr Attribute + optional bool +} + +var _ InterpretableAttribute = &evalAttr{} + +// ID of the attribute instruction. +func (a *evalAttr) ID() int64 { + return a.attr.ID() +} + +// AddQualifier implements the InterpretableAttribute interface method. +func (a *evalAttr) AddQualifier(qual Qualifier) (Attribute, error) { + attr, err := a.attr.AddQualifier(qual) + a.attr = attr + return attr, err +} + +// Attr implements the InterpretableAttribute interface method. +func (a *evalAttr) Attr() Attribute { + return a.attr +} + +// Adapter implements the InterpretableAttribute interface method. +func (a *evalAttr) Adapter() types.Adapter { + return a.adapter +} + +// Eval implements the Interpretable interface method. +func (a *evalAttr) Eval(ctx Activation) ref.Val { + v, err := a.attr.Resolve(ctx) + if err != nil { + return types.LabelErrNode(a.ID(), types.WrapErr(err)) + } + return a.adapter.NativeToValue(v) +} + +// Qualify proxies to the Attribute's Qualify method. +func (a *evalAttr) Qualify(ctx Activation, obj any) (any, error) { + return a.attr.Qualify(ctx, obj) +} + +// QualifyIfPresent proxies to the Attribute's QualifyIfPresent method. +func (a *evalAttr) QualifyIfPresent(ctx Activation, obj any, presenceOnly bool) (any, bool, error) { + return a.attr.QualifyIfPresent(ctx, obj, presenceOnly) +} + +func (a *evalAttr) IsOptional() bool { + return a.optional +} + +// Resolve proxies to the Attribute's Resolve method. +func (a *evalAttr) Resolve(ctx Activation) (any, error) { + return a.attr.Resolve(ctx) +} + +type evalWatchConstructor struct { + constructor InterpretableConstructor + observer EvalObserver +} + +// InitVals implements the InterpretableConstructor InitVals function. +func (c *evalWatchConstructor) InitVals() []Interpretable { + return c.constructor.InitVals() +} + +// Type implements the InterpretableConstructor Type function. +func (c *evalWatchConstructor) Type() ref.Type { + return c.constructor.Type() +} + +// ID implements the Interpretable ID function. +func (c *evalWatchConstructor) ID() int64 { + return c.constructor.ID() +} + +// Eval implements the Interpretable Eval function. +func (c *evalWatchConstructor) Eval(ctx Activation) ref.Val { + val := c.constructor.Eval(ctx) + c.observer(c.ID(), c.constructor, val) + return val +} + +func invalidOptionalEntryInit(field any, value ref.Val) ref.Val { + return types.NewErr("cannot initialize optional entry '%v' from non-optional value %v", field, value) +} + +func invalidOptionalElementInit(value ref.Val) ref.Val { + return types.NewErr("cannot initialize optional list element from non-optional value %v", value) +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/interpreter.go b/vendor/github.com/authzed/cel-go/interpreter/interpreter.go new file mode 100644 index 0000000..96eea0e --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/interpreter.go @@ -0,0 +1,185 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package interpreter provides functions to evaluate parsed expressions with +// the option to augment the evaluation with inputs and functions supplied at +// evaluation time. +package interpreter + +import ( + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/containers" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" +) + +// Interpreter generates a new Interpretable from a checked or unchecked expression. +type Interpreter interface { + // NewInterpretable creates an Interpretable from a checked expression and an + // optional list of InterpretableDecorator values. + NewInterpretable(exprAST *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error) +} + +// EvalObserver is a functional interface that accepts an expression id and an observed value. +// The id identifies the expression that was evaluated, the programStep is the Interpretable or Qualifier that +// was evaluated and value is the result of the evaluation. +type EvalObserver func(id int64, programStep any, value ref.Val) + +// Observe constructs a decorator that calls all the provided observers in order after evaluating each Interpretable +// or Qualifier during program evaluation. +func Observe(observers ...EvalObserver) InterpretableDecorator { + if len(observers) == 1 { + return decObserveEval(observers[0]) + } + observeFn := func(id int64, programStep any, val ref.Val) { + for _, observer := range observers { + observer(id, programStep, val) + } + } + return decObserveEval(observeFn) +} + +// EvalCancelledError represents a cancelled program evaluation operation. +type EvalCancelledError struct { + Message string + // Type identifies the cause of the cancellation. + Cause CancellationCause +} + +func (e EvalCancelledError) Error() string { + return e.Message +} + +// CancellationCause enumerates the ways a program evaluation operation can be cancelled. +type CancellationCause int + +const ( + // ContextCancelled indicates that the operation was cancelled in response to a Golang context cancellation. + ContextCancelled CancellationCause = iota + + // CostLimitExceeded indicates that the operation was cancelled in response to the actual cost limit being + // exceeded. + CostLimitExceeded +) + +// TODO: Replace all usages of TrackState with EvalStateObserver + +// TrackState decorates each expression node with an observer which records the value +// associated with the given expression id. EvalState must be provided to the decorator. +// This decorator is not thread-safe, and the EvalState must be reset between Eval() +// calls. +// DEPRECATED: Please use EvalStateObserver instead. It composes gracefully with additional observers. +func TrackState(state EvalState) InterpretableDecorator { + return Observe(EvalStateObserver(state)) +} + +// EvalStateObserver provides an observer which records the value +// associated with the given expression id. EvalState must be provided to the observer. +// This decorator is not thread-safe, and the EvalState must be reset between Eval() +// calls. +func EvalStateObserver(state EvalState) EvalObserver { + return func(id int64, programStep any, val ref.Val) { + state.SetValue(id, val) + } +} + +// ExhaustiveEval replaces operations that short-circuit with versions that evaluate +// expressions and couples this behavior with the TrackState() decorator to provide +// insight into the evaluation state of the entire expression. EvalState must be +// provided to the decorator. This decorator is not thread-safe, and the EvalState +// must be reset between Eval() calls. +func ExhaustiveEval() InterpretableDecorator { + ex := decDisableShortcircuits() + return func(i Interpretable) (Interpretable, error) { + return ex(i) + } +} + +// InterruptableEval annotates comprehension loops with information that indicates they +// should check the `#interrupted` state within a custom Activation. +// +// The custom activation is currently managed higher up in the stack within the 'cel' package +// and should not require any custom support on behalf of callers. +func InterruptableEval() InterpretableDecorator { + return decInterruptFolds() +} + +// Optimize will pre-compute operations such as list and map construction and optimize +// call arguments to set membership tests. The set of optimizations will increase over time. +func Optimize() InterpretableDecorator { + return decOptimize() +} + +// RegexOptimization provides a way to replace an InterpretableCall for a regex function when the +// RegexIndex argument is a string constant. Typically, the Factory would compile the regex pattern at +// RegexIndex and report any errors (at program creation time) and then use the compiled regex for +// all regex function invocations. +type RegexOptimization struct { + // Function is the name of the function to optimize. + Function string + // OverloadID is the ID of the overload to optimize. + OverloadID string + // RegexIndex is the index position of the regex pattern argument. Only calls to the function where this argument is + // a string constant will be delegated to this optimizer. + RegexIndex int + // Factory constructs a replacement InterpretableCall node that optimizes the regex function call. Factory is + // provided with the unoptimized regex call and the string constant at the RegexIndex argument. + // The Factory may compile the regex for use across all invocations of the call, return any errors and + // return an interpreter.NewCall with the desired regex optimized function impl. + Factory func(call InterpretableCall, regexPattern string) (InterpretableCall, error) +} + +// CompileRegexConstants compiles regex pattern string constants at program creation time and reports any regex pattern +// compile errors. +func CompileRegexConstants(regexOptimizations ...*RegexOptimization) InterpretableDecorator { + return decRegexOptimizer(regexOptimizations...) +} + +type exprInterpreter struct { + dispatcher Dispatcher + container *containers.Container + provider types.Provider + adapter types.Adapter + attrFactory AttributeFactory +} + +// NewInterpreter builds an Interpreter from a Dispatcher and TypeProvider which will be used +// throughout the Eval of all Interpretable instances generated from it. +func NewInterpreter(dispatcher Dispatcher, + container *containers.Container, + provider types.Provider, + adapter types.Adapter, + attrFactory AttributeFactory) Interpreter { + return &exprInterpreter{ + dispatcher: dispatcher, + container: container, + provider: provider, + adapter: adapter, + attrFactory: attrFactory} +} + +// NewIntepretable implements the Interpreter interface method. +func (i *exprInterpreter) NewInterpretable( + checked *ast.AST, + decorators ...InterpretableDecorator) (Interpretable, error) { + p := newPlanner( + i.dispatcher, + i.provider, + i.adapter, + i.attrFactory, + i.container, + checked, + decorators...) + return p.Plan(checked.Expr()) +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/optimizations.go b/vendor/github.com/authzed/cel-go/interpreter/optimizations.go new file mode 100644 index 0000000..5a90513 --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/optimizations.go @@ -0,0 +1,46 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "regexp" + + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" +) + +// MatchesRegexOptimization optimizes the 'matches' standard library function by compiling the regex pattern and +// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function +// call invocations. +var MatchesRegexOptimization = &RegexOptimization{ + Function: "matches", + RegexIndex: 1, + Factory: func(call InterpretableCall, regexPattern string) (InterpretableCall, error) { + compiledRegex, err := regexp.Compile(regexPattern) + if err != nil { + return nil, err + } + return NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(values ...ref.Val) ref.Val { + if len(values) != 2 { + return types.NoSuchOverloadErr() + } + in, ok := values[0].Value().(string) + if !ok { + return types.NoSuchOverloadErr() + } + return types.Bool(compiledRegex.MatchString(in)) + }), nil + }, +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/planner.go b/vendor/github.com/authzed/cel-go/interpreter/planner.go new file mode 100644 index 0000000..4451f7e --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/planner.go @@ -0,0 +1,756 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "fmt" + "strings" + + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/containers" + "github.com/authzed/cel-go/common/functions" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/types" +) + +// interpretablePlanner creates an Interpretable evaluation plan from a proto Expr value. +type interpretablePlanner interface { + // Plan generates an Interpretable value (or error) from the input proto Expr. + Plan(expr ast.Expr) (Interpretable, error) +} + +// newPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider, +// TypeAdapter, Container, and CheckedExpr value. These pieces of data are used to resolve +// functions, types, and namespaced identifiers at plan time rather than at runtime since +// it only needs to be done once and may be semi-expensive to compute. +func newPlanner(disp Dispatcher, + provider types.Provider, + adapter types.Adapter, + attrFactory AttributeFactory, + cont *containers.Container, + exprAST *ast.AST, + decorators ...InterpretableDecorator) interpretablePlanner { + return &planner{ + disp: disp, + provider: provider, + adapter: adapter, + attrFactory: attrFactory, + container: cont, + refMap: exprAST.ReferenceMap(), + typeMap: exprAST.TypeMap(), + decorators: decorators, + } +} + +// planner is an implementation of the interpretablePlanner interface. +type planner struct { + disp Dispatcher + provider types.Provider + adapter types.Adapter + attrFactory AttributeFactory + container *containers.Container + refMap map[int64]*ast.ReferenceInfo + typeMap map[int64]*types.Type + decorators []InterpretableDecorator +} + +// Plan implements the interpretablePlanner interface. This implementation of the Plan method also +// applies decorators to each Interpretable generated as part of the overall plan. Decorators are +// useful for layering functionality into the evaluation that is not natively understood by CEL, +// such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of +// repeated expressions. +func (p *planner) Plan(expr ast.Expr) (Interpretable, error) { + switch expr.Kind() { + case ast.CallKind: + return p.decorate(p.planCall(expr)) + case ast.IdentKind: + return p.decorate(p.planIdent(expr)) + case ast.LiteralKind: + return p.decorate(p.planConst(expr)) + case ast.SelectKind: + return p.decorate(p.planSelect(expr)) + case ast.ListKind: + return p.decorate(p.planCreateList(expr)) + case ast.MapKind: + return p.decorate(p.planCreateMap(expr)) + case ast.StructKind: + return p.decorate(p.planCreateStruct(expr)) + case ast.ComprehensionKind: + return p.decorate(p.planComprehension(expr)) + } + return nil, fmt.Errorf("unsupported expr: %v", expr) +} + +// decorate applies the InterpretableDecorator functions to the given Interpretable. +// Both the Interpretable and error generated by a Plan step are accepted as arguments +// for convenience. +func (p *planner) decorate(i Interpretable, err error) (Interpretable, error) { + if err != nil { + return nil, err + } + for _, dec := range p.decorators { + i, err = dec(i) + if err != nil { + return nil, err + } + } + return i, nil +} + +// planIdent creates an Interpretable that resolves an identifier from an Activation. +func (p *planner) planIdent(expr ast.Expr) (Interpretable, error) { + // Establish whether the identifier is in the reference map. + if identRef, found := p.refMap[expr.ID()]; found { + return p.planCheckedIdent(expr.ID(), identRef) + } + // Create the possible attribute list for the unresolved reference. + ident := expr.AsIdent() + return &evalAttr{ + adapter: p.adapter, + attr: p.attrFactory.MaybeAttribute(expr.ID(), ident), + }, nil +} + +func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Interpretable, error) { + // Plan a constant reference if this is the case for this simple identifier. + if identRef.Value != nil { + return NewConstValue(id, identRef.Value), nil + } + + // Check to see whether the type map indicates this is a type name. All types should be + // registered with the provider. + cType := p.typeMap[id] + if cType.Kind() == types.TypeKind { + cVal, found := p.provider.FindIdent(identRef.Name) + if !found { + return nil, fmt.Errorf("reference to undefined type: %s", identRef.Name) + } + return NewConstValue(id, cVal), nil + } + + // Otherwise, return the attribute for the resolved identifier name. + return &evalAttr{ + adapter: p.adapter, + attr: p.attrFactory.AbsoluteAttribute(id, identRef.Name), + }, nil +} + +// planSelect creates an Interpretable with either: +// +// a) selects a field from a map or proto. +// b) creates a field presence test for a select within a has() macro. +// c) resolves the select expression to a namespaced identifier. +func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) { + // If the Select id appears in the reference map from the CheckedExpr proto then it is either + // a namespaced identifier or enum value. + if identRef, found := p.refMap[expr.ID()]; found { + return p.planCheckedIdent(expr.ID(), identRef) + } + + sel := expr.AsSelect() + // Plan the operand evaluation. + op, err := p.Plan(sel.Operand()) + if err != nil { + return nil, err + } + opType := p.typeMap[sel.Operand().ID()] + + // If the Select was marked TestOnly, this is a presence test. + // + // Note: presence tests are defined for structured (e.g. proto) and dynamic values (map, json) + // as follows: + // - True if the object field has a non-default value, e.g. obj.str != "" + // - True if the dynamic value has the field defined, e.g. key in map + // + // However, presence tests are not defined for qualified identifier names with primitive types. + // If a string named 'a.b.c' is declared in the environment and referenced within `has(a.b.c)`, + // it is not clear whether has should error or follow the convention defined for structured + // values. + + // Establish the attribute reference. + attr, isAttr := op.(InterpretableAttribute) + if !isAttr { + attr, err = p.relativeAttr(op.ID(), op, false) + if err != nil { + return nil, err + } + } + + // Build a qualifier for the attribute. + qual, err := p.attrFactory.NewQualifier(opType, expr.ID(), sel.FieldName(), false) + if err != nil { + return nil, err + } + // Modify the attribute to be test-only. + if sel.IsTestOnly() { + attr = &evalTestOnly{ + id: expr.ID(), + InterpretableAttribute: attr, + } + } + // Append the qualifier on the attribute. + _, err = attr.AddQualifier(qual) + return attr, err +} + +// planCall creates a callable Interpretable while specializing for common functions and invocation +// patterns. Specifically, conditional operators &&, ||, ?:, and (in)equality functions result in +// optimized Interpretable values. +func (p *planner) planCall(expr ast.Expr) (Interpretable, error) { + call := expr.AsCall() + target, fnName, oName := p.resolveFunction(expr) + argCount := len(call.Args()) + var offset int + if target != nil { + argCount++ + offset++ + } + + args := make([]Interpretable, argCount) + if target != nil { + arg, err := p.Plan(target) + if err != nil { + return nil, err + } + args[0] = arg + } + for i, argExpr := range call.Args() { + arg, err := p.Plan(argExpr) + if err != nil { + return nil, err + } + args[i+offset] = arg + } + + // Generate specialized Interpretable operators by function name if possible. + switch fnName { + case operators.LogicalAnd: + return p.planCallLogicalAnd(expr, args) + case operators.LogicalOr: + return p.planCallLogicalOr(expr, args) + case operators.Conditional: + return p.planCallConditional(expr, args) + case operators.Equals: + return p.planCallEqual(expr, args) + case operators.NotEquals: + return p.planCallNotEqual(expr, args) + case operators.Index: + return p.planCallIndex(expr, args, false) + case operators.OptSelect, operators.OptIndex: + return p.planCallIndex(expr, args, true) + } + + // Otherwise, generate Interpretable calls specialized by argument count. + // Try to find the specific function by overload id. + var fnDef *functions.Overload + if oName != "" { + fnDef, _ = p.disp.FindOverload(oName) + } + // If the overload id couldn't resolve the function, try the simple function name. + if fnDef == nil { + fnDef, _ = p.disp.FindOverload(fnName) + } + switch argCount { + case 0: + return p.planCallZero(expr, fnName, oName, fnDef) + case 1: + // If the FunctionOp has been used, then use it as it may exist for the purposes + // of dynamic dispatch within a singleton function implementation. + if fnDef != nil && fnDef.Unary == nil && fnDef.Function != nil { + return p.planCallVarArgs(expr, fnName, oName, fnDef, args) + } + return p.planCallUnary(expr, fnName, oName, fnDef, args) + case 2: + // If the FunctionOp has been used, then use it as it may exist for the purposes + // of dynamic dispatch within a singleton function implementation. + if fnDef != nil && fnDef.Binary == nil && fnDef.Function != nil { + return p.planCallVarArgs(expr, fnName, oName, fnDef, args) + } + return p.planCallBinary(expr, fnName, oName, fnDef, args) + default: + return p.planCallVarArgs(expr, fnName, oName, fnDef, args) + } +} + +// planCallZero generates a zero-arity callable Interpretable. +func (p *planner) planCallZero(expr ast.Expr, + function string, + overload string, + impl *functions.Overload) (Interpretable, error) { + if impl == nil || impl.Function == nil { + return nil, fmt.Errorf("no such overload: %s()", function) + } + return &evalZeroArity{ + id: expr.ID(), + function: function, + overload: overload, + impl: impl.Function, + }, nil +} + +// planCallUnary generates a unary callable Interpretable. +func (p *planner) planCallUnary(expr ast.Expr, + function string, + overload string, + impl *functions.Overload, + args []Interpretable) (Interpretable, error) { + var fn functions.UnaryOp + var trait int + var nonStrict bool + if impl != nil { + if impl.Unary == nil { + return nil, fmt.Errorf("no such overload: %s(arg)", function) + } + fn = impl.Unary + trait = impl.OperandTrait + nonStrict = impl.NonStrict + } + return &evalUnary{ + id: expr.ID(), + function: function, + overload: overload, + arg: args[0], + trait: trait, + impl: fn, + nonStrict: nonStrict, + }, nil +} + +// planCallBinary generates a binary callable Interpretable. +func (p *planner) planCallBinary(expr ast.Expr, + function string, + overload string, + impl *functions.Overload, + args []Interpretable) (Interpretable, error) { + var fn functions.BinaryOp + var trait int + var nonStrict bool + if impl != nil { + if impl.Binary == nil { + return nil, fmt.Errorf("no such overload: %s(lhs, rhs)", function) + } + fn = impl.Binary + trait = impl.OperandTrait + nonStrict = impl.NonStrict + } + return &evalBinary{ + id: expr.ID(), + function: function, + overload: overload, + lhs: args[0], + rhs: args[1], + trait: trait, + impl: fn, + nonStrict: nonStrict, + }, nil +} + +// planCallVarArgs generates a variable argument callable Interpretable. +func (p *planner) planCallVarArgs(expr ast.Expr, + function string, + overload string, + impl *functions.Overload, + args []Interpretable) (Interpretable, error) { + var fn functions.FunctionOp + var trait int + var nonStrict bool + if impl != nil { + if impl.Function == nil { + return nil, fmt.Errorf("no such overload: %s(...)", function) + } + fn = impl.Function + trait = impl.OperandTrait + nonStrict = impl.NonStrict + } + return &evalVarArgs{ + id: expr.ID(), + function: function, + overload: overload, + args: args, + trait: trait, + impl: fn, + nonStrict: nonStrict, + }, nil +} + +// planCallEqual generates an equals (==) Interpretable. +func (p *planner) planCallEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { + return &evalEq{ + id: expr.ID(), + lhs: args[0], + rhs: args[1], + }, nil +} + +// planCallNotEqual generates a not equals (!=) Interpretable. +func (p *planner) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { + return &evalNe{ + id: expr.ID(), + lhs: args[0], + rhs: args[1], + }, nil +} + +// planCallLogicalAnd generates a logical and (&&) Interpretable. +func (p *planner) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Interpretable, error) { + return &evalAnd{ + id: expr.ID(), + terms: args, + }, nil +} + +// planCallLogicalOr generates a logical or (||) Interpretable. +func (p *planner) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interpretable, error) { + return &evalOr{ + id: expr.ID(), + terms: args, + }, nil +} + +// planCallConditional generates a conditional / ternary (c ? t : f) Interpretable. +func (p *planner) planCallConditional(expr ast.Expr, args []Interpretable) (Interpretable, error) { + cond := args[0] + t := args[1] + var tAttr Attribute + truthyAttr, isTruthyAttr := t.(InterpretableAttribute) + if isTruthyAttr { + tAttr = truthyAttr.Attr() + } else { + tAttr = p.attrFactory.RelativeAttribute(t.ID(), t) + } + + f := args[2] + var fAttr Attribute + falsyAttr, isFalsyAttr := f.(InterpretableAttribute) + if isFalsyAttr { + fAttr = falsyAttr.Attr() + } else { + fAttr = p.attrFactory.RelativeAttribute(f.ID(), f) + } + + return &evalAttr{ + adapter: p.adapter, + attr: p.attrFactory.ConditionalAttribute(expr.ID(), cond, tAttr, fAttr), + }, nil +} + +// planCallIndex either extends an attribute with the argument to the index operation, or creates +// a relative attribute based on the return of a function call or operation. +func (p *planner) planCallIndex(expr ast.Expr, args []Interpretable, optional bool) (Interpretable, error) { + op := args[0] + ind := args[1] + opType := p.typeMap[op.ID()] + + // Establish the attribute reference. + var err error + attr, isAttr := op.(InterpretableAttribute) + if !isAttr { + attr, err = p.relativeAttr(op.ID(), op, false) + if err != nil { + return nil, err + } + } + + // Construct the qualifier type. + var qual Qualifier + switch ind := ind.(type) { + case InterpretableConst: + qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind.Value(), optional) + case InterpretableAttribute: + qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind, optional) + default: + qual, err = p.relativeAttr(expr.ID(), ind, optional) + } + if err != nil { + return nil, err + } + + // Add the qualifier to the attribute + _, err = attr.AddQualifier(qual) + return attr, err +} + +// planCreateList generates a list construction Interpretable. +func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) { + list := expr.AsList() + optionalIndices := list.OptionalIndices() + elements := list.Elements() + optionals := make([]bool, len(elements)) + for _, index := range optionalIndices { + if index < 0 || index >= int32(len(elements)) { + return nil, fmt.Errorf("optional index %d out of element bounds [0, %d]", index, len(elements)) + } + optionals[index] = true + } + elems := make([]Interpretable, len(elements)) + for i, elem := range elements { + elemVal, err := p.Plan(elem) + if err != nil { + return nil, err + } + elems[i] = elemVal + } + return &evalList{ + id: expr.ID(), + elems: elems, + optionals: optionals, + hasOptionals: len(optionals) != 0, + adapter: p.adapter, + }, nil +} + +// planCreateStruct generates a map or object construction Interpretable. +func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) { + m := expr.AsMap() + entries := m.Entries() + optionals := make([]bool, len(entries)) + keys := make([]Interpretable, len(entries)) + vals := make([]Interpretable, len(entries)) + for i, e := range entries { + entry := e.AsMapEntry() + keyVal, err := p.Plan(entry.Key()) + if err != nil { + return nil, err + } + keys[i] = keyVal + + valVal, err := p.Plan(entry.Value()) + if err != nil { + return nil, err + } + vals[i] = valVal + optionals[i] = entry.IsOptional() + } + return &evalMap{ + id: expr.ID(), + keys: keys, + vals: vals, + optionals: optionals, + hasOptionals: len(optionals) != 0, + adapter: p.adapter, + }, nil +} + +// planCreateObj generates an object construction Interpretable. +func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) { + obj := expr.AsStruct() + typeName, defined := p.resolveTypeName(obj.TypeName()) + if !defined { + return nil, fmt.Errorf("unknown type: %s", obj.TypeName()) + } + objFields := obj.Fields() + optionals := make([]bool, len(objFields)) + fields := make([]string, len(objFields)) + vals := make([]Interpretable, len(objFields)) + for i, f := range objFields { + field := f.AsStructField() + fields[i] = field.Name() + val, err := p.Plan(field.Value()) + if err != nil { + return nil, err + } + vals[i] = val + optionals[i] = field.IsOptional() + } + return &evalObj{ + id: expr.ID(), + typeName: typeName, + fields: fields, + vals: vals, + optionals: optionals, + hasOptionals: len(optionals) != 0, + provider: p.provider, + }, nil +} + +// planComprehension generates an Interpretable fold operation. +func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { + fold := expr.AsComprehension() + accu, err := p.Plan(fold.AccuInit()) + if err != nil { + return nil, err + } + iterRange, err := p.Plan(fold.IterRange()) + if err != nil { + return nil, err + } + cond, err := p.Plan(fold.LoopCondition()) + if err != nil { + return nil, err + } + step, err := p.Plan(fold.LoopStep()) + if err != nil { + return nil, err + } + result, err := p.Plan(fold.Result()) + if err != nil { + return nil, err + } + return &evalFold{ + id: expr.ID(), + accuVar: fold.AccuVar(), + accu: accu, + iterVar: fold.IterVar(), + iterRange: iterRange, + cond: cond, + step: step, + result: result, + adapter: p.adapter, + }, nil +} + +// planConst generates a constant valued Interpretable. +func (p *planner) planConst(expr ast.Expr) (Interpretable, error) { + return NewConstValue(expr.ID(), expr.AsLiteral()), nil +} + +// resolveTypeName takes a qualified string constructed at parse time, applies the proto +// namespace resolution rules to it in a scan over possible matching types in the TypeProvider. +func (p *planner) resolveTypeName(typeName string) (string, bool) { + for _, qualifiedTypeName := range p.container.ResolveCandidateNames(typeName) { + if _, found := p.provider.FindStructType(qualifiedTypeName); found { + return qualifiedTypeName, true + } + } + return "", false +} + +// resolveFunction determines the call target, function name, and overload name from a given Expr +// value. +// +// The resolveFunction resolves ambiguities where a function may either be a receiver-style +// invocation or a qualified global function name. +// - The target expression may only consist of ident and select expressions. +// - The function is declared in the environment using its fully-qualified name. +// - The fully-qualified function name matches the string serialized target value. +func (p *planner) resolveFunction(expr ast.Expr) (ast.Expr, string, string) { + // Note: similar logic exists within the `checker/checker.go`. If making changes here + // please consider the impact on checker.go and consolidate implementations or mirror code + // as appropriate. + call := expr.AsCall() + var target ast.Expr = nil + if call.IsMemberFunction() { + target = call.Target() + } + fnName := call.FunctionName() + + // Checked expressions always have a reference map entry, and _should_ have the fully qualified + // function name as the fnName value. + oRef, hasOverload := p.refMap[expr.ID()] + if hasOverload { + if len(oRef.OverloadIDs) == 1 { + return target, fnName, oRef.OverloadIDs[0] + } + // Note, this namespaced function name will not appear as a fully qualified name in ASTs + // built and stored before cel-go v0.5.0; however, this functionality did not work at all + // before the v0.5.0 release. + return target, fnName, "" + } + + // Parse-only expressions need to handle the same logic as is normally performed at check time, + // but with potentially much less information. The only reliable source of information about + // which functions are configured is the dispatcher. + if target == nil { + // If the user has a parse-only expression, then it should have been configured as such in + // the interpreter dispatcher as it may have been omitted from the checker environment. + for _, qualifiedName := range p.container.ResolveCandidateNames(fnName) { + _, found := p.disp.FindOverload(qualifiedName) + if found { + return nil, qualifiedName, "" + } + } + // It's possible that the overload was not found, but this situation is accounted for in + // the planCall phase; however, the leading dot used for denoting fully-qualified + // namespaced identifiers must be stripped, as all declarations already use fully-qualified + // names. This stripping behavior is handled automatically by the ResolveCandidateNames + // call. + return target, stripLeadingDot(fnName), "" + } + + // Handle the situation where the function target actually indicates a qualified function name. + qualifiedPrefix, maybeQualified := p.toQualifiedName(target) + if maybeQualified { + maybeQualifiedName := qualifiedPrefix + "." + fnName + for _, qualifiedName := range p.container.ResolveCandidateNames(maybeQualifiedName) { + _, found := p.disp.FindOverload(qualifiedName) + if found { + // Clear the target to ensure the proper arity is used for finding the + // implementation. + return nil, qualifiedName, "" + } + } + } + // In the default case, the function is exactly as it was advertised: a receiver call on with + // an expression-based target with the given simple function name. + return target, fnName, "" +} + +// relativeAttr indicates that the attribute in this case acts as a qualifier and as such needs to +// be observed to ensure that it's evaluation value is properly recorded for state tracking. +func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (InterpretableAttribute, error) { + eAttr, ok := eval.(InterpretableAttribute) + if !ok { + eAttr = &evalAttr{ + adapter: p.adapter, + attr: p.attrFactory.RelativeAttribute(id, eval), + optional: opt, + } + } + // This looks like it should either decorate the new evalAttr node, or early return the InterpretableAttribute + decAttr, err := p.decorate(eAttr, nil) + if err != nil { + return nil, err + } + eAttr, ok = decAttr.(InterpretableAttribute) + if !ok { + return nil, fmt.Errorf("invalid attribute decoration: %v(%T)", decAttr, decAttr) + } + return eAttr, nil +} + +// toQualifiedName converts an expression AST into a qualified name if possible, with a boolean +// 'found' value that indicates if the conversion is successful. +func (p *planner) toQualifiedName(operand ast.Expr) (string, bool) { + // If the checker identified the expression as an attribute by the type-checker, then it can't + // possibly be part of qualified name in a namespace. + _, isAttr := p.refMap[operand.ID()] + if isAttr { + return "", false + } + // Since functions cannot be both namespaced and receiver functions, if the operand is not an + // qualified variable name, return the (possibly) qualified name given the expressions. + switch operand.Kind() { + case ast.IdentKind: + id := operand.AsIdent() + return id, true + case ast.SelectKind: + sel := operand.AsSelect() + // Test only expressions are not valid as qualified names. + if sel.IsTestOnly() { + return "", false + } + if qual, found := p.toQualifiedName(sel.Operand()); found { + return qual + "." + sel.FieldName(), true + } + } + return "", false +} + +func stripLeadingDot(name string) string { + if strings.HasPrefix(name, ".") { + return name[1:] + } + return name +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/prune.go b/vendor/github.com/authzed/cel-go/interpreter/prune.go new file mode 100644 index 0000000..b093f2a --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/prune.go @@ -0,0 +1,543 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" +) + +type astPruner struct { + ast.ExprFactory + expr ast.Expr + macroCalls map[int64]ast.Expr + state EvalState + nextExprID int64 +} + +// TODO Consider having a separate walk of the AST that finds common +// subexpressions. This can be called before or after constant folding to find +// common subexpressions. + +// PruneAst prunes the given AST based on the given EvalState and generates a new AST. +// Given AST is copied on write and a new AST is returned. +// Couple of typical use cases this interface would be: +// +// A) +// 1) Evaluate expr with some unknowns, +// 2) If result is unknown: +// +// a) PruneAst +// b) Goto 1 +// +// Functional call results which are known would be effectively cached across +// iterations. +// +// B) +// 1) Compile the expression (maybe via a service and maybe after checking a +// +// compiled expression does not exists in local cache) +// +// 2) Prepare the environment and the interpreter. Activation might be empty. +// 3) Eval the expression. This might return unknown or error or a concrete +// +// value. +// +// 4) PruneAst +// 4) Maybe cache the expression +// This is effectively constant folding the expression. How the environment is +// prepared in step 2 is flexible. For example, If the caller caches the +// compiled and constant folded expressions, but is not willing to constant +// fold(and thus cache results of) some external calls, then they can prepare +// the overloads accordingly. +func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST { + pruneState := NewEvalState() + for _, id := range state.IDs() { + v, _ := state.Value(id) + pruneState.SetValue(id, v) + } + pruner := &astPruner{ + ExprFactory: ast.NewExprFactory(), + expr: expr, + macroCalls: macroCalls, + state: pruneState, + nextExprID: getMaxID(expr)} + newExpr, _ := pruner.maybePrune(expr) + newInfo := ast.NewSourceInfo(nil) + for id, call := range pruner.macroCalls { + newInfo.SetMacroCall(id, call) + } + return ast.NewAST(newExpr, newInfo) +} + +func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) { + switch v := val.(type) { + case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint: + p.state.SetValue(id, val) + return p.NewLiteral(id, val), true + case types.Duration: + p.state.SetValue(id, val) + durationString := v.ConvertToType(types.StringType).(types.String) + return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true + case types.Timestamp: + timestampString := v.ConvertToType(types.StringType).(types.String) + return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true + } + + // Attempt to build a list literal. + if list, isList := val.(traits.Lister); isList { + sz := list.Size().(types.Int) + elemExprs := make([]ast.Expr, sz) + for i := types.Int(0); i < sz; i++ { + elem := list.Get(i) + if types.IsUnknownOrError(elem) { + return nil, false + } + elemExpr, ok := p.maybeCreateLiteral(p.nextID(), elem) + if !ok { + return nil, false + } + elemExprs[i] = elemExpr + } + p.state.SetValue(id, val) + return p.NewList(id, elemExprs, []int32{}), true + } + + // Create a map literal if possible. + if mp, isMap := val.(traits.Mapper); isMap { + it := mp.Iterator() + entries := make([]ast.EntryExpr, mp.Size().(types.Int)) + i := 0 + for it.HasNext() != types.False { + key := it.Next() + val := mp.Get(key) + if types.IsUnknownOrError(key) || types.IsUnknownOrError(val) { + return nil, false + } + keyExpr, ok := p.maybeCreateLiteral(p.nextID(), key) + if !ok { + return nil, false + } + valExpr, ok := p.maybeCreateLiteral(p.nextID(), val) + if !ok { + return nil, false + } + entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false) + entries[i] = entry + i++ + } + p.state.SetValue(id, val) + return p.NewMap(id, entries), true + } + + // TODO(issues/377) To construct message literals, the type provider will need to support + // the enumeration the fields for a given message. + return nil, false +} + +func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) { + elemVal, found := p.value(elem.ID()) + if found && elemVal.Type() == types.OptionalType { + opt := elemVal.(*types.Optional) + if !opt.HasValue() { + return nil, true + } + if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned { + return newElem, true + } + } + return elem, false +} + +func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) { + // elem in list + call := node.AsCall() + val, exists := p.maybeValue(call.Args()[1].ID()) + if !exists { + return nil, false + } + if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { + return p.maybeCreateLiteral(node.ID(), types.False) + } + return nil, false +} + +func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + arg := call.Args()[0] + val, exists := p.maybeValue(arg.ID()) + if !exists { + return nil, false + } + if b, ok := val.(types.Bool); ok { + return p.maybeCreateLiteral(node.ID(), !b) + } + return nil, false +} + +func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + // We know result is unknown, so we have at least one unknown arg + // and if one side is a known value, we know we can ignore it. + if v, exists := p.maybeValue(call.Args()[0].ID()); exists { + if v == types.True { + return p.maybeCreateLiteral(node.ID(), types.True) + } + return call.Args()[1], true + } + if v, exists := p.maybeValue(call.Args()[1].ID()); exists { + if v == types.True { + return p.maybeCreateLiteral(node.ID(), types.True) + } + return call.Args()[0], true + } + return nil, false +} + +func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + // We know result is unknown, so we have at least one unknown arg + // and if one side is a known value, we know we can ignore it. + if v, exists := p.maybeValue(call.Args()[0].ID()); exists { + if v == types.False { + return p.maybeCreateLiteral(node.ID(), types.False) + } + return call.Args()[1], true + } + if v, exists := p.maybeValue(call.Args()[1].ID()); exists { + if v == types.False { + return p.maybeCreateLiteral(node.ID(), types.False) + } + return call.Args()[0], true + } + return nil, false +} + +func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + cond, exists := p.maybeValue(call.Args()[0].ID()) + if !exists { + return nil, false + } + if cond.Value().(bool) { + return call.Args()[1], true + } + return call.Args()[2], true +} + +func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) { + if _, exists := p.value(node.ID()); !exists { + return nil, false + } + call := node.AsCall() + if call.FunctionName() == operators.LogicalOr { + return p.maybePruneOr(node) + } + if call.FunctionName() == operators.LogicalAnd { + return p.maybePruneAnd(node) + } + if call.FunctionName() == operators.Conditional { + return p.maybePruneConditional(node) + } + if call.FunctionName() == operators.In { + return p.maybePruneIn(node) + } + if call.FunctionName() == operators.LogicalNot { + return p.maybePruneLogicalNot(node) + } + return nil, false +} + +func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) { + return p.prune(node) +} + +func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) { + if node == nil { + return node, false + } + val, valueExists := p.maybeValue(node.ID()) + if valueExists { + if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok { + delete(p.macroCalls, node.ID()) + return newNode, true + } + } + if macro, found := p.macroCalls[node.ID()]; found { + // Ensure that intermediate values for the comprehension are cleared during pruning + if node.Kind() == ast.ComprehensionKind { + compre := node.AsComprehension() + visit(macro, clearIterVarVisitor(compre.IterVar(), p.state)) + } + // prune the expression in terms of the macro call instead of the expanded form. + if newMacro, pruned := p.prune(macro); pruned { + p.macroCalls[node.ID()] = newMacro + } + } + + // We have either an unknown/error value, or something we don't want to + // transform, or expression was not evaluated. If possible, drill down + // more. + switch node.Kind() { + case ast.SelectKind: + sel := node.AsSelect() + if operand, isPruned := p.maybePrune(sel.Operand()); isPruned { + if sel.IsTestOnly() { + return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true + } + return p.NewSelect(node.ID(), operand, sel.FieldName()), true + } + case ast.CallKind: + argsPruned := false + call := node.AsCall() + args := call.Args() + newArgs := make([]ast.Expr, len(args)) + for i, a := range args { + newArgs[i] = a + if arg, isPruned := p.maybePrune(a); isPruned { + argsPruned = true + newArgs[i] = arg + } + } + if !call.IsMemberFunction() { + newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...) + if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { + return prunedCall, true + } + return newCall, argsPruned + } + newTarget := call.Target() + targetPruned := false + if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned { + targetPruned = true + newTarget = prunedTarget + } + newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...) + if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { + return prunedCall, true + } + return newCall, targetPruned || argsPruned + case ast.ListKind: + l := node.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + optIndexMap := map[int32]bool{} + for _, i := range optIndices { + optIndexMap[i] = true + } + newOptIndexMap := make(map[int32]bool, len(optIndexMap)) + newElems := make([]ast.Expr, 0, len(elems)) + var listPruned bool + prunedIdx := 0 + for i, elem := range elems { + _, isOpt := optIndexMap[int32(i)] + if isOpt { + newElem, pruned := p.maybePruneOptional(elem) + if pruned { + listPruned = true + if newElem != nil { + newElems = append(newElems, newElem) + prunedIdx++ + } + continue + } + newOptIndexMap[int32(prunedIdx)] = true + } + if newElem, prunedElem := p.maybePrune(elem); prunedElem { + newElems = append(newElems, newElem) + listPruned = true + } else { + newElems = append(newElems, elem) + } + prunedIdx++ + } + optIndices = make([]int32, len(newOptIndexMap)) + idx := 0 + for i := range newOptIndexMap { + optIndices[idx] = i + idx++ + } + if listPruned { + return p.NewList(node.ID(), newElems, optIndices), true + } + case ast.MapKind: + var mapPruned bool + m := node.AsMap() + entries := m.Entries() + newEntries := make([]ast.EntryExpr, len(entries)) + for i, entry := range entries { + newEntries[i] = entry + e := entry.AsMapEntry() + newKey, keyPruned := p.maybePrune(e.Key()) + newValue, valuePruned := p.maybePrune(e.Value()) + if !keyPruned && !valuePruned { + continue + } + mapPruned = true + newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional()) + newEntries[i] = newEntry + } + if mapPruned { + return p.NewMap(node.ID(), newEntries), true + } + case ast.StructKind: + var structPruned bool + obj := node.AsStruct() + fields := obj.Fields() + newFields := make([]ast.EntryExpr, len(fields)) + for i, field := range fields { + newFields[i] = field + f := field.AsStructField() + newValue, prunedValue := p.maybePrune(f.Value()) + if !prunedValue { + continue + } + structPruned = true + newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional()) + newFields[i] = newEntry + } + if structPruned { + return p.NewStruct(node.ID(), obj.TypeName(), newFields), true + } + case ast.ComprehensionKind: + compre := node.AsComprehension() + // Only the range of the comprehension is pruned since the state tracking only records + // the last iteration of the comprehension and not each step in the evaluation which + // means that the any residuals computed in between might be inaccurate. + if newRange, pruned := p.maybePrune(compre.IterRange()); pruned { + return p.NewComprehension( + node.ID(), + newRange, + compre.IterVar(), + compre.AccuVar(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result(), + ), true + } + } + return node, false +} + +func (p *astPruner) value(id int64) (ref.Val, bool) { + val, found := p.state.Value(id) + return val, (found && val != nil) +} + +func (p *astPruner) maybeValue(id int64) (ref.Val, bool) { + val, found := p.value(id) + if !found || types.IsUnknownOrError(val) { + return nil, false + } + return val, true +} + +func (p *astPruner) nextID() int64 { + next := p.nextExprID + p.nextExprID++ + return next +} + +type astVisitor struct { + // visitEntry is called on every expr node, including those within a map/struct entry. + visitExpr func(expr ast.Expr) + // visitEntry is called before entering the key, value of a map/struct entry. + visitEntry func(entry ast.EntryExpr) +} + +func getMaxID(expr ast.Expr) int64 { + maxID := int64(1) + visit(expr, maxIDVisitor(&maxID)) + return maxID +} + +func clearIterVarVisitor(varName string, state EvalState) astVisitor { + return astVisitor{ + visitExpr: func(e ast.Expr) { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + state.SetValue(e.ID(), nil) + } + }, + } +} + +func maxIDVisitor(maxID *int64) astVisitor { + return astVisitor{ + visitExpr: func(e ast.Expr) { + if e.ID() >= *maxID { + *maxID = e.ID() + 1 + } + }, + visitEntry: func(e ast.EntryExpr) { + if e.ID() >= *maxID { + *maxID = e.ID() + 1 + } + }, + } +} + +func visit(expr ast.Expr, visitor astVisitor) { + exprs := []ast.Expr{expr} + for len(exprs) != 0 { + e := exprs[0] + if visitor.visitExpr != nil { + visitor.visitExpr(e) + } + exprs = exprs[1:] + switch e.Kind() { + case ast.SelectKind: + exprs = append(exprs, e.AsSelect().Operand()) + case ast.CallKind: + call := e.AsCall() + if call.Target() != nil { + exprs = append(exprs, call.Target()) + } + exprs = append(exprs, call.Args()...) + case ast.ComprehensionKind: + compre := e.AsComprehension() + exprs = append(exprs, + compre.IterRange(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result()) + case ast.ListKind: + list := e.AsList() + exprs = append(exprs, list.Elements()...) + case ast.MapKind: + for _, entry := range e.AsMap().Entries() { + e := entry.AsMapEntry() + if visitor.visitEntry != nil { + visitor.visitEntry(entry) + } + exprs = append(exprs, e.Key()) + exprs = append(exprs, e.Value()) + } + case ast.StructKind: + for _, entry := range e.AsStruct().Fields() { + f := entry.AsStructField() + if visitor.visitEntry != nil { + visitor.visitEntry(entry) + } + exprs = append(exprs, f.Value()) + } + } + } +} diff --git a/vendor/github.com/authzed/cel-go/interpreter/runtimecost.go b/vendor/github.com/authzed/cel-go/interpreter/runtimecost.go new file mode 100644 index 0000000..905aeea --- /dev/null +++ b/vendor/github.com/authzed/cel-go/interpreter/runtimecost.go @@ -0,0 +1,316 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interpreter + +import ( + "math" + + "github.com/authzed/cel-go/common" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "github.com/authzed/cel-go/common/types/traits" +) + +// WARNING: Any changes to cost calculations in this file require a corresponding change in checker/cost.go + +// ActualCostEstimator provides function call cost estimations at runtime +// CallCost returns an estimated cost for the function overload invocation with the given args, or nil if it has no +// estimate to provide. CEL attempts to provide reasonable estimates for its standard function library, so CallCost +// should typically not need to provide an estimate for CELs standard function. +type ActualCostEstimator interface { + CallCost(function, overloadID string, args []ref.Val, result ref.Val) *uint64 +} + +// CostObserver provides an observer that tracks runtime cost. +func CostObserver(tracker *CostTracker) EvalObserver { + observer := func(id int64, programStep any, val ref.Val) { + switch t := programStep.(type) { + case ConstantQualifier: + // TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them + // and enable the below pop. Once enabled this can case can be collapsed into the Qualifier case. + tracker.cost++ + case InterpretableConst: + // zero cost + case InterpretableAttribute: + switch a := t.Attr().(type) { + case *conditionalAttribute: + // Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions. + tracker.stack.drop(a.falsy.ID(), a.truthy.ID(), a.expr.ID()) + default: + tracker.stack.drop(t.Attr().ID()) + tracker.cost += common.SelectAndIdentCost + } + if !tracker.presenceTestHasCost { + if _, isTestOnly := programStep.(*evalTestOnly); isTestOnly { + tracker.cost -= common.SelectAndIdentCost + } + } + case *evalExhaustiveConditional: + // Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions. + tracker.stack.drop(t.attr.falsy.ID(), t.attr.truthy.ID(), t.attr.expr.ID()) + + // While the field names are identical, the boolean operation eval structs do not share an interface and so + // must be handled individually. + case *evalOr: + for _, term := range t.terms { + tracker.stack.drop(term.ID()) + } + case *evalAnd: + for _, term := range t.terms { + tracker.stack.drop(term.ID()) + } + case *evalExhaustiveOr: + for _, term := range t.terms { + tracker.stack.drop(term.ID()) + } + case *evalExhaustiveAnd: + for _, term := range t.terms { + tracker.stack.drop(term.ID()) + } + case *evalFold: + tracker.stack.drop(t.iterRange.ID()) + case Qualifier: + tracker.cost++ + case InterpretableCall: + if argVals, ok := tracker.stack.dropArgs(t.Args()); ok { + tracker.cost += tracker.costCall(t, argVals, val) + } + case InterpretableConstructor: + tracker.stack.dropArgs(t.InitVals()) + switch t.Type() { + case types.ListType: + tracker.cost += common.ListCreateBaseCost + case types.MapType: + tracker.cost += common.MapCreateBaseCost + default: + tracker.cost += common.StructCreateBaseCost + } + } + tracker.stack.push(val, id) + + if tracker.Limit != nil && tracker.cost > *tracker.Limit { + panic(EvalCancelledError{Cause: CostLimitExceeded, Message: "operation cancelled: actual cost limit exceeded"}) + } + } + return observer +} + +// CostTrackerOption configures the behavior of CostTracker objects. +type CostTrackerOption func(*CostTracker) error + +// CostTrackerLimit sets the runtime limit on the evaluation cost during execution and will terminate the expression +// evaluation if the limit is exceeded. +func CostTrackerLimit(limit uint64) CostTrackerOption { + return func(tracker *CostTracker) error { + tracker.Limit = &limit + return nil + } +} + +// PresenceTestHasCost determines whether presence testing has a cost of one or zero. +// Defaults to presence test has a cost of one. +func PresenceTestHasCost(hasCost bool) CostTrackerOption { + return func(tracker *CostTracker) error { + tracker.presenceTestHasCost = hasCost + return nil + } +} + +// NewCostTracker creates a new CostTracker with a given estimator and a set of functional CostTrackerOption values. +func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) { + tracker := &CostTracker{ + Estimator: estimator, + overloadTrackers: map[string]FunctionTracker{}, + presenceTestHasCost: true, + } + for _, opt := range opts { + err := opt(tracker) + if err != nil { + return nil, err + } + } + return tracker, nil +} + +// OverloadCostTracker binds an overload ID to a runtime FunctionTracker implementation. +// +// OverloadCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or +// optional cost tracking changes. +func OverloadCostTracker(overloadID string, fnTracker FunctionTracker) CostTrackerOption { + return func(tracker *CostTracker) error { + tracker.overloadTrackers[overloadID] = fnTracker + return nil + } +} + +// FunctionTracker computes the actual cost of evaluating the functions with the given arguments and result. +type FunctionTracker func(args []ref.Val, result ref.Val) *uint64 + +// CostTracker represents the information needed for tracking runtime cost. +type CostTracker struct { + Estimator ActualCostEstimator + overloadTrackers map[string]FunctionTracker + Limit *uint64 + presenceTestHasCost bool + + cost uint64 + stack refValStack +} + +// ActualCost returns the runtime cost +func (c *CostTracker) ActualCost() uint64 { + return c.cost +} + +func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result ref.Val) uint64 { + var cost uint64 + if len(c.overloadTrackers) != 0 { + if tracker, found := c.overloadTrackers[call.OverloadID()]; found { + callCost := tracker(args, result) + if callCost != nil { + cost += *callCost + return cost + } + } + } + if c.Estimator != nil { + callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), args, result) + if callCost != nil { + cost += *callCost + return cost + } + } + // if user didn't specify, the default way of calculating runtime cost would be used. + // if user has their own implementation of ActualCostEstimator, make sure to cover the mapping between overloadId and cost calculation + switch call.OverloadID() { + // O(n) functions + case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString: + cost += uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor)) + case overloads.InList: + // If a list is composed entirely of constant values this is O(1), but we don't account for that here. + // We just assume all list containment checks are O(n). + cost += c.actualSize(args[1]) + // O(min(m, n)) functions + case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString, + overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes, + overloads.Equals, overloads.NotEquals: + // When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.), + // the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost + // of 1. + lhsSize := c.actualSize(args[0]) + rhsSize := c.actualSize(args[1]) + minSize := lhsSize + if rhsSize < minSize { + minSize = rhsSize + } + cost += uint64(math.Ceil(float64(minSize) * common.StringTraversalCostFactor)) + // O(m+n) functions + case overloads.AddString, overloads.AddBytes: + // In the worst case scenario, we would need to reallocate a new backing store and copy both operands over. + cost += uint64(math.Ceil(float64(c.actualSize(args[0])+c.actualSize(args[1])) * common.StringTraversalCostFactor)) + // O(nm) functions + case overloads.MatchesString: + // https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL + // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 + // in case where string is empty but regex is still expensive. + strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(args[0]))) * common.StringTraversalCostFactor)) + // We don't know how many expressions are in the regex, just the string length (a huge + // improvement here would be to somehow get a count the number of expressions in the regex or + // how many states are in the regex state machine and use that to measure regex cost). + // For now, we're making a guess that each expression in a regex is typically at least 4 chars + // in length. + regexCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.RegexStringLengthCostFactor)) + cost += strCost * regexCost + case overloads.ContainsString: + strCost := uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor)) + substrCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.StringTraversalCostFactor)) + cost += strCost * substrCost + + default: + // The following operations are assumed to have O(1) complexity. + // - AddList due to the implementation. Index lookup can be O(c) the + // number of concatenated lists, but we don't track that is cost calculations. + // - Conversions, since none perform a traversal of a type of unbound length. + // - Computing the size of strings, byte sequences, lists and maps. + // - Logical operations and all operators on fixed width scalars (comparisons, equality) + // - Any functions that don't have a declared cost either here or in provided ActualCostEstimator. + cost++ + + } + return cost +} + +// actualSize returns the size of value +func (c *CostTracker) actualSize(value ref.Val) uint64 { + if sz, ok := value.(traits.Sizer); ok { + return uint64(sz.Size().(types.Int)) + } + return 1 +} + +type stackVal struct { + Val ref.Val + ID int64 +} + +// refValStack keeps track of values of the stack for cost calculation purposes +type refValStack []stackVal + +func (s *refValStack) push(val ref.Val, id int64) { + value := stackVal{Val: val, ID: id} + *s = append(*s, value) +} + +// TODO: Allowing drop and dropArgs to remove stack items above the IDs they are provided is a workaround. drop and dropArgs +// should find and remove only the stack items matching the provided IDs once all attributes are properly pushed and popped from stack. + +// drop searches the stack for each ID and removes the ID and all stack items above it. +// If none of the IDs are found, the stack is not modified. +// WARNING: It is possible for multiple expressions with the same ID to exist (due to how macros are implemented) so it's +// possible that a dropped ID will remain on the stack. They should be removed when IDs on the stack are popped. +func (s *refValStack) drop(ids ...int64) { + for _, id := range ids { + for idx := len(*s) - 1; idx >= 0; idx-- { + if (*s)[idx].ID == id { + *s = (*s)[:idx] + break + } + } + } +} + +// dropArgs searches the stack for all the args by their IDs, accumulates their associated ref.Vals and drops any +// stack items above any of the arg IDs. If any of the IDs are not found the stack, false is returned. +// Args are assumed to be found in the stack in reverse order, i.e. the last arg is expected to be found highest in +// the stack. +// WARNING: It is possible for multiple expressions with the same ID to exist (due to how macros are implemented) so it's +// possible that a dropped ID will remain on the stack. They should be removed when IDs on the stack are popped. +func (s *refValStack) dropArgs(args []Interpretable) ([]ref.Val, bool) { + result := make([]ref.Val, len(args)) +argloop: + for nIdx := len(args) - 1; nIdx >= 0; nIdx-- { + for idx := len(*s) - 1; idx >= 0; idx-- { + if (*s)[idx].ID == args[nIdx].ID() { + el := (*s)[idx] + *s = (*s)[:idx] + result[nIdx] = el.Val + continue argloop + } + } + return nil, false + } + return result, true +} |
