summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/cel-go/cel/validator.go
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-24 17:58:01 -0600
committermo khan <mo@mokhan.ca>2025-07-24 17:58:01 -0600
commit72296119fc9755774719f8f625ad03e0e0ec457a (patch)
treeed236ddee12a20fb55b7cfecf13f62d3a000dcb5 /vendor/github.com/authzed/cel-go/cel/validator.go
parenta920a8cfe415858bb2777371a77018599ffed23f (diff)
parenteaa1bd3b8e12934aed06413d75e7482ac58d805a (diff)
Merge branch 'the-spice-must-flow' into 'main'
Add SpiceDB Authorization See merge request gitlab-org/software-supply-chain-security/authorization/sparkled!19
Diffstat (limited to 'vendor/github.com/authzed/cel-go/cel/validator.go')
-rw-r--r--vendor/github.com/authzed/cel-go/cel/validator.go375
1 files changed, 375 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/cel-go/cel/validator.go b/vendor/github.com/authzed/cel-go/cel/validator.go
new file mode 100644
index 0000000..1664c13
--- /dev/null
+++ b/vendor/github.com/authzed/cel-go/cel/validator.go
@@ -0,0 +1,375 @@
+// Copyright 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cel
+
+import (
+ "fmt"
+ "reflect"
+ "regexp"
+
+ "github.com/authzed/cel-go/common/ast"
+ "github.com/authzed/cel-go/common/overloads"
+)
+
+const (
+ homogeneousValidatorName = "cel.lib.std.validate.types.homogeneous"
+
+ // HomogeneousAggregateLiteralExemptFunctions is the ValidatorConfig key used to configure
+ // the set of function names which are exempt from homogeneous type checks. The expected type
+ // is a string list of function names.
+ //
+ // As an example, the `<string>.format([args])` call expects the input arguments list to be
+ // comprised of a variety of types which correspond to the types expected by the format control
+ // clauses; however, all other uses of a mixed element type list, would be unexpected.
+ HomogeneousAggregateLiteralExemptFunctions = homogeneousValidatorName + ".exempt"
+)
+
+// ASTValidators configures a set of ASTValidator instances into the target environment.
+//
+// Validators are applied in the order in which the are specified and are treated as singletons.
+// The same ASTValidator with a given name will not be applied more than once.
+func ASTValidators(validators ...ASTValidator) EnvOption {
+ return func(e *Env) (*Env, error) {
+ for _, v := range validators {
+ if !e.HasValidator(v.Name()) {
+ e.validators = append(e.validators, v)
+ }
+ }
+ return e, nil
+ }
+}
+
+// ASTValidator defines a singleton interface for validating a type-checked Ast against an environment.
+//
+// Note: the Issues argument is mutable in the sense that it is intended to collect errors which will be
+// reported to the caller.
+type ASTValidator interface {
+ // Name returns the name of the validator. Names must be unique.
+ Name() string
+
+ // Validate validates a given Ast within an Environment and collects a set of potential issues.
+ //
+ // The ValidatorConfig is generated from the set of ASTValidatorConfigurer instances prior to
+ // the invocation of the Validate call. The expectation is that the validator configuration
+ // is created in sequence and immutable once provided to the Validate call.
+ //
+ // See individual validators for more information on their configuration keys and configuration
+ // properties.
+ Validate(*Env, ValidatorConfig, *ast.AST, *Issues)
+}
+
+// ValidatorConfig provides an accessor method for querying validator configuration state.
+type ValidatorConfig interface {
+ GetOrDefault(name string, value any) any
+}
+
+// MutableValidatorConfig provides mutation methods for querying and updating validator configuration
+// settings.
+type MutableValidatorConfig interface {
+ ValidatorConfig
+ Set(name string, value any) error
+}
+
+// ASTValidatorConfigurer indicates that this object, currently expected to be an ASTValidator,
+// participates in validator configuration settings.
+//
+// This interface may be split from the expectation of being an ASTValidator instance in the future.
+type ASTValidatorConfigurer interface {
+ Configure(MutableValidatorConfig) error
+}
+
+// validatorConfig implements the ValidatorConfig and MutableValidatorConfig interfaces.
+type validatorConfig struct {
+ data map[string]any
+}
+
+// newValidatorConfig initializes the validator config with default values for core CEL validators.
+func newValidatorConfig() *validatorConfig {
+ return &validatorConfig{
+ data: map[string]any{
+ HomogeneousAggregateLiteralExemptFunctions: []string{},
+ },
+ }
+}
+
+// GetOrDefault returns the configured value for the name, if present, else the input default value.
+//
+// Note, the type-agreement between the input default and configured value is not checked on read.
+func (config *validatorConfig) GetOrDefault(name string, value any) any {
+ v, found := config.data[name]
+ if !found {
+ return value
+ }
+ return v
+}
+
+// Set configures a validator option with the given name and value.
+//
+// If the value had previously been set, the new value must have the same reflection type as the old one,
+// or the call will error.
+func (config *validatorConfig) Set(name string, value any) error {
+ v, found := config.data[name]
+ if found && reflect.TypeOf(v) != reflect.TypeOf(value) {
+ return fmt.Errorf("incompatible configuration type for %s, got %T, wanted %T", name, value, v)
+ }
+ config.data[name] = value
+ return nil
+}
+
+// ExtendedValidations collects a set of common AST validations which reduce the likelihood of runtime errors.
+//
+// - Validate duration and timestamp literals
+// - Ensure regex strings are valid
+// - Disable mixed type list and map literals
+func ExtendedValidations() EnvOption {
+ return ASTValidators(
+ ValidateDurationLiterals(),
+ ValidateTimestampLiterals(),
+ ValidateRegexLiterals(),
+ ValidateHomogeneousAggregateLiterals(),
+ )
+}
+
+// ValidateDurationLiterals ensures that duration literal arguments are valid immediately after type-check.
+func ValidateDurationLiterals() ASTValidator {
+ return newFormatValidator(overloads.TypeConvertDuration, 0, evalCall)
+}
+
+// ValidateTimestampLiterals ensures that timestamp literal arguments are valid immediately after type-check.
+func ValidateTimestampLiterals() ASTValidator {
+ return newFormatValidator(overloads.TypeConvertTimestamp, 0, evalCall)
+}
+
+// ValidateRegexLiterals ensures that regex patterns are validated after type-check.
+func ValidateRegexLiterals() ASTValidator {
+ return newFormatValidator(overloads.Matches, 0, compileRegex)
+}
+
+// ValidateHomogeneousAggregateLiterals checks that all list and map literals entries have the same types, i.e.
+// no mixed list element types or mixed map key or map value types.
+//
+// Note: the string format call relies on a mixed element type list for ease of use, so this check skips all
+// literals which occur within string format calls.
+func ValidateHomogeneousAggregateLiterals() ASTValidator {
+ return homogeneousAggregateLiteralValidator{}
+}
+
+// ValidateComprehensionNestingLimit ensures that comprehension nesting does not exceed the specified limit.
+//
+// This validator can be useful for preventing arbitrarily nested comprehensions which can take high polynomial
+// time to complete.
+//
+// Note, this limit does not apply to comprehensions with an empty iteration range, as these comprehensions have
+// no actual looping cost. The cel.bind() utilizes the comprehension structure to perform local variable
+// assignments and supplies an empty iteration range, so they won't count against the nesting limit either.
+func ValidateComprehensionNestingLimit(limit int) ASTValidator {
+ return nestingLimitValidator{limit: limit}
+}
+
+type argChecker func(env *Env, call, arg ast.Expr) error
+
+func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
+ return formatValidator{
+ funcName: funcName,
+ check: check,
+ argNum: argNum,
+ }
+}
+
+type formatValidator struct {
+ funcName string
+ argNum int
+ check argChecker
+}
+
+// Name returns the unique name of this function format validator.
+func (v formatValidator) Name() string {
+ return fmt.Sprintf("cel.lib.std.validate.functions.%s", v.funcName)
+}
+
+// Validate searches the AST for uses of a given function name with a constant argument and performs a check
+// on whether the argument is a valid literal value.
+func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
+ root := ast.NavigateAST(a)
+ funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
+ for _, call := range funcCalls {
+ callArgs := call.AsCall().Args()
+ if len(callArgs) <= v.argNum {
+ continue
+ }
+ litArg := callArgs[v.argNum]
+ if litArg.Kind() != ast.LiteralKind {
+ continue
+ }
+ if err := v.check(e, call, litArg); err != nil {
+ iss.ReportErrorAtID(litArg.ID(), "invalid %s argument", v.funcName)
+ }
+ }
+}
+
+func evalCall(env *Env, call, arg ast.Expr) error {
+ ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))}
+ prg, err := env.Program(ast)
+ if err != nil {
+ return err
+ }
+ _, _, err = prg.Eval(NoVars())
+ return err
+}
+
+func compileRegex(_ *Env, _, arg ast.Expr) error {
+ pattern := arg.AsLiteral().Value().(string)
+ _, err := regexp.Compile(pattern)
+ return err
+}
+
+type homogeneousAggregateLiteralValidator struct{}
+
+// Name returns the unique name of the homogeneous type validator.
+func (homogeneousAggregateLiteralValidator) Name() string {
+ return homogeneousValidatorName
+}
+
+// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
+//
+// This validator makes an exception for list and map literals which occur at any level of nesting within
+// string format calls.
+func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) {
+ var exemptedFunctions []string
+ exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string)
+ root := ast.NavigateAST(a)
+ listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
+ for _, listExpr := range listExprs {
+ if inExemptFunction(listExpr, exemptedFunctions) {
+ continue
+ }
+ l := listExpr.AsList()
+ elements := l.Elements()
+ optIndices := l.OptionalIndices()
+ var elemType *Type
+ for i, e := range elements {
+ et := a.GetType(e.ID())
+ if isOptionalIndex(i, optIndices) {
+ et = et.Parameters()[0]
+ }
+ if elemType == nil {
+ elemType = et
+ continue
+ }
+ if !elemType.IsEquivalentType(et) {
+ v.typeMismatch(iss, e.ID(), elemType, et)
+ break
+ }
+ }
+ }
+ mapExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.MapKind))
+ for _, mapExpr := range mapExprs {
+ if inExemptFunction(mapExpr, exemptedFunctions) {
+ continue
+ }
+ m := mapExpr.AsMap()
+ entries := m.Entries()
+ var keyType, valType *Type
+ for _, e := range entries {
+ mapEntry := e.AsMapEntry()
+ key, val := mapEntry.Key(), mapEntry.Value()
+ kt, vt := a.GetType(key.ID()), a.GetType(val.ID())
+ if mapEntry.IsOptional() {
+ vt = vt.Parameters()[0]
+ }
+ if keyType == nil && valType == nil {
+ keyType, valType = kt, vt
+ continue
+ }
+ if !keyType.IsEquivalentType(kt) {
+ v.typeMismatch(iss, key.ID(), keyType, kt)
+ }
+ if !valType.IsEquivalentType(vt) {
+ v.typeMismatch(iss, val.ID(), valType, vt)
+ }
+ }
+ }
+}
+
+func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
+ parent, found := e.Parent()
+ for found {
+ if parent.Kind() == ast.CallKind {
+ fnName := parent.AsCall().FunctionName()
+ for _, exempt := range exemptFunctions {
+ if exempt == fnName {
+ return true
+ }
+ }
+ }
+ parent, found = parent.Parent()
+ }
+ return false
+}
+
+func isOptionalIndex(i int, optIndices []int32) bool {
+ for _, optInd := range optIndices {
+ if i == int(optInd) {
+ return true
+ }
+ }
+ return false
+}
+
+func (homogeneousAggregateLiteralValidator) typeMismatch(iss *Issues, id int64, expected, actual *Type) {
+ iss.ReportErrorAtID(id, "expected type '%s' but found '%s'", FormatCELType(expected), FormatCELType(actual))
+}
+
+type nestingLimitValidator struct {
+ limit int
+}
+
+func (v nestingLimitValidator) Name() string {
+ return "cel.lib.std.validate.comprehension_nesting_limit"
+}
+
+func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) {
+ root := ast.NavigateAST(a)
+ comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))
+ if len(comprehensions) <= v.limit {
+ return
+ }
+ for _, comp := range comprehensions {
+ count := 0
+ e := comp
+ hasParent := true
+ for hasParent {
+ // When the expression is not a comprehension, continue to the next ancestor.
+ if e.Kind() != ast.ComprehensionKind {
+ e, hasParent = e.Parent()
+ continue
+ }
+ // When the comprehension has an empty range, continue to the next ancestor
+ // as this comprehension does not have any associated cost.
+ iterRange := e.AsComprehension().IterRange()
+ if iterRange.Kind() == ast.ListKind && iterRange.AsList().Size() == 0 {
+ e, hasParent = e.Parent()
+ continue
+ }
+ // Otherwise check the nesting limit.
+ count++
+ if count > v.limit {
+ iss.ReportErrorAtID(comp.ID(), "comprehension exceeds nesting limit")
+ break
+ }
+ e, hasParent = e.Parent()
+ }
+ }
+}