diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/spicedb/pkg/caveats | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/caveats')
21 files changed, 1873 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/compile.go b/vendor/github.com/authzed/spicedb/pkg/caveats/compile.go new file mode 100644 index 0000000..312a701 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/compile.go @@ -0,0 +1,226 @@ +package caveats + +import ( + "fmt" + "strings" + + "github.com/authzed/cel-go/cel" + "github.com/authzed/cel-go/common" + "google.golang.org/protobuf/proto" + + "github.com/authzed/spicedb/pkg/caveats/replacer" + "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil/mapz" + impl "github.com/authzed/spicedb/pkg/proto/impl/v1" +) + +const anonymousCaveat = "" + +const maxCaveatExpressionSize = 100_000 // characters + +// CompiledCaveat is a compiled form of a caveat. +type CompiledCaveat struct { + // env is the environment under which the CEL program was compiled. + celEnv *cel.Env + + // ast is the AST form of the CEL program. + ast *cel.Ast + + // name of the caveat + name string +} + +// Name represents a user-friendly reference to a caveat +func (cc CompiledCaveat) Name() string { + return cc.name +} + +// ExprString returns the string-form of the caveat. +func (cc CompiledCaveat) ExprString() (string, error) { + return cel.AstToString(cc.ast) +} + +// RewriteVariable replaces the use of a variable with another variable in the compiled caveat. +func (cc CompiledCaveat) RewriteVariable(oldName, newName string) (CompiledCaveat, error) { + // Find the existing parameter name and get its type. + oldExpr, issues := cc.celEnv.Compile(oldName) + if issues.Err() != nil { + return CompiledCaveat{}, fmt.Errorf("failed to parse old variable name: %w", issues.Err()) + } + + oldType := oldExpr.OutputType() + + // Ensure the new variable name is not used. + _, niss := cc.celEnv.Compile(newName) + if niss.Err() == nil { + return CompiledCaveat{}, fmt.Errorf("variable name '%s' is already used", newName) + } + + // Extend the environment with the new variable name. + extended, err := cc.celEnv.Extend(cel.Variable(newName, oldType)) + if err != nil { + return CompiledCaveat{}, fmt.Errorf("failed to extend environment: %w", err) + } + + // Replace the variable in the AST. + updatedAst, err := replacer.ReplaceVariable(extended, cc.ast, oldName, newName) + if err != nil { + return CompiledCaveat{}, fmt.Errorf("failed to rewrite variable: %w", err) + } + + return CompiledCaveat{extended, updatedAst, cc.name}, nil +} + +// Serialize serializes the compiled caveat into a byte string for storage. +func (cc CompiledCaveat) Serialize() ([]byte, error) { + cexpr, err := cel.AstToCheckedExpr(cc.ast) + if err != nil { + return nil, err + } + + caveat := &impl.DecodedCaveat{ + KindOneof: &impl.DecodedCaveat_Cel{ + Cel: cexpr, + }, + Name: cc.name, + } + + // TODO(jschorr): change back to MarshalVT once stable is supported. + // See: https://github.com/planetscale/vtprotobuf/pull/133 + return proto.MarshalOptions{Deterministic: true}.Marshal(caveat) +} + +// ReferencedParameters returns the names of the parameters referenced in the expression. +func (cc CompiledCaveat) ReferencedParameters(parameters []string) (*mapz.Set[string], error) { + referencedParams := mapz.NewSet[string]() + definedParameters := mapz.NewSet[string]() + definedParameters.Extend(parameters) + + checked, err := cel.AstToCheckedExpr(cc.ast) + if err != nil { + return nil, err + } + + referencedParameters(definedParameters, checked.Expr, referencedParams) + return referencedParams, nil +} + +// CompileCaveatWithName compiles a caveat string into a compiled caveat with a given name, +// or returns the compilation errors. +func CompileCaveatWithName(env *Environment, exprString, name string) (*CompiledCaveat, error) { + c, err := CompileCaveatWithSource(env, name, common.NewStringSource(exprString, name), nil) + if err != nil { + return nil, err + } + c.name = name + return c, nil +} + +// CompileCaveatWithSource compiles a caveat source into a compiled caveat, or returns the compilation errors. +func CompileCaveatWithSource(env *Environment, name string, source common.Source, startPosition SourcePosition) (*CompiledCaveat, error) { + celEnv, err := env.asCelEnvironment() + if err != nil { + return nil, err + } + + if len(strings.TrimSpace(source.Content())) > maxCaveatExpressionSize { + return nil, fmt.Errorf("caveat expression provided exceeds maximum allowed size of %d characters", maxCaveatExpressionSize) + } + + ast, issues := celEnv.CompileSource(source) + if issues != nil && issues.Err() != nil { + if startPosition == nil { + return nil, MultipleCompilationError{issues.Err(), issues} + } + + // Construct errors with the source location adjusted based on the starting source position + // in the parent schema (if any). This ensures that the errors coming out of CEL show the correct + // *overall* location information.. + line, col, err := startPosition.LineAndColumn() + if err != nil { + return nil, err + } + + adjustedErrors := common.NewErrors(source) + for _, existingErr := range issues.Errors() { + location := existingErr.Location + + // NOTE: Our locations are zero-indexed while CEL is 1-indexed, so we need to adjust the line/column values accordingly. + if location.Line() == 1 { + location = common.NewLocation(line+location.Line(), col+location.Column()) + } else { + location = common.NewLocation(line+location.Line(), location.Column()) + } + + adjustedError := &common.Error{ + Message: existingErr.Message, + ExprID: existingErr.ExprID, + Location: location, + } + + adjustedErrors = adjustedErrors.Append([]*common.Error{ + adjustedError, + }) + } + + adjustedIssues := cel.NewIssues(adjustedErrors) + return nil, MultipleCompilationError{adjustedIssues.Err(), adjustedIssues} + } + + if ast.OutputType() != cel.BoolType { + return nil, MultipleCompilationError{fmt.Errorf("caveat expression must result in a boolean value: found `%s`", ast.OutputType().String()), nil} + } + + compiled := &CompiledCaveat{celEnv, ast, anonymousCaveat} + compiled.name = name + return compiled, nil +} + +// compileCaveat compiles a caveat string into a compiled caveat, or returns the compilation errors. +func compileCaveat(env *Environment, exprString string) (*CompiledCaveat, error) { + s := common.NewStringSource(exprString, "caveat") + return CompileCaveatWithSource(env, "caveat", s, nil) +} + +// DeserializeCaveat deserializes a byte-serialized caveat back into a CompiledCaveat. +func DeserializeCaveat(serialized []byte, parameterTypes map[string]types.VariableType) (*CompiledCaveat, error) { + env, err := EnvForVariables(parameterTypes) + if err != nil { + return nil, err + } + + return DeserializeCaveatWithEnviroment(env, serialized) +} + +// DeserializeCaveatWithTypeSet deserializes a byte-serialized caveat back into a CompiledCaveat. +func DeserializeCaveatWithTypeSet(ts *types.TypeSet, serialized []byte, parameterTypes map[string]types.VariableType) (*CompiledCaveat, error) { + env, err := EnvForVariablesWithTypeSet(ts, parameterTypes) + if err != nil { + return nil, err + } + + return DeserializeCaveatWithEnviroment(env, serialized) +} + +// DeserializeCaveatWithEnviroment deserializes a byte-serialized caveat back into a CompiledCaveat, +// using the provided environment. It is the responsibility of the caller to ensure that the environment +// has the parameters defined as variables. +func DeserializeCaveatWithEnviroment(env *Environment, serialized []byte) (*CompiledCaveat, error) { + if len(serialized) == 0 { + return nil, fmt.Errorf("given empty serialized") + } + + caveat := &impl.DecodedCaveat{} + err := caveat.UnmarshalVT(serialized) + if err != nil { + return nil, err + } + + celEnv, err := env.asCelEnvironment() + if err != nil { + return nil, err + } + + ast := cel.CheckedExprToAst(caveat.GetCel()) + return &CompiledCaveat{celEnv, ast, caveat.Name}, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/context.go b/vendor/github.com/authzed/spicedb/pkg/caveats/context.go new file mode 100644 index 0000000..5580eb4 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/context.go @@ -0,0 +1,45 @@ +package caveats + +import ( + "maps" + "time" + + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/spicedb/pkg/caveats/types" +) + +// ConvertContextToStruct converts the given context values into a context struct. +func ConvertContextToStruct(contextValues map[string]any) (*structpb.Struct, error) { + cloned := maps.Clone(contextValues) + cloned = convertCustomValues(cloned).(map[string]any) + return structpb.NewStruct(cloned) +} + +func convertCustomValues(value any) any { + switch v := value.(type) { + case map[string]any: + for key, value := range v { + v[key] = convertCustomValues(value) + } + return v + + case []any: + for index, current := range v { + v[index] = convertCustomValues(current) + } + return v + + case time.Time: + return v.Format(time.RFC3339) + + case time.Duration: + return v.String() + + case types.CustomType: + return v.SerializedString() + + default: + return v + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/context_hash.go b/vendor/github.com/authzed/spicedb/pkg/caveats/context_hash.go new file mode 100644 index 0000000..658b492 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/context_hash.go @@ -0,0 +1,90 @@ +package caveats + +import ( + "bytes" + "fmt" + "net/url" + "sort" + "strconv" + + "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/structpb" +) + +// HasherInterface is an interface for writing context to be hashed. +type HasherInterface interface { + WriteString(value string) +} + +// StableContextStringForHashing returns a stable string version of the context, for use in hashing. +func StableContextStringForHashing(context *structpb.Struct) string { + b := bytes.NewBufferString("") + hc := HashableContext{context} + hc.AppendToHash(wrappedBuffer{b}) + return b.String() +} + +type wrappedBuffer struct{ *bytes.Buffer } + +func (wb wrappedBuffer) WriteString(value string) { + wb.Buffer.WriteString(value) +} + +// HashableContext is a wrapper around a context Struct that provides hashing. +type HashableContext struct{ *structpb.Struct } + +func (hc HashableContext) AppendToHash(hasher HasherInterface) { + // NOTE: the order of keys in the Struct and its resulting JSON output are *unspecified*, + // as the go runtime randomizes iterator order to ensure that if relied upon, a sort is used. + // Therefore, we sort the keys here before adding them to the hash. + if hc.Struct == nil { + return + } + + fields := hc.Struct.Fields + keys := maps.Keys(fields) + sort.Strings(keys) + + for _, key := range keys { + hasher.WriteString("`") + hasher.WriteString(key) + hasher.WriteString("`:") + hashableStructValue{fields[key]}.AppendToHash(hasher) + hasher.WriteString(",\n") + } +} + +type hashableStructValue struct{ *structpb.Value } + +func (hsv hashableStructValue) AppendToHash(hasher HasherInterface) { + switch t := hsv.Kind.(type) { + case *structpb.Value_BoolValue: + hasher.WriteString(strconv.FormatBool(t.BoolValue)) + + case *structpb.Value_ListValue: + for _, value := range t.ListValue.Values { + hashableStructValue{value}.AppendToHash(hasher) + hasher.WriteString(",") + } + + case *structpb.Value_NullValue: + hasher.WriteString("null") + + case *structpb.Value_NumberValue: + // AFAICT, this is how Sprintf-style formats float64s + hasher.WriteString(strconv.FormatFloat(t.NumberValue, 'f', 6, 64)) + + case *structpb.Value_StringValue: + // NOTE: we escape the string value here to prevent accidental overlap in keys for string + // values that may themselves contain backticks. + hasher.WriteString("`" + url.PathEscape(t.StringValue) + "`") + + case *structpb.Value_StructValue: + hasher.WriteString("{") + HashableContext{t.StructValue}.AppendToHash(hasher) + hasher.WriteString("}") + + default: + panic(fmt.Sprintf("unknown struct value type: %T", t)) + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/doc.go b/vendor/github.com/authzed/spicedb/pkg/caveats/doc.go new file mode 100644 index 0000000..cb1d885 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/doc.go @@ -0,0 +1,2 @@ +// Package caveats contains code to compile caveats and to evaluate a caveat with a given context. +package caveats diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/env.go b/vendor/github.com/authzed/spicedb/pkg/caveats/env.go new file mode 100644 index 0000000..7e33183 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/env.go @@ -0,0 +1,111 @@ +package caveats + +import ( + "fmt" + + "github.com/authzed/cel-go/cel" + + "github.com/authzed/spicedb/pkg/caveats/types" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// Environment defines the evaluation environment for a caveat. +type Environment struct { + ts *types.TypeSet + variables map[string]types.VariableType +} + +// NewEnvironment creates and returns a new environment for compiling a caveat. +func NewEnvironment() *Environment { + return &Environment{ + ts: types.Default.TypeSet, + variables: map[string]types.VariableType{}, + } +} + +// NewEnvironmentWithTypeSet creates and returns a new environment for compiling a caveat +// with the given TypeSet. +func NewEnvironmentWithTypeSet(ts *types.TypeSet) *Environment { + return &Environment{ + ts: ts, + variables: map[string]types.VariableType{}, + } +} + +// EnvForVariables returns a new environment constructed for the given variables. +func EnvForVariables(vars map[string]types.VariableType) (*Environment, error) { + return EnvForVariablesWithTypeSet(types.Default.TypeSet, vars) +} + +// EnvForVariablesWithTypeSet returns a new environment constructed for the given variables. +func EnvForVariablesWithTypeSet(ts *types.TypeSet, vars map[string]types.VariableType) (*Environment, error) { + e := NewEnvironmentWithTypeSet(ts) + for varName, varType := range vars { + err := e.AddVariable(varName, varType) + if err != nil { + return nil, err + } + } + return e, nil +} + +// MustEnvForVariables returns a new environment constructed for the given variables +// or panics. +func MustEnvForVariables(vars map[string]types.VariableType) *Environment { + env, err := EnvForVariables(vars) + if err != nil { + panic(err) + } + return env +} + +// AddVariable adds a variable with the given type to the environment. +func (e *Environment) AddVariable(name string, varType types.VariableType) error { + if _, ok := e.variables[name]; ok { + return fmt.Errorf("variable `%s` already exists", name) + } + + e.variables[name] = varType + return nil +} + +// EncodedParametersTypes returns the map of encoded parameters for the environment. +func (e *Environment) EncodedParametersTypes() map[string]*core.CaveatTypeReference { + return types.EncodeParameterTypes(e.variables) +} + +// asCelEnvironment converts the exported Environment into an internal CEL environment. +func (e *Environment) asCelEnvironment(extraOptions ...cel.EnvOption) (*cel.Env, error) { + tsOptions, err := e.ts.EnvOptions() + if err != nil { + return nil, err + } + + opts := make([]cel.EnvOption, 0, len(extraOptions)+len(e.variables)+len(tsOptions)+2) + opts = append(opts, extraOptions...) + opts = append(opts, tsOptions...) + + // Set options. + // DefaultUTCTimeZone: ensure all timestamps are evaluated at UTC + opts = append(opts, cel.DefaultUTCTimeZone(true)) + + // OptionalTypes: enable optional typing syntax, e.g. `sometype?.foo` + // See: https://github.com/google/cel-spec/wiki/proposal-246 + opts = append(opts, cel.OptionalTypes(cel.OptionalTypesVersion(0))) + + // EnableMacroCallTracking: enables tracking of call macros so when we call AstToString we get + // back out the expected expressions. + // See: https://github.com/authzed/cel-go/issues/474 + opts = append(opts, cel.EnableMacroCallTracking()) + + // ParserExpressionSizeLimit: disable the size limit for codepoints in expressions. + // This has to be disabled due to us padding out the whitespace in expression parsing based on + // schema size. We instead do our own expression size check in the Compile method. + // TODO(jschorr): Remove this once the whitespace hack is removed. + opts = append(opts, cel.ParserExpressionSizeLimit(-1)) + + for name, varType := range e.variables { + opts = append(opts, cel.Variable(name, varType.CelType())) + } + return cel.NewEnv(opts...) +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/errors.go b/vendor/github.com/authzed/spicedb/pkg/caveats/errors.go new file mode 100644 index 0000000..9c49609 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/errors.go @@ -0,0 +1,76 @@ +package caveats + +import ( + "strconv" + + "github.com/authzed/cel-go/cel" + "github.com/rs/zerolog" +) + +// EvaluationError is an error in evaluation of a caveat expression. +type EvaluationError struct { + error +} + +// MarshalZerologObject implements zerolog.LogObjectMarshaler +func (err EvaluationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err EvaluationError) DetailsMetadata() map[string]string { + return map[string]string{} +} + +// ParameterConversionError is an error in type conversion of a supplied parameter. +type ParameterConversionError struct { + error + parameterName string +} + +// MarshalZerologObject implements zerolog.LogObjectMarshaler +func (err ParameterConversionError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("parameterName", err.parameterName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err ParameterConversionError) DetailsMetadata() map[string]string { + return map[string]string{ + "parameter_name": err.parameterName, + } +} + +// ParameterName is the name of the parameter. +func (err ParameterConversionError) ParameterName() string { + return err.parameterName +} + +// MultipleCompilationError is a wrapping error for containing compilation errors for a Caveat. +type MultipleCompilationError struct { + error + + issues *cel.Issues +} + +// LineNumber is the 0-indexed line number for compilation error. +func (err MultipleCompilationError) LineNumber() int { + return err.issues.Errors()[0].Location.Line() - 1 +} + +// ColumnPositionis the 0-indexed column position for compilation error. +func (err MultipleCompilationError) ColumnPosition() int { + return err.issues.Errors()[0].Location.Column() - 1 +} + +// MarshalZerologObject implements zerolog.LogObjectMarshaler +func (err MultipleCompilationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Int("lineNumber", err.LineNumber()).Int("columnPosition", err.ColumnPosition()) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err MultipleCompilationError) DetailsMetadata() map[string]string { + return map[string]string{ + "line_number": strconv.Itoa(err.LineNumber()), + "column_position": strconv.Itoa(err.ColumnPosition()), + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/eval.go b/vendor/github.com/authzed/spicedb/pkg/caveats/eval.go new file mode 100644 index 0000000..cef9a4b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/eval.go @@ -0,0 +1,156 @@ +package caveats + +import ( + "fmt" + + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/cel-go/cel" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" +) + +// EvaluationConfig is configuration given to an EvaluateCaveatWithConfig call. +type EvaluationConfig struct { + // MaxCost is the max cost of the caveat to be executed. + MaxCost uint64 +} + +// CaveatResult holds the result of evaluating a caveat. +type CaveatResult struct { + val ref.Val + details *cel.EvalDetails + parentCaveat *CompiledCaveat + contextValues map[string]any + missingVarNames []string + isPartial bool +} + +// Value returns the computed value for the result. +func (cr CaveatResult) Value() bool { + if cr.isPartial { + return false + } + + return cr.val.Value().(bool) +} + +// IsPartial returns true if the caveat was only partially evaluated. +func (cr CaveatResult) IsPartial() bool { + return cr.isPartial +} + +// PartialValue returns the partially evaluated caveat. Only applies if IsPartial is true. +func (cr CaveatResult) PartialValue() (*CompiledCaveat, error) { + if !cr.isPartial { + return nil, fmt.Errorf("result is fully evaluated") + } + + ast, err := cr.parentCaveat.celEnv.ResidualAst(cr.parentCaveat.ast, cr.details) + if err != nil { + return nil, err + } + + return &CompiledCaveat{cr.parentCaveat.celEnv, ast, cr.parentCaveat.name}, nil +} + +// ContextValues returns the context values used when computing this result. +func (cr CaveatResult) ContextValues() map[string]any { + return cr.contextValues +} + +// ContextStruct returns the context values used when computing this result as +// a structpb. +func (cr CaveatResult) ContextStruct() (*structpb.Struct, error) { + return ConvertContextToStruct(cr.contextValues) +} + +// ExpressionString returns the human-readable expression string for the evaluated expression. +func (cr CaveatResult) ExpressionString() (string, error) { + return cr.parentCaveat.ExprString() +} + +// ParentCaveat returns the caveat that was evaluated to produce this result. +func (cr CaveatResult) ParentCaveat() *CompiledCaveat { + return cr.parentCaveat +} + +// MissingVarNames returns the name(s) of the missing variables. +func (cr CaveatResult) MissingVarNames() ([]string, error) { + if !cr.isPartial { + return nil, fmt.Errorf("result is fully evaluated") + } + + return cr.missingVarNames, nil +} + +// EvaluateCaveat evaluates the compiled caveat with the specified values, and returns +// the result or an error. +func EvaluateCaveat(caveat *CompiledCaveat, contextValues map[string]any) (*CaveatResult, error) { + return EvaluateCaveatWithConfig(caveat, contextValues, nil) +} + +// EvaluateCaveatWithConfig evaluates the compiled caveat with the specified values, and returns +// the result or an error. +func EvaluateCaveatWithConfig(caveat *CompiledCaveat, contextValues map[string]any, config *EvaluationConfig) (*CaveatResult, error) { + env := caveat.celEnv + celopts := make([]cel.ProgramOption, 0, 3) + + // Option: enables partial evaluation and state tracking for partial evaluation. + celopts = append(celopts, cel.EvalOptions(cel.OptTrackState)) + celopts = append(celopts, cel.EvalOptions(cel.OptPartialEval)) + + // Option: Cost limit on the evaluation. + if config != nil && config.MaxCost > 0 { + celopts = append(celopts, cel.CostLimit(config.MaxCost)) + } + + prg, err := env.Program(caveat.ast, celopts...) + if err != nil { + return nil, err + } + + // Mark any unspecified variables as unknown, to ensure that partial application + // will result in producing a type of Unknown. + activation, err := env.PartialVars(contextValues) + if err != nil { + return nil, err + } + + val, details, err := prg.Eval(activation) + if err != nil { + return nil, EvaluationError{err} + } + + // If the value produced has Unknown type, then it means required context was missing. + if types.IsUnknown(val) { + unknownVal := val.(*types.Unknown) + missingVarNames := make([]string, 0, len(unknownVal.IDs())) + for _, id := range unknownVal.IDs() { + trails, ok := unknownVal.GetAttributeTrails(id) + if ok { + for _, attributeTrail := range trails { + missingVarNames = append(missingVarNames, attributeTrail.String()) + } + } + } + + return &CaveatResult{ + val: val, + details: details, + parentCaveat: caveat, + contextValues: contextValues, + missingVarNames: missingVarNames, + isPartial: true, + }, nil + } + + return &CaveatResult{ + val: val, + details: details, + parentCaveat: caveat, + contextValues: contextValues, + missingVarNames: nil, + isPartial: false, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/parameters.go b/vendor/github.com/authzed/spicedb/pkg/caveats/parameters.go new file mode 100644 index 0000000..cd3ef67 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/parameters.go @@ -0,0 +1,82 @@ +package caveats + +import ( + "fmt" + "strings" + + "github.com/authzed/spicedb/pkg/caveats/types" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// UnknownParameterOption is the option to ConvertContextToParameters around handling +// of unknown parameters. +type UnknownParameterOption int + +const ( + // SkipUnknownParameters indicates that unknown parameters should be skipped in conversion. + SkipUnknownParameters UnknownParameterOption = 0 + + // ErrorForUnknownParameters indicates that unknown parameters should return an error. + ErrorForUnknownParameters UnknownParameterOption = 1 +) + +// ConvertContextToParameters converts the given context into parameters of the types specified. +// Returns a type error if type conversion failed. +func ConvertContextToParameters( + ts *types.TypeSet, + contextMap map[string]any, + parameterTypes map[string]*core.CaveatTypeReference, + unknownParametersOption UnknownParameterOption, +) (map[string]any, error) { + if len(contextMap) == 0 { + return nil, nil + } + + if len(parameterTypes) == 0 { + return nil, fmt.Errorf("missing parameters for caveat") + } + + converted := make(map[string]any, len(contextMap)) + + for key, value := range contextMap { + paramType, ok := parameterTypes[key] + if !ok { + if unknownParametersOption == ErrorForUnknownParameters { + return nil, fmt.Errorf("unknown parameter `%s`", key) + } + + continue + } + + varType, err := types.DecodeParameterType(ts, paramType) + if err != nil { + return nil, err + } + + convertedParam, err := varType.ConvertValue(value) + if err != nil { + return nil, ParameterConversionError{fmt.Errorf("could not convert context parameter `%s`: %w", key, err), key} + } + + converted[key] = convertedParam + } + return converted, nil +} + +// ParameterTypeString returns the string form of the type reference. +func ParameterTypeString(typeRef *core.CaveatTypeReference) string { + var sb strings.Builder + sb.WriteString(typeRef.TypeName) + if len(typeRef.ChildTypes) > 0 { + sb.WriteString("<") + for idx, childType := range typeRef.ChildTypes { + if idx > 0 { + sb.WriteString(", ") + } + sb.WriteString(ParameterTypeString(childType)) + } + sb.WriteString(">") + } + + return sb.String() +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/replacer/inlining.go b/vendor/github.com/authzed/spicedb/pkg/caveats/replacer/inlining.go new file mode 100644 index 0000000..33d8b77 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/replacer/inlining.go @@ -0,0 +1,171 @@ +// Modified version of https://github.com/authzed/cel-go/blob/b707d132d96bb5450df92d138860126bf03f805f/cel/inlining.go +// which changes it so variable replacement is always done without cel.bind calls. See +// the "CHANGED" below. +// +// Original Copyright notice: +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package replacer + +import ( + "github.com/authzed/cel-go/cel" + "github.com/authzed/cel-go/common/ast" + "github.com/authzed/cel-go/common/containers" + "github.com/authzed/cel-go/common/operators" + "github.com/authzed/cel-go/common/overloads" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/traits" +) + +// InlineVariable holds a variable name to be matched and an AST representing +// the expression graph which should be used to replace it. +type InlineVariable struct { + name string + alias string + def *ast.AST +} + +// Name returns the qualified variable or field selection to replace. +func (v *InlineVariable) Name() string { + return v.name +} + +// Alias returns the alias to use when performing cel.bind() calls during inlining. +func (v *InlineVariable) Alias() string { + return v.alias +} + +// Expr returns the inlined expression value. +func (v *InlineVariable) Expr() ast.Expr { + return v.def.Expr() +} + +// Type indicates the inlined expression type. +func (v *InlineVariable) Type() *cel.Type { + return v.def.GetType(v.def.Expr().ID()) +} + +// newInlineVariable declares a variable name to be replaced by a checked expression. +func newInlineVariable(name string, definition *cel.Ast) *InlineVariable { + return newInlineVariableWithAlias(name, name, definition) +} + +// newInlineVariableWithAlias declares a variable name to be replaced by a checked expression. +// If the variable occurs more than once, the provided alias will be used to replace the expressions +// where the variable name occurs. +func newInlineVariableWithAlias(name, alias string, definition *cel.Ast) *InlineVariable { + return &InlineVariable{name: name, alias: alias, def: definition.NativeRep()} +} + +// newModifiedInliningOptimizer creates and optimizer which replaces variables with expression definitions. +func newModifiedInliningOptimizer(inlineVars ...*InlineVariable) cel.ASTOptimizer { + return &inliningOptimizer{variables: inlineVars} +} + +type inliningOptimizer struct { + variables []*InlineVariable +} + +func (opt *inliningOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + for _, inlineVar := range opt.variables { + matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name())) + // Skip cases where the variable isn't in the expression graph + if len(matches) == 0 { + continue + } + + // CHANGED: *ALWAYS* do a direct replacement of the expression sub-graph. + for _, match := range matches { + // Copy the inlined AST expr and source info. + copyExpr := ctx.CopyASTAndMetadata(inlineVar.def) + opt.inlineExpr(ctx, match, copyExpr, inlineVar.Type()) + } + } + return a +} + +// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining +// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is +// made to determine whether the inlined value can be presence or existence tested. +func (opt *inliningOptimizer) inlineExpr(ctx *cel.OptimizerContext, prev ast.NavigableExpr, inlined ast.Expr, inlinedType *cel.Type) { + switch prev.Kind() { + case ast.SelectKind: + sel := prev.AsSelect() + if !sel.IsTestOnly() { + ctx.UpdateExpr(prev, inlined) + return + } + opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType) + default: + ctx.UpdateExpr(prev, inlined) + } +} + +// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe +// expression appropriate for the inlined type, if possible. +// +// If the rewrite is not possible an error is reported at the inline expression site. +func (opt *inliningOptimizer) rewritePresenceExpr(ctx *cel.OptimizerContext, prev, inlined ast.Expr, inlinedType *cel.Type) { + // If the input inlined expression is not a select expression it won't work with the has() + // macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error. + if inlined.Kind() == ast.SelectKind { + presenceTest, hasMacro := ctx.NewHasMacro(prev.ID(), inlined) + ctx.UpdateExpr(prev, presenceTest) + ctx.SetMacroCall(prev.ID(), hasMacro) + return + } + + ctx.ClearMacroCall(prev.ID()) + if inlinedType.IsAssignableType(cel.NullType) { + ctx.UpdateExpr(prev, + ctx.NewCall(operators.NotEquals, + inlined, + ctx.NewLiteral(types.NullValue), + )) + return + } + if inlinedType.HasTrait(traits.SizerType) { + ctx.UpdateExpr(prev, + ctx.NewCall(operators.NotEquals, + ctx.NewMemberCall(overloads.Size, inlined), + ctx.NewLiteral(types.IntZero), + )) + return + } + ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType) +} + +// matchVariable matches simple identifiers, select expressions, and presence test expressions +// which match the (potentially) qualified variable name provided as input. +// +// Note, this function does not support inlining against select expressions which includes optional +// field selection. This may be a future refinement. +func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher { + return func(e ast.NavigableExpr) bool { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + return true + } + if e.Kind() == ast.SelectKind { + sel := e.AsSelect() + // While the `ToQualifiedName` call could take the select directly, this + // would skip presence tests from possible matches, which we would like + // to include. + qualName, found := containers.ToQualifiedName(sel.Operand()) + return found && qualName+"."+sel.FieldName() == varName + } + return false + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/replacer/replacer.go b/vendor/github.com/authzed/spicedb/pkg/caveats/replacer/replacer.go new file mode 100644 index 0000000..d57f40a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/replacer/replacer.go @@ -0,0 +1,25 @@ +package replacer + +import ( + "fmt" + + "github.com/authzed/cel-go/cel" +) + +func ReplaceVariable(e *cel.Env, existingAst *cel.Ast, oldVarName string, newVarName string) (*cel.Ast, error) { + newExpr, iss := e.Compile(newVarName) + if iss.Err() != nil { + return nil, fmt.Errorf("failed to compile new variable name: %w", iss.Err()) + } + + inlinedVars := []*InlineVariable{} + inlinedVars = append(inlinedVars, newInlineVariable(oldVarName, newExpr)) + + opt := cel.NewStaticOptimizer(newModifiedInliningOptimizer(inlinedVars...)) + optimized, iss := opt.Optimize(e, existingAst) + if iss.Err() != nil { + return nil, fmt.Errorf("failed to optimize: %w", iss.Err()) + } + + return optimized, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/source.go b/vendor/github.com/authzed/spicedb/pkg/caveats/source.go new file mode 100644 index 0000000..8550f03 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/source.go @@ -0,0 +1,19 @@ +package caveats + +import ( + "github.com/authzed/cel-go/common" +) + +// SourcePosition is an incoming source position. +type SourcePosition interface { + // LineAndColumn returns the 0-indexed line number and column position in the source file. + LineAndColumn() (int, int, error) + + // RunePosition returns the 0-indexed rune position in the source file. + RunePosition() (int, error) +} + +// NewSource creates a new source for compilation into a caveat. +func NewSource(expressionString string, name string) (common.Source, error) { + return common.NewStringSource(expressionString, name), nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/structure.go b/vendor/github.com/authzed/spicedb/pkg/caveats/structure.go new file mode 100644 index 0000000..d27602a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/structure.go @@ -0,0 +1,56 @@ +package caveats + +import ( + "fmt" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/authzed/spicedb/pkg/genutil/mapz" +) + +// referencedParameters traverses the expression given and finds all parameters which are referenced +// in the expression for the purpose of usage tracking. +func referencedParameters(definedParameters *mapz.Set[string], expr *exprpb.Expr, referencedParams *mapz.Set[string]) { + if expr == nil { + return + } + + switch t := expr.ExprKind.(type) { + case *exprpb.Expr_ConstExpr: + // nothing to do + + case *exprpb.Expr_IdentExpr: + if definedParameters.Has(t.IdentExpr.Name) { + referencedParams.Add(t.IdentExpr.Name) + } + + case *exprpb.Expr_SelectExpr: + referencedParameters(definedParameters, t.SelectExpr.Operand, referencedParams) + + case *exprpb.Expr_CallExpr: + referencedParameters(definedParameters, t.CallExpr.Target, referencedParams) + for _, arg := range t.CallExpr.Args { + referencedParameters(definedParameters, arg, referencedParams) + } + + case *exprpb.Expr_ListExpr: + for _, elem := range t.ListExpr.Elements { + referencedParameters(definedParameters, elem, referencedParams) + } + + case *exprpb.Expr_StructExpr: + for _, entry := range t.StructExpr.Entries { + referencedParameters(definedParameters, entry.Value, referencedParams) + } + + case *exprpb.Expr_ComprehensionExpr: + referencedParameters(definedParameters, t.ComprehensionExpr.AccuInit, referencedParams) + referencedParameters(definedParameters, t.ComprehensionExpr.IterRange, referencedParams) + referencedParameters(definedParameters, t.ComprehensionExpr.LoopCondition, referencedParams) + referencedParameters(definedParameters, t.ComprehensionExpr.LoopStep, referencedParams) + referencedParameters(definedParameters, t.ComprehensionExpr.Result, referencedParams) + + default: + panic(fmt.Sprintf("unknown CEL expression kind: %T", t)) + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/basic.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/basic.go new file mode 100644 index 0000000..688fc34 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/basic.go @@ -0,0 +1,262 @@ +package types + +import ( + "encoding/base64" + "fmt" + "math/big" + "time" + + "github.com/authzed/cel-go/cel" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// RegisterBasicTypes registers the basic types with the given keyword, CEL type, and converter. +func RegisterBasicTypes(sts *StandardTypeSet) error { + ts := sts.TypeSet + + // Register "any" type + anyType, err := RegisterBasicType(ts, "any", cel.DynType, func(value any) (any, error) { return value, nil }) + if err != nil { + return err + } + sts.AnyType = anyType + + // Register "bool" type + boolType, err := RegisterBasicType(ts, "bool", cel.BoolType, requireType[bool]) + if err != nil { + return err + } + sts.BooleanType = boolType + + // Register "string" type + stringType, err := RegisterBasicType(ts, "string", cel.StringType, requireType[string]) + if err != nil { + return err + } + sts.StringType = stringType + + // Register "int" type + intType, err := RegisterBasicType(ts, "int", cel.IntType, convertNumericType[int64]) + if err != nil { + return err + } + sts.IntType = intType + + // Register "uint" type + uintType, err := RegisterBasicType(ts, "uint", cel.IntType, convertNumericType[uint64]) + if err != nil { + return err + } + sts.UIntType = uintType + + // Register "double" type + doubleType, err := RegisterBasicType(ts, "double", cel.DoubleType, convertNumericType[float64]) + if err != nil { + return err + } + sts.DoubleType = doubleType + + // Register "bytes" type + bytesType, err := RegisterBasicType(ts, "bytes", cel.BytesType, func(value any) (any, error) { + vle, ok := value.(string) + if !ok { + return nil, fmt.Errorf("bytes requires a base64 unicode string, found: %T `%v`", value, value) + } + + decoded, err := base64.StdEncoding.DecodeString(vle) + if err != nil { + return nil, fmt.Errorf("bytes requires a base64 encoded string: %w", err) + } + + return decoded, nil + }) + if err != nil { + return err + } + sts.BytesType = bytesType + + // Register "duration" type + durationType, err := RegisterBasicType(ts, "duration", cel.DurationType, func(value any) (any, error) { + vle, ok := value.(string) + if !ok { + return nil, fmt.Errorf("durations requires a duration string, found: %T", value) + } + + d, err := time.ParseDuration(vle) + if err != nil { + return nil, fmt.Errorf("could not parse duration string `%s`: %w", vle, err) + } + + return d, nil + }) + if err != nil { + return err + } + sts.DurationType = durationType + + // Register "timestamp" type + timestampType, err := RegisterBasicType(ts, "timestamp", cel.TimestampType, func(value any) (any, error) { + vle, ok := value.(string) + if !ok { + return nil, fmt.Errorf("timestamps requires a RFC 3339 formatted timestamp string, found: %T `%v`", value, value) + } + + d, err := time.Parse(time.RFC3339, vle) + if err != nil { + return nil, fmt.Errorf("could not parse RFC 3339 formatted timestamp string `%s`: %w", vle, err) + } + + return d, nil + }) + if err != nil { + return err + } + sts.TimestampType = timestampType + + listTypeBuilder, err := RegisterGenericType(ts, "list", 1, + func(childTypes []VariableType) VariableType { + return VariableType{ + localName: "list", + celType: cel.ListType(childTypes[0].celType), + childTypes: childTypes, + converter: func(value any) (any, error) { + vle, ok := value.([]any) + if !ok { + return nil, fmt.Errorf("list requires a list, found: %T", value) + } + + converted := make([]any, 0, len(vle)) + for index, item := range vle { + convertedItem, err := childTypes[0].ConvertValue(item) + if err != nil { + return nil, fmt.Errorf("found an invalid value for item at index %d: %w", index, err) + } + converted = append(converted, convertedItem) + } + + return converted, nil + }, + } + }) + if err != nil { + return err + } + sts.listTypeBuilder = listTypeBuilder + + mapTypeBuilder, err := RegisterGenericType(ts, "map", 1, + func(childTypes []VariableType) VariableType { + return VariableType{ + localName: "map", + celType: cel.MapType(cel.StringType, childTypes[0].celType), + childTypes: childTypes, + converter: func(value any) (any, error) { + vle, ok := value.(map[string]any) + if !ok { + return nil, fmt.Errorf("map requires a map, found: %T", value) + } + + converted := make(map[string]any, len(vle)) + for key, item := range vle { + convertedItem, err := childTypes[0].ConvertValue(item) + if err != nil { + return nil, fmt.Errorf("found an invalid value for key `%s`: %w", key, err) + } + + converted[key] = convertedItem + } + + return converted, nil + }, + } + }, + ) + if err != nil { + return err + } + sts.mapTypeBuilder = mapTypeBuilder + + if err := RegisterMethodOnDefinedType(ts, cel.MapType(cel.StringType, cel.DynType), + "isSubtreeOf", + []*cel.Type{cel.MapType(cel.StringType, cel.DynType)}, + cel.BoolType, + func(arg ...ref.Val) ref.Val { + map0 := arg[0].Value().(map[string]any) + map1 := arg[1].Value().(map[string]any) + return types.Bool(subtree(map0, map1)) + }, + ); err != nil { + return err + } + + ipAddressType, err := RegisterIPAddressType(ts) + if err != nil { + return err + } + sts.IPAddressType = ipAddressType + + return nil +} + +func requireType[T any](value any) (any, error) { + vle, ok := value.(T) + if !ok { + return nil, fmt.Errorf("a %T value is required, but found %T `%v`", *new(T), value, value) + } + return vle, nil +} + +func convertNumericType[T int64 | uint64 | float64](value any) (any, error) { + directValue, ok := value.(T) + if ok { + return directValue, nil + } + + floatValue, ok := value.(float64) + bigFloat := big.NewFloat(floatValue) + if !ok { + stringValue, ok := value.(string) + if !ok { + return nil, fmt.Errorf("a %T value is required, but found %T `%v`", *new(T), value, value) + } + + f, _, err := big.ParseFloat(stringValue, 10, 64, 0) + if err != nil { + return nil, fmt.Errorf("a %T value is required, but found invalid string value `%v`", *new(T), value) + } + + bigFloat = f + } + + // Convert the float to the int or uint if necessary. + n := *new(T) + switch any(n).(type) { + case int64: + if !bigFloat.IsInt() { + return nil, fmt.Errorf("a int value is required, but found numeric value `%s`", bigFloat.String()) + } + + numericValue, _ := bigFloat.Int64() + return numericValue, nil + + case uint64: + if !bigFloat.IsInt() { + return nil, fmt.Errorf("a uint value is required, but found numeric value `%s`", bigFloat.String()) + } + + numericValue, _ := bigFloat.Int64() + if numericValue < 0 { + return nil, fmt.Errorf("a uint value is required, but found int64 value `%s`", bigFloat.String()) + } + return uint64(numericValue), nil + + case float64: + numericValue, _ := bigFloat.Float64() + return numericValue, nil + + default: + return nil, spiceerrors.MustBugf("unsupported numeric type in caveat number type conversion: %T", n) + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/custom.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/custom.go new file mode 100644 index 0000000..08c18f1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/custom.go @@ -0,0 +1,8 @@ +package types + +// CustomType is the interface for custom-defined types. +type CustomType interface { + // SerializedString returns the serialized string form of the data within + // this instance of the type. + SerializedString() string +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/encoding.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/encoding.go new file mode 100644 index 0000000..5d088c1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/encoding.go @@ -0,0 +1,72 @@ +package types + +import ( + "fmt" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// EncodeParameterTypes converts the map of internal caveat types into a map of types for storing +// the caveat in the core. +func EncodeParameterTypes(parametersAndTypes map[string]VariableType) map[string]*core.CaveatTypeReference { + encoded := make(map[string]*core.CaveatTypeReference, len(parametersAndTypes)) + for name, varType := range parametersAndTypes { + encoded[name] = EncodeParameterType(varType) + } + return encoded +} + +// EncodeParameterType converts an internal caveat type into a storable core type. +func EncodeParameterType(varType VariableType) *core.CaveatTypeReference { + childTypes := make([]*core.CaveatTypeReference, 0, len(varType.childTypes)) + for _, childType := range varType.childTypes { + childTypes = append(childTypes, EncodeParameterType(childType)) + } + + return &core.CaveatTypeReference{ + TypeName: varType.localName, + ChildTypes: childTypes, + } +} + +// DecodeParameterType decodes the core caveat parameter type into an internal caveat type. +func DecodeParameterType(ts *TypeSet, parameterType *core.CaveatTypeReference) (*VariableType, error) { + typeDef, ok := ts.definitions[parameterType.TypeName] + if !ok { + return nil, fmt.Errorf("unknown caveat parameter type `%s`", parameterType.TypeName) + } + + if len(parameterType.ChildTypes) != int(typeDef.childTypeCount) { + return nil, fmt.Errorf( + "caveat parameter type `%s` requires %d child types; found %d", + parameterType.TypeName, + len(parameterType.ChildTypes), + typeDef.childTypeCount, + ) + } + + childTypes := make([]VariableType, 0, typeDef.childTypeCount) + for _, encodedChildType := range parameterType.ChildTypes { + childType, err := DecodeParameterType(ts, encodedChildType) + if err != nil { + return nil, err + } + childTypes = append(childTypes, *childType) + } + + return typeDef.asVariableType(childTypes) +} + +// DecodeParameterTypes decodes the core caveat parameter types into internal caveat types. +func DecodeParameterTypes(ts *TypeSet, parameters map[string]*core.CaveatTypeReference) (map[string]VariableType, error) { + parameterTypes := make(map[string]VariableType, len(parameters)) + for paramName, paramType := range parameters { + decodedType, err := DecodeParameterType(ts, paramType) + if err != nil { + return nil, err + } + + parameterTypes[paramName] = *decodedType + } + return parameterTypes, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/ipaddress.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/ipaddress.go new file mode 100644 index 0000000..2a92697 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/ipaddress.go @@ -0,0 +1,114 @@ +package types + +import ( + "fmt" + "net/netip" + "reflect" + + "github.com/authzed/cel-go/cel" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" +) + +// ParseIPAddress parses the string form of an IP Address into an IPAddress object type. +func ParseIPAddress(ip string) (IPAddress, error) { + parsed, err := netip.ParseAddr(ip) + return IPAddress{parsed}, err +} + +// MustParseIPAddress parses the string form of an IP Address into an IPAddress object type. +func MustParseIPAddress(ip string) IPAddress { + ipAddress, err := ParseIPAddress(ip) + if err != nil { + panic(err) + } + return ipAddress +} + +var ipaddressCelType = cel.OpaqueType("IPAddress") + +// IPAddress defines a custom type for representing an IP Address in caveats. +type IPAddress struct { + ip netip.Addr +} + +func (ipa IPAddress) SerializedString() string { + return ipa.ip.String() +} + +func (ipa IPAddress) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + switch typeDesc { + case reflect.TypeOf(""): + return ipa.ip.String(), nil + } + return nil, fmt.Errorf("type conversion error from 'IPAddress' to '%v'", typeDesc) +} + +func (ipa IPAddress) ConvertToType(typeVal ref.Type) ref.Val { + switch typeVal { + case types.StringType: + return types.String(ipa.ip.String()) + case types.TypeType: + return ipaddressCelType + } + return types.NewErr("type conversion error from '%s' to '%s'", ipaddressCelType, typeVal) +} + +func (ipa IPAddress) Equal(other ref.Val) ref.Val { + o2, ok := other.(IPAddress) + if !ok { + return types.ValOrErr(other, "no such overload") + } + return types.Bool(ipa == o2) +} + +func (ipa IPAddress) Type() ref.Type { + return ipaddressCelType +} + +func (ipa IPAddress) Value() interface{} { + return ipa +} + +func RegisterIPAddressType(ts *TypeSet) (VariableType, error) { + return RegisterCustomType[IPAddress](ts, + "ipaddress", + cel.ObjectType("IPAddress"), + func(value any) (any, error) { + ipvalue, ok := value.(IPAddress) + if ok { + return ipvalue, nil + } + + vle, ok := value.(string) + if !ok { + return nil, fmt.Errorf("ipaddress requires an ipaddress string, found: %T `%v`", value, value) + } + + d, err := ParseIPAddress(vle) + if err != nil { + return nil, fmt.Errorf("could not parse ip address string `%s`: %w", vle, err) + } + + return d, nil + }, + cel.Function("in_cidr", + cel.MemberOverload("ipaddress_in_cidr_string", + []*cel.Type{cel.ObjectType("IPAddress"), cel.StringType}, + cel.BoolType, + cel.BinaryBinding(func(lhs, rhs ref.Val) ref.Val { + cidr, ok := rhs.Value().(string) + if !ok { + return types.NewErr("expected CIDR string") + } + + network, err := netip.ParsePrefix(cidr) + if err != nil { + return types.NewErr("invalid CIDR string: `%s`", cidr) + } + + return types.Bool(network.Contains(lhs.(IPAddress).ip)) + }), + ), + )) +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/map.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/map.go new file mode 100644 index 0000000..3d949cd --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/map.go @@ -0,0 +1,26 @@ +package types + +func subtree(map0 map[string]any, map1 map[string]any) bool { + for k, v := range map0 { + val, ok := map1[k] + if !ok { + return false + } + nestedMap0, ok := v.(map[string]any) + if ok { + nestedMap1, ok := val.(map[string]any) + if !ok { + return false + } + nestedResult := subtree(nestedMap0, nestedMap1) + if !nestedResult { + return false + } + } else { + if v != val { + return false + } + } + } + return true +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/registration.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/registration.go new file mode 100644 index 0000000..96663b1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/registration.go @@ -0,0 +1,145 @@ +package types + +import ( + "fmt" + + "github.com/authzed/cel-go/cel" + "github.com/authzed/cel-go/common/types/ref" + + "github.com/authzed/spicedb/pkg/genutil" +) + +type ( + typedValueConverter func(value any) (any, error) +) + +type typeDefinition struct { + // localName is the localized name/keyword for the type. + localName string + + // childTypeCount is the number of generics on the type, if any. + childTypeCount uint8 + + // asVariableType converts the type definition into a VariableType. + asVariableType func(childTypes []VariableType) (*VariableType, error) +} + +// RegisterBasicType registers a basic type with the given keyword, CEL type, and converter. +func RegisterBasicType(ts *TypeSet, keyword string, celType *cel.Type, converter typedValueConverter) (VariableType, error) { + if ts.isFrozen { + return VariableType{}, fmt.Errorf("cannot register new types after the TypeSet is frozen") + } + + varType := VariableType{ + localName: keyword, + celType: celType, + childTypes: nil, + converter: converter, + } + + ts.definitions[keyword] = typeDefinition{ + localName: keyword, + childTypeCount: 0, + asVariableType: func(childTypes []VariableType) (*VariableType, error) { + return &varType, nil + }, + } + return varType, nil +} + +func MustRegisterBasicType(ts *TypeSet, keyword string, celType *cel.Type, converter typedValueConverter) VariableType { + varType, err := RegisterBasicType(ts, keyword, celType, converter) + if err != nil { + panic(fmt.Sprintf("failed to register basic type %s: %v", keyword, err)) + } + return varType +} + +type GenericTypeBuilder func(childTypes ...VariableType) (VariableType, error) + +// RegisterGenericType registers a type with at least one generic. +func RegisterGenericType( + ts *TypeSet, + keyword string, + childTypeCount uint8, + asVariableType func(childTypes []VariableType) VariableType, +) (GenericTypeBuilder, error) { + if ts.isFrozen { + return nil, fmt.Errorf("cannot register new types after the TypeSet is frozen") + } + + ts.definitions[keyword] = typeDefinition{ + localName: keyword, + childTypeCount: childTypeCount, + asVariableType: func(childTypes []VariableType) (*VariableType, error) { + childTypeLength, err := genutil.EnsureUInt8(len(childTypes)) + if err != nil { + return nil, err + } + + if childTypeLength != childTypeCount { + return nil, fmt.Errorf("type `%s` requires %d generic types; found %d", keyword, childTypeCount, len(childTypes)) + } + + built := asVariableType(childTypes) + return &built, nil + }, + } + return func(childTypes ...VariableType) (VariableType, error) { + childTypeLength, err := genutil.EnsureUInt8(len(childTypes)) + if err != nil { + return VariableType{}, err + } + + if childTypeLength != childTypeCount { + return VariableType{}, fmt.Errorf("invalid number of parameters given to type constructor. expected: %d, found: %d", childTypeCount, len(childTypes)) + } + + return asVariableType(childTypes), nil + }, nil +} + +func MustRegisterGenericType( + ts *TypeSet, + keyword string, + childTypeCount uint8, + asVariableType func(childTypes []VariableType) VariableType, +) GenericTypeBuilder { + genericTypeBuilder, err := RegisterGenericType(ts, keyword, childTypeCount, asVariableType) + if err != nil { + panic(fmt.Sprintf("failed to register generic type %s: %v", keyword, err)) + } + return genericTypeBuilder +} + +// RegisterCustomType registers a custom type that wraps a base CEL type. +func RegisterCustomType[T CustomType](ts *TypeSet, keyword string, baseCelType *cel.Type, converter typedValueConverter, opts ...cel.EnvOption) (VariableType, error) { + if ts.isFrozen { + return VariableType{}, fmt.Errorf("cannot register new types after the TypeSet is frozen") + } + + if err := RegisterCustomCELOptions(ts, opts...); err != nil { + return VariableType{}, err + } + + return RegisterBasicType(ts, keyword, baseCelType, converter) +} + +// RegisterCustomTypeWithName registers a custom type with a specific name. +func RegisterMethodOnDefinedType(ts *TypeSet, baseType *cel.Type, name string, args []*cel.Type, returnType *cel.Type, binding func(arg ...ref.Val) ref.Val) error { + finalArgs := make([]*cel.Type, 0, len(args)+1) + finalArgs = append(finalArgs, baseType) + finalArgs = append(finalArgs, args...) + method := cel.Function(name, cel.MemberOverload(name, finalArgs, returnType, cel.FunctionBinding(binding))) + + return RegisterCustomCELOptions(ts, method) +} + +// RegisterCustomOptions registers custom CEL environment options for the TypeSet. +func RegisterCustomCELOptions(ts *TypeSet, opts ...cel.EnvOption) error { + if ts.isFrozen { + return fmt.Errorf("cannot register new options after the TypeSet is frozen") + } + ts.customOptions = append(ts.customOptions, opts...) + return nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/standard.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/standard.go new file mode 100644 index 0000000..b5adf6a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/standard.go @@ -0,0 +1,84 @@ +package types + +// Default is the default set of types to be used in caveats, coming with all +// the standard types pre-registered. This set is frozen and cannot be modified. +var Default *StandardTypeSet + +func init() { + Default = MustNewStandardTypeSet() + Default.Freeze() +} + +// TypeSetOrDefault returns the provided TypeSet if it is not nil, otherwise it +// returns the default TypeSet. This is useful for functions that accept a +// TypeSet parameter but want to use the default if none is provided. +func TypeSetOrDefault(ts *TypeSet) *TypeSet { + if ts == nil { + return Default.TypeSet + } + return ts +} + +func MustNewStandardTypeSet() *StandardTypeSet { + sts, err := NewStandardTypeSet() + if err != nil { + panic(err) + } + return sts +} + +// NewStandardTypeSet creates a new TypeSet with all the standard types pre-registered. +func NewStandardTypeSet() (*StandardTypeSet, error) { + sts := &StandardTypeSet{ + TypeSet: NewTypeSet(), + } + + if err := RegisterBasicTypes(sts); err != nil { + return nil, err + } + return sts, nil +} + +// StandardTypeSet is a TypeSet that contains all the standard types and provides nice accessors +// for each. +type StandardTypeSet struct { + *TypeSet + + AnyType VariableType + BooleanType VariableType + StringType VariableType + IntType VariableType + UIntType VariableType + DoubleType VariableType + BytesType VariableType + DurationType VariableType + TimestampType VariableType + IPAddressType VariableType + + listTypeBuilder GenericTypeBuilder + mapTypeBuilder GenericTypeBuilder +} + +func (sts *StandardTypeSet) ListType(childTypes ...VariableType) (VariableType, error) { + return sts.listTypeBuilder(childTypes...) +} + +func (sts *StandardTypeSet) MapType(childTypes ...VariableType) (VariableType, error) { + return sts.mapTypeBuilder(childTypes...) +} + +func (sts *StandardTypeSet) MustListType(childTypes ...VariableType) VariableType { + v, err := sts.ListType(childTypes...) + if err != nil { + panic(err) + } + return v +} + +func (sts *StandardTypeSet) MustMapType(childTypes ...VariableType) VariableType { + v, err := sts.MapType(childTypes...) + if err != nil { + panic(err) + } + return v +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/types.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/types.go new file mode 100644 index 0000000..14fa081 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/types.go @@ -0,0 +1,44 @@ +package types + +import ( + "fmt" + "strings" + + "github.com/authzed/cel-go/cel" +) + +// VariableType defines the supported types of variables in caveats. +type VariableType struct { + localName string + celType *cel.Type + childTypes []VariableType + converter typedValueConverter +} + +// CelType returns the underlying CEL type for the variable type. +func (vt VariableType) CelType() *cel.Type { + return vt.celType +} + +func (vt VariableType) String() string { + if len(vt.childTypes) > 0 { + childTypeStrings := make([]string, 0, len(vt.childTypes)) + for _, childType := range vt.childTypes { + childTypeStrings = append(childTypeStrings, childType.String()) + } + + return vt.localName + "<" + strings.Join(childTypeStrings, ", ") + ">" + } + + return vt.localName +} + +// ConvertValue converts the given value into one expected by this variable type. +func (vt VariableType) ConvertValue(value any) (any, error) { + converted, err := vt.converter(value) + if err != nil { + return nil, fmt.Errorf("for %s: %w", vt.String(), err) + } + + return converted, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/caveats/types/typeset.go b/vendor/github.com/authzed/spicedb/pkg/caveats/types/typeset.go new file mode 100644 index 0000000..99b8bfe --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/caveats/types/typeset.go @@ -0,0 +1,59 @@ +package types + +import ( + "fmt" + + "github.com/authzed/cel-go/cel" +) + +// TypeSet defines a set of types that can be used in caveats. It is used to register custom types +// and methods that can be used in caveats. The types are registered by calling the RegisterType +// function. The types are then used to build the CEL environment for the caveat. +type TypeSet struct { + // definitions holds the set of all types defined and exported by this package, by name. + definitions map[string]typeDefinition + + // customOptions holds a set of custom options that can be used to create a CEL environment + // for the caveat. + customOptions []cel.EnvOption + + // isFrozen indicates whether the TypeSet is frozen. A frozen TypeSet cannot be modified. + isFrozen bool +} + +// Freeze marks the TypeSet as frozen. A frozen TypeSet cannot be modified. +func (ts *TypeSet) Freeze() { + ts.isFrozen = true +} + +// EnvOptions returns the set of environment options that can be used to create a CEL environment +// for the caveat. This includes the custom types and methods defined in the TypeSet. +func (ts *TypeSet) EnvOptions() ([]cel.EnvOption, error) { + if !ts.isFrozen { + return nil, fmt.Errorf("cannot get env options from a non-frozen TypeSet") + } + return ts.customOptions, nil +} + +// BuildType builds a variable type from its name and child types. +func (ts *TypeSet) BuildType(name string, childTypes []VariableType) (*VariableType, error) { + if !ts.isFrozen { + return nil, fmt.Errorf("cannot build types from a non-frozen TypeSet") + } + + typeDef, ok := ts.definitions[name] + if !ok { + return nil, fmt.Errorf("unknown type `%s`", name) + } + + return typeDef.asVariableType(childTypes) +} + +// NewTypeSet creates a new TypeSet. The TypeSet is not frozen and can be modified. +func NewTypeSet() *TypeSet { + return &TypeSet{ + definitions: map[string]typeDefinition{}, + customOptions: []cel.EnvOption{}, + isFrozen: false, + } +} |
