diff options
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal')
123 files changed, 23894 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/builder.go b/vendor/github.com/authzed/spicedb/internal/caveats/builder.go new file mode 100644 index 0000000..0c93d39 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/caveats/builder.go @@ -0,0 +1,152 @@ +package caveats + +import ( + "google.golang.org/protobuf/types/known/structpb" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// CaveatAsExpr wraps a contextualized caveat into a caveat expression. +func CaveatAsExpr(caveat *core.ContextualizedCaveat) *core.CaveatExpression { + if caveat == nil { + return nil + } + + return &core.CaveatExpression{ + OperationOrCaveat: &core.CaveatExpression_Caveat{ + Caveat: caveat, + }, + } +} + +// CaveatForTesting returns a new ContextualizedCaveat for testing, with empty context. +func CaveatForTesting(name string) *core.ContextualizedCaveat { + return &core.ContextualizedCaveat{ + CaveatName: name, + } +} + +// CaveatExprForTesting returns a CaveatExpression referencing a caveat with the given name and +// empty context. +func CaveatExprForTesting(name string) *core.CaveatExpression { + return &core.CaveatExpression{ + OperationOrCaveat: &core.CaveatExpression_Caveat{ + Caveat: CaveatForTesting(name), + }, + } +} + +// MustCaveatExprForTestingWithContext returns a CaveatExpression referencing a caveat with the given name and +// given context. +func MustCaveatExprForTestingWithContext(name string, context map[string]any) *core.CaveatExpression { + contextStruct, err := structpb.NewStruct(context) + if err != nil { + panic(err) + } + + return &core.CaveatExpression{ + OperationOrCaveat: &core.CaveatExpression_Caveat{ + Caveat: &core.ContextualizedCaveat{ + CaveatName: name, + Context: contextStruct, + }, + }, + } +} + +// ShortcircuitedOr combines two caveat expressions via an `||`. If one of the expressions is nil, +// then the entire expression is *short-circuited*, and a nil is returned. +func ShortcircuitedOr(first *core.CaveatExpression, second *core.CaveatExpression) *core.CaveatExpression { + if first == nil || second == nil { + return nil + } + + return Or(first, second) +} + +// Or `||`'s together two caveat expressions. If one expression is nil, the other is returned. +func Or(first *core.CaveatExpression, second *core.CaveatExpression) *core.CaveatExpression { + if first == nil { + return second + } + + if second == nil { + return first + } + + if first.EqualVT(second) { + return first + } + + return &core.CaveatExpression{ + OperationOrCaveat: &core.CaveatExpression_Operation{ + Operation: &core.CaveatOperation{ + Op: core.CaveatOperation_OR, + Children: []*core.CaveatExpression{first, second}, + }, + }, + } +} + +// And `&&`'s together two caveat expressions. If one expression is nil, the other is returned. +func And(first *core.CaveatExpression, second *core.CaveatExpression) *core.CaveatExpression { + if first == nil { + return second + } + + if second == nil { + return first + } + + if first.EqualVT(second) { + return first + } + + return &core.CaveatExpression{ + OperationOrCaveat: &core.CaveatExpression_Operation{ + Operation: &core.CaveatOperation{ + Op: core.CaveatOperation_AND, + Children: []*core.CaveatExpression{first, second}, + }, + }, + } +} + +// Invert returns the caveat expression with a `!` placed in front of it. If the expression is +// nil, returns nil. +func Invert(ce *core.CaveatExpression) *core.CaveatExpression { + if ce == nil { + return nil + } + + return &core.CaveatExpression{ + OperationOrCaveat: &core.CaveatExpression_Operation{ + Operation: &core.CaveatOperation{ + Op: core.CaveatOperation_NOT, + Children: []*core.CaveatExpression{ce}, + }, + }, + } +} + +// Subtract returns a caveat expression representing the subtracted expression subtracted from the given +// expression. +func Subtract(caveat *core.CaveatExpression, subtracted *core.CaveatExpression) *core.CaveatExpression { + inversion := Invert(subtracted) + if caveat == nil { + return inversion + } + + if subtracted == nil { + return caveat + } + + return &core.CaveatExpression{ + OperationOrCaveat: &core.CaveatExpression_Operation{ + Operation: &core.CaveatOperation{ + Op: core.CaveatOperation_AND, + Children: []*core.CaveatExpression{caveat, inversion}, + }, + }, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/debug.go b/vendor/github.com/authzed/spicedb/internal/caveats/debug.go new file mode 100644 index 0000000..bfcf62b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/caveats/debug.go @@ -0,0 +1,164 @@ +package caveats + +import ( + "fmt" + "maps" + "strconv" + "strings" + + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/spicedb/pkg/caveats" + "github.com/authzed/spicedb/pkg/genutil/mapz" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// BuildDebugInformation returns a human-readable string representation of the given +// ExpressionResult and a Struct representation of the context values used in the expression. +func BuildDebugInformation(exprResult ExpressionResult) (string, *structpb.Struct, error) { + // If a concrete result, return its information directly. + if concrete, ok := exprResult.(*caveats.CaveatResult); ok { + exprString, err := concrete.ParentCaveat().ExprString() + if err != nil { + return "", nil, err + } + + contextStruct, err := concrete.ContextStruct() + if err != nil { + return "", nil, err + } + + return exprString, contextStruct, nil + } + + // Collect parameters which are shared across expressions. + syntheticResult, ok := exprResult.(syntheticResult) + if !ok { + return "", nil, spiceerrors.MustBugf("unknown ExpressionResult type: %T", exprResult) + } + + resultsByParam := mapz.NewMultiMap[string, *caveats.CaveatResult]() + if err := collectParameterUsage(syntheticResult, resultsByParam); err != nil { + return "", nil, err + } + + // Build the synthetic debug information. + exprString, contextMap, err := buildDebugInformation(syntheticResult, resultsByParam) + if err != nil { + return "", nil, err + } + + // Convert the context map to a struct. + contextStruct, err := caveats.ConvertContextToStruct(contextMap) + if err != nil { + return "", nil, err + } + + return exprString, contextStruct, nil +} + +func buildDebugInformation(sr syntheticResult, resultsByParam *mapz.MultiMap[string, *caveats.CaveatResult]) (string, map[string]any, error) { + childExprStrings := make([]string, 0, len(sr.exprResultsForDebug)) + combinedContext := map[string]any{} + + for _, child := range sr.exprResultsForDebug { + if _, ok := child.(*caveats.CaveatResult); ok { + childExprString, contextMap, err := buildDebugInformationForConcrete(child.(*caveats.CaveatResult), resultsByParam) + if err != nil { + return "", nil, err + } + + childExprStrings = append(childExprStrings, "("+childExprString+")") + maps.Copy(combinedContext, contextMap) + continue + } + + childExprString, contextMap, err := buildDebugInformation(child.(syntheticResult), resultsByParam) + if err != nil { + return "", nil, err + } + + childExprStrings = append(childExprStrings, "("+childExprString+")") + maps.Copy(combinedContext, contextMap) + } + + var combinedExprString string + switch sr.op { + case corev1.CaveatOperation_AND: + combinedExprString = strings.Join(childExprStrings, " && ") + + case corev1.CaveatOperation_OR: + combinedExprString = strings.Join(childExprStrings, " || ") + + case corev1.CaveatOperation_NOT: + if len(childExprStrings) != 1 { + return "", nil, spiceerrors.MustBugf("NOT operator must have exactly one child") + } + + combinedExprString = "!" + childExprStrings[0] + + default: + return "", nil, fmt.Errorf("unknown operator: %v", sr.op) + } + + return combinedExprString, combinedContext, nil +} + +func buildDebugInformationForConcrete(cr *caveats.CaveatResult, resultsByParam *mapz.MultiMap[string, *caveats.CaveatResult]) (string, map[string]any, error) { + // For each paramter used in the context of the caveat, check if it is shared across multiple + // caveats. If so, rewrite the parameter to a unique name. + existingContextMap := cr.ContextValues() + contextMap := make(map[string]any, len(existingContextMap)) + + caveat := *cr.ParentCaveat() + + for paramName, paramValue := range existingContextMap { + index := mapz.IndexOfValueInMultimap(resultsByParam, paramName, cr) + if resultsByParam.CountOf(paramName) > 1 { + newName := paramName + "__" + strconv.Itoa(index) + if resultsByParam.Has(newName) { + return "", nil, fmt.Errorf("failed to generate unique name for parameter: %s", newName) + } + + rewritten, err := caveat.RewriteVariable(paramName, newName) + if err != nil { + return "", nil, err + } + + caveat = rewritten + contextMap[newName] = paramValue + continue + } + + contextMap[paramName] = paramValue + } + + exprString, err := caveat.ExprString() + if err != nil { + return "", nil, err + } + + return exprString, contextMap, nil +} + +func collectParameterUsage(sr syntheticResult, resultsByParam *mapz.MultiMap[string, *caveats.CaveatResult]) error { + for _, exprResult := range sr.exprResultsForDebug { + if concrete, ok := exprResult.(*caveats.CaveatResult); ok { + for paramName := range concrete.ContextValues() { + resultsByParam.Add(paramName, concrete) + } + } else { + cast, ok := exprResult.(syntheticResult) + if !ok { + return spiceerrors.MustBugf("unknown ExpressionResult type: %T", exprResult) + } + + if err := collectParameterUsage(cast, resultsByParam); err != nil { + return err + } + } + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/doc.go b/vendor/github.com/authzed/spicedb/internal/caveats/doc.go new file mode 100644 index 0000000..587d7d8 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/caveats/doc.go @@ -0,0 +1,2 @@ +// Package caveats contains code to evaluate a caveat with a given context. +package caveats diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/errors.go b/vendor/github.com/authzed/spicedb/internal/caveats/errors.go new file mode 100644 index 0000000..06284f2 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/caveats/errors.go @@ -0,0 +1,107 @@ +package caveats + +import ( + "errors" + "fmt" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/caveats" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// EvaluationError is an error in evaluation of a caveat expression. +type EvaluationError struct { + error + caveatExpr *core.CaveatExpression + evalErr caveats.EvaluationError +} + +// MarshalZerologObject implements zerolog.LogObjectMarshaler +func (err EvaluationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("caveat_name", err.caveatExpr.GetCaveat().CaveatName).Interface("context", err.caveatExpr.GetCaveat().Context) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err EvaluationError) DetailsMetadata() map[string]string { + return spiceerrors.CombineMetadata(err.evalErr, map[string]string{ + "caveat_name": err.caveatExpr.GetCaveat().CaveatName, + }) +} + +func (err EvaluationError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_CAVEAT_EVALUATION_ERROR, + err.DetailsMetadata(), + ), + ) +} + +func NewEvaluationError(caveatExpr *core.CaveatExpression, err caveats.EvaluationError) EvaluationError { + return EvaluationError{ + fmt.Errorf("evaluation error for caveat %s: %w", caveatExpr.GetCaveat().CaveatName, err), caveatExpr, err, + } +} + +// ParameterTypeError is a type error in constructing a parameter from a value. +type ParameterTypeError struct { + error + caveatExpr *core.CaveatExpression + conversionError *caveats.ParameterConversionError +} + +// MarshalZerologObject implements zerolog.LogObjectMarshaler +func (err ParameterTypeError) MarshalZerologObject(e *zerolog.Event) { + evt := e.Err(err.error). + Str("caveat_name", err.caveatExpr.GetCaveat().CaveatName). + Interface("context", err.caveatExpr.GetCaveat().Context) + + if err.conversionError != nil { + evt.Str("parameter_name", err.conversionError.ParameterName()) + } +} + +// DetailsMetadata returns the metadata for details for this error. +func (err ParameterTypeError) DetailsMetadata() map[string]string { + if err.conversionError != nil { + return spiceerrors.CombineMetadata(err.conversionError, map[string]string{ + "caveat_name": err.caveatExpr.GetCaveat().CaveatName, + }) + } + + return map[string]string{ + "caveat_name": err.caveatExpr.GetCaveat().CaveatName, + } +} + +func (err ParameterTypeError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_CAVEAT_PARAMETER_TYPE_ERROR, + err.DetailsMetadata(), + ), + ) +} + +func NewParameterTypeError(caveatExpr *core.CaveatExpression, err error) ParameterTypeError { + conversionError := &caveats.ParameterConversionError{} + if !errors.As(err, conversionError) { + conversionError = nil + } + + return ParameterTypeError{ + fmt.Errorf("type error for parameters for caveat `%s`: %w", caveatExpr.GetCaveat().CaveatName, err), + caveatExpr, + conversionError, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/run.go b/vendor/github.com/authzed/spicedb/internal/caveats/run.go new file mode 100644 index 0000000..1aed483 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/caveats/run.go @@ -0,0 +1,427 @@ +package caveats + +import ( + "context" + "errors" + "fmt" + "maps" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var tracer = otel.Tracer("spicedb/internal/caveats/run") + +// RunCaveatExpressionDebugOption are the options for running caveat expression evaluation +// with debugging enabled or disabled. +type RunCaveatExpressionDebugOption int + +const ( + // RunCaveatExpressionNoDebugging runs the evaluation without debugging enabled. + RunCaveatExpressionNoDebugging RunCaveatExpressionDebugOption = 0 + + // RunCaveatExpressionWithDebugInformation runs the evaluation with debugging enabled. + RunCaveatExpressionWithDebugInformation RunCaveatExpressionDebugOption = 1 +) + +// RunSingleCaveatExpression runs a caveat expression over the given context and returns the result. +// This instantiates its own CaveatRunner, and should therefore only be used in one-off situations. +func RunSingleCaveatExpression( + ctx context.Context, + ts *caveattypes.TypeSet, + expr *core.CaveatExpression, + context map[string]any, + reader datastore.CaveatReader, + debugOption RunCaveatExpressionDebugOption, +) (ExpressionResult, error) { + runner := NewCaveatRunner(ts) + return runner.RunCaveatExpression(ctx, expr, context, reader, debugOption) +} + +// CaveatRunner is a helper for running caveats, providing a cache for deserialized caveats. +type CaveatRunner struct { + caveatTypeSet *caveattypes.TypeSet + caveatDefs map[string]*core.CaveatDefinition + deserializedCaveats map[string]*caveats.CompiledCaveat +} + +// NewCaveatRunner creates a new CaveatRunner. +func NewCaveatRunner(ts *caveattypes.TypeSet) *CaveatRunner { + return &CaveatRunner{ + caveatTypeSet: ts, + caveatDefs: map[string]*core.CaveatDefinition{}, + deserializedCaveats: map[string]*caveats.CompiledCaveat{}, + } +} + +// RunCaveatExpression runs a caveat expression over the given context and returns the result. +func (cr *CaveatRunner) RunCaveatExpression( + ctx context.Context, + expr *core.CaveatExpression, + context map[string]any, + reader datastore.CaveatReader, + debugOption RunCaveatExpressionDebugOption, +) (ExpressionResult, error) { + ctx, span := tracer.Start(ctx, "RunCaveatExpression") + defer span.End() + + if err := cr.PopulateCaveatDefinitionsForExpr(ctx, expr, reader); err != nil { + return nil, err + } + + env := caveats.NewEnvironment() + return cr.runExpressionWithCaveats(ctx, env, expr, context, debugOption) +} + +// PopulateCaveatDefinitionsForExpr populates the CaveatRunner's cache with the definitions +// referenced in the given caveat expression. +func (cr *CaveatRunner) PopulateCaveatDefinitionsForExpr(ctx context.Context, expr *core.CaveatExpression, reader datastore.CaveatReader) error { + ctx, span := tracer.Start(ctx, "PopulateCaveatDefinitions") + defer span.End() + + // Collect all referenced caveat definitions in the expression. + caveatNames := mapz.NewSet[string]() + collectCaveatNames(expr, caveatNames) + + span.AddEvent("collected caveat names") + span.SetAttributes(attribute.StringSlice("caveat-names", caveatNames.AsSlice())) + + if caveatNames.IsEmpty() { + return fmt.Errorf("received empty caveat expression") + } + + // Remove any caveats already loaded. + for name := range cr.caveatDefs { + caveatNames.Delete(name) + } + + if caveatNames.IsEmpty() { + return nil + } + + // Bulk lookup all of the referenced caveat definitions. + caveatDefs, err := reader.LookupCaveatsWithNames(ctx, caveatNames.AsSlice()) + if err != nil { + return err + } + span.AddEvent("looked up caveats") + + for _, cd := range caveatDefs { + cr.caveatDefs[cd.Definition.GetName()] = cd.Definition + } + + return nil +} + +// get retrieves a caveat definition and its deserialized form. The caveat name must be +// present in the CaveatRunner's cache. +func (cr *CaveatRunner) get(caveatDefName string) (*core.CaveatDefinition, *caveats.CompiledCaveat, error) { + caveat, ok := cr.caveatDefs[caveatDefName] + if !ok { + return nil, nil, datastore.NewCaveatNameNotFoundErr(caveatDefName) + } + + deserialized, ok := cr.deserializedCaveats[caveatDefName] + if ok { + return caveat, deserialized, nil + } + + parameterTypes, err := caveattypes.DecodeParameterTypes(cr.caveatTypeSet, caveat.ParameterTypes) + if err != nil { + return nil, nil, err + } + + justDeserialized, err := caveats.DeserializeCaveatWithTypeSet(cr.caveatTypeSet, caveat.SerializedExpression, parameterTypes) + if err != nil { + return caveat, nil, err + } + + cr.deserializedCaveats[caveatDefName] = justDeserialized + return caveat, justDeserialized, nil +} + +func collectCaveatNames(expr *core.CaveatExpression, caveatNames *mapz.Set[string]) { + if expr.GetCaveat() != nil { + caveatNames.Add(expr.GetCaveat().CaveatName) + return + } + + cop := expr.GetOperation() + for _, child := range cop.Children { + collectCaveatNames(child, caveatNames) + } +} + +func (cr *CaveatRunner) runExpressionWithCaveats( + ctx context.Context, + env *caveats.Environment, + expr *core.CaveatExpression, + context map[string]any, + debugOption RunCaveatExpressionDebugOption, +) (ExpressionResult, error) { + ctx, span := tracer.Start(ctx, "runExpressionWithCaveats") + defer span.End() + + if expr.GetCaveat() != nil { + span.SetAttributes(attribute.String("caveat-name", expr.GetCaveat().CaveatName)) + + caveat, compiled, err := cr.get(expr.GetCaveat().CaveatName) + if err != nil { + return nil, err + } + + // Create a combined context, with the written context taking precedence over that specified. + untypedFullContext := maps.Clone(context) + if untypedFullContext == nil { + untypedFullContext = map[string]any{} + } + + relationshipContext := expr.GetCaveat().GetContext().AsMap() + maps.Copy(untypedFullContext, relationshipContext) + + // Perform type checking and conversion on the context map. + typedParameters, err := caveats.ConvertContextToParameters( + cr.caveatTypeSet, + untypedFullContext, + caveat.ParameterTypes, + caveats.SkipUnknownParameters, + ) + if err != nil { + return nil, NewParameterTypeError(expr, err) + } + + result, err := caveats.EvaluateCaveat(compiled, typedParameters) + if err != nil { + var evalErr caveats.EvaluationError + if errors.As(err, &evalErr) { + return nil, NewEvaluationError(expr, evalErr) + } + + return nil, err + } + + return result, nil + } + + cop := expr.GetOperation() + span.SetAttributes(attribute.String("caveat-operation", cop.Op.String())) + + var currentResult ExpressionResult = syntheticResult{ + value: cop.Op == core.CaveatOperation_AND, + isPartialResult: false, + } + + var exprResultsForDebug []ExpressionResult + if debugOption == RunCaveatExpressionWithDebugInformation { + exprResultsForDebug = make([]ExpressionResult, 0, len(cop.Children)) + } + + var missingVarNames *mapz.Set[string] + if debugOption == RunCaveatExpressionNoDebugging { + missingVarNames = mapz.NewSet[string]() + } + + and := func(existing ExpressionResult, found ExpressionResult) (ExpressionResult, error) { + if !existing.IsPartial() && !existing.Value() { + return syntheticResult{ + value: false, + op: core.CaveatOperation_AND, + exprResultsForDebug: exprResultsForDebug, + isPartialResult: false, + missingVarNames: nil, + }, nil + } + + if !found.IsPartial() && !found.Value() { + return syntheticResult{ + value: false, + op: core.CaveatOperation_AND, + exprResultsForDebug: exprResultsForDebug, + isPartialResult: false, + missingVarNames: nil, + }, nil + } + + value := existing.Value() && found.Value() + if existing.IsPartial() || found.IsPartial() { + value = false + } + + return syntheticResult{ + value: value, + op: core.CaveatOperation_AND, + exprResultsForDebug: exprResultsForDebug, + isPartialResult: existing.IsPartial() || found.IsPartial(), + missingVarNames: missingVarNames, + }, nil + } + + or := func(existing ExpressionResult, found ExpressionResult) (ExpressionResult, error) { + if !existing.IsPartial() && existing.Value() { + return syntheticResult{ + value: true, + op: core.CaveatOperation_OR, + exprResultsForDebug: exprResultsForDebug, + isPartialResult: false, + missingVarNames: nil, + }, nil + } + + if !found.IsPartial() && found.Value() { + return syntheticResult{ + value: true, + op: core.CaveatOperation_OR, + exprResultsForDebug: exprResultsForDebug, + isPartialResult: false, + missingVarNames: nil, + }, nil + } + + value := existing.Value() || found.Value() + if existing.IsPartial() || found.IsPartial() { + value = false + } + + return syntheticResult{ + value: value, + op: core.CaveatOperation_OR, + exprResultsForDebug: exprResultsForDebug, + isPartialResult: existing.IsPartial() || found.IsPartial(), + missingVarNames: missingVarNames, + }, nil + } + + invert := func(existing ExpressionResult) (ExpressionResult, error) { + value := !existing.Value() + if existing.IsPartial() { + value = false + } + + return syntheticResult{ + value: value, + op: core.CaveatOperation_NOT, + exprResultsForDebug: exprResultsForDebug, + isPartialResult: existing.IsPartial(), + missingVarNames: missingVarNames, + }, nil + } + + for _, child := range cop.Children { + childResult, err := cr.runExpressionWithCaveats(ctx, env, child, context, debugOption) + if err != nil { + return nil, err + } + + if debugOption != RunCaveatExpressionNoDebugging { + exprResultsForDebug = append(exprResultsForDebug, childResult) + } else if childResult.IsPartial() { + missingVars, err := childResult.MissingVarNames() + if err != nil { + return nil, err + } + + missingVarNames.Extend(missingVars) + } + + switch cop.Op { + case core.CaveatOperation_AND: + cr, err := and(currentResult, childResult) + if err != nil { + return nil, err + } + + currentResult = cr + if debugOption == RunCaveatExpressionNoDebugging && isFalseResult(currentResult) { + return currentResult, nil + } + + case core.CaveatOperation_OR: + cr, err := or(currentResult, childResult) + if err != nil { + return nil, err + } + + currentResult = cr + if debugOption == RunCaveatExpressionNoDebugging && isTrueResult(currentResult) { + return currentResult, nil + } + + case core.CaveatOperation_NOT: + return invert(childResult) + + default: + return nil, spiceerrors.MustBugf("unknown caveat operation: %v", cop.Op) + } + } + + return currentResult, nil +} + +// ExpressionResult is the result of a caveat expression being run. +// See also caveats.CaveatResult +type ExpressionResult interface { + // Value is the resolved value for the expression. For partially applied expressions, this value will be false. + Value() bool + + // IsPartial returns whether the expression was only partially applied. + IsPartial() bool + + // MissingVarNames returns the names of the parameters missing from the context. + MissingVarNames() ([]string, error) +} + +type syntheticResult struct { + value bool + isPartialResult bool + + op core.CaveatOperation_Operation + exprResultsForDebug []ExpressionResult + missingVarNames *mapz.Set[string] +} + +func (sr syntheticResult) Value() bool { + return sr.value +} + +func (sr syntheticResult) IsPartial() bool { + return sr.isPartialResult +} + +func (sr syntheticResult) MissingVarNames() ([]string, error) { + if sr.isPartialResult { + if sr.missingVarNames != nil { + return sr.missingVarNames.AsSlice(), nil + } + + missingVarNames := mapz.NewSet[string]() + for _, exprResult := range sr.exprResultsForDebug { + if exprResult.IsPartial() { + found, err := exprResult.MissingVarNames() + if err != nil { + return nil, err + } + + missingVarNames.Extend(found) + } + } + + return missingVarNames.AsSlice(), nil + } + + return nil, fmt.Errorf("not a partial value") +} + +func isFalseResult(result ExpressionResult) bool { + return !result.Value() && !result.IsPartial() +} + +func isTrueResult(result ExpressionResult) bool { + return result.Value() && !result.IsPartial() +} diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/basesubjectset.go b/vendor/github.com/authzed/spicedb/internal/datasets/basesubjectset.go new file mode 100644 index 0000000..80ab666 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datasets/basesubjectset.go @@ -0,0 +1,856 @@ +package datasets + +import ( + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/internal/caveats" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +var ( + caveatAnd = caveats.And + caveatOr = caveats.Or + caveatInvert = caveats.Invert + shortcircuitedOr = caveats.ShortcircuitedOr +) + +// Subject is a subject that can be placed into a BaseSubjectSet. It is defined in a generic +// manner to allow implementations that wrap BaseSubjectSet to add their own additional bookkeeping +// to the base implementation. +type Subject[T any] interface { + // GetSubjectId returns the ID of the subject. For wildcards, this should be `*`. + GetSubjectId() string + + // GetCaveatExpression returns the caveat expression for this subject, if it is conditional. + GetCaveatExpression() *core.CaveatExpression + + // GetExcludedSubjects returns the list of subjects excluded. Must only have values + // for wildcards and must never be nested. + GetExcludedSubjects() []T +} + +// BaseSubjectSet defines a set that tracks accessible subjects, their exclusions (if wildcards), +// and all conditional expressions applied due to caveats. +// +// It is generic to allow other implementations to define the kind of tracking information +// associated with each subject. +// +// NOTE: Unlike a traditional set, unions between wildcards and a concrete subject will result +// in *both* being present in the set, to maintain the proper set semantics around wildcards. +type BaseSubjectSet[T Subject[T]] struct { + constructor constructor[T] + concrete map[string]T + wildcard *handle[T] +} + +// NewBaseSubjectSet creates a new base subject set for use underneath well-typed implementation. +// +// The constructor function returns a new instance of type T for a particular subject ID. +func NewBaseSubjectSet[T Subject[T]](constructor constructor[T]) BaseSubjectSet[T] { + return BaseSubjectSet[T]{ + constructor: constructor, + concrete: map[string]T{}, + wildcard: newHandle[T](), + } +} + +// constructor defines a function for constructing a new instance of the Subject type T for +// a subject ID, its (optional) conditional expression, any excluded subjects, and any sources +// for bookkeeping. The sources are those other subjects that were combined to create the current +// subject. +type constructor[T Subject[T]] func(subjectID string, conditionalExpression *core.CaveatExpression, excludedSubjects []T, sources ...T) T + +// MustAdd adds the found subject to the set. This is equivalent to a Union operation between the +// existing set of subjects and a set containing the single subject, but modifies the set +// *in place*. +func (bss BaseSubjectSet[T]) MustAdd(foundSubject T) { + err := bss.Add(foundSubject) + if err != nil { + panic(err) + } +} + +// Add adds the found subject to the set. This is equivalent to a Union operation between the +// existing set of subjects and a set containing the single subject, but modifies the set +// *in place*. +func (bss BaseSubjectSet[T]) Add(foundSubject T) error { + if foundSubject.GetSubjectId() == tuple.PublicWildcard { + existing := bss.wildcard.getOrNil() + updated, err := unionWildcardWithWildcard(existing, foundSubject, bss.constructor) + if err != nil { + return err + } + + bss.wildcard.setOrNil(updated) + + for _, concrete := range bss.concrete { + updated = unionWildcardWithConcrete(updated, concrete, bss.constructor) + } + bss.wildcard.setOrNil(updated) + return nil + } + + var updatedOrNil *T + if updated, ok := bss.concrete[foundSubject.GetSubjectId()]; ok { + updatedOrNil = &updated + } + bss.setConcrete(foundSubject.GetSubjectId(), unionConcreteWithConcrete(updatedOrNil, &foundSubject, bss.constructor)) + + wildcard := bss.wildcard.getOrNil() + wildcard = unionWildcardWithConcrete(wildcard, foundSubject, bss.constructor) + bss.wildcard.setOrNil(wildcard) + return nil +} + +func (bss BaseSubjectSet[T]) setConcrete(subjectID string, subjectOrNil *T) { + if subjectOrNil == nil { + delete(bss.concrete, subjectID) + return + } + + subject := *subjectOrNil + bss.concrete[subject.GetSubjectId()] = subject +} + +// Subtract subtracts the given subject found the set. +func (bss BaseSubjectSet[T]) Subtract(toRemove T) { + if toRemove.GetSubjectId() == tuple.PublicWildcard { + for _, concrete := range bss.concrete { + bss.setConcrete(concrete.GetSubjectId(), subtractWildcardFromConcrete(concrete, toRemove, bss.constructor)) + } + + existing := bss.wildcard.getOrNil() + updatedWildcard, concretesToAdd := subtractWildcardFromWildcard(existing, toRemove, bss.constructor) + bss.wildcard.setOrNil(updatedWildcard) + for _, concrete := range concretesToAdd { + concrete := concrete + bss.setConcrete(concrete.GetSubjectId(), &concrete) + } + return + } + + if existing, ok := bss.concrete[toRemove.GetSubjectId()]; ok { + bss.setConcrete(toRemove.GetSubjectId(), subtractConcreteFromConcrete(existing, toRemove, bss.constructor)) + } + + wildcard, ok := bss.wildcard.get() + if ok { + bss.wildcard.setOrNil(subtractConcreteFromWildcard(wildcard, toRemove, bss.constructor)) + } +} + +// SubtractAll subtracts the other set of subjects from this set of subtracts, modifying this +// set *in place*. +func (bss BaseSubjectSet[T]) SubtractAll(other BaseSubjectSet[T]) { + for _, otherSubject := range other.AsSlice() { + bss.Subtract(otherSubject) + } +} + +// MustIntersectionDifference performs an intersection between this set and the other set, modifying +// this set *in place*. +func (bss BaseSubjectSet[T]) MustIntersectionDifference(other BaseSubjectSet[T]) { + err := bss.IntersectionDifference(other) + if err != nil { + panic(err) + } +} + +// IntersectionDifference performs an intersection between this set and the other set, modifying +// this set *in place*. +func (bss BaseSubjectSet[T]) IntersectionDifference(other BaseSubjectSet[T]) error { + // Intersect the wildcards of the sets, if any. + existingWildcard := bss.wildcard.getOrNil() + otherWildcard := other.wildcard.getOrNil() + + intersection, err := intersectWildcardWithWildcard(existingWildcard, otherWildcard, bss.constructor) + if err != nil { + return err + } + + bss.wildcard.setOrNil(intersection) + + // Intersect the concretes of each set, as well as with the wildcards. + updatedConcretes := make(map[string]T, len(bss.concrete)) + + for _, concreteSubject := range bss.concrete { + var otherConcreteOrNil *T + if otherConcrete, ok := other.concrete[concreteSubject.GetSubjectId()]; ok { + otherConcreteOrNil = &otherConcrete + } + + concreteIntersected := intersectConcreteWithConcrete(concreteSubject, otherConcreteOrNil, bss.constructor) + otherWildcardIntersected, err := intersectConcreteWithWildcard(concreteSubject, otherWildcard, bss.constructor) + if err != nil { + return err + } + + result := unionConcreteWithConcrete(concreteIntersected, otherWildcardIntersected, bss.constructor) + if result != nil { + updatedConcretes[concreteSubject.GetSubjectId()] = *result + } + } + + if existingWildcard != nil { + for _, otherSubject := range other.concrete { + existingWildcardIntersect, err := intersectConcreteWithWildcard(otherSubject, existingWildcard, bss.constructor) + if err != nil { + return err + } + + if existingUpdated, ok := updatedConcretes[otherSubject.GetSubjectId()]; ok { + result := unionConcreteWithConcrete(&existingUpdated, existingWildcardIntersect, bss.constructor) + updatedConcretes[otherSubject.GetSubjectId()] = *result + } else if existingWildcardIntersect != nil { + updatedConcretes[otherSubject.GetSubjectId()] = *existingWildcardIntersect + } + } + } + + clear(bss.concrete) + maps.Copy(bss.concrete, updatedConcretes) + return nil +} + +// UnionWith adds the given subjects to this set, via a union call. +func (bss BaseSubjectSet[T]) UnionWith(foundSubjects []T) error { + for _, fs := range foundSubjects { + err := bss.Add(fs) + if err != nil { + return err + } + } + return nil +} + +// UnionWithSet performs a union operation between this set and the other set, modifying this +// set *in place*. +func (bss BaseSubjectSet[T]) UnionWithSet(other BaseSubjectSet[T]) error { + return bss.UnionWith(other.AsSlice()) +} + +// MustUnionWithSet performs a union operation between this set and the other set, modifying this +// set *in place*. +func (bss BaseSubjectSet[T]) MustUnionWithSet(other BaseSubjectSet[T]) { + err := bss.UnionWithSet(other) + if err != nil { + panic(err) + } +} + +// Get returns the found subject with the given ID in the set, if any. +func (bss BaseSubjectSet[T]) Get(id string) (T, bool) { + if id == tuple.PublicWildcard { + return bss.wildcard.get() + } + + found, ok := bss.concrete[id] + return found, ok +} + +// IsEmpty returns whether the subject set is empty. +func (bss BaseSubjectSet[T]) IsEmpty() bool { + return bss.wildcard.getOrNil() == nil && len(bss.concrete) == 0 +} + +// AsSlice returns the contents of the subject set as a slice of found subjects. +func (bss BaseSubjectSet[T]) AsSlice() []T { + values := maps.Values(bss.concrete) + if wildcard, ok := bss.wildcard.get(); ok { + values = append(values, wildcard) + } + return values +} + +// Clone returns a clone of this subject set. Note that this is a shallow clone. +// NOTE: Should only be used when performance is not a concern. +func (bss BaseSubjectSet[T]) Clone() BaseSubjectSet[T] { + return BaseSubjectSet[T]{ + constructor: bss.constructor, + concrete: maps.Clone(bss.concrete), + wildcard: bss.wildcard.clone(), + } +} + +// UnsafeRemoveExact removes the *exact* matching subject, with no wildcard handling. +// This should ONLY be used for testing. +func (bss BaseSubjectSet[T]) UnsafeRemoveExact(foundSubject T) { + if foundSubject.GetSubjectId() == tuple.PublicWildcard { + bss.wildcard.clear() + return + } + + delete(bss.concrete, foundSubject.GetSubjectId()) +} + +// WithParentCaveatExpression returns a copy of the subject set with the parent caveat expression applied +// to all members of this set. +func (bss BaseSubjectSet[T]) WithParentCaveatExpression(parentCaveatExpr *core.CaveatExpression) BaseSubjectSet[T] { + clone := bss.Clone() + + // Apply the parent caveat expression to the wildcard, if any. + if wildcard, ok := clone.wildcard.get(); ok { + constructed := bss.constructor( + tuple.PublicWildcard, + caveatAnd(parentCaveatExpr, wildcard.GetCaveatExpression()), + wildcard.GetExcludedSubjects(), + wildcard, + ) + clone.wildcard.setOrNil(&constructed) + } + + // Apply the parent caveat expression to each concrete. + for subjectID, concrete := range clone.concrete { + clone.concrete[subjectID] = bss.constructor( + subjectID, + caveatAnd(parentCaveatExpr, concrete.GetCaveatExpression()), + nil, + concrete, + ) + } + + return clone +} + +// unionWildcardWithWildcard performs a union operation over two wildcards, returning the updated +// wildcard (if any). +func unionWildcardWithWildcard[T Subject[T]](existing *T, adding T, constructor constructor[T]) (*T, error) { + // If there is no existing wildcard, return the added one. + if existing == nil { + return &adding, nil + } + + // Otherwise, union together the conditionals for the wildcards and *intersect* their exclusion + // sets. + existingWildcard := *existing + expression := shortcircuitedOr(existingWildcard.GetCaveatExpression(), adding.GetCaveatExpression()) + + // Exclusion sets are intersected because if an exclusion is missing from one wildcard + // but not the other, the missing element will be, by definition, in that other wildcard. + // + // Examples: + // + // {*} + {*} => {*} + // {* - {user:tom}} + {*} => {*} + // {* - {user:tom}} + {* - {user:sarah}} => {*} + // {* - {user:tom, user:sarah}} + {* - {user:sarah}} => {* - {user:sarah}} + // {*}[c1] + {*} => {*} + // {*}[c1] + {*}[c2] => {*}[c1 || c2] + + // NOTE: since we're only using concretes here, it is safe to reuse the BaseSubjectSet itself. + exisingConcreteExclusions := NewBaseSubjectSet(constructor) + for _, excludedSubject := range existingWildcard.GetExcludedSubjects() { + if excludedSubject.GetSubjectId() == tuple.PublicWildcard { + return nil, spiceerrors.MustBugf("wildcards are not allowed in exclusions") + } + + err := exisingConcreteExclusions.Add(excludedSubject) + if err != nil { + return nil, err + } + } + + foundConcreteExclusions := NewBaseSubjectSet(constructor) + for _, excludedSubject := range adding.GetExcludedSubjects() { + if excludedSubject.GetSubjectId() == tuple.PublicWildcard { + return nil, spiceerrors.MustBugf("wildcards are not allowed in exclusions") + } + + err := foundConcreteExclusions.Add(excludedSubject) + if err != nil { + return nil, err + } + } + + err := exisingConcreteExclusions.IntersectionDifference(foundConcreteExclusions) + if err != nil { + return nil, err + } + + constructed := constructor( + tuple.PublicWildcard, + expression, + exisingConcreteExclusions.AsSlice(), + *existing, + adding) + return &constructed, nil +} + +// unionWildcardWithConcrete performs a union operation between a wildcard and a concrete subject +// being added to the set, returning the updated wildcard (if applicable). +func unionWildcardWithConcrete[T Subject[T]](existing *T, adding T, constructor constructor[T]) *T { + // If there is no existing wildcard, nothing more to do. + if existing == nil { + return nil + } + + // If the concrete is in the exclusion set, remove it if not conditional. Otherwise, mark + // it as conditional. + // + // Examples: + // {*} | {user:tom} => {*} (and user:tom in the concrete) + // {* - {user:tom}} | {user:tom} => {*} (and user:tom in the concrete) + // {* - {user:tom}[c1]} | {user:tom}[c2] => {* - {user:tom}[c1 && !c2]} (and user:tom in the concrete) + existingWildcard := *existing + updatedExclusions := make([]T, 0, len(existingWildcard.GetExcludedSubjects())) + for _, existingExclusion := range existingWildcard.GetExcludedSubjects() { + if existingExclusion.GetSubjectId() == adding.GetSubjectId() { + // If the conditional on the concrete is empty, then the concrete is always present, so + // we remove the exclusion entirely. + if adding.GetCaveatExpression() == nil { + continue + } + + // Otherwise, the conditional expression for the new exclusion is the existing expression && + // the *inversion* of the concrete's expression, as the exclusion will only apply if the + // concrete subject is not present and the exclusion's expression is true. + exclusionConditionalExpression := caveatAnd( + existingExclusion.GetCaveatExpression(), + caveatInvert(adding.GetCaveatExpression()), + ) + + updatedExclusions = append(updatedExclusions, constructor( + adding.GetSubjectId(), + exclusionConditionalExpression, + nil, + existingExclusion, + adding), + ) + } else { + updatedExclusions = append(updatedExclusions, existingExclusion) + } + } + + constructed := constructor( + tuple.PublicWildcard, + existingWildcard.GetCaveatExpression(), + updatedExclusions, + existingWildcard) + return &constructed +} + +// unionConcreteWithConcrete performs a union operation between two concrete subjects and returns +// the concrete subject produced, if any. +func unionConcreteWithConcrete[T Subject[T]](existing *T, adding *T, constructor constructor[T]) *T { + // Check for union with other concretes. + if existing == nil { + return adding + } + + if adding == nil { + return existing + } + + existingConcrete := *existing + addingConcrete := *adding + + // A union of a concrete subjects has the conditionals of each concrete merged. + constructed := constructor( + existingConcrete.GetSubjectId(), + shortcircuitedOr( + existingConcrete.GetCaveatExpression(), + addingConcrete.GetCaveatExpression(), + ), + nil, + existingConcrete, addingConcrete) + return &constructed +} + +// subtractWildcardFromWildcard performs a subtraction operation of wildcard from another, returning +// the updated wildcard (if any), as well as any concrete subjects produced by the subtraction +// operation due to exclusions. +func subtractWildcardFromWildcard[T Subject[T]](existing *T, toRemove T, constructor constructor[T]) (*T, []T) { + // If there is no existing wildcard, nothing more to do. + if existing == nil { + return nil, nil + } + + // If there is no condition on the wildcard and the new wildcard has no exclusions, then this wildcard goes away. + // Example: {*} - {*} => {} + if toRemove.GetCaveatExpression() == nil && len(toRemove.GetExcludedSubjects()) == 0 { + return nil, nil + } + + // Otherwise, we construct a new wildcard and return any concrete subjects that might result from this subtraction. + existingWildcard := *existing + existingExclusions := exclusionsMapFor(existingWildcard) + + // Calculate the exclusions which turn into concrete subjects. + // This occurs when a wildcard with exclusions is subtracted from a wildcard + // (with, or without *matching* exclusions). + // + // Example: + // Given the two wildcards `* - {user:sarah}` and `* - {user:tom, user:amy, user:sarah}`, + // the resulting concrete subjects are {user:tom, user:amy} because the first set contains + // `tom` and `amy` (but not `sarah`) and the second set contains all three. + resultingConcreteSubjects := make([]T, 0, len(toRemove.GetExcludedSubjects())) + for _, excludedSubject := range toRemove.GetExcludedSubjects() { + if existingExclusion, isExistingExclusion := existingExclusions[excludedSubject.GetSubjectId()]; !isExistingExclusion || existingExclusion.GetCaveatExpression() != nil { + // The conditional expression for the now-concrete subject type is the conditional on the provided exclusion + // itself. + // + // As an example, subtracting the wildcards + // {*[caveat1] - {user:tom}} + // - + // {*[caveat3] - {user:sarah[caveat4]}} + // + // the resulting expression to produce a *concrete* `user:sarah` is + // `caveat1 && caveat3 && caveat4`, because the concrete subject only appears if the first + // wildcard applies, the *second* wildcard applies and its exclusion applies. + exclusionConditionalExpression := caveatAnd( + caveatAnd( + existingWildcard.GetCaveatExpression(), + toRemove.GetCaveatExpression(), + ), + excludedSubject.GetCaveatExpression(), + ) + + // If there is an existing exclusion, then its caveat expression is added as well, but inverted. + // + // As an example, subtracting the wildcards + // {*[caveat1] - {user:tom[caveat2]}} + // - + // {*[caveat3] - {user:sarah[caveat4]}} + // + // the resulting expression to produce a *concrete* `user:sarah` is + // `caveat1 && !caveat2 && caveat3 && caveat4`, because the concrete subject only appears + // if the first wildcard applies, the *second* wildcard applies, the first exclusion + // does *not* apply (ensuring the concrete is in the first wildcard) and the second exclusion + // *does* apply (ensuring it is not in the second wildcard). + if existingExclusion.GetCaveatExpression() != nil { + exclusionConditionalExpression = caveatAnd( + caveatAnd( + caveatAnd( + existingWildcard.GetCaveatExpression(), + toRemove.GetCaveatExpression(), + ), + caveatInvert(existingExclusion.GetCaveatExpression()), + ), + excludedSubject.GetCaveatExpression(), + ) + } + + resultingConcreteSubjects = append(resultingConcreteSubjects, constructor( + excludedSubject.GetSubjectId(), + exclusionConditionalExpression, + nil, excludedSubject)) + } + } + + // Create the combined conditional: the wildcard can only exist when it is present and the other wildcard is not. + combinedConditionalExpression := caveatAnd(existingWildcard.GetCaveatExpression(), caveatInvert(toRemove.GetCaveatExpression())) + if combinedConditionalExpression != nil { + constructed := constructor( + tuple.PublicWildcard, + combinedConditionalExpression, + existingWildcard.GetExcludedSubjects(), + existingWildcard, + toRemove) + return &constructed, resultingConcreteSubjects + } + + return nil, resultingConcreteSubjects +} + +// subtractWildcardFromConcrete subtracts a wildcard from a concrete element, returning the updated +// concrete subject, if any. +func subtractWildcardFromConcrete[T Subject[T]](existingConcrete T, wildcardToRemove T, constructor constructor[T]) *T { + // Subtraction of a wildcard removes *all* elements of the concrete set, except those that + // are found in the excluded list. If the wildcard *itself* is conditional, then instead of + // items being removed, they are made conditional on the inversion of the wildcard's expression, + // and the exclusion's conditional, if any. + // + // Examples: + // {user:sarah, user:tom} - {*} => {} + // {user:sarah, user:tom} - {*[somecaveat]} => {user:sarah[!somecaveat], user:tom[!somecaveat]} + // {user:sarah, user:tom} - {* - {user:tom}} => {user:tom} + // {user:sarah, user:tom} - {*[somecaveat] - {user:tom}} => {user:sarah[!somecaveat], user:tom} + // {user:sarah, user:tom} - {* - {user:tom[c2]}}[somecaveat] => {user:sarah[!somecaveat], user:tom[c2]} + // {user:sarah[c1], user:tom} - {*[somecaveat] - {user:tom}} => {user:sarah[c1 && !somecaveat], user:tom} + exclusions := exclusionsMapFor(wildcardToRemove) + exclusion, isExcluded := exclusions[existingConcrete.GetSubjectId()] + if !isExcluded { + // If the subject was not excluded within the wildcard, it is either removed directly + // (in the case where the wildcard is not conditional), or has its condition updated to + // reflect that it is only present when the condition for the wildcard is *false*. + if wildcardToRemove.GetCaveatExpression() == nil { + return nil + } + + constructed := constructor( + existingConcrete.GetSubjectId(), + caveatAnd(existingConcrete.GetCaveatExpression(), caveatInvert(wildcardToRemove.GetCaveatExpression())), + nil, + existingConcrete) + return &constructed + } + + // If the exclusion is not conditional, then the subject is always present. + if exclusion.GetCaveatExpression() == nil { + return &existingConcrete + } + + // The conditional of the exclusion is that of the exclusion itself OR the caveatInverted case of + // the wildcard, which would mean the wildcard itself does not apply. + exclusionConditional := caveatOr(caveatInvert(wildcardToRemove.GetCaveatExpression()), exclusion.GetCaveatExpression()) + + constructed := constructor( + existingConcrete.GetSubjectId(), + caveatAnd(existingConcrete.GetCaveatExpression(), exclusionConditional), + nil, + existingConcrete) + return &constructed +} + +// subtractConcreteFromConcrete subtracts a concrete subject from another concrete subject. +func subtractConcreteFromConcrete[T Subject[T]](existingConcrete T, toRemove T, constructor constructor[T]) *T { + // Subtraction of a concrete type removes the entry from the concrete list + // *unless* the subtraction is conditional, in which case the conditional is updated + // to remove the element when it is true. + // + // Examples: + // {user:sarah} - {user:tom} => {user:sarah} + // {user:tom} - {user:tom} => {} + // {user:tom[c1]} - {user:tom} => {user:tom} + // {user:tom} - {user:tom[c2]} => {user:tom[!c2]} + // {user:tom[c1]} - {user:tom[c2]} => {user:tom[c1 && !c2]} + if toRemove.GetCaveatExpression() == nil { + return nil + } + + // Otherwise, adjust the conditional of the existing item to remove it if it is true. + expression := caveatAnd( + existingConcrete.GetCaveatExpression(), + caveatInvert( + toRemove.GetCaveatExpression(), + ), + ) + + constructed := constructor( + existingConcrete.GetSubjectId(), + expression, + nil, + existingConcrete, toRemove) + return &constructed +} + +// subtractConcreteFromWildcard subtracts a concrete element from a wildcard. +func subtractConcreteFromWildcard[T Subject[T]](wildcard T, concreteToRemove T, constructor constructor[T]) *T { + // Subtracting a concrete type from a wildcard adds the concrete to the exclusions for the wildcard. + // Examples: + // {*} - {user:tom} => {* - {user:tom}} + // {*} - {user:tom[c1]} => {* - {user:tom[c1]}} + // {* - {user:tom[c1]}} - {user:tom} => {* - {user:tom}} + // {* - {user:tom[c1]}} - {user:tom[c2]} => {* - {user:tom[c1 || c2]}} + updatedExclusions := make([]T, 0, len(wildcard.GetExcludedSubjects())+1) + wasFound := false + for _, existingExclusion := range wildcard.GetExcludedSubjects() { + if existingExclusion.GetSubjectId() == concreteToRemove.GetSubjectId() { + // The conditional expression for the exclusion is a combination on the existing exclusion or + // the new expression. The caveat is short-circuited here because if either the exclusion or + // the concrete is non-caveated, then the whole exclusion is non-caveated. + exclusionConditionalExpression := shortcircuitedOr( + existingExclusion.GetCaveatExpression(), + concreteToRemove.GetCaveatExpression(), + ) + + updatedExclusions = append(updatedExclusions, constructor( + concreteToRemove.GetSubjectId(), + exclusionConditionalExpression, + nil, + existingExclusion, + concreteToRemove), + ) + wasFound = true + } else { + updatedExclusions = append(updatedExclusions, existingExclusion) + } + } + + if !wasFound { + updatedExclusions = append(updatedExclusions, concreteToRemove) + } + + constructed := constructor( + tuple.PublicWildcard, + wildcard.GetCaveatExpression(), + updatedExclusions, + wildcard) + return &constructed +} + +// intersectConcreteWithConcrete performs intersection between two concrete subjects, returning the +// resolved concrete subject, if any. +func intersectConcreteWithConcrete[T Subject[T]](first T, second *T, constructor constructor[T]) *T { + // Intersection of concrete subjects is a standard intersection operation, where subjects + // must be in both sets, with a combination of the two elements into one for conditionals. + // Otherwise, `and` together conditionals. + if second == nil { + return nil + } + + secondConcrete := *second + constructed := constructor( + first.GetSubjectId(), + caveatAnd(first.GetCaveatExpression(), secondConcrete.GetCaveatExpression()), + nil, + first, + secondConcrete) + + return &constructed +} + +// intersectWildcardWithWildcard performs intersection between two wildcards, returning the resolved +// wildcard subject, if any. +func intersectWildcardWithWildcard[T Subject[T]](first *T, second *T, constructor constructor[T]) (*T, error) { + // If either wildcard does not exist, then no wildcard is placed into the resulting set. + if first == nil || second == nil { + return nil, nil + } + + // If the other wildcard exists, then the intersection between the two wildcards is an && of + // their conditionals, and a *union* of their exclusions. + firstWildcard := *first + secondWildcard := *second + + concreteExclusions := NewBaseSubjectSet(constructor) + for _, excludedSubject := range firstWildcard.GetExcludedSubjects() { + if excludedSubject.GetSubjectId() == tuple.PublicWildcard { + return nil, spiceerrors.MustBugf("wildcards are not allowed in exclusions") + } + + err := concreteExclusions.Add(excludedSubject) + if err != nil { + return nil, err + } + } + + for _, excludedSubject := range secondWildcard.GetExcludedSubjects() { + if excludedSubject.GetSubjectId() == tuple.PublicWildcard { + return nil, spiceerrors.MustBugf("wildcards are not allowed in exclusions") + } + + err := concreteExclusions.Add(excludedSubject) + if err != nil { + return nil, err + } + } + + constructed := constructor( + tuple.PublicWildcard, + caveatAnd(firstWildcard.GetCaveatExpression(), secondWildcard.GetCaveatExpression()), + concreteExclusions.AsSlice(), + firstWildcard, + secondWildcard) + return &constructed, nil +} + +// intersectConcreteWithWildcard performs intersection between a concrete subject and a wildcard +// subject, returning the concrete, if any. +func intersectConcreteWithWildcard[T Subject[T]](concrete T, wildcard *T, constructor constructor[T]) (*T, error) { + // If no wildcard exists, then the concrete cannot exist (for this branch) + if wildcard == nil { + return nil, nil + } + + wildcardToIntersect := *wildcard + exclusionsMap := exclusionsMapFor(wildcardToIntersect) + exclusion, isExcluded := exclusionsMap[concrete.GetSubjectId()] + + // Cases: + // - The concrete subject is not excluded and the wildcard is not conditional => concrete is kept + // - The concrete subject is excluded and the wildcard is not conditional but the exclusion *is* conditional => concrete is made conditional + // - The concrete subject is excluded and the wildcard is not conditional => concrete is removed + // - The concrete subject is not excluded but the wildcard is conditional => concrete is kept, but made conditional + // - The concrete subject is excluded and the wildcard is conditional => concrete is removed, since it is always excluded + // - The concrete subject is excluded and the wildcard is conditional and the exclusion is conditional => combined conditional + switch { + case !isExcluded && wildcardToIntersect.GetCaveatExpression() == nil: + // If the concrete is not excluded and the wildcard conditional is empty, then the concrete is always found. + // Example: {user:tom} & {*} => {user:tom} + return &concrete, nil + + case !isExcluded && wildcardToIntersect.GetCaveatExpression() != nil: + // The concrete subject is only included if the wildcard's caveat is true. + // Example: {user:tom}[acaveat] & {* - user:tom}[somecaveat] => {user:tom}[acaveat && somecaveat] + constructed := constructor( + concrete.GetSubjectId(), + caveatAnd(concrete.GetCaveatExpression(), wildcardToIntersect.GetCaveatExpression()), + nil, + concrete, + wildcardToIntersect) + return &constructed, nil + + case isExcluded && exclusion.GetCaveatExpression() == nil: + // If the concrete is excluded and the exclusion is not conditional, then the concrete can never show up, + // regardless of whether the wildcard is conditional. + // Example: {user:tom} & {* - user:tom}[somecaveat] => {} + return nil, nil + + case isExcluded && exclusion.GetCaveatExpression() != nil: + // NOTE: whether the wildcard is itself conditional or not is handled within the expression combinators below. + // The concrete subject is included if the wildcard's caveat is true and the exclusion's caveat is *false*. + // Example: {user:tom}[acaveat] & {* - user:tom[ecaveat]}[wcaveat] => {user:tom[acaveat && wcaveat && !ecaveat]} + constructed := constructor( + concrete.GetSubjectId(), + caveatAnd( + concrete.GetCaveatExpression(), + caveatAnd( + wildcardToIntersect.GetCaveatExpression(), + caveatInvert(exclusion.GetCaveatExpression()), + )), + nil, + concrete, + wildcardToIntersect, + exclusion) + return &constructed, nil + + default: + return nil, spiceerrors.MustBugf("unhandled case in basesubjectset intersectConcreteWithWildcard: %v & %v", concrete, wildcardToIntersect) + } +} + +type handle[T any] struct { + value *T +} + +func newHandle[T any]() *handle[T] { + return &handle[T]{} +} + +func (h *handle[T]) getOrNil() *T { + return h.value +} + +func (h *handle[T]) setOrNil(value *T) { + h.value = value +} + +func (h *handle[T]) get() (T, bool) { + if h.value != nil { + return *h.value, true + } + + return *new(T), false +} + +func (h *handle[T]) clear() { + h.value = nil +} + +func (h *handle[T]) clone() *handle[T] { + return &handle[T]{ + value: h.value, + } +} + +// exclusionsMapFor creates a map of all the exclusions on a wildcard, by subject ID. +func exclusionsMapFor[T Subject[T]](wildcard T) map[string]T { + exclusions := make(map[string]T, len(wildcard.GetExcludedSubjects())) + for _, excludedSubject := range wildcard.GetExcludedSubjects() { + exclusions[excludedSubject.GetSubjectId()] = excludedSubject + } + return exclusions +} diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/doc.go b/vendor/github.com/authzed/spicedb/internal/datasets/doc.go new file mode 100644 index 0000000..6ff324c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datasets/doc.go @@ -0,0 +1,2 @@ +// Package datasets defines operations with sets of subjects. +package datasets diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/subjectset.go b/vendor/github.com/authzed/spicedb/internal/datasets/subjectset.go new file mode 100644 index 0000000..551bfaa --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datasets/subjectset.go @@ -0,0 +1,65 @@ +package datasets + +import ( + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" +) + +// SubjectSet defines a set that tracks accessible subjects. +// +// NOTE: Unlike a traditional set, unions between wildcards and a concrete subject will result +// in *both* being present in the set, to maintain the proper set semantics around wildcards. +type SubjectSet struct { + BaseSubjectSet[*v1.FoundSubject] +} + +// NewSubjectSet creates and returns a new subject set. +func NewSubjectSet() SubjectSet { + return SubjectSet{ + BaseSubjectSet: NewBaseSubjectSet(subjectSetConstructor), + } +} + +func (ss SubjectSet) SubtractAll(other SubjectSet) { + ss.BaseSubjectSet.SubtractAll(other.BaseSubjectSet) +} + +func (ss SubjectSet) MustIntersectionDifference(other SubjectSet) { + ss.BaseSubjectSet.MustIntersectionDifference(other.BaseSubjectSet) +} + +func (ss SubjectSet) IntersectionDifference(other SubjectSet) error { + return ss.BaseSubjectSet.IntersectionDifference(other.BaseSubjectSet) +} + +func (ss SubjectSet) MustUnionWithSet(other SubjectSet) { + ss.BaseSubjectSet.MustUnionWithSet(other.BaseSubjectSet) +} + +func (ss SubjectSet) Clone() SubjectSet { + return SubjectSet{ss.BaseSubjectSet.Clone()} +} + +func (ss SubjectSet) UnionWithSet(other SubjectSet) error { + return ss.BaseSubjectSet.UnionWithSet(other.BaseSubjectSet) +} + +// WithParentCaveatExpression returns a copy of the subject set with the parent caveat expression applied +// to all members of this set. +func (ss SubjectSet) WithParentCaveatExpression(parentCaveatExpr *core.CaveatExpression) SubjectSet { + return SubjectSet{ss.BaseSubjectSet.WithParentCaveatExpression(parentCaveatExpr)} +} + +func (ss SubjectSet) AsFoundSubjects() *v1.FoundSubjects { + return &v1.FoundSubjects{ + FoundSubjects: ss.AsSlice(), + } +} + +func subjectSetConstructor(subjectID string, caveatExpression *core.CaveatExpression, excludedSubjects []*v1.FoundSubject, _ ...*v1.FoundSubject) *v1.FoundSubject { + return &v1.FoundSubject{ + SubjectId: subjectID, + CaveatExpression: caveatExpression, + ExcludedSubjects: excludedSubjects, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbyresourceid.go b/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbyresourceid.go new file mode 100644 index 0000000..5b1ba13 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbyresourceid.go @@ -0,0 +1,117 @@ +package datasets + +import ( + "fmt" + + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +// NewSubjectSetByResourceID creates and returns a map of subject sets, indexed by resource ID. +func NewSubjectSetByResourceID() SubjectSetByResourceID { + return SubjectSetByResourceID{ + subjectSetByResourceID: map[string]SubjectSet{}, + } +} + +// SubjectSetByResourceID defines a helper type which maps from a resource ID to its associated found +// subjects, in the form of a subject set per resource ID. +type SubjectSetByResourceID struct { + subjectSetByResourceID map[string]SubjectSet +} + +func (ssr SubjectSetByResourceID) add(resourceID string, subject *v1.FoundSubject) error { + if subject == nil { + return fmt.Errorf("cannot add a nil subject to SubjectSetByResourceID") + } + + _, ok := ssr.subjectSetByResourceID[resourceID] + if !ok { + ssr.subjectSetByResourceID[resourceID] = NewSubjectSet() + } + return ssr.subjectSetByResourceID[resourceID].Add(subject) +} + +// AddFromRelationship adds the subject found in the given relationship to this map, indexed at +// the resource ID specified in the relationship. +func (ssr SubjectSetByResourceID) AddFromRelationship(relationship tuple.Relationship) error { + return ssr.add(relationship.Resource.ObjectID, &v1.FoundSubject{ + SubjectId: relationship.Subject.ObjectID, + CaveatExpression: wrapCaveat(relationship.OptionalCaveat), + }) +} + +// UnionWith unions the map's sets with the other map of sets provided. +func (ssr SubjectSetByResourceID) UnionWith(other map[string]*v1.FoundSubjects) error { + for resourceID, subjects := range other { + if subjects == nil { + return fmt.Errorf("received nil FoundSubjects in other map of SubjectSetByResourceID's UnionWith for key %s", resourceID) + } + + for _, subject := range subjects.FoundSubjects { + if err := ssr.add(resourceID, subject); err != nil { + return err + } + } + } + + return nil +} + +// IntersectionDifference performs an in-place intersection between the two maps' sets. +func (ssr SubjectSetByResourceID) IntersectionDifference(other SubjectSetByResourceID) error { + for otherResourceID, otherSubjectSet := range other.subjectSetByResourceID { + existing, ok := ssr.subjectSetByResourceID[otherResourceID] + if !ok { + continue + } + + err := existing.IntersectionDifference(otherSubjectSet) + if err != nil { + return err + } + + if existing.IsEmpty() { + delete(ssr.subjectSetByResourceID, otherResourceID) + } + } + + for existingResourceID := range ssr.subjectSetByResourceID { + _, ok := other.subjectSetByResourceID[existingResourceID] + if !ok { + delete(ssr.subjectSetByResourceID, existingResourceID) + continue + } + } + + return nil +} + +// SubtractAll subtracts all sets in the other map from this map's sets. +func (ssr SubjectSetByResourceID) SubtractAll(other SubjectSetByResourceID) { + for otherResourceID, otherSubjectSet := range other.subjectSetByResourceID { + existing, ok := ssr.subjectSetByResourceID[otherResourceID] + if !ok { + continue + } + + existing.SubtractAll(otherSubjectSet) + if existing.IsEmpty() { + delete(ssr.subjectSetByResourceID, otherResourceID) + } + } +} + +// IsEmpty returns true if the map is empty. +func (ssr SubjectSetByResourceID) IsEmpty() bool { + return len(ssr.subjectSetByResourceID) == 0 +} + +// AsMap converts the map into a map for storage in a proto. +func (ssr SubjectSetByResourceID) AsMap() map[string]*v1.FoundSubjects { + mapped := make(map[string]*v1.FoundSubjects, len(ssr.subjectSetByResourceID)) + for resourceID, subjectsSet := range ssr.subjectSetByResourceID { + mapped[resourceID] = subjectsSet.AsFoundSubjects() + } + return mapped +} diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbytype.go b/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbytype.go new file mode 100644 index 0000000..8882a2e --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbytype.go @@ -0,0 +1,113 @@ +package datasets + +import ( + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +// SubjectByTypeSet is a set of SubjectSet's, grouped by their subject types. +type SubjectByTypeSet struct { + byType map[string]SubjectSet +} + +// NewSubjectByTypeSet creates and returns a new SubjectByTypeSet. +func NewSubjectByTypeSet() *SubjectByTypeSet { + return &SubjectByTypeSet{ + byType: map[string]SubjectSet{}, + } +} + +// AddSubjectOf adds the subject found in the given relationship, along with its caveat. +func (s *SubjectByTypeSet) AddSubjectOf(relationship tuple.Relationship) error { + return s.AddSubject(relationship.Subject, relationship.OptionalCaveat) +} + +// AddConcreteSubject adds a non-caveated subject to the set. +func (s *SubjectByTypeSet) AddConcreteSubject(subject tuple.ObjectAndRelation) error { + return s.AddSubject(subject, nil) +} + +// AddSubject adds the specified subject to the set. +func (s *SubjectByTypeSet) AddSubject(subject tuple.ObjectAndRelation, caveat *core.ContextualizedCaveat) error { + key := tuple.JoinRelRef(subject.ObjectType, subject.Relation) + if _, ok := s.byType[key]; !ok { + s.byType[key] = NewSubjectSet() + } + + return s.byType[key].Add(&v1.FoundSubject{ + SubjectId: subject.ObjectID, + CaveatExpression: wrapCaveat(caveat), + }) +} + +// ForEachType invokes the handler for each type of ObjectAndRelation found in the set, along +// with all IDs of objects of that type. +func (s *SubjectByTypeSet) ForEachType(handler func(rr *core.RelationReference, subjects SubjectSet)) { + for key, subjects := range s.byType { + ns, rel := tuple.MustSplitRelRef(key) + handler(&core.RelationReference{ + Namespace: ns, + Relation: rel, + }, subjects) + } +} + +// Map runs the mapper function over each type of object in the set, returning a new SubjectByTypeSet with +// the object type replaced by that returned by the mapper function. +func (s *SubjectByTypeSet) Map(mapper func(rr *core.RelationReference) (*core.RelationReference, error)) (*SubjectByTypeSet, error) { + mapped := NewSubjectByTypeSet() + for key, subjectset := range s.byType { + ns, rel := tuple.MustSplitRelRef(key) + updatedType, err := mapper(&core.RelationReference{ + Namespace: ns, + Relation: rel, + }) + if err != nil { + return nil, err + } + if updatedType == nil { + continue + } + + key := tuple.JoinRelRef(updatedType.Namespace, updatedType.Relation) + if existing, ok := mapped.byType[key]; ok { + cloned := subjectset.Clone() + if err := cloned.UnionWithSet(existing); err != nil { + return nil, err + } + mapped.byType[key] = cloned + } else { + mapped.byType[key] = subjectset + } + } + return mapped, nil +} + +// IsEmpty returns true if the set is empty. +func (s *SubjectByTypeSet) IsEmpty() bool { + return len(s.byType) == 0 +} + +// Len returns the number of keys in the set. +func (s *SubjectByTypeSet) Len() int { + return len(s.byType) +} + +// SubjectSetForType returns the subject set associated with the given subject type, if any. +func (s *SubjectByTypeSet) SubjectSetForType(rr *core.RelationReference) (SubjectSet, bool) { + found, ok := s.byType[tuple.JoinRelRef(rr.Namespace, rr.Relation)] + return found, ok +} + +func wrapCaveat(caveat *core.ContextualizedCaveat) *core.CaveatExpression { + if caveat == nil { + return nil + } + + return &core.CaveatExpression{ + OperationOrCaveat: &core.CaveatExpression_Caveat{ + Caveat: caveat, + }, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go new file mode 100644 index 0000000..291abb5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go @@ -0,0 +1,352 @@ +package common + +import ( + "context" + "sort" + + "golang.org/x/exp/maps" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/ccoveille/go-safecast" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +const ( + nsPrefix = "n$" + caveatPrefix = "c$" +) + +// Changes represents a set of datastore mutations that are kept self-consistent +// across one or more transaction revisions. +type Changes[R datastore.Revision, K comparable] struct { + records map[K]changeRecord[R] + keyFunc func(R) K + content datastore.WatchContent + maxByteSize uint64 + currentByteSize int64 +} + +type changeRecord[R datastore.Revision] struct { + rev R + relTouches map[string]tuple.Relationship + relDeletes map[string]tuple.Relationship + definitionsChanged map[string]datastore.SchemaDefinition + namespacesDeleted map[string]struct{} + caveatsDeleted map[string]struct{} + metadata map[string]any +} + +// NewChanges creates a new Changes object for change tracking and de-duplication. +func NewChanges[R datastore.Revision, K comparable](keyFunc func(R) K, content datastore.WatchContent, maxByteSize uint64) *Changes[R, K] { + return &Changes[R, K]{ + records: make(map[K]changeRecord[R], 0), + keyFunc: keyFunc, + content: content, + maxByteSize: maxByteSize, + currentByteSize: 0, + } +} + +// IsEmpty returns if the change set is empty. +func (ch *Changes[R, K]) IsEmpty() bool { + return len(ch.records) == 0 +} + +// AddRelationshipChange adds a specific change to the complete list of tracked changes +func (ch *Changes[R, K]) AddRelationshipChange( + ctx context.Context, + rev R, + rel tuple.Relationship, + op tuple.UpdateOperation, +) error { + if ch.content&datastore.WatchRelationships != datastore.WatchRelationships { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + key := tuple.StringWithoutCaveatOrExpiration(rel) + + switch op { + case tuple.UpdateOperationTouch: + // If there was a delete for the same tuple at the same revision, drop it + existing, ok := record.relDeletes[key] + if ok { + delete(record.relDeletes, key) + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + + record.relTouches[key] = rel + if err := ch.adjustByteSize(rel, 1); err != nil { + return err + } + + case tuple.UpdateOperationDelete: + _, alreadyTouched := record.relTouches[key] + if !alreadyTouched { + record.relDeletes[key] = rel + if err := ch.adjustByteSize(rel, 1); err != nil { + return err + } + } + + default: + return spiceerrors.MustBugf("unknown change operation") + } + + return nil +} + +type sized interface { + SizeVT() int +} + +func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error { + if ch.maxByteSize == 0 { + return nil + } + + size := item.SizeVT() * delta + ch.currentByteSize += int64(size) + if ch.currentByteSize < 0 { + return spiceerrors.MustBugf("byte size underflow") + } + + currentByteSize, err := safecast.ToUint64(ch.currentByteSize) + if err != nil { + return spiceerrors.MustBugf("could not cast currentByteSize to uint64: %v", err) + } + + if currentByteSize > ch.maxByteSize { + return datastore.NewMaximumChangesSizeExceededError(ch.maxByteSize) + } + + return nil +} + +// SetRevisionMetadata sets the metadata for the given revision. +func (ch *Changes[R, K]) SetRevisionMetadata(ctx context.Context, rev R, metadata map[string]any) error { + if len(metadata) == 0 { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + if len(record.metadata) > 0 { + return spiceerrors.MustBugf("metadata already set for revision") + } + + maps.Copy(record.metadata, metadata) + return nil +} + +func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) { + k := ch.keyFunc(rev) + revisionChanges, ok := ch.records[k] + if !ok { + revisionChanges = changeRecord[R]{ + rev, + make(map[string]tuple.Relationship), + make(map[string]tuple.Relationship), + make(map[string]datastore.SchemaDefinition), + make(map[string]struct{}), + make(map[string]struct{}), + make(map[string]any), + } + ch.records[k] = revisionChanges + } + + return revisionChanges, nil +} + +// AddDeletedNamespace adds a change indicating that the namespace with the name was deleted. +func (ch *Changes[R, K]) AddDeletedNamespace( + _ context.Context, + rev R, + namespaceName string, +) error { + if ch.content&datastore.WatchSchema != datastore.WatchSchema { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + // if a delete happens in the same transaction as a change, we assume it was a change in the first place + // because that's how namespace changes are implemented in the MVCC + if _, ok := record.definitionsChanged[nsPrefix+namespaceName]; ok { + return nil + } + + delete(record.definitionsChanged, nsPrefix+namespaceName) + record.namespacesDeleted[namespaceName] = struct{}{} + return nil +} + +// AddDeletedCaveat adds a change indicating that the caveat with the name was deleted. +func (ch *Changes[R, K]) AddDeletedCaveat( + _ context.Context, + rev R, + caveatName string, +) error { + if ch.content&datastore.WatchSchema != datastore.WatchSchema { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + // if a delete happens in the same transaction as a change, we assume it was a change in the first place + // because that's how namespace changes are implemented in the MVCC + if _, ok := record.definitionsChanged[caveatPrefix+caveatName]; ok { + return nil + } + + delete(record.definitionsChanged, caveatPrefix+caveatName) + record.caveatsDeleted[caveatName] = struct{}{} + return nil +} + +// AddChangedDefinition adds a change indicating that the schema definition (namespace or caveat) +// was changed to the definition given. +func (ch *Changes[R, K]) AddChangedDefinition( + ctx context.Context, + rev R, + def datastore.SchemaDefinition, +) error { + if ch.content&datastore.WatchSchema != datastore.WatchSchema { + return nil + } + + record, err := ch.recordForRevision(rev) + if err != nil { + return err + } + + switch t := def.(type) { + case *core.NamespaceDefinition: + delete(record.namespacesDeleted, t.Name) + + if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok { + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + + record.definitionsChanged[nsPrefix+t.Name] = t + + if err := ch.adjustByteSize(t, 1); err != nil { + return err + } + + case *core.CaveatDefinition: + delete(record.caveatsDeleted, t.Name) + + if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok { + if err := ch.adjustByteSize(existing, -1); err != nil { + return err + } + } + + record.definitionsChanged[caveatPrefix+t.Name] = t + + if err := ch.adjustByteSize(t, 1); err != nil { + return err + } + + default: + log.Ctx(ctx).Fatal().Msg("unknown schema definition kind") + } + + return nil +} + +// AsRevisionChanges returns the list of changes processed so far as a datastore watch +// compatible, ordered, changelist. +func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) ([]datastore.RevisionChanges, error) { + return ch.revisionChanges(lessThanFunc, *new(R), false) +} + +// FilterAndRemoveRevisionChanges filters a list of changes processed up to the bound revision from the changes list, removing them +// and returning the filtered changes. +func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) ([]datastore.RevisionChanges, error) { + changes, err := ch.revisionChanges(lessThanFunc, boundRev, true) + if err != nil { + return nil, err + } + + ch.removeAllChangesBefore(boundRev) + return changes, nil +} + +func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) ([]datastore.RevisionChanges, error) { + if ch.IsEmpty() { + return nil, nil + } + + revisionsWithChanges := make([]K, 0, len(ch.records)) + for rk, cr := range ch.records { + if !withBound || boundRev.GreaterThan(cr.rev) { + revisionsWithChanges = append(revisionsWithChanges, rk) + } + } + + if len(revisionsWithChanges) == 0 { + return nil, nil + } + + sort.Slice(revisionsWithChanges, func(i int, j int) bool { + return lessThanFunc(revisionsWithChanges[i], revisionsWithChanges[j]) + }) + + changes := make([]datastore.RevisionChanges, len(revisionsWithChanges)) + for i, k := range revisionsWithChanges { + revisionChangeRecord := ch.records[k] + changes[i].Revision = revisionChangeRecord.rev + for _, rel := range revisionChangeRecord.relTouches { + changes[i].RelationshipChanges = append(changes[i].RelationshipChanges, tuple.Touch(rel)) + } + for _, rel := range revisionChangeRecord.relDeletes { + changes[i].RelationshipChanges = append(changes[i].RelationshipChanges, tuple.Delete(rel)) + } + changes[i].ChangedDefinitions = maps.Values(revisionChangeRecord.definitionsChanged) + changes[i].DeletedNamespaces = maps.Keys(revisionChangeRecord.namespacesDeleted) + changes[i].DeletedCaveats = maps.Keys(revisionChangeRecord.caveatsDeleted) + + if len(revisionChangeRecord.metadata) > 0 { + metadata, err := structpb.NewStruct(revisionChangeRecord.metadata) + if err != nil { + return nil, spiceerrors.MustBugf("failed to convert metadata to structpb: %v", err) + } + + changes[i].Metadata = metadata + } + } + + return changes, nil +} + +func (ch *Changes[R, K]) removeAllChangesBefore(boundRev R) { + for rk, cr := range ch.records { + if boundRev.GreaterThan(cr.rev) { + delete(ch.records, rk) + } + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go new file mode 100644 index 0000000..af0b229 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go @@ -0,0 +1,154 @@ +package common + +import ( + "context" + "errors" + "fmt" + "regexp" + "strings" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// SerializationError is returned when there's been a serialization +// error while performing a datastore operation +type SerializationError struct { + error +} + +func (err SerializationError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.Aborted, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_SERIALIZATION_FAILURE, + map[string]string{}, + ), + ) +} + +func (err SerializationError) Unwrap() error { + return err.error +} + +// NewSerializationError creates a new SerializationError +func NewSerializationError(err error) error { + return SerializationError{err} +} + +// ReadOnlyTransactionError is returned when an otherwise read-write +// transaction fails on writes with an error indicating that the datastore +// is currently in a read-only mode. +type ReadOnlyTransactionError struct { + error +} + +func (err ReadOnlyTransactionError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.Aborted, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_SERVICE_READ_ONLY, + map[string]string{}, + ), + ) +} + +// NewReadOnlyTransactionError creates a new ReadOnlyTransactionError. +func NewReadOnlyTransactionError(err error) error { + return ReadOnlyTransactionError{ + fmt.Errorf("could not perform write operation, as the datastore is currently in read-only mode: %w. This may indicate that the datastore has been put into maintenance mode", err), + } +} + +// CreateRelationshipExistsError is an error returned when attempting to CREATE an already-existing +// relationship. +type CreateRelationshipExistsError struct { + error + + // Relationship is the relationship that caused the error. May be nil, depending on the datastore. + Relationship *tuple.Relationship +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err CreateRelationshipExistsError) GRPCStatus() *status.Status { + if err.Relationship == nil { + return spiceerrors.WithCodeAndDetails( + err, + codes.AlreadyExists, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_ATTEMPT_TO_RECREATE_RELATIONSHIP, + map[string]string{}, + ), + ) + } + + relationship := tuple.ToV1Relationship(*err.Relationship) + return spiceerrors.WithCodeAndDetails( + err, + codes.AlreadyExists, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_ATTEMPT_TO_RECREATE_RELATIONSHIP, + map[string]string{ + "relationship": tuple.V1StringRelationshipWithoutCaveatOrExpiration(relationship), + "resource_type": relationship.Resource.ObjectType, + "resource_object_id": relationship.Resource.ObjectId, + "resource_relation": relationship.Relation, + "subject_type": relationship.Subject.Object.ObjectType, + "subject_object_id": relationship.Subject.Object.ObjectId, + "subject_relation": relationship.Subject.OptionalRelation, + }, + ), + ) +} + +// NewCreateRelationshipExistsError creates a new CreateRelationshipExistsError. +func NewCreateRelationshipExistsError(relationship *tuple.Relationship) error { + msg := "could not CREATE one or more relationships, as they already existed. If this is persistent, please switch to TOUCH operations or specify a precondition" + if relationship != nil { + msg = fmt.Sprintf("could not CREATE relationship `%s`, as it already existed. If this is persistent, please switch to TOUCH operations or specify a precondition", tuple.StringWithoutCaveatOrExpiration(*relationship)) + } + + return CreateRelationshipExistsError{ + errors.New(msg), + relationship, + } +} + +var ( + portMatchRegex = regexp.MustCompile("invalid port \\\"(.+)\\\" after host") + parseMatchRegex = regexp.MustCompile("parse \\\"(.+)\\\":") +) + +// RedactAndLogSensitiveConnString elides the given error, logging it only at trace +// level (after being redacted). +func RedactAndLogSensitiveConnString(ctx context.Context, baseErr string, err error, pgURL string) error { + if err == nil { + return errors.New(baseErr) + } + + // See: https://github.com/jackc/pgx/issues/1271 + filtered := err.Error() + filtered = strings.ReplaceAll(filtered, pgURL, "(redacted)") + filtered = portMatchRegex.ReplaceAllString(filtered, "(redacted)") + filtered = parseMatchRegex.ReplaceAllString(filtered, "(redacted)") + log.Ctx(ctx).Trace().Msg(baseErr + ": " + filtered) + return fmt.Errorf("%s. To view details of this error (that may contain sensitive information), please run with --log-level=trace", baseErr) +} + +// RevisionUnavailableError is returned when a revision is not available on a replica. +type RevisionUnavailableError struct { + error +} + +// NewRevisionUnavailableError creates a new RevisionUnavailableError. +func NewRevisionUnavailableError(err error) error { + return RevisionUnavailableError{err} +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go new file mode 100644 index 0000000..5788134 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go @@ -0,0 +1,269 @@ +package common + +import ( + "context" + "fmt" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" +) + +var ( + gcDurationHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_duration_seconds", + Help: "The duration of datastore garbage collection.", + Buckets: []float64{0.01, 0.1, 0.5, 1, 5, 10, 25, 60, 120}, + }) + + gcRelationshipsCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_relationships_total", + Help: "The number of stale relationships deleted by the datastore garbage collection.", + }) + + gcExpiredRelationshipsCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_expired_relationships_total", + Help: "The number of expired relationships deleted by the datastore garbage collection.", + }) + + gcTransactionsCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_transactions_total", + Help: "The number of stale transactions deleted by the datastore garbage collection.", + }) + + gcNamespacesCounter = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_namespaces_total", + Help: "The number of stale namespaces deleted by the datastore garbage collection.", + }) + + gcFailureCounterConfig = prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "datastore", + Name: "gc_failure_total", + Help: "The number of failed runs of the datastore garbage collection.", + } + gcFailureCounter = prometheus.NewCounter(gcFailureCounterConfig) +) + +// RegisterGCMetrics registers garbage collection metrics to the default +// registry. +func RegisterGCMetrics() error { + for _, metric := range []prometheus.Collector{ + gcDurationHistogram, + gcRelationshipsCounter, + gcTransactionsCounter, + gcNamespacesCounter, + gcFailureCounter, + } { + if err := prometheus.Register(metric); err != nil { + return err + } + } + + return nil +} + +// GarbageCollector represents any datastore that supports external garbage +// collection. +type GarbageCollector interface { + // HasGCRun returns true if a garbage collection run has been completed. + HasGCRun() bool + + // MarkGCCompleted marks that a garbage collection run has been completed. + MarkGCCompleted() + + // ResetGCCompleted resets the state of the garbage collection run. + ResetGCCompleted() + + // LockForGCRun attempts to acquire a lock for garbage collection. This lock + // is typically done at the datastore level, to ensure that no other nodes are + // running garbage collection at the same time. + LockForGCRun(ctx context.Context) (bool, error) + + // UnlockAfterGCRun releases the lock after a garbage collection run. + // NOTE: this method does not take a context, as the context used for the + // reset of the GC run can be canceled/timed out and the unlock will still need to happen. + UnlockAfterGCRun() error + + // ReadyState returns the current state of the datastore. + ReadyState(context.Context) (datastore.ReadyState, error) + + // Now returns the current time from the datastore. + Now(context.Context) (time.Time, error) + + // TxIDBefore returns the highest transaction ID before the provided time. + TxIDBefore(context.Context, time.Time) (datastore.Revision, error) + + // DeleteBeforeTx deletes all data before the provided transaction ID. + DeleteBeforeTx(ctx context.Context, txID datastore.Revision) (DeletionCounts, error) + + // DeleteExpiredRels deletes all relationships that have expired. + DeleteExpiredRels(ctx context.Context) (int64, error) +} + +// DeletionCounts tracks the amount of deletions that occurred when calling +// DeleteBeforeTx. +type DeletionCounts struct { + Relationships int64 + Transactions int64 + Namespaces int64 +} + +func (g DeletionCounts) MarshalZerologObject(e *zerolog.Event) { + e. + Int64("relationships", g.Relationships). + Int64("transactions", g.Transactions). + Int64("namespaces", g.Namespaces) +} + +var MaxGCInterval = 60 * time.Minute + +// StartGarbageCollector loops forever until the context is canceled and +// performs garbage collection on the provided interval. +func StartGarbageCollector(ctx context.Context, gc GarbageCollector, interval, window, timeout time.Duration) error { + return startGarbageCollectorWithMaxElapsedTime(ctx, gc, interval, window, 0, timeout, gcFailureCounter) +} + +func startGarbageCollectorWithMaxElapsedTime(ctx context.Context, gc GarbageCollector, interval, window, maxElapsedTime, timeout time.Duration, failureCounter prometheus.Counter) error { + backoffInterval := backoff.NewExponentialBackOff() + backoffInterval.InitialInterval = interval + backoffInterval.MaxInterval = max(MaxGCInterval, interval) + backoffInterval.MaxElapsedTime = maxElapsedTime + backoffInterval.Reset() + + nextInterval := interval + + log.Ctx(ctx).Info(). + Dur("interval", nextInterval). + Msg("datastore garbage collection worker started") + + for { + select { + case <-ctx.Done(): + log.Ctx(ctx).Info(). + Msg("shutting down datastore garbage collection worker") + return ctx.Err() + + case <-time.After(nextInterval): + log.Ctx(ctx).Info(). + Dur("interval", nextInterval). + Dur("window", window). + Dur("timeout", timeout). + Msg("running garbage collection worker") + + err := RunGarbageCollection(gc, window, timeout) + if err != nil { + failureCounter.Inc() + nextInterval = backoffInterval.NextBackOff() + log.Ctx(ctx).Warn().Err(err). + Dur("next-attempt-in", nextInterval). + Msg("error attempting to perform garbage collection") + continue + } + + backoffInterval.Reset() + nextInterval = interval + + log.Ctx(ctx).Debug(). + Dur("next-run-in", interval). + Msg("datastore garbage collection scheduled for next run") + } + } +} + +// RunGarbageCollection runs garbage collection for the datastore. +func RunGarbageCollection(gc GarbageCollector, window, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + ctx, span := tracer.Start(ctx, "RunGarbageCollection") + defer span.End() + + // Before attempting anything, check if the datastore is ready. + startTime := time.Now() + ready, err := gc.ReadyState(ctx) + if err != nil { + return err + } + if !ready.IsReady { + log.Ctx(ctx).Warn(). + Msgf("datastore wasn't ready when attempting garbage collection: %s", ready.Message) + return nil + } + + ok, err := gc.LockForGCRun(ctx) + if err != nil { + return fmt.Errorf("error locking for gc run: %w", err) + } + + if !ok { + log.Info(). + Msg("datastore garbage collection already in progress on another node") + return nil + } + + defer func() { + err := gc.UnlockAfterGCRun() + if err != nil { + log.Error(). + Err(err). + Msg("error unlocking after gc run") + } + }() + + now, err := gc.Now(ctx) + if err != nil { + return fmt.Errorf("error retrieving now: %w", err) + } + + watermark, err := gc.TxIDBefore(ctx, now.Add(-1*window)) + if err != nil { + return fmt.Errorf("error retrieving watermark: %w", err) + } + + collected, err := gc.DeleteBeforeTx(ctx, watermark) + + expiredRelationshipsCount, eerr := gc.DeleteExpiredRels(ctx) + + // even if an error happened, garbage would have been collected. This makes sure these are reflected even if the + // worker eventually fails or times out. + gcRelationshipsCounter.Add(float64(collected.Relationships)) + gcTransactionsCounter.Add(float64(collected.Transactions)) + gcNamespacesCounter.Add(float64(collected.Namespaces)) + gcExpiredRelationshipsCounter.Add(float64(expiredRelationshipsCount)) + collectionDuration := time.Since(startTime) + gcDurationHistogram.Observe(collectionDuration.Seconds()) + + if err != nil { + return fmt.Errorf("error deleting in gc: %w", err) + } + + if eerr != nil { + return fmt.Errorf("error deleting expired relationships in gc: %w", eerr) + } + + log.Ctx(ctx).Info(). + Stringer("highestTxID", watermark). + Dur("duration", collectionDuration). + Time("nowTime", now). + Interface("collected", collected). + Int64("expiredRelationships", expiredRelationshipsCount). + Msg("datastore garbage collection completed successfully") + + gc.MarkGCCompleted() + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go new file mode 100644 index 0000000..8f34134 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go @@ -0,0 +1,49 @@ +package common + +import ( + "context" + "fmt" + + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +// WriteRelationships is a convenience method to perform the same update operation on a set of relationships +func WriteRelationships(ctx context.Context, ds datastore.Datastore, op tuple.UpdateOperation, rels ...tuple.Relationship) (datastore.Revision, error) { + updates := make([]tuple.RelationshipUpdate, 0, len(rels)) + for _, rel := range rels { + ru := tuple.RelationshipUpdate{ + Operation: op, + Relationship: rel, + } + updates = append(updates, ru) + } + return UpdateRelationshipsInDatastore(ctx, ds, updates...) +} + +// UpdateRelationshipsInDatastore is a convenience method to perform multiple relation update operations on a Datastore +func UpdateRelationshipsInDatastore(ctx context.Context, ds datastore.Datastore, updates ...tuple.RelationshipUpdate) (datastore.Revision, error) { + return ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteRelationships(ctx, updates) + }) +} + +// ContextualizedCaveatFrom convenience method that handles creation of a contextualized caveat +// given the possibility of arguments with zero-values. +func ContextualizedCaveatFrom(name string, context map[string]any) (*core.ContextualizedCaveat, error) { + var caveat *core.ContextualizedCaveat + if name != "" { + strct, err := structpb.NewStruct(context) + if err != nil { + return nil, fmt.Errorf("malformed caveat context: %w", err) + } + caveat = &core.ContextualizedCaveat{ + CaveatName: name, + Context: strct, + } + } + return caveat, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go new file mode 100644 index 0000000..1eb64d1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go @@ -0,0 +1,28 @@ +package common + +import "github.com/authzed/spicedb/pkg/datastore/queryshape" + +// IndexDefinition is a definition of an index for a datastore. +type IndexDefinition struct { + // Name is the unique name for the index. + Name string + + // ColumnsSQL is the SQL fragment of the columns over which this index will apply. + ColumnsSQL string + + // Shapes are those query shapes for which this index should be used. + Shapes []queryshape.Shape + + // IsDeprecated is true if this index is deprecated and should not be used. + IsDeprecated bool +} + +// matchesShape returns true if the index matches the given shape. +func (id IndexDefinition) matchesShape(shape queryshape.Shape) bool { + for _, s := range id.Shapes { + if s == shape { + return true + } + } + return false +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go new file mode 100644 index 0000000..6e84549 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go @@ -0,0 +1,15 @@ +package common + +import ( + "context" + + log "github.com/authzed/spicedb/internal/logging" +) + +// LogOnError executes the function and logs the error. +// Useful to avoid silently ignoring errors in defer statements +func LogOnError(ctx context.Context, f func() error) { + if err := f(); err != nil { + log.Ctx(ctx).Err(err).Msg("datastore error") + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go new file mode 100644 index 0000000..304f62c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go @@ -0,0 +1,42 @@ +package common + +import ( + "fmt" + "slices" + "strings" + + "github.com/authzed/spicedb/pkg/datastore" +) + +type MigrationValidator struct { + additionalAllowedMigrations []string + headMigration string +} + +func NewMigrationValidator(headMigration string, additionalAllowedMigrations []string) *MigrationValidator { + return &MigrationValidator{ + additionalAllowedMigrations: additionalAllowedMigrations, + headMigration: headMigration, + } +} + +// MigrationReadyState returns the readiness of the datastore for the given version. +func (mv *MigrationValidator) MigrationReadyState(version string) datastore.ReadyState { + if version == mv.headMigration { + return datastore.ReadyState{IsReady: true} + } + if slices.Contains(mv.additionalAllowedMigrations, version) { + return datastore.ReadyState{IsReady: true} + } + var msgBuilder strings.Builder + msgBuilder.WriteString(fmt.Sprintf("datastore is not migrated: currently at revision %q, but requires %q", version, mv.headMigration)) + + if len(mv.additionalAllowedMigrations) > 0 { + msgBuilder.WriteString(fmt.Sprintf(" (additional allowed migrations: %v)", mv.additionalAllowedMigrations)) + } + msgBuilder.WriteString(". Please run \"spicedb datastore migrate\".") + return datastore.ReadyState{ + Message: msgBuilder.String(), + IsReady: false, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go new file mode 100644 index 0000000..dee0ad5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go @@ -0,0 +1,214 @@ +package common + +import ( + "context" + "database/sql" + "fmt" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/pkg/datastore" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const errUnableToQueryRels = "unable to query relationships: %w" + +// Querier is an interface for querying the database. +type Querier[R Rows] interface { + QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error +} + +// Rows is a common interface for database rows reading. +type Rows interface { + Scan(dest ...any) error + Next() bool + Err() error +} + +type closeRowsWithError interface { + Rows + Close() error +} + +type closeRows interface { + Rows + Close() +} + +func runExplainIfNecessary[R Rows](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R], explainable datastore.Explainable) error { + if builder.SQLExplainCallbackForTest == nil { + return nil + } + + // Determine the expected index names via the schema. + expectedIndexes := builder.Schema.expectedIndexesForShape(builder.queryShape) + + // Run any pre-explain statements. + for _, statement := range explainable.PreExplainStatements() { + if err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + rows.Next() + return nil + }, statement); err != nil { + return fmt.Errorf(errUnableToQueryRels, err) + } + } + + // Run the query with EXPLAIN ANALYZE. + sqlString, args, err := builder.SelectSQL() + if err != nil { + return fmt.Errorf(errUnableToQueryRels, err) + } + + explainSQL, explainArgs, err := explainable.BuildExplainQuery(sqlString, args) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, err) + } + + err = tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + explainString := "" + for rows.Next() { + var explain string + if err := rows.Scan(&explain); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) + } + explainString += explain + "\n" + } + if explainString == "" { + return fmt.Errorf("received empty explain") + } + + return builder.SQLExplainCallbackForTest(ctx, sqlString, args, builder.queryShape, explainString, expectedIndexes) + }, explainSQL, explainArgs...) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, err) + } + + return nil +} + +// QueryRelationships queries relationships for the given query and transaction. +func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R], explainable datastore.Explainable) (datastore.RelationshipIterator, error) { + span := trace.SpanFromContext(ctx) + sqlString, args, err := builder.SelectSQL() + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) + } + + if err := runExplainIfNecessary(ctx, builder, tx, explainable); err != nil { + return nil, err + } + + var resourceObjectType string + var resourceObjectID string + var resourceRelation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName sql.NullString + var caveatCtx C + var expiration *time.Time + + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + span.AddEvent("Selecting columns") + colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) + } + + span.AddEvent("Returning iterator", trace.WithAttributes(attribute.Int("column-count", len(colsToSelect)))) + return func(yield func(tuple.Relationship, error) bool) { + span.AddEvent("Issuing query to database") + err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + span.AddEvent("Query issued to database") + + var r Rows = rows + if crwe, ok := r.(closeRowsWithError); ok { + defer LogOnError(ctx, crwe.Close) + } else if cr, ok := r.(closeRows); ok { + defer cr.Close() + } + + relCount := 0 + for rows.Next() { + if relCount == 0 { + span.AddEvent("First row returned") + } + + if err := rows.Scan(colsToSelect...); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) + } + + if relCount == 0 { + span.AddEvent("First row scanned") + } + + var caveat *corev1.ContextualizedCaveat + if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + if caveatName.Valid { + var err error + caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err)) + } + } + } + + var integrity *corev1.RelationshipIntegrity + if integrityKeyID != "" { + integrity = &corev1.RelationshipIntegrity{ + KeyId: integrityKeyID, + Hash: integrityHash, + HashedAt: timestamppb.New(timestamp), + } + } + + if expiration != nil { + // Ensure the expiration is always read in UTC, since some datastores (like CRDB) + // will normalize to local time. + t := expiration.UTC() + expiration = &t + } + + relCount++ + if !yield(tuple.Relationship{ + RelationshipReference: tuple.RelationshipReference{ + Resource: tuple.ObjectAndRelation{ + ObjectType: resourceObjectType, + ObjectID: resourceObjectID, + Relation: resourceRelation, + }, + Subject: tuple.ObjectAndRelation{ + ObjectType: subjectObjectType, + ObjectID: subjectObjectID, + Relation: subjectRelation, + }, + }, + OptionalCaveat: caveat, + OptionalExpiration: expiration, + OptionalIntegrity: integrity, + }, nil) { + return nil + } + } + + span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) + if err := rows.Err(); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err)) + } + + return nil + }, sqlString, args...) + if err != nil { + if !yield(tuple.Relationship{}, err) { + return + } + } + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go new file mode 100644 index 0000000..6e44d0b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go @@ -0,0 +1,188 @@ +package common + +import ( + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +const ( + relationshipStandardColumnCount = 6 // ColNamespace, ColObjectID, ColRelation, ColUsersetNamespace, ColUsersetObjectID, ColUsersetRelation + relationshipCaveatColumnCount = 2 // ColCaveatName, ColCaveatContext + relationshipExpirationColumnCount = 1 // ColExpiration + relationshipIntegrityColumnCount = 3 // ColIntegrityKeyID, ColIntegrityHash, ColIntegrityTimestamp +) + +// SchemaInformation holds the schema information from the SQL datastore implementation. +// +//go:generate go run github.com/ecordell/optgen -output zz_generated.schema_options.go . SchemaInformation +type SchemaInformation struct { + RelationshipTableName string `debugmap:"visible"` + + ColNamespace string `debugmap:"visible"` + ColObjectID string `debugmap:"visible"` + ColRelation string `debugmap:"visible"` + ColUsersetNamespace string `debugmap:"visible"` + ColUsersetObjectID string `debugmap:"visible"` + ColUsersetRelation string `debugmap:"visible"` + + ColCaveatName string `debugmap:"visible"` + ColCaveatContext string `debugmap:"visible"` + + ColExpiration string `debugmap:"visible"` + + ColIntegrityKeyID string `debugmap:"visible"` + ColIntegrityHash string `debugmap:"visible"` + ColIntegrityTimestamp string `debugmap:"visible"` + + // Indexes are the indexes to use for this schema. + Indexes []IndexDefinition `debugmap:"visible"` + + // PaginationFilterType is the type of pagination filter to use for this schema. + PaginationFilterType PaginationFilterType `debugmap:"visible"` + + // PlaceholderFormat is the format of placeholders to use for this schema. + PlaceholderFormat sq.PlaceholderFormat `debugmap:"visible"` + + // NowFunction is the function to use to get the current time in the datastore. + NowFunction string `debugmap:"visible"` + + // ColumnOptimization is the optimization to use for columns in the schema, if any. + ColumnOptimization ColumnOptimizationOption `debugmap:"visible"` + + // IntegrityEnabled is a flag to indicate if the schema has integrity columns. + IntegrityEnabled bool `debugmap:"visible"` + + // ExpirationDisabled is a flag to indicate whether expiration support is disabled. + ExpirationDisabled bool `debugmap:"visible"` + + // SortByResourceColumnOrder is the order of the resource columns in the schema to use + // when sorting by resource. If unspecified, the default will be used. + SortByResourceColumnOrder []string `debugmap:"visible"` + + // SortBySubjectColumnOrder is the order of the subject columns in the schema to use + // when sorting by subject. If unspecified, the default will be used. + SortBySubjectColumnOrder []string `debugmap:"visible"` +} + +// expectedIndexesForShape returns the expected index names for a given query shape. +func (si SchemaInformation) expectedIndexesForShape(shape queryshape.Shape) options.SQLIndexInformation { + expectedIndexes := options.SQLIndexInformation{} + for _, index := range si.Indexes { + if index.matchesShape(shape) { + expectedIndexes.ExpectedIndexNames = append(expectedIndexes.ExpectedIndexNames, index.Name) + } + } + return expectedIndexes +} + +func (si SchemaInformation) debugValidate() { + spiceerrors.DebugAssert(func() bool { + si.mustValidate() + return true + }, "SchemaInformation failed to validate") +} + +func (si SchemaInformation) sortByResourceColumnOrderColumns() []string { + if len(si.SortByResourceColumnOrder) > 0 { + return si.SortByResourceColumnOrder + } + + return []string{ + si.ColNamespace, + si.ColObjectID, + si.ColRelation, + si.ColUsersetNamespace, + si.ColUsersetObjectID, + si.ColUsersetRelation, + } +} + +func (si SchemaInformation) sortBySubjectColumnOrderColumns() []string { + if len(si.SortBySubjectColumnOrder) > 0 { + return si.SortBySubjectColumnOrder + } + + return []string{ + si.ColUsersetNamespace, + si.ColUsersetObjectID, + si.ColUsersetRelation, + si.ColNamespace, + si.ColObjectID, + si.ColRelation, + } +} + +func (si SchemaInformation) mustValidate() { + if si.RelationshipTableName == "" { + panic("RelationshipTableName is required") + } + + if si.ColNamespace == "" { + panic("ColNamespace is required") + } + + if si.ColObjectID == "" { + panic("ColObjectID is required") + } + + if si.ColRelation == "" { + panic("ColRelation is required") + } + + if si.ColUsersetNamespace == "" { + panic("ColUsersetNamespace is required") + } + + if si.ColUsersetObjectID == "" { + panic("ColUsersetObjectID is required") + } + + if si.ColUsersetRelation == "" { + panic("ColUsersetRelation is required") + } + + if si.ColCaveatName == "" { + panic("ColCaveatName is required") + } + + if si.ColCaveatContext == "" { + panic("ColCaveatContext is required") + } + + if si.ColExpiration == "" { + panic("ColExpiration is required") + } + + if si.IntegrityEnabled { + if si.ColIntegrityKeyID == "" { + panic("ColIntegrityKeyID is required") + } + + if si.ColIntegrityHash == "" { + panic("ColIntegrityHash is required") + } + + if si.ColIntegrityTimestamp == "" { + panic("ColIntegrityTimestamp is required") + } + } + + if si.NowFunction == "" { + panic("NowFunction is required") + } + + if si.ColumnOptimization == ColumnOptimizationOptionUnknown { + panic("ColumnOptimization is required") + } + + if si.PaginationFilterType == PaginationFilterTypeUnknown { + panic("PaginationFilterType is required") + } + + if si.PlaceholderFormat == nil { + panic("PlaceholderFormat is required") + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go new file mode 100644 index 0000000..4972700 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go @@ -0,0 +1,17 @@ +package common + +import ( + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/tuple" +) + +// NewSliceRelationshipIterator creates a datastore.RelationshipIterator instance from a materialized slice of tuples. +func NewSliceRelationshipIterator(rels []tuple.Relationship) datastore.RelationshipIterator { + return func(yield func(tuple.Relationship, error) bool) { + for _, rel := range rels { + if !yield(rel, nil) { + break + } + } + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go new file mode 100644 index 0000000..ba9c4f6 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go @@ -0,0 +1,961 @@ +package common + +import ( + "context" + "fmt" + "maps" + "math" + "strings" + "time" + + sq "github.com/Masterminds/squirrel" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/jzelinskie/stringz" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var ( + // CaveatNameKey is a tracing attribute representing a caveat name + CaveatNameKey = attribute.Key("authzed.com/spicedb/sql/caveatName") + + // ObjNamespaceNameKey is a tracing attribute representing the resource + // object type. + ObjNamespaceNameKey = attribute.Key("authzed.com/spicedb/sql/objNamespaceName") + + // ObjRelationNameKey is a tracing attribute representing the resource + // relation. + ObjRelationNameKey = attribute.Key("authzed.com/spicedb/sql/objRelationName") + + // ObjIDKey is a tracing attribute representing the resource object ID. + ObjIDKey = attribute.Key("authzed.com/spicedb/sql/objId") + + // SubNamespaceNameKey is a tracing attribute representing the subject object + // type. + SubNamespaceNameKey = attribute.Key("authzed.com/spicedb/sql/subNamespaceName") + + // SubRelationNameKey is a tracing attribute representing the subject + // relation. + SubRelationNameKey = attribute.Key("authzed.com/spicedb/sql/subRelationName") + + // SubObjectIDKey is a tracing attribute representing the the subject object + // ID. + SubObjectIDKey = attribute.Key("authzed.com/spicedb/sql/subObjectId") + + tracer = otel.Tracer("spicedb/internal/datastore/common") +) + +// PaginationFilterType is an enumerator for pagination filter types. +type PaginationFilterType uint8 + +const ( + PaginationFilterTypeUnknown PaginationFilterType = iota + + // TupleComparison uses a comparison with a compound key, + // e.g. (namespace, object_id, relation) > ('ns', '123', 'viewer') + // which is not compatible with all datastores. + TupleComparison = 1 + + // ExpandedLogicComparison comparison uses a nested tree of ANDs and ORs to properly + // filter out already received relationships. Useful for databases that do not support + // tuple comparison, or do not execute it efficiently + ExpandedLogicComparison = 2 +) + +// ColumnOptimizationOption is an enumerator for column optimization options. +type ColumnOptimizationOption int + +const ( + ColumnOptimizationOptionUnknown ColumnOptimizationOption = iota + + // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns. + ColumnOptimizationOptionNone + + // ColumnOptimizationOptionStaticValues is an option that optimizes columns for static values. + ColumnOptimizationOptionStaticValues +) + +type columnTracker struct { + SingleValue *string +} + +type columnTrackerMap map[string]columnTracker + +func (ctm columnTrackerMap) hasStaticValue(columnName string) bool { + if r, ok := ctm[columnName]; ok && r.SingleValue != nil { + return true + } + return false +} + +// SchemaQueryFilterer wraps a SchemaInformation and SelectBuilder to give an opinionated +// way to build query objects. +type SchemaQueryFilterer struct { + schema SchemaInformation + queryBuilder sq.SelectBuilder + filteringColumnTracker columnTrackerMap + filterMaximumIDCount uint16 + isCustomQuery bool + extraFields []string + fromSuffix string + fromTable string + indexingHint IndexingHint +} + +// IndexingHint is an interface that can be implemented to provide a hint for the SQL query. +type IndexingHint interface { + // SQLPrefix returns the SQL prefix to be used for the indexing hint, if any. + SQLPrefix() (string, error) + + // FromTable returns the table name to be used for the indexing hint, if any. + FromTable(existingTableName string) (string, error) + + // FromSQLSuffix returns the suffix to be used for the indexing hint, if any. + FromSQLSuffix() (string, error) +} + +// NewSchemaQueryFiltererForRelationshipsSelect creates a new SchemaQueryFilterer object for selecting +// relationships. This method will automatically filter the columns retrieved from the database, only +// selecting the columns that are not already specified with a single static value in the query. +func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer { + schema.debugValidate() + + if filterMaximumIDCount == 0 { + filterMaximumIDCount = 100 + log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") + } + + queryBuilder := sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFormat).Select() + return SchemaQueryFilterer{ + schema: schema, + queryBuilder: queryBuilder, + filteringColumnTracker: map[string]columnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: false, + extraFields: extraFields, + fromTable: "", + } +} + +// NewSchemaQueryFiltererWithStartingQuery creates a new SchemaQueryFilterer object for selecting +// relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect, +// this method will not auto-filter the columns retrieved from the database. +func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { + schema.debugValidate() + + if filterMaximumIDCount == 0 { + filterMaximumIDCount = 100 + log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") + } + + return SchemaQueryFilterer{ + schema: schema, + queryBuilder: startingQuery, + filteringColumnTracker: map[string]columnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: true, + extraFields: nil, + fromTable: "", + } +} + +// WithAdditionalFilter returns the SchemaQueryFilterer with an additional filter applied to the query. +func (sqf SchemaQueryFilterer) WithAdditionalFilter(filter func(original sq.SelectBuilder) sq.SelectBuilder) SchemaQueryFilterer { + sqf.queryBuilder = filter(sqf.queryBuilder) + return sqf +} + +// WithFromTable returns the SchemaQueryFilterer with a custom FROM table. +func (sqf SchemaQueryFilterer) WithFromTable(fromTable string) SchemaQueryFilterer { + sqf.fromTable = fromTable + return sqf +} + +// WithFromSuffix returns the SchemaQueryFilterer with a suffix added to the FROM clause. +func (sqf SchemaQueryFilterer) WithFromSuffix(fromSuffix string) SchemaQueryFilterer { + sqf.fromSuffix = fromSuffix + return sqf +} + +// WithIndexingHint returns the SchemaQueryFilterer with an indexing hint applied to the query. +func (sqf SchemaQueryFilterer) WithIndexingHint(indexingHint IndexingHint) SchemaQueryFilterer { + sqf.indexingHint = indexingHint + return sqf +} + +func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { + spiceerrors.DebugAssert(func() bool { + return sqf.isCustomQuery + }, "UnderlyingQueryBuilder should only be called on custom queries") + return sqf.queryBuilderWithMaybeExpirationFilter(false) +} + +// queryBuilderWithMaybeExpirationFilter returns the query builder with the expiration filter applied, when necessary. +// Note that this adds the clause to the existing builder. +func (sqf SchemaQueryFilterer) queryBuilderWithMaybeExpirationFilter(skipExpiration bool) sq.SelectBuilder { + if sqf.schema.ExpirationDisabled || skipExpiration { + return sqf.queryBuilder + } + + // Filter out any expired relationships. + return sqf.queryBuilder.Where(sq.Or{ + sq.Eq{sqf.schema.ColExpiration: nil}, + sq.Expr(sqf.schema.ColExpiration + " > " + sqf.schema.NowFunction + "()"), + }) +} + +func (sqf SchemaQueryFilterer) TupleOrder(order options.SortOrder) SchemaQueryFilterer { + switch order { + case options.ByResource: + sqf.queryBuilder = sqf.queryBuilder.OrderBy(sqf.schema.sortByResourceColumnOrderColumns()...) + + case options.BySubject: + sqf.queryBuilder = sqf.queryBuilder.OrderBy(sqf.schema.sortBySubjectColumnOrderColumns()...) + } + + return sqf +} + +type nameAndValue struct { + name string + value string +} + +func columnsAndValuesForSort( + order options.SortOrder, + schema SchemaInformation, + cursor options.Cursor, +) ([]nameAndValue, error) { + var columnNames []string + + switch order { + case options.ByResource: + columnNames = schema.sortByResourceColumnOrderColumns() + + case options.BySubject: + columnNames = schema.sortBySubjectColumnOrderColumns() + + default: + return nil, spiceerrors.MustBugf("invalid sort order %q", order) + } + + nameAndValues := make([]nameAndValue, 0, len(columnNames)) + for _, columnName := range columnNames { + switch columnName { + case schema.ColNamespace: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Resource.ObjectType, + }) + + case schema.ColObjectID: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Resource.ObjectID, + }) + + case schema.ColRelation: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Resource.Relation, + }) + + case schema.ColUsersetNamespace: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Subject.ObjectType, + }) + + case schema.ColUsersetObjectID: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Subject.ObjectID, + }) + + case schema.ColUsersetRelation: + nameAndValues = append(nameAndValues, nameAndValue{ + name: columnName, + value: cursor.Subject.Relation, + }) + + default: + return nil, spiceerrors.MustBugf("invalid column name %q", columnName) + } + } + + return nameAndValues, nil +} + +func (sqf SchemaQueryFilterer) MustAfter(cursor options.Cursor, order options.SortOrder) SchemaQueryFilterer { + updated, err := sqf.After(cursor, order) + if err != nil { + panic(err) + } + return updated +} + +func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOrder) (SchemaQueryFilterer, error) { + spiceerrors.DebugAssertNotNil(cursor, "cursor cannot be nil") + + // NOTE: The ordering of these columns can affect query performance, be aware when changing. + columnsAndValues, err := columnsAndValuesForSort(order, sqf.schema, cursor) + if err != nil { + return sqf, err + } + + switch sqf.schema.PaginationFilterType { + case TupleComparison: + // For performance reasons, remove any column names that have static values in the query. + columnNames := make([]string, 0, len(columnsAndValues)) + valueSlots := make([]any, 0, len(columnsAndValues)) + comparisonSlotCount := 0 + + for _, cav := range columnsAndValues { + if !sqf.filteringColumnTracker.hasStaticValue(cav.name) { + columnNames = append(columnNames, cav.name) + valueSlots = append(valueSlots, cav.value) + comparisonSlotCount++ + } + } + + if comparisonSlotCount > 0 { + comparisonTuple := "(" + strings.Join(columnNames, ",") + ") > (" + strings.Repeat(",?", comparisonSlotCount)[1:] + ")" + sqf.queryBuilder = sqf.queryBuilder.Where( + comparisonTuple, + valueSlots..., + ) + } + + case ExpandedLogicComparison: + // For performance reasons, remove any column names that have static values in the query. + orClause := sq.Or{} + + for index, cav := range columnsAndValues { + if !sqf.filteringColumnTracker.hasStaticValue(cav.name) { + andClause := sq.And{} + for _, previous := range columnsAndValues[0:index] { + if !sqf.filteringColumnTracker.hasStaticValue(previous.name) { + andClause = append(andClause, sq.Eq{previous.name: previous.value}) + } + } + + andClause = append(andClause, sq.Gt{cav.name: cav.value}) + orClause = append(orClause, andClause) + } + } + + if len(orClause) > 0 { + sqf.queryBuilder = sqf.queryBuilder.Where(orClause) + } + } + + return sqf, nil +} + +// FilterToResourceType returns a new SchemaQueryFilterer that is limited to resources of the +// specified type. +func (sqf SchemaQueryFilterer) FilterToResourceType(resourceType string) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColNamespace: resourceType}) + sqf.recordColumnValue(sqf.schema.ColNamespace, resourceType) + return sqf +} + +func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string) { + existing, ok := sqf.filteringColumnTracker[colName] + if ok { + if existing.SingleValue != nil && *existing.SingleValue != colValue { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} + } + } else { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: &colValue} + } +} + +func (sqf SchemaQueryFilterer) recordVaryingColumnValue(colName string) { + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} +} + +// FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the +// specified ID. +func (sqf SchemaQueryFilterer) FilterToResourceID(objectID string) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColObjectID: objectID}) + sqf.recordColumnValue(sqf.schema.ColObjectID, objectID) + return sqf +} + +func (sqf SchemaQueryFilterer) MustFilterToResourceIDs(resourceIds []string) SchemaQueryFilterer { + updated, err := sqf.FilterToResourceIDs(resourceIds) + if err != nil { + panic(err) + } + return updated +} + +// FilterWithResourceIDPrefix returns new SchemaQueryFilterer that is limited to resources whose ID +// starts with the specified prefix. +func (sqf SchemaQueryFilterer) FilterWithResourceIDPrefix(prefix string) (SchemaQueryFilterer, error) { + if strings.Contains(prefix, "%") { + return sqf, spiceerrors.MustBugf("prefix cannot contain the percent sign") + } + if prefix == "" { + return sqf, spiceerrors.MustBugf("prefix cannot be empty") + } + + prefix = strings.ReplaceAll(prefix, `\`, `\\`) + prefix = strings.ReplaceAll(prefix, "_", `\_`) + + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.ColObjectID: prefix + "%"}) + + // NOTE: we do *not* record the use of the resource ID column here, because it is not used + // statically and thus is necessary for sorting operations. + return sqf, nil +} + +func (sqf SchemaQueryFilterer) MustFilterWithResourceIDPrefix(prefix string) SchemaQueryFilterer { + updated, err := sqf.FilterWithResourceIDPrefix(prefix) + if err != nil { + panic(err) + } + return updated +} + +// FilterToResourceIDs returns a new SchemaQueryFilterer that is limited to resources with any of the +// specified IDs. +func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (SchemaQueryFilterer, error) { + spiceerrors.DebugAssert(func() bool { + return len(resourceIds) <= int(sqf.filterMaximumIDCount) + }, "cannot have more than %d resource IDs in a single filter", sqf.filterMaximumIDCount) + + var builder strings.Builder + builder.WriteString(sqf.schema.ColObjectID) + builder.WriteString(" IN (") + args := make([]any, 0, len(resourceIds)) + + for _, resourceID := range resourceIds { + if len(resourceID) == 0 { + return sqf, spiceerrors.MustBugf("got empty resource ID") + } + + args = append(args, resourceID) + sqf.recordColumnValue(sqf.schema.ColObjectID, resourceID) + } + + builder.WriteString("?") + if len(resourceIds) > 1 { + builder.WriteString(strings.Repeat(",?", len(resourceIds)-1)) + } + builder.WriteString(")") + + sqf.queryBuilder = sqf.queryBuilder.Where(builder.String(), args...) + return sqf, nil +} + +// FilterToRelation returns a new SchemaQueryFilterer that is limited to resources with the +// specified relation. +func (sqf SchemaQueryFilterer) FilterToRelation(relation string) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColRelation: relation}) + sqf.recordColumnValue(sqf.schema.ColRelation, relation) + return sqf +} + +// MustFilterWithRelationshipsFilter returns a new SchemaQueryFilterer that is limited to resources with +// resources that match the specified filter. +func (sqf SchemaQueryFilterer) MustFilterWithRelationshipsFilter(filter datastore.RelationshipsFilter) SchemaQueryFilterer { + updated, err := sqf.FilterWithRelationshipsFilter(filter) + if err != nil { + panic(err) + } + return updated +} + +func (sqf SchemaQueryFilterer) FilterWithRelationshipsFilter(filter datastore.RelationshipsFilter) (SchemaQueryFilterer, error) { + csqf := sqf + + if filter.OptionalResourceType != "" { + csqf = csqf.FilterToResourceType(filter.OptionalResourceType) + } + + if filter.OptionalResourceRelation != "" { + csqf = csqf.FilterToRelation(filter.OptionalResourceRelation) + } + + if len(filter.OptionalResourceIds) > 0 && filter.OptionalResourceIDPrefix != "" { + return csqf, spiceerrors.MustBugf("cannot filter by both resource IDs and ID prefix") + } + + if len(filter.OptionalResourceIds) > 0 { + usqf, err := csqf.FilterToResourceIDs(filter.OptionalResourceIds) + if err != nil { + return csqf, err + } + csqf = usqf + } + + if len(filter.OptionalResourceIDPrefix) > 0 { + usqf, err := csqf.FilterWithResourceIDPrefix(filter.OptionalResourceIDPrefix) + if err != nil { + return csqf, err + } + csqf = usqf + } + + if len(filter.OptionalSubjectsSelectors) > 0 { + usqf, err := csqf.FilterWithSubjectsSelectors(filter.OptionalSubjectsSelectors...) + if err != nil { + return csqf, err + } + csqf = usqf + } + + switch filter.OptionalCaveatNameFilter.Option { + case datastore.CaveatFilterOptionHasMatchingCaveat: + spiceerrors.DebugAssert(func() bool { + return filter.OptionalCaveatNameFilter.CaveatName != "" + }, "caveat name must be set when using HasMatchingCaveat") + csqf = csqf.FilterWithCaveatName(filter.OptionalCaveatNameFilter.CaveatName) + + case datastore.CaveatFilterOptionNoCaveat: + csqf = csqf.FilterWithNoCaveat() + + case datastore.CaveatFilterOptionNone: + // No action needed, as this is the default behavior. + + default: + return csqf, spiceerrors.MustBugf("unknown caveat filter option: %v", filter.OptionalCaveatNameFilter.Option) + } + + if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionHasExpiration { + csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.ColExpiration: nil}) + spiceerrors.DebugAssert(func() bool { return !sqf.schema.ExpirationDisabled }, "expiration filter requested but schema does not support expiration") + } else if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionNoExpiration { + csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.ColExpiration: nil}) + } + + return csqf, nil +} + +// MustFilterWithSubjectsSelectors returns a new SchemaQueryFilterer that is limited to resources with +// subjects that match the specified selector(s). +func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) SchemaQueryFilterer { + usqf, err := sqf.FilterWithSubjectsSelectors(selectors...) + if err != nil { + panic(err) + } + return usqf +} + +// FilterWithSubjectsSelectors returns a new SchemaQueryFilterer that is limited to resources with +// subjects that match the specified selector(s). +func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) { + selectorsOrClause := sq.Or{} + + // If there is more than a single filter, record all the subjects as varying, as the subjects returned + // can differ for each branch. + // TODO(jschorr): Optimize this further where applicable. + if len(selectors) > 1 { + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetNamespace) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetObjectID) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation) + } + + for _, selector := range selectors { + selectorClause := sq.And{} + + if len(selector.OptionalSubjectType) > 0 { + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetNamespace: selector.OptionalSubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, selector.OptionalSubjectType) + } + + if len(selector.OptionalSubjectIds) > 0 { + spiceerrors.DebugAssert(func() bool { + return len(selector.OptionalSubjectIds) <= int(sqf.filterMaximumIDCount) + }, "cannot have more than %d subject IDs in a single filter", sqf.filterMaximumIDCount) + + var builder strings.Builder + builder.WriteString(sqf.schema.ColUsersetObjectID) + builder.WriteString(" IN (") + args := make([]any, 0, len(selector.OptionalSubjectIds)) + + for _, subjectID := range selector.OptionalSubjectIds { + if len(subjectID) == 0 { + return sqf, spiceerrors.MustBugf("got empty subject ID") + } + + args = append(args, subjectID) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, subjectID) + } + + builder.WriteString("?") + if len(selector.OptionalSubjectIds) > 1 { + builder.WriteString(strings.Repeat(",?", len(selector.OptionalSubjectIds)-1)) + } + + builder.WriteString(")") + selectorClause = append(selectorClause, sq.Expr(builder.String(), args...)) + } + + if !selector.RelationFilter.IsEmpty() { + if selector.RelationFilter.OnlyNonEllipsisRelations { + selectorClause = append(selectorClause, sq.NotEq{sqf.schema.ColUsersetRelation: datastore.Ellipsis}) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation) + } else { + relations := make([]string, 0, 2) + if selector.RelationFilter.IncludeEllipsisRelation { + relations = append(relations, datastore.Ellipsis) + } + + if selector.RelationFilter.NonEllipsisRelation != "" { + relations = append(relations, selector.RelationFilter.NonEllipsisRelation) + } + + if len(relations) == 1 { + relName := relations[0] + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetRelation: relName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, relName) + } else { + orClause := sq.Or{} + for _, relationName := range relations { + dsRelationName := stringz.DefaultEmpty(relationName, datastore.Ellipsis) + orClause = append(orClause, sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, dsRelationName) + } + + selectorClause = append(selectorClause, orClause) + } + } + } + + selectorsOrClause = append(selectorsOrClause, selectorClause) + } + + sqf.queryBuilder = sqf.queryBuilder.Where(selectorsOrClause) + return sqf, nil +} + +// FilterToSubjectFilter returns a new SchemaQueryFilterer that is limited to resources with +// subjects that match the specified filter. +func (sqf SchemaQueryFilterer) FilterToSubjectFilter(filter *v1.SubjectFilter) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetNamespace: filter.SubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, filter.SubjectType) + + if filter.OptionalSubjectId != "" { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetObjectID: filter.OptionalSubjectId}) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, filter.OptionalSubjectId) + } + + if filter.OptionalRelation != nil { + dsRelationName := stringz.DefaultEmpty(filter.OptionalRelation.Relation, datastore.Ellipsis) + + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, datastore.Ellipsis) + } + + return sqf +} + +// FilterWithCaveatName returns a new SchemaQueryFilterer that is limited to resources with the +// specified caveat name. +func (sqf SchemaQueryFilterer) FilterWithCaveatName(caveatName string) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColCaveatName: caveatName}) + sqf.recordColumnValue(sqf.schema.ColCaveatName, caveatName) + return sqf +} + +// FilterWithNoCaveat returns a new SchemaQueryFilterer that is limited to resources with no caveat. +func (sqf SchemaQueryFilterer) FilterWithNoCaveat() SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Where( + sq.Or{ + sq.Eq{sqf.schema.ColCaveatName: nil}, + sq.Eq{sqf.schema.ColCaveatName: ""}, + }) + sqf.recordVaryingColumnValue(sqf.schema.ColCaveatName) + return sqf +} + +// Limit returns a new SchemaQueryFilterer which is limited to the specified number of results. +func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer { + sqf.queryBuilder = sqf.queryBuilder.Limit(limit) + return sqf +} + +// QueryRelationshipsExecutor is a relationships query runner shared by SQL implementations of the datastore. +type QueryRelationshipsExecutor struct { + Executor ExecuteReadRelsQueryFunc +} + +// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. +type ExecuteReadRelsQueryFunc func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) + +// ExecuteQuery executes the query. +func (exc QueryRelationshipsExecutor) ExecuteQuery( + ctx context.Context, + query SchemaQueryFilterer, + opts ...options.QueryOptionsOption, +) (datastore.RelationshipIterator, error) { + if query.isCustomQuery { + return nil, spiceerrors.MustBugf("ExecuteQuery should not be called on custom queries") + } + + queryOpts := options.NewQueryOptionsWithOptions(opts...) + + // Add sort order. + query = query.TupleOrder(queryOpts.Sort) + + // Add cursor. + if queryOpts.After != nil { + if queryOpts.Sort == options.Unsorted { + return nil, datastore.ErrCursorsWithoutSorting + } + + q, err := query.After(queryOpts.After, queryOpts.Sort) + if err != nil { + return nil, err + } + query = q + } + + // Add limit. + var limit uint64 + // NOTE: we use a uint here because it lines up with the + // assignments in this function, but we set it to MaxInt64 + // because that's the biggest value that postgres and friends + // treat as valid. + limit = math.MaxInt64 + if queryOpts.Limit != nil { + limit = *queryOpts.Limit + } + + if limit < math.MaxInt64 { + query = query.limit(limit) + } + + // Add FROM clause. + from := query.schema.RelationshipTableName + if query.fromTable != "" { + from = query.fromTable + } + + // Add index hints, if any. + if query.indexingHint != nil { + // Check for a SQL prefix (pg_hint_plan). + sqlPrefix, err := query.indexingHint.SQLPrefix() + if err != nil { + return nil, fmt.Errorf("error getting SQL prefix for indexing hint: %w", err) + } + + if sqlPrefix != "" { + query.queryBuilder = query.queryBuilder.Prefix(sqlPrefix) + } + + // Check for an adjusting FROM table name (CRDB). + fromTableName, err := query.indexingHint.FromTable(from) + if err != nil { + return nil, fmt.Errorf("error getting FROM table name for indexing hint: %w", err) + } + from = fromTableName + + // Check for a SQL suffix (MySQL, Spanner). + fromSuffix, err := query.indexingHint.FromSQLSuffix() + if err != nil { + return nil, fmt.Errorf("error getting SQL suffix for indexing hint: %w", err) + } + + if fromSuffix != "" { + from += " " + fromSuffix + } + } + + if query.fromSuffix != "" { + from += " " + query.fromSuffix + } + + query.queryBuilder = query.queryBuilder.From(from) + + builder := RelationshipsQueryBuilder{ + Schema: query.schema, + SkipCaveats: queryOpts.SkipCaveats, + SkipExpiration: queryOpts.SkipExpiration, + SQLCheckAssertionForTest: queryOpts.SQLCheckAssertionForTest, + SQLExplainCallbackForTest: queryOpts.SQLExplainCallbackForTest, + filteringValues: query.filteringColumnTracker, + queryShape: queryOpts.QueryShape, + baseQueryBuilder: query, + } + + return exc.Executor(ctx, builder) +} + +// RelationshipsQueryBuilder is a builder for producing the SQL and arguments necessary for reading +// relationships. +type RelationshipsQueryBuilder struct { + Schema SchemaInformation + SkipCaveats bool + SkipExpiration bool + + filteringValues columnTrackerMap + baseQueryBuilder SchemaQueryFilterer + SQLCheckAssertionForTest options.SQLCheckAssertionForTest + SQLExplainCallbackForTest options.SQLExplainCallbackForTest + queryShape queryshape.Shape +} + +// withCaveats returns true if caveats should be included in the query. +func (b RelationshipsQueryBuilder) withCaveats() bool { + return !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone +} + +// withExpiration returns true if expiration should be included in the query. +func (b RelationshipsQueryBuilder) withExpiration() bool { + return !b.SkipExpiration && !b.Schema.ExpirationDisabled +} + +// integrityEnabled returns true if integrity columns should be included in the query. +func (b RelationshipsQueryBuilder) integrityEnabled() bool { + return b.Schema.IntegrityEnabled +} + +// columnCount returns the number of columns that will be selected in the query. +func (b RelationshipsQueryBuilder) columnCount() int { + columnCount := relationshipStandardColumnCount + if b.withCaveats() { + columnCount += relationshipCaveatColumnCount + } + if b.withExpiration() { + columnCount += relationshipExpirationColumnCount + } + if b.integrityEnabled() { + columnCount += relationshipIntegrityColumnCount + } + return columnCount +} + +// SelectSQL returns the SQL and arguments necessary for reading relationships. +func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { + // Set the column names to select. + columnNamesToSelect := make([]string, 0, b.columnCount()) + + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColRelation) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetRelation) + + if b.withCaveats() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext) + } + + if b.withExpiration() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) + } + + if b.integrityEnabled() { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp) + } + + if len(columnNamesToSelect) == 0 { + columnNamesToSelect = append(columnNamesToSelect, "1") + } + + sqlBuilder := b.baseQueryBuilder.queryBuilderWithMaybeExpirationFilter(b.SkipExpiration) + sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) + + sql, args, err := sqlBuilder.ToSql() + if err != nil { + return "", nil, err + } + + if b.SQLCheckAssertionForTest != nil { + b.SQLCheckAssertionForTest(sql) + } + + return sql, args, nil +} + +// FilteringValuesForTesting returns the filtering values. For test use only. +func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]columnTracker { + return maps.Clone(b.filteringValues) +} + +func (b RelationshipsQueryBuilder) checkColumn(columns []string, colName string) []string { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + return append(columns, colName) + } + + if !b.filteringValues.hasStaticValue(colName) { + return append(columns, colName) + } + + return columns +} + +func (b RelationshipsQueryBuilder) staticValueOrAddColumnForSelect(colsToSelect []any, colName string, field *string) []any { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + // If column optimization is disabled, always add the column to the list of columns to select. + colsToSelect = append(colsToSelect, field) + return colsToSelect + } + + // If the value is static, set the field to it and return. + if found, ok := b.filteringValues[colName]; ok && found.SingleValue != nil { + *field = *found.SingleValue + return colsToSelect + } + + // Otherwise, add the column to the list of columns to select, as the value is not static. + colsToSelect = append(colsToSelect, field) + return colsToSelect +} + +// ColumnsToSelect returns the columns to select for a given query. The columns provided are +// the references to the slots in which the values for each relationship will be placed. +func ColumnsToSelect[CN any, CC any, EC any]( + b RelationshipsQueryBuilder, + resourceObjectType *string, + resourceObjectID *string, + resourceRelation *string, + subjectObjectType *string, + subjectObjectID *string, + subjectRelation *string, + caveatName *CN, + caveatCtx *CC, + expiration EC, + + integrityKeyID *string, + integrityHash *[]byte, + timestamp *time.Time, +) ([]any, error) { + colsToSelect := make([]any, 0, b.columnCount()) + + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColNamespace, resourceObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColObjectID, resourceObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColRelation, resourceRelation) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetNamespace, subjectObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetObjectID, subjectObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetRelation, subjectRelation) + + if b.withCaveats() { + colsToSelect = append(colsToSelect, caveatName, caveatCtx) + } + + if b.withExpiration() { + colsToSelect = append(colsToSelect, expiration) + } + + if b.Schema.IntegrityEnabled { + colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp) + } + + if len(colsToSelect) == 0 { + var unused int64 + colsToSelect = append(colsToSelect, &unused) + } + + return colsToSelect, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go new file mode 100644 index 0000000..fa23efc --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go @@ -0,0 +1,31 @@ +package common + +import ( + "context" + "errors" + "strings" +) + +// IsCancellationError determines if an error returned by pgx has been caused by context cancellation. +func IsCancellationError(err error) bool { + if errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) || + err.Error() == "conn closed" { // conns are sometimes closed async upon cancellation + return true + } + return false +} + +// IsResettableError returns whether the given error is a resettable error. +func IsResettableError(err error) bool { + // detect when an error is likely due to a node taken out of service + if strings.Contains(err.Error(), "broken pipe") || + strings.Contains(err.Error(), "unexpected EOF") || + strings.Contains(err.Error(), "conn closed") || + strings.Contains(err.Error(), "connection refused") || + strings.Contains(err.Error(), "connection reset by peer") { + return true + } + + return false +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go new file mode 100644 index 0000000..be665ed --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go @@ -0,0 +1,19 @@ +package common + +import ( + "errors" + "net/url" +) + +// MetricsIDFromURL extracts the metrics ID from a given datastore URL. +func MetricsIDFromURL(dsURL string) (string, error) { + if dsURL == "" { + return "", errors.New("datastore URL is empty") + } + + u, err := url.Parse(dsURL) + if err != nil { + return "", errors.New("could not parse datastore URL") + } + return u.Host + u.Path, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go new file mode 100644 index 0000000..2caa57a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go @@ -0,0 +1,276 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package common + +import ( + squirrel "github.com/Masterminds/squirrel" + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" +) + +type SchemaInformationOption func(s *SchemaInformation) + +// NewSchemaInformationWithOptions creates a new SchemaInformation with the passed in options set +func NewSchemaInformationWithOptions(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + for _, o := range opts { + o(s) + } + return s +} + +// NewSchemaInformationWithOptionsAndDefaults creates a new SchemaInformation with the passed in options set starting from the defaults +func NewSchemaInformationWithOptionsAndDefaults(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + defaults.MustSet(s) + for _, o := range opts { + o(s) + } + return s +} + +// ToOption returns a new SchemaInformationOption that sets the values from the passed in SchemaInformation +func (s *SchemaInformation) ToOption() SchemaInformationOption { + return func(to *SchemaInformation) { + to.RelationshipTableName = s.RelationshipTableName + to.ColNamespace = s.ColNamespace + to.ColObjectID = s.ColObjectID + to.ColRelation = s.ColRelation + to.ColUsersetNamespace = s.ColUsersetNamespace + to.ColUsersetObjectID = s.ColUsersetObjectID + to.ColUsersetRelation = s.ColUsersetRelation + to.ColCaveatName = s.ColCaveatName + to.ColCaveatContext = s.ColCaveatContext + to.ColExpiration = s.ColExpiration + to.ColIntegrityKeyID = s.ColIntegrityKeyID + to.ColIntegrityHash = s.ColIntegrityHash + to.ColIntegrityTimestamp = s.ColIntegrityTimestamp + to.Indexes = s.Indexes + to.PaginationFilterType = s.PaginationFilterType + to.PlaceholderFormat = s.PlaceholderFormat + to.NowFunction = s.NowFunction + to.ColumnOptimization = s.ColumnOptimization + to.IntegrityEnabled = s.IntegrityEnabled + to.ExpirationDisabled = s.ExpirationDisabled + to.SortByResourceColumnOrder = s.SortByResourceColumnOrder + to.SortBySubjectColumnOrder = s.SortBySubjectColumnOrder + } +} + +// DebugMap returns a map form of SchemaInformation for debugging +func (s SchemaInformation) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["RelationshipTableName"] = helpers.DebugValue(s.RelationshipTableName, false) + debugMap["ColNamespace"] = helpers.DebugValue(s.ColNamespace, false) + debugMap["ColObjectID"] = helpers.DebugValue(s.ColObjectID, false) + debugMap["ColRelation"] = helpers.DebugValue(s.ColRelation, false) + debugMap["ColUsersetNamespace"] = helpers.DebugValue(s.ColUsersetNamespace, false) + debugMap["ColUsersetObjectID"] = helpers.DebugValue(s.ColUsersetObjectID, false) + debugMap["ColUsersetRelation"] = helpers.DebugValue(s.ColUsersetRelation, false) + debugMap["ColCaveatName"] = helpers.DebugValue(s.ColCaveatName, false) + debugMap["ColCaveatContext"] = helpers.DebugValue(s.ColCaveatContext, false) + debugMap["ColExpiration"] = helpers.DebugValue(s.ColExpiration, false) + debugMap["ColIntegrityKeyID"] = helpers.DebugValue(s.ColIntegrityKeyID, false) + debugMap["ColIntegrityHash"] = helpers.DebugValue(s.ColIntegrityHash, false) + debugMap["ColIntegrityTimestamp"] = helpers.DebugValue(s.ColIntegrityTimestamp, false) + debugMap["Indexes"] = helpers.DebugValue(s.Indexes, false) + debugMap["PaginationFilterType"] = helpers.DebugValue(s.PaginationFilterType, false) + debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false) + debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false) + debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false) + debugMap["IntegrityEnabled"] = helpers.DebugValue(s.IntegrityEnabled, false) + debugMap["ExpirationDisabled"] = helpers.DebugValue(s.ExpirationDisabled, false) + debugMap["SortByResourceColumnOrder"] = helpers.DebugValue(s.SortByResourceColumnOrder, false) + debugMap["SortBySubjectColumnOrder"] = helpers.DebugValue(s.SortBySubjectColumnOrder, false) + return debugMap +} + +// SchemaInformationWithOptions configures an existing SchemaInformation with the passed in options set +func SchemaInformationWithOptions(s *SchemaInformation, opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithOptions configures the receiver SchemaInformation with the passed in options set +func (s *SchemaInformation) WithOptions(opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithRelationshipTableName returns an option that can set RelationshipTableName on a SchemaInformation +func WithRelationshipTableName(relationshipTableName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.RelationshipTableName = relationshipTableName + } +} + +// WithColNamespace returns an option that can set ColNamespace on a SchemaInformation +func WithColNamespace(colNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColNamespace = colNamespace + } +} + +// WithColObjectID returns an option that can set ColObjectID on a SchemaInformation +func WithColObjectID(colObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColObjectID = colObjectID + } +} + +// WithColRelation returns an option that can set ColRelation on a SchemaInformation +func WithColRelation(colRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColRelation = colRelation + } +} + +// WithColUsersetNamespace returns an option that can set ColUsersetNamespace on a SchemaInformation +func WithColUsersetNamespace(colUsersetNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetNamespace = colUsersetNamespace + } +} + +// WithColUsersetObjectID returns an option that can set ColUsersetObjectID on a SchemaInformation +func WithColUsersetObjectID(colUsersetObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetObjectID = colUsersetObjectID + } +} + +// WithColUsersetRelation returns an option that can set ColUsersetRelation on a SchemaInformation +func WithColUsersetRelation(colUsersetRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetRelation = colUsersetRelation + } +} + +// WithColCaveatName returns an option that can set ColCaveatName on a SchemaInformation +func WithColCaveatName(colCaveatName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatName = colCaveatName + } +} + +// WithColCaveatContext returns an option that can set ColCaveatContext on a SchemaInformation +func WithColCaveatContext(colCaveatContext string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatContext = colCaveatContext + } +} + +// WithColExpiration returns an option that can set ColExpiration on a SchemaInformation +func WithColExpiration(colExpiration string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColExpiration = colExpiration + } +} + +// WithColIntegrityKeyID returns an option that can set ColIntegrityKeyID on a SchemaInformation +func WithColIntegrityKeyID(colIntegrityKeyID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityKeyID = colIntegrityKeyID + } +} + +// WithColIntegrityHash returns an option that can set ColIntegrityHash on a SchemaInformation +func WithColIntegrityHash(colIntegrityHash string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityHash = colIntegrityHash + } +} + +// WithColIntegrityTimestamp returns an option that can set ColIntegrityTimestamp on a SchemaInformation +func WithColIntegrityTimestamp(colIntegrityTimestamp string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityTimestamp = colIntegrityTimestamp + } +} + +// WithIndexes returns an option that can append Indexess to SchemaInformation.Indexes +func WithIndexes(indexes IndexDefinition) SchemaInformationOption { + return func(s *SchemaInformation) { + s.Indexes = append(s.Indexes, indexes) + } +} + +// SetIndexes returns an option that can set Indexes on a SchemaInformation +func SetIndexes(indexes []IndexDefinition) SchemaInformationOption { + return func(s *SchemaInformation) { + s.Indexes = indexes + } +} + +// WithPaginationFilterType returns an option that can set PaginationFilterType on a SchemaInformation +func WithPaginationFilterType(paginationFilterType PaginationFilterType) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PaginationFilterType = paginationFilterType + } +} + +// WithPlaceholderFormat returns an option that can set PlaceholderFormat on a SchemaInformation +func WithPlaceholderFormat(placeholderFormat squirrel.PlaceholderFormat) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PlaceholderFormat = placeholderFormat + } +} + +// WithNowFunction returns an option that can set NowFunction on a SchemaInformation +func WithNowFunction(nowFunction string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.NowFunction = nowFunction + } +} + +// WithColumnOptimization returns an option that can set ColumnOptimization on a SchemaInformation +func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColumnOptimization = columnOptimization + } +} + +// WithIntegrityEnabled returns an option that can set IntegrityEnabled on a SchemaInformation +func WithIntegrityEnabled(integrityEnabled bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.IntegrityEnabled = integrityEnabled + } +} + +// WithExpirationDisabled returns an option that can set ExpirationDisabled on a SchemaInformation +func WithExpirationDisabled(expirationDisabled bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ExpirationDisabled = expirationDisabled + } +} + +// WithSortByResourceColumnOrder returns an option that can append SortByResourceColumnOrders to SchemaInformation.SortByResourceColumnOrder +func WithSortByResourceColumnOrder(sortByResourceColumnOrder string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.SortByResourceColumnOrder = append(s.SortByResourceColumnOrder, sortByResourceColumnOrder) + } +} + +// SetSortByResourceColumnOrder returns an option that can set SortByResourceColumnOrder on a SchemaInformation +func SetSortByResourceColumnOrder(sortByResourceColumnOrder []string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.SortByResourceColumnOrder = sortByResourceColumnOrder + } +} + +// WithSortBySubjectColumnOrder returns an option that can append SortBySubjectColumnOrders to SchemaInformation.SortBySubjectColumnOrder +func WithSortBySubjectColumnOrder(sortBySubjectColumnOrder string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.SortBySubjectColumnOrder = append(s.SortBySubjectColumnOrder, sortBySubjectColumnOrder) + } +} + +// SetSortBySubjectColumnOrder returns an option that can set SortBySubjectColumnOrder on a SchemaInformation +func SetSortBySubjectColumnOrder(sortBySubjectColumnOrder []string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.SortBySubjectColumnOrder = sortBySubjectColumnOrder + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md new file mode 100644 index 0000000..de32e34 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md @@ -0,0 +1,23 @@ +# MemDB Datastore Implementation + +The `memdb` datastore implementation is based on Hashicorp's [go-memdb library](https://github.com/hashicorp/go-memdb). +Its implementation most closely mimics that of `spanner`, or `crdb`, where there is a single immutable datastore that supports querying at any point in time. +The `memdb` datastore is used for validating and rapidly iterating on concepts from consumers of other datastores. +It is 100% compliant with the datastore acceptance test suite and it should be possible to use it in place of any other datastore for development purposes. +Differences between the `memdb` datastore and other implementations that manifest themselves as differences visible to the caller should be reported as bugs. + +**The memdb datastore can NOT be used in a production setting!** + +## Implementation Caveats + +### No Garbage Collection + +This implementation of the datastore has no garbage collection, meaning that memory usage will grow monotonically with mutations. + +### No Durable Storage + +The `memdb` datastore, as its name implies, stores information entirely in memory, and therefore will lose all data when the host process terminates. + +### Cannot be used for multi-node dispatch + +If you attempt to run SpiceDB with multi-node dispatch enabled using the memory datastore, each independent node will get a separate copy of the datastore, and you will end up very confused. diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go new file mode 100644 index 0000000..2b4baca --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go @@ -0,0 +1,156 @@ +package memdb + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-memdb" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +const tableCaveats = "caveats" + +type caveat struct { + name string + definition []byte + revision datastore.Revision +} + +func (c *caveat) Unwrap() (*core.CaveatDefinition, error) { + definition := core.CaveatDefinition{} + err := definition.UnmarshalVT(c.definition) + return &definition, err +} + +func (r *memdbReader) ReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, datastore.NoRevision, err + } + return r.readUnwrappedCaveatByName(tx, name) +} + +func (r *memdbReader) readCaveatByName(tx *memdb.Txn, name string) (*caveat, datastore.Revision, error) { + found, err := tx.First(tableCaveats, indexID, name) + if err != nil { + return nil, datastore.NoRevision, err + } + if found == nil { + return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) + } + cvt := found.(*caveat) + return cvt, cvt.revision, nil +} + +func (r *memdbReader) readUnwrappedCaveatByName(tx *memdb.Txn, name string) (*core.CaveatDefinition, datastore.Revision, error) { + c, rev, err := r.readCaveatByName(tx, name) + if err != nil { + return nil, datastore.NoRevision, err + } + unwrapped, err := c.Unwrap() + if err != nil { + return nil, datastore.NoRevision, err + } + return unwrapped, rev, nil +} + +func (r *memdbReader) ListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) { + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + var caveats []datastore.RevisionedCaveat + it, err := tx.LowerBound(tableCaveats, indexID) + if err != nil { + return nil, err + } + + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + rawCaveat := foundRaw.(*caveat) + definition, err := rawCaveat.Unwrap() + if err != nil { + return nil, err + } + caveats = append(caveats, datastore.RevisionedCaveat{ + Definition: definition, + LastWrittenRevision: rawCaveat.revision, + }) + } + + return caveats, nil +} + +func (r *memdbReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { + allCaveats, err := r.ListAllCaveats(ctx) + if err != nil { + return nil, err + } + + allowedCaveatNames := mapz.NewSet[string]() + allowedCaveatNames.Extend(caveatNames) + + toReturn := make([]datastore.RevisionedCaveat, 0, len(caveatNames)) + for _, caveat := range allCaveats { + if allowedCaveatNames.Has(caveat.Definition.Name) { + toReturn = append(toReturn, caveat) + } + } + return toReturn, nil +} + +func (rwt *memdbReadWriteTx) WriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error { + rwt.mustLock() + defer rwt.Unlock() + tx, err := rwt.txSource() + if err != nil { + return err + } + return rwt.writeCaveat(tx, caveats) +} + +func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.CaveatDefinition) error { + caveatNames := mapz.NewSet[string]() + for _, coreCaveat := range caveats { + if !caveatNames.Add(coreCaveat.Name) { + return fmt.Errorf("duplicate caveat %s", coreCaveat.Name) + } + marshalled, err := coreCaveat.MarshalVT() + if err != nil { + return err + } + c := caveat{ + name: coreCaveat.Name, + definition: marshalled, + revision: rwt.newRevision, + } + if err := tx.Insert(tableCaveats, &c); err != nil { + return err + } + } + return nil +} + +func (rwt *memdbReadWriteTx) DeleteCaveats(_ context.Context, names []string) error { + rwt.mustLock() + defer rwt.Unlock() + tx, err := rwt.txSource() + if err != nil { + return err + } + for _, name := range names { + if err := tx.Delete(tableCaveats, caveat{name: name}); err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go new file mode 100644 index 0000000..0ef4b8b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go @@ -0,0 +1,37 @@ +package memdb + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// SerializationMaxRetriesReachedError occurs when a write request has reached its maximum number +// of retries due to serialization errors. +type SerializationMaxRetriesReachedError struct { + error +} + +// NewSerializationMaxRetriesReachedErr constructs a new max retries reached error. +func NewSerializationMaxRetriesReachedErr(baseErr error) error { + return SerializationMaxRetriesReachedError{ + error: baseErr, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err SerializationMaxRetriesReachedError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.DeadlineExceeded, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{ + "details": "too many updates were made to the in-memory datastore at once; this datastore has limited write throughput capability", + }, + ), + ) +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go new file mode 100644 index 0000000..61eba84 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go @@ -0,0 +1,386 @@ +package memdb + +import ( + "context" + "errors" + "fmt" + "math" + "sort" + "sync" + "time" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + + "github.com/google/uuid" + "github.com/hashicorp/go-memdb" + + "github.com/authzed/spicedb/internal/datastore/revisions" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +const ( + Engine = "memory" + defaultWatchBufferLength = 128 + numAttempts = 10 +) + +var ( + ErrMemDBIsClosed = errors.New("datastore is closed") + ErrSerialization = errors.New("serialization error") +) + +// DisableGC is a convenient constant for setting the garbage collection +// interval high enough that it will never run. +const DisableGC = time.Duration(math.MaxInt64) + +// NewMemdbDatastore creates a new Datastore compliant datastore backed by memdb. +// +// If the watchBufferLength value of 0 is set then a default value of 128 will be used. +func NewMemdbDatastore( + watchBufferLength uint16, + revisionQuantization, + gcWindow time.Duration, +) (datastore.Datastore, error) { + if revisionQuantization > gcWindow { + return nil, errors.New("gc window must be larger than quantization interval") + } + + if revisionQuantization <= 1 { + revisionQuantization = 1 + } + + db, err := memdb.NewMemDB(schema) + if err != nil { + return nil, err + } + + if watchBufferLength == 0 { + watchBufferLength = defaultWatchBufferLength + } + + uniqueID := uuid.NewString() + return &memdbDatastore{ + CommonDecoder: revisions.CommonDecoder{ + Kind: revisions.Timestamp, + }, + db: db, + revisions: []snapshot{ + { + revision: nowRevision(), + db: db, + }, + }, + + negativeGCWindow: gcWindow.Nanoseconds() * -1, + quantizationPeriod: revisionQuantization.Nanoseconds(), + watchBufferLength: watchBufferLength, + watchBufferWriteTimeout: 100 * time.Millisecond, + uniqueID: uniqueID, + }, nil +} + +type memdbDatastore struct { + sync.RWMutex + revisions.CommonDecoder + + // NOTE: call checkNotClosed before using + db *memdb.MemDB // GUARDED_BY(RWMutex) + revisions []snapshot // GUARDED_BY(RWMutex) + activeWriteTxn *memdb.Txn // GUARDED_BY(RWMutex) + + negativeGCWindow int64 + quantizationPeriod int64 + watchBufferLength uint16 + watchBufferWriteTimeout time.Duration + uniqueID string +} + +type snapshot struct { + revision revisions.TimestampRevision + db *memdb.MemDB +} + +func (mdb *memdbDatastore) MetricsID() (string, error) { + return "memdb", nil +} + +func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision) datastore.Reader { + mdb.RLock() + defer mdb.RUnlock() + + if err := mdb.checkNotClosed(); err != nil { + return &memdbReader{nil, nil, err, time.Now()} + } + + if len(mdb.revisions) == 0 { + return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is not ready"), time.Now()} + } + + if err := mdb.checkRevisionLocalCallerMustLock(dr); err != nil { + return &memdbReader{nil, nil, err, time.Now()} + } + + revIndex := sort.Search(len(mdb.revisions), func(i int) bool { + return mdb.revisions[i].revision.GreaterThan(dr) || mdb.revisions[i].revision.Equal(dr) + }) + + // handle the case when there is no revision snapshot newer than the requested revision + if revIndex == len(mdb.revisions) { + revIndex = len(mdb.revisions) - 1 + } + + rev := mdb.revisions[revIndex] + if rev.db == nil { + return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is already closed"), time.Now()} + } + + roTxn := rev.db.Txn(false) + txSrc := func() (*memdb.Txn, error) { + return roTxn, nil + } + + return &memdbReader{noopTryLocker{}, txSrc, nil, time.Now()} +} + +func (mdb *memdbDatastore) SupportsIntegrity() bool { + return true +} + +func (mdb *memdbDatastore) ReadWriteTx( + ctx context.Context, + f datastore.TxUserFunc, + opts ...options.RWTOptionsOption, +) (datastore.Revision, error) { + config := options.NewRWTOptionsWithOptions(opts...) + txNumAttempts := numAttempts + if config.DisableRetries { + txNumAttempts = 1 + } + + for i := 0; i < txNumAttempts; i++ { + var tx *memdb.Txn + createTxOnce := sync.Once{} + txSrc := func() (*memdb.Txn, error) { + var err error + createTxOnce.Do(func() { + mdb.Lock() + defer mdb.Unlock() + + if mdb.activeWriteTxn != nil { + err = ErrSerialization + return + } + + if err = mdb.checkNotClosed(); err != nil { + return + } + + tx = mdb.db.Txn(true) + tx.TrackChanges() + mdb.activeWriteTxn = tx + }) + + return tx, err + } + + newRevision := mdb.newRevisionID() + rwt := &memdbReadWriteTx{memdbReader{&sync.Mutex{}, txSrc, nil, time.Now()}, newRevision} + if err := f(ctx, rwt); err != nil { + mdb.Lock() + if tx != nil { + tx.Abort() + mdb.activeWriteTxn = nil + } + + // If the error was a serialization error, retry the transaction + if errors.Is(err, ErrSerialization) { + mdb.Unlock() + + // If we don't sleep here, we run out of retries instantaneously + time.Sleep(1 * time.Millisecond) + continue + } + defer mdb.Unlock() + + // We *must* return the inner error unmodified in case it's not an error type + // that supports unwrapping (e.g. gRPC errors) + return datastore.NoRevision, err + } + + mdb.Lock() + defer mdb.Unlock() + + tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0) + if tx != nil { + if config.Metadata != nil && len(config.Metadata.GetFields()) > 0 { + if err := tracked.SetRevisionMetadata(ctx, newRevision, config.Metadata.AsMap()); err != nil { + return datastore.NoRevision, err + } + } + + for _, change := range tx.Changes() { + switch change.Table { + case tableRelationship: + if change.After != nil { + rt, err := change.After.(*relationship).Relationship() + if err != nil { + return datastore.NoRevision, err + } + + if err := tracked.AddRelationshipChange(ctx, newRevision, rt, tuple.UpdateOperationTouch); err != nil { + return datastore.NoRevision, err + } + } else if change.After == nil && change.Before != nil { + rt, err := change.Before.(*relationship).Relationship() + if err != nil { + return datastore.NoRevision, err + } + + if err := tracked.AddRelationshipChange(ctx, newRevision, rt, tuple.UpdateOperationDelete); err != nil { + return datastore.NoRevision, err + } + } else { + return datastore.NoRevision, spiceerrors.MustBugf("unexpected relationship change") + } + case tableNamespace: + if change.After != nil { + loaded := &corev1.NamespaceDefinition{} + if err := loaded.UnmarshalVT(change.After.(*namespace).configBytes); err != nil { + return datastore.NoRevision, err + } + + err := tracked.AddChangedDefinition(ctx, newRevision, loaded) + if err != nil { + return datastore.NoRevision, err + } + } else if change.After == nil && change.Before != nil { + err := tracked.AddDeletedNamespace(ctx, newRevision, change.Before.(*namespace).name) + if err != nil { + return datastore.NoRevision, err + } + } else { + return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change") + } + case tableCaveats: + if change.After != nil { + loaded := &corev1.CaveatDefinition{} + if err := loaded.UnmarshalVT(change.After.(*caveat).definition); err != nil { + return datastore.NoRevision, err + } + + err := tracked.AddChangedDefinition(ctx, newRevision, loaded) + if err != nil { + return datastore.NoRevision, err + } + } else if change.After == nil && change.Before != nil { + err := tracked.AddDeletedCaveat(ctx, newRevision, change.Before.(*caveat).name) + if err != nil { + return datastore.NoRevision, err + } + } else { + return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change") + } + } + } + + var rc datastore.RevisionChanges + changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc) + if err != nil { + return datastore.NoRevision, err + } + + if len(changes) > 1 { + return datastore.NoRevision, spiceerrors.MustBugf("unexpected MemDB transaction with multiple revision changes") + } else if len(changes) == 1 { + rc = changes[0] + } + + change := &changelog{ + revisionNanos: newRevision.TimestampNanoSec(), + changes: rc, + } + if err := tx.Insert(tableChangelog, change); err != nil { + return datastore.NoRevision, fmt.Errorf("error writing changelog: %w", err) + } + + tx.Commit() + } + mdb.activeWriteTxn = nil + + if err := mdb.checkNotClosed(); err != nil { + return datastore.NoRevision, err + } + + // Create a snapshot and add it to the revisions slice + snap := mdb.db.Snapshot() + mdb.revisions = append(mdb.revisions, snapshot{newRevision, snap}) + return newRevision, nil + } + + return datastore.NoRevision, NewSerializationMaxRetriesReachedErr(errors.New("serialization max retries exceeded; please reduce your parallel writes")) +} + +func (mdb *memdbDatastore) ReadyState(_ context.Context) (datastore.ReadyState, error) { + mdb.RLock() + defer mdb.RUnlock() + + return datastore.ReadyState{ + Message: "missing expected initial revision", + IsReady: len(mdb.revisions) > 0, + }, nil +} + +func (mdb *memdbDatastore) OfflineFeatures() (*datastore.Features, error) { + return &datastore.Features{ + Watch: datastore.Feature{ + Status: datastore.FeatureSupported, + }, + IntegrityData: datastore.Feature{ + Status: datastore.FeatureSupported, + }, + ContinuousCheckpointing: datastore.Feature{ + Status: datastore.FeatureUnsupported, + }, + WatchEmitsImmediately: datastore.Feature{ + Status: datastore.FeatureUnsupported, + }, + }, nil +} + +func (mdb *memdbDatastore) Features(_ context.Context) (*datastore.Features, error) { + return mdb.OfflineFeatures() +} + +func (mdb *memdbDatastore) Close() error { + mdb.Lock() + defer mdb.Unlock() + + if db := mdb.db; db != nil { + mdb.revisions = []snapshot{ + { + revision: nowRevision(), + db: db, + }, + } + } else { + mdb.revisions = []snapshot{} + } + + mdb.db = nil + + return nil +} + +// This code assumes that the RWMutex has been acquired. +func (mdb *memdbDatastore) checkNotClosed() error { + if mdb.db == nil { + return ErrMemDBIsClosed + } + return nil +} + +var _ datastore.Datastore = &memdbDatastore{} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go new file mode 100644 index 0000000..fdd224a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go @@ -0,0 +1,597 @@ +package memdb + +import ( + "context" + "fmt" + "slices" + "sort" + "strings" + "time" + + "github.com/hashicorp/go-memdb" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +type txFactory func() (*memdb.Txn, error) + +type memdbReader struct { + TryLocker + txSource txFactory + initErr error + now time.Time +} + +func (r *memdbReader) CountRelationships(ctx context.Context, name string) (int, error) { + counters, err := r.LookupCounters(ctx) + if err != nil { + return 0, err + } + + var found *core.RelationshipFilter + for _, counter := range counters { + if counter.Name == name { + found = counter.Filter + break + } + } + + if found == nil { + return 0, datastore.NewCounterNotRegisteredErr(name) + } + + coreFilter, err := datastore.RelationshipsFilterFromCoreFilter(found) + if err != nil { + return 0, err + } + + iter, err := r.QueryRelationships(ctx, coreFilter) + if err != nil { + return 0, err + } + + count := 0 + for _, err := range iter { + if err != nil { + return 0, err + } + + count++ + } + return count, nil +} + +func (r *memdbReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) { + if r.initErr != nil { + return nil, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + var counters []datastore.RelationshipCounter + + it, err := tx.LowerBound(tableCounters, indexID) + if err != nil { + return nil, err + } + + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + found := foundRaw.(*counter) + + loaded := &core.RelationshipFilter{} + if err := loaded.UnmarshalVT(found.filterBytes); err != nil { + return nil, err + } + + counters = append(counters, datastore.RelationshipCounter{ + Name: found.name, + Filter: loaded, + Count: found.count, + ComputedAtRevision: found.updated, + }) + } + + return counters, nil +} + +// QueryRelationships reads relationships starting from the resource side. +func (r *memdbReader) QueryRelationships( + _ context.Context, + filter datastore.RelationshipsFilter, + opts ...options.QueryOptionsOption, +) (datastore.RelationshipIterator, error) { + if r.initErr != nil { + return nil, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + queryOpts := options.NewQueryOptionsWithOptions(opts...) + + bestIterator, err := iteratorForFilter(tx, filter) + if err != nil { + return nil, err + } + + if queryOpts.After != nil && queryOpts.Sort == options.Unsorted { + return nil, datastore.ErrCursorsWithoutSorting + } + + matchingRelationshipsFilterFunc := filterFuncForFilters( + filter.OptionalResourceType, + filter.OptionalResourceIds, + filter.OptionalResourceIDPrefix, + filter.OptionalResourceRelation, + filter.OptionalSubjectsSelectors, + filter.OptionalCaveatNameFilter, + filter.OptionalExpirationOption, + makeCursorFilterFn(queryOpts.After, queryOpts.Sort), + ) + filteredIterator := memdb.NewFilterIterator(bestIterator, matchingRelationshipsFilterFunc) + + switch queryOpts.Sort { + case options.Unsorted: + fallthrough + + case options.ByResource: + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration) + return iter, nil + + case options.BySubject: + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration) + + default: + return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort) + } +} + +// ReverseQueryRelationships reads relationships starting from the subject. +func (r *memdbReader) ReverseQueryRelationships( + _ context.Context, + subjectsFilter datastore.SubjectsFilter, + opts ...options.ReverseQueryOptionsOption, +) (datastore.RelationshipIterator, error) { + if r.initErr != nil { + return nil, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + queryOpts := options.NewReverseQueryOptionsWithOptions(opts...) + + iterator, err := tx.Get( + tableRelationship, + indexSubjectNamespace, + subjectsFilter.SubjectType, + ) + if err != nil { + return nil, err + } + + filterObjectType, filterRelation := "", "" + if queryOpts.ResRelation != nil { + filterObjectType = queryOpts.ResRelation.Namespace + filterRelation = queryOpts.ResRelation.Relation + } + + matchingRelationshipsFilterFunc := filterFuncForFilters( + filterObjectType, + nil, + "", + filterRelation, + []datastore.SubjectsSelector{subjectsFilter.AsSelector()}, + datastore.CaveatNameFilter{}, + datastore.ExpirationFilterOptionNone, + makeCursorFilterFn(queryOpts.AfterForReverse, queryOpts.SortForReverse), + ) + filteredIterator := memdb.NewFilterIterator(iterator, matchingRelationshipsFilterFunc) + + switch queryOpts.SortForReverse { + case options.Unsorted: + fallthrough + + case options.ByResource: + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false) + return iter, nil + + case options.BySubject: + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false) + + default: + return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.SortForReverse) + } +} + +// ReadNamespace reads a namespace definition and version and returns it, and the revision at +// which it was created or last written, if found. +func (r *memdbReader) ReadNamespaceByName(_ context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) { + if r.initErr != nil { + return nil, datastore.NoRevision, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, datastore.NoRevision, err + } + + foundRaw, err := tx.First(tableNamespace, indexID, nsName) + if err != nil { + return nil, datastore.NoRevision, err + } + + if foundRaw == nil { + return nil, datastore.NoRevision, datastore.NewNamespaceNotFoundErr(nsName) + } + + found := foundRaw.(*namespace) + + loaded := &core.NamespaceDefinition{} + if err := loaded.UnmarshalVT(found.configBytes); err != nil { + return nil, datastore.NoRevision, err + } + + return loaded, found.updated, nil +} + +// ListNamespaces lists all namespaces defined. +func (r *memdbReader) ListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) { + if r.initErr != nil { + return nil, r.initErr + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + var nsDefs []datastore.RevisionedNamespace + + it, err := tx.LowerBound(tableNamespace, indexID) + if err != nil { + return nil, err + } + + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + found := foundRaw.(*namespace) + + loaded := &core.NamespaceDefinition{} + if err := loaded.UnmarshalVT(found.configBytes); err != nil { + return nil, err + } + + nsDefs = append(nsDefs, datastore.RevisionedNamespace{ + Definition: loaded, + LastWrittenRevision: found.updated, + }) + } + + return nsDefs, nil +} + +func (r *memdbReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { + if r.initErr != nil { + return nil, r.initErr + } + + if len(nsNames) == 0 { + return nil, nil + } + + r.mustLock() + defer r.Unlock() + + tx, err := r.txSource() + if err != nil { + return nil, err + } + + it, err := tx.LowerBound(tableNamespace, indexID) + if err != nil { + return nil, err + } + + nsNameMap := make(map[string]struct{}, len(nsNames)) + for _, nsName := range nsNames { + nsNameMap[nsName] = struct{}{} + } + + nsDefs := make([]datastore.RevisionedNamespace, 0, len(nsNames)) + + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + found := foundRaw.(*namespace) + + loaded := &core.NamespaceDefinition{} + if err := loaded.UnmarshalVT(found.configBytes); err != nil { + return nil, err + } + + if _, ok := nsNameMap[loaded.Name]; ok { + nsDefs = append(nsDefs, datastore.RevisionedNamespace{ + Definition: loaded, + LastWrittenRevision: found.updated, + }) + } + } + + return nsDefs, nil +} + +func (r *memdbReader) mustLock() { + if !r.TryLock() { + panic("detected concurrent use of ReadWriteTransaction") + } +} + +func iteratorForFilter(txn *memdb.Txn, filter datastore.RelationshipsFilter) (memdb.ResultIterator, error) { + // "_prefix" is a specialized index suffix used by github.com/hashicorp/go-memdb to match on + // a prefix of a string. + // See: https://github.com/hashicorp/go-memdb/blob/9940d4a14258e3b887bfb4bc6ebc28f65461a01c/txn.go#L531 + index := indexNamespace + "_prefix" + + var args []any + if filter.OptionalResourceType != "" { + args = append(args, filter.OptionalResourceType) + index = indexNamespace + } else { + args = append(args, "") + } + + if filter.OptionalResourceType != "" && filter.OptionalResourceRelation != "" { + args = append(args, filter.OptionalResourceRelation) + index = indexNamespaceAndRelation + } + + if len(args) == 0 { + return nil, spiceerrors.MustBugf("cannot specify an empty filter") + } + + iter, err := txn.Get(tableRelationship, index, args...) + if err != nil { + return nil, fmt.Errorf("unable to get iterator for filter: %w", err) + } + + return iter, err +} + +func filterFuncForFilters( + optionalResourceType string, + optionalResourceIds []string, + optionalResourceIDPrefix string, + optionalRelation string, + optionalSubjectsSelectors []datastore.SubjectsSelector, + optionalCaveatFilter datastore.CaveatNameFilter, + optionalExpirationFilter datastore.ExpirationFilterOption, + cursorFilter func(*relationship) bool, +) memdb.FilterFunc { + return func(tupleRaw interface{}) bool { + tuple := tupleRaw.(*relationship) + + switch { + case optionalResourceType != "" && optionalResourceType != tuple.namespace: + return true + case len(optionalResourceIds) > 0 && !slices.Contains(optionalResourceIds, tuple.resourceID): + return true + case optionalResourceIDPrefix != "" && !strings.HasPrefix(tuple.resourceID, optionalResourceIDPrefix): + return true + case optionalRelation != "" && optionalRelation != tuple.relation: + return true + case optionalCaveatFilter.Option == datastore.CaveatFilterOptionHasMatchingCaveat && (tuple.caveat == nil || tuple.caveat.caveatName != optionalCaveatFilter.CaveatName): + return true + case optionalCaveatFilter.Option == datastore.CaveatFilterOptionNoCaveat && (tuple.caveat != nil && tuple.caveat.caveatName != ""): + return true + case optionalExpirationFilter == datastore.ExpirationFilterOptionHasExpiration && tuple.expiration == nil: + return true + case optionalExpirationFilter == datastore.ExpirationFilterOptionNoExpiration && tuple.expiration != nil: + return true + } + + applySubjectSelector := func(selector datastore.SubjectsSelector) bool { + switch { + case len(selector.OptionalSubjectType) > 0 && selector.OptionalSubjectType != tuple.subjectNamespace: + return false + case len(selector.OptionalSubjectIds) > 0 && !slices.Contains(selector.OptionalSubjectIds, tuple.subjectObjectID): + return false + } + + if selector.RelationFilter.OnlyNonEllipsisRelations { + return tuple.subjectRelation != datastore.Ellipsis + } + + relations := make([]string, 0, 2) + if selector.RelationFilter.IncludeEllipsisRelation { + relations = append(relations, datastore.Ellipsis) + } + + if selector.RelationFilter.NonEllipsisRelation != "" { + relations = append(relations, selector.RelationFilter.NonEllipsisRelation) + } + + return len(relations) == 0 || slices.Contains(relations, tuple.subjectRelation) + } + + if len(optionalSubjectsSelectors) > 0 { + hasMatchingSelector := false + for _, selector := range optionalSubjectsSelectors { + if applySubjectSelector(selector) { + hasMatchingSelector = true + break + } + } + + if !hasMatchingSelector { + return true + } + } + + return cursorFilter(tuple) + } +} + +func makeCursorFilterFn(after options.Cursor, order options.SortOrder) func(tpl *relationship) bool { + if after != nil { + switch order { + case options.ByResource: + return func(tpl *relationship) bool { + return less(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) || + (eq(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) && + (less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) || + eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject))) + } + case options.BySubject: + return func(tpl *relationship) bool { + return less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) || + (eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) && + (less(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) || + eq(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource))) + } + } + } + return noopCursorFilter +} + +func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) (datastore.RelationshipIterator, error) { + results := make([]tuple.Relationship, 0) + + // Coalesce all of the results into memory + for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { + rt, err := foundRaw.(*relationship).Relationship() + if err != nil { + return nil, err + } + + if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) { + continue + } + + if skipCaveats && rt.OptionalCaveat != nil { + return nil, spiceerrors.MustBugf("unexpected caveat in result for relationship: %v", rt) + } + + if skipExpiration && rt.OptionalExpiration != nil { + return nil, spiceerrors.MustBugf("unexpected expiration in result for relationship: %v", rt) + } + + results = append(results, rt) + } + + // Sort them by subject + sort.Slice(results, func(i, j int) bool { + lhsRes := results[i].Resource + lhsSub := results[i].Subject + rhsRes := results[j].Resource + rhsSub := results[j].Subject + return less(lhsSub.ObjectType, lhsSub.ObjectID, lhsSub.Relation, rhsSub) || + (eq(lhsSub.ObjectType, lhsSub.ObjectID, lhsSub.Relation, rhsSub) && + (less(lhsRes.ObjectType, lhsRes.ObjectID, lhsRes.Relation, rhsRes))) + }) + + // Limit them if requested + if limit != nil && uint64(len(results)) > *limit { + results = results[0:*limit] + } + + return common.NewSliceRelationshipIterator(results), nil +} + +func noopCursorFilter(_ *relationship) bool { + return false +} + +func less(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelation) bool { + return lhsNamespace < rhs.ObjectType || + (lhsNamespace == rhs.ObjectType && lhsObjectID < rhs.ObjectID) || + (lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation < rhs.Relation) +} + +func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelation) bool { + return lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation == rhs.Relation +} + +func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) datastore.RelationshipIterator { + var count uint64 + return func(yield func(tuple.Relationship, error) bool) { + for { + foundRaw := it.Next() + if foundRaw == nil { + return + } + + if limit != nil && count >= *limit { + return + } + + rt, err := foundRaw.(*relationship).Relationship() + if err != nil { + if !yield(tuple.Relationship{}, err) { + return + } + continue + } + + if skipCaveats && rt.OptionalCaveat != nil { + yield(rt, fmt.Errorf("unexpected caveat in result for relationship: %v", rt)) + return + } + + if skipExpiration && rt.OptionalExpiration != nil { + yield(rt, fmt.Errorf("unexpected expiration in result for relationship: %v", rt)) + return + } + + if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) { + continue + } + + if !yield(rt, err) { + return + } + count++ + } + } +} + +var _ datastore.Reader = &memdbReader{} + +type TryLocker interface { + TryLock() bool + Unlock() +} + +type noopTryLocker struct{} + +func (ntl noopTryLocker) TryLock() bool { + return true +} + +func (ntl noopTryLocker) Unlock() {} + +var _ TryLocker = noopTryLocker{} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go new file mode 100644 index 0000000..8929e84 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go @@ -0,0 +1,386 @@ +package memdb + +import ( + "context" + "fmt" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/hashicorp/go-memdb" + "github.com/jzelinskie/stringz" + + "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +type memdbReadWriteTx struct { + memdbReader + newRevision datastore.Revision +} + +func (rwt *memdbReadWriteTx) WriteRelationships(_ context.Context, mutations []tuple.RelationshipUpdate) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + return rwt.write(tx, mutations...) +} + +func (rwt *memdbReadWriteTx) toIntegrity(mutation tuple.RelationshipUpdate) *relationshipIntegrity { + var ri *relationshipIntegrity + if mutation.Relationship.OptionalIntegrity != nil { + ri = &relationshipIntegrity{ + keyID: mutation.Relationship.OptionalIntegrity.KeyId, + hash: mutation.Relationship.OptionalIntegrity.Hash, + timestamp: mutation.Relationship.OptionalIntegrity.HashedAt.AsTime(), + } + } + return ri +} + +// Caller must already hold the concurrent access lock! +func (rwt *memdbReadWriteTx) write(tx *memdb.Txn, mutations ...tuple.RelationshipUpdate) error { + // Apply the mutations + for _, mutation := range mutations { + rel := &relationship{ + mutation.Relationship.Resource.ObjectType, + mutation.Relationship.Resource.ObjectID, + mutation.Relationship.Resource.Relation, + mutation.Relationship.Subject.ObjectType, + mutation.Relationship.Subject.ObjectID, + mutation.Relationship.Subject.Relation, + rwt.toCaveatReference(mutation), + rwt.toIntegrity(mutation), + mutation.Relationship.OptionalExpiration, + } + + found, err := tx.First( + tableRelationship, + indexID, + rel.namespace, + rel.resourceID, + rel.relation, + rel.subjectNamespace, + rel.subjectObjectID, + rel.subjectRelation, + ) + if err != nil { + return fmt.Errorf("error loading existing relationship: %w", err) + } + + var existing *relationship + if found != nil { + existing = found.(*relationship) + } + + switch mutation.Operation { + case tuple.UpdateOperationCreate: + if existing != nil { + rt, err := existing.Relationship() + if err != nil { + return err + } + return common.NewCreateRelationshipExistsError(&rt) + } + if err := tx.Insert(tableRelationship, rel); err != nil { + return fmt.Errorf("error inserting relationship: %w", err) + } + + case tuple.UpdateOperationTouch: + if existing != nil { + rt, err := existing.Relationship() + if err != nil { + return err + } + if tuple.MustString(rt) == tuple.MustString(mutation.Relationship) { + continue + } + } + + if err := tx.Insert(tableRelationship, rel); err != nil { + return fmt.Errorf("error inserting relationship: %w", err) + } + + case tuple.UpdateOperationDelete: + if existing != nil { + if err := tx.Delete(tableRelationship, existing); err != nil { + return fmt.Errorf("error deleting relationship: %w", err) + } + } + default: + return spiceerrors.MustBugf("unknown tuple mutation operation type: %v", mutation.Operation) + } + } + + return nil +} + +func (rwt *memdbReadWriteTx) toCaveatReference(mutation tuple.RelationshipUpdate) *contextualizedCaveat { + var cr *contextualizedCaveat + if mutation.Relationship.OptionalCaveat != nil { + cr = &contextualizedCaveat{ + caveatName: mutation.Relationship.OptionalCaveat.CaveatName, + context: mutation.Relationship.OptionalCaveat.Context.AsMap(), + } + } + return cr +} + +func (rwt *memdbReadWriteTx) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return 0, false, err + } + + delOpts := options.NewDeleteOptionsWithOptionsAndDefaults(opts...) + var delLimit uint64 + if delOpts.DeleteLimit != nil && *delOpts.DeleteLimit > 0 { + delLimit = *delOpts.DeleteLimit + } + + return rwt.deleteWithLock(tx, filter, delLimit) +} + +// caller must already hold the concurrent access lock +func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.RelationshipFilter, limit uint64) (uint64, bool, error) { + // Create an iterator to find the relevant tuples + dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter) + if err != nil { + return 0, false, err + } + + bestIter, err := iteratorForFilter(tx, dsFilter) + if err != nil { + return 0, false, err + } + filteredIter := memdb.NewFilterIterator(bestIter, relationshipFilterFilterFunc(filter)) + + // Collect the tuples into a slice of mutations for the changelog + var mutations []tuple.RelationshipUpdate + var counter uint64 + + metLimit := false + for row := filteredIter.Next(); row != nil; row = filteredIter.Next() { + rt, err := row.(*relationship).Relationship() + if err != nil { + return 0, false, err + } + mutations = append(mutations, tuple.Delete(rt)) + counter++ + + if limit > 0 && counter == limit { + metLimit = true + break + } + } + + return counter, metLimit, rwt.write(tx, mutations...) +} + +func (rwt *memdbReadWriteTx) RegisterCounter(ctx context.Context, name string, filter *core.RelationshipFilter) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + foundRaw, err := tx.First(tableCounters, indexID, name) + if err != nil { + return err + } + + if foundRaw != nil { + return datastore.NewCounterAlreadyRegisteredErr(name, filter) + } + + filterBytes, err := filter.MarshalVT() + if err != nil { + return err + } + + // Insert the counter + counter := &counter{ + name, + filterBytes, + 0, + datastore.NoRevision, + } + + return tx.Insert(tableCounters, counter) +} + +func (rwt *memdbReadWriteTx) UnregisterCounter(ctx context.Context, name string) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + // Check if the counter exists + foundRaw, err := tx.First(tableCounters, indexID, name) + if err != nil { + return err + } + + if foundRaw == nil { + return datastore.NewCounterNotRegisteredErr(name) + } + + return tx.Delete(tableCounters, foundRaw) +} + +func (rwt *memdbReadWriteTx) StoreCounterValue(ctx context.Context, name string, value int, computedAtRevision datastore.Revision) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + // Check if the counter exists + foundRaw, err := tx.First(tableCounters, indexID, name) + if err != nil { + return err + } + + if foundRaw == nil { + return datastore.NewCounterNotRegisteredErr(name) + } + + counter := foundRaw.(*counter) + counter.count = value + counter.updated = computedAtRevision + + return tx.Insert(tableCounters, counter) +} + +func (rwt *memdbReadWriteTx) WriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + for _, newConfig := range newConfigs { + serialized, err := newConfig.MarshalVT() + if err != nil { + return err + } + + newConfigEntry := &namespace{newConfig.Name, serialized, rwt.newRevision} + + err = tx.Insert(tableNamespace, newConfigEntry) + if err != nil { + return err + } + } + + return nil +} + +func (rwt *memdbReadWriteTx) DeleteNamespaces(_ context.Context, nsNames ...string) error { + rwt.mustLock() + defer rwt.Unlock() + + tx, err := rwt.txSource() + if err != nil { + return err + } + + for _, nsName := range nsNames { + foundRaw, err := tx.First(tableNamespace, indexID, nsName) + if err != nil { + return err + } + + if foundRaw == nil { + return fmt.Errorf("namespace not found") + } + + if err := tx.Delete(tableNamespace, foundRaw); err != nil { + return err + } + + // Delete the relationships from the namespace + if _, _, err := rwt.deleteWithLock(tx, &v1.RelationshipFilter{ + ResourceType: nsName, + }, 0); err != nil { + return fmt.Errorf("unable to delete relationships from deleted namespace: %w", err) + } + } + + return nil +} + +func (rwt *memdbReadWriteTx) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { + var numCopied uint64 + var next *tuple.Relationship + var err error + + updates := []tuple.RelationshipUpdate{{ + Operation: tuple.UpdateOperationCreate, + }} + + for next, err = iter.Next(ctx); next != nil && err == nil; next, err = iter.Next(ctx) { + updates[0].Relationship = *next + if err := rwt.WriteRelationships(ctx, updates); err != nil { + return 0, err + } + numCopied++ + } + + return numCopied, err +} + +func relationshipFilterFilterFunc(filter *v1.RelationshipFilter) func(interface{}) bool { + return func(tupleRaw interface{}) bool { + tuple := tupleRaw.(*relationship) + + // If it doesn't match one of the resource filters, filter it. + switch { + case filter.ResourceType != "" && filter.ResourceType != tuple.namespace: + return true + case filter.OptionalResourceId != "" && filter.OptionalResourceId != tuple.resourceID: + return true + case filter.OptionalResourceIdPrefix != "" && !strings.HasPrefix(tuple.resourceID, filter.OptionalResourceIdPrefix): + return true + case filter.OptionalRelation != "" && filter.OptionalRelation != tuple.relation: + return true + } + + // If it doesn't match one of the subject filters, filter it. + if subjectFilter := filter.OptionalSubjectFilter; subjectFilter != nil { + switch { + case subjectFilter.SubjectType != tuple.subjectNamespace: + return true + case subjectFilter.OptionalSubjectId != "" && subjectFilter.OptionalSubjectId != tuple.subjectObjectID: + return true + case subjectFilter.OptionalRelation != nil && + stringz.DefaultEmpty(subjectFilter.OptionalRelation.Relation, datastore.Ellipsis) != tuple.subjectRelation: + return true + } + } + + return false + } +} + +var _ datastore.ReadWriteTransaction = &memdbReadWriteTx{} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go new file mode 100644 index 0000000..be79771 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go @@ -0,0 +1,118 @@ +package memdb + +import ( + "context" + "time" + + "github.com/authzed/spicedb/internal/datastore/revisions" + "github.com/authzed/spicedb/pkg/datastore" +) + +var ParseRevisionString = revisions.RevisionParser(revisions.Timestamp) + +func nowRevision() revisions.TimestampRevision { + return revisions.NewForTime(time.Now().UTC()) +} + +func (mdb *memdbDatastore) newRevisionID() revisions.TimestampRevision { + mdb.Lock() + defer mdb.Unlock() + + existing := mdb.revisions[len(mdb.revisions)-1].revision + created := nowRevision() + + // NOTE: The time.Now().UTC() only appears to have *microsecond* level + // precision on macOS Monterey in Go 1.19.1. This means that HeadRevision + // and the result of a ReadWriteTx could return the *same* transaction ID + // if both are executed in sequence without any other forms of delay on + // macOS. We therefore check if the created transaction ID matches that + // previously created and, if not, add to it. + // + // See: https://github.com/golang/go/issues/22037 which appeared to fix + // this in Go 1.9.2, but there appears to have been a reversion with either + // the new version of macOS or Go. + if created.Equal(existing) { + return revisions.NewForTimestamp(created.TimestampNanoSec() + 1) + } + + return created +} + +func (mdb *memdbDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) { + mdb.RLock() + defer mdb.RUnlock() + if err := mdb.checkNotClosed(); err != nil { + return nil, err + } + + return mdb.headRevisionNoLock(), nil +} + +func (mdb *memdbDatastore) SquashRevisionsForTesting() { + mdb.revisions = []snapshot{ + { + revision: nowRevision(), + db: mdb.db, + }, + } +} + +func (mdb *memdbDatastore) headRevisionNoLock() revisions.TimestampRevision { + return mdb.revisions[len(mdb.revisions)-1].revision +} + +func (mdb *memdbDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) { + mdb.RLock() + defer mdb.RUnlock() + if err := mdb.checkNotClosed(); err != nil { + return nil, err + } + + now := nowRevision() + return revisions.NewForTimestamp(now.TimestampNanoSec() - now.TimestampNanoSec()%mdb.quantizationPeriod), nil +} + +func (mdb *memdbDatastore) CheckRevision(_ context.Context, dr datastore.Revision) error { + mdb.RLock() + defer mdb.RUnlock() + if err := mdb.checkNotClosed(); err != nil { + return err + } + + return mdb.checkRevisionLocalCallerMustLock(dr) +} + +func (mdb *memdbDatastore) checkRevisionLocalCallerMustLock(dr datastore.Revision) error { + now := nowRevision() + + // Ensure the revision has not fallen outside of the GC window. If it has, it is considered + // invalid. + if mdb.revisionOutsideGCWindow(now, dr) { + return datastore.NewInvalidRevisionErr(dr, datastore.RevisionStale) + } + + // If the revision <= now and later than the GC window, it is assumed to be valid, even if + // HEAD revision is behind it. + if dr.GreaterThan(now) { + // If the revision is in the "future", then check to ensure that it is <= of HEAD to handle + // the microsecond granularity on macos (see comment above in newRevisionID) + headRevision := mdb.headRevisionNoLock() + if dr.LessThan(headRevision) || dr.Equal(headRevision) { + return nil + } + + return datastore.NewInvalidRevisionErr(dr, datastore.CouldNotDetermineRevision) + } + + return nil +} + +func (mdb *memdbDatastore) revisionOutsideGCWindow(now revisions.TimestampRevision, revisionRaw datastore.Revision) bool { + // make an exception for head revision - it will be acceptable even if outside GC Window + if revisionRaw.Equal(mdb.headRevisionNoLock()) { + return false + } + + oldest := revisions.NewForTimestamp(now.TimestampNanoSec() + mdb.negativeGCWindow) + return revisionRaw.LessThan(oldest) +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go new file mode 100644 index 0000000..7905d48 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go @@ -0,0 +1,232 @@ +package memdb + +import ( + "time" + + "github.com/hashicorp/go-memdb" + "github.com/rs/zerolog" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const ( + tableNamespace = "namespace" + + tableRelationship = "relationship" + indexID = "id" + indexNamespace = "namespace" + indexNamespaceAndRelation = "namespaceAndRelation" + indexSubjectNamespace = "subjectNamespace" + + tableCounters = "counters" + + tableChangelog = "changelog" + indexRevision = "id" +) + +type namespace struct { + name string + configBytes []byte + updated datastore.Revision +} + +func (ns namespace) MarshalZerologObject(e *zerolog.Event) { + e.Stringer("rev", ns.updated).Str("name", ns.name) +} + +type counter struct { + name string + filterBytes []byte + count int + updated datastore.Revision +} + +type relationship struct { + namespace string + resourceID string + relation string + subjectNamespace string + subjectObjectID string + subjectRelation string + caveat *contextualizedCaveat + integrity *relationshipIntegrity + expiration *time.Time +} + +type relationshipIntegrity struct { + keyID string + hash []byte + timestamp time.Time +} + +func (ri relationshipIntegrity) MarshalZerologObject(e *zerolog.Event) { + e.Str("keyID", ri.keyID).Bytes("hash", ri.hash).Time("timestamp", ri.timestamp) +} + +func (ri relationshipIntegrity) RelationshipIntegrity() *core.RelationshipIntegrity { + return &core.RelationshipIntegrity{ + KeyId: ri.keyID, + Hash: ri.hash, + HashedAt: timestamppb.New(ri.timestamp), + } +} + +type contextualizedCaveat struct { + caveatName string + context map[string]any +} + +func (cr *contextualizedCaveat) ContextualizedCaveat() (*core.ContextualizedCaveat, error) { + if cr == nil { + return nil, nil + } + v, err := structpb.NewStruct(cr.context) + if err != nil { + return nil, err + } + return &core.ContextualizedCaveat{ + CaveatName: cr.caveatName, + Context: v, + }, nil +} + +func (r relationship) String() string { + caveat := "" + if r.caveat != nil { + caveat = "[" + r.caveat.caveatName + "]" + } + + expiration := "" + if r.expiration != nil { + expiration = "[expiration:" + r.expiration.Format(time.RFC3339Nano) + "]" + } + + return r.namespace + ":" + r.resourceID + "#" + r.relation + "@" + r.subjectNamespace + ":" + r.subjectObjectID + "#" + r.subjectRelation + caveat + expiration +} + +func (r relationship) MarshalZerologObject(e *zerolog.Event) { + e.Str("rel", r.String()) +} + +func (r relationship) Relationship() (tuple.Relationship, error) { + cr, err := r.caveat.ContextualizedCaveat() + if err != nil { + return tuple.Relationship{}, err + } + + var ig *core.RelationshipIntegrity + if r.integrity != nil { + ig = r.integrity.RelationshipIntegrity() + } + + return tuple.Relationship{ + RelationshipReference: tuple.RelationshipReference{ + Resource: tuple.ObjectAndRelation{ + ObjectType: r.namespace, + ObjectID: r.resourceID, + Relation: r.relation, + }, + Subject: tuple.ObjectAndRelation{ + ObjectType: r.subjectNamespace, + ObjectID: r.subjectObjectID, + Relation: r.subjectRelation, + }, + }, + OptionalCaveat: cr, + OptionalIntegrity: ig, + OptionalExpiration: r.expiration, + }, nil +} + +type changelog struct { + revisionNanos int64 + changes datastore.RevisionChanges +} + +var schema = &memdb.DBSchema{ + Tables: map[string]*memdb.TableSchema{ + tableNamespace: { + Name: tableNamespace, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, + tableChangelog: { + Name: tableChangelog, + Indexes: map[string]*memdb.IndexSchema{ + indexRevision: { + Name: indexRevision, + Unique: true, + Indexer: &memdb.IntFieldIndex{Field: "revisionNanos"}, + }, + }, + }, + tableRelationship: { + Name: tableRelationship, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{Field: "namespace"}, + &memdb.StringFieldIndex{Field: "resourceID"}, + &memdb.StringFieldIndex{Field: "relation"}, + &memdb.StringFieldIndex{Field: "subjectNamespace"}, + &memdb.StringFieldIndex{Field: "subjectObjectID"}, + &memdb.StringFieldIndex{Field: "subjectRelation"}, + }, + }, + }, + indexNamespace: { + Name: indexNamespace, + Unique: false, + Indexer: &memdb.StringFieldIndex{Field: "namespace"}, + }, + indexNamespaceAndRelation: { + Name: indexNamespaceAndRelation, + Unique: false, + Indexer: &memdb.CompoundIndex{ + Indexes: []memdb.Indexer{ + &memdb.StringFieldIndex{Field: "namespace"}, + &memdb.StringFieldIndex{Field: "relation"}, + }, + }, + }, + indexSubjectNamespace: { + Name: indexSubjectNamespace, + Unique: false, + Indexer: &memdb.StringFieldIndex{Field: "subjectNamespace"}, + }, + }, + }, + tableCaveats: { + Name: tableCaveats, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, + tableCounters: { + Name: tableCounters, + Indexes: map[string]*memdb.IndexSchema{ + indexID: { + Name: indexID, + Unique: true, + Indexer: &memdb.StringFieldIndex{Field: "name"}, + }, + }, + }, + }, +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go new file mode 100644 index 0000000..33665a1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go @@ -0,0 +1,51 @@ +package memdb + +import ( + "context" + "fmt" + + "github.com/authzed/spicedb/pkg/datastore" +) + +func (mdb *memdbDatastore) Statistics(ctx context.Context) (datastore.Stats, error) { + head, err := mdb.HeadRevision(ctx) + if err != nil { + return datastore.Stats{}, fmt.Errorf("unable to compute head revision: %w", err) + } + + count, err := mdb.countRelationships(ctx) + if err != nil { + return datastore.Stats{}, fmt.Errorf("unable to count relationships: %w", err) + } + + objTypes, err := mdb.SnapshotReader(head).ListAllNamespaces(ctx) + if err != nil { + return datastore.Stats{}, fmt.Errorf("unable to list object types: %w", err) + } + + return datastore.Stats{ + UniqueID: mdb.uniqueID, + EstimatedRelationshipCount: count, + ObjectTypeStatistics: datastore.ComputeObjectTypeStats(objTypes), + }, nil +} + +func (mdb *memdbDatastore) countRelationships(_ context.Context) (uint64, error) { + mdb.RLock() + defer mdb.RUnlock() + + txn := mdb.db.Txn(false) + defer txn.Abort() + + it, err := txn.LowerBound(tableRelationship, indexID) + if err != nil { + return 0, err + } + + var count uint64 + for row := it.Next(); row != nil; row = it.Next() { + count++ + } + + return count, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go new file mode 100644 index 0000000..eaa4812 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go @@ -0,0 +1,148 @@ +package memdb + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/hashicorp/go-memdb" + + "github.com/authzed/spicedb/internal/datastore/revisions" + "github.com/authzed/spicedb/pkg/datastore" +) + +const errWatchError = "watch error: %w" + +func (mdb *memdbDatastore) Watch(ctx context.Context, ar datastore.Revision, options datastore.WatchOptions) (<-chan datastore.RevisionChanges, <-chan error) { + watchBufferLength := options.WatchBufferLength + if watchBufferLength == 0 { + watchBufferLength = mdb.watchBufferLength + } + + updates := make(chan datastore.RevisionChanges, watchBufferLength) + errs := make(chan error, 1) + + if options.EmissionStrategy == datastore.EmitImmediatelyStrategy { + close(updates) + errs <- errors.New("emit immediately strategy is unsupported in MemDB") + return updates, errs + } + + watchBufferWriteTimeout := options.WatchBufferWriteTimeout + if watchBufferWriteTimeout == 0 { + watchBufferWriteTimeout = mdb.watchBufferWriteTimeout + } + + sendChange := func(change datastore.RevisionChanges) bool { + select { + case updates <- change: + return true + + default: + // If we cannot immediately write, setup the timer and try again. + } + + timer := time.NewTimer(watchBufferWriteTimeout) + defer timer.Stop() + + select { + case updates <- change: + return true + + case <-timer.C: + errs <- datastore.NewWatchDisconnectedErr() + return false + } + } + + go func() { + defer close(updates) + defer close(errs) + + currentTxn := ar.(revisions.TimestampRevision).TimestampNanoSec() + + for { + var stagedUpdates []datastore.RevisionChanges + var watchChan <-chan struct{} + var err error + stagedUpdates, currentTxn, watchChan, err = mdb.loadChanges(ctx, currentTxn, options) + if err != nil { + errs <- err + return + } + + // Write the staged updates to the channel + for _, changeToWrite := range stagedUpdates { + if !sendChange(changeToWrite) { + return + } + } + + // Wait for new changes + ws := memdb.NewWatchSet() + ws.Add(watchChan) + + err = ws.WatchCtx(ctx) + if err != nil { + switch { + case errors.Is(err, context.Canceled): + errs <- datastore.NewWatchCanceledErr() + default: + errs <- fmt.Errorf(errWatchError, err) + } + return + } + } + }() + + return updates, errs +} + +func (mdb *memdbDatastore) loadChanges(_ context.Context, currentTxn int64, options datastore.WatchOptions) ([]datastore.RevisionChanges, int64, <-chan struct{}, error) { + mdb.RLock() + defer mdb.RUnlock() + + if err := mdb.checkNotClosed(); err != nil { + return nil, 0, nil, err + } + + loadNewTxn := mdb.db.Txn(false) + defer loadNewTxn.Abort() + + it, err := loadNewTxn.LowerBound(tableChangelog, indexRevision, currentTxn+1) + if err != nil { + return nil, 0, nil, fmt.Errorf(errWatchError, err) + } + + var changes []datastore.RevisionChanges + lastRevision := currentTxn + for changeRaw := it.Next(); changeRaw != nil; changeRaw = it.Next() { + change := changeRaw.(*changelog) + + if options.Content&datastore.WatchRelationships == datastore.WatchRelationships && len(change.changes.RelationshipChanges) > 0 { + changes = append(changes, change.changes) + } + + if options.Content&datastore.WatchSchema == datastore.WatchSchema && + len(change.changes.ChangedDefinitions) > 0 || len(change.changes.DeletedCaveats) > 0 || len(change.changes.DeletedNamespaces) > 0 { + changes = append(changes, change.changes) + } + + if options.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints && change.revisionNanos > lastRevision { + changes = append(changes, datastore.RevisionChanges{ + Revision: revisions.NewForTimestamp(change.revisionNanos), + IsCheckpoint: true, + }) + } + + lastRevision = change.revisionNanos + } + + watchChan, _, err := loadNewTxn.LastWatch(tableChangelog, indexRevision) + if err != nil { + return nil, 0, nil, fmt.Errorf(errWatchError, err) + } + + return changes, lastRevision, watchChan, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go new file mode 100644 index 0000000..7092728 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go @@ -0,0 +1,79 @@ +package revisions + +import ( + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// RevisionKind is an enum of the different kinds of revisions that can be used. +type RevisionKind string + +const ( + // Timestamp is a revision that is a timestamp. + Timestamp RevisionKind = "timestamp" + + // TransactionID is a revision that is a transaction ID. + TransactionID = "txid" + + // HybridLogicalClock is a revision that is a hybrid logical clock. + HybridLogicalClock = "hlc" +) + +// ParsingFunc is a function that can parse a string into a revision. +type ParsingFunc func(revisionStr string) (rev datastore.Revision, err error) + +// RevisionParser returns a ParsingFunc for the given RevisionKind. +func RevisionParser(kind RevisionKind) ParsingFunc { + switch kind { + case TransactionID: + return parseTransactionIDRevisionString + + case Timestamp: + return parseTimestampRevisionString + + case HybridLogicalClock: + return parseHLCRevisionString + + default: + return func(revisionStr string) (rev datastore.Revision, err error) { + return nil, spiceerrors.MustBugf("unknown revision kind: %v", kind) + } + } +} + +// CommonDecoder is a revision decoder that can decode revisions of a given kind. +type CommonDecoder struct { + Kind RevisionKind +} + +func (cd CommonDecoder) RevisionFromString(s string) (datastore.Revision, error) { + switch cd.Kind { + case TransactionID: + return parseTransactionIDRevisionString(s) + + case Timestamp: + return parseTimestampRevisionString(s) + + case HybridLogicalClock: + return parseHLCRevisionString(s) + + default: + return nil, spiceerrors.MustBugf("unknown revision kind in decoder: %v", cd.Kind) + } +} + +// WithInexactFloat64 is an interface that can be implemented by a revision to +// provide an inexact float64 representation of the revision. +type WithInexactFloat64 interface { + // InexactFloat64 returns a float64 that is an inexact representation of the + // revision. + InexactFloat64() float64 +} + +// WithTimestampRevision is an interface that can be implemented by a revision to +// provide a timestamp. +type WithTimestampRevision interface { + datastore.Revision + TimestampNanoSec() int64 + ConstructForTimestamp(timestampNanoSec int64) WithTimestampRevision +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go new file mode 100644 index 0000000..e4f7fc6 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go @@ -0,0 +1,166 @@ +package revisions + +import ( + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/ccoveille/go-safecast" + "github.com/shopspring/decimal" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var zeroHLC = HLCRevision{} + +// NOTE: This *must* match the length defined in CRDB or the implementation below will break. +const logicalClockLength = 10 + +var logicalClockOffset = uint32(math.Pow10(logicalClockLength + 1)) + +// HLCRevision is a revision that is a hybrid logical clock, stored as two integers. +// The first integer is the timestamp in nanoseconds, and the second integer is the +// logical clock defined as 11 digits, with the first digit being ignored to ensure +// precision of the given logical clock. +type HLCRevision struct { + time int64 + logicalclock uint32 +} + +// parseHLCRevisionString parses a string into a hybrid logical clock revision. +func parseHLCRevisionString(revisionStr string) (datastore.Revision, error) { + pieces := strings.Split(revisionStr, ".") + if len(pieces) == 1 { + // If there is no decimal point, assume the revision is a timestamp. + timestamp, err := strconv.ParseInt(pieces[0], 10, 64) + if err != nil { + return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) + } + return HLCRevision{timestamp, logicalClockOffset}, nil + } + + if len(pieces) != 2 { + return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) + } + + timestamp, err := strconv.ParseInt(pieces[0], 10, 64) + if err != nil { + return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) + } + + if len(pieces[1]) > logicalClockLength { + return datastore.NoRevision, spiceerrors.MustBugf("invalid revision string due to unexpected logical clock size (%d): %q", len(pieces[1]), revisionStr) + } + + paddedLogicalClockStr := pieces[1] + strings.Repeat("0", logicalClockLength-len(pieces[1])) + logicalclock, err := strconv.ParseUint(paddedLogicalClockStr, 10, 64) + if err != nil { + return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr) + } + + if logicalclock > math.MaxUint32 { + return datastore.NoRevision, spiceerrors.MustBugf("received logical lock that exceeds MaxUint32 (%d > %d): revision %q", logicalclock, math.MaxUint32, revisionStr) + } + + uintLogicalClock, err := safecast.ToUint32(logicalclock) + if err != nil { + return datastore.NoRevision, spiceerrors.MustBugf("could not cast logicalclock to uint32: %v", err) + } + + return HLCRevision{timestamp, uintLogicalClock + logicalClockOffset}, nil +} + +// HLCRevisionFromString parses a string into a hybrid logical clock revision. +func HLCRevisionFromString(revisionStr string) (HLCRevision, error) { + rev, err := parseHLCRevisionString(revisionStr) + if err != nil { + return zeroHLC, err + } + + return rev.(HLCRevision), nil +} + +// NewForHLC creates a new revision for the given hybrid logical clock. +func NewForHLC(decimal decimal.Decimal) (HLCRevision, error) { + rev, err := HLCRevisionFromString(decimal.String()) + if err != nil { + return zeroHLC, fmt.Errorf("invalid HLC decimal: %v (%s) => %w", decimal, decimal.String(), err) + } + + return rev, nil +} + +// NewHLCForTime creates a new revision for the given time. +func NewHLCForTime(time time.Time) HLCRevision { + return HLCRevision{time.UnixNano(), logicalClockOffset} +} + +func (hlc HLCRevision) ByteSortable() bool { + return true +} + +func (hlc HLCRevision) Equal(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroHLC + } + + rhsHLC := rhs.(HLCRevision) + return hlc.time == rhsHLC.time && hlc.logicalclock == rhsHLC.logicalclock +} + +func (hlc HLCRevision) GreaterThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroHLC + } + + rhsHLC := rhs.(HLCRevision) + return hlc.time > rhsHLC.time || (hlc.time == rhsHLC.time && hlc.logicalclock > rhsHLC.logicalclock) +} + +func (hlc HLCRevision) LessThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroHLC + } + + rhsHLC := rhs.(HLCRevision) + return hlc.time < rhsHLC.time || (hlc.time == rhsHLC.time && hlc.logicalclock < rhsHLC.logicalclock) +} + +func (hlc HLCRevision) String() string { + logicalClockString := strconv.FormatInt(int64(hlc.logicalclock)-int64(logicalClockOffset), 10) + return strconv.FormatInt(hlc.time, 10) + "." + strings.Repeat("0", logicalClockLength-len(logicalClockString)) + logicalClockString +} + +func (hlc HLCRevision) TimestampNanoSec() int64 { + return hlc.time +} + +func (hlc HLCRevision) InexactFloat64() float64 { + return float64(hlc.time) + float64(hlc.logicalclock-logicalClockOffset)/math.Pow10(logicalClockLength) +} + +func (hlc HLCRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevision { + return HLCRevision{timestamp, logicalClockOffset} +} + +func (hlc HLCRevision) AsDecimal() (decimal.Decimal, error) { + return decimal.NewFromString(hlc.String()) +} + +var ( + _ datastore.Revision = HLCRevision{} + _ WithTimestampRevision = HLCRevision{} +) + +// HLCKeyFunc is used to convert a simple HLC for use in maps. +func HLCKeyFunc(r HLCRevision) HLCRevision { + return r +} + +// HLCKeyLessThanFunc is used to compare keys created by the HLCKeyFunc. +func HLCKeyLessThanFunc(lhs, rhs HLCRevision) bool { + return lhs.time < rhs.time || (lhs.time == rhs.time && lhs.logicalclock < rhs.logicalclock) +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go new file mode 100644 index 0000000..3a5a919 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go @@ -0,0 +1,118 @@ +package revisions + +import ( + "context" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/benbjohnson/clock" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/singleflight" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" +) + +var tracer = otel.Tracer("spicedb/internal/datastore/common/revisions") + +// OptimizedRevisionFunction instructs the datastore to compute its own current +// optimized revision given the specific quantization, and return for how long +// it will remain valid. +type OptimizedRevisionFunction func(context.Context) (rev datastore.Revision, validFor time.Duration, err error) + +// NewCachedOptimizedRevisions returns a CachedOptimizedRevisions for the given configuration +func NewCachedOptimizedRevisions(maxRevisionStaleness time.Duration) *CachedOptimizedRevisions { + return &CachedOptimizedRevisions{ + maxRevisionStaleness: maxRevisionStaleness, + clockFn: clock.New(), + } +} + +// SetOptimizedRevisionFunc must be called after construction, and is the method +// by which one specializes this helper for a specific datastore. +func (cor *CachedOptimizedRevisions) SetOptimizedRevisionFunc(revisionFunc OptimizedRevisionFunction) { + cor.optimizedFunc = revisionFunc +} + +func (cor *CachedOptimizedRevisions) OptimizedRevision(ctx context.Context) (datastore.Revision, error) { + span := trace.SpanFromContext(ctx) + localNow := cor.clockFn.Now() + + // Subtract a random amount of time from now, to let barely expired candidates get selected + adjustedNow := localNow + if cor.maxRevisionStaleness > 0 { + // nolint:gosec + // G404 use of non cryptographically secure random number generator is not a security concern here, + // as we are using it to introduce randomness to the accepted staleness of a revision and reduce the odds of + // a thundering herd to the datastore + adjustedNow = localNow.Add(-1 * time.Duration(rand.Int63n(cor.maxRevisionStaleness.Nanoseconds())) * time.Nanosecond) + } + + cor.RLock() + for _, candidate := range cor.candidates { + if candidate.validThrough.After(adjustedNow) { + cor.RUnlock() + log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", candidate.validThrough).Msg("returning cached revision") + span.AddEvent("returning cached revision") + return candidate.revision, nil + } + } + cor.RUnlock() + + newQuantizedRevision, err, _ := cor.updateGroup.Do("", func() (interface{}, error) { + log.Ctx(ctx).Debug().Time("now", localNow).Msg("computing new revision") + span.AddEvent("computing new revision") + + optimized, validFor, err := cor.optimizedFunc(ctx) + if err != nil { + return nil, fmt.Errorf("unable to compute optimized revision: %w", err) + } + + rvt := localNow.Add(validFor) + + // Prune the candidates that have definitely expired + cor.Lock() + var numToDrop uint + for _, candidate := range cor.candidates { + if candidate.validThrough.Add(cor.maxRevisionStaleness).Before(localNow) { + numToDrop++ + } else { + break + } + } + + cor.candidates = cor.candidates[numToDrop:] + cor.candidates = append(cor.candidates, validRevision{optimized, rvt}) + cor.Unlock() + + log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", rvt).Stringer("validFor", validFor).Msg("setting valid through") + return optimized, nil + }) + if err != nil { + return datastore.NoRevision, err + } + return newQuantizedRevision.(datastore.Revision), err +} + +// CachedOptimizedRevisions does caching and deduplication for requests for optimized revisions. +type CachedOptimizedRevisions struct { + sync.RWMutex + + maxRevisionStaleness time.Duration + optimizedFunc OptimizedRevisionFunction + clockFn clock.Clock + + // these values are read and set by multiple consumers + candidates []validRevision // GUARDED_BY(RWMutex) + + // the updategroup consolidates concurrent requests to the database into 1 + updateGroup singleflight.Group +} + +type validRevision struct { + revision datastore.Revision + validThrough time.Time +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go new file mode 100644 index 0000000..ef793c8 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go @@ -0,0 +1,125 @@ +package revisions + +import ( + "context" + "time" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// RemoteNowFunction queries the datastore to get a current revision. +type RemoteNowFunction func(context.Context) (datastore.Revision, error) + +// RemoteClockRevisions handles revision calculation for datastores that provide +// their own clocks. +type RemoteClockRevisions struct { + *CachedOptimizedRevisions + + gcWindowNanos int64 + nowFunc RemoteNowFunction + followerReadDelayNanos int64 + quantizationNanos int64 +} + +// NewRemoteClockRevisions returns a RemoteClockRevisions for the given configuration +func NewRemoteClockRevisions(gcWindow, maxRevisionStaleness, followerReadDelay, quantization time.Duration) *RemoteClockRevisions { + // Ensure the max revision staleness never exceeds the GC window. + if maxRevisionStaleness > gcWindow { + log.Warn(). + Dur("maxRevisionStaleness", maxRevisionStaleness). + Dur("gcWindow", gcWindow). + Msg("the configured maximum revision staleness exceeds the configured gc window, so capping to gcWindow") + maxRevisionStaleness = gcWindow - 1 + } + + revisions := &RemoteClockRevisions{ + CachedOptimizedRevisions: NewCachedOptimizedRevisions( + maxRevisionStaleness, + ), + gcWindowNanos: gcWindow.Nanoseconds(), + followerReadDelayNanos: followerReadDelay.Nanoseconds(), + quantizationNanos: quantization.Nanoseconds(), + } + + revisions.SetOptimizedRevisionFunc(revisions.optimizedRevisionFunc) + + return revisions +} + +func (rcr *RemoteClockRevisions) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, error) { + nowRev, err := rcr.nowFunc(ctx) + if err != nil { + return datastore.NoRevision, 0, err + } + + if nowRev == datastore.NoRevision { + return datastore.NoRevision, 0, datastore.NewInvalidRevisionErr(nowRev, datastore.CouldNotDetermineRevision) + } + + nowTS, ok := nowRev.(WithTimestampRevision) + if !ok { + return datastore.NoRevision, 0, spiceerrors.MustBugf("expected with-timestamp revision, got %T", nowRev) + } + + delayedNow := nowTS.TimestampNanoSec() - rcr.followerReadDelayNanos + quantized := delayedNow + validForNanos := int64(0) + if rcr.quantizationNanos > 0 { + afterLastQuantization := delayedNow % rcr.quantizationNanos + quantized -= afterLastQuantization + validForNanos = rcr.quantizationNanos - afterLastQuantization + } + log.Ctx(ctx).Debug(). + Time("quantized", time.Unix(0, quantized)). + Int64("readSkew", rcr.followerReadDelayNanos). + Int64("totalSkew", nowTS.TimestampNanoSec()-quantized). + Msg("revision skews") + + return nowTS.ConstructForTimestamp(quantized), time.Duration(validForNanos) * time.Nanosecond, nil +} + +// SetNowFunc sets the function used to determine the head revision +func (rcr *RemoteClockRevisions) SetNowFunc(nowFunc RemoteNowFunction) { + rcr.nowFunc = nowFunc +} + +func (rcr *RemoteClockRevisions) CheckRevision(ctx context.Context, dsRevision datastore.Revision) error { + if dsRevision == datastore.NoRevision { + return datastore.NewInvalidRevisionErr(dsRevision, datastore.CouldNotDetermineRevision) + } + + revision := dsRevision.(WithTimestampRevision) + + ctx, span := tracer.Start(ctx, "CheckRevision") + defer span.End() + + // Make sure the system time indicated is within the software GC window + now, err := rcr.nowFunc(ctx) + if err != nil { + return err + } + + nowTS, ok := now.(WithTimestampRevision) + if !ok { + return spiceerrors.MustBugf("expected HLC revision, got %T", now) + } + + nowNanos := nowTS.TimestampNanoSec() + revisionNanos := revision.TimestampNanoSec() + + isStale := revisionNanos < (nowNanos - rcr.gcWindowNanos) + if isStale { + log.Ctx(ctx).Debug().Stringer("now", now).Stringer("revision", revision).Msg("stale revision") + return datastore.NewInvalidRevisionErr(revision, datastore.RevisionStale) + } + + isUnknown := revisionNanos > nowNanos + if isUnknown { + log.Ctx(ctx).Debug().Stringer("now", now).Stringer("revision", revision).Msg("unknown revision") + return datastore.NewInvalidRevisionErr(revision, datastore.CouldNotDetermineRevision) + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go new file mode 100644 index 0000000..fc2a250 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go @@ -0,0 +1,97 @@ +package revisions + +import ( + "fmt" + "strconv" + "time" + + "github.com/authzed/spicedb/pkg/datastore" +) + +// TimestampRevision is a revision that is a timestamp. +type TimestampRevision int64 + +var zeroTimestampRevision = TimestampRevision(0) + +// NewForTime creates a new revision for the given time. +func NewForTime(time time.Time) TimestampRevision { + return TimestampRevision(time.UnixNano()) +} + +// NewForTimestamp creates a new revision for the given timestamp. +func NewForTimestamp(timestampNanosec int64) TimestampRevision { + return TimestampRevision(timestampNanosec) +} + +// parseTimestampRevisionString parses a string into a timestamp revision. +func parseTimestampRevisionString(revisionStr string) (rev datastore.Revision, err error) { + parsed, err := strconv.ParseInt(revisionStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid integer revision: %w", err) + } + + return TimestampRevision(parsed), nil +} + +func (ir TimestampRevision) ByteSortable() bool { + return true +} + +func (ir TimestampRevision) Equal(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTimestampRevision + } + + return int64(ir) == int64(rhs.(TimestampRevision)) +} + +func (ir TimestampRevision) GreaterThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTimestampRevision + } + + return int64(ir) > int64(rhs.(TimestampRevision)) +} + +func (ir TimestampRevision) LessThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTimestampRevision + } + + return int64(ir) < int64(rhs.(TimestampRevision)) +} + +func (ir TimestampRevision) TimestampNanoSec() int64 { + return int64(ir) +} + +func (ir TimestampRevision) String() string { + return strconv.FormatInt(int64(ir), 10) +} + +func (ir TimestampRevision) Time() time.Time { + return time.Unix(0, int64(ir)) +} + +func (ir TimestampRevision) WithInexactFloat64() float64 { + return float64(ir) +} + +func (ir TimestampRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevision { + return TimestampRevision(timestamp) +} + +var ( + _ datastore.Revision = TimestampRevision(0) + _ WithTimestampRevision = TimestampRevision(0) +) + +// TimestampIDKeyFunc is used to create keys for timestamps. +func TimestampIDKeyFunc(r TimestampRevision) int64 { + return int64(r) +} + +// TimestampIDKeyLessThanFunc is used to create keys for timestamps. +func TimestampIDKeyLessThanFunc(l, r int64) bool { + return l < r +} diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go new file mode 100644 index 0000000..31d837f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go @@ -0,0 +1,80 @@ +package revisions + +import ( + "fmt" + "strconv" + + "github.com/authzed/spicedb/pkg/datastore" +) + +// TransactionIDRevision is a revision that is a transaction ID. +type TransactionIDRevision uint64 + +var zeroTransactionIDRevision = TransactionIDRevision(0) + +// NewForTransactionID creates a new revision for the given transaction ID. +func NewForTransactionID(transactionID uint64) TransactionIDRevision { + return TransactionIDRevision(transactionID) +} + +// parseTransactionIDRevisionString parses a string into a transaction ID revision. +func parseTransactionIDRevisionString(revisionStr string) (rev datastore.Revision, err error) { + parsed, err := strconv.ParseUint(revisionStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid integer revision: %w", err) + } + + return TransactionIDRevision(parsed), nil +} + +func (ir TransactionIDRevision) ByteSortable() bool { + return true +} + +func (ir TransactionIDRevision) Equal(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTransactionIDRevision + } + + return uint64(ir) == uint64(rhs.(TransactionIDRevision)) +} + +func (ir TransactionIDRevision) GreaterThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTransactionIDRevision + } + + return uint64(ir) > uint64(rhs.(TransactionIDRevision)) +} + +func (ir TransactionIDRevision) LessThan(rhs datastore.Revision) bool { + if rhs == datastore.NoRevision { + rhs = zeroTransactionIDRevision + } + + return uint64(ir) < uint64(rhs.(TransactionIDRevision)) +} + +func (ir TransactionIDRevision) TransactionID() uint64 { + return uint64(ir) +} + +func (ir TransactionIDRevision) String() string { + return strconv.FormatUint(uint64(ir), 10) +} + +func (ir TransactionIDRevision) WithInexactFloat64() float64 { + return float64(ir) +} + +var _ datastore.Revision = TransactionIDRevision(0) + +// TransactionIDKeyFunc is used to create keys for transaction IDs. +func TransactionIDKeyFunc(r TransactionIDRevision) uint64 { + return uint64(r) +} + +// TransactionIDKeyLessThanFunc is used to create keys for transaction IDs. +func TransactionIDKeyLessThanFunc(l, r uint64) bool { + return l < r +} diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/doc.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/doc.go new file mode 100644 index 0000000..325981d --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/doc.go @@ -0,0 +1,2 @@ +// Package developmentmembership defines operations with sets. To be used in tests only. +package developmentmembership diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/foundsubject.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/foundsubject.go new file mode 100644 index 0000000..bf93d56 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/foundsubject.go @@ -0,0 +1,127 @@ +package developmentmembership + +import ( + "sort" + "strings" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + + "github.com/authzed/spicedb/pkg/tuple" +) + +// NewFoundSubject creates a new FoundSubject for a subject and a set of its resources. +func NewFoundSubject(subject *core.DirectSubject, resources ...tuple.ObjectAndRelation) FoundSubject { + return FoundSubject{tuple.FromCoreObjectAndRelation(subject.Subject), nil, subject.CaveatExpression, NewONRSet(resources...)} +} + +// FoundSubject contains a single found subject and all the relationships in which that subject +// is a member which were found via the ONRs expansion. +type FoundSubject struct { + // subject is the subject found. + subject tuple.ObjectAndRelation + + // excludedSubjects are any subjects excluded. Only should be set if subject is a wildcard. + excludedSubjects []FoundSubject + + // caveatExpression is the conditional expression on the found subject. + caveatExpression *core.CaveatExpression + + // resources are the resources under which the subject lives that informed the locating + // of this subject for the root ONR. + resources ONRSet +} + +// GetSubjectId is named to match the Subject interface for the BaseSubjectSet. +// +//nolint:all +func (fs FoundSubject) GetSubjectId() string { + return fs.subject.ObjectID +} + +func (fs FoundSubject) GetCaveatExpression() *core.CaveatExpression { + return fs.caveatExpression +} + +func (fs FoundSubject) GetExcludedSubjects() []FoundSubject { + return fs.excludedSubjects +} + +// Subject returns the Subject of the FoundSubject. +func (fs FoundSubject) Subject() tuple.ObjectAndRelation { + return fs.subject +} + +// WildcardType returns the object type for the wildcard subject, if this is a wildcard subject. +func (fs FoundSubject) WildcardType() (string, bool) { + if fs.subject.ObjectID == tuple.PublicWildcard { + return fs.subject.ObjectType, true + } + + return "", false +} + +// ExcludedSubjectsFromWildcard returns those subjects excluded from the wildcard subject. +// If not a wildcard subject, returns false. +func (fs FoundSubject) ExcludedSubjectsFromWildcard() ([]FoundSubject, bool) { + if fs.subject.ObjectID == tuple.PublicWildcard { + return fs.excludedSubjects, true + } + + return nil, false +} + +func (fs FoundSubject) excludedSubjectStrings() []string { + excludedStrings := make([]string, 0, len(fs.excludedSubjects)) + for _, excludedSubject := range fs.excludedSubjects { + excludedSubjectString := tuple.StringONR(excludedSubject.subject) + if excludedSubject.GetCaveatExpression() != nil { + excludedSubjectString += "[...]" + } + excludedStrings = append(excludedStrings, excludedSubjectString) + } + + sort.Strings(excludedStrings) + return excludedStrings +} + +// ToValidationString returns the FoundSubject in a format that is consumable by the validationfile +// package. +func (fs FoundSubject) ToValidationString() string { + onrString := tuple.StringONR(fs.Subject()) + validationString := onrString + if fs.caveatExpression != nil { + validationString = validationString + "[...]" + } + + excluded, isWildcard := fs.ExcludedSubjectsFromWildcard() + if isWildcard && len(excluded) > 0 { + validationString = validationString + " - {" + strings.Join(fs.excludedSubjectStrings(), ", ") + "}" + } + + return validationString +} + +func (fs FoundSubject) String() string { + return fs.ToValidationString() +} + +// ParentResources returns all the resources in which the subject was found as per the expand. +func (fs FoundSubject) ParentResources() []tuple.ObjectAndRelation { + return fs.resources.AsSlice() +} + +// FoundSubjects contains the subjects found for a specific ONR. +type FoundSubjects struct { + // subjects is a map from the Subject ONR (as a string) to the FoundSubject information. + subjects *TrackingSubjectSet +} + +// ListFound returns a slice of all the FoundSubject's. +func (fs FoundSubjects) ListFound() []FoundSubject { + return fs.subjects.ToSlice() +} + +// LookupSubject returns the FoundSubject for a matching subject, if any. +func (fs FoundSubjects) LookupSubject(subject tuple.ObjectAndRelation) (FoundSubject, bool) { + return fs.subjects.Get(subject) +} diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/membership.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/membership.go new file mode 100644 index 0000000..caa5e8f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/membership.go @@ -0,0 +1,167 @@ +package developmentmembership + +import ( + "fmt" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// Set represents the set of membership for one or more ONRs, based on expansion +// trees. +type Set struct { + // objectsAndRelations is a map from an ONR (as a string) to the subjects found for that ONR. + objectsAndRelations map[string]FoundSubjects +} + +// SubjectsByONR returns a map from ONR (as a string) to the FoundSubjects for that ONR. +func (ms *Set) SubjectsByONR() map[string]FoundSubjects { + return ms.objectsAndRelations +} + +// NewMembershipSet constructs a new membership set. +// +// NOTE: This is designed solely for the developer API and should *not* be used in any performance +// sensitive code. +func NewMembershipSet() *Set { + return &Set{ + objectsAndRelations: map[string]FoundSubjects{}, + } +} + +// AddExpansion adds the expansion of an ONR to the membership set. Returns false if the ONR was already added. +// +// NOTE: The expansion tree *should* be the fully recursive expansion. +func (ms *Set) AddExpansion(onr tuple.ObjectAndRelation, expansion *core.RelationTupleTreeNode) (FoundSubjects, bool, error) { + onrString := tuple.StringONR(onr) + existing, ok := ms.objectsAndRelations[onrString] + if ok { + return existing, false, nil + } + + tss, err := populateFoundSubjects(onr, expansion) + if err != nil { + return FoundSubjects{}, false, err + } + + fs := tss.ToFoundSubjects() + ms.objectsAndRelations[onrString] = fs + return fs, true, nil +} + +// AccessibleExpansionSubjects returns a TrackingSubjectSet representing the set of accessible subjects in the expansion. +func AccessibleExpansionSubjects(treeNode *core.RelationTupleTreeNode) (*TrackingSubjectSet, error) { + return populateFoundSubjects(tuple.FromCoreObjectAndRelation(treeNode.Expanded), treeNode) +} + +func populateFoundSubjects(rootONR tuple.ObjectAndRelation, treeNode *core.RelationTupleTreeNode) (*TrackingSubjectSet, error) { + resource := rootONR + if treeNode.Expanded != nil { + resource = tuple.FromCoreObjectAndRelation(treeNode.Expanded) + } + + switch typed := treeNode.NodeType.(type) { + case *core.RelationTupleTreeNode_IntermediateNode: + switch typed.IntermediateNode.Operation { + case core.SetOperationUserset_UNION: + toReturn := NewTrackingSubjectSet() + for _, child := range typed.IntermediateNode.ChildNodes { + tss, err := populateFoundSubjects(resource, child) + if err != nil { + return nil, err + } + + err = toReturn.AddFrom(tss) + if err != nil { + return nil, err + } + } + + toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression) + return toReturn, nil + + case core.SetOperationUserset_INTERSECTION: + if len(typed.IntermediateNode.ChildNodes) == 0 { + return nil, fmt.Errorf("found intersection with no children") + } + + firstChildSet, err := populateFoundSubjects(rootONR, typed.IntermediateNode.ChildNodes[0]) + if err != nil { + return nil, err + } + + toReturn := NewTrackingSubjectSet() + err = toReturn.AddFrom(firstChildSet) + if err != nil { + return nil, err + } + + for _, child := range typed.IntermediateNode.ChildNodes[1:] { + childSet, err := populateFoundSubjects(rootONR, child) + if err != nil { + return nil, err + } + + updated, err := toReturn.Intersect(childSet) + if err != nil { + return nil, err + } + + toReturn = updated + } + + toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression) + return toReturn, nil + + case core.SetOperationUserset_EXCLUSION: + if len(typed.IntermediateNode.ChildNodes) == 0 { + return nil, fmt.Errorf("found exclusion with no children") + } + + firstChildSet, err := populateFoundSubjects(rootONR, typed.IntermediateNode.ChildNodes[0]) + if err != nil { + return nil, err + } + + toReturn := NewTrackingSubjectSet() + err = toReturn.AddFrom(firstChildSet) + if err != nil { + return nil, err + } + + for _, child := range typed.IntermediateNode.ChildNodes[1:] { + childSet, err := populateFoundSubjects(rootONR, child) + if err != nil { + return nil, err + } + toReturn = toReturn.Exclude(childSet) + } + + toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression) + return toReturn, nil + + default: + return nil, spiceerrors.MustBugf("unknown expand operation") + } + + case *core.RelationTupleTreeNode_LeafNode: + toReturn := NewTrackingSubjectSet() + for _, subject := range typed.LeafNode.Subjects { + fs := NewFoundSubject(subject) + err := toReturn.Add(fs) + if err != nil { + return nil, err + } + + fs.resources.Add(resource) + } + + toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression) + return toReturn, nil + + default: + return nil, spiceerrors.MustBugf("unknown TreeNode type") + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/onrset.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/onrset.go new file mode 100644 index 0000000..ad7fcfd --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/onrset.go @@ -0,0 +1,87 @@ +package developmentmembership + +import ( + "github.com/ccoveille/go-safecast" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/tuple" +) + +// TODO(jschorr): Replace with the generic set over tuple.ObjectAndRelation + +// ONRSet is a set of ObjectAndRelation's. +type ONRSet struct { + onrs *mapz.Set[tuple.ObjectAndRelation] +} + +// NewONRSet creates a new set. +func NewONRSet(onrs ...tuple.ObjectAndRelation) ONRSet { + created := ONRSet{ + onrs: mapz.NewSet[tuple.ObjectAndRelation](), + } + created.Update(onrs) + return created +} + +// Length returns the size of the set. +func (ons ONRSet) Length() uint64 { + // This is the length of a set so we should never fall out of bounds. + length, _ := safecast.ToUint64(ons.onrs.Len()) + return length +} + +// IsEmpty returns whether the set is empty. +func (ons ONRSet) IsEmpty() bool { + return ons.onrs.IsEmpty() +} + +// Has returns true if the set contains the given ONR. +func (ons ONRSet) Has(onr tuple.ObjectAndRelation) bool { + return ons.onrs.Has(onr) +} + +// Add adds the given ONR to the set. Returns true if the object was not in the set before this +// call and false otherwise. +func (ons ONRSet) Add(onr tuple.ObjectAndRelation) bool { + return ons.onrs.Add(onr) +} + +// Update updates the set by adding the given ONRs to it. +func (ons ONRSet) Update(onrs []tuple.ObjectAndRelation) { + for _, onr := range onrs { + ons.Add(onr) + } +} + +// UpdateFrom updates the set by adding the ONRs found in the other set to it. +func (ons ONRSet) UpdateFrom(otherSet ONRSet) { + if otherSet.onrs == nil { + return + } + ons.onrs.Merge(otherSet.onrs) +} + +// Intersect returns an intersection between this ONR set and the other set provided. +func (ons ONRSet) Intersect(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Intersect(otherSet.onrs)} +} + +// Subtract returns a subtraction from this ONR set of the other set provided. +func (ons ONRSet) Subtract(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Subtract(otherSet.onrs)} +} + +// Union returns a copy of this ONR set with the other set's elements added in. +func (ons ONRSet) Union(otherSet ONRSet) ONRSet { + return ONRSet{ons.onrs.Union(otherSet.onrs)} +} + +// AsSlice returns the ONRs found in the set as a slice. +func (ons ONRSet) AsSlice() []tuple.ObjectAndRelation { + slice := make([]tuple.ObjectAndRelation, 0, ons.Length()) + _ = ons.onrs.ForEach(func(onr tuple.ObjectAndRelation) error { + slice = append(slice, onr) + return nil + }) + return slice +} diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/trackingsubjectset.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/trackingsubjectset.go new file mode 100644 index 0000000..00f8836 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/trackingsubjectset.go @@ -0,0 +1,235 @@ +package developmentmembership + +import ( + "github.com/authzed/spicedb/internal/datasets" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +// TrackingSubjectSet defines a set that tracks accessible subjects and their associated +// relationships. +// +// NOTE: This is designed solely for the developer API and testing and should *not* be used in any +// performance sensitive code. +type TrackingSubjectSet struct { + setByType map[tuple.RelationReference]datasets.BaseSubjectSet[FoundSubject] +} + +// NewTrackingSubjectSet creates a new TrackingSubjectSet +func NewTrackingSubjectSet() *TrackingSubjectSet { + tss := &TrackingSubjectSet{ + setByType: map[tuple.RelationReference]datasets.BaseSubjectSet[FoundSubject]{}, + } + return tss +} + +// MustNewTrackingSubjectSetWith creates a new TrackingSubjectSet, and adds the specified +// subjects to it. +func MustNewTrackingSubjectSetWith(subjects ...FoundSubject) *TrackingSubjectSet { + tss := NewTrackingSubjectSet() + for _, subject := range subjects { + err := tss.Add(subject) + if err != nil { + panic(err) + } + } + return tss +} + +// AddFrom adds the subjects found in the other set to this set. +func (tss *TrackingSubjectSet) AddFrom(otherSet *TrackingSubjectSet) error { + for key, oss := range otherSet.setByType { + err := tss.getSetForKey(key).UnionWithSet(oss) + if err != nil { + return err + } + } + return nil +} + +// MustAddFrom adds the subjects found in the other set to this set. +func (tss *TrackingSubjectSet) MustAddFrom(otherSet *TrackingSubjectSet) { + err := tss.AddFrom(otherSet) + if err != nil { + panic(err) + } +} + +// RemoveFrom removes any subjects found in the other set from this set. +func (tss *TrackingSubjectSet) RemoveFrom(otherSet *TrackingSubjectSet) { + for key, oss := range otherSet.setByType { + tss.getSetForKey(key).SubtractAll(oss) + } +} + +// MustAdd adds the given subjects to this set. +func (tss *TrackingSubjectSet) MustAdd(subjectsAndResources ...FoundSubject) { + err := tss.Add(subjectsAndResources...) + if err != nil { + panic(err) + } +} + +// Add adds the given subjects to this set. +func (tss *TrackingSubjectSet) Add(subjectsAndResources ...FoundSubject) error { + for _, fs := range subjectsAndResources { + err := tss.getSet(fs).Add(fs) + if err != nil { + return err + } + } + return nil +} + +func (tss *TrackingSubjectSet) getSetForKey(key tuple.RelationReference) datasets.BaseSubjectSet[FoundSubject] { + if existing, ok := tss.setByType[key]; ok { + return existing + } + + created := datasets.NewBaseSubjectSet( + func(subjectID string, caveatExpression *core.CaveatExpression, excludedSubjects []FoundSubject, sources ...FoundSubject) FoundSubject { + fs := NewFoundSubject(&core.DirectSubject{ + Subject: &core.ObjectAndRelation{ + Namespace: key.ObjectType, + ObjectId: subjectID, + Relation: key.Relation, + }, + CaveatExpression: caveatExpression, + }) + fs.excludedSubjects = excludedSubjects + fs.caveatExpression = caveatExpression + for _, source := range sources { + fs.resources.UpdateFrom(source.resources) + } + return fs + }, + ) + tss.setByType[key] = created + return created +} + +func (tss *TrackingSubjectSet) getSet(fs FoundSubject) datasets.BaseSubjectSet[FoundSubject] { + return tss.getSetForKey(fs.subject.RelationReference()) +} + +// Get returns the found subject in the set, if any. +func (tss *TrackingSubjectSet) Get(subject tuple.ObjectAndRelation) (FoundSubject, bool) { + set, ok := tss.setByType[subject.RelationReference()] + if !ok { + return FoundSubject{}, false + } + + return set.Get(subject.ObjectID) +} + +// Contains returns true if the set contains the given subject. +func (tss *TrackingSubjectSet) Contains(subject tuple.ObjectAndRelation) bool { + _, ok := tss.Get(subject) + return ok +} + +// Exclude returns a new set that contains the items in this set minus those in the other set. +func (tss *TrackingSubjectSet) Exclude(otherSet *TrackingSubjectSet) *TrackingSubjectSet { + newSet := NewTrackingSubjectSet() + + for key, bss := range tss.setByType { + cloned := bss.Clone() + if oss, ok := otherSet.setByType[key]; ok { + cloned.SubtractAll(oss) + } + + newSet.setByType[key] = cloned + } + + return newSet +} + +// MustIntersect returns a new set that contains the items in this set *and* the other set. Note that +// if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found +// on the other side of the intersection. +func (tss *TrackingSubjectSet) MustIntersect(otherSet *TrackingSubjectSet) *TrackingSubjectSet { + updated, err := tss.Intersect(otherSet) + if err != nil { + panic(err) + } + return updated +} + +// Intersect returns a new set that contains the items in this set *and* the other set. Note that +// if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found +// on the other side of the intersection. +func (tss *TrackingSubjectSet) Intersect(otherSet *TrackingSubjectSet) (*TrackingSubjectSet, error) { + newSet := NewTrackingSubjectSet() + + for key, bss := range tss.setByType { + if oss, ok := otherSet.setByType[key]; ok { + cloned := bss.Clone() + err := cloned.IntersectionDifference(oss) + if err != nil { + return nil, err + } + + newSet.setByType[key] = cloned + } + } + + return newSet, nil +} + +// ApplyParentCaveatExpression applies the given parent caveat expression (if any) to each subject set. +func (tss *TrackingSubjectSet) ApplyParentCaveatExpression(parentCaveatExpr *core.CaveatExpression) { + if parentCaveatExpr == nil { + return + } + + for key, bss := range tss.setByType { + tss.setByType[key] = bss.WithParentCaveatExpression(parentCaveatExpr) + } +} + +// removeExact removes the given subject(s) from the set. If the subject is a wildcard, only +// the exact matching wildcard will be removed. +func (tss *TrackingSubjectSet) removeExact(subjects ...tuple.ObjectAndRelation) { + for _, subject := range subjects { + if set, ok := tss.setByType[subject.RelationReference()]; ok { + set.UnsafeRemoveExact(FoundSubject{ + subject: subject, + }) + } + } +} + +func (tss *TrackingSubjectSet) getSubjects() []string { + var subjects []string + for _, subjectSet := range tss.setByType { + for _, foundSubject := range subjectSet.AsSlice() { + subjects = append(subjects, tuple.StringONR(foundSubject.subject)) + } + } + return subjects +} + +// ToSlice returns a slice of all subjects found in the set. +func (tss *TrackingSubjectSet) ToSlice() []FoundSubject { + subjects := []FoundSubject{} + for _, bss := range tss.setByType { + subjects = append(subjects, bss.AsSlice()...) + } + + return subjects +} + +// ToFoundSubjects returns the set as a FoundSubjects struct. +func (tss *TrackingSubjectSet) ToFoundSubjects() FoundSubjects { + return FoundSubjects{tss} +} + +// IsEmpty returns true if the tracking subject set is empty. +func (tss *TrackingSubjectSet) IsEmpty() bool { + for _, bss := range tss.setByType { + if !bss.IsEmpty() { + return false + } + } + return true +} diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/dispatch.go b/vendor/github.com/authzed/spicedb/internal/dispatch/dispatch.go new file mode 100644 index 0000000..95a231a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/dispatch/dispatch.go @@ -0,0 +1,98 @@ +package dispatch + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + + log "github.com/authzed/spicedb/internal/logging" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" +) + +// ReadyState represents the ready state of the dispatcher. +type ReadyState struct { + // Message is a human-readable status message for the current state. + Message string + + // IsReady indicates whether the datastore is ready. + IsReady bool +} + +// Dispatcher interface describes a method for passing subchecks off to additional machines. +type Dispatcher interface { + Check + Expand + LookupSubjects + LookupResources2 + + // Close closes the dispatcher. + Close() error + + // ReadyState returns true when dispatcher is able to respond to requests + ReadyState() ReadyState +} + +// Check interface describes just the methods required to dispatch check requests. +type Check interface { + // DispatchCheck submits a single check request and returns its result. + DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) +} + +// Expand interface describes just the methods required to dispatch expand requests. +type Expand interface { + // DispatchExpand submits a single expand request and returns its result. + // If an error is returned, DispatchExpandResponse will still contain Metadata. + DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) +} + +type LookupResources2Stream = Stream[*v1.DispatchLookupResources2Response] + +type LookupResources2 interface { + DispatchLookupResources2( + req *v1.DispatchLookupResources2Request, + stream LookupResources2Stream, + ) error +} + +// LookupSubjectsStream is an alias for the stream to which found subjects will be written. +type LookupSubjectsStream = Stream[*v1.DispatchLookupSubjectsResponse] + +// LookupSubjects interface describes just the methods required to dispatch lookup subjects requests. +type LookupSubjects interface { + // DispatchLookupSubjects submits a single lookup subjects request, writing its results to the specified stream. + DispatchLookupSubjects( + req *v1.DispatchLookupSubjectsRequest, + stream LookupSubjectsStream, + ) error +} + +// DispatchableRequest is an interface for requests. +type DispatchableRequest interface { + zerolog.LogObjectMarshaler + + GetMetadata() *v1.ResolverMeta +} + +// CheckDepth returns ErrMaxDepth if there is insufficient depth remaining to dispatch. +func CheckDepth(ctx context.Context, req DispatchableRequest) error { + metadata := req.GetMetadata() + if metadata == nil { + log.Ctx(ctx).Warn().Object("request", req).Msg("request missing metadata") + return fmt.Errorf("request missing metadata") + } + + if metadata.DepthRemaining == 0 { + return NewMaxDepthExceededError(req) + } + + return nil +} + +// AddResponseMetadata adds the metadata found in the incoming metadata to the existing +// metadata, *modifying it in place*. +func AddResponseMetadata(existing *v1.ResponseMeta, incoming *v1.ResponseMeta) { + existing.DispatchCount += incoming.DispatchCount + existing.CachedDispatchCount += incoming.CachedDispatchCount + existing.DepthRequired = max(existing.DepthRequired, incoming.DepthRequired) +} diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/doc.go b/vendor/github.com/authzed/spicedb/internal/dispatch/doc.go new file mode 100644 index 0000000..8b88bb0 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/dispatch/doc.go @@ -0,0 +1,2 @@ +// Package dispatch contains logic to dispatch requests locally or to other nodes. +package dispatch diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/errors.go b/vendor/github.com/authzed/spicedb/internal/dispatch/errors.go new file mode 100644 index 0000000..17cec3f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/dispatch/errors.go @@ -0,0 +1,39 @@ +package dispatch + +import ( + "fmt" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// MaxDepthExceededError is an error returned when the maximum depth for dispatching has been exceeded. +type MaxDepthExceededError struct { + error + + // Request is the request that exceeded the maximum depth. + Request DispatchableRequest +} + +// NewMaxDepthExceededError creates a new MaxDepthExceededError. +func NewMaxDepthExceededError(req DispatchableRequest) error { + return MaxDepthExceededError{ + fmt.Errorf("max depth exceeded: this usually indicates a recursive or too deep data dependency. See: https://spicedb.dev/d/debug-max-depth"), + req, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err MaxDepthExceededError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.ResourceExhausted, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_MAXIMUM_DEPTH_EXCEEDED, + map[string]string{}, + ), + ) +} diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/graph/errors.go b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/errors.go new file mode 100644 index 0000000..ecaf59a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/errors.go @@ -0,0 +1,77 @@ +package graph + +import ( + "fmt" + + "github.com/rs/zerolog" +) + +// NamespaceNotFoundError occurs when a namespace was not found. +type NamespaceNotFoundError struct { + error + namespaceName string +} + +// NotFoundNamespaceName returns the name of the namespace that was not found. +func (err NamespaceNotFoundError) NotFoundNamespaceName() string { + return err.namespaceName +} + +// MarshalZerologObject implements zerolog.LogObjectMarshaler +func (err NamespaceNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err NamespaceNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + } +} + +// NewNamespaceNotFoundErr constructs a new namespace not found error. +func NewNamespaceNotFoundErr(nsName string) error { + return NamespaceNotFoundError{ + error: fmt.Errorf("object definition `%s` not found", nsName), + namespaceName: nsName, + } +} + +// RelationNotFoundError occurs when a relation was not found under a namespace. +type RelationNotFoundError struct { + error + namespaceName string + relationName string +} + +// NamespaceName returns the name of the namespace in which the relation was not found. +func (err RelationNotFoundError) NamespaceName() string { + return err.namespaceName +} + +// NotFoundRelationName returns the name of the relation not found. +func (err RelationNotFoundError) NotFoundRelationName() string { + return err.relationName +} + +// MarshalZerologObject implements zerolog.LogObjectMarshaler +func (err RelationNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err RelationNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + "relation_or_permission_name": err.relationName, + } +} + +// NewRelationNotFoundErr constructs a new relation not found error. +func NewRelationNotFoundErr(nsName string, relationName string) error { + return RelationNotFoundError{ + error: fmt.Errorf("relation/permission `%s` not found under definition `%s`", relationName, nsName), + namespaceName: nsName, + relationName: relationName, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/graph/graph.go b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/graph.go new file mode 100644 index 0000000..232b1e7 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/graph.go @@ -0,0 +1,437 @@ +package graph + +import ( + "context" + "errors" + "fmt" + + "github.com/rs/zerolog" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph" + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/middleware/nodeid" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const errDispatch = "error dispatching request: %w" + +var tracer = otel.Tracer("spicedb/internal/dispatch/local") + +// ConcurrencyLimits defines per-dispatch-type concurrency limits. +// +//go:generate go run github.com/ecordell/optgen -output zz_generated.options.go . ConcurrencyLimits +type ConcurrencyLimits struct { + Check uint16 `debugmap:"visible"` + ReachableResources uint16 `debugmap:"visible"` + LookupResources uint16 `debugmap:"visible"` + LookupSubjects uint16 `debugmap:"visible"` +} + +const defaultConcurrencyLimit = 50 + +// WithOverallDefaultLimit sets the overall default limit for any unspecified limits +// and returns a new struct. +func (cl ConcurrencyLimits) WithOverallDefaultLimit(overallDefaultLimit uint16) ConcurrencyLimits { + return limitsOrDefaults(cl, overallDefaultLimit) +} + +func (cl ConcurrencyLimits) MarshalZerologObject(e *zerolog.Event) { + e.Uint16("concurrency-limit-check-permission", cl.Check) + e.Uint16("concurrency-limit-lookup-resources", cl.LookupResources) + e.Uint16("concurrency-limit-lookup-subjects", cl.LookupSubjects) + e.Uint16("concurrency-limit-reachable-resources", cl.ReachableResources) +} + +func limitsOrDefaults(limits ConcurrencyLimits, overallDefaultLimit uint16) ConcurrencyLimits { + limits.Check = limitOrDefault(limits.Check, overallDefaultLimit) + limits.LookupResources = limitOrDefault(limits.LookupResources, overallDefaultLimit) + limits.LookupSubjects = limitOrDefault(limits.LookupSubjects, overallDefaultLimit) + limits.ReachableResources = limitOrDefault(limits.ReachableResources, overallDefaultLimit) + return limits +} + +func limitOrDefault(limit uint16, defaultLimit uint16) uint16 { + if limit <= 0 { + return defaultLimit + } + return limit +} + +// SharedConcurrencyLimits returns a ConcurrencyLimits struct with the limit +// set to that provided for each operation. +func SharedConcurrencyLimits(concurrencyLimit uint16) ConcurrencyLimits { + return ConcurrencyLimits{ + Check: concurrencyLimit, + ReachableResources: concurrencyLimit, + LookupResources: concurrencyLimit, + LookupSubjects: concurrencyLimit, + } +} + +// NewLocalOnlyDispatcher creates a dispatcher that consults with the graph to formulate a response. +func NewLocalOnlyDispatcher(typeSet *caveattypes.TypeSet, concurrencyLimit uint16, dispatchChunkSize uint16) dispatch.Dispatcher { + return NewLocalOnlyDispatcherWithLimits(typeSet, SharedConcurrencyLimits(concurrencyLimit), dispatchChunkSize) +} + +// NewLocalOnlyDispatcherWithLimits creates a dispatcher thatg consults with the graph to formulate a response +// and has the defined concurrency limits per dispatch type. +func NewLocalOnlyDispatcherWithLimits(typeSet *caveattypes.TypeSet, concurrencyLimits ConcurrencyLimits, dispatchChunkSize uint16) dispatch.Dispatcher { + d := &localDispatcher{} + + concurrencyLimits = limitsOrDefaults(concurrencyLimits, defaultConcurrencyLimit) + chunkSize := dispatchChunkSize + if chunkSize == 0 { + chunkSize = 100 + log.Warn().Msgf("LocalOnlyDispatcher: dispatchChunkSize not set, defaulting to %d", chunkSize) + } + + d.checker = graph.NewConcurrentChecker(d, concurrencyLimits.Check, chunkSize) + d.expander = graph.NewConcurrentExpander(d) + d.lookupSubjectsHandler = graph.NewConcurrentLookupSubjects(d, concurrencyLimits.LookupSubjects, chunkSize) + d.lookupResourcesHandler2 = graph.NewCursoredLookupResources2(d, d, typeSet, concurrencyLimits.LookupResources, chunkSize) + + return d +} + +// NewDispatcher creates a dispatcher that consults with the graph and redispatches subproblems to +// the provided redispatcher. +func NewDispatcher(redispatcher dispatch.Dispatcher, typeSet *caveattypes.TypeSet, concurrencyLimits ConcurrencyLimits, dispatchChunkSize uint16) dispatch.Dispatcher { + concurrencyLimits = limitsOrDefaults(concurrencyLimits, defaultConcurrencyLimit) + chunkSize := dispatchChunkSize + if chunkSize == 0 { + chunkSize = 100 + log.Warn().Msgf("Dispatcher: dispatchChunkSize not set, defaulting to %d", chunkSize) + } + + checker := graph.NewConcurrentChecker(redispatcher, concurrencyLimits.Check, chunkSize) + expander := graph.NewConcurrentExpander(redispatcher) + lookupSubjectsHandler := graph.NewConcurrentLookupSubjects(redispatcher, concurrencyLimits.LookupSubjects, chunkSize) + lookupResourcesHandler2 := graph.NewCursoredLookupResources2(redispatcher, redispatcher, typeSet, concurrencyLimits.LookupResources, chunkSize) + + return &localDispatcher{ + checker: checker, + expander: expander, + lookupSubjectsHandler: lookupSubjectsHandler, + lookupResourcesHandler2: lookupResourcesHandler2, + } +} + +type localDispatcher struct { + checker *graph.ConcurrentChecker + expander *graph.ConcurrentExpander + lookupSubjectsHandler *graph.ConcurrentLookupSubjects + lookupResourcesHandler2 *graph.CursoredLookupResources2 +} + +func (ld *localDispatcher) loadNamespace(ctx context.Context, nsName string, revision datastore.Revision) (*core.NamespaceDefinition, error) { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(revision) + + // Load namespace and relation from the datastore + ns, _, err := ds.ReadNamespaceByName(ctx, nsName) + if err != nil { + return nil, rewriteNamespaceError(err) + } + + return ns, err +} + +func (ld *localDispatcher) parseRevision(ctx context.Context, s string) (datastore.Revision, error) { + ds := datastoremw.MustFromContext(ctx) + return ds.RevisionFromString(s) +} + +func (ld *localDispatcher) lookupRelation(_ context.Context, ns *core.NamespaceDefinition, relationName string) (*core.Relation, error) { + var relation *core.Relation + for _, candidate := range ns.Relation { + if candidate.Name == relationName { + relation = candidate + break + } + } + + if relation == nil { + return nil, NewRelationNotFoundErr(ns.Name, relationName) + } + + return relation, nil +} + +// DispatchCheck implements dispatch.Check interface +func (ld *localDispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) { + resourceType := tuple.StringCoreRR(req.ResourceRelation) + spanName := "DispatchCheck → " + resourceType + "@" + req.Subject.Namespace + "#" + req.Subject.Relation + + nodeID, err := nodeid.FromContext(ctx) + if err != nil { + log.Err(err).Msg("failed to get node ID") + } + + ctx, span := tracer.Start(ctx, spanName, trace.WithAttributes( + attribute.String("resource-type", resourceType), + attribute.StringSlice("resource-ids", req.ResourceIds), + attribute.String("subject", tuple.StringCoreONR(req.Subject)), + attribute.String("node-id", nodeID), + )) + defer span.End() + + if err := dispatch.CheckDepth(ctx, req); err != nil { + if req.Debug != v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING { + return &v1.DispatchCheckResponse{ + Metadata: &v1.ResponseMeta{ + DispatchCount: 0, + }, + }, rewriteError(ctx, err) + } + + // NOTE: we return debug information here to ensure tooling can see the cycle. + nodeID, nerr := nodeid.FromContext(ctx) + if nerr != nil { + log.Err(nerr).Msg("failed to get nodeID from context") + } + + return &v1.DispatchCheckResponse{ + Metadata: &v1.ResponseMeta{ + DispatchCount: 0, + DebugInfo: &v1.DebugInformation{ + Check: &v1.CheckDebugTrace{ + Request: req, + SourceId: nodeID, + }, + }, + }, + }, rewriteError(ctx, err) + } + + revision, err := ld.parseRevision(ctx, req.Metadata.AtRevision) + if err != nil { + return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err) + } + + ns, err := ld.loadNamespace(ctx, req.ResourceRelation.Namespace, revision) + if err != nil { + return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err) + } + + relation, err := ld.lookupRelation(ctx, ns, req.ResourceRelation.Relation) + if err != nil { + return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err) + } + + // If the relation is aliasing another one and the subject does not have the same type as + // resource, load the aliased relation and dispatch to it. We cannot use the alias if the + // resource and subject types are the same because a check on the *exact same* resource and + // subject must pass, and we don't know how many intermediate steps may hit that case. + if relation.AliasingRelation != "" && req.ResourceRelation.Namespace != req.Subject.Namespace { + relation, err := ld.lookupRelation(ctx, ns, relation.AliasingRelation) + if err != nil { + return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err) + } + + // Rewrite the request over the aliased relation. + validatedReq := graph.ValidatedCheckRequest{ + DispatchCheckRequest: &v1.DispatchCheckRequest{ + ResourceRelation: &core.RelationReference{ + Namespace: req.ResourceRelation.Namespace, + Relation: relation.Name, + }, + ResourceIds: req.ResourceIds, + Subject: req.Subject, + Metadata: req.Metadata, + Debug: req.Debug, + CheckHints: req.CheckHints, + }, + Revision: revision, + OriginalRelationName: req.ResourceRelation.Relation, + } + + resp, err := ld.checker.Check(ctx, validatedReq, relation) + return resp, rewriteError(ctx, err) + } + + resp, err := ld.checker.Check(ctx, graph.ValidatedCheckRequest{ + DispatchCheckRequest: req, + Revision: revision, + }, relation) + return resp, rewriteError(ctx, err) +} + +// DispatchExpand implements dispatch.Expand interface +func (ld *localDispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) { + nodeID, err := nodeid.FromContext(ctx) + if err != nil { + log.Err(err).Msg("failed to get node ID") + } + + ctx, span := tracer.Start(ctx, "DispatchExpand", trace.WithAttributes( + attribute.String("start", tuple.StringCoreONR(req.ResourceAndRelation)), + attribute.String("node-id", nodeID), + )) + defer span.End() + + if err := dispatch.CheckDepth(ctx, req); err != nil { + return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err + } + + revision, err := ld.parseRevision(ctx, req.Metadata.AtRevision) + if err != nil { + return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err + } + + ns, err := ld.loadNamespace(ctx, req.ResourceAndRelation.Namespace, revision) + if err != nil { + return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err + } + + relation, err := ld.lookupRelation(ctx, ns, req.ResourceAndRelation.Relation) + if err != nil { + return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err + } + + return ld.expander.Expand(ctx, graph.ValidatedExpandRequest{ + DispatchExpandRequest: req, + Revision: revision, + }, relation) +} + +func (ld *localDispatcher) DispatchLookupResources2( + req *v1.DispatchLookupResources2Request, + stream dispatch.LookupResources2Stream, +) error { + nodeID, err := nodeid.FromContext(stream.Context()) + if err != nil { + log.Err(err).Msg("failed to get node ID") + } + + ctx, span := tracer.Start(stream.Context(), "DispatchLookupResources2", trace.WithAttributes( + attribute.String("resource-type", tuple.StringCoreRR(req.ResourceRelation)), + attribute.String("subject-type", tuple.StringCoreRR(req.SubjectRelation)), + attribute.StringSlice("subject-ids", req.SubjectIds), + attribute.String("terminal-subject", tuple.StringCoreONR(req.TerminalSubject)), + attribute.String("node-id", nodeID), + )) + defer span.End() + + if err := dispatch.CheckDepth(ctx, req); err != nil { + return err + } + + revision, err := ld.parseRevision(ctx, req.Metadata.AtRevision) + if err != nil { + return err + } + + return ld.lookupResourcesHandler2.LookupResources2( + graph.ValidatedLookupResources2Request{ + DispatchLookupResources2Request: req, + Revision: revision, + }, + dispatch.StreamWithContext(ctx, stream), + ) +} + +// DispatchLookupSubjects implements dispatch.LookupSubjects interface +func (ld *localDispatcher) DispatchLookupSubjects( + req *v1.DispatchLookupSubjectsRequest, + stream dispatch.LookupSubjectsStream, +) error { + nodeID, err := nodeid.FromContext(stream.Context()) + if err != nil { + log.Err(err).Msg("failed to get node ID") + } + + resourceType := tuple.StringCoreRR(req.ResourceRelation) + subjectRelation := tuple.StringCoreRR(req.SubjectRelation) + spanName := "DispatchLookupSubjects → " + resourceType + "@" + subjectRelation + + ctx, span := tracer.Start(stream.Context(), spanName, trace.WithAttributes( + attribute.String("resource-type", resourceType), + attribute.String("subject-type", subjectRelation), + attribute.StringSlice("resource-ids", req.ResourceIds), + attribute.String("node-id", nodeID), + )) + defer span.End() + + if err := dispatch.CheckDepth(ctx, req); err != nil { + return err + } + + revision, err := ld.parseRevision(ctx, req.Metadata.AtRevision) + if err != nil { + return err + } + + return ld.lookupSubjectsHandler.LookupSubjects( + graph.ValidatedLookupSubjectsRequest{ + DispatchLookupSubjectsRequest: req, + Revision: revision, + }, + dispatch.StreamWithContext(ctx, stream), + ) +} + +func (ld *localDispatcher) Close() error { + return nil +} + +func (ld *localDispatcher) ReadyState() dispatch.ReadyState { + return dispatch.ReadyState{IsReady: true} +} + +func rewriteNamespaceError(original error) error { + nsNotFound := datastore.NamespaceNotFoundError{} + + switch { + case errors.As(original, &nsNotFound): + return NewNamespaceNotFoundErr(nsNotFound.NotFoundNamespaceName()) + case errors.As(original, &NamespaceNotFoundError{}): + fallthrough + case errors.As(original, &RelationNotFoundError{}): + return original + default: + return fmt.Errorf(errDispatch, original) + } +} + +// rewriteError transforms graph errors into a gRPC Status +func rewriteError(ctx context.Context, err error) error { + if err == nil { + return nil + } + + // Check if the error can be directly used. + if st, ok := status.FromError(err); ok { + return st.Err() + } + + switch { + case errors.Is(err, context.DeadlineExceeded): + return status.Errorf(codes.DeadlineExceeded, "%s", err) + case errors.Is(err, context.Canceled): + err := context.Cause(ctx) + if err != nil { + if _, ok := status.FromError(err); ok { + return err + } + } + + return status.Errorf(codes.Canceled, "%s", err) + default: + log.Ctx(ctx).Err(err).Msg("received unexpected graph error") + return err + } +} + +var emptyMetadata = &v1.ResponseMeta{ + DispatchCount: 0, +} diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/graph/zz_generated.options.go b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/zz_generated.options.go new file mode 100644 index 0000000..9a0a7fc --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/zz_generated.options.go @@ -0,0 +1,92 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package graph + +import ( + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" +) + +type ConcurrencyLimitsOption func(c *ConcurrencyLimits) + +// NewConcurrencyLimitsWithOptions creates a new ConcurrencyLimits with the passed in options set +func NewConcurrencyLimitsWithOptions(opts ...ConcurrencyLimitsOption) *ConcurrencyLimits { + c := &ConcurrencyLimits{} + for _, o := range opts { + o(c) + } + return c +} + +// NewConcurrencyLimitsWithOptionsAndDefaults creates a new ConcurrencyLimits with the passed in options set starting from the defaults +func NewConcurrencyLimitsWithOptionsAndDefaults(opts ...ConcurrencyLimitsOption) *ConcurrencyLimits { + c := &ConcurrencyLimits{} + defaults.MustSet(c) + for _, o := range opts { + o(c) + } + return c +} + +// ToOption returns a new ConcurrencyLimitsOption that sets the values from the passed in ConcurrencyLimits +func (c *ConcurrencyLimits) ToOption() ConcurrencyLimitsOption { + return func(to *ConcurrencyLimits) { + to.Check = c.Check + to.ReachableResources = c.ReachableResources + to.LookupResources = c.LookupResources + to.LookupSubjects = c.LookupSubjects + } +} + +// DebugMap returns a map form of ConcurrencyLimits for debugging +func (c ConcurrencyLimits) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["Check"] = helpers.DebugValue(c.Check, false) + debugMap["ReachableResources"] = helpers.DebugValue(c.ReachableResources, false) + debugMap["LookupResources"] = helpers.DebugValue(c.LookupResources, false) + debugMap["LookupSubjects"] = helpers.DebugValue(c.LookupSubjects, false) + return debugMap +} + +// ConcurrencyLimitsWithOptions configures an existing ConcurrencyLimits with the passed in options set +func ConcurrencyLimitsWithOptions(c *ConcurrencyLimits, opts ...ConcurrencyLimitsOption) *ConcurrencyLimits { + for _, o := range opts { + o(c) + } + return c +} + +// WithOptions configures the receiver ConcurrencyLimits with the passed in options set +func (c *ConcurrencyLimits) WithOptions(opts ...ConcurrencyLimitsOption) *ConcurrencyLimits { + for _, o := range opts { + o(c) + } + return c +} + +// WithCheck returns an option that can set Check on a ConcurrencyLimits +func WithCheck(check uint16) ConcurrencyLimitsOption { + return func(c *ConcurrencyLimits) { + c.Check = check + } +} + +// WithReachableResources returns an option that can set ReachableResources on a ConcurrencyLimits +func WithReachableResources(reachableResources uint16) ConcurrencyLimitsOption { + return func(c *ConcurrencyLimits) { + c.ReachableResources = reachableResources + } +} + +// WithLookupResources returns an option that can set LookupResources on a ConcurrencyLimits +func WithLookupResources(lookupResources uint16) ConcurrencyLimitsOption { + return func(c *ConcurrencyLimits) { + c.LookupResources = lookupResources + } +} + +// WithLookupSubjects returns an option that can set LookupSubjects on a ConcurrencyLimits +func WithLookupSubjects(lookupSubjects uint16) ConcurrencyLimitsOption { + return func(c *ConcurrencyLimits) { + c.LookupSubjects = lookupSubjects + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go b/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go new file mode 100644 index 0000000..1d6636c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go @@ -0,0 +1,187 @@ +package dispatch + +import ( + "context" + "sync" + "sync/atomic" + + grpc "google.golang.org/grpc" +) + +// Stream defines the interface generically matching a streaming dispatch response. +type Stream[T any] interface { + // Publish publishes the result to the stream. + Publish(T) error + + // Context returns the context for the stream. + Context() context.Context +} + +type grpcStream[T any] interface { + grpc.ServerStream + Send(T) error +} + +// WrapGRPCStream wraps a gRPC result stream with a concurrent-safe dispatch stream. This is +// necessary because gRPC response streams are *not concurrent safe*. +// See: https://groups.google.com/g/grpc-io/c/aI6L6M4fzQ0?pli=1 +func WrapGRPCStream[R any, S grpcStream[R]](grpcStream S) Stream[R] { + return &concurrentSafeStream[R]{ + grpcStream: grpcStream, + mu: sync.Mutex{}, + } +} + +type concurrentSafeStream[T any] struct { + grpcStream grpcStream[T] // GUARDED_BY(mu) + mu sync.Mutex +} + +func (s *concurrentSafeStream[T]) Context() context.Context { + return s.grpcStream.Context() +} + +func (s *concurrentSafeStream[T]) Publish(result T) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.grpcStream.Send(result) +} + +// NewCollectingDispatchStream creates a new CollectingDispatchStream. +func NewCollectingDispatchStream[T any](ctx context.Context) *CollectingDispatchStream[T] { + return &CollectingDispatchStream[T]{ + ctx: ctx, + results: nil, + mu: sync.Mutex{}, + } +} + +// CollectingDispatchStream is a dispatch stream that collects results in memory. +type CollectingDispatchStream[T any] struct { + ctx context.Context + results []T // GUARDED_BY(mu) + mu sync.Mutex +} + +func (s *CollectingDispatchStream[T]) Context() context.Context { + return s.ctx +} + +func (s *CollectingDispatchStream[T]) Results() []T { + return s.results +} + +func (s *CollectingDispatchStream[T]) Publish(result T) error { + s.mu.Lock() + defer s.mu.Unlock() + s.results = append(s.results, result) + return nil +} + +// WrappedDispatchStream is a dispatch stream that wraps another dispatch stream, and performs +// an operation on each result before puppeting back up to the parent stream. +type WrappedDispatchStream[T any] struct { + Stream Stream[T] + Ctx context.Context + Processor func(result T) (T, bool, error) +} + +func (s *WrappedDispatchStream[T]) Publish(result T) error { + if s.Processor == nil { + return s.Stream.Publish(result) + } + + processed, ok, err := s.Processor(result) + if err != nil { + return err + } + if !ok { + return nil + } + + return s.Stream.Publish(processed) +} + +func (s *WrappedDispatchStream[T]) Context() context.Context { + return s.Ctx +} + +// StreamWithContext returns the given dispatch stream, wrapped to return the given context. +func StreamWithContext[T any](context context.Context, stream Stream[T]) Stream[T] { + return &WrappedDispatchStream[T]{ + Stream: stream, + Ctx: context, + Processor: nil, + } +} + +// HandlingDispatchStream is a dispatch stream that executes a handler for each item published. +// It uses an internal mutex to ensure it is thread safe. +type HandlingDispatchStream[T any] struct { + ctx context.Context + processor func(result T) error // GUARDED_BY(mu) + mu sync.Mutex +} + +// NewHandlingDispatchStream returns a new handling dispatch stream. +func NewHandlingDispatchStream[T any](ctx context.Context, processor func(result T) error) Stream[T] { + return &HandlingDispatchStream[T]{ + ctx: ctx, + processor: processor, + mu: sync.Mutex{}, + } +} + +func (s *HandlingDispatchStream[T]) Publish(result T) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.processor == nil { + return nil + } + + return s.processor(result) +} + +func (s *HandlingDispatchStream[T]) Context() context.Context { + return s.ctx +} + +// CountingDispatchStream is a dispatch stream that counts the number of items published. +// It uses an internal atomic int to ensure it is thread safe. +type CountingDispatchStream[T any] struct { + Stream Stream[T] + count *atomic.Uint64 +} + +func NewCountingDispatchStream[T any](wrapped Stream[T]) *CountingDispatchStream[T] { + return &CountingDispatchStream[T]{ + Stream: wrapped, + count: &atomic.Uint64{}, + } +} + +func (s *CountingDispatchStream[T]) PublishedCount() uint64 { + return s.count.Load() +} + +func (s *CountingDispatchStream[T]) Publish(result T) error { + err := s.Stream.Publish(result) + if err != nil { + return err + } + + s.count.Add(1) + return nil +} + +func (s *CountingDispatchStream[T]) Context() context.Context { + return s.Stream.Context() +} + +// Ensure the streams implement the interface. +var ( + _ Stream[any] = &CollectingDispatchStream[any]{} + _ Stream[any] = &WrappedDispatchStream[any]{} + _ Stream[any] = &CountingDispatchStream[any]{} +) diff --git a/vendor/github.com/authzed/spicedb/internal/graph/check.go b/vendor/github.com/authzed/spicedb/internal/graph/check.go new file mode 100644 index 0000000..65bcb50 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/check.go @@ -0,0 +1,1354 @@ +package graph + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/samber/lo" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph/hints" + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/taskrunner" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/middleware/nodeid" + nspkg "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +var tracer = otel.Tracer("spicedb/internal/graph/check") + +var dispatchChunkCountHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "spicedb_check_dispatch_chunk_count", + Help: "number of chunks when dispatching in check", + Buckets: []float64{1, 2, 3, 5, 10, 25, 100, 250}, +}) + +var directDispatchQueryHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "spicedb_check_direct_dispatch_query_count", + Help: "number of queries made per direct dispatch", + Buckets: []float64{1, 2}, +}) + +const noOriginalRelation = "" + +func init() { + prometheus.MustRegister(directDispatchQueryHistogram) + prometheus.MustRegister(dispatchChunkCountHistogram) +} + +// NewConcurrentChecker creates an instance of ConcurrentChecker. +func NewConcurrentChecker(d dispatch.Check, concurrencyLimit uint16, dispatchChunkSize uint16) *ConcurrentChecker { + return &ConcurrentChecker{d, concurrencyLimit, dispatchChunkSize} +} + +// ConcurrentChecker exposes a method to perform Check requests, and delegates subproblems to the +// provided dispatch.Check instance. +type ConcurrentChecker struct { + d dispatch.Check + concurrencyLimit uint16 + dispatchChunkSize uint16 +} + +// ValidatedCheckRequest represents a request after it has been validated and parsed for internal +// consumption. +type ValidatedCheckRequest struct { + *v1.DispatchCheckRequest + Revision datastore.Revision + + // OriginalRelationName is the original relation/permission name that was used in the request, + // before being changed due to aliasing. + OriginalRelationName string +} + +// currentRequestContext holds context information for the current request being +// processed. +type currentRequestContext struct { + // parentReq is the parent request being processed. + parentReq ValidatedCheckRequest + + // filteredResourceIDs are those resource IDs to be checked after filtering for + // any resource IDs found directly matching the incoming subject. + // + // For example, a check of resources `user:{tom,sarah,fred}` and subject `user:sarah` will + // result in this slice containing `tom` and `fred`, but not `sarah`, as she was found as a + // match. + // + // This check and filter occurs via the filterForFoundMemberResource function in the + // checkInternal function before the rest of the checking logic is run. This slice should never + // be empty. + filteredResourceIDs []string + + // resultsSetting is the results setting to use for this request and all subsequent + // requests. + resultsSetting v1.DispatchCheckRequest_ResultsSetting + + // dispatchChunkSize is the maximum number of resource IDs that can be specified in each dispatch. + dispatchChunkSize uint16 +} + +// Check performs a check request with the provided request and context +func (cc *ConcurrentChecker) Check(ctx context.Context, req ValidatedCheckRequest, relation *core.Relation) (*v1.DispatchCheckResponse, error) { + var startTime *time.Time + if req.Debug != v1.DispatchCheckRequest_NO_DEBUG { + now := time.Now() + startTime = &now + } + + resolved := cc.checkInternal(ctx, req, relation) + resolved.Resp.Metadata = addCallToResponseMetadata(resolved.Resp.Metadata) + if req.Debug == v1.DispatchCheckRequest_NO_DEBUG { + return resolved.Resp, resolved.Err + } + + nodeID, err := nodeid.FromContext(ctx) + if err != nil { + // NOTE: we ignore this error here as if the node ID is missing, the debug + // trace is still valid. + log.Err(err).Msg("failed to get node ID") + } + + // Add debug information if requested. + debugInfo := resolved.Resp.Metadata.DebugInfo + if debugInfo == nil { + debugInfo = &v1.DebugInformation{ + Check: &v1.CheckDebugTrace{ + TraceId: NewTraceID(), + SourceId: nodeID, + }, + } + } else if debugInfo.Check != nil && debugInfo.Check.SourceId == "" { + debugInfo.Check.SourceId = nodeID + } + + // Remove the traversal bloom from the debug request to save some data over the + // wire. + clonedRequest := req.DispatchCheckRequest.CloneVT() + clonedRequest.Metadata.TraversalBloom = nil + + debugInfo.Check.Request = clonedRequest + debugInfo.Check.Duration = durationpb.New(time.Since(*startTime)) + + if nspkg.GetRelationKind(relation) == iv1.RelationMetadata_PERMISSION { + debugInfo.Check.ResourceRelationType = v1.CheckDebugTrace_PERMISSION + } else if nspkg.GetRelationKind(relation) == iv1.RelationMetadata_RELATION { + debugInfo.Check.ResourceRelationType = v1.CheckDebugTrace_RELATION + } + + // Build the results for the debug trace. + results := make(map[string]*v1.ResourceCheckResult, len(req.DispatchCheckRequest.ResourceIds)) + for _, resourceID := range req.DispatchCheckRequest.ResourceIds { + if found, ok := resolved.Resp.ResultsByResourceId[resourceID]; ok { + results[resourceID] = found + } + } + debugInfo.Check.Results = results + + // If there is existing debug information in the error, then place it as the subproblem of the current + // debug information. + if existingDebugInfo, ok := spiceerrors.GetDetails[*v1.DebugInformation](resolved.Err); ok { + debugInfo.Check.SubProblems = []*v1.CheckDebugTrace{existingDebugInfo.Check} + } + + resolved.Resp.Metadata.DebugInfo = debugInfo + + // If there is an error and it is already a gRPC error, add the debug information + // into the details portion of the payload. This allows the client to see the debug + // information, as gRPC will only return the error. + updatedErr := spiceerrors.WithReplacedDetails(resolved.Err, debugInfo) + return resolved.Resp, updatedErr +} + +func (cc *ConcurrentChecker) checkInternal(ctx context.Context, req ValidatedCheckRequest, relation *core.Relation) CheckResult { + spiceerrors.DebugAssert(func() bool { + return relation.GetUsersetRewrite() != nil || relation.GetTypeInformation() != nil + }, "found relation without type information") + + // Ensure that we have at least one resource ID for which to execute the check. + if len(req.ResourceIds) == 0 { + return checkResultError( + spiceerrors.MustBugf("empty resource IDs given to dispatched check"), + emptyMetadata, + ) + } + + // Ensure that we are not performing a check for a wildcard as the subject. + if req.Subject.ObjectId == tuple.PublicWildcard { + return checkResultError(NewWildcardNotAllowedErr("cannot perform check on wildcard subject", "subject.object_id"), emptyMetadata) + } + + // Deduplicate any incoming resource IDs. + resourceIds := lo.Uniq(req.ResourceIds) + + // Filter the incoming resource IDs for any which match the subject directly. For example, if we receive + // a check for resource `user:{tom, fred, sarah}#...` and a subject of `user:sarah#...`, then we know + // that `user:sarah#...` is a valid "member" of the resource, as it matches exactly. + // + // If the filtering results in no further resource IDs to check, or a result is found and a single + // result is allowed, we terminate early. + membershipSet, filteredResourcesIds := filterForFoundMemberResource(req.ResourceRelation, resourceIds, req.Subject) + if membershipSet.HasDeterminedMember() && req.DispatchCheckRequest.ResultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT { + return checkResultsForMembership(membershipSet, emptyMetadata) + } + + // Filter for check hints, if any. + if len(req.CheckHints) > 0 { + subject := tuple.FromCoreObjectAndRelation(req.Subject) + filteredResourcesIdsSet := mapz.NewSet(filteredResourcesIds...) + for _, checkHint := range req.CheckHints { + resourceID, ok := hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.ResourceRelation.Relation, subject) + if ok { + filteredResourcesIdsSet.Delete(resourceID) + continue + } + + if req.OriginalRelationName != "" { + resourceID, ok = hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.OriginalRelationName, subject) + if ok { + filteredResourcesIdsSet.Delete(resourceID) + } + } + } + filteredResourcesIds = filteredResourcesIdsSet.AsSlice() + } + + if len(filteredResourcesIds) == 0 { + return combineWithCheckHints(combineResultWithFoundResources(noMembers(), membershipSet), req) + } + + // NOTE: We can always allow a single result if we're only trying to find the results for a + // single resource ID. This "reset" allows for short circuiting of downstream dispatched calls. + resultsSetting := req.ResultsSetting + if len(filteredResourcesIds) == 1 { + resultsSetting = v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT + } + + crc := currentRequestContext{ + parentReq: req, + filteredResourceIDs: filteredResourcesIds, + resultsSetting: resultsSetting, + dispatchChunkSize: cc.dispatchChunkSize, + } + + if req.Debug == v1.DispatchCheckRequest_ENABLE_TRACE_DEBUGGING { + crc.dispatchChunkSize = 1 + } + + if relation.UsersetRewrite == nil { + return combineWithCheckHints(combineResultWithFoundResources(cc.checkDirect(ctx, crc, relation), membershipSet), req) + } + + return combineWithCheckHints(combineResultWithFoundResources(cc.checkUsersetRewrite(ctx, crc, relation.UsersetRewrite), membershipSet), req) +} + +func combineWithComputedHints(result CheckResult, hints map[string]*v1.ResourceCheckResult) CheckResult { + if len(hints) == 0 { + return result + } + + for resourceID, hint := range hints { + if _, ok := result.Resp.ResultsByResourceId[resourceID]; ok { + return checkResultError( + spiceerrors.MustBugf("check hint for resource ID %q, which already exists", resourceID), + emptyMetadata, + ) + } + + if result.Resp.ResultsByResourceId == nil { + result.Resp.ResultsByResourceId = make(map[string]*v1.ResourceCheckResult) + } + result.Resp.ResultsByResourceId[resourceID] = hint + } + + return result +} + +func combineWithCheckHints(result CheckResult, req ValidatedCheckRequest) CheckResult { + if len(req.CheckHints) == 0 { + return result + } + + subject := tuple.FromCoreObjectAndRelation(req.Subject) + for _, checkHint := range req.CheckHints { + resourceID, ok := hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.ResourceRelation.Relation, subject) + if !ok { + if req.OriginalRelationName != "" { + resourceID, ok = hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.OriginalRelationName, subject) + } + + if !ok { + continue + } + } + + if result.Resp.ResultsByResourceId == nil { + result.Resp.ResultsByResourceId = make(map[string]*v1.ResourceCheckResult) + } + + if _, ok := result.Resp.ResultsByResourceId[resourceID]; ok { + return checkResultError( + spiceerrors.MustBugf("check hint for resource ID %q, which already exists", resourceID), + emptyMetadata, + ) + } + + result.Resp.ResultsByResourceId[resourceID] = checkHint.Result + } + + return result +} + +func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequestContext, relation *core.Relation) CheckResult { + ctx, span := tracer.Start(ctx, "checkDirect") + defer span.End() + + // Build a filter for finding the direct relationships for the check. There are three + // classes of relationships to be found: + // 1) the target subject itself, if allowed on this relation + // 2) the wildcard form of the target subject, if a wildcard is allowed on this relation + // 3) Otherwise, any non-terminal (non-`...`) subjects, if allowed on this relation, to be + // redispatched outward + totalNonTerminals := 0 + totalDirectSubjects := 0 + totalWildcardSubjects := 0 + + defer func() { + if totalNonTerminals > 0 { + span.SetName("non terminal") + } else if totalDirectSubjects > 0 { + span.SetName("terminal") + } else { + span.SetName("wildcard subject") + } + }() + log.Ctx(ctx).Trace().Object("direct", crc.parentReq).Send() + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + + directSubjectsAndWildcardsWithoutCaveats := 0 + directSubjectsAndWildcardsWithoutExpiration := 0 + nonTerminalsWithoutCaveats := 0 + nonTerminalsWithoutExpiration := 0 + + for _, allowedDirectRelation := range relation.GetTypeInformation().GetAllowedDirectRelations() { + // If the namespace of the allowed direct relation matches the subject type, there are two + // cases to optimize: + // 1) Finding the target subject itself, as a direct lookup + // 2) Finding a wildcard for the subject type+relation + if allowedDirectRelation.GetNamespace() == crc.parentReq.Subject.Namespace { + if allowedDirectRelation.GetPublicWildcard() != nil { + totalWildcardSubjects++ + } else if allowedDirectRelation.GetRelation() == crc.parentReq.Subject.Relation { + totalDirectSubjects++ + } + + if allowedDirectRelation.RequiredCaveat == nil { + directSubjectsAndWildcardsWithoutCaveats++ + } + + if allowedDirectRelation.RequiredExpiration == nil { + directSubjectsAndWildcardsWithoutExpiration++ + } + } + + // If the relation found is not an ellipsis, then this is a nested relation that + // might need to be followed, so indicate that such relationships should be returned + // + // TODO(jschorr): Use type information to *further* optimize this query around which nested + // relations can reach the target subject type. + if allowedDirectRelation.GetRelation() != tuple.Ellipsis { + totalNonTerminals++ + if allowedDirectRelation.RequiredCaveat == nil { + nonTerminalsWithoutCaveats++ + } + if allowedDirectRelation.RequiredExpiration == nil { + nonTerminalsWithoutExpiration++ + } + } + } + + nonTerminalsCanHaveCaveats := totalNonTerminals != nonTerminalsWithoutCaveats + nonTerminalsCanHaveExpiration := totalNonTerminals != nonTerminalsWithoutExpiration + hasNonTerminals := totalNonTerminals > 0 + + foundResources := NewMembershipSet() + + // If the direct subject or a wildcard form can be found, issue a query for just that + // subject. + var queryCount float64 + defer func() { + directDispatchQueryHistogram.Observe(queryCount) + }() + + hasDirectSubject := totalDirectSubjects > 0 + hasWildcardSubject := totalWildcardSubjects > 0 + if hasDirectSubject || hasWildcardSubject { + directSubjectOrWildcardCanHaveCaveats := directSubjectsAndWildcardsWithoutCaveats != (totalDirectSubjects + totalWildcardSubjects) + directSubjectOrWildcardCanHaveExpiration := directSubjectsAndWildcardsWithoutExpiration != (totalDirectSubjects + totalWildcardSubjects) + + subjectSelectors := []datastore.SubjectsSelector{} + + if hasDirectSubject { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: crc.parentReq.Subject.Namespace, + OptionalSubjectIds: []string{crc.parentReq.Subject.ObjectId}, + RelationFilter: datastore.SubjectRelationFilter{}.WithRelation(crc.parentReq.Subject.Relation), + }) + } + + if hasWildcardSubject { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: crc.parentReq.Subject.Namespace, + OptionalSubjectIds: []string{tuple.PublicWildcard}, + RelationFilter: datastore.SubjectRelationFilter{}.WithEllipsisRelation(), + }) + } + + filter := datastore.RelationshipsFilter{ + OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, + OptionalResourceIds: crc.filteredResourceIDs, + OptionalResourceRelation: crc.parentReq.ResourceRelation.Relation, + OptionalSubjectsSelectors: subjectSelectors, + } + + it, err := ds.QueryRelationships(ctx, filter, + options.WithSkipCaveats(!directSubjectOrWildcardCanHaveCaveats), + options.WithSkipExpiration(!directSubjectOrWildcardCanHaveExpiration), + options.WithQueryShape(queryshape.CheckPermissionSelectDirectSubjects), + ) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + queryCount += 1.0 + + // Find the matching subject(s). + for rel, err := range it { + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + // If the subject of the relationship matches the target subject, then we've found + // a result. + foundResources.AddDirectMember(rel.Resource.ObjectID, rel.OptionalCaveat) + if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT && foundResources.HasDeterminedMember() { + return checkResultsForMembership(foundResources, emptyMetadata) + } + } + } + + // Filter down the resource IDs for further dispatch based on whether they exist as found + // subjects in the existing membership set. + furtherFilteredResourceIDs := make([]string, 0, len(crc.filteredResourceIDs)-foundResources.Size()) + for _, resourceID := range crc.filteredResourceIDs { + if foundResources.HasConcreteResourceID(resourceID) { + continue + } + + furtherFilteredResourceIDs = append(furtherFilteredResourceIDs, resourceID) + } + + // If there are no possible non-terminals, then the check is completed. + if !hasNonTerminals || len(furtherFilteredResourceIDs) == 0 { + return checkResultsForMembership(foundResources, emptyMetadata) + } + + // Otherwise, for any remaining resource IDs, query for redispatch. + filter := datastore.RelationshipsFilter{ + OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, + OptionalResourceIds: furtherFilteredResourceIDs, + OptionalResourceRelation: crc.parentReq.ResourceRelation.Relation, + OptionalSubjectsSelectors: []datastore.SubjectsSelector{ + { + RelationFilter: datastore.SubjectRelationFilter{}.WithOnlyNonEllipsisRelations(), + }, + }, + } + + it, err := ds.QueryRelationships(ctx, filter, + options.WithSkipCaveats(!nonTerminalsCanHaveCaveats), + options.WithSkipExpiration(!nonTerminalsCanHaveExpiration), + options.WithQueryShape(queryshape.CheckPermissionSelectIndirectSubjects), + ) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + queryCount += 1.0 + + // Build the set of subjects over which to dispatch, along with metadata for + // mapping over caveats (if any). + checksToDispatch := newCheckDispatchSet() + for rel, err := range it { + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + checksToDispatch.addForRelationship(rel) + } + + // Dispatch and map to the associated resource ID(s). + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) + result := union(ctx, crc, toDispatch, func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) CheckResult { + // If there are caveats on any of the incoming relationships for the subjects to dispatch, then we must require all + // results to be found, as we need to ensure that all caveats are used for building the final expression. + resultsSetting := crc.resultsSetting + if dd.hasIncomingCaveats { + resultsSetting = v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS + } + + childResult := cc.dispatch(ctx, crc, ValidatedCheckRequest{ + &v1.DispatchCheckRequest{ + ResourceRelation: dd.resourceType.ToCoreRR(), + ResourceIds: dd.resourceIds, + Subject: crc.parentReq.Subject, + ResultsSetting: resultsSetting, + + Metadata: decrementDepth(crc.parentReq.Metadata), + Debug: crc.parentReq.Debug, + CheckHints: crc.parentReq.CheckHints, + }, + crc.parentReq.Revision, + noOriginalRelation, + }) + + if childResult.Err != nil { + return childResult + } + + return mapFoundResources(childResult, dd.resourceType, checksToDispatch) + }, cc.concurrencyLimit) + + return combineResultWithFoundResources(result, foundResources) +} + +func mapFoundResources(result CheckResult, resourceType tuple.RelationReference, checksToDispatch *checkDispatchSet) CheckResult { + // Map any resources found to the parent resource IDs. + membershipSet := NewMembershipSet() + for foundResourceID, result := range result.Resp.ResultsByResourceId { + resourceIDAndCaveats := checksToDispatch.mappingsForSubject(resourceType.ObjectType, foundResourceID, resourceType.Relation) + + spiceerrors.DebugAssert(func() bool { + return len(resourceIDAndCaveats) > 0 + }, "found resource ID without associated caveats") + + for _, riac := range resourceIDAndCaveats { + membershipSet.AddMemberWithParentCaveat(riac.resourceID, result.Expression, riac.caveat) + } + } + + if membershipSet.IsEmpty() { + return noMembersWithMetadata(result.Resp.Metadata) + } + + return checkResultsForMembership(membershipSet, result.Resp.Metadata) +} + +func (cc *ConcurrentChecker) checkUsersetRewrite(ctx context.Context, crc currentRequestContext, rewrite *core.UsersetRewrite) CheckResult { + switch rw := rewrite.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + if len(rw.Union.Child) > 1 { + var span trace.Span + ctx, span = tracer.Start(ctx, "+") + defer span.End() + } + return union(ctx, crc, rw.Union.Child, cc.runSetOperation, cc.concurrencyLimit) + case *core.UsersetRewrite_Intersection: + ctx, span := tracer.Start(ctx, "&") + defer span.End() + return all(ctx, crc, rw.Intersection.Child, cc.runSetOperation, cc.concurrencyLimit) + case *core.UsersetRewrite_Exclusion: + ctx, span := tracer.Start(ctx, "-") + defer span.End() + return difference(ctx, crc, rw.Exclusion.Child, cc.runSetOperation, cc.concurrencyLimit) + default: + return checkResultError(spiceerrors.MustBugf("unknown userset rewrite operator"), emptyMetadata) + } +} + +func (cc *ConcurrentChecker) dispatch(ctx context.Context, _ currentRequestContext, req ValidatedCheckRequest) CheckResult { + log.Ctx(ctx).Trace().Object("dispatch", req).Send() + result, err := cc.d.DispatchCheck(ctx, req.DispatchCheckRequest) + return CheckResult{result, err} +} + +func (cc *ConcurrentChecker) runSetOperation(ctx context.Context, crc currentRequestContext, childOneof *core.SetOperation_Child) CheckResult { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_XThis: + return checkResultError(spiceerrors.MustBugf("use of _this is unsupported; please rewrite your schema"), emptyMetadata) + case *core.SetOperation_Child_ComputedUserset: + return cc.checkComputedUserset(ctx, crc, child.ComputedUserset, nil, nil) + case *core.SetOperation_Child_UsersetRewrite: + return cc.checkUsersetRewrite(ctx, crc, child.UsersetRewrite) + case *core.SetOperation_Child_TupleToUserset: + return checkTupleToUserset(ctx, cc, crc, child.TupleToUserset) + case *core.SetOperation_Child_FunctionedTupleToUserset: + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + return checkTupleToUserset(ctx, cc, crc, child.FunctionedTupleToUserset) + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + return checkIntersectionTupleToUserset(ctx, cc, crc, child.FunctionedTupleToUserset) + + default: + return checkResultError(spiceerrors.MustBugf("unknown userset function `%s`", child.FunctionedTupleToUserset.Function), emptyMetadata) + } + + case *core.SetOperation_Child_XNil: + return noMembers() + default: + return checkResultError(spiceerrors.MustBugf("unknown set operation child `%T` in check", child), emptyMetadata) + } +} + +func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc currentRequestContext, cu *core.ComputedUserset, rr *tuple.RelationReference, resourceIds []string) CheckResult { + ctx, span := tracer.Start(ctx, cu.Relation) + defer span.End() + + var startNamespace string + var targetResourceIds []string + if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT { + if rr == nil || len(resourceIds) == 0 { + return checkResultError(spiceerrors.MustBugf("computed userset for tupleset without tuples"), emptyMetadata) + } + + startNamespace = rr.ObjectType + targetResourceIds = resourceIds + } else if cu.Object == core.ComputedUserset_TUPLE_OBJECT { + if rr != nil { + return checkResultError(spiceerrors.MustBugf("computed userset for tupleset with wrong object type"), emptyMetadata) + } + + startNamespace = crc.parentReq.ResourceRelation.Namespace + targetResourceIds = crc.filteredResourceIDs + } + + targetRR := &core.RelationReference{ + Namespace: startNamespace, + Relation: cu.Relation, + } + + // If we will be dispatching to the goal's ONR, then we know that the ONR is a member. + membershipSet, updatedTargetResourceIds := filterForFoundMemberResource(targetRR, targetResourceIds, crc.parentReq.Subject) + if (membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT) || len(updatedTargetResourceIds) == 0 { + return checkResultsForMembership(membershipSet, emptyMetadata) + } + + // Check if the target relation exists. If not, return nothing. This is only necessary + // for TTU-based computed usersets, as directly computed ones reference relations within + // the same namespace as the caller, and thus must be fully typed checked. + if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + err := namespace.CheckNamespaceAndRelation(ctx, targetRR.Namespace, targetRR.Relation, true, ds) + if err != nil { + if errors.As(err, &namespace.RelationNotFoundError{}) { + return noMembers() + } + + return checkResultError(err, emptyMetadata) + } + } + + result := cc.dispatch(ctx, crc, ValidatedCheckRequest{ + &v1.DispatchCheckRequest{ + ResourceRelation: targetRR, + ResourceIds: updatedTargetResourceIds, + Subject: crc.parentReq.Subject, + ResultsSetting: crc.resultsSetting, + Metadata: decrementDepth(crc.parentReq.Metadata), + Debug: crc.parentReq.Debug, + CheckHints: crc.parentReq.CheckHints, + }, + crc.parentReq.Revision, + noOriginalRelation, + }) + return combineResultWithFoundResources(result, membershipSet) +} + +type Traits struct { + HasCaveats bool + HasExpiration bool +} + +// TraitsForArrowRelation returns traits such as HasCaveats and HasExpiration if *any* of the subject +// types of the given relation support caveats or expiration. +func TraitsForArrowRelation(ctx context.Context, reader datastore.Reader, namespaceName string, relationName string) (Traits, error) { + // TODO(jschorr): Change to use the type system once we wire it through Check dispatch. + nsDef, _, err := reader.ReadNamespaceByName(ctx, namespaceName) + if err != nil { + return Traits{}, err + } + + var relation *core.Relation + for _, rel := range nsDef.Relation { + if rel.Name == relationName { + relation = rel + break + } + } + + if relation == nil || relation.TypeInformation == nil { + return Traits{}, fmt.Errorf("relation %q not found", relationName) + } + + hasCaveats := false + hasExpiration := false + + for _, allowedDirectRelation := range relation.TypeInformation.GetAllowedDirectRelations() { + if allowedDirectRelation.RequiredCaveat != nil { + hasCaveats = true + } + + if allowedDirectRelation.RequiredExpiration != nil { + hasExpiration = true + } + } + + return Traits{ + HasCaveats: hasCaveats, + HasExpiration: hasExpiration, + }, nil +} + +func queryOptionsForArrowRelation(ctx context.Context, ds datastore.Reader, namespaceName string, relationName string) ([]options.QueryOptionsOption, error) { + opts := make([]options.QueryOptionsOption, 0, 3) + opts = append(opts, options.WithQueryShape(queryshape.AllSubjectsForResources)) + + traits, err := TraitsForArrowRelation(ctx, ds, namespaceName, relationName) + if err != nil { + return nil, err + } + + if !traits.HasCaveats { + opts = append(opts, options.WithSkipCaveats(true)) + } + + if !traits.HasExpiration { + opts = append(opts, options.WithSkipExpiration(true)) + } + + return opts, nil +} + +func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) (*MembershipSet, []string) { + if resourceRelation.Namespace != subject.Namespace || resourceRelation.Relation != subject.Relation { + return nil, resourceIds + } + + for index, resourceID := range resourceIds { + if subject.ObjectId == resourceID { + membershipSet := NewMembershipSet() + membershipSet.AddDirectMember(resourceID, nil) + return membershipSet, removeIndexFromSlice(resourceIds, index) + } + } + + return nil, resourceIds +} + +func removeIndexFromSlice[T any](s []T, index int) []T { + cpy := make([]T, 0, len(s)-1) + cpy = append(cpy, s[:index]...) + return append(cpy, s[index+1:]...) +} + +type relation interface { + GetRelation() string +} + +type ttu[T relation] interface { + GetComputedUserset() *core.ComputedUserset + GetTupleset() T +} + +type checkResultWithType struct { + CheckResult + + relationType tuple.RelationReference +} + +func checkIntersectionTupleToUserset( + ctx context.Context, + cc *ConcurrentChecker, + crc currentRequestContext, + ttu *core.FunctionedTupleToUserset, +) CheckResult { + // TODO(jschorr): use check hints here + ctx, span := tracer.Start(ctx, ttu.GetTupleset().GetRelation()+"-(all)->"+ttu.GetComputedUserset().Relation) + defer span.End() + + // Query for the subjects over which to walk the TTU. + log.Ctx(ctx).Trace().Object("intersectionttu", crc.parentReq).Send() + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, + OptionalResourceIds: crc.filteredResourceIDs, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + }, queryOpts...) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + checksToDispatch := newCheckDispatchSet() + subjectsByResourceID := mapz.NewMultiMap[string, tuple.ObjectAndRelation]() + for rel, err := range it { + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + checksToDispatch.addForRelationship(rel) + subjectsByResourceID.Add(rel.Resource.ObjectID, rel.Subject) + } + + // Convert the subjects into batched requests. + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) + if len(toDispatch) == 0 { + return noMembers() + } + + // Run the dispatch for all the chunks. Unlike a standard TTU, we do *not* perform mapping here, + // as we need to access the results on a per subject basis. Instead, we keep each result and map + // by the relation type of the dispatched subject. + chunkResults, err := run( + ctx, + currentRequestContext{ + parentReq: crc.parentReq, + filteredResourceIDs: crc.filteredResourceIDs, + resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS, + dispatchChunkSize: crc.dispatchChunkSize, + }, + toDispatch, + func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) checkResultWithType { + resourceType := dd.resourceType + childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), &resourceType, dd.resourceIds) + return checkResultWithType{ + CheckResult: childResult, + relationType: dd.resourceType, + } + }, + cc.concurrencyLimit, + ) + if err != nil { + return checkResultError(err, emptyMetadata) + } + + // Create a membership set per-subject-type, representing the membership for each of the dispatched subjects. + resultsByDispatchedSubject := map[tuple.RelationReference]*MembershipSet{} + combinedMetadata := emptyMetadata + for _, result := range chunkResults { + if result.Err != nil { + return checkResultError(result.Err, emptyMetadata) + } + + if _, ok := resultsByDispatchedSubject[result.relationType]; !ok { + resultsByDispatchedSubject[result.relationType] = NewMembershipSet() + } + + resultsByDispatchedSubject[result.relationType].UnionWith(result.Resp.ResultsByResourceId) + combinedMetadata = combineResponseMetadata(ctx, combinedMetadata, result.Resp.Metadata) + } + + // For each resource ID, check that there exist some sort of permission for *each* subject. If not, then the + // intersection for that resource fails. If all subjects have some sort of permission, then the resource ID is + // a member, perhaps caveated. + resourcesFound := NewMembershipSet() + for _, resourceID := range subjectsByResourceID.Keys() { + subjects, _ := subjectsByResourceID.Get(resourceID) + if len(subjects) == 0 { + return checkResultError(spiceerrors.MustBugf("no subjects found for resource ID %s", resourceID), emptyMetadata) + } + + hasAllSubjects := true + caveats := make([]*core.CaveatExpression, 0, len(subjects)) + + // Check each of the subjects found for the resource ID and ensure that membership (at least caveated) + // was found for each. If any are not found, then the resource ID is not a member. + // We also collect up the caveats for each subject, as they will be added to the final result. + for _, subject := range subjects { + subjectTypeKey := subject.RelationReference() + results, ok := resultsByDispatchedSubject[subjectTypeKey] + if !ok { + hasAllSubjects = false + break + } + + hasMembership, caveat := results.GetResourceID(subject.ObjectID) + if !hasMembership { + hasAllSubjects = false + break + } + + if caveat != nil { + caveats = append(caveats, caveat) + } + + // Add any caveats on the subject from the starting relationship(s) as well. + resourceIDAndCaveats := checksToDispatch.mappingsForSubject(subject.ObjectType, subject.ObjectID, subject.Relation) + for _, riac := range resourceIDAndCaveats { + if riac.caveat != nil { + caveats = append(caveats, wrapCaveat(riac.caveat)) + } + } + } + + if !hasAllSubjects { + continue + } + + // Add the member to the membership set, with the caveats for each (if any). + resourcesFound.AddMemberWithOptionalCaveats(resourceID, caveats) + } + + return checkResultsForMembership(resourcesFound, combinedMetadata) +} + +func checkTupleToUserset[T relation]( + ctx context.Context, + cc *ConcurrentChecker, + crc currentRequestContext, + ttu ttu[T], +) CheckResult { + filteredResourceIDs := crc.filteredResourceIDs + hintsToReturn := make(map[string]*v1.ResourceCheckResult, len(crc.parentReq.CheckHints)) + if len(crc.parentReq.CheckHints) > 0 { + filteredResourcesIdsSet := mapz.NewSet(crc.filteredResourceIDs...) + + for _, checkHint := range crc.parentReq.CheckHints { + resourceID, ok := hints.AsCheckHintForArrow( + checkHint, + crc.parentReq.ResourceRelation.Namespace, + ttu.GetTupleset().GetRelation(), + ttu.GetComputedUserset().Relation, + tuple.FromCoreObjectAndRelation(crc.parentReq.Subject), + ) + if !ok { + continue + } + + filteredResourcesIdsSet.Delete(resourceID) + hintsToReturn[resourceID] = checkHint.Result + } + + filteredResourceIDs = filteredResourcesIdsSet.AsSlice() + } + + if len(filteredResourceIDs) == 0 { + return combineWithComputedHints(noMembers(), hintsToReturn) + } + + ctx, span := tracer.Start(ctx, ttu.GetTupleset().GetRelation()+"->"+ttu.GetComputedUserset().Relation) + defer span.End() + + log.Ctx(ctx).Trace().Object("ttu", crc.parentReq).Send() + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + + queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, + OptionalResourceIds: filteredResourceIDs, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + }, queryOpts...) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + checksToDispatch := newCheckDispatchSet() + for rel, err := range it { + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + checksToDispatch.addForRelationship(rel) + } + + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) + return combineWithComputedHints(union( + ctx, + crc, + toDispatch, + func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) CheckResult { + resourceType := dd.resourceType + childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), &resourceType, dd.resourceIds) + if childResult.Err != nil { + return childResult + } + + return mapFoundResources(childResult, dd.resourceType, checksToDispatch) + }, + cc.concurrencyLimit, + ), hintsToReturn) +} + +func withDistinctMetadata(ctx context.Context, result CheckResult) CheckResult { + // NOTE: This is necessary to ensure unique debug information on the request and that debug + // information from the child metadata is *not* copied over. + clonedResp := result.Resp.CloneVT() + clonedResp.Metadata = combineResponseMetadata(ctx, emptyMetadata, clonedResp.Metadata) + return CheckResult{ + Resp: clonedResp, + Err: result.Err, + } +} + +// run runs all the children in parallel and returns the full set of results. +func run[T any, R withError]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) R, + concurrencyLimit uint16, +) ([]R, error) { + if len(children) == 0 { + return nil, nil + } + + if len(children) == 1 { + return []R{handler(ctx, crc, children[0])}, nil + } + + resultChan := make(chan R, len(children)) + childCtx, cancelFn := context.WithCancel(ctx) + dispatchAllAsync(childCtx, crc, children, handler, resultChan, concurrencyLimit) + defer cancelFn() + + results := make([]R, 0, len(children)) + for i := 0; i < len(children); i++ { + select { + case result := <-resultChan: + results = append(results, result) + + case <-ctx.Done(): + log.Ctx(ctx).Trace().Msg("anyCanceled") + return nil, ctx.Err() + } + } + + return results, nil +} + +// union returns whether any one of the lazy checks pass, and is used for union. +func union[T any]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult, + concurrencyLimit uint16, +) CheckResult { + if len(children) == 0 { + return noMembers() + } + + if len(children) == 1 { + return withDistinctMetadata(ctx, handler(ctx, crc, children[0])) + } + + resultChan := make(chan CheckResult, len(children)) + childCtx, cancelFn := context.WithCancel(ctx) + dispatchAllAsync(childCtx, crc, children, handler, resultChan, concurrencyLimit) + defer cancelFn() + + responseMetadata := emptyMetadata + membershipSet := NewMembershipSet() + + for i := 0; i < len(children); i++ { + select { + case result := <-resultChan: + log.Ctx(ctx).Trace().Object("anyResult", result.Resp).Send() + responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata) + if result.Err != nil { + return checkResultError(result.Err, responseMetadata) + } + + membershipSet.UnionWith(result.Resp.ResultsByResourceId) + if membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT { + return checkResultsForMembership(membershipSet, responseMetadata) + } + + case <-ctx.Done(): + log.Ctx(ctx).Trace().Msg("anyCanceled") + return checkResultError(context.Canceled, responseMetadata) + } + } + + return checkResultsForMembership(membershipSet, responseMetadata) +} + +// all returns whether all of the lazy checks pass, and is used for intersection. +func all[T any]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult, + concurrencyLimit uint16, +) CheckResult { + if len(children) == 0 { + return noMembers() + } + + if len(children) == 1 { + return withDistinctMetadata(ctx, handler(ctx, crc, children[0])) + } + + responseMetadata := emptyMetadata + + resultChan := make(chan CheckResult, len(children)) + childCtx, cancelFn := context.WithCancel(ctx) + dispatchAllAsync(childCtx, currentRequestContext{ + parentReq: crc.parentReq, + filteredResourceIDs: crc.filteredResourceIDs, + resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS, + dispatchChunkSize: crc.dispatchChunkSize, + }, children, handler, resultChan, concurrencyLimit) + defer cancelFn() + + var membershipSet *MembershipSet + for i := 0; i < len(children); i++ { + select { + case result := <-resultChan: + responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata) + if result.Err != nil { + return checkResultError(result.Err, responseMetadata) + } + + if membershipSet == nil { + membershipSet = NewMembershipSet() + membershipSet.UnionWith(result.Resp.ResultsByResourceId) + } else { + membershipSet.IntersectWith(result.Resp.ResultsByResourceId) + } + + if membershipSet.IsEmpty() { + return noMembersWithMetadata(responseMetadata) + } + case <-ctx.Done(): + return checkResultError(context.Canceled, responseMetadata) + } + } + + return checkResultsForMembership(membershipSet, responseMetadata) +} + +// difference returns whether the first lazy check passes and none of the subsequent checks pass. +func difference[T any]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult, + concurrencyLimit uint16, +) CheckResult { + if len(children) == 0 { + return noMembers() + } + + if len(children) == 1 { + return checkResultError(spiceerrors.MustBugf("difference requires more than a single child"), emptyMetadata) + } + + childCtx, cancelFn := context.WithCancel(ctx) + baseChan := make(chan CheckResult, 1) + othersChan := make(chan CheckResult, len(children)-1) + + go func() { + result := handler(childCtx, currentRequestContext{ + parentReq: crc.parentReq, + filteredResourceIDs: crc.filteredResourceIDs, + resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS, + dispatchChunkSize: crc.dispatchChunkSize, + }, children[0]) + baseChan <- result + }() + + dispatchAllAsync(childCtx, currentRequestContext{ + parentReq: crc.parentReq, + filteredResourceIDs: crc.filteredResourceIDs, + resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS, + dispatchChunkSize: crc.dispatchChunkSize, + }, children[1:], handler, othersChan, concurrencyLimit-1) + defer cancelFn() + + responseMetadata := emptyMetadata + membershipSet := NewMembershipSet() + + // Wait for the base set to return. + select { + case base := <-baseChan: + responseMetadata = combineResponseMetadata(ctx, responseMetadata, base.Resp.Metadata) + + if base.Err != nil { + return checkResultError(base.Err, responseMetadata) + } + + membershipSet.UnionWith(base.Resp.ResultsByResourceId) + if membershipSet.IsEmpty() { + return noMembersWithMetadata(responseMetadata) + } + + case <-ctx.Done(): + return checkResultError(context.Canceled, responseMetadata) + } + + // Subtract the remaining sets. + for i := 1; i < len(children); i++ { + select { + case sub := <-othersChan: + responseMetadata = combineResponseMetadata(ctx, responseMetadata, sub.Resp.Metadata) + + if sub.Err != nil { + return checkResultError(sub.Err, responseMetadata) + } + + membershipSet.Subtract(sub.Resp.ResultsByResourceId) + if membershipSet.IsEmpty() { + return noMembersWithMetadata(responseMetadata) + } + + case <-ctx.Done(): + return checkResultError(context.Canceled, responseMetadata) + } + } + + return checkResultsForMembership(membershipSet, responseMetadata) +} + +type withError interface { + ResultError() error +} + +func dispatchAllAsync[T any, R withError]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) R, + resultChan chan<- R, + concurrencyLimit uint16, +) { + tr := taskrunner.NewPreloadedTaskRunner(ctx, concurrencyLimit, len(children)) + for _, currentChild := range children { + currentChild := currentChild + tr.Add(func(ctx context.Context) error { + result := handler(ctx, crc, currentChild) + resultChan <- result + return result.ResultError() + }) + } + + tr.Start() +} + +func noMembers() CheckResult { + return CheckResult{ + &v1.DispatchCheckResponse{ + Metadata: emptyMetadata, + }, + nil, + } +} + +func noMembersWithMetadata(metadata *v1.ResponseMeta) CheckResult { + return CheckResult{ + &v1.DispatchCheckResponse{ + Metadata: metadata, + }, + nil, + } +} + +func checkResultsForMembership(foundMembership *MembershipSet, subProblemMetadata *v1.ResponseMeta) CheckResult { + return CheckResult{ + &v1.DispatchCheckResponse{ + Metadata: ensureMetadata(subProblemMetadata), + ResultsByResourceId: foundMembership.AsCheckResultsMap(), + }, + nil, + } +} + +func checkResultError(err error, subProblemMetadata *v1.ResponseMeta) CheckResult { + return CheckResult{ + &v1.DispatchCheckResponse{ + Metadata: ensureMetadata(subProblemMetadata), + }, + err, + } +} + +func combineResultWithFoundResources(result CheckResult, foundResources *MembershipSet) CheckResult { + if result.Err != nil { + return result + } + + if foundResources.IsEmpty() { + return result + } + + foundResources.UnionWith(result.Resp.ResultsByResourceId) + return CheckResult{ + Resp: &v1.DispatchCheckResponse{ + ResultsByResourceId: foundResources.AsCheckResultsMap(), + Metadata: result.Resp.Metadata, + }, + Err: result.Err, + } +} + +func combineResponseMetadata(ctx context.Context, existing *v1.ResponseMeta, responseMetadata *v1.ResponseMeta) *v1.ResponseMeta { + combined := &v1.ResponseMeta{ + DispatchCount: existing.DispatchCount + responseMetadata.DispatchCount, + DepthRequired: max(existing.DepthRequired, responseMetadata.DepthRequired), + CachedDispatchCount: existing.CachedDispatchCount + responseMetadata.CachedDispatchCount, + } + + if existing.DebugInfo == nil && responseMetadata.DebugInfo == nil { + return combined + } + + nodeID, err := nodeid.FromContext(ctx) + if err != nil { + log.Err(err).Msg("failed to get nodeID from context") + } + + debugInfo := &v1.DebugInformation{ + Check: &v1.CheckDebugTrace{ + TraceId: NewTraceID(), + SourceId: nodeID, + }, + } + + if existing.DebugInfo != nil { + if existing.DebugInfo.Check.Request != nil { + debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, existing.DebugInfo.Check) + } else { + debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, existing.DebugInfo.Check.SubProblems...) + } + } + + if responseMetadata.DebugInfo != nil { + if responseMetadata.DebugInfo.Check.Request != nil { + debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, responseMetadata.DebugInfo.Check) + } else { + debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, responseMetadata.DebugInfo.Check.SubProblems...) + } + } + + combined.DebugInfo = debugInfo + return combined +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go b/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go new file mode 100644 index 0000000..ed3f3cb --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go @@ -0,0 +1,144 @@ +package graph + +import ( + "sort" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/genutil/slicez" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// checkDispatchSet is the set of subjects over which check will need to dispatch +// as subproblems in order to answer the parent problem. +type checkDispatchSet struct { + // bySubjectType is a map from the type of subject to the set of subjects of that type + // over which to dispatch, along with information indicating whether caveats are present + // for that chunk. + bySubjectType map[tuple.RelationReference]map[string]bool + + // bySubject is a map from the subject to the set of resources for which the subject + // has a relationship, along with the caveats that apply to that relationship. + bySubject *mapz.MultiMap[tuple.ObjectAndRelation, resourceIDAndCaveat] +} + +// checkDispatchChunk is a chunk of subjects over which to dispatch a check operation. +type checkDispatchChunk struct { + // resourceType is the type of the subjects in this chunk. + resourceType tuple.RelationReference + + // resourceIds is the set of subjects in this chunk. + resourceIds []string + + // hasIncomingCaveats is true if any of the subjects in this chunk have incoming caveats. + // This is used to determine whether the check operation should be dispatched requiring + // all results. + hasIncomingCaveats bool +} + +// subjectIDAndHasCaveat is a tuple of a subject ID and whether it has a caveat. +type subjectIDAndHasCaveat struct { + // objectID is the ID of the subject. + objectID string + + // hasIncomingCaveats is true if the subject has a caveat. + hasIncomingCaveats bool +} + +// resourceIDAndCaveat is a tuple of a resource ID and a caveat. +type resourceIDAndCaveat struct { + // resourceID is the ID of the resource. + resourceID string + + // caveat is the caveat that applies to the relationship between the subject and the resource. + // May be nil. + caveat *core.ContextualizedCaveat +} + +// newCheckDispatchSet creates and returns a new checkDispatchSet. +func newCheckDispatchSet() *checkDispatchSet { + return &checkDispatchSet{ + bySubjectType: map[tuple.RelationReference]map[string]bool{}, + bySubject: mapz.NewMultiMap[tuple.ObjectAndRelation, resourceIDAndCaveat](), + } +} + +// Add adds the specified ObjectAndRelation to the set. +func (s *checkDispatchSet) addForRelationship(rel tuple.Relationship) { + // Add an entry for the subject pointing to the resource ID and caveat for the subject. + riac := resourceIDAndCaveat{ + resourceID: rel.Resource.ObjectID, + caveat: rel.OptionalCaveat, + } + s.bySubject.Add(rel.Subject, riac) + + // Add the subject ID to the map of subjects for the type of subject. + siac := subjectIDAndHasCaveat{ + objectID: rel.Subject.ObjectID, + hasIncomingCaveats: rel.OptionalCaveat != nil && rel.OptionalCaveat.CaveatName != "", + } + + subjectIDsForType, ok := s.bySubjectType[rel.Subject.RelationReference()] + if !ok { + subjectIDsForType = make(map[string]bool) + s.bySubjectType[rel.Subject.RelationReference()] = subjectIDsForType + } + + // If a caveat exists for the subject ID in any branch, the whole branch is considered caveated. + subjectIDsForType[rel.Subject.ObjectID] = siac.hasIncomingCaveats || subjectIDsForType[rel.Subject.ObjectID] +} + +func (s *checkDispatchSet) dispatchChunks(dispatchChunkSize uint16) []checkDispatchChunk { + // Start with an estimate of one chunk per type, plus one for the remainder. + expectedNumberOfChunks := len(s.bySubjectType) + 1 + toDispatch := make([]checkDispatchChunk, 0, expectedNumberOfChunks) + + // For each type of subject, create chunks of the IDs over which to dispatch. + for subjectType, subjectIDsAndHasCaveats := range s.bySubjectType { + entries := make([]subjectIDAndHasCaveat, 0, len(subjectIDsAndHasCaveats)) + for objectID, hasIncomingCaveats := range subjectIDsAndHasCaveats { + entries = append(entries, subjectIDAndHasCaveat{objectID: objectID, hasIncomingCaveats: hasIncomingCaveats}) + } + + // Sort the list of subject IDs by whether they have caveats and then the ID itself. + sort.Slice(entries, func(i, j int) bool { + iHasCaveat := entries[i].hasIncomingCaveats + jHasCaveat := entries[j].hasIncomingCaveats + if iHasCaveat == jHasCaveat { + return entries[i].objectID < entries[j].objectID + } + return iHasCaveat && !jHasCaveat + }) + + chunkCount := 0.0 + slicez.ForEachChunk(entries, dispatchChunkSize, func(subjectIdChunk []subjectIDAndHasCaveat) { + chunkCount++ + + subjectIDsToDispatch := make([]string, 0, len(subjectIdChunk)) + hasIncomingCaveats := false + for _, entry := range subjectIdChunk { + subjectIDsToDispatch = append(subjectIDsToDispatch, entry.objectID) + hasIncomingCaveats = hasIncomingCaveats || entry.hasIncomingCaveats + } + + toDispatch = append(toDispatch, checkDispatchChunk{ + resourceType: subjectType, + resourceIds: subjectIDsToDispatch, + hasIncomingCaveats: hasIncomingCaveats, + }) + }) + dispatchChunkCountHistogram.Observe(chunkCount) + } + + return toDispatch +} + +// mappingsForSubject returns the mappings that apply to the relationship between the specified +// subject and any of its resources. The returned caveats include the resource ID of the resource +// that the subject has a relationship with. +func (s *checkDispatchSet) mappingsForSubject(subjectType string, subjectObjectID string, subjectRelation string) []resourceIDAndCaveat { + results, ok := s.bySubject.Get(tuple.ONR(subjectType, subjectObjectID, subjectRelation)) + spiceerrors.DebugAssert(func() bool { return ok }, "no caveats found for subject %s:%s:%s", subjectType, subjectObjectID, subjectRelation) + return results +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go b/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go new file mode 100644 index 0000000..0bf20b5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go @@ -0,0 +1,205 @@ +package computed + +import ( + "context" + + cexpr "github.com/authzed/spicedb/internal/caveats" + "github.com/authzed/spicedb/internal/dispatch" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/slicez" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// DebugOption defines the various debug level options for Checks. +type DebugOption int + +const ( + // NoDebugging indicates that debug information should be retained + // while performing the Check. + NoDebugging DebugOption = 0 + + // BasicDebuggingEnabled indicates that basic debug information, such + // as which steps were taken, should be retained while performing the + // Check and returned to the caller. + // + // NOTE: This has a minor performance impact. + BasicDebuggingEnabled DebugOption = 1 + + // TraceDebuggingEnabled indicates that the Check is being issued for + // tracing the exact calls made for debugging, which means that not only + // should debug information be recorded and returned, but that optimizations + // such as batching should be disabled. + // + // WARNING: This has a fairly significant performance impact and should only + // be used in tooling! + TraceDebuggingEnabled DebugOption = 2 +) + +// CheckParameters are the parameters for the ComputeCheck call. *All* are required. +type CheckParameters struct { + ResourceType tuple.RelationReference + Subject tuple.ObjectAndRelation + CaveatContext map[string]any + AtRevision datastore.Revision + MaximumDepth uint32 + DebugOption DebugOption + CheckHints []*v1.CheckHint +} + +// ComputeCheck computes a check result for the given resource and subject, computing any +// caveat expressions found. +func ComputeCheck( + ctx context.Context, + d dispatch.Check, + ts *caveattypes.TypeSet, + params CheckParameters, + resourceID string, + dispatchChunkSize uint16, +) (*v1.ResourceCheckResult, *v1.ResponseMeta, error) { + resultsMap, meta, di, err := computeCheck(ctx, d, ts, params, []string{resourceID}, dispatchChunkSize) + if err != nil { + return nil, meta, err + } + + spiceerrors.DebugAssert(func() bool { + return (len(di) == 0 && meta.DebugInfo == nil) || (len(di) == 1 && meta.DebugInfo != nil) + }, "mismatch in debug information returned from computeCheck") + + return resultsMap[resourceID], meta, err +} + +// ComputeBulkCheck computes a check result for the given resources and subject, computing any +// caveat expressions found. +func ComputeBulkCheck( + ctx context.Context, + d dispatch.Check, + ts *caveattypes.TypeSet, + params CheckParameters, + resourceIDs []string, + dispatchChunkSize uint16, +) (map[string]*v1.ResourceCheckResult, *v1.ResponseMeta, []*v1.DebugInformation, error) { + return computeCheck(ctx, d, ts, params, resourceIDs, dispatchChunkSize) +} + +func computeCheck(ctx context.Context, + d dispatch.Check, + ts *caveattypes.TypeSet, + params CheckParameters, + resourceIDs []string, + dispatchChunkSize uint16, +) (map[string]*v1.ResourceCheckResult, *v1.ResponseMeta, []*v1.DebugInformation, error) { + debugging := v1.DispatchCheckRequest_NO_DEBUG + if params.DebugOption == BasicDebuggingEnabled { + debugging = v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING + } else if params.DebugOption == TraceDebuggingEnabled { + debugging = v1.DispatchCheckRequest_ENABLE_TRACE_DEBUGGING + } + + setting := v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS + if len(resourceIDs) == 1 { + setting = v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT + } + + // Ensure that the number of resources IDs given to each dispatch call is not in excess of the maximum. + results := make(map[string]*v1.ResourceCheckResult, len(resourceIDs)) + metadata := &v1.ResponseMeta{} + + bf, err := v1.NewTraversalBloomFilter(uint(params.MaximumDepth)) + if err != nil { + return nil, nil, nil, spiceerrors.MustBugf("failed to create new traversal bloom filter") + } + + caveatRunner := cexpr.NewCaveatRunner(ts) + + // TODO(jschorr): Should we make this run in parallel via the preloadedTaskRunner? + debugInfo := make([]*v1.DebugInformation, 0) + _, err = slicez.ForEachChunkUntil(resourceIDs, dispatchChunkSize, func(resourceIDsToCheck []string) (bool, error) { + checkResult, err := d.DispatchCheck(ctx, &v1.DispatchCheckRequest{ + ResourceRelation: params.ResourceType.ToCoreRR(), + ResourceIds: resourceIDsToCheck, + ResultsSetting: setting, + Subject: params.Subject.ToCoreONR(), + Metadata: &v1.ResolverMeta{ + AtRevision: params.AtRevision.String(), + DepthRemaining: params.MaximumDepth, + TraversalBloom: bf, + }, + Debug: debugging, + CheckHints: params.CheckHints, + }) + + if checkResult.Metadata.DebugInfo != nil { + debugInfo = append(debugInfo, checkResult.Metadata.DebugInfo) + } + + if len(resourceIDs) == 1 { + metadata = checkResult.Metadata + } else { + metadata = &v1.ResponseMeta{ + DispatchCount: metadata.DispatchCount + checkResult.Metadata.DispatchCount, + DepthRequired: max(metadata.DepthRequired, checkResult.Metadata.DepthRequired), + CachedDispatchCount: metadata.CachedDispatchCount + checkResult.Metadata.CachedDispatchCount, + DebugInfo: nil, + } + } + + if err != nil { + return false, err + } + + for _, resourceID := range resourceIDsToCheck { + computed, err := computeCaveatedCheckResult(ctx, caveatRunner, params, resourceID, checkResult) + if err != nil { + return false, err + } + results[resourceID] = computed + } + + return true, nil + }) + return results, metadata, debugInfo, err +} + +func computeCaveatedCheckResult(ctx context.Context, runner *cexpr.CaveatRunner, params CheckParameters, resourceID string, checkResult *v1.DispatchCheckResponse) (*v1.ResourceCheckResult, error) { + result, ok := checkResult.ResultsByResourceId[resourceID] + if !ok { + return &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_NOT_MEMBER, + }, nil + } + + if result.Membership == v1.ResourceCheckResult_MEMBER { + return result, nil + } + + ds := datastoremw.MustFromContext(ctx) + reader := ds.SnapshotReader(params.AtRevision) + + caveatResult, err := runner.RunCaveatExpression(ctx, result.Expression, params.CaveatContext, reader, cexpr.RunCaveatExpressionNoDebugging) + if err != nil { + return nil, err + } + + if caveatResult.IsPartial() { + missingFields, _ := caveatResult.MissingVarNames() + return &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_CAVEATED_MEMBER, + Expression: result.Expression, + MissingExprFields: missingFields, + }, nil + } + + if caveatResult.Value() { + return &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_MEMBER, + }, nil + } + + return &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_NOT_MEMBER, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/context.go b/vendor/github.com/authzed/spicedb/internal/graph/context.go new file mode 100644 index 0000000..1485fa0 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/context.go @@ -0,0 +1,33 @@ +package graph + +import ( + "context" + + "go.opentelemetry.io/otel/trace" + + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/pkg/middleware/requestid" +) + +// branchContext returns a context disconnected from the parent context, but populated with the datastore. +// Also returns a function for canceling the newly created context, without canceling the parent context. +func branchContext(ctx context.Context) (context.Context, func(cancelErr error)) { + // Add tracing to the context. + span := trace.SpanFromContext(ctx) + detachedContext := trace.ContextWithSpan(context.Background(), span) + + // Add datastore to the context. + ds := datastoremw.FromContext(ctx) + detachedContext = datastoremw.ContextWithDatastore(detachedContext, ds) + + // Add logging to the context. + loggerFromContext := log.Ctx(ctx) + if loggerFromContext != nil { + detachedContext = loggerFromContext.WithContext(detachedContext) + } + + detachedContext = requestid.PropagateIfExists(ctx, detachedContext) + + return context.WithCancelCause(detachedContext) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/cursors.go b/vendor/github.com/authzed/spicedb/internal/graph/cursors.go new file mode 100644 index 0000000..ad3d705 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/cursors.go @@ -0,0 +1,542 @@ +package graph + +import ( + "context" + "errors" + "strconv" + "sync" + + "github.com/ccoveille/go-safecast" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/taskrunner" + "github.com/authzed/spicedb/pkg/datastore/options" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// cursorInformation is a struct which holds information about the current incoming cursor (if any) +// and the sections to be added to the *outgoing* partial cursor. +type cursorInformation struct { + // currentCursor is the current incoming cursor. This may be nil. + currentCursor *v1.Cursor + + // outgoingCursorSections are the sections to be added to the outgoing *partial* cursor. + // It is the responsibility of the *caller* to append together the incoming cursors to form + // the final cursor. + // + // A `section` is a portion of the cursor, representing a section of code that was + // executed to produce the section of the cursor. + outgoingCursorSections []string + + // limits is the limits tracker for the call over which the cursor is being used. + limits *limitTracker + + // dispatchCursorVersion is the version of the dispatch to be stored in the cursor. + dispatchCursorVersion uint32 +} + +// newCursorInformation constructs a new cursorInformation struct from the incoming cursor (which +// may be nil) +func newCursorInformation(incomingCursor *v1.Cursor, limits *limitTracker, dispatchCursorVersion uint32) (cursorInformation, error) { + if incomingCursor != nil && incomingCursor.DispatchVersion != dispatchCursorVersion { + return cursorInformation{}, NewInvalidCursorErr(dispatchCursorVersion, incomingCursor) + } + + if dispatchCursorVersion == 0 { + return cursorInformation{}, spiceerrors.MustBugf("invalid dispatch cursor version") + } + + return cursorInformation{ + currentCursor: incomingCursor, + outgoingCursorSections: nil, + limits: limits, + dispatchCursorVersion: dispatchCursorVersion, + }, nil +} + +// responsePartialCursor is the *partial* cursor to return in a response. +func (ci cursorInformation) responsePartialCursor() *v1.Cursor { + return &v1.Cursor{ + DispatchVersion: ci.dispatchCursorVersion, + Sections: ci.outgoingCursorSections, + } +} + +// withClonedLimits returns the cursor, but with its limits tracker cloned. +func (ci cursorInformation) withClonedLimits() cursorInformation { + return cursorInformation{ + currentCursor: ci.currentCursor, + outgoingCursorSections: ci.outgoingCursorSections, + limits: ci.limits.clone(), + dispatchCursorVersion: ci.dispatchCursorVersion, + } +} + +// headSectionValue returns the string value found at the head of the incoming cursor. +// If the incoming cursor is empty, returns empty. +func (ci cursorInformation) headSectionValue() (string, bool) { + if ci.currentCursor == nil || len(ci.currentCursor.Sections) < 1 { + return "", false + } + + return ci.currentCursor.Sections[0], true +} + +// integerSectionValue returns the *integer* found at the head of the incoming cursor. +// If the incoming cursor is empty, returns 0. If the incoming cursor does not start with an +// int value, fails with an error. +func (ci cursorInformation) integerSectionValue() (int, error) { + valueStr, hasValue := ci.headSectionValue() + if !hasValue { + return 0, nil + } + + if valueStr == "" { + return 0, nil + } + + return strconv.Atoi(valueStr) +} + +// withOutgoingSection returns cursorInformation updated with the given optional +// value appended to the outgoingCursorSections for the current cursor. If the current +// cursor already begins with any values, those values are replaced. +func (ci cursorInformation) withOutgoingSection(value string) (cursorInformation, error) { + ocs := make([]string, 0, len(ci.outgoingCursorSections)+1) + ocs = append(ocs, ci.outgoingCursorSections...) + ocs = append(ocs, value) + + if ci.currentCursor != nil && len(ci.currentCursor.Sections) > 0 { + // If the cursor already has values, replace them with those specified. + return cursorInformation{ + currentCursor: &v1.Cursor{ + DispatchVersion: ci.dispatchCursorVersion, + Sections: ci.currentCursor.Sections[1:], + }, + outgoingCursorSections: ocs, + limits: ci.limits, + dispatchCursorVersion: ci.dispatchCursorVersion, + }, nil + } + + return cursorInformation{ + currentCursor: nil, + outgoingCursorSections: ocs, + limits: ci.limits, + dispatchCursorVersion: ci.dispatchCursorVersion, + }, nil +} + +func (ci cursorInformation) clearIncoming() cursorInformation { + return cursorInformation{ + currentCursor: nil, + outgoingCursorSections: ci.outgoingCursorSections, + limits: ci.limits, + dispatchCursorVersion: ci.dispatchCursorVersion, + } +} + +type cursorHandler func(c cursorInformation) error + +// itemAndPostCursor represents an item and the cursor to be used for all items after it. +type itemAndPostCursor[T any] struct { + item T + cursor options.Cursor +} + +// withDatastoreCursorInCursor executes the given lookup function to retrieve items from the datastore, +// and then executes the handler on each of the produced items *in parallel*, streaming the results +// in the correct order to the parent stream. +func withDatastoreCursorInCursor[T any, Q any]( + ctx context.Context, + ci cursorInformation, + parentStream dispatch.Stream[Q], + concurrencyLimit uint16, + lookup func(queryCursor options.Cursor) ([]itemAndPostCursor[T], error), + handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error, +) error { + // Retrieve the *datastore* cursor, if one is found at the head of the incoming cursor. + var datastoreCursor options.Cursor + datastoreCursorString, _ := ci.headSectionValue() + if datastoreCursorString != "" { + datastoreCursor = options.ToCursor(tuple.MustParse(datastoreCursorString)) + } + + if ci.limits.hasExhaustedLimit() { + return nil + } + + // Execute the lookup to call the database and find items for processing. + itemsToBeProcessed, err := lookup(datastoreCursor) + if err != nil { + return err + } + + if len(itemsToBeProcessed) == 0 { + return nil + } + + itemsToRun := make([]T, 0, len(itemsToBeProcessed)) + for _, itemAndCursor := range itemsToBeProcessed { + itemsToRun = append(itemsToRun, itemAndCursor.item) + } + + getItemCursor := func(taskIndex int) (cursorInformation, error) { + // Create an updated cursor referencing the current item's cursor, so that any items returned know to resume from this point. + cursorRel := options.ToRelationship(itemsToBeProcessed[taskIndex].cursor) + cursorSection := "" + if cursorRel != nil { + cursorSection = tuple.StringWithoutCaveatOrExpiration(*cursorRel) + } + + currentCursor, err := ci.withOutgoingSection(cursorSection) + if err != nil { + return currentCursor, err + } + + // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top + // of the cursor. + if taskIndex > 0 { + currentCursor = currentCursor.clearIncoming() + } + + return currentCursor, nil + } + + return withInternalParallelizedStreamingIterableInCursor( + ctx, + ci, + itemsToRun, + parentStream, + concurrencyLimit, + getItemCursor, + handler, + ) +} + +type afterResponseCursor func(nextOffset int) *v1.Cursor + +// withSubsetInCursor executes the given handler with the offset index found at the beginning of the +// cursor. If the offset is not found, executes with 0. The handler is given the current offset as +// well as a callback to mint the cursor with the next offset. +func withSubsetInCursor( + ci cursorInformation, + handler func(currentOffset int, nextCursorWith afterResponseCursor) error, + next cursorHandler, +) error { + if ci.limits.hasExhaustedLimit() { + return nil + } + + afterIndex, err := ci.integerSectionValue() + if err != nil { + return err + } + + if afterIndex >= 0 { + var foundCerr error + err = handler(afterIndex, func(nextOffset int) *v1.Cursor { + cursor, cerr := ci.withOutgoingSection(strconv.Itoa(nextOffset)) + foundCerr = cerr + if cerr != nil { + return nil + } + + return cursor.responsePartialCursor() + }) + if err != nil { + return err + } + if foundCerr != nil { + return foundCerr + } + } + + if ci.limits.hasExhaustedLimit() { + return nil + } + + // -1 means that the handler has been completed. + uci, err := ci.withOutgoingSection("-1") + if err != nil { + return err + } + return next(uci) +} + +// combineCursors combines the given cursors into one resulting cursor. +func combineCursors(cursor *v1.Cursor, toAdd *v1.Cursor) (*v1.Cursor, error) { + if toAdd == nil || len(toAdd.Sections) == 0 { + return nil, spiceerrors.MustBugf("supplied toAdd cursor was nil or empty") + } + + if cursor == nil || len(cursor.Sections) == 0 { + return toAdd, nil + } + + sections := make([]string, 0, len(cursor.Sections)+len(toAdd.Sections)) + sections = append(sections, cursor.Sections...) + sections = append(sections, toAdd.Sections...) + + return &v1.Cursor{ + DispatchVersion: toAdd.DispatchVersion, + Sections: sections, + }, nil +} + +// withParallelizedStreamingIterableInCursor executes the given handler for each item in the items list, skipping any +// items marked as completed at the head of the cursor and injecting a cursor representing the current +// item. +// +// For example, if items contains 3 items, and the cursor returned was within the handler for item +// index #1, then item index #0 will be skipped on subsequent invocation. +// +// The next index is executed in parallel with the current index, with its results stored in a CollectingStream +// until the next iteration. +func withParallelizedStreamingIterableInCursor[T any, Q any]( + ctx context.Context, + ci cursorInformation, + items []T, + parentStream dispatch.Stream[Q], + concurrencyLimit uint16, + handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error, +) error { + // Check the cursor for a starting index, before which any items will be skipped. + startingIndex, err := ci.integerSectionValue() + if err != nil { + return err + } + + if startingIndex < 0 || startingIndex > len(items) { + return spiceerrors.MustBugf("invalid cursor in withParallelizedStreamingIterableInCursor: found starting index %d for items %v", startingIndex, items) + } + + itemsToRun := items[startingIndex:] + if len(itemsToRun) == 0 { + return nil + } + + getItemCursor := func(taskIndex int) (cursorInformation, error) { + // Create an updated cursor referencing the current item's index, so that any items returned know to resume from this point. + currentCursor, err := ci.withOutgoingSection(strconv.Itoa(taskIndex + startingIndex)) + if err != nil { + return currentCursor, err + } + + // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top + // of the cursor. + if taskIndex > 0 { + currentCursor = currentCursor.clearIncoming() + } + + return currentCursor, nil + } + + return withInternalParallelizedStreamingIterableInCursor( + ctx, + ci, + itemsToRun, + parentStream, + concurrencyLimit, + getItemCursor, + handler, + ) +} + +func withInternalParallelizedStreamingIterableInCursor[T any, Q any]( + ctx context.Context, + ci cursorInformation, + itemsToRun []T, + parentStream dispatch.Stream[Q], + concurrencyLimit uint16, + getItemCursor func(taskIndex int) (cursorInformation, error), + handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error, +) error { + // Queue up each iteration's worth of items to be run by the task runner. + tr := taskrunner.NewPreloadedTaskRunner(ctx, concurrencyLimit, len(itemsToRun)) + stream, err := newParallelLimitedIndexedStream(ctx, ci, parentStream, len(itemsToRun)) + if err != nil { + return err + } + + // Schedule a task to be invoked for each item to be run. + for taskIndex, item := range itemsToRun { + taskIndex := taskIndex + item := item + tr.Add(func(ctx context.Context) error { + stream.lock.Lock() + if ci.limits.hasExhaustedLimit() { + stream.lock.Unlock() + return nil + } + stream.lock.Unlock() + + ici, err := getItemCursor(taskIndex) + if err != nil { + return err + } + + // Invoke the handler with the current item's index in the outgoing cursor, indicating that + // subsequent invocations should jump right to this item. + ictx, istream, icursor := stream.forTaskIndex(ctx, taskIndex, ici) + + err = handler(ictx, icursor, item, istream) + if err != nil { + // If the branch was canceled explicitly by *this* streaming iterable because other branches have fulfilled + // the configured limit, then we can safely ignore this error. + if errors.Is(context.Cause(ictx), stream.errCanceledBecauseFulfilled) { + return nil + } + return err + } + + return stream.completedTaskIndex(taskIndex) + }) + } + + err = tr.StartAndWait() + if err != nil { + return err + } + return nil +} + +// parallelLimitedIndexedStream is a specialization of a dispatch.Stream that collects results from multiple +// tasks running in parallel, and emits them in the order of the tasks. The first task's results are directly +// emitted to the parent stream, while subsequent tasks' results are emitted in the defined order of the tasks +// to ensure cursors and limits work as expected. +type parallelLimitedIndexedStream[Q any] struct { + lock sync.Mutex + + ctx context.Context + ci cursorInformation + parentStream dispatch.Stream[Q] + + streamCount int + toPublishTaskIndex int + countingStream *dispatch.CountingDispatchStream[Q] // GUARDED_BY(lock) + childStreams map[int]*dispatch.CollectingDispatchStream[Q] // GUARDED_BY(lock) + childContextCancels map[int]func(cause error) // GUARDED_BY(lock) + completedTaskIndexes map[int]bool // GUARDED_BY(lock) + errCanceledBecauseFulfilled error +} + +func newParallelLimitedIndexedStream[Q any]( + ctx context.Context, + ci cursorInformation, + parentStream dispatch.Stream[Q], + streamCount int, +) (*parallelLimitedIndexedStream[Q], error) { + if streamCount <= 0 { + return nil, spiceerrors.MustBugf("got invalid stream count") + } + + return ¶llelLimitedIndexedStream[Q]{ + ctx: ctx, + ci: ci, + parentStream: parentStream, + countingStream: nil, + childStreams: map[int]*dispatch.CollectingDispatchStream[Q]{}, + childContextCancels: map[int]func(cause error){}, + completedTaskIndexes: map[int]bool{}, + toPublishTaskIndex: 0, + streamCount: streamCount, + + // NOTE: we mint a new error here to ensure that we only skip cancelations from this very instance. + errCanceledBecauseFulfilled: errors.New("canceled because other branches fulfilled limit"), + }, nil +} + +// forTaskIndex returns a new context, stream and cursor for invoking the task at the specific index and publishing its results. +func (ls *parallelLimitedIndexedStream[Q]) forTaskIndex(ctx context.Context, index int, currentCursor cursorInformation) (context.Context, dispatch.Stream[Q], cursorInformation) { + ls.lock.Lock() + defer ls.lock.Unlock() + + // Create a new cursor with cloned limits, because each child task which executes (in parallel) will need its own + // limit tracking. The overall limit on the original cursor is managed in completedTaskIndex. + childCI := currentCursor.withClonedLimits() + childContext, cancelDispatch := branchContext(ctx) + + ls.childContextCancels[index] = cancelDispatch + + // If executing for the first index, it can stream directly to the parent stream, but we need to count the number + // of items streamed to adjust the overall limits. + if index == 0 { + countingStream := dispatch.NewCountingDispatchStream(ls.parentStream) + ls.countingStream = countingStream + return childContext, countingStream, childCI + } + + // Otherwise, create a child stream with an adjusted limits on the cursor. We have to clone the cursor's + // limits here to ensure that the child's publishing doesn't affect the first branch. + childStream := dispatch.NewCollectingDispatchStream[Q](childContext) + ls.childStreams[index] = childStream + + return childContext, childStream, childCI +} + +// cancelRemainingDispatches cancels the contexts for each dispatched branch, indicating that no additional results +// are necessary. +func (ls *parallelLimitedIndexedStream[Q]) cancelRemainingDispatches() { + for _, cancel := range ls.childContextCancels { + cancel(ls.errCanceledBecauseFulfilled) + } +} + +// completedTaskIndex indicates the the task at the specific index has completed successfully and that its collected +// results should be published to the parent stream, so long as all previous tasks have been completed and published as well. +func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error { + ls.lock.Lock() + defer ls.lock.Unlock() + + // Mark the task as completed, but not yet published. + ls.completedTaskIndexes[index] = true + + // If the overall limit has been reached, nothing more to do. + if ls.ci.limits.hasExhaustedLimit() { + ls.cancelRemainingDispatches() + return nil + } + + // Otherwise, publish any results from previous completed tasks up, and including, this task. This loop ensures + // that the collected results for each task are published to the parent stream in the correct order. + for { + if !ls.completedTaskIndexes[ls.toPublishTaskIndex] { + return nil + } + + if ls.toPublishTaskIndex == 0 { + // Remove the already emitted data from the overall limits. + publishedCount, err := safecast.ToUint32(ls.countingStream.PublishedCount()) + if err != nil { + return spiceerrors.MustBugf("cannot cast published count to uint32: %v", err) + } + if err := ls.ci.limits.markAlreadyPublished(publishedCount); err != nil { + return err + } + + if ls.ci.limits.hasExhaustedLimit() { + ls.cancelRemainingDispatches() + } + } else { + // Publish, to the parent stream, the results produced by the task and stored in the child stream. + childStream := ls.childStreams[ls.toPublishTaskIndex] + for _, result := range childStream.Results() { + if !ls.ci.limits.prepareForPublishing() { + ls.cancelRemainingDispatches() + return nil + } + + err := ls.parentStream.Publish(result) + if err != nil { + return err + } + } + ls.childStreams[ls.toPublishTaskIndex] = nil + } + + ls.toPublishTaskIndex++ + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/doc.go b/vendor/github.com/authzed/spicedb/internal/graph/doc.go new file mode 100644 index 0000000..904b216 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/doc.go @@ -0,0 +1,2 @@ +// Package graph contains the code to traverse a relationship graph to solve requests like Checks, Expansions and Lookups. +package graph diff --git a/vendor/github.com/authzed/spicedb/internal/graph/errors.go b/vendor/github.com/authzed/spicedb/internal/graph/errors.go new file mode 100644 index 0000000..31577d9 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/errors.go @@ -0,0 +1,213 @@ +package graph + +import ( + "errors" + "fmt" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/sharederrors" + dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// CheckFailureError occurs when check failed in some manner. Note this should not apply to +// namespaces and relations not being found. +type CheckFailureError struct { + error +} + +func (e CheckFailureError) Unwrap() error { + return e.error +} + +// NewCheckFailureErr constructs a new check failed error. +func NewCheckFailureErr(baseErr error) error { + return CheckFailureError{ + error: fmt.Errorf("error performing check: %w", baseErr), + } +} + +// ExpansionFailureError occurs when expansion failed in some manner. Note this should not apply to +// namespaces and relations not being found. +type ExpansionFailureError struct { + error +} + +func (e ExpansionFailureError) Unwrap() error { + return e.error +} + +// NewExpansionFailureErr constructs a new expansion failed error. +func NewExpansionFailureErr(baseErr error) error { + return ExpansionFailureError{ + error: fmt.Errorf("error performing expand: %w", baseErr), + } +} + +// AlwaysFailError is returned when an internal error leads to an operation +// guaranteed to fail. +type AlwaysFailError struct { + error +} + +// NewAlwaysFailErr constructs a new always fail error. +func NewAlwaysFailErr() error { + return AlwaysFailError{ + error: errors.New("always fail"), + } +} + +// RelationNotFoundError occurs when a relation was not found under a namespace. +type RelationNotFoundError struct { + error + namespaceName string + relationName string +} + +// NamespaceName returns the name of the namespace in which the relation was not found. +func (err RelationNotFoundError) NamespaceName() string { + return err.namespaceName +} + +// NotFoundRelationName returns the name of the relation not found. +func (err RelationNotFoundError) NotFoundRelationName() string { + return err.relationName +} + +func (err RelationNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err RelationNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + "relation_or_permission_name": err.relationName, + } +} + +// NewRelationNotFoundErr constructs a new relation not found error. +func NewRelationNotFoundErr(nsName string, relationName string) error { + return RelationNotFoundError{ + error: fmt.Errorf("relation/permission `%s` not found under definition `%s`", relationName, nsName), + namespaceName: nsName, + relationName: relationName, + } +} + +var _ sharederrors.UnknownRelationError = RelationNotFoundError{} + +// RelationMissingTypeInfoError defines an error for when type information is missing from a relation +// during a lookup. +type RelationMissingTypeInfoError struct { + error + namespaceName string + relationName string +} + +// NamespaceName returns the name of the namespace in which the relation was found. +func (err RelationMissingTypeInfoError) NamespaceName() string { + return err.namespaceName +} + +// RelationName returns the name of the relation missing type information. +func (err RelationMissingTypeInfoError) RelationName() string { + return err.relationName +} + +func (err RelationMissingTypeInfoError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err RelationMissingTypeInfoError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + "relation_name": err.relationName, + } +} + +// NewRelationMissingTypeInfoErr constructs a new relation not missing type information error. +func NewRelationMissingTypeInfoErr(nsName string, relationName string) error { + return RelationMissingTypeInfoError{ + error: fmt.Errorf("relation/permission `%s` under definition `%s` is missing type information", relationName, nsName), + namespaceName: nsName, + relationName: relationName, + } +} + +// WildcardNotAllowedError occurs when a request sent has an invalid wildcard argument. +type WildcardNotAllowedError struct { + error + + fieldName string +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err WildcardNotAllowedError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_WILDCARD_NOT_ALLOWED, + map[string]string{ + "field": err.fieldName, + }, + ), + ) +} + +// NewWildcardNotAllowedErr constructs an error indicating that a wildcard was not allowed. +func NewWildcardNotAllowedErr(message string, fieldName string) error { + return WildcardNotAllowedError{ + error: fmt.Errorf("invalid argument: %s", message), + fieldName: fieldName, + } +} + +// UnimplementedError is returned when some functionality is not yet supported. +type UnimplementedError struct { + error +} + +// NewUnimplementedErr constructs a new unimplemented error. +func NewUnimplementedErr(baseErr error) error { + return UnimplementedError{ + error: baseErr, + } +} + +func (e UnimplementedError) Unwrap() error { + return e.error +} + +// InvalidCursorError is returned when a cursor is no longer valid. +type InvalidCursorError struct { + error +} + +// NewInvalidCursorErr constructs a new unimplemented error. +func NewInvalidCursorErr(dispatchCursorVersion uint32, cursor *dispatch.Cursor) error { + return InvalidCursorError{ + error: fmt.Errorf("the supplied cursor is no longer valid: found version %d, expected version %d", cursor.DispatchVersion, dispatchCursorVersion), + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err InvalidCursorError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_INVALID_CURSOR, + map[string]string{ + "details": "cursor was used against an incompatible version of SpiceDB", + }, + ), + ) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/expand.go b/vendor/github.com/authzed/spicedb/internal/graph/expand.go new file mode 100644 index 0000000..9418bec --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/expand.go @@ -0,0 +1,436 @@ +package graph + +import ( + "context" + "errors" + + "github.com/authzed/spicedb/internal/caveats" + + "github.com/authzed/spicedb/internal/dispatch" + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// NewConcurrentExpander creates an instance of ConcurrentExpander +func NewConcurrentExpander(d dispatch.Expand) *ConcurrentExpander { + return &ConcurrentExpander{d: d} +} + +// ConcurrentExpander exposes a method to perform Expand requests, and delegates subproblems to the +// provided dispatch.Expand instance. +type ConcurrentExpander struct { + d dispatch.Expand +} + +// ValidatedExpandRequest represents a request after it has been validated and parsed for internal +// consumption. +type ValidatedExpandRequest struct { + *v1.DispatchExpandRequest + Revision datastore.Revision +} + +// Expand performs an expand request with the provided request and context. +func (ce *ConcurrentExpander) Expand(ctx context.Context, req ValidatedExpandRequest, relation *core.Relation) (*v1.DispatchExpandResponse, error) { + log.Ctx(ctx).Trace().Object("expand", req).Send() + + var directFunc ReduceableExpandFunc + if relation.UsersetRewrite == nil { + directFunc = ce.expandDirect(ctx, req) + } else { + directFunc = ce.expandUsersetRewrite(ctx, req, relation.UsersetRewrite) + } + + resolved := expandOne(ctx, directFunc) + resolved.Resp.Metadata = addCallToResponseMetadata(resolved.Resp.Metadata) + return resolved.Resp, resolved.Err +} + +func (ce *ConcurrentExpander) expandDirect( + ctx context.Context, + req ValidatedExpandRequest, +) ReduceableExpandFunc { + log.Ctx(ctx).Trace().Object("direct", req).Send() + return func(ctx context.Context, resultChan chan<- ExpandResult) { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: req.ResourceAndRelation.Namespace, + OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId}, + OptionalResourceRelation: req.ResourceAndRelation.Relation, + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata) + return + } + + var foundNonTerminalUsersets []*core.DirectSubject + var foundTerminalUsersets []*core.DirectSubject + for rel, err := range it { + if err != nil { + resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata) + return + } + + ds := &core.DirectSubject{ + Subject: rel.Subject.ToCoreONR(), + CaveatExpression: caveats.CaveatAsExpr(rel.OptionalCaveat), + } + + if rel.Subject.Relation == Ellipsis { + foundTerminalUsersets = append(foundTerminalUsersets, ds) + } else { + foundNonTerminalUsersets = append(foundNonTerminalUsersets, ds) + } + } + + // If only shallow expansion was required, or there are no non-terminal subjects found, + // nothing more to do. + if req.ExpansionMode == v1.DispatchExpandRequest_SHALLOW || len(foundNonTerminalUsersets) == 0 { + resultChan <- expandResult( + &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_LeafNode{ + LeafNode: &core.DirectSubjects{ + Subjects: append(foundTerminalUsersets, foundNonTerminalUsersets...), + }, + }, + Expanded: req.ResourceAndRelation, + }, + emptyMetadata, + ) + return + } + + // Otherwise, recursively issue expansion and collect the results from that, plus the + // found terminals together. + var requestsToDispatch []ReduceableExpandFunc + for _, nonTerminalUser := range foundNonTerminalUsersets { + toDispatch := ce.dispatch(ValidatedExpandRequest{ + &v1.DispatchExpandRequest{ + ResourceAndRelation: nonTerminalUser.Subject, + Metadata: decrementDepth(req.Metadata), + ExpansionMode: req.ExpansionMode, + }, + req.Revision, + }) + + requestsToDispatch = append(requestsToDispatch, decorateWithCaveatIfNecessary(toDispatch, nonTerminalUser.CaveatExpression)) + } + + result := expandAny(ctx, req.ResourceAndRelation, requestsToDispatch) + if result.Err != nil { + resultChan <- result + return + } + + unionNode := result.Resp.TreeNode.GetIntermediateNode() + unionNode.ChildNodes = append(unionNode.ChildNodes, &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_LeafNode{ + LeafNode: &core.DirectSubjects{ + Subjects: append(foundTerminalUsersets, foundNonTerminalUsersets...), + }, + }, + Expanded: req.ResourceAndRelation, + }) + resultChan <- result + } +} + +func decorateWithCaveatIfNecessary(toDispatch ReduceableExpandFunc, caveatExpr *core.CaveatExpression) ReduceableExpandFunc { + // If no caveat expression, simply return the func unmodified. + if caveatExpr == nil { + return toDispatch + } + + // Otherwise return a wrapped function that expands the underlying func to be dispatched, and then decorates + // the resulting node with the caveat expression. + // + // TODO(jschorr): This will generate a lot of function closures, so we should change Expand to avoid them + // like we did in Check. + return func(ctx context.Context, resultChan chan<- ExpandResult) { + result := expandOne(ctx, toDispatch) + if result.Err != nil { + resultChan <- result + return + } + + result.Resp.TreeNode.CaveatExpression = caveatExpr + resultChan <- result + } +} + +func (ce *ConcurrentExpander) expandUsersetRewrite(ctx context.Context, req ValidatedExpandRequest, usr *core.UsersetRewrite) ReduceableExpandFunc { + switch rw := usr.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + log.Ctx(ctx).Trace().Msg("union") + return ce.expandSetOperation(ctx, req, rw.Union, expandAny) + case *core.UsersetRewrite_Intersection: + log.Ctx(ctx).Trace().Msg("intersection") + return ce.expandSetOperation(ctx, req, rw.Intersection, expandAll) + case *core.UsersetRewrite_Exclusion: + log.Ctx(ctx).Trace().Msg("exclusion") + return ce.expandSetOperation(ctx, req, rw.Exclusion, expandDifference) + default: + return alwaysFailExpand + } +} + +func (ce *ConcurrentExpander) expandSetOperation(ctx context.Context, req ValidatedExpandRequest, so *core.SetOperation, reducer ExpandReducer) ReduceableExpandFunc { + var requests []ReduceableExpandFunc + for _, childOneof := range so.Child { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_XThis: + return expandError(errors.New("use of _this is unsupported; please rewrite your schema")) + case *core.SetOperation_Child_ComputedUserset: + requests = append(requests, ce.expandComputedUserset(ctx, req, child.ComputedUserset, nil)) + case *core.SetOperation_Child_UsersetRewrite: + requests = append(requests, ce.expandUsersetRewrite(ctx, req, child.UsersetRewrite)) + case *core.SetOperation_Child_TupleToUserset: + requests = append(requests, expandTupleToUserset(ctx, ce, req, child.TupleToUserset, expandAny)) + case *core.SetOperation_Child_FunctionedTupleToUserset: + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + requests = append(requests, expandTupleToUserset(ctx, ce, req, child.FunctionedTupleToUserset, expandAny)) + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + requests = append(requests, expandTupleToUserset(ctx, ce, req, child.FunctionedTupleToUserset, expandAll)) + + default: + return expandError(spiceerrors.MustBugf("unknown function `%s` in expand", child.FunctionedTupleToUserset.Function)) + } + case *core.SetOperation_Child_XNil: + requests = append(requests, emptyExpansion(req.ResourceAndRelation)) + default: + return expandError(spiceerrors.MustBugf("unknown set operation child `%T` in expand", child)) + } + } + return func(ctx context.Context, resultChan chan<- ExpandResult) { + resultChan <- reducer(ctx, req.ResourceAndRelation, requests) + } +} + +func (ce *ConcurrentExpander) dispatch(req ValidatedExpandRequest) ReduceableExpandFunc { + return func(ctx context.Context, resultChan chan<- ExpandResult) { + log.Ctx(ctx).Trace().Object("dispatchExpand", req).Send() + result, err := ce.d.DispatchExpand(ctx, req.DispatchExpandRequest) + resultChan <- ExpandResult{result, err} + } +} + +func (ce *ConcurrentExpander) expandComputedUserset(ctx context.Context, req ValidatedExpandRequest, cu *core.ComputedUserset, rel *tuple.Relationship) ReduceableExpandFunc { + log.Ctx(ctx).Trace().Str("relation", cu.Relation).Msg("computed userset") + var start tuple.ObjectAndRelation + if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT { + if rel == nil { + return expandError(spiceerrors.MustBugf("computed userset for tupleset without tuple")) + } + + start = rel.Subject + } else if cu.Object == core.ComputedUserset_TUPLE_OBJECT { + if rel != nil { + start = rel.Resource + } else { + start = tuple.FromCoreObjectAndRelation(req.ResourceAndRelation) + } + } + + // Check if the target relation exists. If not, return nothing. + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + err := namespace.CheckNamespaceAndRelation(ctx, start.ObjectType, cu.Relation, true, ds) + if err != nil { + if errors.As(err, &namespace.RelationNotFoundError{}) { + return emptyExpansion(req.ResourceAndRelation) + } + + return expandError(err) + } + + return ce.dispatch(ValidatedExpandRequest{ + &v1.DispatchExpandRequest{ + ResourceAndRelation: &core.ObjectAndRelation{ + Namespace: start.ObjectType, + ObjectId: start.ObjectID, + Relation: cu.Relation, + }, + Metadata: decrementDepth(req.Metadata), + ExpansionMode: req.ExpansionMode, + }, + req.Revision, + }) +} + +type expandFunc func(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult + +func expandTupleToUserset[T relation]( + _ context.Context, + ce *ConcurrentExpander, + req ValidatedExpandRequest, + ttu ttu[T], + expandFunc expandFunc, +) ReduceableExpandFunc { + return func(ctx context.Context, resultChan chan<- ExpandResult) { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: req.ResourceAndRelation.Namespace, + OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId}, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata) + return + } + + var requestsToDispatch []ReduceableExpandFunc + for rel, err := range it { + if err != nil { + resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata) + return + } + + toDispatch := ce.expandComputedUserset(ctx, req, ttu.GetComputedUserset(), &rel) + requestsToDispatch = append(requestsToDispatch, decorateWithCaveatIfNecessary(toDispatch, caveats.CaveatAsExpr(rel.OptionalCaveat))) + } + + resultChan <- expandFunc(ctx, req.ResourceAndRelation, requestsToDispatch) + } +} + +func setResult( + op core.SetOperationUserset_Operation, + start *core.ObjectAndRelation, + children []*core.RelationTupleTreeNode, + metadata *v1.ResponseMeta, +) ExpandResult { + return expandResult( + &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_IntermediateNode{ + IntermediateNode: &core.SetOperationUserset{ + Operation: op, + ChildNodes: children, + }, + }, + Expanded: start, + }, + metadata, + ) +} + +func expandSetOperation( + ctx context.Context, + start *core.ObjectAndRelation, + requests []ReduceableExpandFunc, + op core.SetOperationUserset_Operation, +) ExpandResult { + children := make([]*core.RelationTupleTreeNode, 0, len(requests)) + + if len(requests) == 0 { + return setResult(op, start, children, emptyMetadata) + } + + childCtx, cancelFn := context.WithCancel(ctx) + defer cancelFn() + + resultChans := make([]chan ExpandResult, 0, len(requests)) + for _, req := range requests { + resultChan := make(chan ExpandResult, 1) + resultChans = append(resultChans, resultChan) + go req(childCtx, resultChan) + } + + responseMetadata := emptyMetadata + for _, resultChan := range resultChans { + select { + case result := <-resultChan: + responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata) + if result.Err != nil { + return expandResultError(result.Err, responseMetadata) + } + children = append(children, result.Resp.TreeNode) + case <-ctx.Done(): + return expandResultError(context.Canceled, responseMetadata) + } + } + + return setResult(op, start, children, responseMetadata) +} + +// emptyExpansion returns an empty expansion. +func emptyExpansion(start *core.ObjectAndRelation) ReduceableExpandFunc { + return func(ctx context.Context, resultChan chan<- ExpandResult) { + resultChan <- expandResult(&core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_LeafNode{ + LeafNode: &core.DirectSubjects{}, + }, + Expanded: start, + }, emptyMetadata) + } +} + +// expandError returns the error. +func expandError(err error) ReduceableExpandFunc { + return func(ctx context.Context, resultChan chan<- ExpandResult) { + resultChan <- expandResultError(err, emptyMetadata) + } +} + +// expandAll returns a tree with all of the children and an intersection node type. +func expandAll(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult { + return expandSetOperation(ctx, start, requests, core.SetOperationUserset_INTERSECTION) +} + +// expandAny returns a tree with all of the children and a union node type. +func expandAny(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult { + return expandSetOperation(ctx, start, requests, core.SetOperationUserset_UNION) +} + +// expandDifference returns a tree with all of the children and an exclusion node type. +func expandDifference(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult { + return expandSetOperation(ctx, start, requests, core.SetOperationUserset_EXCLUSION) +} + +// expandOne waits for exactly one response +func expandOne(ctx context.Context, request ReduceableExpandFunc) ExpandResult { + resultChan := make(chan ExpandResult, 1) + go request(ctx, resultChan) + + select { + case result := <-resultChan: + if result.Err != nil { + return result + } + return result + case <-ctx.Done(): + return expandResultError(context.Canceled, emptyMetadata) + } +} + +var errAlwaysFailExpand = errors.New("always fail") + +func alwaysFailExpand(_ context.Context, resultChan chan<- ExpandResult) { + resultChan <- expandResultError(errAlwaysFailExpand, emptyMetadata) +} + +func expandResult(treeNode *core.RelationTupleTreeNode, subProblemMetadata *v1.ResponseMeta) ExpandResult { + return ExpandResult{ + &v1.DispatchExpandResponse{ + Metadata: ensureMetadata(subProblemMetadata), + TreeNode: treeNode, + }, + nil, + } +} + +func expandResultError(err error, subProblemMetadata *v1.ResponseMeta) ExpandResult { + return ExpandResult{ + &v1.DispatchExpandResponse{ + Metadata: ensureMetadata(subProblemMetadata), + }, + err, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/graph.go b/vendor/github.com/authzed/spicedb/internal/graph/graph.go new file mode 100644 index 0000000..2a44189 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/graph.go @@ -0,0 +1,89 @@ +package graph + +import ( + "context" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" +) + +// Ellipsis relation is used to signify a semantic-free relationship. +const Ellipsis = "..." + +// CheckResult is the data that is returned by a single check or sub-check. +type CheckResult struct { + Resp *v1.DispatchCheckResponse + Err error +} + +func (cr CheckResult) ResultError() error { + return cr.Err +} + +// ExpandResult is the data that is returned by a single expand or sub-expand. +type ExpandResult struct { + Resp *v1.DispatchExpandResponse + Err error +} + +func (er ExpandResult) ResultError() error { + return er.Err +} + +// ReduceableExpandFunc is a function that can be bound to a execution context. +type ReduceableExpandFunc func(ctx context.Context, resultChan chan<- ExpandResult) + +// AlwaysFailExpand is a ReduceableExpandFunc which will always fail when reduced. +func AlwaysFailExpand(_ context.Context, resultChan chan<- ExpandResult) { + resultChan <- expandResultError(NewAlwaysFailErr(), emptyMetadata) +} + +// ExpandReducer is a type for the functions Any and All which combine check results. +type ExpandReducer func( + ctx context.Context, + start *core.ObjectAndRelation, + requests []ReduceableExpandFunc, +) ExpandResult + +func decrementDepth(md *v1.ResolverMeta) *v1.ResolverMeta { + return &v1.ResolverMeta{ + AtRevision: md.AtRevision, + DepthRemaining: md.DepthRemaining - 1, + TraversalBloom: md.TraversalBloom, + } +} + +var emptyMetadata = &v1.ResponseMeta{} + +func ensureMetadata(subProblemMetadata *v1.ResponseMeta) *v1.ResponseMeta { + if subProblemMetadata == nil { + subProblemMetadata = emptyMetadata + } + + return &v1.ResponseMeta{ + DispatchCount: subProblemMetadata.DispatchCount, + DepthRequired: subProblemMetadata.DepthRequired, + CachedDispatchCount: subProblemMetadata.CachedDispatchCount, + DebugInfo: subProblemMetadata.DebugInfo, + } +} + +func addCallToResponseMetadata(metadata *v1.ResponseMeta) *v1.ResponseMeta { + // + 1 for the current call. + return &v1.ResponseMeta{ + DispatchCount: metadata.DispatchCount + 1, + DepthRequired: metadata.DepthRequired + 1, + CachedDispatchCount: metadata.CachedDispatchCount, + DebugInfo: metadata.DebugInfo, + } +} + +func addAdditionalDepthRequired(metadata *v1.ResponseMeta) *v1.ResponseMeta { + return &v1.ResponseMeta{ + DispatchCount: metadata.DispatchCount, + DepthRequired: metadata.DepthRequired + 1, + CachedDispatchCount: metadata.CachedDispatchCount, + DebugInfo: metadata.DebugInfo, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go b/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go new file mode 100644 index 0000000..485a3bf --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go @@ -0,0 +1,96 @@ +package hints + +import ( + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// CheckHintForComputedUserset creates a CheckHint for a relation and a subject. +func CheckHintForComputedUserset(resourceType string, resourceID string, relation string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) *v1.CheckHint { + return &v1.CheckHint{ + Resource: &core.ObjectAndRelation{ + Namespace: resourceType, + ObjectId: resourceID, + Relation: relation, + }, + Subject: subject.ToCoreONR(), + Result: result, + } +} + +// CheckHintForArrow creates a CheckHint for an arrow and a subject. +func CheckHintForArrow(resourceType string, resourceID string, tuplesetRelation string, computedUsersetRelation string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) *v1.CheckHint { + return &v1.CheckHint{ + Resource: &core.ObjectAndRelation{ + Namespace: resourceType, + ObjectId: resourceID, + Relation: tuplesetRelation, + }, + TtuComputedUsersetRelation: computedUsersetRelation, + Subject: subject.ToCoreONR(), + Result: result, + } +} + +// AsCheckHintForComputedUserset returns the resourceID if the checkHint is for the given relation and subject. +func AsCheckHintForComputedUserset(checkHint *v1.CheckHint, resourceType string, relationName string, subject tuple.ObjectAndRelation) (string, bool) { + if checkHint.TtuComputedUsersetRelation != "" { + return "", false + } + + if checkHint.Resource.Namespace == resourceType && checkHint.Resource.Relation == relationName && checkHint.Subject.EqualVT(subject.ToCoreONR()) { + return checkHint.Resource.ObjectId, true + } + + return "", false +} + +// AsCheckHintForArrow returns the resourceID if the checkHint is for the given arrow and subject. +func AsCheckHintForArrow(checkHint *v1.CheckHint, resourceType string, tuplesetRelation string, computedUsersetRelation string, subject tuple.ObjectAndRelation) (string, bool) { + if checkHint.TtuComputedUsersetRelation != computedUsersetRelation { + return "", false + } + + if checkHint.Resource.Namespace == resourceType && checkHint.Resource.Relation == tuplesetRelation && checkHint.Subject.EqualVT(subject.ToCoreONR()) { + return checkHint.Resource.ObjectId, true + } + + return "", false +} + +// HintForEntrypoint returns a CheckHint for the given reachability graph entrypoint and associated subject and result. +func HintForEntrypoint(re schema.ReachabilityEntrypoint, resourceID string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) (*v1.CheckHint, error) { + switch re.EntrypointKind() { + case core.ReachabilityEntrypoint_RELATION_ENTRYPOINT: + return nil, spiceerrors.MustBugf("cannot call CheckHintForResource for kind %v", re.EntrypointKind()) + + case core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT: + namespace := re.TargetNamespace() + tuplesetRelation, err := re.TuplesetRelation() + if err != nil { + return nil, err + } + + computedUsersetRelation, err := re.ComputedUsersetRelation() + if err != nil { + return nil, err + } + + return CheckHintForArrow(namespace, resourceID, tuplesetRelation, computedUsersetRelation, subject, result), nil + + case core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT: + namespace := re.TargetNamespace() + relation, err := re.ComputedUsersetRelation() + if err != nil { + return nil, err + } + + return CheckHintForComputedUserset(namespace, resourceID, relation, subject, result), nil + + default: + return nil, spiceerrors.MustBugf("unknown relation entrypoint kind") + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/limits.go b/vendor/github.com/authzed/spicedb/internal/graph/limits.go new file mode 100644 index 0000000..6b2e2bd --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/limits.go @@ -0,0 +1,80 @@ +package graph + +import ( + "fmt" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var ErrLimitReached = fmt.Errorf("limit has been reached") + +// limitTracker is a helper struct for tracking the limit requested by a caller and decrementing +// that limit as results are published. +type limitTracker struct { + hasLimit bool + currentLimit uint32 +} + +// newLimitTracker creates a new limit tracker, returning the tracker. +func newLimitTracker(optionalLimit uint32) *limitTracker { + return &limitTracker{ + currentLimit: optionalLimit, + hasLimit: optionalLimit > 0, + } +} + +// clone creates a copy of the limitTracker, inheriting the current limit. +func (lt *limitTracker) clone() *limitTracker { + return &limitTracker{ + currentLimit: lt.currentLimit, + hasLimit: lt.hasLimit, + } +} + +// prepareForPublishing asks the limit tracker to remove an element from the limit requested, +// returning whether that element can be published. +// +// Example usage: +// +// okay := limits.prepareForPublishing() +// if okay { ... publish ... } +func (lt *limitTracker) prepareForPublishing() bool { + // if there is no limit defined, then the count is always allowed. + if !lt.hasLimit { + return true + } + + // if the limit has been reached, allow no further items to be published. + if lt.currentLimit == 0 { + return false + } + + // otherwise, remove the element from the limit. + lt.currentLimit-- + return true +} + +// markAlreadyPublished marks that the given count of results has already been published. If the count is +// greater than the limit, returns a spiceerror. +func (lt *limitTracker) markAlreadyPublished(count uint32) error { + if !lt.hasLimit { + return nil + } + + if count > lt.currentLimit { + return spiceerrors.MustBugf("given published count of %d exceeds the remaining limit of %d", count, lt.currentLimit) + } + + lt.currentLimit -= count + if lt.currentLimit == 0 { + return nil + } + + return nil +} + +// hasExhaustedLimit returns true if the limit has been reached and all items allowable have been +// published. +func (lt *limitTracker) hasExhaustedLimit() bool { + return lt.hasLimit && lt.currentLimit == 0 +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go b/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go new file mode 100644 index 0000000..57acb49 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go @@ -0,0 +1,681 @@ +package graph + +import ( + "context" + "slices" + "sort" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/authzed/spicedb/internal/caveats" + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph/computed" + "github.com/authzed/spicedb/internal/graph/hints" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// dispatchVersion defines the "version" of this dispatcher. Must be incremented +// anytime an incompatible change is made to the dispatcher itself or its cursor +// production. +const dispatchVersion = 1 + +func NewCursoredLookupResources2(dl dispatch.LookupResources2, dc dispatch.Check, caveatTypeSet *caveattypes.TypeSet, concurrencyLimit uint16, dispatchChunkSize uint16) *CursoredLookupResources2 { + return &CursoredLookupResources2{dl, dc, caveatTypeSet, concurrencyLimit, dispatchChunkSize} +} + +type CursoredLookupResources2 struct { + dl dispatch.LookupResources2 + dc dispatch.Check + caveatTypeSet *caveattypes.TypeSet + concurrencyLimit uint16 + dispatchChunkSize uint16 +} + +type ValidatedLookupResources2Request struct { + *v1.DispatchLookupResources2Request + Revision datastore.Revision +} + +func (crr *CursoredLookupResources2) LookupResources2( + req ValidatedLookupResources2Request, + stream dispatch.LookupResources2Stream, +) error { + ctx, span := tracer.Start(stream.Context(), "lookupResources2") + defer span.End() + + if req.TerminalSubject == nil { + return spiceerrors.MustBugf("no terminal subject given to lookup resources dispatch") + } + + if slices.Contains(req.SubjectIds, tuple.PublicWildcard) { + return NewWildcardNotAllowedErr("cannot perform lookup resources on wildcard", "subject_id") + } + + if len(req.SubjectIds) == 0 { + return spiceerrors.MustBugf("no subjects ids given to lookup resources dispatch") + } + + // Sort for stability. + if len(req.SubjectIds) > 1 { + sort.Strings(req.SubjectIds) + } + + limits := newLimitTracker(req.OptionalLimit) + ci, err := newCursorInformation(req.OptionalCursor, limits, dispatchVersion) + if err != nil { + return err + } + + return withSubsetInCursor(ci, + func(currentOffset int, nextCursorWith afterResponseCursor) error { + // If the resource type matches the subject type, yield directly as a one-to-one result + // for each subjectID. + if req.SubjectRelation.Namespace == req.ResourceRelation.Namespace && + req.SubjectRelation.Relation == req.ResourceRelation.Relation { + for index, subjectID := range req.SubjectIds { + if index < currentOffset { + continue + } + + if !ci.limits.prepareForPublishing() { + return nil + } + + err := stream.Publish(&v1.DispatchLookupResources2Response{ + Resource: &v1.PossibleResource{ + ResourceId: subjectID, + ForSubjectIds: []string{subjectID}, + }, + Metadata: emptyMetadata, + AfterResponseCursor: nextCursorWith(index + 1), + }) + if err != nil { + return err + } + } + } + return nil + }, func(ci cursorInformation) error { + // Once done checking for the matching subject type, yield by dispatching over entrypoints. + return crr.afterSameType(ctx, ci, req, stream) + }) +} + +func (crr *CursoredLookupResources2) afterSameType( + ctx context.Context, + ci cursorInformation, + req ValidatedLookupResources2Request, + parentStream dispatch.LookupResources2Stream, +) error { + reachabilityForString := req.ResourceRelation.Namespace + "#" + req.ResourceRelation.Relation + ctx, span := tracer.Start(ctx, "reachability: "+reachabilityForString) + defer span.End() + + dispatched := NewSyncONRSet() + + // Load the type system and reachability graph to find the entrypoints for the reachability. + ds := datastoremw.MustFromContext(ctx) + reader := ds.SnapshotReader(req.Revision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(reader)) + vdef, err := ts.GetValidatedDefinition(ctx, req.ResourceRelation.Namespace) + if err != nil { + return err + } + + rg := vdef.Reachability() + entrypoints, err := rg.FirstEntrypointsForSubjectToResource(ctx, &core.RelationReference{ + Namespace: req.SubjectRelation.Namespace, + Relation: req.SubjectRelation.Relation, + }, req.ResourceRelation) + if err != nil { + return err + } + + // For each entrypoint, load the necessary data and re-dispatch if a subproblem was found. + return withParallelizedStreamingIterableInCursor(ctx, ci, entrypoints, parentStream, crr.concurrencyLimit, + func(ctx context.Context, ci cursorInformation, entrypoint schema.ReachabilityEntrypoint, stream dispatch.LookupResources2Stream) error { + ds, err := entrypoint.DebugString() + spiceerrors.DebugAssert(func() bool { + return err == nil + }, "Error in entrypoint.DebugString()") + ctx, span := tracer.Start(ctx, "entrypoint: "+ds, trace.WithAttributes()) + defer span.End() + + switch entrypoint.EntrypointKind() { + case core.ReachabilityEntrypoint_RELATION_ENTRYPOINT: + return crr.lookupRelationEntrypoint(ctx, ci, entrypoint, rg, ts, reader, req, stream, dispatched) + + case core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT: + containingRelation := entrypoint.ContainingRelationOrPermission() + rewrittenSubjectRelation := &core.RelationReference{ + Namespace: containingRelation.Namespace, + Relation: containingRelation.Relation, + } + + rsm := subjectIDsToResourcesMap2(rewrittenSubjectRelation, req.SubjectIds) + drsm := rsm.asReadOnly() + + return crr.redispatchOrReport( + ctx, + ci, + rewrittenSubjectRelation, + drsm, + rg, + entrypoint, + stream, + req, + dispatched, + ) + + case core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT: + return crr.lookupTTUEntrypoint(ctx, ci, entrypoint, rg, ts, reader, req, stream, dispatched) + + default: + return spiceerrors.MustBugf("Unknown kind of entrypoint: %v", entrypoint.EntrypointKind()) + } + }) +} + +func (crr *CursoredLookupResources2) lookupRelationEntrypoint( + ctx context.Context, + ci cursorInformation, + entrypoint schema.ReachabilityEntrypoint, + rg *schema.DefinitionReachability, + ts *schema.TypeSystem, + reader datastore.Reader, + req ValidatedLookupResources2Request, + stream dispatch.LookupResources2Stream, + dispatched *syncONRSet, +) error { + relationReference, err := entrypoint.DirectRelation() + if err != nil { + return err + } + + relDefinition, err := ts.GetValidatedDefinition(ctx, relationReference.Namespace) + if err != nil { + return err + } + + // Build the list of subjects to lookup based on the type information available. + isDirectAllowed, err := relDefinition.IsAllowedDirectRelation( + relationReference.Relation, + req.SubjectRelation.Namespace, + req.SubjectRelation.Relation, + ) + if err != nil { + return err + } + + subjectIds := make([]string, 0, len(req.SubjectIds)+1) + if isDirectAllowed == schema.DirectRelationValid { + subjectIds = append(subjectIds, req.SubjectIds...) + } + + if req.SubjectRelation.Relation == tuple.Ellipsis { + isWildcardAllowed, err := relDefinition.IsAllowedPublicNamespace(relationReference.Relation, req.SubjectRelation.Namespace) + if err != nil { + return err + } + + if isWildcardAllowed == schema.PublicSubjectAllowed { + subjectIds = append(subjectIds, "*") + } + } + + // Lookup the subjects and then redispatch/report results. + relationFilter := datastore.SubjectRelationFilter{ + NonEllipsisRelation: req.SubjectRelation.Relation, + } + + if req.SubjectRelation.Relation == tuple.Ellipsis { + relationFilter = datastore.SubjectRelationFilter{ + IncludeEllipsisRelation: true, + } + } + + subjectsFilter := datastore.SubjectsFilter{ + SubjectType: req.SubjectRelation.Namespace, + OptionalSubjectIds: subjectIds, + RelationFilter: relationFilter, + } + + return crr.redispatchOrReportOverDatabaseQuery( + ctx, + redispatchOverDatabaseConfig2{ + ci: ci, + ts: ts, + reader: reader, + subjectsFilter: subjectsFilter, + sourceResourceType: relationReference, + foundResourceType: relationReference, + entrypoint: entrypoint, + rg: rg, + concurrencyLimit: crr.concurrencyLimit, + parentStream: stream, + parentRequest: req, + dispatched: dispatched, + }, + ) +} + +type redispatchOverDatabaseConfig2 struct { + ci cursorInformation + + ts *schema.TypeSystem + + // Direct reader for reverse ReverseQueryRelationships + reader datastore.Reader + + subjectsFilter datastore.SubjectsFilter + sourceResourceType *core.RelationReference + foundResourceType *core.RelationReference + + entrypoint schema.ReachabilityEntrypoint + rg *schema.DefinitionReachability + + concurrencyLimit uint16 + parentStream dispatch.LookupResources2Stream + parentRequest ValidatedLookupResources2Request + dispatched *syncONRSet +} + +func (crr *CursoredLookupResources2) redispatchOrReportOverDatabaseQuery( + ctx context.Context, + config redispatchOverDatabaseConfig2, +) error { + ctx, span := tracer.Start(ctx, "datastorequery", trace.WithAttributes( + attribute.String("source-resource-type-namespace", config.sourceResourceType.Namespace), + attribute.String("source-resource-type-relation", config.sourceResourceType.Relation), + attribute.String("subjects-filter-subject-type", config.subjectsFilter.SubjectType), + attribute.Int("subjects-filter-subject-ids-count", len(config.subjectsFilter.OptionalSubjectIds)), + )) + defer span.End() + + return withDatastoreCursorInCursor(ctx, config.ci, config.parentStream, config.concurrencyLimit, + // Find the target resources for the subject. + func(queryCursor options.Cursor) ([]itemAndPostCursor[dispatchableResourcesSubjectMap2], error) { + it, err := config.reader.ReverseQueryRelationships( + ctx, + config.subjectsFilter, + options.WithResRelation(&options.ResourceRelation{ + Namespace: config.sourceResourceType.Namespace, + Relation: config.sourceResourceType.Relation, + }), + options.WithSortForReverse(options.BySubject), + options.WithAfterForReverse(queryCursor), + options.WithQueryShapeForReverse(queryshape.MatchingResourcesForSubject), + ) + if err != nil { + return nil, err + } + + // Chunk based on the FilterMaximumIDCount, to ensure we never send more than that amount of + // results to a downstream dispatch. + rsm := newResourcesSubjectMap2WithCapacity(config.sourceResourceType, uint32(crr.dispatchChunkSize)) + toBeHandled := make([]itemAndPostCursor[dispatchableResourcesSubjectMap2], 0) + currentCursor := queryCursor + caveatRunner := caveats.NewCaveatRunner(crr.caveatTypeSet) + + for rel, err := range it { + if err != nil { + return nil, err + } + + var missingContextParameters []string + + // If a caveat exists on the relationship, run it and filter the results, marking those that have missing context. + if rel.OptionalCaveat != nil && rel.OptionalCaveat.CaveatName != "" { + caveatExpr := caveats.CaveatAsExpr(rel.OptionalCaveat) + runResult, err := caveatRunner.RunCaveatExpression(ctx, caveatExpr, config.parentRequest.Context.AsMap(), config.reader, caveats.RunCaveatExpressionNoDebugging) + if err != nil { + return nil, err + } + + // If a partial result is returned, collect the missing context parameters. + if runResult.IsPartial() { + missingNames, err := runResult.MissingVarNames() + if err != nil { + return nil, err + } + + missingContextParameters = missingNames + } else if !runResult.Value() { + // If the run result shows the caveat does not apply, skip. This shears the tree of results early. + continue + } + } + + if err := rsm.addRelationship(rel, missingContextParameters); err != nil { + return nil, err + } + + if rsm.len() == int(crr.dispatchChunkSize) { + toBeHandled = append(toBeHandled, itemAndPostCursor[dispatchableResourcesSubjectMap2]{ + item: rsm.asReadOnly(), + cursor: currentCursor, + }) + rsm = newResourcesSubjectMap2WithCapacity(config.sourceResourceType, uint32(crr.dispatchChunkSize)) + currentCursor = options.ToCursor(rel) + } + } + + if rsm.len() > 0 { + toBeHandled = append(toBeHandled, itemAndPostCursor[dispatchableResourcesSubjectMap2]{ + item: rsm.asReadOnly(), + cursor: currentCursor, + }) + } + + return toBeHandled, nil + }, + + // Redispatch or report the results. + func( + ctx context.Context, + ci cursorInformation, + drsm dispatchableResourcesSubjectMap2, + currentStream dispatch.LookupResources2Stream, + ) error { + return crr.redispatchOrReport( + ctx, + ci, + config.foundResourceType, + drsm, + config.rg, + config.entrypoint, + currentStream, + config.parentRequest, + config.dispatched, + ) + }, + ) +} + +func (crr *CursoredLookupResources2) lookupTTUEntrypoint(ctx context.Context, + ci cursorInformation, + entrypoint schema.ReachabilityEntrypoint, + rg *schema.DefinitionReachability, + ts *schema.TypeSystem, + reader datastore.Reader, + req ValidatedLookupResources2Request, + stream dispatch.LookupResources2Stream, + dispatched *syncONRSet, +) error { + containingRelation := entrypoint.ContainingRelationOrPermission() + + ttuDef, err := ts.GetValidatedDefinition(ctx, containingRelation.Namespace) + if err != nil { + return err + } + + tuplesetRelation, err := entrypoint.TuplesetRelation() + if err != nil { + return err + } + + // Determine whether this TTU should be followed, which will be the case if the subject relation's namespace + // is allowed in any form on the relation; since arrows ignore the subject's relation (if any), we check + // for the subject namespace as a whole. + allowedRelations, err := ttuDef.GetAllowedDirectNamespaceSubjectRelations(tuplesetRelation, req.SubjectRelation.Namespace) + if err != nil { + return err + } + + if allowedRelations == nil { + return nil + } + + // Search for the resolved subjects in the tupleset of the TTU. + subjectsFilter := datastore.SubjectsFilter{ + SubjectType: req.SubjectRelation.Namespace, + OptionalSubjectIds: req.SubjectIds, + } + + // Optimization: if there is a single allowed relation, pass it as a subject relation filter to make things faster + // on querying. + if allowedRelations.Len() == 1 { + allowedRelationName := allowedRelations.AsSlice()[0] + subjectsFilter.RelationFilter = datastore.SubjectRelationFilter{}.WithRelation(allowedRelationName) + } + + tuplesetRelationReference := &core.RelationReference{ + Namespace: containingRelation.Namespace, + Relation: tuplesetRelation, + } + + return crr.redispatchOrReportOverDatabaseQuery( + ctx, + redispatchOverDatabaseConfig2{ + ci: ci, + ts: ts, + reader: reader, + subjectsFilter: subjectsFilter, + sourceResourceType: tuplesetRelationReference, + foundResourceType: containingRelation, + entrypoint: entrypoint, + rg: rg, + parentStream: stream, + parentRequest: req, + dispatched: dispatched, + }, + ) +} + +type possibleResourceAndIndex struct { + resource *v1.PossibleResource + index int +} + +// redispatchOrReport checks if further redispatching is necessary for the found resource +// type. If not, and the found resource type+relation matches the target resource type+relation, +// the resource is reported to the parent stream. +func (crr *CursoredLookupResources2) redispatchOrReport( + ctx context.Context, + ci cursorInformation, + foundResourceType *core.RelationReference, + foundResources dispatchableResourcesSubjectMap2, + rg *schema.DefinitionReachability, + entrypoint schema.ReachabilityEntrypoint, + parentStream dispatch.LookupResources2Stream, + parentRequest ValidatedLookupResources2Request, + dispatched *syncONRSet, +) error { + if foundResources.isEmpty() { + // Nothing more to do. + return nil + } + + ctx, span := tracer.Start(ctx, "redispatchOrReport", trace.WithAttributes( + attribute.Int("found-resources-count", foundResources.len()), + )) + defer span.End() + + // Check for entrypoints for the new found resource type. + hasResourceEntrypoints, err := rg.HasOptimizedEntrypointsForSubjectToResource(ctx, foundResourceType, parentRequest.ResourceRelation) + if err != nil { + return err + } + + return withSubsetInCursor(ci, + func(currentOffset int, nextCursorWith afterResponseCursor) error { + if !hasResourceEntrypoints { + // If the found resource matches the target resource type and relation, potentially yield the resource. + if foundResourceType.Namespace == parentRequest.ResourceRelation.Namespace && foundResourceType.Relation == parentRequest.ResourceRelation.Relation { + resources := foundResources.asPossibleResources() + if len(resources) == 0 { + return nil + } + + if currentOffset >= len(resources) { + return nil + } + + offsetted := resources[currentOffset:] + if len(offsetted) == 0 { + return nil + } + + filtered := make([]possibleResourceAndIndex, 0, len(offsetted)) + for index, resource := range offsetted { + filtered = append(filtered, possibleResourceAndIndex{ + resource: resource, + index: index, + }) + } + + metadata := emptyMetadata + + // If the entrypoint is not a direct result, issue a check to further filter the results on the intersection or exclusion. + if !entrypoint.IsDirectResult() { + resourceIDs := make([]string, 0, len(offsetted)) + checkHints := make([]*v1.CheckHint, 0, len(offsetted)) + for _, resource := range offsetted { + resourceIDs = append(resourceIDs, resource.ResourceId) + + checkHint, err := hints.HintForEntrypoint( + entrypoint, + resource.ResourceId, + tuple.FromCoreObjectAndRelation(parentRequest.TerminalSubject), + &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_MEMBER, + }) + if err != nil { + return err + } + checkHints = append(checkHints, checkHint) + } + + resultsByResourceID, checkMetadata, _, err := computed.ComputeBulkCheck(ctx, crr.dc, crr.caveatTypeSet, computed.CheckParameters{ + ResourceType: tuple.FromCoreRelationReference(parentRequest.ResourceRelation), + Subject: tuple.FromCoreObjectAndRelation(parentRequest.TerminalSubject), + CaveatContext: parentRequest.Context.AsMap(), + AtRevision: parentRequest.Revision, + MaximumDepth: parentRequest.Metadata.DepthRemaining - 1, + DebugOption: computed.NoDebugging, + CheckHints: checkHints, + }, resourceIDs, crr.dispatchChunkSize) + if err != nil { + return err + } + + metadata = addCallToResponseMetadata(checkMetadata) + + filtered = make([]possibleResourceAndIndex, 0, len(offsetted)) + for index, resource := range offsetted { + result, ok := resultsByResourceID[resource.ResourceId] + if !ok { + continue + } + + switch result.Membership { + case v1.ResourceCheckResult_MEMBER: + filtered = append(filtered, possibleResourceAndIndex{ + resource: resource, + index: index, + }) + + case v1.ResourceCheckResult_CAVEATED_MEMBER: + missingContextParams := mapz.NewSet(result.MissingExprFields...) + missingContextParams.Extend(resource.MissingContextParams) + + filtered = append(filtered, possibleResourceAndIndex{ + resource: &v1.PossibleResource{ + ResourceId: resource.ResourceId, + ForSubjectIds: resource.ForSubjectIds, + MissingContextParams: missingContextParams.AsSlice(), + }, + index: index, + }) + + case v1.ResourceCheckResult_NOT_MEMBER: + // Skip. + + default: + return spiceerrors.MustBugf("unexpected result from check: %v", result.Membership) + } + } + } + + for _, resourceAndIndex := range filtered { + if !ci.limits.prepareForPublishing() { + return nil + } + + err := parentStream.Publish(&v1.DispatchLookupResources2Response{ + Resource: resourceAndIndex.resource, + Metadata: metadata, + AfterResponseCursor: nextCursorWith(currentOffset + resourceAndIndex.index + 1), + }) + if err != nil { + return err + } + + metadata = emptyMetadata + } + return nil + } + } + return nil + }, func(ci cursorInformation) error { + if !hasResourceEntrypoints { + return nil + } + + // The new subject type for dispatching was the found type of the *resource*. + newSubjectType := foundResourceType + + // To avoid duplicate work, remove any subjects already dispatched. + filteredSubjectIDs := foundResources.filterSubjectIDsToDispatch(dispatched, newSubjectType) + if len(filteredSubjectIDs) == 0 { + return nil + } + + // If the entrypoint is a direct result then we can simply dispatch directly and map + // all found results, as no further filtering will be needed. + if entrypoint.IsDirectResult() { + stream := unfilteredLookupResourcesDispatchStreamForEntrypoint(ctx, foundResources, parentStream, ci) + return crr.dl.DispatchLookupResources2(&v1.DispatchLookupResources2Request{ + ResourceRelation: parentRequest.ResourceRelation, + SubjectRelation: newSubjectType, + SubjectIds: filteredSubjectIDs, + TerminalSubject: parentRequest.TerminalSubject, + Metadata: &v1.ResolverMeta{ + AtRevision: parentRequest.Revision.String(), + DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + }, + OptionalCursor: ci.currentCursor, + OptionalLimit: parentRequest.OptionalLimit, + Context: parentRequest.Context, + }, stream) + } + + // Otherwise, we need to filter results by batch checking along the way before dispatching. + return runCheckerAndDispatch( + ctx, + parentRequest, + foundResources, + ci, + parentStream, + newSubjectType, + filteredSubjectIDs, + entrypoint, + crr.dl, + crr.dc, + crr.caveatTypeSet, + crr.concurrencyLimit, + crr.dispatchChunkSize, + ) + }) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go b/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go new file mode 100644 index 0000000..7560847 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go @@ -0,0 +1,803 @@ +package graph + +import ( + "context" + "errors" + "fmt" + "sync" + + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/authzed/spicedb/internal/datasets" + "github.com/authzed/spicedb/internal/dispatch" + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/taskrunner" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/genutil/slicez" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ValidatedLookupSubjectsRequest represents a request after it has been validated and parsed for internal +// consumption. +type ValidatedLookupSubjectsRequest struct { + *v1.DispatchLookupSubjectsRequest + Revision datastore.Revision +} + +// NewConcurrentLookupSubjects creates an instance of ConcurrentLookupSubjects. +func NewConcurrentLookupSubjects(d dispatch.LookupSubjects, concurrencyLimit uint16, dispatchChunkSize uint16) *ConcurrentLookupSubjects { + return &ConcurrentLookupSubjects{d, concurrencyLimit, dispatchChunkSize} +} + +type ConcurrentLookupSubjects struct { + d dispatch.LookupSubjects + concurrencyLimit uint16 + dispatchChunkSize uint16 +} + +func (cl *ConcurrentLookupSubjects) LookupSubjects( + req ValidatedLookupSubjectsRequest, + stream dispatch.LookupSubjectsStream, +) error { + ctx := stream.Context() + + if len(req.ResourceIds) == 0 { + return fmt.Errorf("no resources ids given to lookupsubjects dispatch") + } + + // If the resource type matches the subject type, yield directly. + if req.SubjectRelation.Namespace == req.ResourceRelation.Namespace && + req.SubjectRelation.Relation == req.ResourceRelation.Relation { + if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: subjectsForConcreteIds(req.ResourceIds), + Metadata: emptyMetadata, + }); err != nil { + return err + } + } + + ds := datastoremw.MustFromContext(ctx) + reader := ds.SnapshotReader(req.Revision) + _, relation, err := namespace.ReadNamespaceAndRelation( + ctx, + req.ResourceRelation.Namespace, + req.ResourceRelation.Relation, + reader) + if err != nil { + return err + } + + if relation.UsersetRewrite == nil { + // Direct lookup of subjects. + return cl.lookupDirectSubjects(ctx, req, stream, relation, reader) + } + + return cl.lookupViaRewrite(ctx, req, stream, relation.UsersetRewrite) +} + +func subjectsForConcreteIds(subjectIds []string) map[string]*v1.FoundSubjects { + foundSubjects := make(map[string]*v1.FoundSubjects, len(subjectIds)) + for _, subjectID := range subjectIds { + foundSubjects[subjectID] = &v1.FoundSubjects{ + FoundSubjects: []*v1.FoundSubject{ + { + SubjectId: subjectID, + CaveatExpression: nil, // Explicitly nil since this is a concrete found subject. + }, + }, + } + } + return foundSubjects +} + +func (cl *ConcurrentLookupSubjects) lookupDirectSubjects( + ctx context.Context, + req ValidatedLookupSubjectsRequest, + stream dispatch.LookupSubjectsStream, + _ *core.Relation, + reader datastore.Reader, +) error { + // TODO(jschorr): use type information to skip subject relations that cannot reach the subject type. + + toDispatchByType := datasets.NewSubjectByTypeSet() + foundSubjectsByResourceID := datasets.NewSubjectSetByResourceID() + relationshipsBySubjectONR := mapz.NewMultiMap[tuple.ObjectAndRelation, tuple.Relationship]() + + it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: req.ResourceRelation.Namespace, + OptionalResourceRelation: req.ResourceRelation.Relation, + OptionalResourceIds: req.ResourceIds, + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + return err + } + + for rel, err := range it { + if err != nil { + return err + } + + if rel.Subject.ObjectType == req.SubjectRelation.Namespace && + rel.Subject.Relation == req.SubjectRelation.Relation { + if err := foundSubjectsByResourceID.AddFromRelationship(rel); err != nil { + return fmt.Errorf("failed to call AddFromRelationship in lookupDirectSubjects: %w", err) + } + } + + if rel.Subject.Relation != tuple.Ellipsis { + err := toDispatchByType.AddSubjectOf(rel) + if err != nil { + return err + } + + relationshipsBySubjectONR.Add(rel.Subject, rel) + } + } + + if !foundSubjectsByResourceID.IsEmpty() { + if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: foundSubjectsByResourceID.AsMap(), + Metadata: emptyMetadata, + }); err != nil { + return err + } + } + + return cl.dispatchTo(ctx, req, toDispatchByType, relationshipsBySubjectONR, stream) +} + +func (cl *ConcurrentLookupSubjects) lookupViaComputed( + ctx context.Context, + parentRequest ValidatedLookupSubjectsRequest, + parentStream dispatch.LookupSubjectsStream, + cu *core.ComputedUserset, +) error { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + if err := namespace.CheckNamespaceAndRelation(ctx, parentRequest.ResourceRelation.Namespace, cu.Relation, true, ds); err != nil { + if errors.As(err, &namespace.RelationNotFoundError{}) { + return nil + } + + return err + } + + stream := &dispatch.WrappedDispatchStream[*v1.DispatchLookupSubjectsResponse]{ + Stream: parentStream, + Ctx: ctx, + Processor: func(result *v1.DispatchLookupSubjectsResponse) (*v1.DispatchLookupSubjectsResponse, bool, error) { + return &v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: result.FoundSubjectsByResourceId, + Metadata: addCallToResponseMetadata(result.Metadata), + }, true, nil + }, + } + + return cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{ + ResourceRelation: &core.RelationReference{ + Namespace: parentRequest.ResourceRelation.Namespace, + Relation: cu.Relation, + }, + ResourceIds: parentRequest.ResourceIds, + SubjectRelation: parentRequest.SubjectRelation, + Metadata: &v1.ResolverMeta{ + AtRevision: parentRequest.Revision.String(), + DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + }, + }, stream) +} + +type resourceDispatchTracker struct { + ctx context.Context + cancelDispatch context.CancelFunc + resourceID string + + subjectsSet datasets.SubjectSet // GUARDED_BY(lock) + metadata *v1.ResponseMeta // GUARDED_BY(lock) + + isFirstUpdate bool // GUARDED_BY(lock) + wasCanceled bool // GUARDED_BY(lock) + + lock sync.Mutex +} + +func lookupViaIntersectionTupleToUserset( + ctx context.Context, + cl *ConcurrentLookupSubjects, + parentRequest ValidatedLookupSubjectsRequest, + parentStream dispatch.LookupSubjectsStream, + ttu *core.FunctionedTupleToUserset, +) error { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: parentRequest.ResourceRelation.Namespace, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + OptionalResourceIds: parentRequest.ResourceIds, + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + return err + } + + // TODO(jschorr): Find a means of doing this without dispatching per subject, per resource. Perhaps + // there is a way we can still dispatch to all the subjects at once, and then intersect the results + // afterwards. + resourceDispatchTrackerByResourceID := make(map[string]*resourceDispatchTracker) + + cancelCtx, checkCancel := context.WithCancel(ctx) + defer checkCancel() + + // For each found tuple, dispatch a lookup subjects request and collect its results. + // We need to intersect between *all* the found subjects for each resource ID. + var ttuCaveat *core.CaveatExpression + taskrunner := taskrunner.NewPreloadedTaskRunner(cancelCtx, cl.concurrencyLimit, 1) + for rel, err := range it { + if err != nil { + return err + } + + // If the relationship has a caveat, add it to the overall TTU caveat. Since this is an intersection + // of *all* branches, the caveat will be applied to all found subjects, so this is a safe approach. + if rel.OptionalCaveat != nil { + ttuCaveat = caveatAnd(ttuCaveat, wrapCaveat(rel.OptionalCaveat)) + } + + if err := namespace.CheckNamespaceAndRelation(ctx, rel.Subject.ObjectType, ttu.GetComputedUserset().Relation, false, ds); err != nil { + if !errors.As(err, &namespace.RelationNotFoundError{}) { + return err + } + + continue + } + + // Create a data structure to track the intersection of subjects for the particular resource. If the resource's subject set + // ends up empty anywhere along the way, the dispatches for *that resource* will be canceled early. + resourceID := rel.Resource.ObjectID + dispatchInfoForResource, ok := resourceDispatchTrackerByResourceID[resourceID] + if !ok { + dispatchCtx, cancelDispatch := context.WithCancel(cancelCtx) + dispatchInfoForResource = &resourceDispatchTracker{ + ctx: dispatchCtx, + cancelDispatch: cancelDispatch, + resourceID: resourceID, + subjectsSet: datasets.NewSubjectSet(), + metadata: emptyMetadata, + isFirstUpdate: true, + lock: sync.Mutex{}, + } + resourceDispatchTrackerByResourceID[resourceID] = dispatchInfoForResource + } + + rel := rel + taskrunner.Add(func(ctx context.Context) error { + // Collect all results for this branch of the resource ID. + // TODO(jschorr): once LS has cursoring (and thus, ordering), we can move to not collecting everything up before intersecting + // for this branch of the resource ID. + collectingStream := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](dispatchInfoForResource.ctx) + err := cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{ + ResourceRelation: &core.RelationReference{ + Namespace: rel.Subject.ObjectType, + Relation: ttu.GetComputedUserset().Relation, + }, + ResourceIds: []string{rel.Subject.ObjectID}, + SubjectRelation: parentRequest.SubjectRelation, + Metadata: &v1.ResolverMeta{ + AtRevision: parentRequest.Revision.String(), + DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + }, + }, collectingStream) + if err != nil { + // Check if the dispatches for the resource were canceled, and if so, return nil to stop the task. + dispatchInfoForResource.lock.Lock() + wasCanceled := dispatchInfoForResource.wasCanceled + dispatchInfoForResource.lock.Unlock() + + if wasCanceled { + if errors.Is(err, context.Canceled) { + return nil + } + + errStatus, ok := status.FromError(err) + if ok && errStatus.Code() == codes.Canceled { + return nil + } + } + + return err + } + + // Collect the results into a subject set. + results := datasets.NewSubjectSet() + collectedMetadata := emptyMetadata + for _, result := range collectingStream.Results() { + collectedMetadata = combineResponseMetadata(ctx, collectedMetadata, result.Metadata) + for _, foundSubjects := range result.FoundSubjectsByResourceId { + if err := results.UnionWith(foundSubjects.FoundSubjects); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsIntersection: %w", err) + } + } + } + + dispatchInfoForResource.lock.Lock() + defer dispatchInfoForResource.lock.Unlock() + + dispatchInfoForResource.metadata = combineResponseMetadata(ctx, dispatchInfoForResource.metadata, collectedMetadata) + + // If the first update for the resource, set the subjects set to the results. + if dispatchInfoForResource.isFirstUpdate { + dispatchInfoForResource.isFirstUpdate = false + dispatchInfoForResource.subjectsSet = results + } else { + // Otherwise, intersect the results with the existing subjects set. + err := dispatchInfoForResource.subjectsSet.IntersectionDifference(results) + if err != nil { + return err + } + } + + // If the subjects set is empty, cancel the dispatch for any further results for this resource ID. + if dispatchInfoForResource.subjectsSet.IsEmpty() { + dispatchInfoForResource.wasCanceled = true + dispatchInfoForResource.cancelDispatch() + } + + return nil + }) + } + + // Wait for all dispatched operations to complete. + if err := taskrunner.StartAndWait(); err != nil { + return err + } + + // For each resource ID, intersect the found subjects from each stream. + metadata := emptyMetadata + currentSubjectsByResourceID := map[string]*v1.FoundSubjects{} + + for incomingResourceID, tracker := range resourceDispatchTrackerByResourceID { + currentSubjects := tracker.subjectsSet + currentSubjects = currentSubjects.WithParentCaveatExpression(ttuCaveat) + currentSubjectsByResourceID[incomingResourceID] = currentSubjects.AsFoundSubjects() + + metadata = combineResponseMetadata(ctx, metadata, tracker.metadata) + } + + return parentStream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: currentSubjectsByResourceID, + Metadata: metadata, + }) +} + +func lookupViaTupleToUserset[T relation]( + ctx context.Context, + cl *ConcurrentLookupSubjects, + parentRequest ValidatedLookupSubjectsRequest, + parentStream dispatch.LookupSubjectsStream, + ttu ttu[T], +) error { + toDispatchByTuplesetType := datasets.NewSubjectByTypeSet() + relationshipsBySubjectONR := mapz.NewMultiMap[tuple.ObjectAndRelation, tuple.Relationship]() + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: parentRequest.ResourceRelation.Namespace, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + OptionalResourceIds: parentRequest.ResourceIds, + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + return err + } + + for rel, err := range it { + if err != nil { + return err + } + + // Add the subject to be dispatched. + err := toDispatchByTuplesetType.AddSubjectOf(rel) + if err != nil { + return err + } + + // Add the *rewritten* subject to the relationships multimap for mapping back to the associated + // relationship, as we will be mapping from the computed relation, not the tupleset relation. + relationshipsBySubjectONR.Add(tuple.ONR(rel.Subject.ObjectType, rel.Subject.ObjectID, ttu.GetComputedUserset().Relation), rel) + } + + // Map the found subject types by the computed userset relation, so that we dispatch to it. + toDispatchByComputedRelationType, err := toDispatchByTuplesetType.Map(func(resourceType *core.RelationReference) (*core.RelationReference, error) { + if err := namespace.CheckNamespaceAndRelation(ctx, resourceType.Namespace, ttu.GetComputedUserset().Relation, false, ds); err != nil { + if errors.As(err, &namespace.RelationNotFoundError{}) { + return nil, nil + } + + return nil, err + } + + return &core.RelationReference{ + Namespace: resourceType.Namespace, + Relation: ttu.GetComputedUserset().Relation, + }, nil + }) + if err != nil { + return err + } + + return cl.dispatchTo(ctx, parentRequest, toDispatchByComputedRelationType, relationshipsBySubjectONR, parentStream) +} + +func (cl *ConcurrentLookupSubjects) lookupViaRewrite( + ctx context.Context, + req ValidatedLookupSubjectsRequest, + stream dispatch.LookupSubjectsStream, + usr *core.UsersetRewrite, +) error { + switch rw := usr.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + log.Ctx(ctx).Trace().Msg("union") + return cl.lookupSetOperation(ctx, req, rw.Union, newLookupSubjectsUnion(stream)) + case *core.UsersetRewrite_Intersection: + log.Ctx(ctx).Trace().Msg("intersection") + return cl.lookupSetOperation(ctx, req, rw.Intersection, newLookupSubjectsIntersection(stream)) + case *core.UsersetRewrite_Exclusion: + log.Ctx(ctx).Trace().Msg("exclusion") + return cl.lookupSetOperation(ctx, req, rw.Exclusion, newLookupSubjectsExclusion(stream)) + default: + return fmt.Errorf("unknown kind of rewrite in lookup subjects") + } +} + +func (cl *ConcurrentLookupSubjects) lookupSetOperation( + ctx context.Context, + req ValidatedLookupSubjectsRequest, + so *core.SetOperation, + reducer lookupSubjectsReducer, +) error { + cancelCtx, checkCancel := context.WithCancel(ctx) + defer checkCancel() + + g, subCtx := errgroup.WithContext(cancelCtx) + g.SetLimit(int(cl.concurrencyLimit)) + + for index, childOneof := range so.Child { + stream := reducer.ForIndex(subCtx, index) + + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_XThis: + return errors.New("use of _this is unsupported; please rewrite your schema") + + case *core.SetOperation_Child_ComputedUserset: + g.Go(func() error { + return cl.lookupViaComputed(subCtx, req, stream, child.ComputedUserset) + }) + + case *core.SetOperation_Child_UsersetRewrite: + g.Go(func() error { + return cl.lookupViaRewrite(subCtx, req, stream, child.UsersetRewrite) + }) + + case *core.SetOperation_Child_TupleToUserset: + g.Go(func() error { + return lookupViaTupleToUserset(subCtx, cl, req, stream, child.TupleToUserset) + }) + + case *core.SetOperation_Child_FunctionedTupleToUserset: + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + g.Go(func() error { + return lookupViaTupleToUserset(subCtx, cl, req, stream, child.FunctionedTupleToUserset) + }) + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + g.Go(func() error { + return lookupViaIntersectionTupleToUserset(subCtx, cl, req, stream, child.FunctionedTupleToUserset) + }) + + default: + return spiceerrors.MustBugf("unknown function in lookup subjects: %v", child.FunctionedTupleToUserset.Function) + } + + case *core.SetOperation_Child_XNil: + // Purposely do nothing. + continue + + default: + return spiceerrors.MustBugf("unknown set operation child `%T` in lookup subjects", child) + } + } + + // Wait for all dispatched operations to complete. + if err := g.Wait(); err != nil { + return err + } + + return reducer.CompletedChildOperations(ctx) +} + +func (cl *ConcurrentLookupSubjects) dispatchTo( + ctx context.Context, + parentRequest ValidatedLookupSubjectsRequest, + toDispatchByType *datasets.SubjectByTypeSet, + relationshipsBySubjectONR *mapz.MultiMap[tuple.ObjectAndRelation, tuple.Relationship], + parentStream dispatch.LookupSubjectsStream, +) error { + if toDispatchByType.IsEmpty() { + return nil + } + + cancelCtx, checkCancel := context.WithCancel(ctx) + defer checkCancel() + + g, subCtx := errgroup.WithContext(cancelCtx) + g.SetLimit(int(cl.concurrencyLimit)) + + toDispatchByType.ForEachType(func(resourceType *core.RelationReference, foundSubjects datasets.SubjectSet) { + slice := foundSubjects.AsSlice() + resourceIds := make([]string, 0, len(slice)) + for _, foundSubject := range slice { + resourceIds = append(resourceIds, foundSubject.SubjectId) + } + + stream := &dispatch.WrappedDispatchStream[*v1.DispatchLookupSubjectsResponse]{ + Stream: parentStream, + Ctx: subCtx, + Processor: func(result *v1.DispatchLookupSubjectsResponse) (*v1.DispatchLookupSubjectsResponse, bool, error) { + // For any found subjects, map them through their associated starting resources, to apply any caveats that were + // only those resources' relationships. + // + // For example, given relationships which formed the dispatch: + // - document:firstdoc#viewer@group:group1#member + // - document:firstdoc#viewer@group:group2#member[somecaveat] + // + // And results: + // - group1 => {user:tom, user:sarah} + // - group2 => {user:tom, user:fred} + // + // This will produce: + // - firstdoc => {user:tom, user:sarah, user:fred[somecaveat]} + // + mappedFoundSubjects := make(map[string]*v1.FoundSubjects) + for childResourceID, foundSubjects := range result.FoundSubjectsByResourceId { + subjectKey := tuple.ONR(resourceType.Namespace, childResourceID, resourceType.Relation) + relationships, _ := relationshipsBySubjectONR.Get(subjectKey) + if len(relationships) == 0 { + return nil, false, fmt.Errorf("missing relationships for subject key %v; please report this error", subjectKey) + } + + for _, relationship := range relationships { + existing := mappedFoundSubjects[relationship.Resource.ObjectID] + + // If the relationship has no caveat, simply map the resource ID. + if relationship.OptionalCaveat == nil { + combined, err := combineFoundSubjects(existing, foundSubjects) + if err != nil { + return nil, false, fmt.Errorf("could not combine caveat-less subjects: %w", err) + } + mappedFoundSubjects[relationship.Resource.ObjectID] = combined + continue + } + + // Otherwise, apply the caveat to all found subjects for that resource and map to the resource ID. + foundSubjectSet := datasets.NewSubjectSet() + err := foundSubjectSet.UnionWith(foundSubjects.FoundSubjects) + if err != nil { + return nil, false, fmt.Errorf("could not combine subject sets: %w", err) + } + + combined, err := combineFoundSubjects( + existing, + foundSubjectSet.WithParentCaveatExpression(wrapCaveat(relationship.OptionalCaveat)).AsFoundSubjects(), + ) + if err != nil { + return nil, false, fmt.Errorf("could not combine caveated subjects: %w", err) + } + + mappedFoundSubjects[relationship.Resource.ObjectID] = combined + } + } + + return &v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: mappedFoundSubjects, + Metadata: addCallToResponseMetadata(result.Metadata), + }, true, nil + }, + } + + // Dispatch the found subjects as the resources of the next step. + slicez.ForEachChunk(resourceIds, cl.dispatchChunkSize, func(resourceIdChunk []string) { + g.Go(func() error { + return cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{ + ResourceRelation: resourceType, + ResourceIds: resourceIdChunk, + SubjectRelation: parentRequest.SubjectRelation, + Metadata: &v1.ResolverMeta{ + AtRevision: parentRequest.Revision.String(), + DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + }, + }, stream) + }) + }) + }) + + return g.Wait() +} + +func combineFoundSubjects(existing *v1.FoundSubjects, toAdd *v1.FoundSubjects) (*v1.FoundSubjects, error) { + if existing == nil { + return toAdd, nil + } + + if toAdd == nil { + return nil, fmt.Errorf("toAdd FoundSubject cannot be nil") + } + + return &v1.FoundSubjects{ + FoundSubjects: append(existing.FoundSubjects, toAdd.FoundSubjects...), + }, nil +} + +type lookupSubjectsReducer interface { + ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream + CompletedChildOperations(ctx context.Context) error +} + +// Union +type lookupSubjectsUnion struct { + parentStream dispatch.LookupSubjectsStream + collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse] +} + +func newLookupSubjectsUnion(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsUnion { + return &lookupSubjectsUnion{ + parentStream: parentStream, + collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{}, + } +} + +func (lsu *lookupSubjectsUnion) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream { + collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx) + lsu.collectors[setOperationIndex] = collector + return collector +} + +func (lsu *lookupSubjectsUnion) CompletedChildOperations(ctx context.Context) error { + foundSubjects := datasets.NewSubjectSetByResourceID() + metadata := emptyMetadata + + for index := 0; index < len(lsu.collectors); index++ { + collector, ok := lsu.collectors[index] + if !ok { + return fmt.Errorf("missing collector for index %d", index) + } + + for _, result := range collector.Results() { + metadata = combineResponseMetadata(ctx, metadata, result.Metadata) + if err := foundSubjects.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsUnion: %w", err) + } + } + } + + if foundSubjects.IsEmpty() { + return nil + } + + return lsu.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: foundSubjects.AsMap(), + Metadata: metadata, + }) +} + +// Intersection +type lookupSubjectsIntersection struct { + parentStream dispatch.LookupSubjectsStream + collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse] +} + +func newLookupSubjectsIntersection(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsIntersection { + return &lookupSubjectsIntersection{ + parentStream: parentStream, + collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{}, + } +} + +func (lsi *lookupSubjectsIntersection) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream { + collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx) + lsi.collectors[setOperationIndex] = collector + return collector +} + +func (lsi *lookupSubjectsIntersection) CompletedChildOperations(ctx context.Context) error { + var foundSubjects datasets.SubjectSetByResourceID + metadata := emptyMetadata + + for index := 0; index < len(lsi.collectors); index++ { + collector, ok := lsi.collectors[index] + if !ok { + return fmt.Errorf("missing collector for index %d", index) + } + + results := datasets.NewSubjectSetByResourceID() + for _, result := range collector.Results() { + metadata = combineResponseMetadata(ctx, metadata, result.Metadata) + if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsIntersection: %w", err) + } + } + + if index == 0 { + foundSubjects = results + } else { + err := foundSubjects.IntersectionDifference(results) + if err != nil { + return err + } + + if foundSubjects.IsEmpty() { + return nil + } + } + } + + return lsi.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: foundSubjects.AsMap(), + Metadata: metadata, + }) +} + +// Exclusion +type lookupSubjectsExclusion struct { + parentStream dispatch.LookupSubjectsStream + collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse] +} + +func newLookupSubjectsExclusion(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsExclusion { + return &lookupSubjectsExclusion{ + parentStream: parentStream, + collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{}, + } +} + +func (lse *lookupSubjectsExclusion) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream { + collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx) + lse.collectors[setOperationIndex] = collector + return collector +} + +func (lse *lookupSubjectsExclusion) CompletedChildOperations(ctx context.Context) error { + var foundSubjects datasets.SubjectSetByResourceID + metadata := emptyMetadata + + for index := 0; index < len(lse.collectors); index++ { + collector := lse.collectors[index] + results := datasets.NewSubjectSetByResourceID() + for _, result := range collector.Results() { + metadata = combineResponseMetadata(ctx, metadata, result.Metadata) + if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsExclusion: %w", err) + } + } + + if index == 0 { + foundSubjects = results + } else { + foundSubjects.SubtractAll(results) + if foundSubjects.IsEmpty() { + return nil + } + } + } + + return lse.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: foundSubjects.AsMap(), + Metadata: metadata, + }) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go b/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go new file mode 100644 index 0000000..f04ee6f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go @@ -0,0 +1,334 @@ +package graph + +import ( + "context" + "strconv" + "sync" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph/computed" + "github.com/authzed/spicedb/internal/graph/hints" + "github.com/authzed/spicedb/internal/taskrunner" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// runCheckerAndDispatch runs the dispatch and checker for a lookup resources call, and publishes +// the results to the parent stream. This function is responsible for handling checking the +// results to filter them, and then dispatching those found. +func runCheckerAndDispatch( + ctx context.Context, + parentReq ValidatedLookupResources2Request, + foundResources dispatchableResourcesSubjectMap2, + ci cursorInformation, + parentStream dispatch.LookupResources2Stream, + newSubjectType *core.RelationReference, + filteredSubjectIDs []string, + entrypoint schema.ReachabilityEntrypoint, + lrDispatcher dispatch.LookupResources2, + checkDispatcher dispatch.Check, + caveatTypeSet *caveattypes.TypeSet, + concurrencyLimit uint16, + dispatchChunkSize uint16, +) error { + // Only allow max one dispatcher and one checker to run concurrently. + concurrencyLimit = min(concurrencyLimit, 2) + + currentCheckIndex, err := ci.integerSectionValue() + if err != nil { + return err + } + + rdc := &checkAndDispatchRunner{ + parentRequest: parentReq, + foundResources: foundResources, + ci: ci, + parentStream: parentStream, + newSubjectType: newSubjectType, + filteredSubjectIDs: filteredSubjectIDs, + currentCheckIndex: currentCheckIndex, + entrypoint: entrypoint, + lrDispatcher: lrDispatcher, + checkDispatcher: checkDispatcher, + taskrunner: taskrunner.NewTaskRunner(ctx, concurrencyLimit), + lock: &sync.Mutex{}, + dispatchChunkSize: dispatchChunkSize, + caveatTypeSet: caveatTypeSet, + } + + return rdc.runAndWait() +} + +type checkAndDispatchRunner struct { + parentRequest ValidatedLookupResources2Request + foundResources dispatchableResourcesSubjectMap2 + parentStream dispatch.LookupResources2Stream + newSubjectType *core.RelationReference + entrypoint schema.ReachabilityEntrypoint + lrDispatcher dispatch.LookupResources2 + checkDispatcher dispatch.Check + dispatchChunkSize uint16 + caveatTypeSet *caveattypes.TypeSet + filteredSubjectIDs []string + + currentCheckIndex int + taskrunner *taskrunner.TaskRunner + + lock *sync.Mutex + ci cursorInformation // GUARDED_BY(lock) +} + +func (rdc *checkAndDispatchRunner) runAndWait() error { + // Kick off a check at the current cursor, to filter a portion of the initial results set. + rdc.taskrunner.Schedule(func(ctx context.Context) error { + return rdc.runChecker(ctx, rdc.currentCheckIndex) + }) + + return rdc.taskrunner.Wait() +} + +func (rdc *checkAndDispatchRunner) runChecker(ctx context.Context, startingIndex int) error { + rdc.lock.Lock() + if rdc.ci.limits.hasExhaustedLimit() { + rdc.lock.Unlock() + return nil + } + rdc.lock.Unlock() + + endingIndex := min(startingIndex+int(rdc.dispatchChunkSize), len(rdc.filteredSubjectIDs)) + resourceIDsToCheck := rdc.filteredSubjectIDs[startingIndex:endingIndex] + if len(resourceIDsToCheck) == 0 { + return nil + } + + ctx, span := tracer.Start(ctx, "lr2Check", trace.WithAttributes( + attribute.Int("resource-id-count", len(resourceIDsToCheck)), + )) + defer span.End() + + checkHints := make([]*v1.CheckHint, 0, len(resourceIDsToCheck)) + for _, resourceID := range resourceIDsToCheck { + checkHint, err := hints.HintForEntrypoint( + rdc.entrypoint, + resourceID, + tuple.FromCoreObjectAndRelation(rdc.parentRequest.TerminalSubject), + &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_MEMBER, + }) + if err != nil { + return err + } + checkHints = append(checkHints, checkHint) + } + + // NOTE: we are checking the containing permission here, *not* the target relation, as + // the goal is to shear for the containing permission. + resultsByResourceID, checkMetadata, _, err := computed.ComputeBulkCheck(ctx, rdc.checkDispatcher, rdc.caveatTypeSet, computed.CheckParameters{ + ResourceType: tuple.FromCoreRelationReference(rdc.newSubjectType), + Subject: tuple.FromCoreObjectAndRelation(rdc.parentRequest.TerminalSubject), + CaveatContext: rdc.parentRequest.Context.AsMap(), + AtRevision: rdc.parentRequest.Revision, + MaximumDepth: rdc.parentRequest.Metadata.DepthRemaining - 1, + DebugOption: computed.NoDebugging, + CheckHints: checkHints, + }, resourceIDsToCheck, rdc.dispatchChunkSize) + if err != nil { + return err + } + + adjustedResources := rdc.foundResources.cloneAsMutable() + + // Dispatch any resources that are visible. + resourceIDToDispatch := make([]string, 0, len(resourceIDsToCheck)) + for _, resourceID := range resourceIDsToCheck { + result, ok := resultsByResourceID[resourceID] + if !ok { + continue + } + + switch result.Membership { + case v1.ResourceCheckResult_MEMBER: + fallthrough + + case v1.ResourceCheckResult_CAVEATED_MEMBER: + // Record any additional caveats missing from the check. + adjustedResources.withAdditionalMissingContextForDispatchedResourceID(resourceID, result.MissingExprFields) + resourceIDToDispatch = append(resourceIDToDispatch, resourceID) + + case v1.ResourceCheckResult_NOT_MEMBER: + // Skip. + continue + + default: + return spiceerrors.MustBugf("unexpected result from check: %v", result.Membership) + } + } + + if len(resourceIDToDispatch) > 0 { + // Schedule a dispatch of those resources. + rdc.taskrunner.Schedule(func(ctx context.Context) error { + return rdc.runDispatch(ctx, resourceIDToDispatch, adjustedResources.asReadOnly(), checkMetadata, startingIndex) + }) + } + + // Start the next check chunk (if applicable). + nextIndex := startingIndex + len(resourceIDsToCheck) + if nextIndex < len(rdc.filteredSubjectIDs) { + rdc.taskrunner.Schedule(func(ctx context.Context) error { + return rdc.runChecker(ctx, nextIndex) + }) + } + + return nil +} + +func (rdc *checkAndDispatchRunner) runDispatch( + ctx context.Context, + resourceIDsToDispatch []string, + adjustedResources dispatchableResourcesSubjectMap2, + checkMetadata *v1.ResponseMeta, + startingIndex int, +) error { + rdc.lock.Lock() + if rdc.ci.limits.hasExhaustedLimit() { + rdc.lock.Unlock() + return nil + } + rdc.lock.Unlock() + + ctx, span := tracer.Start(ctx, "lr2Dispatch", trace.WithAttributes( + attribute.Int("resource-id-count", len(resourceIDsToDispatch)), + )) + defer span.End() + + // NOTE: Since we extracted a custom section from the cursor at the beginning of this run, we have to add + // the starting index to the cursor to ensure that the next run starts from the correct place, and we have + // to use the *updated* cursor below on the dispatch. + updatedCi, err := rdc.ci.withOutgoingSection(strconv.Itoa(startingIndex)) + if err != nil { + return err + } + responsePartialCursor := updatedCi.responsePartialCursor() + + // Dispatch to the parent resource type and publish any results found. + isFirstPublishCall := true + + wrappedStream := dispatch.NewHandlingDispatchStream(ctx, func(result *v1.DispatchLookupResources2Response) error { + if err := ctx.Err(); err != nil { + return err + } + + if err := publishResultToParentStream(ctx, result, rdc.ci, responsePartialCursor, adjustedResources, nil, isFirstPublishCall, checkMetadata, rdc.parentStream); err != nil { + return err + } + isFirstPublishCall = false + return nil + }) + + return rdc.lrDispatcher.DispatchLookupResources2(&v1.DispatchLookupResources2Request{ + ResourceRelation: rdc.parentRequest.ResourceRelation, + SubjectRelation: rdc.newSubjectType, + SubjectIds: resourceIDsToDispatch, + TerminalSubject: rdc.parentRequest.TerminalSubject, + Metadata: &v1.ResolverMeta{ + AtRevision: rdc.parentRequest.Revision.String(), + DepthRemaining: rdc.parentRequest.Metadata.DepthRemaining - 1, + }, + OptionalCursor: updatedCi.currentCursor, + OptionalLimit: rdc.ci.limits.currentLimit, + Context: rdc.parentRequest.Context, + }, wrappedStream) +} + +// unfilteredLookupResourcesDispatchStreamForEntrypoint creates a new dispatch stream that wraps +// the parent stream, and publishes the results of the lookup resources call to the parent stream, +// mapped via foundResources. +func unfilteredLookupResourcesDispatchStreamForEntrypoint( + ctx context.Context, + foundResources dispatchableResourcesSubjectMap2, + parentStream dispatch.LookupResources2Stream, + ci cursorInformation, +) dispatch.LookupResources2Stream { + isFirstPublishCall := true + + wrappedStream := dispatch.NewHandlingDispatchStream(ctx, func(result *v1.DispatchLookupResources2Response) error { + select { + case <-ctx.Done(): + return ctx.Err() + + default: + } + + if err := publishResultToParentStream(ctx, result, ci, ci.responsePartialCursor(), foundResources, nil, isFirstPublishCall, emptyMetadata, parentStream); err != nil { + return err + } + isFirstPublishCall = false + return nil + }) + + return wrappedStream +} + +// publishResultToParentStream publishes the result of a lookup resources call to the parent stream, +// mapped via foundResources. +func publishResultToParentStream( + ctx context.Context, + result *v1.DispatchLookupResources2Response, + ci cursorInformation, + responseCursor *v1.Cursor, + foundResources dispatchableResourcesSubjectMap2, + additionalMissingContext []string, + isFirstPublishCall bool, + additionalMetadata *v1.ResponseMeta, + parentStream dispatch.LookupResources2Stream, +) error { + // Map the found resources via the subject+resources used for dispatching, to determine + // if any need to be made conditional due to caveats. + mappedResource, err := foundResources.mapPossibleResource(result.Resource) + if err != nil { + return err + } + + if !ci.limits.prepareForPublishing() { + return nil + } + + // The cursor for the response is that of the parent response + the cursor from the result itself. + afterResponseCursor, err := combineCursors( + responseCursor, + result.AfterResponseCursor, + ) + if err != nil { + return err + } + + metadata := result.Metadata + if isFirstPublishCall { + metadata = addCallToResponseMetadata(metadata) + metadata = combineResponseMetadata(ctx, metadata, additionalMetadata) + } else { + metadata = addAdditionalDepthRequired(metadata) + } + + missingContextParameters := mapz.NewSet(mappedResource.MissingContextParams...) + missingContextParameters.Extend(result.Resource.MissingContextParams) + missingContextParameters.Extend(additionalMissingContext) + + mappedResource.MissingContextParams = missingContextParameters.AsSlice() + + resp := &v1.DispatchLookupResources2Response{ + Resource: mappedResource, + Metadata: metadata, + AfterResponseCursor: afterResponseCursor, + } + + return parentStream.Publish(resp) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go b/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go new file mode 100644 index 0000000..8ab20f4 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go @@ -0,0 +1,243 @@ +package graph + +import ( + "github.com/authzed/spicedb/internal/caveats" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +var ( + caveatOr = caveats.Or + caveatAnd = caveats.And + caveatSub = caveats.Subtract + wrapCaveat = caveats.CaveatAsExpr +) + +// CheckResultsMap defines a type that is a map from resource ID to ResourceCheckResult. +// This must match that defined in the DispatchCheckResponse for the `results_by_resource_id` +// field. +type CheckResultsMap map[string]*v1.ResourceCheckResult + +// NewMembershipSet constructs a new helper set for tracking the membership found for a dispatched +// check request. +func NewMembershipSet() *MembershipSet { + return &MembershipSet{ + hasDeterminedMember: false, + membersByID: map[string]*core.CaveatExpression{}, + } +} + +func membershipSetFromMap(mp map[string]*core.CaveatExpression) *MembershipSet { + ms := NewMembershipSet() + for resourceID, result := range mp { + ms.addMember(resourceID, result) + } + return ms +} + +// MembershipSet is a helper set that trackes the membership results for a dispatched Check +// request, including tracking of the caveats associated with found resource IDs. +type MembershipSet struct { + membersByID map[string]*core.CaveatExpression + hasDeterminedMember bool +} + +// AddDirectMember adds a resource ID that was *directly* found for the dispatched check, with +// optional caveat found on the relationship. +func (ms *MembershipSet) AddDirectMember(resourceID string, caveat *core.ContextualizedCaveat) { + ms.addMember(resourceID, wrapCaveat(caveat)) +} + +// AddMemberViaRelationship adds a resource ID that was found via another relationship, such +// as the result of an arrow operation. The `parentRelationship` is the relationship that was +// followed before the resource itself was resolved. This method will properly apply the caveat(s) +// from both the parent relationship and the resource's result itself, assuming either have a caveat +// associated. +func (ms *MembershipSet) AddMemberViaRelationship( + resourceID string, + resourceCaveatExpression *core.CaveatExpression, + parentRelationship tuple.Relationship, +) { + ms.AddMemberWithParentCaveat(resourceID, resourceCaveatExpression, parentRelationship.OptionalCaveat) +} + +// AddMemberWithParentCaveat adds the given resource ID as a member with the parent caveat +// combined via intersection with the resource's caveat. The parent caveat may be nil. +func (ms *MembershipSet) AddMemberWithParentCaveat( + resourceID string, + resourceCaveatExpression *core.CaveatExpression, + parentCaveat *core.ContextualizedCaveat, +) { + intersection := caveatAnd(wrapCaveat(parentCaveat), resourceCaveatExpression) + ms.addMember(resourceID, intersection) +} + +// AddMemberWithOptionalCaveats adds the given resource ID as a member with the optional caveats combined +// via intersection. +func (ms *MembershipSet) AddMemberWithOptionalCaveats( + resourceID string, + caveats []*core.CaveatExpression, +) { + if len(caveats) == 0 { + ms.addMember(resourceID, nil) + return + } + + intersection := caveats[0] + for _, caveat := range caveats[1:] { + intersection = caveatAnd(intersection, caveat) + } + + ms.addMember(resourceID, intersection) +} + +func (ms *MembershipSet) addMember(resourceID string, caveatExpr *core.CaveatExpression) { + existing, ok := ms.membersByID[resourceID] + if !ok { + ms.hasDeterminedMember = ms.hasDeterminedMember || caveatExpr == nil + ms.membersByID[resourceID] = caveatExpr + return + } + + // If a determined membership result has already been found (i.e. there is no caveat), + // then nothing more to do. + if existing == nil { + return + } + + // If the new caveat expression is nil, then we are adding a determined result. + if caveatExpr == nil { + ms.hasDeterminedMember = true + ms.membersByID[resourceID] = nil + return + } + + // Otherwise, the caveats get unioned together. + ms.membersByID[resourceID] = caveatOr(existing, caveatExpr) +} + +// UnionWith combines the results found in the given map with the members of this set. +// The changes are made in-place. +func (ms *MembershipSet) UnionWith(resultsMap CheckResultsMap) { + for resourceID, details := range resultsMap { + if details.Membership != v1.ResourceCheckResult_NOT_MEMBER { + ms.addMember(resourceID, details.Expression) + } + } +} + +// IntersectWith intersects the results found in the given map with the members of this set. +// The changes are made in-place. +func (ms *MembershipSet) IntersectWith(resultsMap CheckResultsMap) { + for resourceID := range ms.membersByID { + if details, ok := resultsMap[resourceID]; !ok || details.Membership == v1.ResourceCheckResult_NOT_MEMBER { + delete(ms.membersByID, resourceID) + } + } + + ms.hasDeterminedMember = false + for resourceID, details := range resultsMap { + existing, ok := ms.membersByID[resourceID] + if !ok || details.Membership == v1.ResourceCheckResult_NOT_MEMBER { + continue + } + if existing == nil && details.Expression == nil { + ms.hasDeterminedMember = true + continue + } + + ms.membersByID[resourceID] = caveatAnd(existing, details.Expression) + } +} + +// Subtract subtracts the results found in the given map with the members of this set. +// The changes are made in-place. +func (ms *MembershipSet) Subtract(resultsMap CheckResultsMap) { + ms.hasDeterminedMember = false + for resourceID, expression := range ms.membersByID { + if details, ok := resultsMap[resourceID]; ok && details.Membership != v1.ResourceCheckResult_NOT_MEMBER { + // If the incoming member has no caveat, then this removal is absolute. + if details.Expression == nil { + delete(ms.membersByID, resourceID) + continue + } + + // Otherwise, the caveat expression gets combined with an intersection of the inversion + // of the expression. + ms.membersByID[resourceID] = caveatSub(expression, details.Expression) + } else { + if expression == nil { + ms.hasDeterminedMember = true + } + } + } +} + +// HasConcreteResourceID returns whether the resourceID was found in the set +// and has no caveat attached. +func (ms *MembershipSet) HasConcreteResourceID(resourceID string) bool { + if ms == nil { + return false + } + + found, ok := ms.membersByID[resourceID] + return ok && found == nil +} + +// GetResourceID returns a bool indicating whether the resource is found in the set and the +// associated caveat expression, if any. +func (ms *MembershipSet) GetResourceID(resourceID string) (bool, *core.CaveatExpression) { + if ms == nil { + return false, nil + } + + caveat, ok := ms.membersByID[resourceID] + return ok, caveat +} + +// Size returns the number of elements in the membership set. +func (ms *MembershipSet) Size() int { + if ms == nil { + return 0 + } + + return len(ms.membersByID) +} + +// IsEmpty returns true if the set is empty. +func (ms *MembershipSet) IsEmpty() bool { + if ms == nil { + return true + } + + return len(ms.membersByID) == 0 +} + +// HasDeterminedMember returns whether there exists at least one non-caveated member of the set. +func (ms *MembershipSet) HasDeterminedMember() bool { + if ms == nil { + return false + } + + return ms.hasDeterminedMember +} + +// AsCheckResultsMap converts the membership set back into a CheckResultsMap for placement into +// a DispatchCheckResult. +func (ms *MembershipSet) AsCheckResultsMap() CheckResultsMap { + resultsMap := make(CheckResultsMap, len(ms.membersByID)) + for resourceID, caveat := range ms.membersByID { + membership := v1.ResourceCheckResult_MEMBER + if caveat != nil { + membership = v1.ResourceCheckResult_CAVEATED_MEMBER + } + + resultsMap[resourceID] = &v1.ResourceCheckResult{ + Membership: membership, + Expression: caveat, + } + } + + return resultsMap +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go b/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go new file mode 100644 index 0000000..4e41955 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go @@ -0,0 +1,248 @@ +package graph + +import ( + "sort" + "sync" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +type syncONRSet struct { + sync.Mutex + items map[string]struct{} // GUARDED_BY(Mutex) +} + +func (s *syncONRSet) Add(onr *core.ObjectAndRelation) bool { + key := tuple.StringONR(tuple.FromCoreObjectAndRelation(onr)) + s.Lock() + _, existed := s.items[key] + if !existed { + s.items[key] = struct{}{} + } + s.Unlock() + return !existed +} + +func NewSyncONRSet() *syncONRSet { + return &syncONRSet{items: make(map[string]struct{})} +} + +// resourcesSubjectMap2 is a multimap which tracks mappings from found resource IDs +// to the subject IDs (may be more than one) for each, as well as whether the mapping +// is conditional due to the use of a caveat on the relationship which formed the mapping. +type resourcesSubjectMap2 struct { + resourceType *core.RelationReference + resourcesAndSubjects *mapz.MultiMap[string, subjectInfo2] +} + +// subjectInfo2 is the information about a subject contained in a resourcesSubjectMap2. +type subjectInfo2 struct { + subjectID string + missingContextParameters []string +} + +func newResourcesSubjectMap2(resourceType *core.RelationReference) resourcesSubjectMap2 { + return resourcesSubjectMap2{ + resourceType: resourceType, + resourcesAndSubjects: mapz.NewMultiMap[string, subjectInfo2](), + } +} + +func newResourcesSubjectMap2WithCapacity(resourceType *core.RelationReference, capacity uint32) resourcesSubjectMap2 { + return resourcesSubjectMap2{ + resourceType: resourceType, + resourcesAndSubjects: mapz.NewMultiMapWithCap[string, subjectInfo2](capacity), + } +} + +func subjectIDsToResourcesMap2(resourceType *core.RelationReference, subjectIDs []string) resourcesSubjectMap2 { + rsm := newResourcesSubjectMap2(resourceType) + for _, subjectID := range subjectIDs { + rsm.addSubjectIDAsFoundResourceID(subjectID) + } + return rsm +} + +// addRelationship adds the relationship to the resource subject map, recording a mapping from +// the resource of the relationship to the subject, as well as whether the relationship was caveated. +func (rsm resourcesSubjectMap2) addRelationship(rel tuple.Relationship, missingContextParameters []string) error { + spiceerrors.DebugAssert(func() bool { + return rel.Resource.ObjectType == rsm.resourceType.Namespace && rel.Resource.Relation == rsm.resourceType.Relation + }, "invalid relationship for addRelationship. expected: %v, found: %v", rsm.resourceType, rel.Resource) + + spiceerrors.DebugAssert(func() bool { + return len(missingContextParameters) == 0 || rel.OptionalCaveat != nil + }, "missing context parameters must be empty if there is no caveat") + + rsm.resourcesAndSubjects.Add(rel.Resource.ObjectID, subjectInfo2{rel.Subject.ObjectID, missingContextParameters}) + return nil +} + +// withAdditionalMissingContextForDispatchedResourceID adds additional missing context parameters +// to the existing missing context parameters for the dispatched resource ID. +func (rsm resourcesSubjectMap2) withAdditionalMissingContextForDispatchedResourceID( + resourceID string, + additionalMissingContext []string, +) { + if len(additionalMissingContext) == 0 { + return + } + + subjectInfo2s, _ := rsm.resourcesAndSubjects.Get(resourceID) + updatedInfos := make([]subjectInfo2, 0, len(subjectInfo2s)) + for _, info := range subjectInfo2s { + info.missingContextParameters = append(info.missingContextParameters, additionalMissingContext...) + updatedInfos = append(updatedInfos, info) + } + rsm.resourcesAndSubjects.Set(resourceID, updatedInfos) +} + +// addSubjectIDAsFoundResourceID adds a subject ID directly as a found subject for itself as the resource, +// with no associated caveat. +func (rsm resourcesSubjectMap2) addSubjectIDAsFoundResourceID(subjectID string) { + rsm.resourcesAndSubjects.Add(subjectID, subjectInfo2{subjectID, nil}) +} + +// asReadOnly returns a read-only dispatchableResourcesSubjectMap2 for dispatching for the +// resources in this map (if any). +func (rsm resourcesSubjectMap2) asReadOnly() dispatchableResourcesSubjectMap2 { + return dispatchableResourcesSubjectMap2{rsm} +} + +func (rsm resourcesSubjectMap2) len() int { + return rsm.resourcesAndSubjects.Len() +} + +// dispatchableResourcesSubjectMap2 is a read-only, frozen version of the resourcesSubjectMap2 that +// can be used for mapping conditionals once calls have been dispatched. This is read-only due to +// its use by concurrent callers. +type dispatchableResourcesSubjectMap2 struct { + resourcesSubjectMap2 +} + +func (rsm dispatchableResourcesSubjectMap2) len() int { + return rsm.resourcesAndSubjects.Len() +} + +func (rsm dispatchableResourcesSubjectMap2) isEmpty() bool { + return rsm.resourcesAndSubjects.IsEmpty() +} + +func (rsm dispatchableResourcesSubjectMap2) resourceIDs() []string { + return rsm.resourcesAndSubjects.Keys() +} + +// filterSubjectIDsToDispatch returns the set of subject IDs that have not yet been +// dispatched, by adding them to the dispatched set. +func (rsm dispatchableResourcesSubjectMap2) filterSubjectIDsToDispatch(dispatched *syncONRSet, dispatchSubjectType *core.RelationReference) []string { + resourceIDs := rsm.resourceIDs() + filtered := make([]string, 0, len(resourceIDs)) + for _, resourceID := range resourceIDs { + if dispatched.Add(&core.ObjectAndRelation{ + Namespace: dispatchSubjectType.Namespace, + ObjectId: resourceID, + Relation: dispatchSubjectType.Relation, + }) { + filtered = append(filtered, resourceID) + } + } + + return filtered +} + +// cloneAsMutable returns a mutable clone of this dispatchableResourcesSubjectMap2. +func (rsm dispatchableResourcesSubjectMap2) cloneAsMutable() resourcesSubjectMap2 { + return resourcesSubjectMap2{ + resourceType: rsm.resourceType, + resourcesAndSubjects: rsm.resourcesAndSubjects.Clone(), + } +} + +func (rsm dispatchableResourcesSubjectMap2) asPossibleResources() []*v1.PossibleResource { + resources := make([]*v1.PossibleResource, 0, rsm.resourcesAndSubjects.Len()) + + // Sort for stability. + sortedResourceIds := rsm.resourcesAndSubjects.Keys() + sort.Strings(sortedResourceIds) + + for _, resourceID := range sortedResourceIds { + subjectInfo2s, _ := rsm.resourcesAndSubjects.Get(resourceID) + subjectIDs := make([]string, 0, len(subjectInfo2s)) + allCaveated := true + nonCaveatedSubjectIDs := make([]string, 0, len(subjectInfo2s)) + missingContextParameters := mapz.NewSet[string]() + + for _, info := range subjectInfo2s { + subjectIDs = append(subjectIDs, info.subjectID) + if len(info.missingContextParameters) == 0 { + allCaveated = false + nonCaveatedSubjectIDs = append(nonCaveatedSubjectIDs, info.subjectID) + } else { + missingContextParameters.Extend(info.missingContextParameters) + } + } + + // Sort for stability. + sort.Strings(subjectIDs) + + // If all the incoming edges are caveated, then the entire status has to be marked as a check + // is required. Otherwise, if there is at least *one* non-caveated incoming edge, then we can + // return the existing status as a short-circuit for those non-caveated found subjects. + if allCaveated { + resources = append(resources, &v1.PossibleResource{ + ResourceId: resourceID, + ForSubjectIds: subjectIDs, + MissingContextParams: missingContextParameters.AsSlice(), + }) + } else { + resources = append(resources, &v1.PossibleResource{ + ResourceId: resourceID, + ForSubjectIds: nonCaveatedSubjectIDs, + }) + } + } + return resources +} + +func (rsm dispatchableResourcesSubjectMap2) mapPossibleResource(foundResource *v1.PossibleResource) (*v1.PossibleResource, error) { + forSubjectIDs := mapz.NewSet[string]() + nonCaveatedSubjectIDs := mapz.NewSet[string]() + missingContextParameters := mapz.NewSet[string]() + + for _, forSubjectID := range foundResource.ForSubjectIds { + // Map from the incoming subject ID to the subject ID(s) that caused the dispatch. + infos, ok := rsm.resourcesAndSubjects.Get(forSubjectID) + if !ok { + return nil, spiceerrors.MustBugf("missing for subject ID") + } + + for _, info := range infos { + forSubjectIDs.Insert(info.subjectID) + if len(info.missingContextParameters) == 0 { + nonCaveatedSubjectIDs.Insert(info.subjectID) + } else { + missingContextParameters.Extend(info.missingContextParameters) + } + } + } + + // If there are some non-caveated IDs, return those and mark as the parent status. + if nonCaveatedSubjectIDs.Len() > 0 { + return &v1.PossibleResource{ + ResourceId: foundResource.ResourceId, + ForSubjectIds: nonCaveatedSubjectIDs.AsSlice(), + }, nil + } + + // Otherwise, everything is caveated, so return the full set of subject IDs and mark + // as a check is required. + return &v1.PossibleResource{ + ResourceId: foundResource.ResourceId, + ForSubjectIds: forSubjectIDs.AsSlice(), + MissingContextParams: missingContextParameters.AsSlice(), + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/traceid.go b/vendor/github.com/authzed/spicedb/internal/graph/traceid.go new file mode 100644 index 0000000..e275bc5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/traceid.go @@ -0,0 +1,13 @@ +package graph + +import ( + "github.com/google/uuid" +) + +// NewTraceID generates a new trace ID. The trace IDs will only be unique with +// a single dispatch request tree and should not be used for any other purpose. +// This function currently uses the UUID library to generate a new trace ID, +// which means it should not be invoked from performance-critical code paths. +func NewTraceID() string { + return uuid.NewString() +} diff --git a/vendor/github.com/authzed/spicedb/internal/grpchelpers/grpchelpers.go b/vendor/github.com/authzed/spicedb/internal/grpchelpers/grpchelpers.go new file mode 100644 index 0000000..de91a0f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/grpchelpers/grpchelpers.go @@ -0,0 +1,20 @@ +package grpchelpers + +import ( + "context" + + "google.golang.org/grpc" +) + +// DialAndWait creates a new client connection to the target and blocks until the connection is ready. +func DialAndWait(ctx context.Context, target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + // TODO: move to NewClient + opts = append(opts, grpc.WithBlock()) // nolint: staticcheck + return grpc.DialContext(ctx, target, opts...) // nolint: staticcheck +} + +// Dial creates a new client connection to the target. +func Dial(ctx context.Context, target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + // TODO: move to NewClient + return grpc.DialContext(ctx, target, opts...) // nolint: staticcheck +} diff --git a/vendor/github.com/authzed/spicedb/internal/logging/logger.go b/vendor/github.com/authzed/spicedb/internal/logging/logger.go new file mode 100644 index 0000000..8204af9 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/logging/logger.go @@ -0,0 +1,43 @@ +package logging + +import ( + "context" + + "github.com/go-logr/zerologr" + "github.com/rs/zerolog" + logf "sigs.k8s.io/controller-runtime/pkg/log" +) + +var Logger zerolog.Logger + +func init() { + SetGlobalLogger(zerolog.Nop()) + logf.SetLogger(zerologr.New(&Logger)) +} + +func SetGlobalLogger(logger zerolog.Logger) { + Logger = logger + zerolog.DefaultContextLogger = &Logger +} + +func With() zerolog.Context { return Logger.With() } + +func Err(err error) *zerolog.Event { return Logger.Err(err) } + +func Trace() *zerolog.Event { return Logger.Trace() } + +func Debug() *zerolog.Event { return Logger.Debug() } + +func Info() *zerolog.Event { return Logger.Info() } + +func Warn() *zerolog.Event { return Logger.Warn() } + +func Error() *zerolog.Event { return Logger.Error() } + +func Fatal() *zerolog.Event { return Logger.Fatal() } + +func WithLevel(level zerolog.Level) *zerolog.Event { return Logger.WithLevel(level) } + +func Log() *zerolog.Event { return Logger.Log() } + +func Ctx(ctx context.Context) *zerolog.Logger { return zerolog.Ctx(ctx) } diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/chain.go b/vendor/github.com/authzed/spicedb/internal/middleware/chain.go new file mode 100644 index 0000000..de08ffc --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/chain.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "context" + + "google.golang.org/grpc" +) + +// Vendored from grpc-go-middleware +// These were removed in v2, see: https://github.com/grpc-ecosystem/go-grpc-middleware/pull/385 + +// ChainUnaryServer creates a single interceptor out of a chain of many interceptors. +// +// Execution is done in left-to-right order, including passing of context. +// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three +// will see context changes of one and two. +func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { + n := len(interceptors) + + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + chainer := func(currentInter grpc.UnaryServerInterceptor, currentHandler grpc.UnaryHandler) grpc.UnaryHandler { + return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) { + return currentInter(currentCtx, currentReq, info, currentHandler) + } + } + + chainedHandler := handler + for i := n - 1; i >= 0; i-- { + chainedHandler = chainer(interceptors[i], chainedHandler) + } + + return chainedHandler(ctx, req) + } +} + +// ChainStreamServer creates a single interceptor out of a chain of many interceptors. +// +// Execution is done in left-to-right order, including passing of context. +// For example ChainUnaryServer(one, two, three) will execute one before two before three. +// If you want to pass context between interceptors, use WrapServerStream. +func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { + n := len(interceptors) + + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + chainer := func(currentInter grpc.StreamServerInterceptor, currentHandler grpc.StreamHandler) grpc.StreamHandler { + return func(currentSrv interface{}, currentStream grpc.ServerStream) error { + return currentInter(currentSrv, currentStream, info, currentHandler) + } + } + + chainedHandler := handler + for i := n - 1; i >= 0; i-- { + chainedHandler = chainer(interceptors[i], chainedHandler) + } + + return chainedHandler(srv, ss) + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/datastore/datastore.go b/vendor/github.com/authzed/spicedb/internal/middleware/datastore/datastore.go new file mode 100644 index 0000000..8c321b3 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/datastore/datastore.go @@ -0,0 +1,85 @@ +package datastore + +import ( + "context" + + middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" + "google.golang.org/grpc" + + "github.com/authzed/spicedb/pkg/datastore" +) + +type ctxKeyType struct{} + +var datastoreKey ctxKeyType = struct{}{} + +type datastoreHandle struct { + datastore datastore.Datastore +} + +// ContextWithHandle adds a placeholder to a context that will later be +// filled by the datastore +func ContextWithHandle(ctx context.Context) context.Context { + return context.WithValue(ctx, datastoreKey, &datastoreHandle{}) +} + +// FromContext reads the selected datastore out of a context.Context +// and returns nil if it does not exist. +func FromContext(ctx context.Context) datastore.Datastore { + if c := ctx.Value(datastoreKey); c != nil { + handle := c.(*datastoreHandle) + return handle.datastore + } + return nil +} + +// MustFromContext reads the selected datastore out of a context.Context and panics if it does not exist +func MustFromContext(ctx context.Context) datastore.Datastore { + datastore := FromContext(ctx) + if datastore == nil { + panic("datastore middleware did not inject datastore") + } + + return datastore +} + +// SetInContext adds a datastore to the given context +func SetInContext(ctx context.Context, datastore datastore.Datastore) error { + handle := ctx.Value(datastoreKey) + if handle == nil { + return nil + } + handle.(*datastoreHandle).datastore = datastore + return nil +} + +// ContextWithDatastore adds the handle and datastore in one step +func ContextWithDatastore(ctx context.Context, datastore datastore.Datastore) context.Context { + return context.WithValue(ctx, datastoreKey, &datastoreHandle{datastore: datastore}) +} + +// UnaryServerInterceptor returns a new unary server interceptor that adds the +// datastore to the context +func UnaryServerInterceptor(datastore datastore.Datastore) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + newCtx := ContextWithHandle(ctx) + if err := SetInContext(newCtx, datastore); err != nil { + return nil, err + } + + return handler(newCtx, req) + } +} + +// StreamServerInterceptor returns a new stream server interceptor that adds the +// datastore to the context +func StreamServerInterceptor(datastore datastore.Datastore) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + wrapped := middleware.WrapServerStream(stream) + wrapped.WrappedContext = ContextWithHandle(wrapped.WrappedContext) + if err := SetInContext(wrapped.WrappedContext, datastore); err != nil { + return err + } + return handler(srv, wrapped) + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/datastore/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/datastore/doc.go new file mode 100644 index 0000000..a4d0cf0 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/datastore/doc.go @@ -0,0 +1,2 @@ +// Package datastore defines middleware that injects the datastore into the context. +package datastore diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/doc.go new file mode 100644 index 0000000..b11553e --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/doc.go @@ -0,0 +1,2 @@ +// Package middleware defines various custom middlewares. +package middleware diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/doc.go new file mode 100644 index 0000000..2d9aa01 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/doc.go @@ -0,0 +1,2 @@ +// Package handwrittenvalidation defines middleware that runs custom-made validations on incoming requests. +package handwrittenvalidation diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/handwrittenvalidation.go b/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/handwrittenvalidation.go new file mode 100644 index 0000000..2adc4b3 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/handwrittenvalidation.go @@ -0,0 +1,54 @@ +package handwrittenvalidation + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type handwrittenValidator interface { + HandwrittenValidate() error +} + +// UnaryServerInterceptor returns a new unary server interceptor that runs the handwritten validation +// on the incoming request, if any. +func UnaryServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + validator, ok := req.(handwrittenValidator) + if ok { + err := validator.HandwrittenValidate() + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%s", err) + } + } + + return handler(ctx, req) +} + +// StreamServerInterceptor returns a new stream server interceptor that runs the handwritten validation +// on the incoming request messages, if any. +func StreamServerInterceptor(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + wrapper := &recvWrapper{stream} + return handler(srv, wrapper) +} + +type recvWrapper struct { + grpc.ServerStream +} + +func (s *recvWrapper) RecvMsg(m interface{}) error { + if err := s.ServerStream.RecvMsg(m); err != nil { + return err + } + + validator, ok := m.(handwrittenValidator) + if ok { + err := validator.HandwrittenValidate() + if err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/doc.go new file mode 100644 index 0000000..7e3ae0f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/doc.go @@ -0,0 +1,2 @@ +// Package servicespecific defines middleware that injects other middlewares. +package servicespecific diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/servicespecific.go b/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/servicespecific.go new file mode 100644 index 0000000..10fe753 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/servicespecific.go @@ -0,0 +1,39 @@ +package servicespecific + +import ( + "context" + + "google.golang.org/grpc" +) + +// ExtraUnaryInterceptor is an interface for a service which has its own bundled +// unary interceptors that must be run. +type ExtraUnaryInterceptor interface { + UnaryInterceptor() grpc.UnaryServerInterceptor +} + +// ExtraStreamInterceptor is an interface for a service which has its own bundled +// stream interceptors that must be run. +type ExtraStreamInterceptor interface { + StreamInterceptor() grpc.StreamServerInterceptor +} + +// UnaryServerInterceptor returns a new unary server interceptor that runs bundled interceptors. +func UnaryServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if hasExtraInterceptor, ok := info.Server.(ExtraUnaryInterceptor); ok { + interceptor := hasExtraInterceptor.UnaryInterceptor() + return interceptor(ctx, req, info, handler) + } + + return handler(ctx, req) +} + +// StreamServerInterceptor returns a new stream server interceptor that runs bundled interceptors. +func StreamServerInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if hasExtraInterceptor, ok := srv.(ExtraStreamInterceptor); ok { + interceptor := hasExtraInterceptor.StreamInterceptor() + return interceptor(srv, stream, info, handler) + } + + return handler(srv, stream) +} diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/doc.go new file mode 100644 index 0000000..9eceb4d --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/doc.go @@ -0,0 +1,2 @@ +// Package streamtimeout defines middleware that cancels the context after a timeout if no new data has been received. +package streamtimeout diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/streamtimeout.go b/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/streamtimeout.go new file mode 100644 index 0000000..8f09fdb --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/streamtimeout.go @@ -0,0 +1,57 @@ +package streamtimeout + +import ( + "context" + "fmt" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// MustStreamServerInterceptor returns a new stream server interceptor that cancels the context +// after a timeout if no new data has been received. +func MustStreamServerInterceptor(timeout time.Duration) grpc.StreamServerInterceptor { + if timeout <= 0 { + panic("timeout must be >= 0 for streaming timeout interceptor") + } + + return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + ctx := stream.Context() + withCancel, internalCancelFn := context.WithCancelCause(ctx) + timer := time.AfterFunc(timeout, func() { + internalCancelFn(spiceerrors.WithCodeAndDetailsAsError(fmt.Errorf("operation took longer than allowed %v to complete", timeout), codes.DeadlineExceeded)) + }) + wrapper := &sendWrapper{stream, withCancel, timer, timeout} + return handler(srv, wrapper) + } +} + +type sendWrapper struct { + grpc.ServerStream + + ctx context.Context + timer *time.Timer + timeout time.Duration +} + +func (s *sendWrapper) Context() context.Context { + return s.ctx +} + +func (s *sendWrapper) SetTrailer(_ metadata.MD) { + s.timer.Stop() +} + +func (s *sendWrapper) SendMsg(m any) error { + err := s.ServerStream.SendMsg(m) + if err != nil { + s.timer.Stop() + } else { + s.timer.Reset(s.timeout) + } + return err +} diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/doc.go new file mode 100644 index 0000000..c05cacc --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/doc.go @@ -0,0 +1,2 @@ +// Package usagemetrics defines middleware that adds usage data (e.g. dispatch counts) to the response. +package usagemetrics diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/usagemetrics.go b/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/usagemetrics.go new file mode 100644 index 0000000..32f5676 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/usagemetrics.go @@ -0,0 +1,128 @@ +package usagemetrics + +import ( + "context" + "strconv" + "time" + + "github.com/authzed/authzed-go/pkg/responsemeta" + "github.com/authzed/grpcutil" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "google.golang.org/grpc" + + log "github.com/authzed/spicedb/internal/logging" + dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" +) + +var ( + // DispatchedCountLabels are the labels that DispatchedCountHistogram will + // have by default. + DispatchedCountLabels = []string{"method", "cached"} + + // DispatchedCountHistogram is the metric that SpiceDB uses to keep track + // of the number of downstream dispatches that are performed to answer a + // single query. + DispatchedCountHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "spicedb", + Subsystem: "services", + Name: "dispatches", + Help: "Histogram of cluster dispatches performed by the instance.", + Buckets: []float64{1, 5, 10, 25, 50, 100, 250}, + }, DispatchedCountLabels) +) + +type reporter struct{} + +func (r *reporter) ServerReporter(ctx context.Context, callMeta interceptors.CallMeta) (interceptors.Reporter, context.Context) { + _, methodName := grpcutil.SplitMethodName(callMeta.FullMethod()) + ctx = ContextWithHandle(ctx) + return &serverReporter{ctx: ctx, methodName: methodName}, ctx +} + +type serverReporter struct { + interceptors.NoopReporter + ctx context.Context + methodName string +} + +func (r *serverReporter) PostCall(_ error, _ time.Duration) { + responseMeta := FromContext(r.ctx) + if responseMeta == nil { + responseMeta = &dispatch.ResponseMeta{} + } + + err := annotateAndReportForMetadata(r.ctx, r.methodName, responseMeta) + // if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite + // this prevents logging unnecessary error messages + if r.ctx.Err() != nil { + return + } + if err != nil { + log.Ctx(r.ctx).Warn().Err(err).Msg("usagemetrics: could not report metadata") + } +} + +// UnaryServerInterceptor implements a gRPC Middleware for reporting usage metrics +// in both the trailer of the request, as well as to the registered prometheus +// metrics. +func UnaryServerInterceptor() grpc.UnaryServerInterceptor { + return interceptors.UnaryServerInterceptor(&reporter{}) +} + +// StreamServerInterceptor implements a gRPC Middleware for reporting usage metrics +// in both the trailer of the request, as well as to the registered prometheus +// metrics +func StreamServerInterceptor() grpc.StreamServerInterceptor { + return interceptors.StreamServerInterceptor(&reporter{}) +} + +func annotateAndReportForMetadata(ctx context.Context, methodName string, metadata *dispatch.ResponseMeta) error { + DispatchedCountHistogram.WithLabelValues(methodName, "false").Observe(float64(metadata.DispatchCount)) + DispatchedCountHistogram.WithLabelValues(methodName, "true").Observe(float64(metadata.CachedDispatchCount)) + + return responsemeta.SetResponseTrailerMetadata(ctx, map[responsemeta.ResponseMetadataTrailerKey]string{ + responsemeta.DispatchedOperationsCount: strconv.Itoa(int(metadata.DispatchCount)), + responsemeta.CachedOperationsCount: strconv.Itoa(int(metadata.CachedDispatchCount)), + }) +} + +// Create a new type to prevent context collisions +type responseMetaKey string + +var metadataCtxKey responseMetaKey = "dispatched-response-meta" + +type metaHandle struct{ metadata *dispatch.ResponseMeta } + +// SetInContext should be called in a gRPC handler to correctly set the response metadata +// for the dispatched request. +func SetInContext(ctx context.Context, metadata *dispatch.ResponseMeta) { + possibleHandle := ctx.Value(metadataCtxKey) + if possibleHandle == nil { + return + } + + handle := possibleHandle.(*metaHandle) + handle.metadata = metadata +} + +// FromContext returns any metadata that was stored in the context. +// +// This is useful for testing that a handler is properly setting the context. +func FromContext(ctx context.Context) *dispatch.ResponseMeta { + possibleHandle := ctx.Value(metadataCtxKey) + if possibleHandle == nil { + return nil + } + return possibleHandle.(*metaHandle).metadata +} + +// ContextWithHandle creates a new context with a location to store metadata +// returned from a dispatched request. +// +// This should only be called in middleware or testing functions. +func ContextWithHandle(ctx context.Context) context.Context { + var handle metaHandle + return context.WithValue(ctx, metadataCtxKey, &handle) +} diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/aliasing.go b/vendor/github.com/authzed/spicedb/internal/namespace/aliasing.go new file mode 100644 index 0000000..adbaa94 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/namespace/aliasing.go @@ -0,0 +1,82 @@ +package namespace + +import ( + "sort" + + "github.com/authzed/spicedb/pkg/schema" +) + +// computePermissionAliases computes a map of aliases between the various permissions in a +// namespace. A permission is considered an alias if it *directly* refers to another permission +// or relation without any other form of expression. +func computePermissionAliases(typeDefinition *schema.ValidatedDefinition) (map[string]string, error) { + aliases := map[string]string{} + done := map[string]struct{}{} + unresolvedAliases := map[string]string{} + + for _, rel := range typeDefinition.Namespace().Relation { + // Ensure the relation has a rewrite... + if rel.GetUsersetRewrite() == nil { + done[rel.Name] = struct{}{} + continue + } + + // ... with a union ... + union := rel.GetUsersetRewrite().GetUnion() + if union == nil { + done[rel.Name] = struct{}{} + continue + } + + // ... with a single child ... + if len(union.Child) != 1 { + done[rel.Name] = struct{}{} + continue + } + + // ... that is a computed userset. + computedUserset := union.Child[0].GetComputedUserset() + if computedUserset == nil { + done[rel.Name] = struct{}{} + continue + } + + // If the aliased item is a relation, then we've found the alias target. + aliasedPermOrRel := computedUserset.GetRelation() + if !typeDefinition.IsPermission(aliasedPermOrRel) { + done[rel.Name] = struct{}{} + aliases[rel.Name] = aliasedPermOrRel + continue + } + + // Otherwise, add the permission to the working set. + unresolvedAliases[rel.Name] = aliasedPermOrRel + } + + for len(unresolvedAliases) > 0 { + startingCount := len(unresolvedAliases) + for relName, aliasedPermission := range unresolvedAliases { + if _, ok := done[aliasedPermission]; ok { + done[relName] = struct{}{} + + if alias, ok := aliases[aliasedPermission]; ok { + aliases[relName] = alias + } else { + aliases[relName] = aliasedPermission + } + delete(unresolvedAliases, relName) + continue + } + } + if len(unresolvedAliases) == startingCount { + keys := make([]string, 0, len(unresolvedAliases)) + for key := range unresolvedAliases { + keys = append(keys, key) + } + sort.Strings(keys) + return nil, NewPermissionsCycleErr(typeDefinition.Namespace().Name, keys) + } + } + + return aliases, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/annotate.go b/vendor/github.com/authzed/spicedb/internal/namespace/annotate.go new file mode 100644 index 0000000..d85edff --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/namespace/annotate.go @@ -0,0 +1,29 @@ +package namespace + +import "github.com/authzed/spicedb/pkg/schema" + +// AnnotateNamespace annotates the namespace in the type system with computed aliasing and cache key +// metadata for more efficient dispatching. +func AnnotateNamespace(def *schema.ValidatedDefinition) error { + aliases, aerr := computePermissionAliases(def) + if aerr != nil { + return aerr + } + + cacheKeys, cerr := computeCanonicalCacheKeys(def, aliases) + if cerr != nil { + return cerr + } + + for _, rel := range def.Namespace().Relation { + if alias, ok := aliases[rel.Name]; ok { + rel.AliasingRelation = alias + } + + if cacheKey, ok := cacheKeys[rel.Name]; ok { + rel.CanonicalCacheKey = cacheKey + } + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/canonicalization.go b/vendor/github.com/authzed/spicedb/internal/namespace/canonicalization.go new file mode 100644 index 0000000..24fa61e --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/namespace/canonicalization.go @@ -0,0 +1,282 @@ +package namespace + +import ( + "encoding/hex" + "hash/fnv" + + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + + "github.com/dalzilio/rudd" + + "github.com/authzed/spicedb/pkg/graph" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +const computedKeyPrefix = "%" + +// computeCanonicalCacheKeys computes a map from permission name to associated canonicalized +// cache key for each non-aliased permission in the given type system's namespace. +// +// Canonicalization works by taking each permission's userset rewrite expression and transforming +// it into a Binary Decision Diagram (BDD) via the `rudd` library. +// +// Each access of a relation or arrow is assigned a unique integer ID within the *namespace*, +// and the operations (+, -, &) are converted into binary operations. +// +// For example, for the namespace: +// +// definition somenamespace { +// relation first: ... +// relation second: ... +// relation third: ... +// permission someperm = second + (first - third->something) +// } +// +// We begin by assigning a unique integer index to each relation and arrow found for all +// expressions in the namespace: +// +// definition somenamespace { +// relation first: ... +// ^ index 0 +// relation second: ... +// ^ index 1 +// relation third: ... +// ^ index 2 +// permission someperm = second + (first - third->something) +// ^ 1 ^ 0 ^ index 3 +// } +// +// These indexes are then used with the rudd library to build the expression: +// +// someperm => `bdd.Or(bdd.Ithvar(1), bdd.And(bdd.Ithvar(0), bdd.NIthvar(2)))` +// +// The `rudd` library automatically handles associativity, and produces a hash representing the +// canonical representation of the binary expression. These hashes can then be used for caching, +// representing the same *logical* expressions for a permission, even if the relations have +// different names. +func computeCanonicalCacheKeys(typeDef *schema.ValidatedDefinition, aliasMap map[string]string) (map[string]string, error) { + varMap, err := buildBddVarMap(typeDef.Namespace().Relation, aliasMap) + if err != nil { + return nil, err + } + + if varMap.Len() == 0 { + return map[string]string{}, nil + } + + bdd, err := rudd.New(varMap.Len()) + if err != nil { + return nil, err + } + + // For each permission, build a canonicalized cache key based on its expression. + cacheKeys := make(map[string]string, len(typeDef.Namespace().Relation)) + for _, rel := range typeDef.Namespace().Relation { + rewrite := rel.GetUsersetRewrite() + if rewrite == nil { + // If the relation has no rewrite (making it a pure relation), then its canonical + // key is simply the relation's name. + cacheKeys[rel.Name] = rel.Name + continue + } + + hasher := fnv.New64a() + node, err := convertRewriteToBdd(rel, bdd, rewrite, varMap) + if err != nil { + return nil, err + } + + bdd.Print(hasher, node) + cacheKeys[rel.Name] = computedKeyPrefix + hex.EncodeToString(hasher.Sum(nil)) + } + + return cacheKeys, nil +} + +func convertRewriteToBdd(relation *core.Relation, bdd *rudd.BDD, rewrite *core.UsersetRewrite, varMap bddVarMap) (rudd.Node, error) { + switch rw := rewrite.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + return convertToBdd(relation, bdd, rw.Union, bdd.Or, func(childIndex int, varIndex int) rudd.Node { + return bdd.Ithvar(varIndex) + }, varMap) + + case *core.UsersetRewrite_Intersection: + return convertToBdd(relation, bdd, rw.Intersection, bdd.And, func(childIndex int, varIndex int) rudd.Node { + return bdd.Ithvar(varIndex) + }, varMap) + + case *core.UsersetRewrite_Exclusion: + return convertToBdd(relation, bdd, rw.Exclusion, bdd.And, func(childIndex int, varIndex int) rudd.Node { + if childIndex == 0 { + return bdd.Ithvar(varIndex) + } + return bdd.NIthvar(varIndex) + }, varMap) + + default: + return nil, spiceerrors.MustBugf("Unknown rewrite kind %v", rw) + } +} + +type ( + combiner func(n ...rudd.Node) rudd.Node + builder func(childIndex int, varIndex int) rudd.Node +) + +func convertToBdd(relation *core.Relation, bdd *rudd.BDD, so *core.SetOperation, combiner combiner, builder builder, varMap bddVarMap) (rudd.Node, error) { + values := make([]rudd.Node, 0, len(so.Child)) + for index, childOneof := range so.Child { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_XThis: + return nil, spiceerrors.MustBugf("use of _this is disallowed") + + case *core.SetOperation_Child_ComputedUserset: + cuIndex, err := varMap.Get(child.ComputedUserset.Relation) + if err != nil { + return nil, err + } + + values = append(values, builder(index, cuIndex)) + + case *core.SetOperation_Child_UsersetRewrite: + node, err := convertRewriteToBdd(relation, bdd, child.UsersetRewrite, varMap) + if err != nil { + return nil, err + } + + values = append(values, node) + + case *core.SetOperation_Child_TupleToUserset: + arrowIndex, err := varMap.GetArrow(child.TupleToUserset.Tupleset.Relation, child.TupleToUserset.ComputedUserset.Relation) + if err != nil { + return nil, err + } + + values = append(values, builder(index, arrowIndex)) + + case *core.SetOperation_Child_FunctionedTupleToUserset: + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + arrowIndex, err := varMap.GetArrow(child.FunctionedTupleToUserset.Tupleset.Relation, child.FunctionedTupleToUserset.ComputedUserset.Relation) + if err != nil { + return nil, err + } + + values = append(values, builder(index, arrowIndex)) + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + arrowIndex, err := varMap.GetIntersectionArrow(child.FunctionedTupleToUserset.Tupleset.Relation, child.FunctionedTupleToUserset.ComputedUserset.Relation) + if err != nil { + return nil, err + } + + values = append(values, builder(index, arrowIndex)) + + default: + return nil, spiceerrors.MustBugf("unknown function %v", child.FunctionedTupleToUserset.Function) + } + + case *core.SetOperation_Child_XNil: + values = append(values, builder(index, varMap.Nil())) + + default: + return nil, spiceerrors.MustBugf("unknown set operation child %T", child) + } + } + return combiner(values...), nil +} + +type bddVarMap struct { + aliasMap map[string]string + varMap map[string]int +} + +func (bvm bddVarMap) GetArrow(tuplesetName string, relName string) (int, error) { + key := tuplesetName + "->" + relName + index, ok := bvm.varMap[key] + if !ok { + return -1, spiceerrors.MustBugf("missing arrow key %s in varMap", key) + } + return index, nil +} + +func (bvm bddVarMap) GetIntersectionArrow(tuplesetName string, relName string) (int, error) { + key := tuplesetName + "-(all)->" + relName + index, ok := bvm.varMap[key] + if !ok { + return -1, spiceerrors.MustBugf("missing intersection arrow key %s in varMap", key) + } + return index, nil +} + +func (bvm bddVarMap) Nil() int { + return len(bvm.varMap) +} + +func (bvm bddVarMap) Get(relName string) (int, error) { + if alias, ok := bvm.aliasMap[relName]; ok { + return bvm.Get(alias) + } + + index, ok := bvm.varMap[relName] + if !ok { + return -1, spiceerrors.MustBugf("missing key %s in varMap", relName) + } + return index, nil +} + +func (bvm bddVarMap) Len() int { + return len(bvm.varMap) + 1 // +1 for `nil` +} + +func buildBddVarMap(relations []*core.Relation, aliasMap map[string]string) (bddVarMap, error) { + varMap := map[string]int{} + for _, rel := range relations { + if _, ok := aliasMap[rel.Name]; ok { + continue + } + + varMap[rel.Name] = len(varMap) + + rewrite := rel.GetUsersetRewrite() + if rewrite == nil { + continue + } + + _, err := graph.WalkRewrite(rewrite, func(childOneof *core.SetOperation_Child) (interface{}, error) { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_TupleToUserset: + key := child.TupleToUserset.Tupleset.Relation + "->" + child.TupleToUserset.ComputedUserset.Relation + if _, ok := varMap[key]; !ok { + varMap[key] = len(varMap) + } + case *core.SetOperation_Child_FunctionedTupleToUserset: + key := child.FunctionedTupleToUserset.Tupleset.Relation + "->" + child.FunctionedTupleToUserset.ComputedUserset.Relation + + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + // Use the key. + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + key = child.FunctionedTupleToUserset.Tupleset.Relation + "-(all)->" + child.FunctionedTupleToUserset.ComputedUserset.Relation + + default: + return nil, spiceerrors.MustBugf("unknown function %v", child.FunctionedTupleToUserset.Function) + } + + if _, ok := varMap[key]; !ok { + varMap[key] = len(varMap) + } + } + return nil, nil + }) + if err != nil { + return bddVarMap{}, err + } + } + return bddVarMap{ + aliasMap: aliasMap, + varMap: varMap, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/caveats.go b/vendor/github.com/authzed/spicedb/internal/namespace/caveats.go new file mode 100644 index 0000000..5ddfa9d --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/namespace/caveats.go @@ -0,0 +1,69 @@ +package namespace + +import ( + "fmt" + + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema" +) + +// ValidateCaveatDefinition validates the parameters and types within the given caveat +// definition, including usage of the parameters. +func ValidateCaveatDefinition(ts *caveattypes.TypeSet, caveat *core.CaveatDefinition) error { + // Ensure all parameters are used by the caveat expression itself. + parameterTypes, err := caveattypes.DecodeParameterTypes(ts, caveat.ParameterTypes) + if err != nil { + return schema.NewTypeWithSourceError( + fmt.Errorf("could not decode caveat parameters `%s`: %w", caveat.Name, err), + caveat, + caveat.Name, + ) + } + + deserialized, err := caveats.DeserializeCaveatWithTypeSet(ts, caveat.SerializedExpression, parameterTypes) + if err != nil { + return schema.NewTypeWithSourceError( + fmt.Errorf("could not decode caveat `%s`: %w", caveat.Name, err), + caveat, + caveat.Name, + ) + } + + if len(caveat.ParameterTypes) == 0 { + return schema.NewTypeWithSourceError( + fmt.Errorf("caveat `%s` must have at least one parameter defined", caveat.Name), + caveat, + caveat.Name, + ) + } + + referencedNames, err := deserialized.ReferencedParameters(maps.Keys(caveat.ParameterTypes)) + if err != nil { + return err + } + + for paramName, paramType := range caveat.ParameterTypes { + _, err := caveattypes.DecodeParameterType(ts, paramType) + if err != nil { + return schema.NewTypeWithSourceError( + fmt.Errorf("type error for parameter `%s` for caveat `%s`: %w", paramName, caveat.Name, err), + caveat, + paramName, + ) + } + + if !referencedNames.Has(paramName) { + return schema.NewTypeWithSourceError( + NewUnusedCaveatParameterErr(caveat.Name, paramName), + caveat, + paramName, + ) + } + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/doc.go b/vendor/github.com/authzed/spicedb/internal/namespace/doc.go new file mode 100644 index 0000000..1546280 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/namespace/doc.go @@ -0,0 +1,2 @@ +// Package namespace provides functions for dealing with and validating types, relations and caveats. +package namespace diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/errors.go b/vendor/github.com/authzed/spicedb/internal/namespace/errors.go new file mode 100644 index 0000000..abe7fe6 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/namespace/errors.go @@ -0,0 +1,171 @@ +package namespace + +import ( + "fmt" + "strings" + + "github.com/rs/zerolog" + + "github.com/authzed/spicedb/internal/sharederrors" +) + +// NamespaceNotFoundError occurs when a namespace was not found. +type NamespaceNotFoundError struct { + error + namespaceName string +} + +// NotFoundNamespaceName is the name of the namespace not found. +func (err NamespaceNotFoundError) NotFoundNamespaceName() string { + return err.namespaceName +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err NamespaceNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err NamespaceNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + } +} + +// RelationNotFoundError occurs when a relation was not found under a namespace. +type RelationNotFoundError struct { + error + namespaceName string + relationName string +} + +// NamespaceName returns the name of the namespace in which the relation was not found. +func (err RelationNotFoundError) NamespaceName() string { + return err.namespaceName +} + +// NotFoundRelationName returns the name of the relation not found. +func (err RelationNotFoundError) NotFoundRelationName() string { + return err.relationName +} + +func (err RelationNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err RelationNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + "relation_or_permission_name": err.relationName, + } +} + +// DuplicateRelationError occurs when a duplicate relation was found inside a namespace. +type DuplicateRelationError struct { + error + namespaceName string + relationName string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err DuplicateRelationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err DuplicateRelationError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + "relation_or_permission_name": err.relationName, + } +} + +// PermissionsCycleError occurs when a cycle exists within permissions. +type PermissionsCycleError struct { + error + namespaceName string + permissionNames []string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err PermissionsCycleError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName).Str("permissions", strings.Join(err.permissionNames, ", ")) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err PermissionsCycleError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + "permission_names": strings.Join(err.permissionNames, ","), + } +} + +// UnusedCaveatParameterError indicates that a caveat parameter is unused in the caveat expression. +type UnusedCaveatParameterError struct { + error + caveatName string + paramName string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err UnusedCaveatParameterError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("caveat", err.caveatName).Str("param", err.paramName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err UnusedCaveatParameterError) DetailsMetadata() map[string]string { + return map[string]string{ + "caveat_name": err.caveatName, + "parameter_name": err.paramName, + } +} + +// NewNamespaceNotFoundErr constructs a new namespace not found error. +func NewNamespaceNotFoundErr(nsName string) error { + return NamespaceNotFoundError{ + error: fmt.Errorf("object definition `%s` not found", nsName), + namespaceName: nsName, + } +} + +// NewRelationNotFoundErr constructs a new relation not found error. +func NewRelationNotFoundErr(nsName string, relationName string) error { + return RelationNotFoundError{ + error: fmt.Errorf("relation/permission `%s` not found under definition `%s`", relationName, nsName), + namespaceName: nsName, + relationName: relationName, + } +} + +// NewDuplicateRelationError constructs an error indicating that a relation was defined more than once in a namespace. +func NewDuplicateRelationError(nsName string, relationName string) error { + return DuplicateRelationError{ + error: fmt.Errorf("found duplicate relation/permission name `%s` under definition `%s`", relationName, nsName), + namespaceName: nsName, + relationName: relationName, + } +} + +// NewPermissionsCycleErr constructs an error indicating that a cycle exists amongst permissions. +func NewPermissionsCycleErr(nsName string, permissionNames []string) error { + return PermissionsCycleError{ + error: fmt.Errorf("under definition `%s`, there exists a cycle in permissions: %s", nsName, strings.Join(permissionNames, ", ")), + namespaceName: nsName, + permissionNames: permissionNames, + } +} + +// NewUnusedCaveatParameterErr constructs indicating that a parameter was unused in a caveat expression. +func NewUnusedCaveatParameterErr(caveatName string, paramName string) error { + return UnusedCaveatParameterError{ + error: fmt.Errorf("parameter `%s` for caveat `%s` is unused", paramName, caveatName), + caveatName: caveatName, + paramName: paramName, + } +} + +var ( + _ sharederrors.UnknownNamespaceError = NamespaceNotFoundError{} + _ sharederrors.UnknownRelationError = RelationNotFoundError{} +) diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/util.go b/vendor/github.com/authzed/spicedb/internal/namespace/util.go new file mode 100644 index 0000000..497bdfb --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/namespace/util.go @@ -0,0 +1,148 @@ +package namespace + +import ( + "context" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// ReadNamespaceAndRelation checks that the specified namespace and relation exist in the +// datastore. +// +// Returns NamespaceNotFoundError if the namespace cannot be found. +// Returns RelationNotFoundError if the relation was not found in the namespace. +// Returns the direct downstream error for all other unknown error. +func ReadNamespaceAndRelation( + ctx context.Context, + namespace string, + relation string, + ds datastore.Reader, +) (*core.NamespaceDefinition, *core.Relation, error) { + config, _, err := ds.ReadNamespaceByName(ctx, namespace) + if err != nil { + return nil, nil, err + } + + for _, rel := range config.Relation { + if rel.Name == relation { + return config, rel, nil + } + } + + return nil, nil, NewRelationNotFoundErr(namespace, relation) +} + +// TypeAndRelationToCheck is a single check of a namespace+relation pair. +type TypeAndRelationToCheck struct { + // NamespaceName is the namespace name to ensure exists. + NamespaceName string + + // RelationName is the relation name to ensure exists under the namespace. + RelationName string + + // AllowEllipsis, if true, allows for the ellipsis as the RelationName. + AllowEllipsis bool +} + +// CheckNamespaceAndRelations ensures that the given namespace+relation checks all succeed. If any fail, returns an error. +// +// Returns NamespaceNotFoundError if the namespace cannot be found. +// Returns RelationNotFoundError if the relation was not found in the namespace. +// Returns the direct downstream error for all other unknown error. +func CheckNamespaceAndRelations(ctx context.Context, checks []TypeAndRelationToCheck, ds datastore.Reader) error { + nsNames := mapz.NewSet[string]() + for _, toCheck := range checks { + nsNames.Insert(toCheck.NamespaceName) + } + + if nsNames.IsEmpty() { + return nil + } + + namespaces, err := ds.LookupNamespacesWithNames(ctx, nsNames.AsSlice()) + if err != nil { + return err + } + + mappedNamespaces := make(map[string]*core.NamespaceDefinition, len(namespaces)) + for _, namespace := range namespaces { + mappedNamespaces[namespace.Definition.Name] = namespace.Definition + } + + for _, toCheck := range checks { + nsDef, ok := mappedNamespaces[toCheck.NamespaceName] + if !ok { + return NewNamespaceNotFoundErr(toCheck.NamespaceName) + } + + if toCheck.AllowEllipsis && toCheck.RelationName == datastore.Ellipsis { + continue + } + + foundRelation := false + for _, rel := range nsDef.Relation { + if rel.Name == toCheck.RelationName { + foundRelation = true + break + } + } + + if !foundRelation { + return NewRelationNotFoundErr(toCheck.NamespaceName, toCheck.RelationName) + } + } + + return nil +} + +// CheckNamespaceAndRelation checks that the specified namespace and relation exist in the +// datastore. +// +// Returns datastore.NamespaceNotFoundError if the namespace cannot be found. +// Returns RelationNotFoundError if the relation was not found in the namespace. +// Returns the direct downstream error for all other unknown error. +func CheckNamespaceAndRelation( + ctx context.Context, + namespace string, + relation string, + allowEllipsis bool, + ds datastore.Reader, +) error { + config, _, err := ds.ReadNamespaceByName(ctx, namespace) + if err != nil { + return err + } + + if allowEllipsis && relation == datastore.Ellipsis { + return nil + } + + for _, rel := range config.Relation { + if rel.Name == relation { + return nil + } + } + + return NewRelationNotFoundErr(namespace, relation) +} + +// ListReferencedNamespaces returns the names of all namespaces referenced in the +// given namespace definitions. This includes the namespaces themselves, as well as +// any found in type information on relations. +func ListReferencedNamespaces(nsdefs []*core.NamespaceDefinition) []string { + referencedNamespaceNamesSet := mapz.NewSet[string]() + for _, nsdef := range nsdefs { + referencedNamespaceNamesSet.Insert(nsdef.Name) + + for _, relation := range nsdef.Relation { + if relation.GetTypeInformation() != nil { + for _, allowedRel := range relation.GetTypeInformation().AllowedDirectRelations { + referencedNamespaceNamesSet.Insert(allowedRel.GetNamespace()) + } + } + } + } + return referencedNamespaceNamesSet.AsSlice() +} diff --git a/vendor/github.com/authzed/spicedb/internal/relationships/doc.go b/vendor/github.com/authzed/spicedb/internal/relationships/doc.go new file mode 100644 index 0000000..6e1bfc6 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/relationships/doc.go @@ -0,0 +1,2 @@ +// Package relationships contains helper methods to validate relationships that are going to be written. +package relationships diff --git a/vendor/github.com/authzed/spicedb/internal/relationships/errors.go b/vendor/github.com/authzed/spicedb/internal/relationships/errors.go new file mode 100644 index 0000000..3237e0b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/relationships/errors.go @@ -0,0 +1,195 @@ +package relationships + +import ( + "fmt" + "maps" + "sort" + "strings" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/lithammer/fuzzysearch/fuzzy" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// InvalidSubjectTypeError indicates that a write was attempted with a subject type which is not +// allowed on relation. +type InvalidSubjectTypeError struct { + error + relationship tuple.Relationship + relationType *core.AllowedRelation + additionalDetails map[string]string +} + +// NewInvalidSubjectTypeError constructs a new error for attempting to write an invalid subject type. +func NewInvalidSubjectTypeError( + relationship tuple.Relationship, + relationType *core.AllowedRelation, + definition *schema.Definition, +) error { + allowedTypes, err := definition.AllowedDirectRelationsAndWildcards(relationship.Resource.Relation) + if err != nil { + return err + } + + // Special case: if the subject is uncaveated but only a caveated version is allowed, return + // a more descriptive error. + if relationship.OptionalCaveat == nil { + allowedCaveatsForSubject := mapz.NewSet[string]() + + for _, allowedType := range allowedTypes { + if allowedType.RequiredCaveat != nil && + allowedType.RequiredCaveat.CaveatName != "" && + allowedType.Namespace == relationship.Subject.ObjectType && + allowedType.GetRelation() == relationship.Subject.Relation && + (allowedType.RequiredExpiration != nil) == (relationship.OptionalExpiration != nil) { + allowedCaveatsForSubject.Add(allowedType.RequiredCaveat.CaveatName) + } + } + + if !allowedCaveatsForSubject.IsEmpty() { + return InvalidSubjectTypeError{ + error: fmt.Errorf( + "subjects of type `%s` are not allowed on relation `%s#%s` without one of the following caveats: %s", + schema.SourceForAllowedRelation(relationType), + relationship.Resource.ObjectType, + relationship.Resource.Relation, + strings.Join(allowedCaveatsForSubject.AsSlice(), ","), + ), + relationship: relationship, + relationType: relationType, + additionalDetails: map[string]string{ + "allowed_caveats": strings.Join(allowedCaveatsForSubject.AsSlice(), ","), + }, + } + } + } + + allowedTypeStrings := make([]string, 0, len(allowedTypes)) + for _, allowedType := range allowedTypes { + allowedTypeStrings = append(allowedTypeStrings, schema.SourceForAllowedRelation(allowedType)) + } + + matches := fuzzy.RankFind(schema.SourceForAllowedRelation(relationType), allowedTypeStrings) + sort.Sort(matches) + if len(matches) > 0 { + return InvalidSubjectTypeError{ + error: fmt.Errorf( + "subjects of type `%s` are not allowed on relation `%s#%s`; did you mean `%s`?", + schema.SourceForAllowedRelation(relationType), + relationship.Resource.ObjectType, + relationship.Resource.Relation, + matches[0].Target, + ), + relationship: relationship, + relationType: relationType, + additionalDetails: nil, + } + } + + return InvalidSubjectTypeError{ + error: fmt.Errorf( + "subjects of type `%s` are not allowed on relation `%s#%s`", + schema.SourceForAllowedRelation(relationType), + relationship.Resource.ObjectType, + relationship.Resource.Relation, + ), + relationship: relationship, + relationType: relationType, + additionalDetails: nil, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err InvalidSubjectTypeError) GRPCStatus() *status.Status { + details := map[string]string{ + "definition_name": err.relationship.Resource.ObjectType, + "relation_name": err.relationship.Resource.Relation, + "subject_type": schema.SourceForAllowedRelation(err.relationType), + } + + if err.additionalDetails != nil { + maps.Copy(details, err.additionalDetails) + } + + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_INVALID_SUBJECT_TYPE, + details, + ), + ) +} + +// CannotWriteToPermissionError indicates that a write was attempted on a permission. +type CannotWriteToPermissionError struct { + error + rel tuple.Relationship +} + +// NewCannotWriteToPermissionError constructs a new error for attempting to write to a permission. +func NewCannotWriteToPermissionError(rel tuple.Relationship) CannotWriteToPermissionError { + return CannotWriteToPermissionError{ + error: fmt.Errorf( + "cannot write a relationship to permission `%s` under definition `%s`", + rel.Resource.Relation, + rel.Resource.ObjectType, + ), + rel: rel, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err CannotWriteToPermissionError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_CANNOT_UPDATE_PERMISSION, + map[string]string{ + "definition_name": err.rel.Resource.ObjectType, + "permission_name": err.rel.Resource.Relation, + }, + ), + ) +} + +// CaveatNotFoundError indicates that a caveat referenced in a relationship update was not found. +type CaveatNotFoundError struct { + error + relationship tuple.Relationship +} + +// NewCaveatNotFoundError constructs a new caveat not found error. +func NewCaveatNotFoundError(relationship tuple.Relationship) CaveatNotFoundError { + return CaveatNotFoundError{ + error: fmt.Errorf( + "the caveat `%s` was not found for relationship `%s`", + relationship.OptionalCaveat.CaveatName, + tuple.MustString(relationship), + ), + relationship: relationship, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err CaveatNotFoundError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.FailedPrecondition, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNKNOWN_CAVEAT, + map[string]string{ + "caveat_name": err.relationship.OptionalCaveat.CaveatName, + }, + ), + ) +} diff --git a/vendor/github.com/authzed/spicedb/internal/relationships/validation.go b/vendor/github.com/authzed/spicedb/internal/relationships/validation.go new file mode 100644 index 0000000..ff4a6fb --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/relationships/validation.go @@ -0,0 +1,280 @@ +package relationships + +import ( + "context" + + "github.com/samber/lo" + + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/mapz" + ns "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ValidateRelationshipUpdates performs validation on the given relationship updates, ensuring that +// they can be applied against the datastore. +func ValidateRelationshipUpdates( + ctx context.Context, + reader datastore.Reader, + caveatTypeSet *caveattypes.TypeSet, + updates []tuple.RelationshipUpdate, +) error { + rels := lo.Map(updates, func(item tuple.RelationshipUpdate, _ int) tuple.Relationship { + return item.Relationship + }) + + // Load namespaces and caveats. + referencedNamespaceMap, referencedCaveatMap, err := loadNamespacesAndCaveats(ctx, rels, reader) + if err != nil { + return err + } + + // Validate each updates's types. + for _, update := range updates { + option := ValidateRelationshipForCreateOrTouch + if update.Operation == tuple.UpdateOperationDelete { + option = ValidateRelationshipForDeletion + } + + if err := ValidateOneRelationship( + referencedNamespaceMap, + referencedCaveatMap, + caveatTypeSet, + update.Relationship, + option, + ); err != nil { + return err + } + } + + return nil +} + +// ValidateRelationshipsForCreateOrTouch performs validation on the given relationships to be written, ensuring that +// they can be applied against the datastore. +// +// NOTE: This method *cannot* be used for relationships that will be deleted. +func ValidateRelationshipsForCreateOrTouch( + ctx context.Context, + reader datastore.Reader, + caveatTypeSet *caveattypes.TypeSet, + rels ...tuple.Relationship, +) error { + // Load namespaces and caveats. + referencedNamespaceMap, referencedCaveatMap, err := loadNamespacesAndCaveats(ctx, rels, reader) + if err != nil { + return err + } + + // Validate each relationship's types. + for _, rel := range rels { + if err := ValidateOneRelationship( + referencedNamespaceMap, + referencedCaveatMap, + caveatTypeSet, + rel, + ValidateRelationshipForCreateOrTouch, + ); err != nil { + return err + } + } + + return nil +} + +func loadNamespacesAndCaveats(ctx context.Context, rels []tuple.Relationship, reader datastore.Reader) (map[string]*schema.Definition, map[string]*core.CaveatDefinition, error) { + referencedNamespaceNames := mapz.NewSet[string]() + referencedCaveatNamesWithContext := mapz.NewSet[string]() + for _, rel := range rels { + referencedNamespaceNames.Insert(rel.Resource.ObjectType) + referencedNamespaceNames.Insert(rel.Subject.ObjectType) + if hasNonEmptyCaveatContext(rel) { + referencedCaveatNamesWithContext.Insert(rel.OptionalCaveat.CaveatName) + } + } + + var referencedNamespaceMap map[string]*schema.Definition + var referencedCaveatMap map[string]*core.CaveatDefinition + + if !referencedNamespaceNames.IsEmpty() { + foundNamespaces, err := reader.LookupNamespacesWithNames(ctx, referencedNamespaceNames.AsSlice()) + if err != nil { + return nil, nil, err + } + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(reader)) + + referencedNamespaceMap = make(map[string]*schema.Definition, len(foundNamespaces)) + for _, nsDef := range foundNamespaces { + nts, err := schema.NewDefinition(ts, nsDef.Definition) + if err != nil { + return nil, nil, err + } + + referencedNamespaceMap[nsDef.Definition.Name] = nts + } + } + + if !referencedCaveatNamesWithContext.IsEmpty() { + foundCaveats, err := reader.LookupCaveatsWithNames(ctx, referencedCaveatNamesWithContext.AsSlice()) + if err != nil { + return nil, nil, err + } + + referencedCaveatMap = make(map[string]*core.CaveatDefinition, len(foundCaveats)) + for _, caveatDef := range foundCaveats { + referencedCaveatMap[caveatDef.Definition.Name] = caveatDef.Definition + } + } + return referencedNamespaceMap, referencedCaveatMap, nil +} + +// ValidationRelationshipRule is the rule to use for the validation. +type ValidationRelationshipRule int + +const ( + // ValidateRelationshipForCreateOrTouch indicates that the validation should occur for a CREATE or TOUCH operation. + ValidateRelationshipForCreateOrTouch ValidationRelationshipRule = 0 + + // ValidateRelationshipForDeletion indicates that the validation should occur for a DELETE operation. + ValidateRelationshipForDeletion ValidationRelationshipRule = 1 +) + +// ValidateOneRelationship validates a single relationship for CREATE/TOUCH or DELETE. +func ValidateOneRelationship( + namespaceMap map[string]*schema.Definition, + caveatMap map[string]*core.CaveatDefinition, + caveatTypeSet *caveattypes.TypeSet, + rel tuple.Relationship, + rule ValidationRelationshipRule, +) error { + // Validate the IDs of the resource and subject. + if err := tuple.ValidateResourceID(rel.Resource.ObjectID); err != nil { + return err + } + + if err := tuple.ValidateSubjectID(rel.Subject.ObjectID); err != nil { + return err + } + + // Validate the namespace and relation for the resource. + resourceTS, ok := namespaceMap[rel.Resource.ObjectType] + if !ok { + return namespace.NewNamespaceNotFoundErr(rel.Resource.ObjectType) + } + + if !resourceTS.HasRelation(rel.Resource.Relation) { + return namespace.NewRelationNotFoundErr(rel.Resource.ObjectType, rel.Resource.Relation) + } + + // Validate the namespace and relation for the subject. + subjectTS, ok := namespaceMap[rel.Subject.ObjectType] + if !ok { + return namespace.NewNamespaceNotFoundErr(rel.Subject.ObjectType) + } + + if rel.Subject.Relation != tuple.Ellipsis { + if !subjectTS.HasRelation(rel.Subject.Relation) { + return namespace.NewRelationNotFoundErr(rel.Subject.ObjectType, rel.Subject.Relation) + } + } + + // Validate that the relationship is not writing to a permission. + if resourceTS.IsPermission(rel.Resource.Relation) { + return NewCannotWriteToPermissionError(rel) + } + + // Validate the subject against the allowed relation(s). + var caveat *core.AllowedCaveat + if rel.OptionalCaveat != nil { + caveat = ns.AllowedCaveat(rel.OptionalCaveat.CaveatName) + } + + var relationToCheck *core.AllowedRelation + if rel.Subject.ObjectID == tuple.PublicWildcard { + relationToCheck = ns.AllowedPublicNamespaceWithCaveat(rel.Subject.ObjectType, caveat) + } else { + relationToCheck = ns.AllowedRelationWithCaveat( + rel.Subject.ObjectType, + rel.Subject.Relation, + caveat) + } + + if rel.OptionalExpiration != nil { + relationToCheck = ns.WithExpiration(relationToCheck) + } + + switch { + case rule == ValidateRelationshipForCreateOrTouch || caveat != nil: + // For writing or when the caveat was specified, the caveat must be a direct match. + isAllowed, err := resourceTS.HasAllowedRelation( + rel.Resource.Relation, + relationToCheck) + if err != nil { + return err + } + + if isAllowed != schema.AllowedRelationValid { + return NewInvalidSubjectTypeError(rel, relationToCheck, resourceTS) + } + + case rule == ValidateRelationshipForDeletion && caveat == nil: + // For deletion, the caveat *can* be ignored if not specified. + if rel.Subject.ObjectID == tuple.PublicWildcard { + isAllowed, err := resourceTS.IsAllowedPublicNamespace(rel.Resource.Relation, rel.Subject.ObjectType) + if err != nil { + return err + } + + if isAllowed != schema.PublicSubjectAllowed { + return NewInvalidSubjectTypeError(rel, relationToCheck, resourceTS) + } + } else { + isAllowed, err := resourceTS.IsAllowedDirectRelation(rel.Resource.Relation, rel.Subject.ObjectType, rel.Subject.Relation) + if err != nil { + return err + } + + if isAllowed != schema.DirectRelationValid { + return NewInvalidSubjectTypeError(rel, relationToCheck, resourceTS) + } + } + + default: + return spiceerrors.MustBugf("unknown validate rule") + } + + // Validate caveat and its context, if applicable. + if hasNonEmptyCaveatContext(rel) { + caveat, ok := caveatMap[rel.OptionalCaveat.CaveatName] + if !ok { + // Should ideally never happen since the caveat is type checked above, but just in case. + return NewCaveatNotFoundError(rel) + } + + // Verify that the provided context information matches the types of the parameters defined. + _, err := caveats.ConvertContextToParameters( + caveatTypeSet, + rel.OptionalCaveat.Context.AsMap(), + caveat.ParameterTypes, + caveats.ErrorForUnknownParameters, + ) + if err != nil { + return err + } + } + + return nil +} + +func hasNonEmptyCaveatContext(relationship tuple.Relationship) bool { + return relationship.OptionalCaveat != nil && + relationship.OptionalCaveat.CaveatName != "" && + relationship.OptionalCaveat.Context != nil && + len(relationship.OptionalCaveat.Context.GetFields()) > 0 +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go b/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go new file mode 100644 index 0000000..05b3907 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go @@ -0,0 +1,208 @@ +package shared + +import ( + "context" + "errors" + "fmt" + "strconv" + + "github.com/rs/zerolog" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph" + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/sharederrors" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// ErrServiceReadOnly is an extended GRPC error returned when a service is in read-only mode. +var ErrServiceReadOnly = mustMakeStatusReadonly() + +func mustMakeStatusReadonly() error { + status, err := status.New(codes.Unavailable, "service read-only").WithDetails(&errdetails.ErrorInfo{ + Reason: v1.ErrorReason_name[int32(v1.ErrorReason_ERROR_REASON_SERVICE_READ_ONLY)], + Domain: spiceerrors.Domain, + }) + if err != nil { + panic("error constructing shared error type") + } + return status.Err() +} + +// NewSchemaWriteDataValidationError creates a new error representing that a schema write cannot be +// completed due to existing data that would be left unreferenced. +func NewSchemaWriteDataValidationError(message string, args ...any) SchemaWriteDataValidationError { + return SchemaWriteDataValidationError{ + error: fmt.Errorf(message, args...), + } +} + +// SchemaWriteDataValidationError occurs when a schema cannot be applied due to leaving data unreferenced. +type SchemaWriteDataValidationError struct { + error +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err SchemaWriteDataValidationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err SchemaWriteDataValidationError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_SCHEMA_TYPE_ERROR, + map[string]string{}, + ), + ) +} + +// MaxDepthExceededError is an error returned when the maximum depth for dispatching has been exceeded. +type MaxDepthExceededError struct { + *spiceerrors.WithAdditionalDetailsError + + // AllowedMaximumDepth is the configured allowed maximum depth. + AllowedMaximumDepth uint32 +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err MaxDepthExceededError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.ResourceExhausted, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_MAXIMUM_DEPTH_EXCEEDED, + err.AddToDetails(map[string]string{ + "maximum_depth_allowed": strconv.Itoa(int(err.AllowedMaximumDepth)), + }), + ), + ) +} + +// NewMaxDepthExceededError creates a new MaxDepthExceededError. +func NewMaxDepthExceededError(allowedMaximumDepth uint32, isCheckRequest bool) error { + if isCheckRequest { + return MaxDepthExceededError{ + spiceerrors.NewWithAdditionalDetailsError(fmt.Errorf("the check request has exceeded the allowable maximum depth of %d: this usually indicates a recursive or too deep data dependency. Try running zed with --explain to see the dependency. See: https://spicedb.dev/d/debug-max-depth-check", allowedMaximumDepth)), + allowedMaximumDepth, + } + } + + return MaxDepthExceededError{ + spiceerrors.NewWithAdditionalDetailsError(fmt.Errorf("the request has exceeded the allowable maximum depth of %d: this usually indicates a recursive or too deep data dependency. See: https://spicedb.dev/d/debug-max-depth", allowedMaximumDepth)), + allowedMaximumDepth, + } +} + +func AsValidationError(err error) *SchemaWriteDataValidationError { + var validationErr SchemaWriteDataValidationError + if errors.As(err, &validationErr) { + return &validationErr + } + return nil +} + +type ConfigForErrors struct { + MaximumAPIDepth uint32 + DebugTrace *v1.DebugInformation +} + +func RewriteErrorWithoutConfig(ctx context.Context, err error) error { + return rewriteError(ctx, err, nil) +} + +func RewriteError(ctx context.Context, err error, config *ConfigForErrors) error { + rerr := rewriteError(ctx, err, config) + if config != nil && config.DebugTrace != nil { + spiceerrors.WithAdditionalDetails(rerr, spiceerrors.DebugTraceErrorDetailsKey, config.DebugTrace.String()) + } + return rerr +} + +func rewriteError(ctx context.Context, err error, config *ConfigForErrors) error { + // Check if the error can be directly used. + if _, ok := status.FromError(err); ok { + return err + } + + // Otherwise, convert any graph/datastore errors. + var nsNotFoundError sharederrors.UnknownNamespaceError + var relationNotFoundError sharederrors.UnknownRelationError + + var compilerError compiler.BaseCompilerError + var sourceError spiceerrors.WithSourceError + var typeError schema.TypeError + var maxDepthError dispatch.MaxDepthExceededError + + switch { + case errors.As(err, &typeError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_SCHEMA_TYPE_ERROR) + case errors.As(err, &compilerError): + return spiceerrors.WithCodeAndReason(err, codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_SCHEMA_PARSE_ERROR) + case errors.As(err, &sourceError): + return spiceerrors.WithCodeAndReason(err, codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_SCHEMA_PARSE_ERROR) + + case errors.Is(err, cursor.ErrHashMismatch): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_INVALID_CURSOR) + + case errors.As(err, &nsNotFoundError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_DEFINITION) + case errors.As(err, &relationNotFoundError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_RELATION_OR_PERMISSION) + + case errors.As(err, &maxDepthError): + if config == nil { + return spiceerrors.MustBugf("missing config for API error") + } + + _, isCheckRequest := maxDepthError.Request.(*dispatchv1.DispatchCheckRequest) + return NewMaxDepthExceededError(config.MaximumAPIDepth, isCheckRequest) + + case errors.As(err, &datastore.ReadOnlyError{}): + return ErrServiceReadOnly + case errors.As(err, &datastore.InvalidRevisionError{}): + return status.Errorf(codes.OutOfRange, "invalid zedtoken: %s", err) + case errors.As(err, &datastore.CaveatNameNotFoundError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_CAVEAT) + case errors.As(err, &datastore.WatchDisabledError{}): + return status.Errorf(codes.FailedPrecondition, "%s", err) + case errors.As(err, &datastore.CounterAlreadyRegisteredError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_COUNTER_ALREADY_REGISTERED) + case errors.As(err, &datastore.CounterNotRegisteredError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_COUNTER_NOT_REGISTERED) + + case errors.As(err, &graph.RelationMissingTypeInfoError{}): + return status.Errorf(codes.FailedPrecondition, "failed precondition: %s", err) + case errors.As(err, &graph.AlwaysFailError{}): + log.Ctx(ctx).Err(err).Msg("received internal error") + return status.Errorf(codes.Internal, "internal error: %s", err) + case errors.As(err, &graph.UnimplementedError{}): + return status.Errorf(codes.Unimplemented, "%s", err) + case errors.Is(err, context.DeadlineExceeded): + return status.Errorf(codes.DeadlineExceeded, "%s", err) + case errors.Is(err, context.Canceled): + err := context.Cause(ctx) + if err != nil { + if _, ok := status.FromError(err); ok { + return err + } + } + + return status.Errorf(codes.Canceled, "%s", err) + default: + log.Ctx(ctx).Err(err).Msg("received unexpected error") + return err + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go b/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go new file mode 100644 index 0000000..455de0a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go @@ -0,0 +1,52 @@ +package shared + +import ( + "google.golang.org/grpc" + + "github.com/authzed/spicedb/internal/middleware/servicespecific" +) + +// WithUnaryServiceSpecificInterceptor is a helper to add a unary interceptor or interceptor +// chain to a service. +type WithUnaryServiceSpecificInterceptor struct { + Unary grpc.UnaryServerInterceptor +} + +// UnaryInterceptor implements servicespecific.ExtraUnaryInterceptor +func (wussi WithUnaryServiceSpecificInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor { + return wussi.Unary +} + +// WithStreamServiceSpecificInterceptor is a helper to add a stream interceptor or interceptor +// chain to a service. +type WithStreamServiceSpecificInterceptor struct { + Stream grpc.StreamServerInterceptor +} + +// StreamInterceptor implements servicespecific.ExtraStreamInterceptor +func (wsssi WithStreamServiceSpecificInterceptor) StreamInterceptor() grpc.StreamServerInterceptor { + return wsssi.Stream +} + +// WithServiceSpecificInterceptors is a helper to add both a unary and stream interceptor +// or interceptor chain to a service. +type WithServiceSpecificInterceptors struct { + Unary grpc.UnaryServerInterceptor + Stream grpc.StreamServerInterceptor +} + +// UnaryInterceptor implements servicespecific.ExtraUnaryInterceptor +func (wssi WithServiceSpecificInterceptors) UnaryInterceptor() grpc.UnaryServerInterceptor { + return wssi.Unary +} + +// StreamInterceptor implements servicespecific.ExtraStreamInterceptor +func (wssi WithServiceSpecificInterceptors) StreamInterceptor() grpc.StreamServerInterceptor { + return wssi.Stream +} + +var ( + _ servicespecific.ExtraUnaryInterceptor = WithUnaryServiceSpecificInterceptor{} + _ servicespecific.ExtraUnaryInterceptor = WithServiceSpecificInterceptors{} + _ servicespecific.ExtraStreamInterceptor = WithServiceSpecificInterceptors{} +) diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go b/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go new file mode 100644 index 0000000..83accde --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go @@ -0,0 +1,474 @@ +package shared + +import ( + "context" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/namespace" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats" + nsdiff "github.com/authzed/spicedb/pkg/diff/namespace" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ValidatedSchemaChanges is a set of validated schema changes that can be applied to the datastore. +type ValidatedSchemaChanges struct { + compiled *compiler.CompiledSchema + validatedTypeSystems map[string]*schema.ValidatedDefinition + newCaveatDefNames *mapz.Set[string] + newObjectDefNames *mapz.Set[string] + additiveOnly bool +} + +// ValidateSchemaChanges validates the schema found in the compiled schema and returns a +// ValidatedSchemaChanges, if fully validated. +func ValidateSchemaChanges(ctx context.Context, compiled *compiler.CompiledSchema, caveatTypeSet *caveattypes.TypeSet, additiveOnly bool) (*ValidatedSchemaChanges, error) { + // 1) Validate the caveats defined. + newCaveatDefNames := mapz.NewSet[string]() + for _, caveatDef := range compiled.CaveatDefinitions { + if err := namespace.ValidateCaveatDefinition(caveatTypeSet, caveatDef); err != nil { + return nil, err + } + + newCaveatDefNames.Insert(caveatDef.Name) + } + + // 2) Validate the namespaces defined. + newObjectDefNames := mapz.NewSet[string]() + validatedTypeSystems := make(map[string]*schema.ValidatedDefinition, len(compiled.ObjectDefinitions)) + res := schema.ResolverForPredefinedDefinitions(schema.PredefinedElements{ + Definitions: compiled.ObjectDefinitions, + Caveats: compiled.CaveatDefinitions, + }) + ts := schema.NewTypeSystem(res) + + for _, nsdef := range compiled.ObjectDefinitions { + vts, err := ts.GetValidatedDefinition(ctx, nsdef.GetName()) + if err != nil { + return nil, err + } + + validatedTypeSystems[nsdef.Name] = vts + newObjectDefNames.Insert(nsdef.Name) + } + + return &ValidatedSchemaChanges{ + compiled: compiled, + validatedTypeSystems: validatedTypeSystems, + newCaveatDefNames: newCaveatDefNames, + newObjectDefNames: newObjectDefNames, + additiveOnly: additiveOnly, + }, nil +} + +// AppliedSchemaChanges holds information about the applied schema changes. +type AppliedSchemaChanges struct { + // TotalOperationCount holds the total number of "dispatch" operations performed by the schema + // being applied. + TotalOperationCount int + + // NewObjectDefNames contains the names of the newly added object definitions. + NewObjectDefNames []string + + // RemovedObjectDefNames contains the names of the removed object definitions. + RemovedObjectDefNames []string + + // NewCaveatDefNames contains the names of the newly added caveat definitions. + NewCaveatDefNames []string + + // RemovedCaveatDefNames contains the names of the removed caveat definitions. + RemovedCaveatDefNames []string +} + +// ApplySchemaChanges applies schema changes found in the validated changes struct, via the specified +// ReadWriteTransaction. +func ApplySchemaChanges(ctx context.Context, rwt datastore.ReadWriteTransaction, caveatTypeSet *caveattypes.TypeSet, validated *ValidatedSchemaChanges) (*AppliedSchemaChanges, error) { + existingCaveats, err := rwt.ListAllCaveats(ctx) + if err != nil { + return nil, err + } + + existingObjectDefs, err := rwt.ListAllNamespaces(ctx) + if err != nil { + return nil, err + } + + return ApplySchemaChangesOverExisting(ctx, rwt, caveatTypeSet, validated, datastore.DefinitionsOf(existingCaveats), datastore.DefinitionsOf(existingObjectDefs)) +} + +// ApplySchemaChangesOverExisting applies schema changes found in the validated changes struct, against +// existing caveat and object definitions given. +func ApplySchemaChangesOverExisting( + ctx context.Context, + rwt datastore.ReadWriteTransaction, + caveatTypeSet *caveattypes.TypeSet, + validated *ValidatedSchemaChanges, + existingCaveats []*core.CaveatDefinition, + existingObjectDefs []*core.NamespaceDefinition, +) (*AppliedSchemaChanges, error) { + // Build a map of existing caveats to determine those being removed, if any. + existingCaveatDefMap := make(map[string]*core.CaveatDefinition, len(existingCaveats)) + existingCaveatDefNames := mapz.NewSet[string]() + + for _, existingCaveat := range existingCaveats { + existingCaveatDefMap[existingCaveat.Name] = existingCaveat + existingCaveatDefNames.Insert(existingCaveat.Name) + } + + // For each caveat definition, perform a diff and ensure the changes will not result in type errors. + caveatDefsWithChanges := make([]*core.CaveatDefinition, 0, len(validated.compiled.CaveatDefinitions)) + for _, caveatDef := range validated.compiled.CaveatDefinitions { + diff, err := sanityCheckCaveatChanges(ctx, rwt, caveatTypeSet, caveatDef, existingCaveatDefMap) + if err != nil { + return nil, err + } + + if len(diff.Deltas()) > 0 { + caveatDefsWithChanges = append(caveatDefsWithChanges, caveatDef) + } + } + + removedCaveatDefNames := existingCaveatDefNames.Subtract(validated.newCaveatDefNames) + + // Build a map of existing definitions to determine those being removed, if any. + existingObjectDefMap := make(map[string]*core.NamespaceDefinition, len(existingObjectDefs)) + existingObjectDefNames := mapz.NewSet[string]() + for _, existingDef := range existingObjectDefs { + existingObjectDefMap[existingDef.Name] = existingDef + existingObjectDefNames.Insert(existingDef.Name) + } + + // For each definition, perform a diff and ensure the changes will not result in any + // breaking changes. + objectDefsWithChanges := make([]*core.NamespaceDefinition, 0, len(validated.compiled.ObjectDefinitions)) + for _, nsdef := range validated.compiled.ObjectDefinitions { + diff, err := sanityCheckNamespaceChanges(ctx, rwt, nsdef, existingObjectDefMap) + if err != nil { + return nil, err + } + + if len(diff.Deltas()) > 0 { + objectDefsWithChanges = append(objectDefsWithChanges, nsdef) + + vts, ok := validated.validatedTypeSystems[nsdef.Name] + if !ok { + return nil, spiceerrors.MustBugf("validated type system not found for namespace `%s`", nsdef.Name) + } + + if err := namespace.AnnotateNamespace(vts); err != nil { + return nil, err + } + } + } + + log.Ctx(ctx). + Trace(). + Int("objectDefinitions", len(validated.compiled.ObjectDefinitions)). + Int("caveatDefinitions", len(validated.compiled.CaveatDefinitions)). + Int("objectDefsWithChanges", len(objectDefsWithChanges)). + Int("caveatDefsWithChanges", len(caveatDefsWithChanges)). + Msg("validated namespace definitions") + + // Ensure that deleting namespaces will not result in any relationships left without associated + // schema. + removedObjectDefNames := existingObjectDefNames.Subtract(validated.newObjectDefNames) + if !validated.additiveOnly { + if err := removedObjectDefNames.ForEach(func(nsdefName string) error { + return ensureNoRelationshipsExist(ctx, rwt, nsdefName) + }); err != nil { + return nil, err + } + } + + // Write the new/changes caveats. + if len(caveatDefsWithChanges) > 0 { + if err := rwt.WriteCaveats(ctx, caveatDefsWithChanges); err != nil { + return nil, err + } + } + + // Write the new/changed namespaces. + if len(objectDefsWithChanges) > 0 { + if err := rwt.WriteNamespaces(ctx, objectDefsWithChanges...); err != nil { + return nil, err + } + } + + if !validated.additiveOnly { + // Delete the removed namespaces. + if removedObjectDefNames.Len() > 0 { + if err := rwt.DeleteNamespaces(ctx, removedObjectDefNames.AsSlice()...); err != nil { + return nil, err + } + } + + // Delete the removed caveats. + if !removedCaveatDefNames.IsEmpty() { + if err := rwt.DeleteCaveats(ctx, removedCaveatDefNames.AsSlice()); err != nil { + return nil, err + } + } + } + + log.Ctx(ctx).Trace(). + Interface("objectDefinitions", validated.compiled.ObjectDefinitions). + Interface("caveatDefinitions", validated.compiled.CaveatDefinitions). + Object("addedOrChangedObjectDefinitions", validated.newObjectDefNames). + Object("removedObjectDefinitions", removedObjectDefNames). + Object("addedOrChangedCaveatDefinitions", validated.newCaveatDefNames). + Object("removedCaveatDefinitions", removedCaveatDefNames). + Msg("completed schema update") + + return &AppliedSchemaChanges{ + TotalOperationCount: len(validated.compiled.ObjectDefinitions) + len(validated.compiled.CaveatDefinitions) + removedObjectDefNames.Len() + removedCaveatDefNames.Len(), + NewObjectDefNames: validated.newObjectDefNames.Subtract(existingObjectDefNames).AsSlice(), + RemovedObjectDefNames: removedObjectDefNames.AsSlice(), + NewCaveatDefNames: validated.newCaveatDefNames.Subtract(existingCaveatDefNames).AsSlice(), + RemovedCaveatDefNames: removedCaveatDefNames.AsSlice(), + }, nil +} + +// sanityCheckCaveatChanges ensures that a caveat definition being written does not break +// the types of the parameters that may already exist on relationships. +func sanityCheckCaveatChanges( + _ context.Context, + _ datastore.ReadWriteTransaction, + caveatTypeSet *caveattypes.TypeSet, + caveatDef *core.CaveatDefinition, + existingDefs map[string]*core.CaveatDefinition, +) (*caveatdiff.Diff, error) { + // Ensure that the updated namespace does not break the existing tuple data. + existing := existingDefs[caveatDef.Name] + diff, err := caveatdiff.DiffCaveats(existing, caveatDef, caveatTypeSet) + if err != nil { + return nil, err + } + + for _, delta := range diff.Deltas() { + switch delta.Type { + case caveatdiff.RemovedParameter: + return diff, NewSchemaWriteDataValidationError("cannot remove parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name) + + case caveatdiff.ParameterTypeChanged: + return diff, NewSchemaWriteDataValidationError("cannot change the type of parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name) + } + } + + return diff, nil +} + +// ensureNoRelationshipsExist ensures that no relationships exist within the namespace with the given name. +func ensureNoRelationshipsExist(ctx context.Context, rwt datastore.ReadWriteTransaction, namespaceName string) error { + qy, qyErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{OptionalResourceType: namespaceName}, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceOfType), + ) + if err := errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete object definition `%s`, as a relationship exists under it", + namespaceName, + ); err != nil { + return err + } + + qy, qyErr = rwt.ReverseQueryRelationships( + ctx, + datastore.SubjectsFilter{ + SubjectType: namespaceName, + }, + options.WithLimitForReverse(options.LimitOne), + options.WithQueryShapeForReverse(queryshape.FindSubjectOfType), + ) + err := errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete object definition `%s`, as a relationship references it", + namespaceName, + ) + if err != nil { + return err + } + + return nil +} + +// sanityCheckNamespaceChanges ensures that a namespace definition being written does not result +// in breaking changes, such as relationships without associated defined schema object definitions +// and relations. +func sanityCheckNamespaceChanges( + ctx context.Context, + rwt datastore.ReadWriteTransaction, + nsdef *core.NamespaceDefinition, + existingDefs map[string]*core.NamespaceDefinition, +) (*nsdiff.Diff, error) { + // Ensure that the updated namespace does not break the existing tuple data. + existing := existingDefs[nsdef.Name] + diff, err := nsdiff.DiffNamespaces(existing, nsdef) + if err != nil { + return nil, err + } + + for _, delta := range diff.Deltas() { + switch delta.Type { + case nsdiff.RemovedRelation: + // NOTE: We add the subject filters here to ensure the reverse relationship index is used + // by the datastores. As there is no index that has {namespace, relation} directly, but there + // *is* an index that has {subject_namespace, subject_relation, namespace, relation}, we can + // force the datastore to use the reverse index by adding the subject filters. + var previousRelation *core.Relation + for _, relation := range existing.Relation { + if relation.Name == delta.RelationName { + previousRelation = relation + break + } + } + + if previousRelation == nil { + return nil, spiceerrors.MustBugf("relation `%s` not found in existing namespace definition", delta.RelationName) + } + + subjectSelectors := make([]datastore.SubjectsSelector, 0, len(previousRelation.TypeInformation.AllowedDirectRelations)) + for _, allowedType := range previousRelation.TypeInformation.AllowedDirectRelations { + if allowedType.GetRelation() == datastore.Ellipsis { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: allowedType.Namespace, + RelationFilter: datastore.SubjectRelationFilter{ + IncludeEllipsisRelation: true, + }, + }) + } else { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: allowedType.Namespace, + RelationFilter: datastore.SubjectRelationFilter{ + NonEllipsisRelation: allowedType.GetRelation(), + }, + }) + } + } + + qy, qyErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{ + OptionalResourceType: nsdef.Name, + OptionalResourceRelation: delta.RelationName, + OptionalSubjectsSelectors: subjectSelectors, + }, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceOfTypeAndRelation), + ) + + err = errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete relation `%s` in object definition `%s`, as a relationship exists under it", delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + + // Also check for right sides of tuples. + qy, qyErr = rwt.ReverseQueryRelationships( + ctx, + datastore.SubjectsFilter{ + SubjectType: nsdef.Name, + RelationFilter: datastore.SubjectRelationFilter{ + NonEllipsisRelation: delta.RelationName, + }, + }, + options.WithLimitForReverse(options.LimitOne), + options.WithQueryShapeForReverse(queryshape.FindSubjectOfTypeAndRelation), + ) + err = errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete relation `%s` in object definition `%s`, as a relationship references it", delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + + case nsdiff.RelationAllowedTypeRemoved: + var optionalSubjectIds []string + var relationFilter datastore.SubjectRelationFilter + var optionalCaveatNameFilter datastore.CaveatNameFilter + + if delta.AllowedType.GetPublicWildcard() != nil { + optionalSubjectIds = []string{tuple.PublicWildcard} + } else { + relationFilter = datastore.SubjectRelationFilter{ + NonEllipsisRelation: delta.AllowedType.GetRelation(), + } + } + + if delta.AllowedType.GetRequiredCaveat() != nil && delta.AllowedType.GetRequiredCaveat().CaveatName != "" { + optionalCaveatNameFilter = datastore.WithCaveatName(delta.AllowedType.GetRequiredCaveat().CaveatName) + } else { + optionalCaveatNameFilter = datastore.WithNoCaveat() + } + + expirationOption := datastore.ExpirationFilterOptionNoExpiration + if delta.AllowedType.RequiredExpiration != nil { + expirationOption = datastore.ExpirationFilterOptionHasExpiration + } + + qyr, qyrErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{ + OptionalResourceType: nsdef.Name, + OptionalResourceRelation: delta.RelationName, + OptionalSubjectsSelectors: []datastore.SubjectsSelector{ + { + OptionalSubjectType: delta.AllowedType.Namespace, + OptionalSubjectIds: optionalSubjectIds, + RelationFilter: relationFilter, + }, + }, + OptionalCaveatNameFilter: optionalCaveatNameFilter, + OptionalExpirationOption: expirationOption, + }, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceRelationForSubjectRelation), + ) + err = errorIfTupleIteratorReturnsTuples( + ctx, + qyr, + qyrErr, + "cannot remove allowed type `%s` from relation `%s` in object definition `%s`, as a relationship exists with it", + schema.SourceForAllowedRelation(delta.AllowedType), delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + } + } + return diff, nil +} + +// errorIfTupleIteratorReturnsTuples takes a tuple iterator and any error that was generated +// when the original iterator was created, and returns an error if iterator contains any tuples. +func errorIfTupleIteratorReturnsTuples(_ context.Context, qy datastore.RelationshipIterator, qyErr error, message string, args ...interface{}) error { + if qyErr != nil { + return qyErr + } + + for _, err := range qy { + if err != nil { + return err + } + return NewSchemaWriteDataValidationError(message, args...) + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go b/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go new file mode 100644 index 0000000..819452e --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go @@ -0,0 +1,332 @@ +package v1 + +import ( + "context" + "slices" + "sync" + "time" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/jzelinskie/stringz" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph" + "github.com/authzed/spicedb/internal/graph/computed" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/internal/taskrunner" + "github.com/authzed/spicedb/internal/telemetry" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/genutil/slicez" + "github.com/authzed/spicedb/pkg/middleware/consistency" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// bulkChecker contains the logic to allow ExperimentalService/BulkCheckPermission and +// PermissionsService/CheckBulkPermissions to share the same implementation. +type bulkChecker struct { + maxAPIDepth uint32 + maxCaveatContextSize int + maxConcurrency uint16 + caveatTypeSet *caveattypes.TypeSet + + dispatch dispatch.Dispatcher + dispatchChunkSize uint16 +} + +const maxBulkCheckCount = 10000 + +func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) { + telemetry.RecordLogicalChecks(uint64(len(req.Items))) + + atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, err + } + + if len(req.Items) > maxBulkCheckCount { + return nil, NewExceedsMaximumChecksErr(uint64(len(req.Items)), maxBulkCheckCount) + } + + // Compute a hash for each requested item and record its index(es) for the items, to be used for sorting of results. + itemCount, err := genutil.EnsureUInt32(len(req.Items)) + if err != nil { + return nil, err + } + + itemIndexByHash := mapz.NewMultiMapWithCap[string, int](itemCount) + for index, item := range req.Items { + itemHash, err := computeCheckBulkPermissionsItemHash(item) + if err != nil { + return nil, err + } + + itemIndexByHash.Add(itemHash, index) + } + + // Identify checks with same permission+subject over different resources and group them. This is doable because + // the dispatching system already internally supports this kind of batching for performance. + groupedItems, err := groupItems(ctx, groupingParameters{ + atRevision: atRevision, + maxCaveatContextSize: bc.maxCaveatContextSize, + maximumAPIDepth: bc.maxAPIDepth, + withTracing: req.WithTracing, + }, req.Items) + if err != nil { + return nil, err + } + + bulkResponseMutex := sync.Mutex{} + + spiceerrors.DebugAssert(func() bool { + return bc.maxConcurrency > 0 + }, "max concurrency must be greater than 0 in bulk check") + + tr := taskrunner.NewPreloadedTaskRunner(ctx, bc.maxConcurrency, len(groupedItems)) + + respMetadata := &dispatchv1.ResponseMeta{ + DispatchCount: 1, + CachedDispatchCount: 0, + DepthRequired: 1, + DebugInfo: nil, + } + usagemetrics.SetInContext(ctx, respMetadata) + + orderedPairs := make([]*v1.CheckBulkPermissionsPair, len(req.Items)) + + addPair := func(pair *v1.CheckBulkPermissionsPair) error { + pairItemHash, err := computeCheckBulkPermissionsItemHash(pair.Request) + if err != nil { + return err + } + + found, ok := itemIndexByHash.Get(pairItemHash) + if !ok { + return spiceerrors.MustBugf("missing expected item hash") + } + + for _, index := range found { + orderedPairs[index] = pair + } + + return nil + } + + appendResultsForError := func(params *computed.CheckParameters, resourceIDs []string, err error) error { + rewritten := shared.RewriteError(ctx, err, &shared.ConfigForErrors{ + MaximumAPIDepth: bc.maxAPIDepth, + }) + statusResp, ok := status.FromError(rewritten) + if !ok { + // If error is not a gRPC Status, fail the entire bulk check request. + return err + } + + bulkResponseMutex.Lock() + defer bulkResponseMutex.Unlock() + + for _, resourceID := range resourceIDs { + reqItem, err := requestItemFromResourceAndParameters(params, resourceID) + if err != nil { + return err + } + + if err := addPair(&v1.CheckBulkPermissionsPair{ + Request: reqItem, + Response: &v1.CheckBulkPermissionsPair_Error{ + Error: statusResp.Proto(), + }, + }); err != nil { + return err + } + } + + return nil + } + + appendResultsForCheck := func( + params *computed.CheckParameters, + resourceIDs []string, + metadata *dispatchv1.ResponseMeta, + debugInfos []*dispatchv1.DebugInformation, + results map[string]*dispatchv1.ResourceCheckResult, + ) error { + bulkResponseMutex.Lock() + defer bulkResponseMutex.Unlock() + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + schemaText := "" + if len(debugInfos) > 0 { + schema, err := getFullSchema(ctx, ds) + if err != nil { + return err + } + schemaText = schema + } + + for _, resourceID := range resourceIDs { + var debugTrace *v1.DebugInformation + if len(debugInfos) > 0 { + // Find the debug info that matches the resource ID. + var debugInfo *dispatchv1.DebugInformation + for _, di := range debugInfos { + if slices.Contains(di.Check.Request.ResourceIds, resourceID) { + debugInfo = di + break + } + } + + if debugInfo != nil { + // Synthesize a new debug information with a trace "wrapping" the (potentially batched) + // trace. + localResults := make(map[string]*dispatchv1.ResourceCheckResult, 1) + if result, ok := results[resourceID]; ok { + localResults[resourceID] = result + } + wrappedDebugInfo := &dispatchv1.DebugInformation{ + Check: &dispatchv1.CheckDebugTrace{ + Request: &dispatchv1.DispatchCheckRequest{ + ResourceRelation: debugInfo.Check.Request.ResourceRelation, + ResourceIds: []string{resourceID}, + Subject: debugInfo.Check.Request.Subject, + ResultsSetting: debugInfo.Check.Request.ResultsSetting, + Debug: debugInfo.Check.Request.Debug, + }, + ResourceRelationType: debugInfo.Check.ResourceRelationType, + IsCachedResult: false, + SubProblems: []*dispatchv1.CheckDebugTrace{ + debugInfo.Check, + }, + Results: localResults, + Duration: durationpb.New(time.Duration(0)), + TraceId: graph.NewTraceID(), + SourceId: debugInfo.Check.SourceId, + }, + } + + // Convert to debug information. + dt, err := convertCheckDispatchDebugInformationWithSchema(ctx, params.CaveatContext, wrappedDebugInfo, ds, bc.caveatTypeSet, schemaText) + if err != nil { + return err + } + debugTrace = dt + } + } + + reqItem, err := requestItemFromResourceAndParameters(params, resourceID) + if err != nil { + return err + } + + if err := addPair(&v1.CheckBulkPermissionsPair{ + Request: reqItem, + Response: pairItemFromCheckResult(results[resourceID], debugTrace), + }); err != nil { + return err + } + } + + respMetadata.DispatchCount += metadata.DispatchCount + respMetadata.CachedDispatchCount += metadata.CachedDispatchCount + return nil + } + + for _, group := range groupedItems { + group := group + + slicez.ForEachChunk(group.resourceIDs, bc.dispatchChunkSize, func(resourceIDs []string) { + tr.Add(func(ctx context.Context) error { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + // Ensure the check namespaces and relations are valid. + err := namespace.CheckNamespaceAndRelations(ctx, + []namespace.TypeAndRelationToCheck{ + { + NamespaceName: group.params.ResourceType.ObjectType, + RelationName: group.params.ResourceType.Relation, + AllowEllipsis: false, + }, + { + NamespaceName: group.params.Subject.ObjectType, + RelationName: stringz.DefaultEmpty(group.params.Subject.Relation, graph.Ellipsis), + AllowEllipsis: true, + }, + }, ds) + if err != nil { + return appendResultsForError(group.params, resourceIDs, err) + } + + // Call bulk check to compute the check result(s) for the resource ID(s). + rcr, metadata, debugInfos, err := computed.ComputeBulkCheck(ctx, bc.dispatch, bc.caveatTypeSet, *group.params, resourceIDs, bc.dispatchChunkSize) + if err != nil { + return appendResultsForError(group.params, resourceIDs, err) + } + + return appendResultsForCheck(group.params, resourceIDs, metadata, debugInfos, rcr) + }) + }) + } + + // Run the checks in parallel. + if err := tr.StartAndWait(); err != nil { + return nil, err + } + + return &v1.CheckBulkPermissionsResponse{CheckedAt: checkedAt, Pairs: orderedPairs}, nil +} + +func toCheckBulkPermissionsRequest(req *v1.BulkCheckPermissionRequest) *v1.CheckBulkPermissionsRequest { + items := make([]*v1.CheckBulkPermissionsRequestItem, len(req.Items)) + for i, item := range req.Items { + items[i] = &v1.CheckBulkPermissionsRequestItem{ + Resource: item.Resource, + Permission: item.Permission, + Subject: item.Subject, + Context: item.Context, + } + } + + return &v1.CheckBulkPermissionsRequest{Items: items} +} + +func toBulkCheckPermissionResponse(resp *v1.CheckBulkPermissionsResponse) *v1.BulkCheckPermissionResponse { + pairs := make([]*v1.BulkCheckPermissionPair, len(resp.Pairs)) + for i, pair := range resp.Pairs { + pairs[i] = &v1.BulkCheckPermissionPair{} + pairs[i].Request = &v1.BulkCheckPermissionRequestItem{ + Resource: pair.Request.Resource, + Permission: pair.Request.Permission, + Subject: pair.Request.Subject, + Context: pair.Request.Context, + } + + switch t := pair.Response.(type) { + case *v1.CheckBulkPermissionsPair_Item: + pairs[i].Response = &v1.BulkCheckPermissionPair_Item{ + Item: &v1.BulkCheckPermissionResponseItem{ + Permissionship: t.Item.Permissionship, + PartialCaveatInfo: t.Item.PartialCaveatInfo, + }, + } + case *v1.CheckBulkPermissionsPair_Error: + pairs[i].Response = &v1.BulkCheckPermissionPair_Error{ + Error: t.Error, + } + default: + panic("unknown CheckBulkPermissionResponse pair response type") + } + } + + return &v1.BulkCheckPermissionResponse{ + CheckedAt: resp.CheckedAt, + Pairs: pairs, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go b/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go new file mode 100644 index 0000000..712f9ec --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go @@ -0,0 +1,238 @@ +package v1 + +import ( + "cmp" + "context" + "slices" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + cexpr "github.com/authzed/spicedb/internal/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ConvertCheckDispatchDebugInformation converts dispatch debug information found in the response metadata +// into DebugInformation returnable to the API. +func ConvertCheckDispatchDebugInformation( + ctx context.Context, + caveatTypeSet *caveattypes.TypeSet, + caveatContext map[string]any, + debugInfo *dispatch.DebugInformation, + reader datastore.Reader, +) (*v1.DebugInformation, error) { + if debugInfo == nil { + return nil, nil + } + + schema, err := getFullSchema(ctx, reader) + if err != nil { + return nil, err + } + + return convertCheckDispatchDebugInformationWithSchema(ctx, caveatContext, debugInfo, reader, caveatTypeSet, schema) +} + +// getFullSchema returns the full schema from the reader. +func getFullSchema(ctx context.Context, reader datastore.Reader) (string, error) { + caveats, err := reader.ListAllCaveats(ctx) + if err != nil { + return "", err + } + + namespaces, err := reader.ListAllNamespaces(ctx) + if err != nil { + return "", err + } + + defs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, caveat := range caveats { + defs = append(defs, caveat.Definition) + } + for _, ns := range namespaces { + defs = append(defs, ns.Definition) + } + + schema, _, err := generator.GenerateSchema(defs) + if err != nil { + return "", err + } + + return schema, nil +} + +func convertCheckDispatchDebugInformationWithSchema( + ctx context.Context, + caveatContext map[string]any, + debugInfo *dispatch.DebugInformation, + reader datastore.Reader, + caveatTypeSet *caveattypes.TypeSet, + schema string, +) (*v1.DebugInformation, error) { + converted, err := convertCheckTrace(ctx, caveatContext, debugInfo.Check, reader, caveatTypeSet) + if err != nil { + return nil, err + } + + return &v1.DebugInformation{ + Check: converted, + SchemaUsed: strings.TrimSpace(schema), + }, nil +} + +func convertCheckTrace(ctx context.Context, caveatContext map[string]any, ct *dispatch.CheckDebugTrace, reader datastore.Reader, caveatTypeSet *caveattypes.TypeSet) (*v1.CheckDebugTrace, error) { + permissionType := v1.CheckDebugTrace_PERMISSION_TYPE_UNSPECIFIED + if ct.ResourceRelationType == dispatch.CheckDebugTrace_PERMISSION { + permissionType = v1.CheckDebugTrace_PERMISSION_TYPE_PERMISSION + } else if ct.ResourceRelationType == dispatch.CheckDebugTrace_RELATION { + permissionType = v1.CheckDebugTrace_PERMISSION_TYPE_RELATION + } + + subRelation := ct.Request.Subject.Relation + if subRelation == tuple.Ellipsis { + subRelation = "" + } + + permissionship := v1.CheckDebugTrace_PERMISSIONSHIP_NO_PERMISSION + var partialResults []*dispatch.ResourceCheckResult + for _, checkResult := range ct.Results { + if checkResult.Membership == dispatch.ResourceCheckResult_MEMBER { + permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_HAS_PERMISSION + break + } + + if checkResult.Membership == dispatch.ResourceCheckResult_CAVEATED_MEMBER && permissionship != v1.CheckDebugTrace_PERMISSIONSHIP_HAS_PERMISSION { + permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION + partialResults = append(partialResults, checkResult) + } + } + + var caveatEvalInfo *v1.CaveatEvalInfo + + // NOTE: Bulk check gives the *fully resolved* results, rather than the result pre-caveat + // evaluation. In that case, we skip re-evaluating here. + // TODO(jschorr): Add support for evaluating *each* result distinctly. + if permissionship == v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION && len(partialResults) == 1 && + len(partialResults[0].MissingExprFields) == 0 { + partialCheckResult := partialResults[0] + spiceerrors.DebugAssertNotNil(partialCheckResult.Expression, "got nil caveat expression") + + computedResult, err := cexpr.RunSingleCaveatExpression(ctx, caveatTypeSet, partialCheckResult.Expression, caveatContext, reader, cexpr.RunCaveatExpressionWithDebugInformation) + if err != nil { + return nil, err + } + + var partialCaveatInfo *v1.PartialCaveatInfo + caveatResult := v1.CaveatEvalInfo_RESULT_FALSE + if computedResult.Value() { + caveatResult = v1.CaveatEvalInfo_RESULT_TRUE + } else if computedResult.IsPartial() { + caveatResult = v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT + missingNames, _ := computedResult.MissingVarNames() + partialCaveatInfo = &v1.PartialCaveatInfo{ + MissingRequiredContext: missingNames, + } + } + + exprString, contextStruct, err := cexpr.BuildDebugInformation(computedResult) + if err != nil { + return nil, err + } + + caveatName := "" + if partialCheckResult.Expression.GetCaveat() != nil { + caveatName = partialCheckResult.Expression.GetCaveat().CaveatName + } + + caveatEvalInfo = &v1.CaveatEvalInfo{ + Expression: exprString, + Result: caveatResult, + Context: contextStruct, + PartialCaveatInfo: partialCaveatInfo, + CaveatName: caveatName, + } + } + + // If there is more than a single result, mark the overall permissionship + // as unspecified if *all* results needed to be true and at least one is not. + if len(ct.Request.ResourceIds) > 1 && ct.Request.ResultsSetting == dispatch.DispatchCheckRequest_REQUIRE_ALL_RESULTS { + for _, resourceID := range ct.Request.ResourceIds { + if result, ok := ct.Results[resourceID]; !ok || result.Membership != dispatch.ResourceCheckResult_MEMBER { + permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_UNSPECIFIED + break + } + } + } + + if len(ct.SubProblems) > 0 { + subProblems := make([]*v1.CheckDebugTrace, 0, len(ct.SubProblems)) + for _, subProblem := range ct.SubProblems { + converted, err := convertCheckTrace(ctx, caveatContext, subProblem, reader, caveatTypeSet) + if err != nil { + return nil, err + } + + subProblems = append(subProblems, converted) + } + + slices.SortFunc(subProblems, func(a, b *v1.CheckDebugTrace) int { + return cmp.Compare(tuple.V1StringObjectRef(a.Resource), tuple.V1StringObjectRef(a.Resource)) + }) + + return &v1.CheckDebugTrace{ + TraceOperationId: ct.TraceId, + Resource: &v1.ObjectReference{ + ObjectType: ct.Request.ResourceRelation.Namespace, + ObjectId: strings.Join(ct.Request.ResourceIds, ","), + }, + Permission: ct.Request.ResourceRelation.Relation, + PermissionType: permissionType, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: ct.Request.Subject.Namespace, + ObjectId: ct.Request.Subject.ObjectId, + }, + OptionalRelation: subRelation, + }, + CaveatEvaluationInfo: caveatEvalInfo, + Result: permissionship, + Resolution: &v1.CheckDebugTrace_SubProblems_{ + SubProblems: &v1.CheckDebugTrace_SubProblems{ + Traces: subProblems, + }, + }, + Duration: ct.Duration, + Source: ct.SourceId, + }, nil + } + + return &v1.CheckDebugTrace{ + TraceOperationId: ct.TraceId, + Resource: &v1.ObjectReference{ + ObjectType: ct.Request.ResourceRelation.Namespace, + ObjectId: strings.Join(ct.Request.ResourceIds, ","), + }, + Permission: ct.Request.ResourceRelation.Relation, + PermissionType: permissionType, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: ct.Request.Subject.Namespace, + ObjectId: ct.Request.Subject.ObjectId, + }, + OptionalRelation: subRelation, + }, + CaveatEvaluationInfo: caveatEvalInfo, + Result: permissionship, + Resolution: &v1.CheckDebugTrace_WasCachedResult{ + WasCachedResult: ct.IsCachedResult, + }, + Duration: ct.Duration, + Source: ct.SourceId, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go b/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go new file mode 100644 index 0000000..6de6749 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go @@ -0,0 +1,511 @@ +package v1 + +import ( + "fmt" + "strconv" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ExceedsMaximumLimitError occurs when a limit that is too large is given to a call. +type ExceedsMaximumLimitError struct { + error + providedLimit uint64 + maxLimitAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ExceedsMaximumLimitError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("providedLimit", err.providedLimit).Uint64("maxLimitAllowed", err.maxLimitAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ExceedsMaximumLimitError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_EXCEEDS_MAXIMUM_ALLOWABLE_LIMIT, + map[string]string{ + "limit_provided": strconv.FormatUint(err.providedLimit, 10), + "maximum_limit_allowed": strconv.FormatUint(err.maxLimitAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumLimitErr creates a new error representing that the limit specified was too large. +func NewExceedsMaximumLimitErr(providedLimit uint64, maxLimitAllowed uint64) ExceedsMaximumLimitError { + return ExceedsMaximumLimitError{ + error: fmt.Errorf("provided limit %d is greater than maximum allowed of %d", providedLimit, maxLimitAllowed), + providedLimit: providedLimit, + maxLimitAllowed: maxLimitAllowed, + } +} + +// ExceedsMaximumChecksError occurs when too many checks are given to a call. +type ExceedsMaximumChecksError struct { + error + checkCount uint64 + maxCountAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ExceedsMaximumChecksError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("checkCount", err.checkCount).Uint64("maxCountAllowed", err.maxCountAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ExceedsMaximumChecksError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{ + "check_count": strconv.FormatUint(err.checkCount, 10), + "maximum_checks_allowed": strconv.FormatUint(err.maxCountAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumChecksErr creates a new error representing that too many updates were given to a BulkCheckPermissions call. +func NewExceedsMaximumChecksErr(checkCount uint64, maxCountAllowed uint64) ExceedsMaximumChecksError { + return ExceedsMaximumChecksError{ + error: fmt.Errorf("check count of %d is greater than maximum allowed of %d", checkCount, maxCountAllowed), + checkCount: checkCount, + maxCountAllowed: maxCountAllowed, + } +} + +// ExceedsMaximumUpdatesError occurs when too many updates are given to a call. +type ExceedsMaximumUpdatesError struct { + error + updateCount uint64 + maxCountAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ExceedsMaximumUpdatesError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("updateCount", err.updateCount).Uint64("maxCountAllowed", err.maxCountAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ExceedsMaximumUpdatesError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_TOO_MANY_UPDATES_IN_REQUEST, + map[string]string{ + "update_count": strconv.FormatUint(err.updateCount, 10), + "maximum_updates_allowed": strconv.FormatUint(err.maxCountAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumUpdatesErr creates a new error representing that too many updates were given to a WriteRelationships call. +func NewExceedsMaximumUpdatesErr(updateCount uint64, maxCountAllowed uint64) ExceedsMaximumUpdatesError { + return ExceedsMaximumUpdatesError{ + error: fmt.Errorf("update count of %d is greater than maximum allowed of %d", updateCount, maxCountAllowed), + updateCount: updateCount, + maxCountAllowed: maxCountAllowed, + } +} + +// ExceedsMaximumPreconditionsError occurs when too many preconditions are given to a call. +type ExceedsMaximumPreconditionsError struct { + error + preconditionCount uint64 + maxCountAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ExceedsMaximumPreconditionsError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("preconditionCount", err.preconditionCount).Uint64("maxCountAllowed", err.maxCountAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ExceedsMaximumPreconditionsError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_TOO_MANY_PRECONDITIONS_IN_REQUEST, + map[string]string{ + "precondition_count": strconv.FormatUint(err.preconditionCount, 10), + "maximum_updates_allowed": strconv.FormatUint(err.maxCountAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumPreconditionsErr creates a new error representing that too many preconditions were given to a call. +func NewExceedsMaximumPreconditionsErr(preconditionCount uint64, maxCountAllowed uint64) ExceedsMaximumPreconditionsError { + return ExceedsMaximumPreconditionsError{ + error: fmt.Errorf( + "precondition count of %d is greater than maximum allowed of %d", + preconditionCount, + maxCountAllowed), + preconditionCount: preconditionCount, + maxCountAllowed: maxCountAllowed, + } +} + +// PreconditionFailedError occurs when the precondition to a write tuple call does not match. +type PreconditionFailedError struct { + error + precondition *v1.Precondition +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err PreconditionFailedError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Interface("precondition", err.precondition) +} + +// NewPreconditionFailedErr constructs a new precondition failed error. +func NewPreconditionFailedErr(precondition *v1.Precondition) error { + return PreconditionFailedError{ + error: fmt.Errorf("unable to satisfy write precondition `%s`", precondition), + precondition: precondition, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err PreconditionFailedError) GRPCStatus() *status.Status { + metadata := map[string]string{ + "precondition_operation": v1.Precondition_Operation_name[int32(err.precondition.Operation)], + } + + if err.precondition.Filter.ResourceType != "" { + metadata["precondition_resource_type"] = err.precondition.Filter.ResourceType + } + + if err.precondition.Filter.OptionalResourceId != "" { + metadata["precondition_resource_id"] = err.precondition.Filter.OptionalResourceId + } + + if err.precondition.Filter.OptionalResourceIdPrefix != "" { + metadata["precondition_resource_id_prefix"] = err.precondition.Filter.OptionalResourceIdPrefix + } + + if err.precondition.Filter.OptionalRelation != "" { + metadata["precondition_relation"] = err.precondition.Filter.OptionalRelation + } + + if err.precondition.Filter.OptionalSubjectFilter != nil { + metadata["precondition_subject_type"] = err.precondition.Filter.OptionalSubjectFilter.SubjectType + + if err.precondition.Filter.OptionalSubjectFilter.OptionalSubjectId != "" { + metadata["precondition_subject_id"] = err.precondition.Filter.OptionalSubjectFilter.OptionalSubjectId + } + + if err.precondition.Filter.OptionalSubjectFilter.OptionalRelation != nil { + metadata["precondition_subject_relation"] = err.precondition.Filter.OptionalSubjectFilter.OptionalRelation.Relation + } + } + + return spiceerrors.WithCodeAndDetails( + err, + codes.FailedPrecondition, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_WRITE_OR_DELETE_PRECONDITION_FAILURE, + metadata, + ), + ) +} + +// DuplicateRelationErrorshipError indicates that an update was attempted on the same relationship. +type DuplicateRelationErrorshipError struct { + error + update *v1.RelationshipUpdate +} + +// NewDuplicateRelationshipErr constructs a new invalid subject error. +func NewDuplicateRelationshipErr(update *v1.RelationshipUpdate) DuplicateRelationErrorshipError { + return DuplicateRelationErrorshipError{ + error: fmt.Errorf( + "found more than one update with relationship `%s` in this request; a relationship can only be specified in an update once per overall WriteRelationships request", + tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship), + ), + update: update, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err DuplicateRelationErrorshipError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UPDATES_ON_SAME_RELATIONSHIP, + map[string]string{ + "definition_name": err.update.Relationship.Resource.ObjectType, + "relationship": tuple.MustV1StringRelationship(err.update.Relationship), + }, + ), + ) +} + +// ErrMaxRelationshipContextError indicates an attempt to write a relationship that exceeded the maximum +// configured context size. +type ErrMaxRelationshipContextError struct { + error + update *v1.RelationshipUpdate + maxAllowedSize int +} + +// NewMaxRelationshipContextError constructs a new max relationship context error. +func NewMaxRelationshipContextError(update *v1.RelationshipUpdate, maxAllowedSize int) ErrMaxRelationshipContextError { + return ErrMaxRelationshipContextError{ + error: fmt.Errorf( + "provided relationship `%s` exceeded maximum allowed caveat size of %d", + tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship), + maxAllowedSize, + ), + update: update, + maxAllowedSize: maxAllowedSize, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ErrMaxRelationshipContextError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_MAX_RELATIONSHIP_CONTEXT_SIZE, + map[string]string{ + "relationship": tuple.V1StringRelationshipWithoutCaveatOrExpiration(err.update.Relationship), + "max_allowed_size": strconv.Itoa(err.maxAllowedSize), + "context_size": strconv.Itoa(proto.Size(err.update.Relationship)), + }, + ), + ) +} + +// CouldNotTransactionallyDeleteError indicates that a deletion could not occur transactionally. +type CouldNotTransactionallyDeleteError struct { + error + limit uint32 + filter *v1.RelationshipFilter +} + +// NewCouldNotTransactionallyDeleteErr constructs a new could not transactionally deleter error. +func NewCouldNotTransactionallyDeleteErr(filter *v1.RelationshipFilter, limit uint32) CouldNotTransactionallyDeleteError { + return CouldNotTransactionallyDeleteError{ + error: fmt.Errorf( + "found more than %d relationships to be deleted and partial deletion was not requested", + limit, + ), + limit: limit, + filter: filter, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err CouldNotTransactionallyDeleteError) GRPCStatus() *status.Status { + metadata := map[string]string{ + "limit": strconv.Itoa(int(err.limit)), + "filter_resource_type": err.filter.ResourceType, + } + + if err.filter.OptionalResourceId != "" { + metadata["filter_resource_id"] = err.filter.OptionalResourceId + } + + if err.filter.OptionalRelation != "" { + metadata["filter_relation"] = err.filter.OptionalRelation + } + + if err.filter.OptionalSubjectFilter != nil { + metadata["filter_subject_type"] = err.filter.OptionalSubjectFilter.SubjectType + + if err.filter.OptionalSubjectFilter.OptionalSubjectId != "" { + metadata["filter_subject_id"] = err.filter.OptionalSubjectFilter.OptionalSubjectId + } + + if err.filter.OptionalSubjectFilter.OptionalRelation != nil { + metadata["filter_subject_relation"] = err.filter.OptionalSubjectFilter.OptionalRelation.Relation + } + } + + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_TOO_MANY_RELATIONSHIPS_FOR_TRANSACTIONAL_DELETE, + metadata, + ), + ) +} + +// InvalidCursorError indicates that an invalid cursor was found. +type InvalidCursorError struct { + error + reason string +} + +// NewInvalidCursorErr constructs a new invalid cursor error. +func NewInvalidCursorErr(reason string) InvalidCursorError { + return InvalidCursorError{ + error: fmt.Errorf( + "the cursor provided is not valid: %s", + reason, + ), + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err InvalidCursorError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.FailedPrecondition, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_INVALID_CURSOR, + map[string]string{ + "reason": err.reason, + }, + ), + ) +} + +// InvalidFilterError indicates the specified relationship filter was invalid. +type InvalidFilterError struct { + error + + filter string +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err InvalidFilterError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_INVALID_FILTER, + map[string]string{ + "filter": err.filter, + }, + ), + ) +} + +// NewInvalidFilterErr constructs a new invalid filter error. +func NewInvalidFilterErr(reason string, filter string) InvalidFilterError { + return InvalidFilterError{ + error: fmt.Errorf( + "the relationship filter provided is not valid: %s", reason, + ), + filter: filter, + } +} + +// NewEmptyPreconditionErr constructs a new empty precondition error. +func NewEmptyPreconditionErr() EmptyPreconditionError { + return EmptyPreconditionError{ + error: fmt.Errorf( + "one of the specified preconditions is empty", + ), + } +} + +// EmptyPreconditionError indicates an empty precondition was found. +type EmptyPreconditionError struct { + error +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err EmptyPreconditionError) GRPCStatus() *status.Status { + // TODO(jschorr): Put a proper error reason in here. + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{}, + ), + ) +} + +// NewNotAPermissionError constructs a new not a permission error. +func NewNotAPermissionError(relationName string) NotAPermissionError { + return NotAPermissionError{ + error: fmt.Errorf( + "the relation `%s` is not a permission", relationName, + ), + relationName: relationName, + } +} + +// NotAPermissionError indicates that the relation is not a permission. +type NotAPermissionError struct { + error + relationName string +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err NotAPermissionError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNKNOWN_RELATION_OR_PERMISSION, + map[string]string{ + "relationName": err.relationName, + }, + ), + ) +} + +func defaultIfZero[T comparable](value T, defaultValue T) T { + var zero T + if value == zero { + return defaultValue + } + return value +} + +// TransactionMetadataTooLargeError indicates that the metadata for a transaction is too large. +type TransactionMetadataTooLargeError struct { + error + metadataSize int + maxSize int +} + +// NewTransactionMetadataTooLargeErr constructs a new transaction metadata too large error. +func NewTransactionMetadataTooLargeErr(metadataSize int, maxSize int) TransactionMetadataTooLargeError { + return TransactionMetadataTooLargeError{ + error: fmt.Errorf("metadata size of %d is greater than maximum allowed of %d", metadataSize, maxSize), + metadataSize: metadataSize, + maxSize: maxSize, + } +} + +func (err TransactionMetadataTooLargeError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Int("metadataSize", err.metadataSize).Int("maxSize", err.maxSize) +} + +func (err TransactionMetadataTooLargeError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_TRANSACTION_METADATA_TOO_LARGE, + map[string]string{ + "metadata_byte_size": strconv.Itoa(err.metadataSize), + "maximum_allowed_metadata_byte_size": strconv.Itoa(err.maxSize), + }, + ), + ) +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go b/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go new file mode 100644 index 0000000..0e4b4a7 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go @@ -0,0 +1,824 @@ +package v1 + +import ( + "context" + "errors" + "fmt" + "io" + "slices" + "sort" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/ccoveille/go-safecast" + "github.com/jzelinskie/stringz" + + "github.com/authzed/spicedb/internal/dispatch" + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/middleware" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/handwrittenvalidation" + "github.com/authzed/spicedb/internal/middleware/streamtimeout" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/relationships" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/internal/services/v1/options" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "github.com/samber/lo" +) + +const ( + defaultExportBatchSizeFallback = 1_000 + maxExportBatchSizeFallback = 10_000 + streamReadTimeoutFallbackSeconds = 600 +) + +// NewExperimentalServer creates a ExperimentalServiceServer instance. +func NewExperimentalServer(dispatch dispatch.Dispatcher, permServerConfig PermissionsServerConfig, opts ...options.ExperimentalServerOptionsOption) v1.ExperimentalServiceServer { + config := options.NewExperimentalServerOptionsWithOptionsAndDefaults(opts...) + if config.DefaultExportBatchSize == 0 { + log. + Warn(). + Uint32("specified", config.DefaultExportBatchSize). + Uint32("fallback", defaultExportBatchSizeFallback). + Msg("experimental server config specified invalid DefaultExportBatchSize, setting to fallback") + config.DefaultExportBatchSize = defaultExportBatchSizeFallback + } + if config.MaxExportBatchSize == 0 { + fallback := permServerConfig.MaxBulkExportRelationshipsLimit + if fallback == 0 { + fallback = maxExportBatchSizeFallback + } + + log. + Warn(). + Uint32("specified", config.MaxExportBatchSize). + Uint32("fallback", fallback). + Msg("experimental server config specified invalid MaxExportBatchSize, setting to fallback") + config.MaxExportBatchSize = fallback + } + if config.StreamReadTimeout == 0 { + log. + Warn(). + Stringer("specified", config.StreamReadTimeout). + Stringer("fallback", streamReadTimeoutFallbackSeconds*time.Second). + Msg("experimental server config specified invalid StreamReadTimeout, setting to fallback") + config.StreamReadTimeout = streamReadTimeoutFallbackSeconds * time.Second + } + + chunkSize := permServerConfig.DispatchChunkSize + if chunkSize == 0 { + log. + Warn(). + Msg("experimental server config specified invalid DispatchChunkSize, defaulting to 100") + chunkSize = 100 + } + + return &experimentalServer{ + WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{ + Unary: middleware.ChainUnaryServer( + grpcvalidate.UnaryServerInterceptor(), + handwrittenvalidation.UnaryServerInterceptor, + usagemetrics.UnaryServerInterceptor(), + ), + Stream: middleware.ChainStreamServer( + grpcvalidate.StreamServerInterceptor(), + handwrittenvalidation.StreamServerInterceptor, + usagemetrics.StreamServerInterceptor(), + streamtimeout.MustStreamServerInterceptor(config.StreamReadTimeout), + ), + }, + maxBatchSize: uint64(config.MaxExportBatchSize), + caveatTypeSet: caveattypes.TypeSetOrDefault(permServerConfig.CaveatTypeSet), + bulkChecker: &bulkChecker{ + maxAPIDepth: permServerConfig.MaximumAPIDepth, + maxCaveatContextSize: permServerConfig.MaxCaveatContextSize, + maxConcurrency: config.BulkCheckMaxConcurrency, + dispatch: dispatch, + dispatchChunkSize: chunkSize, + caveatTypeSet: caveattypes.TypeSetOrDefault(permServerConfig.CaveatTypeSet), + }, + } +} + +type experimentalServer struct { + v1.UnimplementedExperimentalServiceServer + shared.WithServiceSpecificInterceptors + + maxBatchSize uint64 + + bulkChecker *bulkChecker + caveatTypeSet *caveattypes.TypeSet +} + +type bulkLoadAdapter struct { + stream v1.ExperimentalService_BulkImportRelationshipsServer + referencedNamespaceMap map[string]*schema.Definition + referencedCaveatMap map[string]*core.CaveatDefinition + current tuple.Relationship + caveat core.ContextualizedCaveat + caveatTypeSet *caveattypes.TypeSet + + awaitingNamespaces []string + awaitingCaveats []string + + currentBatch []*v1.Relationship + numSent int + err error +} + +func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) { + for a.err == nil && a.numSent == len(a.currentBatch) { + // Load a new batch + batch, err := a.stream.Recv() + if err != nil { + a.err = err + if errors.Is(a.err, io.EOF) { + return nil, nil + } + return nil, a.err + } + + a.currentBatch = batch.Relationships + a.numSent = 0 + + a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats( + a.currentBatch, + a.referencedNamespaceMap, + a.referencedCaveatMap, + ) + } + + if len(a.awaitingNamespaces) > 0 || len(a.awaitingCaveats) > 0 { + // Shut down the stream to give our caller a chance to fill in this information + return nil, nil + } + + a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType + a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId + a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation + a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType + a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId + a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis) + + if a.currentBatch[a.numSent].OptionalCaveat != nil { + a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName + a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context + a.current.OptionalCaveat = &a.caveat + } else { + a.current.OptionalCaveat = nil + } + + if a.currentBatch[a.numSent].OptionalExpiresAt != nil { + t := a.currentBatch[a.numSent].OptionalExpiresAt.AsTime() + a.current.OptionalExpiration = &t + } else { + a.current.OptionalExpiration = nil + } + + a.current.OptionalIntegrity = nil + + if err := relationships.ValidateOneRelationship( + a.referencedNamespaceMap, + a.referencedCaveatMap, + a.caveatTypeSet, + a.current, + relationships.ValidateRelationshipForCreateOrTouch, + ); err != nil { + return nil, err + } + + a.numSent++ + return &a.current, nil +} + +func extractBatchNewReferencedNamespacesAndCaveats( + batch []*v1.Relationship, + existingNamespaces map[string]*schema.Definition, + existingCaveats map[string]*core.CaveatDefinition, +) ([]string, []string) { + newNamespaces := make(map[string]struct{}, 2) + newCaveats := make(map[string]struct{}, 0) + for _, rel := range batch { + if _, ok := existingNamespaces[rel.Resource.ObjectType]; !ok { + newNamespaces[rel.Resource.ObjectType] = struct{}{} + } + if _, ok := existingNamespaces[rel.Subject.Object.ObjectType]; !ok { + newNamespaces[rel.Subject.Object.ObjectType] = struct{}{} + } + if rel.OptionalCaveat != nil { + if _, ok := existingCaveats[rel.OptionalCaveat.CaveatName]; !ok { + newCaveats[rel.OptionalCaveat.CaveatName] = struct{}{} + } + } + } + + return lo.Keys(newNamespaces), lo.Keys(newCaveats) +} + +// TODO: this is now duplicate code with ImportBulkRelationships +func (es *experimentalServer) BulkImportRelationships(stream v1.ExperimentalService_BulkImportRelationshipsServer) error { + ds := datastoremw.MustFromContext(stream.Context()) + + var numWritten uint64 + if _, err := ds.ReadWriteTx(stream.Context(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + loadedNamespaces := make(map[string]*schema.Definition, 2) + loadedCaveats := make(map[string]*core.CaveatDefinition, 0) + + adapter := &bulkLoadAdapter{ + stream: stream, + referencedNamespaceMap: loadedNamespaces, + referencedCaveatMap: loadedCaveats, + current: tuple.Relationship{}, + caveat: core.ContextualizedCaveat{}, + caveatTypeSet: es.caveatTypeSet, + } + resolver := schema.ResolverForDatastoreReader(rwt) + ts := schema.NewTypeSystem(resolver) + + var streamWritten uint64 + var err error + for ; adapter.err == nil && err == nil; streamWritten, err = rwt.BulkLoad(stream.Context(), adapter) { + numWritten += streamWritten + + // The stream has terminated because we're awaiting namespace and/or caveat information + if len(adapter.awaitingNamespaces) > 0 { + nsDefs, err := rwt.LookupNamespacesWithNames(stream.Context(), adapter.awaitingNamespaces) + if err != nil { + return err + } + + for _, nsDef := range nsDefs { + newDef, err := schema.NewDefinition(ts, nsDef.Definition) + if err != nil { + return err + } + + loadedNamespaces[nsDef.Definition.Name] = newDef + } + adapter.awaitingNamespaces = nil + } + + if len(adapter.awaitingCaveats) > 0 { + caveats, err := rwt.LookupCaveatsWithNames(stream.Context(), adapter.awaitingCaveats) + if err != nil { + return err + } + + for _, caveat := range caveats { + loadedCaveats[caveat.Definition.Name] = caveat.Definition + } + adapter.awaitingCaveats = nil + } + } + numWritten += streamWritten + + return err + }, dsoptions.WithDisableRetries(true)); err != nil { + return shared.RewriteErrorWithoutConfig(stream.Context(), err) + } + + usagemetrics.SetInContext(stream.Context(), &dispatchv1.ResponseMeta{ + // One request for the whole load + DispatchCount: 1, + }) + + return stream.SendAndClose(&v1.BulkImportRelationshipsResponse{ + NumLoaded: numWritten, + }) +} + +// TODO: this is now duplicate code with ExportBulkRelationships +func (es *experimentalServer) BulkExportRelationships( + req *v1.BulkExportRelationshipsRequest, + resp grpc.ServerStreamingServer[v1.BulkExportRelationshipsResponse], +) error { + ctx := resp.Context() + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + return BulkExport(ctx, datastoremw.MustFromContext(ctx), es.maxBatchSize, req, atRevision, resp.Send) +} + +// BulkExport implements the BulkExportRelationships API functionality. Given a datastore.Datastore, it will +// export stream via the sender all relationships matched by the incoming request. +// If no cursor is provided, it will fallback to the provided revision. +func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.BulkExportRelationshipsResponse) error) error { + if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { + return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) + } + + atRevision := fallbackRevision + var curNamespace string + var cur dsoptions.Cursor + if req.OptionalCursor != nil { + var err error + atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + } + + reader := ds.SnapshotReader(atRevision) + + namespaces, err := reader.ListAllNamespaces(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Make sure the namespaces are always in a stable order + slices.SortFunc(namespaces, func( + lhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + rhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + ) int { + return strings.Compare(lhs.Definition.Name, rhs.Definition.Name) + }) + + // Skip the namespaces that are already fully returned + for cur != nil && len(namespaces) > 0 && namespaces[0].Definition.Name < curNamespace { + namespaces = namespaces[1:] + } + + limit := batchSize + if req.OptionalLimit > 0 { + limit = uint64(req.OptionalLimit) + } + + // Pre-allocate all of the relationships that we might need in order to + // make export easier and faster for the garbage collector. + relsArray := make([]v1.Relationship, limit) + objArray := make([]v1.ObjectReference, limit) + subArray := make([]v1.SubjectReference, limit) + subObjArray := make([]v1.ObjectReference, limit) + caveatArray := make([]v1.ContextualizedCaveat, limit) + for i := range relsArray { + relsArray[i].Resource = &objArray[i] + relsArray[i].Subject = &subArray[i] + relsArray[i].Subject.Object = &subObjArray[i] + } + + emptyRels := make([]*v1.Relationship, limit) + for _, ns := range namespaces { + rels := emptyRels + + // Reset the cursor between namespaces. + if ns.Definition.Name != curNamespace { + cur = nil + } + + // Skip this namespace if a resource type filter was specified. + if req.OptionalRelationshipFilter != nil && req.OptionalRelationshipFilter.ResourceType != "" { + if ns.Definition.Name != req.OptionalRelationshipFilter.ResourceType { + continue + } + } + + // Setup the filter to use for the relationships. + relationshipFilter := datastore.RelationshipsFilter{OptionalResourceType: ns.Definition.Name} + if req.OptionalRelationshipFilter != nil { + rf, err := datastore.RelationshipsFilterFromPublicFilter(req.OptionalRelationshipFilter) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Overload the namespace name with the one from the request, because each iteration is for a different namespace. + rf.OptionalResourceType = ns.Definition.Name + relationshipFilter = rf + } + + // We want to keep iterating as long as we're sending full batches. + // To bootstrap this loop, we enter the first time with a full rels + // slice of dummy rels that were never sent. + for uint64(len(rels)) == limit { + // Lop off any rels we've already sent + rels = rels[:0] + + relFn := func(rel tuple.Relationship) { + offset := len(rels) + rels = append(rels, &relsArray[offset]) // nozero + + v1Rel := &relsArray[offset] + v1Rel.Resource.ObjectType = rel.RelationshipReference.Resource.ObjectType + v1Rel.Resource.ObjectId = rel.RelationshipReference.Resource.ObjectID + v1Rel.Relation = rel.RelationshipReference.Resource.Relation + v1Rel.Subject.Object.ObjectType = rel.RelationshipReference.Subject.ObjectType + v1Rel.Subject.Object.ObjectId = rel.RelationshipReference.Subject.ObjectID + v1Rel.Subject.OptionalRelation = denormalizeSubjectRelation(rel.RelationshipReference.Subject.Relation) + + if rel.OptionalCaveat != nil { + caveatArray[offset].CaveatName = rel.OptionalCaveat.CaveatName + caveatArray[offset].Context = rel.OptionalCaveat.Context + v1Rel.OptionalCaveat = &caveatArray[offset] + } else { + v1Rel.OptionalCaveat = nil + } + + if rel.OptionalExpiration != nil { + v1Rel.OptionalExpiresAt = timestamppb.New(*rel.OptionalExpiration) + } else { + v1Rel.OptionalExpiresAt = nil + } + } + + cur, err = queryForEach( + ctx, + reader, + relationshipFilter, + relFn, + dsoptions.WithLimit(&limit), + dsoptions.WithAfter(cur), + dsoptions.WithSort(dsoptions.ByResource), + dsoptions.WithQueryShape(queryshape.Varying), + ) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if len(rels) == 0 { + continue + } + + encoded, err := cursor.Encode(&implv1.DecodedCursor{ + VersionOneof: &implv1.DecodedCursor_V1{ + V1: &implv1.V1Cursor{ + Revision: atRevision.String(), + Sections: []string{ + ns.Definition.Name, + tuple.MustString(*dsoptions.ToRelationship(cur)), + }, + }, + }, + }) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if err := sender(&v1.BulkExportRelationshipsResponse{ + AfterResultCursor: encoded, + Relationships: rels, + }); err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + } + } + return nil +} + +func (es *experimentalServer) BulkCheckPermission(ctx context.Context, req *v1.BulkCheckPermissionRequest) (*v1.BulkCheckPermissionResponse, error) { + convertedReq := toCheckBulkPermissionsRequest(req) + res, err := es.bulkChecker.checkBulkPermissions(ctx, convertedReq) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return toBulkCheckPermissionResponse(res), nil +} + +func (es *experimentalServer) ExperimentalReflectSchema(ctx context.Context, req *v1.ExperimentalReflectSchemaRequest) (*v1.ExperimentalReflectSchemaResponse, error) { + // Get the current schema. + schema, atRevision, err := loadCurrentSchema(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + filters, err := newexpSchemaFilters(req.OptionalFilters) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + definitions := make([]*v1.ExpDefinition, 0, len(schema.ObjectDefinitions)) + if filters.HasNamespaces() { + for _, ns := range schema.ObjectDefinitions { + def, err := expNamespaceAPIRepr(ns, filters) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + if def != nil { + definitions = append(definitions, def) + } + } + } + + caveats := make([]*v1.ExpCaveat, 0, len(schema.CaveatDefinitions)) + if filters.HasCaveats() { + for _, cd := range schema.CaveatDefinitions { + caveat, err := expCaveatAPIRepr(cd, filters, es.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + if caveat != nil { + caveats = append(caveats, caveat) + } + } + } + + return &v1.ExperimentalReflectSchemaResponse{ + Definitions: definitions, + Caveats: caveats, + ReadAt: zedtoken.MustNewFromRevision(atRevision), + }, nil +} + +func (es *experimentalServer) ExperimentalDiffSchema(ctx context.Context, req *v1.ExperimentalDiffSchemaRequest) (*v1.ExperimentalDiffSchemaResponse, error) { + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, err + } + + diff, existingSchema, comparisonSchema, err := schemaDiff(ctx, req.ComparisonSchema, es.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + resp, err := expConvertDiff(diff, existingSchema, comparisonSchema, atRevision, es.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return resp, nil +} + +func (es *experimentalServer) ExperimentalComputablePermissions(ctx context.Context, req *v1.ExperimentalComputablePermissionsRequest) (*v1.ExperimentalComputablePermissionsResponse, error) { + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) + vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relationName := req.RelationName + if relationName == "" { + relationName = tuple.Ellipsis + } else { + if _, ok := vdef.GetRelation(relationName); !ok { + return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, relationName)) + } + } + + allNamespaces, err := ds.ListAllNamespaces(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + allDefinitions := make([]*core.NamespaceDefinition, 0, len(allNamespaces)) + for _, ns := range allNamespaces { + allDefinitions = append(allDefinitions, ns.Definition) + } + + rg := vdef.Reachability() + rr, err := rg.RelationsEncounteredForSubject(ctx, allDefinitions, &core.RelationReference{ + Namespace: req.DefinitionName, + Relation: relationName, + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations := make([]*v1.ExpRelationReference, 0, len(rr)) + for _, r := range rr { + if r.Namespace == req.DefinitionName && r.Relation == req.RelationName { + continue + } + + if req.OptionalDefinitionNameFilter != "" && !strings.HasPrefix(r.Namespace, req.OptionalDefinitionNameFilter) { + continue + } + + def, err := ts.GetValidatedDefinition(ctx, r.Namespace) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations = append(relations, &v1.ExpRelationReference{ + DefinitionName: r.Namespace, + RelationName: r.Relation, + IsPermission: def.IsPermission(r.Relation), + }) + } + + sort.Slice(relations, func(i, j int) bool { + if relations[i].DefinitionName == relations[j].DefinitionName { + return relations[i].RelationName < relations[j].RelationName + } + return relations[i].DefinitionName < relations[j].DefinitionName + }) + + return &v1.ExperimentalComputablePermissionsResponse{ + Permissions: relations, + ReadAt: revisionReadAt, + }, nil +} + +func (es *experimentalServer) ExperimentalDependentRelations(ctx context.Context, req *v1.ExperimentalDependentRelationsRequest) (*v1.ExperimentalDependentRelationsResponse, error) { + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) + vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + _, ok := vdef.GetRelation(req.PermissionName) + if !ok { + return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, req.PermissionName)) + } + + if !vdef.IsPermission(req.PermissionName) { + return nil, shared.RewriteErrorWithoutConfig(ctx, NewNotAPermissionError(req.PermissionName)) + } + + rg := vdef.Reachability() + rr, err := rg.RelationsEncounteredForResource(ctx, &core.RelationReference{ + Namespace: req.DefinitionName, + Relation: req.PermissionName, + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations := make([]*v1.ExpRelationReference, 0, len(rr)) + for _, r := range rr { + if r.Namespace == req.DefinitionName && r.Relation == req.PermissionName { + continue + } + + ts, err := ts.GetDefinition(ctx, r.Namespace) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations = append(relations, &v1.ExpRelationReference{ + DefinitionName: r.Namespace, + RelationName: r.Relation, + IsPermission: ts.IsPermission(r.Relation), + }) + } + + sort.Slice(relations, func(i, j int) bool { + if relations[i].DefinitionName == relations[j].DefinitionName { + return relations[i].RelationName < relations[j].RelationName + } + + return relations[i].DefinitionName < relations[j].DefinitionName + }) + + return &v1.ExperimentalDependentRelationsResponse{ + Relations: relations, + ReadAt: revisionReadAt, + }, nil +} + +func (es *experimentalServer) ExperimentalRegisterRelationshipCounter(ctx context.Context, req *v1.ExperimentalRegisterRelationshipCounterRequest) (*v1.ExperimentalRegisterRelationshipCounterResponse, error) { + ds := datastoremw.MustFromContext(ctx) + + if req.Name == "" { + return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED)) + } + + _, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, rwt); err != nil { + return err + } + + coreFilter := datastore.CoreFilterFromRelationshipFilter(req.RelationshipFilter) + return rwt.RegisterCounter(ctx, req.Name, coreFilter) + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return &v1.ExperimentalRegisterRelationshipCounterResponse{}, nil +} + +func (es *experimentalServer) ExperimentalUnregisterRelationshipCounter(ctx context.Context, req *v1.ExperimentalUnregisterRelationshipCounterRequest) (*v1.ExperimentalUnregisterRelationshipCounterResponse, error) { + ds := datastoremw.MustFromContext(ctx) + + if req.Name == "" { + return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED)) + } + + _, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.UnregisterCounter(ctx, req.Name) + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return &v1.ExperimentalUnregisterRelationshipCounterResponse{}, nil +} + +func (es *experimentalServer) ExperimentalCountRelationships(ctx context.Context, req *v1.ExperimentalCountRelationshipsRequest) (*v1.ExperimentalCountRelationshipsResponse, error) { + if req.Name == "" { + return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED)) + } + + ds := datastoremw.MustFromContext(ctx) + headRev, err := ds.HeadRevision(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + snapshotReader := ds.SnapshotReader(headRev) + count, err := snapshotReader.CountRelationships(ctx, req.Name) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + uintCount, err := safecast.ToUint64(count) + if err != nil { + return nil, spiceerrors.MustBugf("count should not be negative") + } + + return &v1.ExperimentalCountRelationshipsResponse{ + CounterResult: &v1.ExperimentalCountRelationshipsResponse_ReadCounterValue{ + ReadCounterValue: &v1.ReadCounterValue{ + RelationshipCount: uintCount, + ReadAt: zedtoken.MustNewFromRevision(headRev), + }, + }, + }, nil +} + +func queryForEach( + ctx context.Context, + reader datastore.Reader, + filter datastore.RelationshipsFilter, + fn func(rel tuple.Relationship), + opts ...dsoptions.QueryOptionsOption, +) (dsoptions.Cursor, error) { + iter, err := reader.QueryRelationships(ctx, filter, opts...) + if err != nil { + return nil, err + } + + var cursor dsoptions.Cursor + for rel, err := range iter { + if err != nil { + return nil, err + } + + fn(rel) + cursor = dsoptions.ToCursor(rel) + } + return cursor, nil +} + +func decodeCursor(ds datastore.ReadOnlyDatastore, encoded *v1.Cursor) (datastore.Revision, string, dsoptions.Cursor, error) { + decoded, err := cursor.Decode(encoded) + if err != nil { + return datastore.NoRevision, "", nil, err + } + + if decoded.GetV1() == nil { + return datastore.NoRevision, "", nil, errors.New("malformed cursor: no V1 in OneOf") + } + + if len(decoded.GetV1().Sections) != 2 { + return datastore.NoRevision, "", nil, errors.New("malformed cursor: wrong number of components") + } + + atRevision, err := ds.RevisionFromString(decoded.GetV1().Revision) + if err != nil { + return datastore.NoRevision, "", nil, err + } + + cur, err := tuple.Parse(decoded.GetV1().GetSections()[1]) + if err != nil { + return datastore.NoRevision, "", nil, fmt.Errorf("malformed cursor: invalid encoded relation tuple: %w", err) + } + + // Returns the current namespace and the cursor. + return atRevision, decoded.GetV1().GetSections()[0], dsoptions.ToCursor(cur), nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go b/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go new file mode 100644 index 0000000..8ef6c25 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go @@ -0,0 +1,720 @@ +package v1 + +import ( + "sort" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/diff" + caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats" + nsdiff "github.com/authzed/spicedb/pkg/diff/namespace" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +type expSchemaFilters struct { + filters []*v1.ExpSchemaFilter +} + +func newexpSchemaFilters(filters []*v1.ExpSchemaFilter) (*expSchemaFilters, error) { + for _, filter := range filters { + if filter.OptionalDefinitionNameFilter != "" { + if filter.OptionalCaveatNameFilter != "" { + return nil, NewInvalidFilterErr("cannot filter by both definition and caveat name", filter.String()) + } + } + + if filter.OptionalRelationNameFilter != "" { + if filter.OptionalDefinitionNameFilter == "" { + return nil, NewInvalidFilterErr("relation name match requires definition name match", filter.String()) + } + + if filter.OptionalPermissionNameFilter != "" { + return nil, NewInvalidFilterErr("cannot filter by both relation and permission name", filter.String()) + } + } + + if filter.OptionalPermissionNameFilter != "" { + if filter.OptionalDefinitionNameFilter == "" { + return nil, NewInvalidFilterErr("permission name match requires definition name match", filter.String()) + } + } + } + + return &expSchemaFilters{filters: filters}, nil +} + +func (sf *expSchemaFilters) HasNamespaces() bool { + if len(sf.filters) == 0 { + return true + } + + for _, filter := range sf.filters { + if filter.OptionalDefinitionNameFilter != "" { + return true + } + } + + return false +} + +func (sf *expSchemaFilters) HasCaveats() bool { + if len(sf.filters) == 0 { + return true + } + + for _, filter := range sf.filters { + if filter.OptionalCaveatNameFilter != "" { + return true + } + } + + return false +} + +func (sf *expSchemaFilters) HasNamespace(namespaceName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasDefinitionFilter := false + for _, filter := range sf.filters { + if filter.OptionalDefinitionNameFilter == "" { + continue + } + + hasDefinitionFilter = true + isMatch := strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasDefinitionFilter +} + +func (sf *expSchemaFilters) HasCaveat(caveatName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasCaveatFilter := false + for _, filter := range sf.filters { + if filter.OptionalCaveatNameFilter == "" { + continue + } + + hasCaveatFilter = true + isMatch := strings.HasPrefix(caveatName, filter.OptionalCaveatNameFilter) + if isMatch { + return true + } + } + + return !hasCaveatFilter +} + +func (sf *expSchemaFilters) HasRelation(namespaceName, relationName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasRelationFilter := false + for _, filter := range sf.filters { + if filter.OptionalRelationNameFilter == "" { + continue + } + + hasRelationFilter = true + isMatch := strings.HasPrefix(relationName, filter.OptionalRelationNameFilter) + if !isMatch { + continue + } + + isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasRelationFilter +} + +func (sf *expSchemaFilters) HasPermission(namespaceName, permissionName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasPermissionFilter := false + for _, filter := range sf.filters { + if filter.OptionalPermissionNameFilter == "" { + continue + } + + hasPermissionFilter = true + isMatch := strings.HasPrefix(permissionName, filter.OptionalPermissionNameFilter) + if !isMatch { + continue + } + + isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasPermissionFilter +} + +// expConvertDiff converts a schema diff into an API response. +func expConvertDiff( + diff *diff.SchemaDiff, + existingSchema *diff.DiffableSchema, + comparisonSchema *diff.DiffableSchema, + atRevision datastore.Revision, + caveatTypeSet *caveattypes.TypeSet, +) (*v1.ExperimentalDiffSchemaResponse, error) { + size := len(diff.AddedNamespaces) + len(diff.RemovedNamespaces) + len(diff.AddedCaveats) + len(diff.RemovedCaveats) + len(diff.ChangedNamespaces) + len(diff.ChangedCaveats) + diffs := make([]*v1.ExpSchemaDiff, 0, size) + + // Add/remove namespaces. + for _, ns := range diff.AddedNamespaces { + nsDef, err := expNamespaceAPIReprForName(ns, comparisonSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_DefinitionAdded{ + DefinitionAdded: nsDef, + }, + }) + } + + for _, ns := range diff.RemovedNamespaces { + nsDef, err := expNamespaceAPIReprForName(ns, existingSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_DefinitionRemoved{ + DefinitionRemoved: nsDef, + }, + }) + } + + // Add/remove caveats. + for _, caveat := range diff.AddedCaveats { + caveatDef, err := expCaveatAPIReprForName(caveat, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatAdded{ + CaveatAdded: caveatDef, + }, + }) + } + + for _, caveat := range diff.RemovedCaveats { + caveatDef, err := expCaveatAPIReprForName(caveat, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatRemoved{ + CaveatRemoved: caveatDef, + }, + }) + } + + // Changed namespaces. + for nsName, nsDiff := range diff.ChangedNamespaces { + for _, delta := range nsDiff.Deltas() { + switch delta.Type { + case nsdiff.AddedPermission: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := expPermissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_PermissionAdded{ + PermissionAdded: perm, + }, + }) + + case nsdiff.AddedRelation: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationAdded{ + RelationAdded: rel, + }, + }) + + case nsdiff.ChangedPermissionComment: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := expPermissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_PermissionDocCommentChanged{ + PermissionDocCommentChanged: perm, + }, + }) + + case nsdiff.ChangedPermissionImpl: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := expPermissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_PermissionExprChanged{ + PermissionExprChanged: perm, + }, + }) + + case nsdiff.ChangedRelationComment: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationDocCommentChanged{ + RelationDocCommentChanged: rel, + }, + }) + + case nsdiff.LegacyChangedRelationImpl: + return nil, spiceerrors.MustBugf("legacy relation implementation changes are not supported") + + case nsdiff.NamespaceCommentsChanged: + def, err := expNamespaceAPIReprForName(nsName, comparisonSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_DefinitionDocCommentChanged{ + DefinitionDocCommentChanged: def, + }, + }) + + case nsdiff.RelationAllowedTypeRemoved: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationSubjectTypeRemoved{ + RelationSubjectTypeRemoved: &v1.ExpRelationSubjectTypeChange{ + Relation: rel, + ChangedSubjectType: expTypeAPIRepr(delta.AllowedType), + }, + }, + }) + + case nsdiff.RelationAllowedTypeAdded: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationSubjectTypeAdded{ + RelationSubjectTypeAdded: &v1.ExpRelationSubjectTypeChange{ + Relation: rel, + ChangedSubjectType: expTypeAPIRepr(delta.AllowedType), + }, + }, + }) + + case nsdiff.RemovedPermission: + permission, ok := existingSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := expPermissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_PermissionRemoved{ + PermissionRemoved: perm, + }, + }) + + case nsdiff.RemovedRelation: + relation, ok := existingSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationRemoved{ + RelationRemoved: rel, + }, + }) + + case nsdiff.NamespaceAdded: + return nil, spiceerrors.MustBugf("should be handled above") + + case nsdiff.NamespaceRemoved: + return nil, spiceerrors.MustBugf("should be handled above") + + default: + return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type) + } + } + } + + // Changed caveats. + for caveatName, caveatDiff := range diff.ChangedCaveats { + for _, delta := range caveatDiff.Deltas() { + switch delta.Type { + case caveatdiff.CaveatCommentsChanged: + caveat, err := expCaveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatDocCommentChanged{ + CaveatDocCommentChanged: caveat, + }, + }) + + case caveatdiff.AddedParameter: + paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatParameterAdded{ + CaveatParameterAdded: paramDef, + }, + }) + + case caveatdiff.RemovedParameter: + paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatParameterRemoved{ + CaveatParameterRemoved: paramDef, + }, + }) + + case caveatdiff.ParameterTypeChanged: + previousParamDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatParameterTypeChanged{ + CaveatParameterTypeChanged: &v1.ExpCaveatParameterTypeChange{ + Parameter: paramDef, + PreviousType: previousParamDef.Type, + }, + }, + }) + + case caveatdiff.CaveatExpressionChanged: + caveat, err := expCaveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatExprChanged{ + CaveatExprChanged: caveat, + }, + }) + + case caveatdiff.CaveatAdded: + return nil, spiceerrors.MustBugf("should be handled above") + + case caveatdiff.CaveatRemoved: + return nil, spiceerrors.MustBugf("should be handled above") + + default: + return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type) + } + } + } + + return &v1.ExperimentalDiffSchemaResponse{ + Diffs: diffs, + ReadAt: zedtoken.MustNewFromRevision(atRevision), + }, nil +} + +// expNamespaceAPIReprForName builds an API representation of a namespace. +func expNamespaceAPIReprForName(namespaceName string, schema *diff.DiffableSchema) (*v1.ExpDefinition, error) { + nsDef, ok := schema.GetNamespace(namespaceName) + if !ok { + return nil, spiceerrors.MustBugf("namespace %q not found in schema", namespaceName) + } + + return expNamespaceAPIRepr(nsDef, nil) +} + +func expNamespaceAPIRepr(nsDef *core.NamespaceDefinition, expSchemaFilters *expSchemaFilters) (*v1.ExpDefinition, error) { + if expSchemaFilters != nil && !expSchemaFilters.HasNamespace(nsDef.Name) { + return nil, nil + } + + relations := make([]*v1.ExpRelation, 0, len(nsDef.Relation)) + permissions := make([]*v1.ExpPermission, 0, len(nsDef.Relation)) + + for _, rel := range nsDef.Relation { + if namespace.GetRelationKind(rel) == iv1.RelationMetadata_PERMISSION { + permission, err := expPermissionAPIRepr(rel, nsDef.Name, expSchemaFilters) + if err != nil { + return nil, err + } + + if permission != nil { + permissions = append(permissions, permission) + } + continue + } + + relation, err := expRelationAPIRepr(rel, nsDef.Name, expSchemaFilters) + if err != nil { + return nil, err + } + + if relation != nil { + relations = append(relations, relation) + } + } + + comments := namespace.GetComments(nsDef.Metadata) + return &v1.ExpDefinition{ + Name: nsDef.Name, + Comment: strings.Join(comments, "\n"), + Relations: relations, + Permissions: permissions, + }, nil +} + +// expPermissionAPIRepr builds an API representation of a permission. +func expPermissionAPIRepr(relation *core.Relation, parentDefName string, expSchemaFilters *expSchemaFilters) (*v1.ExpPermission, error) { + if expSchemaFilters != nil && !expSchemaFilters.HasPermission(parentDefName, relation.Name) { + return nil, nil + } + + comments := namespace.GetComments(relation.Metadata) + return &v1.ExpPermission{ + Name: relation.Name, + Comment: strings.Join(comments, "\n"), + ParentDefinitionName: parentDefName, + }, nil +} + +// expRelationAPIRepr builds an API representation of a relation. +func expRelationAPIRepr(relation *core.Relation, parentDefName string, expSchemaFilters *expSchemaFilters) (*v1.ExpRelation, error) { + if expSchemaFilters != nil && !expSchemaFilters.HasRelation(parentDefName, relation.Name) { + return nil, nil + } + + comments := namespace.GetComments(relation.Metadata) + + var subjectTypes []*v1.ExpTypeReference + if relation.TypeInformation != nil { + subjectTypes = make([]*v1.ExpTypeReference, 0, len(relation.TypeInformation.AllowedDirectRelations)) + for _, subjectType := range relation.TypeInformation.AllowedDirectRelations { + typeref := expTypeAPIRepr(subjectType) + subjectTypes = append(subjectTypes, typeref) + } + } + + return &v1.ExpRelation{ + Name: relation.Name, + Comment: strings.Join(comments, "\n"), + ParentDefinitionName: parentDefName, + SubjectTypes: subjectTypes, + }, nil +} + +// expTypeAPIRepr builds an API representation of a type. +func expTypeAPIRepr(subjectType *core.AllowedRelation) *v1.ExpTypeReference { + typeref := &v1.ExpTypeReference{ + SubjectDefinitionName: subjectType.Namespace, + Typeref: &v1.ExpTypeReference_IsTerminalSubject{}, + } + + if subjectType.GetRelation() != tuple.Ellipsis && subjectType.GetRelation() != "" { + typeref.Typeref = &v1.ExpTypeReference_OptionalRelationName{ + OptionalRelationName: subjectType.GetRelation(), + } + } else if subjectType.GetPublicWildcard() != nil { + typeref.Typeref = &v1.ExpTypeReference_IsPublicWildcard{ + IsPublicWildcard: true, + } + } + + if subjectType.GetRequiredCaveat() != nil { + typeref.OptionalCaveatName = subjectType.GetRequiredCaveat().CaveatName + } + + return typeref +} + +// expCaveatAPIReprForName builds an API representation of a caveat. +func expCaveatAPIReprForName(caveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveat, error) { + caveatDef, ok := schema.GetCaveat(caveatName) + if !ok { + return nil, spiceerrors.MustBugf("caveat %q not found in schema", caveatName) + } + + return expCaveatAPIRepr(caveatDef, nil, caveatTypeSet) +} + +// expCaveatAPIRepr builds an API representation of a caveat. +func expCaveatAPIRepr(caveatDef *core.CaveatDefinition, expSchemaFilters *expSchemaFilters, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveat, error) { + if expSchemaFilters != nil && !expSchemaFilters.HasCaveat(caveatDef.Name) { + return nil, nil + } + + parameters := make([]*v1.ExpCaveatParameter, 0, len(caveatDef.ParameterTypes)) + paramNames := maps.Keys(caveatDef.ParameterTypes) + sort.Strings(paramNames) + + for _, paramName := range paramNames { + paramType, ok := caveatDef.ParameterTypes[paramName] + if !ok { + return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, caveatDef.Name) + } + + decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType) + if err != nil { + return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err) + } + + parameters = append(parameters, &v1.ExpCaveatParameter{ + Name: paramName, + Type: decoded.String(), + ParentCaveatName: caveatDef.Name, + }) + } + + parameterTypes, err := caveattypes.DecodeParameterTypes(caveatTypeSet, caveatDef.ParameterTypes) + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat parameters: %v", err) + } + + deserializedExpression, err := caveats.DeserializeCaveatWithTypeSet(caveatTypeSet, caveatDef.SerializedExpression, parameterTypes) + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat expression bytes: %v", err) + } + + exprString, err := deserializedExpression.ExprString() + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat expression: %v", err) + } + + comments := namespace.GetComments(caveatDef.Metadata) + return &v1.ExpCaveat{ + Name: caveatDef.Name, + Comment: strings.Join(comments, "\n"), + Parameters: parameters, + Expression: exprString, + }, nil +} + +// expCaveatAPIParamRepr builds an API representation of a caveat parameter. +func expCaveatAPIParamRepr(paramName, parentCaveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveatParameter, error) { + caveatDef, ok := schema.GetCaveat(parentCaveatName) + if !ok { + return nil, spiceerrors.MustBugf("caveat %q not found in schema", parentCaveatName) + } + + paramType, ok := caveatDef.ParameterTypes[paramName] + if !ok { + return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, parentCaveatName) + } + + decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType) + if err != nil { + return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err) + } + + return &v1.ExpCaveatParameter{ + Name: paramName, + Type: decoded.String(), + ParentCaveatName: parentCaveatName, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go b/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go new file mode 100644 index 0000000..99b681d --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go @@ -0,0 +1,72 @@ +package v1 + +import ( + "context" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/graph/computed" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/tuple" +) + +type groupedCheckParameters struct { + params *computed.CheckParameters + resourceIDs []string +} + +type groupingParameters struct { + atRevision datastore.Revision + maximumAPIDepth uint32 + maxCaveatContextSize int + withTracing bool +} + +// groupItems takes a slice of CheckBulkPermissionsRequestItem and groups them based +// on using the same permission, subject type, subject id, and caveat. +func groupItems(ctx context.Context, params groupingParameters, items []*v1.CheckBulkPermissionsRequestItem) (map[string]*groupedCheckParameters, error) { + res := make(map[string]*groupedCheckParameters) + + for _, item := range items { + hash, err := computeCheckBulkPermissionsItemHashWithoutResourceID(item) + if err != nil { + return nil, err + } + + if _, ok := res[hash]; !ok { + caveatContext, err := GetCaveatContext(ctx, item.Context, params.maxCaveatContextSize) + if err != nil { + return nil, err + } + + res[hash] = &groupedCheckParameters{ + params: checkParametersFromCheckBulkPermissionsRequestItem(item, params, caveatContext), + resourceIDs: []string{item.Resource.ObjectId}, + } + } else { + res[hash].resourceIDs = append(res[hash].resourceIDs, item.Resource.ObjectId) + } + } + + return res, nil +} + +func checkParametersFromCheckBulkPermissionsRequestItem( + bc *v1.CheckBulkPermissionsRequestItem, + params groupingParameters, + caveatContext map[string]any, +) *computed.CheckParameters { + debugOption := computed.NoDebugging + if params.withTracing { + debugOption = computed.BasicDebuggingEnabled + } + + return &computed.CheckParameters{ + ResourceType: tuple.RR(bc.Resource.ObjectType, bc.Permission), + Subject: tuple.ONR(bc.Subject.Object.ObjectType, bc.Subject.Object.ObjectId, normalizeSubjectRelation(bc.Subject)), + CaveatContext: caveatContext, + AtRevision: params.atRevision, + MaximumDepth: params.maximumAPIDepth, + DebugOption: debugOption, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go new file mode 100644 index 0000000..1754669 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go @@ -0,0 +1,110 @@ +package v1 + +import ( + "strconv" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/spicedb/pkg/caveats" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +func computeCheckBulkPermissionsItemHashWithoutResourceID(req *v1.CheckBulkPermissionsRequestItem) (string, error) { + return computeCallHash("v1.checkbulkpermissionsrequestitem", nil, map[string]any{ + "resource-type": req.Resource.ObjectType, + "permission": req.Permission, + "subject-type": req.Subject.Object.ObjectType, + "subject-id": req.Subject.Object.ObjectId, + "subject-relation": req.Subject.OptionalRelation, + "context": req.Context, + }) +} + +func computeCheckBulkPermissionsItemHash(req *v1.CheckBulkPermissionsRequestItem) (string, error) { + return computeCallHash("v1.checkbulkpermissionsrequestitem", nil, map[string]any{ + "resource-type": req.Resource.ObjectType, + "resource-id": req.Resource.ObjectId, + "permission": req.Permission, + "subject-type": req.Subject.Object.ObjectType, + "subject-id": req.Subject.Object.ObjectId, + "subject-relation": req.Subject.OptionalRelation, + "context": req.Context, + }) +} + +func computeReadRelationshipsRequestHash(req *v1.ReadRelationshipsRequest) (string, error) { + osf := req.RelationshipFilter.OptionalSubjectFilter + if osf == nil { + osf = &v1.SubjectFilter{} + } + + srf := "(none)" + if osf.OptionalRelation != nil { + srf = osf.OptionalRelation.Relation + } + + return computeCallHash("v1.readrelationships", req.Consistency, map[string]any{ + "filter-resource-type": req.RelationshipFilter.ResourceType, + "filter-relation": req.RelationshipFilter.OptionalRelation, + "filter-resource-id": req.RelationshipFilter.OptionalResourceId, + "subject-type": osf.SubjectType, + "subject-relation": srf, + "subject-resource-id": osf.OptionalSubjectId, + "limit": req.OptionalLimit, + }) +} + +func computeLRRequestHash(req *v1.LookupResourcesRequest) (string, error) { + return computeCallHash("v1.lookupresources", req.Consistency, map[string]any{ + "resource-type": req.ResourceObjectType, + "permission": req.Permission, + "subject": tuple.V1StringSubjectRef(req.Subject), + "limit": req.OptionalLimit, + "context": req.Context, + }) +} + +func computeCallHash(apiName string, consistency *v1.Consistency, arguments map[string]any) (string, error) { + stringArguments := make(map[string]string, len(arguments)+1) + + if consistency == nil { + consistency = &v1.Consistency{ + Requirement: &v1.Consistency_MinimizeLatency{ + MinimizeLatency: true, + }, + } + } + + consistencyBytes, err := consistency.MarshalVT() + if err != nil { + return "", err + } + + stringArguments["consistency"] = string(consistencyBytes) + + for argName, argValue := range arguments { + if argName == "consistency" { + return "", spiceerrors.MustBugf("cannot specify consistency in the arguments") + } + + switch v := argValue.(type) { + case string: + stringArguments[argName] = v + + case int: + stringArguments[argName] = strconv.Itoa(v) + + case uint32: + stringArguments[argName] = strconv.Itoa(int(v)) + + case *structpb.Struct: + stringArguments[argName] = caveats.StableContextStringForHashing(v) + + default: + return "", spiceerrors.MustBugf("unknown argument type in compute call hash") + } + } + return computeAPICallHash(apiName, stringArguments) +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go new file mode 100644 index 0000000..fad4a40 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go @@ -0,0 +1,52 @@ +//go:build !wasm +// +build !wasm + +package v1 + +import ( + "fmt" + "sort" + + "github.com/cespare/xxhash/v2" + "golang.org/x/exp/maps" +) + +func computeAPICallHash(apiName string, arguments map[string]string) (string, error) { + hasher := xxhash.New() + _, err := hasher.WriteString(apiName) + if err != nil { + return "", err + } + + _, err = hasher.WriteString(":") + if err != nil { + return "", err + } + + keys := maps.Keys(arguments) + sort.Strings(keys) + + for _, key := range keys { + _, err = hasher.WriteString(key) + if err != nil { + return "", err + } + + _, err = hasher.WriteString(":") + if err != nil { + return "", err + } + + _, err = hasher.WriteString(arguments[key]) + if err != nil { + return "", err + } + + _, err = hasher.WriteString(";") + if err != nil { + return "", err + } + } + + return fmt.Sprintf("%x", hasher.Sum(nil)), nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go new file mode 100644 index 0000000..4c75aa0 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go @@ -0,0 +1,50 @@ +package v1 + +import ( + "crypto/sha256" + "fmt" + "sort" + + "golang.org/x/exp/maps" +) + +func computeAPICallHash(apiName string, arguments map[string]string) (string, error) { + h := sha256.New() + + _, err := h.Write([]byte(apiName)) + if err != nil { + return "", err + } + + _, err = h.Write([]byte(":")) + if err != nil { + return "", err + } + + keys := maps.Keys(arguments) + sort.Strings(keys) + + for _, key := range keys { + _, err = h.Write([]byte(key)) + if err != nil { + return "", err + } + + _, err = h.Write([]byte(":")) + if err != nil { + return "", err + } + + _, err = h.Write([]byte(arguments[key])) + if err != nil { + return "", err + } + + _, err = h.Write([]byte(";")) + if err != nil { + return "", err + } + } + + return fmt.Sprintf("%x", h.Sum(nil)), nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go b/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go new file mode 100644 index 0000000..d309c3b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go @@ -0,0 +1,12 @@ +package options + +import "time" + +//go:generate go run github.com/ecordell/optgen -output zz_generated.query_options.go . ExperimentalServerOptions + +type ExperimentalServerOptions struct { + StreamReadTimeout time.Duration `debugmap:"visible" default:"600s"` + DefaultExportBatchSize uint32 `debugmap:"visible" default:"1_000"` + MaxExportBatchSize uint32 `debugmap:"visible" default:"100_000"` + BulkCheckMaxConcurrency uint16 `debugmap:"visible" default:"50"` +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go b/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go new file mode 100644 index 0000000..5b75b5f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go @@ -0,0 +1,93 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package options + +import ( + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" + "time" +) + +type ExperimentalServerOptionsOption func(e *ExperimentalServerOptions) + +// NewExperimentalServerOptionsWithOptions creates a new ExperimentalServerOptions with the passed in options set +func NewExperimentalServerOptionsWithOptions(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions { + e := &ExperimentalServerOptions{} + for _, o := range opts { + o(e) + } + return e +} + +// NewExperimentalServerOptionsWithOptionsAndDefaults creates a new ExperimentalServerOptions with the passed in options set starting from the defaults +func NewExperimentalServerOptionsWithOptionsAndDefaults(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions { + e := &ExperimentalServerOptions{} + defaults.MustSet(e) + for _, o := range opts { + o(e) + } + return e +} + +// ToOption returns a new ExperimentalServerOptionsOption that sets the values from the passed in ExperimentalServerOptions +func (e *ExperimentalServerOptions) ToOption() ExperimentalServerOptionsOption { + return func(to *ExperimentalServerOptions) { + to.StreamReadTimeout = e.StreamReadTimeout + to.DefaultExportBatchSize = e.DefaultExportBatchSize + to.MaxExportBatchSize = e.MaxExportBatchSize + to.BulkCheckMaxConcurrency = e.BulkCheckMaxConcurrency + } +} + +// DebugMap returns a map form of ExperimentalServerOptions for debugging +func (e ExperimentalServerOptions) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["StreamReadTimeout"] = helpers.DebugValue(e.StreamReadTimeout, false) + debugMap["DefaultExportBatchSize"] = helpers.DebugValue(e.DefaultExportBatchSize, false) + debugMap["MaxExportBatchSize"] = helpers.DebugValue(e.MaxExportBatchSize, false) + debugMap["BulkCheckMaxConcurrency"] = helpers.DebugValue(e.BulkCheckMaxConcurrency, false) + return debugMap +} + +// ExperimentalServerOptionsWithOptions configures an existing ExperimentalServerOptions with the passed in options set +func ExperimentalServerOptionsWithOptions(e *ExperimentalServerOptions, opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions { + for _, o := range opts { + o(e) + } + return e +} + +// WithOptions configures the receiver ExperimentalServerOptions with the passed in options set +func (e *ExperimentalServerOptions) WithOptions(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions { + for _, o := range opts { + o(e) + } + return e +} + +// WithStreamReadTimeout returns an option that can set StreamReadTimeout on a ExperimentalServerOptions +func WithStreamReadTimeout(streamReadTimeout time.Duration) ExperimentalServerOptionsOption { + return func(e *ExperimentalServerOptions) { + e.StreamReadTimeout = streamReadTimeout + } +} + +// WithDefaultExportBatchSize returns an option that can set DefaultExportBatchSize on a ExperimentalServerOptions +func WithDefaultExportBatchSize(defaultExportBatchSize uint32) ExperimentalServerOptionsOption { + return func(e *ExperimentalServerOptions) { + e.DefaultExportBatchSize = defaultExportBatchSize + } +} + +// WithMaxExportBatchSize returns an option that can set MaxExportBatchSize on a ExperimentalServerOptions +func WithMaxExportBatchSize(maxExportBatchSize uint32) ExperimentalServerOptionsOption { + return func(e *ExperimentalServerOptions) { + e.MaxExportBatchSize = maxExportBatchSize + } +} + +// WithBulkCheckMaxConcurrency returns an option that can set BulkCheckMaxConcurrency on a ExperimentalServerOptions +func WithBulkCheckMaxConcurrency(bulkCheckMaxConcurrency uint16) ExperimentalServerOptionsOption { + return func(e *ExperimentalServerOptions) { + e.BulkCheckMaxConcurrency = bulkCheckMaxConcurrency + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go b/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go new file mode 100644 index 0000000..da6dd18 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go @@ -0,0 +1,1094 @@ +package v1 + +import ( + "context" + "errors" + "fmt" + "io" + "slices" + "strings" + + "github.com/authzed/authzed-go/pkg/requestmeta" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/jzelinskie/stringz" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + + cexpr "github.com/authzed/spicedb/internal/caveats" + dispatchpkg "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph" + "github.com/authzed/spicedb/internal/graph/computed" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/relationships" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/internal/telemetry" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +func (ps *permissionServer) rewriteError(ctx context.Context, err error) error { + return shared.RewriteError(ctx, err, &shared.ConfigForErrors{ + MaximumAPIDepth: ps.config.MaximumAPIDepth, + }) +} + +func (ps *permissionServer) rewriteErrorWithOptionalDebugTrace(ctx context.Context, err error, debugTrace *v1.DebugInformation) error { + return shared.RewriteError(ctx, err, &shared.ConfigForErrors{ + MaximumAPIDepth: ps.config.MaximumAPIDepth, + DebugTrace: debugTrace, + }) +} + +func (ps *permissionServer) CheckPermission(ctx context.Context, req *v1.CheckPermissionRequest) (*v1.CheckPermissionResponse, error) { + telemetry.RecordLogicalChecks(1) + + atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + if err := namespace.CheckNamespaceAndRelations(ctx, + []namespace.TypeAndRelationToCheck{ + { + NamespaceName: req.Resource.ObjectType, + RelationName: req.Permission, + AllowEllipsis: false, + }, + { + NamespaceName: req.Subject.Object.ObjectType, + RelationName: normalizeSubjectRelation(req.Subject), + AllowEllipsis: true, + }, + }, ds); err != nil { + return nil, ps.rewriteError(ctx, err) + } + + debugOption := computed.NoDebugging + + if md, ok := metadata.FromIncomingContext(ctx); ok { + _, isDebuggingEnabled := md[string(requestmeta.RequestDebugInformation)] + if isDebuggingEnabled { + debugOption = computed.BasicDebuggingEnabled + } + } + + if req.WithTracing { + debugOption = computed.BasicDebuggingEnabled + } + + cr, metadata, err := computed.ComputeCheck(ctx, ps.dispatch, + ps.config.CaveatTypeSet, + computed.CheckParameters{ + ResourceType: tuple.RR(req.Resource.ObjectType, req.Permission), + Subject: tuple.ONR(req.Subject.Object.ObjectType, req.Subject.Object.ObjectId, normalizeSubjectRelation(req.Subject)), + CaveatContext: caveatContext, + AtRevision: atRevision, + MaximumDepth: ps.config.MaximumAPIDepth, + DebugOption: debugOption, + }, + req.Resource.ObjectId, + ps.config.DispatchChunkSize, + ) + usagemetrics.SetInContext(ctx, metadata) + + var debugTrace *v1.DebugInformation + if debugOption != computed.NoDebugging && metadata.DebugInfo != nil { + // Convert the dispatch debug information into API debug information. + converted, cerr := ConvertCheckDispatchDebugInformation(ctx, ps.config.CaveatTypeSet, caveatContext, metadata.DebugInfo, ds) + if cerr != nil { + return nil, ps.rewriteError(ctx, cerr) + } + debugTrace = converted + } + + if err != nil { + // If the error already contains debug information, rewrite it. This can happen if + // a dispatch error occurs and debug was requested. + if dispatchDebugInfo, ok := spiceerrors.GetDetails[*dispatch.DebugInformation](err); ok { + // Convert the dispatch debug information into API debug information. + converted, cerr := ConvertCheckDispatchDebugInformation(ctx, ps.config.CaveatTypeSet, caveatContext, dispatchDebugInfo, ds) + if cerr != nil { + return nil, ps.rewriteError(ctx, cerr) + } + + if converted != nil { + return nil, spiceerrors.AppendDetailsMetadata(err, spiceerrors.DebugTraceErrorDetailsKey, converted.String()) + } + } + + return nil, ps.rewriteErrorWithOptionalDebugTrace(ctx, err, debugTrace) + } + + permissionship, partialCaveat := checkResultToAPITypes(cr) + + return &v1.CheckPermissionResponse{ + CheckedAt: checkedAt, + Permissionship: permissionship, + PartialCaveatInfo: partialCaveat, + DebugTrace: debugTrace, + }, nil +} + +func checkResultToAPITypes(cr *dispatch.ResourceCheckResult) (v1.CheckPermissionResponse_Permissionship, *v1.PartialCaveatInfo) { + var partialCaveat *v1.PartialCaveatInfo + permissionship := v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION + if cr.Membership == dispatch.ResourceCheckResult_MEMBER { + permissionship = v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION + } else if cr.Membership == dispatch.ResourceCheckResult_CAVEATED_MEMBER { + permissionship = v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION + partialCaveat = &v1.PartialCaveatInfo{ + MissingRequiredContext: cr.MissingExprFields, + } + } + return permissionship, partialCaveat +} + +func (ps *permissionServer) CheckBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) { + res, err := ps.bulkChecker.checkBulkPermissions(ctx, req) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + return res, nil +} + +func pairItemFromCheckResult(checkResult *dispatch.ResourceCheckResult, debugTrace *v1.DebugInformation) *v1.CheckBulkPermissionsPair_Item { + permissionship, partialCaveat := checkResultToAPITypes(checkResult) + return &v1.CheckBulkPermissionsPair_Item{ + Item: &v1.CheckBulkPermissionsResponseItem{ + Permissionship: permissionship, + PartialCaveatInfo: partialCaveat, + DebugTrace: debugTrace, + }, + } +} + +func requestItemFromResourceAndParameters(params *computed.CheckParameters, resourceID string) (*v1.CheckBulkPermissionsRequestItem, error) { + item := &v1.CheckBulkPermissionsRequestItem{ + Resource: &v1.ObjectReference{ + ObjectType: params.ResourceType.ObjectType, + ObjectId: resourceID, + }, + Permission: params.ResourceType.Relation, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: params.Subject.ObjectType, + ObjectId: params.Subject.ObjectID, + }, + OptionalRelation: denormalizeSubjectRelation(params.Subject.Relation), + }, + } + if len(params.CaveatContext) > 0 { + var err error + item.Context, err = structpb.NewStruct(params.CaveatContext) + if err != nil { + return nil, fmt.Errorf("caveat context wasn't properly validated: %w", err) + } + } + return item, nil +} + +func (ps *permissionServer) ExpandPermissionTree(ctx context.Context, req *v1.ExpandPermissionTreeRequest) (*v1.ExpandPermissionTreeResponse, error) { + telemetry.RecordLogicalChecks(1) + + atRevision, expandedAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + err = namespace.CheckNamespaceAndRelation(ctx, req.Resource.ObjectType, req.Permission, false, ds) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth)) + if err != nil { + return nil, err + } + + resp, err := ps.dispatch.DispatchExpand(ctx, &dispatch.DispatchExpandRequest{ + Metadata: &dispatch.ResolverMeta{ + AtRevision: atRevision.String(), + DepthRemaining: ps.config.MaximumAPIDepth, + TraversalBloom: bf, + }, + ResourceAndRelation: &core.ObjectAndRelation{ + Namespace: req.Resource.ObjectType, + ObjectId: req.Resource.ObjectId, + Relation: req.Permission, + }, + ExpansionMode: dispatch.DispatchExpandRequest_SHALLOW, + }) + usagemetrics.SetInContext(ctx, resp.Metadata) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + // TODO(jschorr): Change to either using shared interfaces for nodes, or switch the internal + // dispatched expand to return V1 node types. + return &v1.ExpandPermissionTreeResponse{ + TreeRoot: TranslateExpansionTree(resp.TreeNode), + ExpandedAt: expandedAt, + }, nil +} + +// TranslateRelationshipTree translates a V1 PermissionRelationshipTree into a RelationTupleTreeNode. +func TranslateRelationshipTree(tree *v1.PermissionRelationshipTree) *core.RelationTupleTreeNode { + var expanded *core.ObjectAndRelation + if tree.ExpandedObject != nil { + expanded = &core.ObjectAndRelation{ + Namespace: tree.ExpandedObject.ObjectType, + ObjectId: tree.ExpandedObject.ObjectId, + Relation: tree.ExpandedRelation, + } + } + + switch t := tree.TreeType.(type) { + case *v1.PermissionRelationshipTree_Intermediate: + var operation core.SetOperationUserset_Operation + switch t.Intermediate.Operation { + case v1.AlgebraicSubjectSet_OPERATION_EXCLUSION: + operation = core.SetOperationUserset_EXCLUSION + case v1.AlgebraicSubjectSet_OPERATION_INTERSECTION: + operation = core.SetOperationUserset_INTERSECTION + case v1.AlgebraicSubjectSet_OPERATION_UNION: + operation = core.SetOperationUserset_UNION + default: + panic("unknown set operation") + } + + children := []*core.RelationTupleTreeNode{} + for _, child := range t.Intermediate.Children { + children = append(children, TranslateRelationshipTree(child)) + } + + return &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_IntermediateNode{ + IntermediateNode: &core.SetOperationUserset{ + Operation: operation, + ChildNodes: children, + }, + }, + Expanded: expanded, + } + + case *v1.PermissionRelationshipTree_Leaf: + var subjects []*core.DirectSubject + for _, subj := range t.Leaf.Subjects { + subjects = append(subjects, &core.DirectSubject{ + Subject: &core.ObjectAndRelation{ + Namespace: subj.Object.ObjectType, + ObjectId: subj.Object.ObjectId, + Relation: stringz.DefaultEmpty(subj.OptionalRelation, graph.Ellipsis), + }, + }) + } + + return &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_LeafNode{ + LeafNode: &core.DirectSubjects{Subjects: subjects}, + }, + Expanded: expanded, + } + + default: + panic("unknown type of expansion tree node") + } +} + +func TranslateExpansionTree(node *core.RelationTupleTreeNode) *v1.PermissionRelationshipTree { + switch t := node.NodeType.(type) { + case *core.RelationTupleTreeNode_IntermediateNode: + var operation v1.AlgebraicSubjectSet_Operation + switch t.IntermediateNode.Operation { + case core.SetOperationUserset_EXCLUSION: + operation = v1.AlgebraicSubjectSet_OPERATION_EXCLUSION + case core.SetOperationUserset_INTERSECTION: + operation = v1.AlgebraicSubjectSet_OPERATION_INTERSECTION + case core.SetOperationUserset_UNION: + operation = v1.AlgebraicSubjectSet_OPERATION_UNION + default: + panic("unknown set operation") + } + + var children []*v1.PermissionRelationshipTree + for _, child := range node.GetIntermediateNode().ChildNodes { + children = append(children, TranslateExpansionTree(child)) + } + + var objRef *v1.ObjectReference + var objRel string + if node.Expanded != nil { + objRef = &v1.ObjectReference{ + ObjectType: node.Expanded.Namespace, + ObjectId: node.Expanded.ObjectId, + } + objRel = node.Expanded.Relation + } + + return &v1.PermissionRelationshipTree{ + TreeType: &v1.PermissionRelationshipTree_Intermediate{ + Intermediate: &v1.AlgebraicSubjectSet{ + Operation: operation, + Children: children, + }, + }, + ExpandedObject: objRef, + ExpandedRelation: objRel, + } + + case *core.RelationTupleTreeNode_LeafNode: + var subjects []*v1.SubjectReference + for _, found := range t.LeafNode.Subjects { + subjects = append(subjects, &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: found.Subject.Namespace, + ObjectId: found.Subject.ObjectId, + }, + OptionalRelation: denormalizeSubjectRelation(found.Subject.Relation), + }) + } + + if node.Expanded == nil { + return &v1.PermissionRelationshipTree{ + TreeType: &v1.PermissionRelationshipTree_Leaf{ + Leaf: &v1.DirectSubjectSet{ + Subjects: subjects, + }, + }, + } + } + + return &v1.PermissionRelationshipTree{ + TreeType: &v1.PermissionRelationshipTree_Leaf{ + Leaf: &v1.DirectSubjectSet{ + Subjects: subjects, + }, + }, + ExpandedObject: &v1.ObjectReference{ + ObjectType: node.Expanded.Namespace, + ObjectId: node.Expanded.ObjectId, + }, + ExpandedRelation: node.Expanded.Relation, + } + + default: + panic("unknown type of expansion tree node") + } +} + +const lrv2CursorFlag = "lrv2" + +func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp v1.PermissionsService_LookupResourcesServer) error { + // NOTE: LRv2 is the only valid option, and we'll expect that all cursors include that flag. + // This is to preserve backward-compatibility in the meantime. + if req.OptionalCursor != nil { + _, _, err := cursor.GetCursorFlag(req.OptionalCursor, lrv2CursorFlag) + if err != nil { + return ps.rewriteError(resp.Context(), err) + } + } + + if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxLookupResourcesLimit { + return ps.rewriteError(resp.Context(), NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxLookupResourcesLimit))) + } + + ctx := resp.Context() + + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + if err := namespace.CheckNamespaceAndRelations(ctx, + []namespace.TypeAndRelationToCheck{ + { + NamespaceName: req.ResourceObjectType, + RelationName: req.Permission, + AllowEllipsis: false, + }, + { + NamespaceName: req.Subject.Object.ObjectType, + RelationName: normalizeSubjectRelation(req.Subject), + AllowEllipsis: true, + }, + }, ds); err != nil { + return ps.rewriteError(ctx, err) + } + + respMetadata := &dispatch.ResponseMeta{ + DispatchCount: 1, + CachedDispatchCount: 0, + DepthRequired: 1, + DebugInfo: nil, + } + usagemetrics.SetInContext(ctx, respMetadata) + + var currentCursor *dispatch.Cursor + + lrRequestHash, err := computeLRRequestHash(req) + if err != nil { + return ps.rewriteError(ctx, err) + } + + if req.OptionalCursor != nil { + decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash) + if err != nil { + return ps.rewriteError(ctx, err) + } + currentCursor = decodedCursor + } + + alreadyPublishedPermissionedResourceIds := map[string]struct{}{} + var totalCountPublished uint64 + defer func() { + telemetry.RecordLogicalChecks(totalCountPublished) + }() + + stream := dispatchpkg.NewHandlingDispatchStream(ctx, func(result *dispatch.DispatchLookupResources2Response) error { + found := result.Resource + + dispatchpkg.AddResponseMetadata(respMetadata, result.Metadata) + currentCursor = result.AfterResponseCursor + + var partial *v1.PartialCaveatInfo + permissionship := v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION + if len(found.MissingContextParams) > 0 { + permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION + partial = &v1.PartialCaveatInfo{ + MissingRequiredContext: found.MissingContextParams, + } + } else if req.OptionalLimit == 0 { + if _, ok := alreadyPublishedPermissionedResourceIds[found.ResourceId]; ok { + // Skip publishing the duplicate. + return nil + } + + // TODO(jschorr): Investigate something like a Trie here for better memory efficiency. + alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{} + } + + encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, map[string]string{ + lrv2CursorFlag: "1", + }) + if err != nil { + return ps.rewriteError(ctx, err) + } + + err = resp.Send(&v1.LookupResourcesResponse{ + LookedUpAt: revisionReadAt, + ResourceObjectId: found.ResourceId, + Permissionship: permissionship, + PartialCaveatInfo: partial, + AfterResultCursor: encodedCursor, + }) + if err != nil { + return err + } + + totalCountPublished++ + return nil + }) + + bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth)) + if err != nil { + return err + } + + err = ps.dispatch.DispatchLookupResources2( + &dispatch.DispatchLookupResources2Request{ + Metadata: &dispatch.ResolverMeta{ + AtRevision: atRevision.String(), + DepthRemaining: ps.config.MaximumAPIDepth, + TraversalBloom: bf, + }, + ResourceRelation: &core.RelationReference{ + Namespace: req.ResourceObjectType, + Relation: req.Permission, + }, + SubjectRelation: &core.RelationReference{ + Namespace: req.Subject.Object.ObjectType, + Relation: normalizeSubjectRelation(req.Subject), + }, + SubjectIds: []string{req.Subject.Object.ObjectId}, + TerminalSubject: &core.ObjectAndRelation{ + Namespace: req.Subject.Object.ObjectType, + ObjectId: req.Subject.Object.ObjectId, + Relation: normalizeSubjectRelation(req.Subject), + }, + Context: req.Context, + OptionalCursor: currentCursor, + OptionalLimit: req.OptionalLimit, + }, + stream) + if err != nil { + return ps.rewriteError(ctx, err) + } + + return nil +} + +func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v1.PermissionsService_LookupSubjectsServer) error { + ctx := resp.Context() + + if req.OptionalConcreteLimit != 0 { + return ps.rewriteError(ctx, status.Errorf(codes.Unimplemented, "concrete limit is not yet supported")) + } + + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) + if err != nil { + return ps.rewriteError(ctx, err) + } + + if err := namespace.CheckNamespaceAndRelations(ctx, + []namespace.TypeAndRelationToCheck{ + { + NamespaceName: req.Resource.ObjectType, + RelationName: req.Permission, + AllowEllipsis: false, + }, + { + NamespaceName: req.SubjectObjectType, + RelationName: stringz.DefaultEmpty(req.OptionalSubjectRelation, tuple.Ellipsis), + AllowEllipsis: true, + }, + }, ds); err != nil { + return ps.rewriteError(ctx, err) + } + + respMetadata := &dispatch.ResponseMeta{ + DispatchCount: 0, + CachedDispatchCount: 0, + DepthRequired: 0, + DebugInfo: nil, + } + usagemetrics.SetInContext(ctx, respMetadata) + + var totalCountPublished uint64 + defer func() { + telemetry.RecordLogicalChecks(totalCountPublished) + }() + + stream := dispatchpkg.NewHandlingDispatchStream(ctx, func(result *dispatch.DispatchLookupSubjectsResponse) error { + foundSubjects, ok := result.FoundSubjectsByResourceId[req.Resource.ObjectId] + if !ok { + return fmt.Errorf("missing resource ID in returned LS") + } + + for _, foundSubject := range foundSubjects.FoundSubjects { + excludedSubjectIDs := make([]string, 0, len(foundSubject.ExcludedSubjects)) + for _, excludedSubject := range foundSubject.ExcludedSubjects { + excludedSubjectIDs = append(excludedSubjectIDs, excludedSubject.SubjectId) + } + + excludedSubjects := make([]*v1.ResolvedSubject, 0, len(foundSubject.ExcludedSubjects)) + for _, excludedSubject := range foundSubject.ExcludedSubjects { + resolvedExcludedSubject, err := foundSubjectToResolvedSubject(ctx, excludedSubject, caveatContext, ds, ps.config.CaveatTypeSet) + if err != nil { + return err + } + + if resolvedExcludedSubject == nil { + continue + } + + excludedSubjects = append(excludedSubjects, resolvedExcludedSubject) + } + + subject, err := foundSubjectToResolvedSubject(ctx, foundSubject, caveatContext, ds, ps.config.CaveatTypeSet) + if err != nil { + return err + } + if subject == nil { + continue + } + + err = resp.Send(&v1.LookupSubjectsResponse{ + Subject: subject, + ExcludedSubjects: excludedSubjects, + LookedUpAt: revisionReadAt, + SubjectObjectId: foundSubject.SubjectId, // Deprecated + ExcludedSubjectIds: excludedSubjectIDs, // Deprecated + Permissionship: subject.Permissionship, // Deprecated + PartialCaveatInfo: subject.PartialCaveatInfo, // Deprecated + }) + if err != nil { + return err + } + } + + totalCountPublished++ + dispatchpkg.AddResponseMetadata(respMetadata, result.Metadata) + return nil + }) + + bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth)) + if err != nil { + return err + } + + err = ps.dispatch.DispatchLookupSubjects( + &dispatch.DispatchLookupSubjectsRequest{ + Metadata: &dispatch.ResolverMeta{ + AtRevision: atRevision.String(), + DepthRemaining: ps.config.MaximumAPIDepth, + TraversalBloom: bf, + }, + ResourceRelation: &core.RelationReference{ + Namespace: req.Resource.ObjectType, + Relation: req.Permission, + }, + ResourceIds: []string{req.Resource.ObjectId}, + SubjectRelation: &core.RelationReference{ + Namespace: req.SubjectObjectType, + Relation: stringz.DefaultEmpty(req.OptionalSubjectRelation, tuple.Ellipsis), + }, + }, + stream) + if err != nil { + return ps.rewriteError(ctx, err) + } + + return nil +} + +func foundSubjectToResolvedSubject(ctx context.Context, foundSubject *dispatch.FoundSubject, caveatContext map[string]any, ds datastore.CaveatReader, caveatTypeSet *caveattypes.TypeSet) (*v1.ResolvedSubject, error) { + var partialCaveat *v1.PartialCaveatInfo + permissionship := v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION + if foundSubject.GetCaveatExpression() != nil { + permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION + + cr, err := cexpr.RunSingleCaveatExpression(ctx, caveatTypeSet, foundSubject.GetCaveatExpression(), caveatContext, ds, cexpr.RunCaveatExpressionNoDebugging) + if err != nil { + return nil, err + } + + if cr.Value() { + permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION + } else if cr.IsPartial() { + missingFields, _ := cr.MissingVarNames() + partialCaveat = &v1.PartialCaveatInfo{ + MissingRequiredContext: missingFields, + } + } else { + // Skip this found subject. + return nil, nil + } + } + + return &v1.ResolvedSubject{ + SubjectObjectId: foundSubject.SubjectId, + Permissionship: permissionship, + PartialCaveatInfo: partialCaveat, + }, nil +} + +func normalizeSubjectRelation(sub *v1.SubjectReference) string { + if sub.OptionalRelation == "" { + return graph.Ellipsis + } + return sub.OptionalRelation +} + +func denormalizeSubjectRelation(relation string) string { + if relation == graph.Ellipsis { + return "" + } + return relation +} + +func GetCaveatContext(ctx context.Context, caveatCtx *structpb.Struct, maxCaveatContextSize int) (map[string]any, error) { + var caveatContext map[string]any + if caveatCtx != nil { + if size := proto.Size(caveatCtx); maxCaveatContextSize > 0 && size > maxCaveatContextSize { + return nil, shared.RewriteError( + ctx, + status.Errorf( + codes.InvalidArgument, + "request caveat context should have less than %d bytes but had %d", + maxCaveatContextSize, + size, + ), + nil, + ) + } + caveatContext = caveatCtx.AsMap() + } + return caveatContext, nil +} + +type loadBulkAdapter struct { + stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse] + referencedNamespaceMap map[string]*schema.Definition + referencedCaveatMap map[string]*core.CaveatDefinition + current tuple.Relationship + caveat core.ContextualizedCaveat + caveatTypeSet *caveattypes.TypeSet + + awaitingNamespaces []string + awaitingCaveats []string + + currentBatch []*v1.Relationship + numSent int + err error +} + +func (a *loadBulkAdapter) Next(_ context.Context) (*tuple.Relationship, error) { + for a.err == nil && a.numSent == len(a.currentBatch) { + // Load a new batch + batch, err := a.stream.Recv() + if err != nil { + a.err = err + if errors.Is(a.err, io.EOF) { + return nil, nil + } + return nil, a.err + } + + a.currentBatch = batch.Relationships + a.numSent = 0 + + a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats( + a.currentBatch, + a.referencedNamespaceMap, + a.referencedCaveatMap, + ) + } + + if len(a.awaitingNamespaces) > 0 || len(a.awaitingCaveats) > 0 { + // Shut down the stream to give our caller a chance to fill in this information + return nil, nil + } + + a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType + a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId + a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation + a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType + a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId + a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis) + + if a.currentBatch[a.numSent].OptionalCaveat != nil { + a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName + a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context + a.current.OptionalCaveat = &a.caveat + } else { + a.current.OptionalCaveat = nil + } + + if a.currentBatch[a.numSent].OptionalExpiresAt != nil { + t := a.currentBatch[a.numSent].OptionalExpiresAt.AsTime() + a.current.OptionalExpiration = &t + } else { + a.current.OptionalExpiration = nil + } + + a.current.OptionalIntegrity = nil + + if err := relationships.ValidateOneRelationship( + a.referencedNamespaceMap, + a.referencedCaveatMap, + a.caveatTypeSet, + a.current, + relationships.ValidateRelationshipForCreateOrTouch, + ); err != nil { + return nil, err + } + + a.numSent++ + return &a.current, nil +} + +func (ps *permissionServer) ImportBulkRelationships(stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error { + ds := datastoremw.MustFromContext(stream.Context()) + + var numWritten uint64 + if _, err := ds.ReadWriteTx(stream.Context(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + loadedNamespaces := make(map[string]*schema.Definition, 2) + loadedCaveats := make(map[string]*core.CaveatDefinition, 0) + + adapter := &loadBulkAdapter{ + stream: stream, + referencedNamespaceMap: loadedNamespaces, + referencedCaveatMap: loadedCaveats, + caveat: core.ContextualizedCaveat{}, + caveatTypeSet: ps.config.CaveatTypeSet, + } + resolver := schema.ResolverForDatastoreReader(rwt) + ts := schema.NewTypeSystem(resolver) + + var streamWritten uint64 + var err error + for ; adapter.err == nil && err == nil; streamWritten, err = rwt.BulkLoad(stream.Context(), adapter) { + numWritten += streamWritten + + // The stream has terminated because we're awaiting namespace and/or caveat information + if len(adapter.awaitingNamespaces) > 0 { + nsDefs, err := rwt.LookupNamespacesWithNames(stream.Context(), adapter.awaitingNamespaces) + if err != nil { + return err + } + + for _, nsDef := range nsDefs { + newDef, err := schema.NewDefinition(ts, nsDef.Definition) + if err != nil { + return err + } + + loadedNamespaces[nsDef.Definition.Name] = newDef + } + adapter.awaitingNamespaces = nil + } + + if len(adapter.awaitingCaveats) > 0 { + caveats, err := rwt.LookupCaveatsWithNames(stream.Context(), adapter.awaitingCaveats) + if err != nil { + return err + } + + for _, caveat := range caveats { + loadedCaveats[caveat.Definition.Name] = caveat.Definition + } + adapter.awaitingCaveats = nil + } + } + numWritten += streamWritten + + return err + }, dsoptions.WithDisableRetries(true)); err != nil { + return shared.RewriteErrorWithoutConfig(stream.Context(), err) + } + + usagemetrics.SetInContext(stream.Context(), &dispatch.ResponseMeta{ + // One request for the whole load + DispatchCount: 1, + }) + + return stream.SendAndClose(&v1.ImportBulkRelationshipsResponse{ + NumLoaded: numWritten, + }) +} + +func (ps *permissionServer) ExportBulkRelationships( + req *v1.ExportBulkRelationshipsRequest, + resp grpc.ServerStreamingServer[v1.ExportBulkRelationshipsResponse], +) error { + ctx := resp.Context() + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + return ExportBulk(ctx, datastoremw.MustFromContext(ctx), uint64(ps.config.MaxBulkExportRelationshipsLimit), req, atRevision, resp.Send) +} + +// ExportBulk implements the ExportBulkRelationships API functionality. Given a datastore.Datastore, it will +// export stream via the sender all relationships matched by the incoming request. +// If no cursor is provided, it will fallback to the provided revision. +func ExportBulk(ctx context.Context, ds datastore.Datastore, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.ExportBulkRelationshipsResponse) error) error { + if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { + return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) + } + + atRevision := fallbackRevision + var curNamespace string + var cur dsoptions.Cursor + if req.OptionalCursor != nil { + var err error + atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + } + + reader := ds.SnapshotReader(atRevision) + + namespaces, err := reader.ListAllNamespaces(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Make sure the namespaces are always in a stable order + slices.SortFunc(namespaces, func( + lhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + rhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + ) int { + return strings.Compare(lhs.Definition.Name, rhs.Definition.Name) + }) + + // Skip the namespaces that are already fully returned + for cur != nil && len(namespaces) > 0 && namespaces[0].Definition.Name < curNamespace { + namespaces = namespaces[1:] + } + + limit := batchSize + if req.OptionalLimit > 0 { + limit = uint64(req.OptionalLimit) + } + + // Pre-allocate all of the relationships that we might need in order to + // make export easier and faster for the garbage collector. + relsArray := make([]v1.Relationship, limit) + objArray := make([]v1.ObjectReference, limit) + subArray := make([]v1.SubjectReference, limit) + subObjArray := make([]v1.ObjectReference, limit) + caveatArray := make([]v1.ContextualizedCaveat, limit) + for i := range relsArray { + relsArray[i].Resource = &objArray[i] + relsArray[i].Subject = &subArray[i] + relsArray[i].Subject.Object = &subObjArray[i] + } + + emptyRels := make([]*v1.Relationship, limit) + // The number of batches/dispatches for the purpose of usage metrics + var batches uint32 + for _, ns := range namespaces { + rels := emptyRels + + // Reset the cursor between namespaces. + if ns.Definition.Name != curNamespace { + cur = nil + } + + // Skip this namespace if a resource type filter was specified. + if req.OptionalRelationshipFilter != nil && req.OptionalRelationshipFilter.ResourceType != "" { + if ns.Definition.Name != req.OptionalRelationshipFilter.ResourceType { + continue + } + } + + // Setup the filter to use for the relationships. + relationshipFilter := datastore.RelationshipsFilter{OptionalResourceType: ns.Definition.Name} + if req.OptionalRelationshipFilter != nil { + rf, err := datastore.RelationshipsFilterFromPublicFilter(req.OptionalRelationshipFilter) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Overload the namespace name with the one from the request, because each iteration is for a different namespace. + rf.OptionalResourceType = ns.Definition.Name + relationshipFilter = rf + } + + // We want to keep iterating as long as we're sending full batches. + // To bootstrap this loop, we enter the first time with a full rels + // slice of dummy rels that were never sent. + for uint64(len(rels)) == limit { + // Lop off any rels we've already sent + rels = rels[:0] + + relFn := func(rel tuple.Relationship) { + offset := len(rels) + rels = append(rels, &relsArray[offset]) // nozero + + v1Rel := &relsArray[offset] + v1Rel.Resource.ObjectType = rel.RelationshipReference.Resource.ObjectType + v1Rel.Resource.ObjectId = rel.RelationshipReference.Resource.ObjectID + v1Rel.Relation = rel.RelationshipReference.Resource.Relation + v1Rel.Subject.Object.ObjectType = rel.RelationshipReference.Subject.ObjectType + v1Rel.Subject.Object.ObjectId = rel.RelationshipReference.Subject.ObjectID + v1Rel.Subject.OptionalRelation = denormalizeSubjectRelation(rel.RelationshipReference.Subject.Relation) + + if rel.OptionalCaveat != nil { + caveatArray[offset].CaveatName = rel.OptionalCaveat.CaveatName + caveatArray[offset].Context = rel.OptionalCaveat.Context + v1Rel.OptionalCaveat = &caveatArray[offset] + } else { + caveatArray[offset].CaveatName = "" + caveatArray[offset].Context = nil + v1Rel.OptionalCaveat = nil + } + + if rel.OptionalExpiration != nil { + v1Rel.OptionalExpiresAt = timestamppb.New(*rel.OptionalExpiration) + } else { + v1Rel.OptionalExpiresAt = nil + } + } + + cur, err = queryForEach( + ctx, + reader, + relationshipFilter, + relFn, + dsoptions.WithLimit(&limit), + dsoptions.WithAfter(cur), + dsoptions.WithSort(dsoptions.ByResource), + dsoptions.WithQueryShape(queryshape.Varying), + ) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if len(rels) == 0 { + continue + } + + encoded, err := cursor.Encode(&implv1.DecodedCursor{ + VersionOneof: &implv1.DecodedCursor_V1{ + V1: &implv1.V1Cursor{ + Revision: atRevision.String(), + Sections: []string{ + ns.Definition.Name, + tuple.MustString(*dsoptions.ToRelationship(cur)), + }, + }, + }, + }) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if err := sender(&v1.ExportBulkRelationshipsResponse{ + AfterResultCursor: encoded, + Relationships: rels, + }); err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + // Increment batches for usagemetrics + batches++ + } + } + + // Record usage metrics + respMetadata := &dispatch.ResponseMeta{ + DispatchCount: batches, + } + usagemetrics.SetInContext(ctx, respMetadata) + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go b/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go new file mode 100644 index 0000000..c34d5d5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go @@ -0,0 +1,54 @@ +package v1 + +import ( + "context" + "fmt" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" +) + +var limitOne uint64 = 1 + +// checkPreconditions checks whether the preconditions are met in the context of a datastore +// read-write transaction, and returns an error if they are not met. +func checkPreconditions( + ctx context.Context, + rwt datastore.ReadWriteTransaction, + preconditions []*v1.Precondition, +) error { + for _, precond := range preconditions { + dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(precond.Filter) + if err != nil { + return fmt.Errorf("error converting filter: %w", err) + } + + iter, err := rwt.QueryRelationships(ctx, dsFilter, options.WithLimit(&limitOne), options.WithQueryShape(queryshape.Varying)) + if err != nil { + return fmt.Errorf("error reading relationships: %w", err) + } + + _, ok, err := datastore.FirstRelationshipIn(iter) + if err != nil { + return fmt.Errorf("error reading relationships from iterator: %w", err) + } + + switch precond.Operation { + case v1.Precondition_OPERATION_MUST_NOT_MATCH: + if ok { + return NewPreconditionFailedErr(precond) + } + case v1.Precondition_OPERATION_MUST_MATCH: + if !ok { + return NewPreconditionFailedErr(precond) + } + default: + return fmt.Errorf("unspecified precondition operation: %s", precond.Operation) + } + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go new file mode 100644 index 0000000..723a8d3 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go @@ -0,0 +1,720 @@ +package v1 + +import ( + "sort" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/diff" + caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats" + nsdiff "github.com/authzed/spicedb/pkg/diff/namespace" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +type schemaFilters struct { + filters []*v1.ReflectionSchemaFilter +} + +func newSchemaFilters(filters []*v1.ReflectionSchemaFilter) (*schemaFilters, error) { + for _, filter := range filters { + if filter.OptionalDefinitionNameFilter != "" { + if filter.OptionalCaveatNameFilter != "" { + return nil, NewInvalidFilterErr("cannot filter by both definition and caveat name", filter.String()) + } + } + + if filter.OptionalRelationNameFilter != "" { + if filter.OptionalDefinitionNameFilter == "" { + return nil, NewInvalidFilterErr("relation name match requires definition name match", filter.String()) + } + + if filter.OptionalPermissionNameFilter != "" { + return nil, NewInvalidFilterErr("cannot filter by both relation and permission name", filter.String()) + } + } + + if filter.OptionalPermissionNameFilter != "" { + if filter.OptionalDefinitionNameFilter == "" { + return nil, NewInvalidFilterErr("permission name match requires definition name match", filter.String()) + } + } + } + + return &schemaFilters{filters: filters}, nil +} + +func (sf *schemaFilters) HasNamespaces() bool { + if len(sf.filters) == 0 { + return true + } + + for _, filter := range sf.filters { + if filter.OptionalDefinitionNameFilter != "" { + return true + } + } + + return false +} + +func (sf *schemaFilters) HasCaveats() bool { + if len(sf.filters) == 0 { + return true + } + + for _, filter := range sf.filters { + if filter.OptionalCaveatNameFilter != "" { + return true + } + } + + return false +} + +func (sf *schemaFilters) HasNamespace(namespaceName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasDefinitionFilter := false + for _, filter := range sf.filters { + if filter.OptionalDefinitionNameFilter == "" { + continue + } + + hasDefinitionFilter = true + isMatch := strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasDefinitionFilter +} + +func (sf *schemaFilters) HasCaveat(caveatName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasCaveatFilter := false + for _, filter := range sf.filters { + if filter.OptionalCaveatNameFilter == "" { + continue + } + + hasCaveatFilter = true + isMatch := strings.HasPrefix(caveatName, filter.OptionalCaveatNameFilter) + if isMatch { + return true + } + } + + return !hasCaveatFilter +} + +func (sf *schemaFilters) HasRelation(namespaceName, relationName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasRelationFilter := false + for _, filter := range sf.filters { + if filter.OptionalRelationNameFilter == "" { + continue + } + + hasRelationFilter = true + isMatch := strings.HasPrefix(relationName, filter.OptionalRelationNameFilter) + if !isMatch { + continue + } + + isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasRelationFilter +} + +func (sf *schemaFilters) HasPermission(namespaceName, permissionName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasPermissionFilter := false + for _, filter := range sf.filters { + if filter.OptionalPermissionNameFilter == "" { + continue + } + + hasPermissionFilter = true + isMatch := strings.HasPrefix(permissionName, filter.OptionalPermissionNameFilter) + if !isMatch { + continue + } + + isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasPermissionFilter +} + +// convertDiff converts a schema diff into an API response. +func convertDiff( + diff *diff.SchemaDiff, + existingSchema *diff.DiffableSchema, + comparisonSchema *diff.DiffableSchema, + atRevision datastore.Revision, + caveatTypeSet *caveattypes.TypeSet, +) (*v1.DiffSchemaResponse, error) { + size := len(diff.AddedNamespaces) + len(diff.RemovedNamespaces) + len(diff.AddedCaveats) + len(diff.RemovedCaveats) + len(diff.ChangedNamespaces) + len(diff.ChangedCaveats) + diffs := make([]*v1.ReflectionSchemaDiff, 0, size) + + // Add/remove namespaces. + for _, ns := range diff.AddedNamespaces { + nsDef, err := namespaceAPIReprForName(ns, comparisonSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_DefinitionAdded{ + DefinitionAdded: nsDef, + }, + }) + } + + for _, ns := range diff.RemovedNamespaces { + nsDef, err := namespaceAPIReprForName(ns, existingSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_DefinitionRemoved{ + DefinitionRemoved: nsDef, + }, + }) + } + + // Add/remove caveats. + for _, caveat := range diff.AddedCaveats { + caveatDef, err := caveatAPIReprForName(caveat, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatAdded{ + CaveatAdded: caveatDef, + }, + }) + } + + for _, caveat := range diff.RemovedCaveats { + caveatDef, err := caveatAPIReprForName(caveat, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatRemoved{ + CaveatRemoved: caveatDef, + }, + }) + } + + // Changed namespaces. + for nsName, nsDiff := range diff.ChangedNamespaces { + for _, delta := range nsDiff.Deltas() { + switch delta.Type { + case nsdiff.AddedPermission: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := permissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_PermissionAdded{ + PermissionAdded: perm, + }, + }) + + case nsdiff.AddedRelation: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationAdded{ + RelationAdded: rel, + }, + }) + + case nsdiff.ChangedPermissionComment: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := permissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_PermissionDocCommentChanged{ + PermissionDocCommentChanged: perm, + }, + }) + + case nsdiff.ChangedPermissionImpl: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := permissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_PermissionExprChanged{ + PermissionExprChanged: perm, + }, + }) + + case nsdiff.ChangedRelationComment: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationDocCommentChanged{ + RelationDocCommentChanged: rel, + }, + }) + + case nsdiff.LegacyChangedRelationImpl: + return nil, spiceerrors.MustBugf("legacy relation implementation changes are not supported") + + case nsdiff.NamespaceCommentsChanged: + def, err := namespaceAPIReprForName(nsName, comparisonSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_DefinitionDocCommentChanged{ + DefinitionDocCommentChanged: def, + }, + }) + + case nsdiff.RelationAllowedTypeRemoved: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationSubjectTypeRemoved{ + RelationSubjectTypeRemoved: &v1.ReflectionRelationSubjectTypeChange{ + Relation: rel, + ChangedSubjectType: typeAPIRepr(delta.AllowedType), + }, + }, + }) + + case nsdiff.RelationAllowedTypeAdded: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationSubjectTypeAdded{ + RelationSubjectTypeAdded: &v1.ReflectionRelationSubjectTypeChange{ + Relation: rel, + ChangedSubjectType: typeAPIRepr(delta.AllowedType), + }, + }, + }) + + case nsdiff.RemovedPermission: + permission, ok := existingSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := permissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_PermissionRemoved{ + PermissionRemoved: perm, + }, + }) + + case nsdiff.RemovedRelation: + relation, ok := existingSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationRemoved{ + RelationRemoved: rel, + }, + }) + + case nsdiff.NamespaceAdded: + return nil, spiceerrors.MustBugf("should be handled above") + + case nsdiff.NamespaceRemoved: + return nil, spiceerrors.MustBugf("should be handled above") + + default: + return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type) + } + } + } + + // Changed caveats. + for caveatName, caveatDiff := range diff.ChangedCaveats { + for _, delta := range caveatDiff.Deltas() { + switch delta.Type { + case caveatdiff.CaveatCommentsChanged: + caveat, err := caveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatDocCommentChanged{ + CaveatDocCommentChanged: caveat, + }, + }) + + case caveatdiff.AddedParameter: + paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatParameterAdded{ + CaveatParameterAdded: paramDef, + }, + }) + + case caveatdiff.RemovedParameter: + paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatParameterRemoved{ + CaveatParameterRemoved: paramDef, + }, + }) + + case caveatdiff.ParameterTypeChanged: + previousParamDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatParameterTypeChanged{ + CaveatParameterTypeChanged: &v1.ReflectionCaveatParameterTypeChange{ + Parameter: paramDef, + PreviousType: previousParamDef.Type, + }, + }, + }) + + case caveatdiff.CaveatExpressionChanged: + caveat, err := caveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatExprChanged{ + CaveatExprChanged: caveat, + }, + }) + + case caveatdiff.CaveatAdded: + return nil, spiceerrors.MustBugf("should be handled above") + + case caveatdiff.CaveatRemoved: + return nil, spiceerrors.MustBugf("should be handled above") + + default: + return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type) + } + } + } + + return &v1.DiffSchemaResponse{ + Diffs: diffs, + ReadAt: zedtoken.MustNewFromRevision(atRevision), + }, nil +} + +// namespaceAPIReprForName builds an API representation of a namespace. +func namespaceAPIReprForName(namespaceName string, schema *diff.DiffableSchema) (*v1.ReflectionDefinition, error) { + nsDef, ok := schema.GetNamespace(namespaceName) + if !ok { + return nil, spiceerrors.MustBugf("namespace %q not found in schema", namespaceName) + } + + return namespaceAPIRepr(nsDef, nil) +} + +func namespaceAPIRepr(nsDef *core.NamespaceDefinition, schemaFilters *schemaFilters) (*v1.ReflectionDefinition, error) { + if schemaFilters != nil && !schemaFilters.HasNamespace(nsDef.Name) { + return nil, nil + } + + relations := make([]*v1.ReflectionRelation, 0, len(nsDef.Relation)) + permissions := make([]*v1.ReflectionPermission, 0, len(nsDef.Relation)) + + for _, rel := range nsDef.Relation { + if namespace.GetRelationKind(rel) == iv1.RelationMetadata_PERMISSION { + permission, err := permissionAPIRepr(rel, nsDef.Name, schemaFilters) + if err != nil { + return nil, err + } + + if permission != nil { + permissions = append(permissions, permission) + } + continue + } + + relation, err := relationAPIRepr(rel, nsDef.Name, schemaFilters) + if err != nil { + return nil, err + } + + if relation != nil { + relations = append(relations, relation) + } + } + + comments := namespace.GetComments(nsDef.Metadata) + return &v1.ReflectionDefinition{ + Name: nsDef.Name, + Comment: strings.Join(comments, "\n"), + Relations: relations, + Permissions: permissions, + }, nil +} + +// permissionAPIRepr builds an API representation of a permission. +func permissionAPIRepr(relation *core.Relation, parentDefName string, schemaFilters *schemaFilters) (*v1.ReflectionPermission, error) { + if schemaFilters != nil && !schemaFilters.HasPermission(parentDefName, relation.Name) { + return nil, nil + } + + comments := namespace.GetComments(relation.Metadata) + return &v1.ReflectionPermission{ + Name: relation.Name, + Comment: strings.Join(comments, "\n"), + ParentDefinitionName: parentDefName, + }, nil +} + +// relationAPIRepresentation builds an API representation of a relation. +func relationAPIRepr(relation *core.Relation, parentDefName string, schemaFilters *schemaFilters) (*v1.ReflectionRelation, error) { + if schemaFilters != nil && !schemaFilters.HasRelation(parentDefName, relation.Name) { + return nil, nil + } + + comments := namespace.GetComments(relation.Metadata) + + var subjectTypes []*v1.ReflectionTypeReference + if relation.TypeInformation != nil { + subjectTypes = make([]*v1.ReflectionTypeReference, 0, len(relation.TypeInformation.AllowedDirectRelations)) + for _, subjectType := range relation.TypeInformation.AllowedDirectRelations { + typeref := typeAPIRepr(subjectType) + subjectTypes = append(subjectTypes, typeref) + } + } + + return &v1.ReflectionRelation{ + Name: relation.Name, + Comment: strings.Join(comments, "\n"), + ParentDefinitionName: parentDefName, + SubjectTypes: subjectTypes, + }, nil +} + +// typeAPIRepr builds an API representation of a type. +func typeAPIRepr(subjectType *core.AllowedRelation) *v1.ReflectionTypeReference { + typeref := &v1.ReflectionTypeReference{ + SubjectDefinitionName: subjectType.Namespace, + Typeref: &v1.ReflectionTypeReference_IsTerminalSubject{}, + } + + if subjectType.GetRelation() != tuple.Ellipsis && subjectType.GetRelation() != "" { + typeref.Typeref = &v1.ReflectionTypeReference_OptionalRelationName{ + OptionalRelationName: subjectType.GetRelation(), + } + } else if subjectType.GetPublicWildcard() != nil { + typeref.Typeref = &v1.ReflectionTypeReference_IsPublicWildcard{ + IsPublicWildcard: true, + } + } + + if subjectType.GetRequiredCaveat() != nil { + typeref.OptionalCaveatName = subjectType.GetRequiredCaveat().CaveatName + } + + return typeref +} + +// caveatAPIReprForName builds an API representation of a caveat. +func caveatAPIReprForName(caveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveat, error) { + caveatDef, ok := schema.GetCaveat(caveatName) + if !ok { + return nil, spiceerrors.MustBugf("caveat %q not found in schema", caveatName) + } + + return caveatAPIRepr(caveatDef, nil, caveatTypeSet) +} + +// caveatAPIRepr builds an API representation of a caveat. +func caveatAPIRepr(caveatDef *core.CaveatDefinition, schemaFilters *schemaFilters, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveat, error) { + if schemaFilters != nil && !schemaFilters.HasCaveat(caveatDef.Name) { + return nil, nil + } + + parameters := make([]*v1.ReflectionCaveatParameter, 0, len(caveatDef.ParameterTypes)) + paramNames := maps.Keys(caveatDef.ParameterTypes) + sort.Strings(paramNames) + + for _, paramName := range paramNames { + paramType, ok := caveatDef.ParameterTypes[paramName] + if !ok { + return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, caveatDef.Name) + } + + decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType) + if err != nil { + return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err) + } + + parameters = append(parameters, &v1.ReflectionCaveatParameter{ + Name: paramName, + Type: decoded.String(), + ParentCaveatName: caveatDef.Name, + }) + } + + parameterTypes, err := caveattypes.DecodeParameterTypes(caveatTypeSet, caveatDef.ParameterTypes) + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat parameters: %v", err) + } + + deserializedReflectionression, err := caveats.DeserializeCaveatWithTypeSet(caveatTypeSet, caveatDef.SerializedExpression, parameterTypes) + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat expression bytes: %v", err) + } + + exprString, err := deserializedReflectionression.ExprString() + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat expression: %v", err) + } + + comments := namespace.GetComments(caveatDef.Metadata) + return &v1.ReflectionCaveat{ + Name: caveatDef.Name, + Comment: strings.Join(comments, "\n"), + Parameters: parameters, + Expression: exprString, + }, nil +} + +// caveatAPIParamRepresentation builds an API representation of a caveat parameter. +func caveatAPIParamRepr(paramName, parentCaveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveatParameter, error) { + caveatDef, ok := schema.GetCaveat(parentCaveatName) + if !ok { + return nil, spiceerrors.MustBugf("caveat %q not found in schema", parentCaveatName) + } + + paramType, ok := caveatDef.ParameterTypes[paramName] + if !ok { + return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, parentCaveatName) + } + + decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType) + if err != nil { + return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err) + } + + return &v1.ReflectionCaveatParameter{ + Name: paramName, + Type: decoded.String(), + ParentCaveatName: parentCaveatName, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go new file mode 100644 index 0000000..a572216 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go @@ -0,0 +1,76 @@ +package v1 + +import ( + "context" + + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/diff" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +func loadCurrentSchema(ctx context.Context) (*diff.DiffableSchema, datastore.Revision, error) { + ds := datastoremw.MustFromContext(ctx) + + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, nil, err + } + + reader := ds.SnapshotReader(atRevision) + + namespacesAndRevs, err := reader.ListAllNamespaces(ctx) + if err != nil { + return nil, atRevision, err + } + + caveatsAndRevs, err := reader.ListAllCaveats(ctx) + if err != nil { + return nil, atRevision, err + } + + namespaces := make([]*core.NamespaceDefinition, 0, len(namespacesAndRevs)) + for _, namespaceAndRev := range namespacesAndRevs { + namespaces = append(namespaces, namespaceAndRev.Definition) + } + + caveats := make([]*core.CaveatDefinition, 0, len(caveatsAndRevs)) + for _, caveatAndRev := range caveatsAndRevs { + caveats = append(caveats, caveatAndRev.Definition) + } + + return &diff.DiffableSchema{ + ObjectDefinitions: namespaces, + CaveatDefinitions: caveats, + }, atRevision, nil +} + +func schemaDiff(ctx context.Context, comparisonSchemaString string, caveatTypeSet *caveattypes.TypeSet) (*diff.SchemaDiff, *diff.DiffableSchema, *diff.DiffableSchema, error) { + existingSchema, _, err := loadCurrentSchema(ctx) + if err != nil { + return nil, nil, nil, err + } + + // Compile the comparison schema. + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: comparisonSchemaString, + }, compiler.AllowUnprefixedObjectType(), compiler.CaveatTypeSet(caveatTypeSet)) + if err != nil { + return nil, nil, nil, err + } + + comparisonSchema := diff.NewDiffableSchemaFromCompiledSchema(compiled) + + diff, err := diff.DiffSchemas(*existingSchema, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, nil, nil, err + } + + // Return the diff. + return diff, existingSchema, &comparisonSchema, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go b/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go new file mode 100644 index 0000000..f0b2138 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go @@ -0,0 +1,576 @@ +package v1 + +import ( + "context" + "fmt" + "time" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "github.com/jzelinskie/stringz" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/middleware" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/handwrittenvalidation" + "github.com/authzed/spicedb/internal/middleware/streamtimeout" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/relationships" + "github.com/authzed/spicedb/internal/services/shared" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/pagination" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/genutil" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/middleware/consistency" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +var writeUpdateCounter = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "spicedb", + Subsystem: "v1", + Name: "write_relationships_updates", + Help: "The update counts for the WriteRelationships calls", + Buckets: []float64{0, 1, 2, 5, 10, 15, 25, 50, 100, 250, 500, 1000}, +}, []string{"kind"}) + +const MaximumTransactionMetadataSize = 65000 // bytes. Limited by the BLOB size used in MySQL driver + +// PermissionsServerConfig is configuration for the permissions server. +type PermissionsServerConfig struct { + // MaxUpdatesPerWrite holds the maximum number of updates allowed per + // WriteRelationships call. + MaxUpdatesPerWrite uint16 + + // MaxPreconditionsCount holds the maximum number of preconditions allowed + // on a WriteRelationships or DeleteRelationships call. + MaxPreconditionsCount uint16 + + // MaximumAPIDepth is the default/starting depth remaining for API calls made + // to the permissions server. + MaximumAPIDepth uint32 + + // DispatchChunkSize is the maximum number of elements to dispach in a dispatch call + DispatchChunkSize uint16 + + // StreamingAPITimeout is the timeout for streaming APIs when no response has been + // recently received. + StreamingAPITimeout time.Duration + + // MaxCaveatContextSize defines the maximum length of the request caveat context in bytes + MaxCaveatContextSize int + + // MaxRelationshipContextSize defines the maximum length of a relationship's context in bytes + MaxRelationshipContextSize int + + // MaxDatastoreReadPageSize defines the maximum number of relationships loaded from the + // datastore in one query. + MaxDatastoreReadPageSize uint64 + + // MaxCheckBulkConcurrency defines the maximum number of concurrent checks that can be + // made in a single CheckBulkPermissions call. + MaxCheckBulkConcurrency uint16 + + // MaxReadRelationshipsLimit defines the maximum number of relationships that can be read + // in a single ReadRelationships call. + MaxReadRelationshipsLimit uint32 + + // MaxDeleteRelationshipsLimit defines the maximum number of relationships that can be deleted + // in a single DeleteRelationships call. + MaxDeleteRelationshipsLimit uint32 + + // MaxLookupResourcesLimit defines the maximum number of resources that can be looked up in a + // single LookupResources call. + MaxLookupResourcesLimit uint32 + + // MaxBulkExportRelationshipsLimit defines the maximum number of relationships that can be + // exported in a single BulkExportRelationships call. + MaxBulkExportRelationshipsLimit uint32 + + // ExpiringRelationshipsEnabled defines whether or not expiring relationships are enabled. + ExpiringRelationshipsEnabled bool + + // CaveatTypeSet is the set of caveat types to use for caveats. If not specified, + // the default type set is used. + CaveatTypeSet *caveattypes.TypeSet +} + +// NewPermissionsServer creates a PermissionsServiceServer instance. +func NewPermissionsServer( + dispatch dispatch.Dispatcher, + config PermissionsServerConfig, +) v1.PermissionsServiceServer { + configWithDefaults := PermissionsServerConfig{ + MaxPreconditionsCount: defaultIfZero(config.MaxPreconditionsCount, 1000), + MaxUpdatesPerWrite: defaultIfZero(config.MaxUpdatesPerWrite, 1000), + MaximumAPIDepth: defaultIfZero(config.MaximumAPIDepth, 50), + StreamingAPITimeout: defaultIfZero(config.StreamingAPITimeout, 30*time.Second), + MaxCaveatContextSize: defaultIfZero(config.MaxCaveatContextSize, 4096), + MaxRelationshipContextSize: defaultIfZero(config.MaxRelationshipContextSize, 25_000), + MaxDatastoreReadPageSize: defaultIfZero(config.MaxDatastoreReadPageSize, 1_000), + MaxReadRelationshipsLimit: defaultIfZero(config.MaxReadRelationshipsLimit, 1_000), + MaxDeleteRelationshipsLimit: defaultIfZero(config.MaxDeleteRelationshipsLimit, 1_000), + MaxLookupResourcesLimit: defaultIfZero(config.MaxLookupResourcesLimit, 1_000), + MaxBulkExportRelationshipsLimit: defaultIfZero(config.MaxBulkExportRelationshipsLimit, 100_000), + DispatchChunkSize: defaultIfZero(config.DispatchChunkSize, 100), + MaxCheckBulkConcurrency: defaultIfZero(config.MaxCheckBulkConcurrency, 50), + CaveatTypeSet: caveattypes.TypeSetOrDefault(config.CaveatTypeSet), + ExpiringRelationshipsEnabled: true, + } + + return &permissionServer{ + dispatch: dispatch, + config: configWithDefaults, + WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{ + Unary: middleware.ChainUnaryServer( + grpcvalidate.UnaryServerInterceptor(), + handwrittenvalidation.UnaryServerInterceptor, + usagemetrics.UnaryServerInterceptor(), + ), + Stream: middleware.ChainStreamServer( + grpcvalidate.StreamServerInterceptor(), + handwrittenvalidation.StreamServerInterceptor, + usagemetrics.StreamServerInterceptor(), + streamtimeout.MustStreamServerInterceptor(configWithDefaults.StreamingAPITimeout), + ), + }, + bulkChecker: &bulkChecker{ + maxAPIDepth: configWithDefaults.MaximumAPIDepth, + maxCaveatContextSize: configWithDefaults.MaxCaveatContextSize, + maxConcurrency: configWithDefaults.MaxCheckBulkConcurrency, + dispatch: dispatch, + dispatchChunkSize: configWithDefaults.DispatchChunkSize, + caveatTypeSet: configWithDefaults.CaveatTypeSet, + }, + } +} + +type permissionServer struct { + v1.UnimplementedPermissionsServiceServer + shared.WithServiceSpecificInterceptors + + dispatch dispatch.Dispatcher + config PermissionsServerConfig + + bulkChecker *bulkChecker +} + +func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest, resp v1.PermissionsService_ReadRelationshipsServer) error { + if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxReadRelationshipsLimit { + return ps.rewriteError(resp.Context(), NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxReadRelationshipsLimit))) + } + + ctx := resp.Context() + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, ds); err != nil { + return ps.rewriteError(ctx, err) + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + DispatchCount: 1, + }) + + limit := uint64(0) + var startCursor options.Cursor + + rrRequestHash, err := computeReadRelationshipsRequestHash(req) + if err != nil { + return ps.rewriteError(ctx, err) + } + + if req.OptionalCursor != nil { + decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, rrRequestHash) + if err != nil { + return ps.rewriteError(ctx, err) + } + + if len(decodedCursor.Sections) != 1 { + return ps.rewriteError(ctx, NewInvalidCursorErr("did not find expected resume relationship")) + } + + parsed, err := tuple.Parse(decodedCursor.Sections[0]) + if err != nil { + return ps.rewriteError(ctx, NewInvalidCursorErr("could not parse resume relationship")) + } + + startCursor = options.ToCursor(parsed) + } + + pageSize := ps.config.MaxDatastoreReadPageSize + if req.OptionalLimit > 0 { + limit = uint64(req.OptionalLimit) + if limit < pageSize { + pageSize = limit + } + } + + dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(req.RelationshipFilter) + if err != nil { + return ps.rewriteError(ctx, fmt.Errorf("error filtering: %w", err)) + } + + it, err := pagination.NewPaginatedIterator( + ctx, + ds, + dsFilter, + pageSize, + options.ByResource, + startCursor, + queryshape.Varying, + ) + if err != nil { + return ps.rewriteError(ctx, err) + } + + response := &v1.ReadRelationshipsResponse{ + ReadAt: revisionReadAt, + Relationship: &v1.Relationship{ + Resource: &v1.ObjectReference{}, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{}, + }, + }, + } + + dispatchCursor := &dispatchv1.Cursor{ + DispatchVersion: 1, + Sections: []string{""}, + } + + var returnedCount uint64 + for rel, err := range it { + if err != nil { + return ps.rewriteError(ctx, fmt.Errorf("error when reading tuples: %w", err)) + } + + if limit > 0 && returnedCount >= limit { + break + } + + dispatchCursor.Sections[0] = tuple.StringWithoutCaveatOrExpiration(rel) + encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, nil) + if err != nil { + return ps.rewriteError(ctx, err) + } + + tuple.CopyToV1Relationship(rel, response.Relationship) + response.AfterResultCursor = encodedCursor + + err = resp.Send(response) + if err != nil { + return ps.rewriteError(ctx, fmt.Errorf("error when streaming tuple: %w", err)) + } + returnedCount++ + } + return nil +} + +func (ps *permissionServer) WriteRelationships(ctx context.Context, req *v1.WriteRelationshipsRequest) (*v1.WriteRelationshipsResponse, error) { + if err := ps.validateTransactionMetadata(req.OptionalTransactionMetadata); err != nil { + return nil, ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx) + + span := trace.SpanFromContext(ctx) + span.AddEvent("validating mutations") + // Ensure that the updates and preconditions are not over the configured limits. + if len(req.Updates) > int(ps.config.MaxUpdatesPerWrite) { + return nil, ps.rewriteError( + ctx, + NewExceedsMaximumUpdatesErr(uint64(len(req.Updates)), uint64(ps.config.MaxUpdatesPerWrite)), + ) + } + + if len(req.OptionalPreconditions) > int(ps.config.MaxPreconditionsCount) { + return nil, ps.rewriteError( + ctx, + NewExceedsMaximumPreconditionsErr(uint64(len(req.OptionalPreconditions)), uint64(ps.config.MaxPreconditionsCount)), + ) + } + + // Check for duplicate updates and create the set of caveat names to load. + updateRelationshipSet := mapz.NewSet[string]() + for _, update := range req.Updates { + // TODO(jschorr): Change to struct-based keys. + tupleStr := tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship) + if !updateRelationshipSet.Add(tupleStr) { + return nil, ps.rewriteError( + ctx, + NewDuplicateRelationshipErr(update), + ) + } + if proto.Size(update.Relationship.OptionalCaveat) > ps.config.MaxRelationshipContextSize { + return nil, ps.rewriteError( + ctx, + NewMaxRelationshipContextError(update, ps.config.MaxRelationshipContextSize), + ) + } + + if !ps.config.ExpiringRelationshipsEnabled && update.Relationship.OptionalExpiresAt != nil { + return nil, ps.rewriteError( + ctx, + fmt.Errorf("support for expiring relationships is not enabled"), + ) + } + } + + // Execute the write operation(s). + span.AddEvent("read write transaction") + relUpdates, err := tuple.UpdatesFromV1RelationshipUpdates(req.Updates) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + span.AddEvent("preconditions") + + // Validate the preconditions. + for _, precond := range req.OptionalPreconditions { + if err := validatePrecondition(ctx, precond, rwt); err != nil { + return err + } + } + + // Validate the updates. + span.AddEvent("validate updates") + err := relationships.ValidateRelationshipUpdates(ctx, rwt, ps.config.CaveatTypeSet, relUpdates) + if err != nil { + return ps.rewriteError(ctx, err) + } + + dispatchCount, err := genutil.EnsureUInt32(len(req.OptionalPreconditions) + 1) + if err != nil { + return ps.rewriteError(ctx, err) + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + // One request per precondition and one request for the actual writes. + DispatchCount: dispatchCount, + }) + + span.AddEvent("preconditions") + if err := checkPreconditions(ctx, rwt, req.OptionalPreconditions); err != nil { + return err + } + + span.AddEvent("write relationships") + return rwt.WriteRelationships(ctx, relUpdates) + }, options.WithMetadata(req.OptionalTransactionMetadata)) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + // Log a metric of the counts of the different kinds of update operations. + updateCountByOperation := make(map[v1.RelationshipUpdate_Operation]int, 0) + for _, update := range req.Updates { + updateCountByOperation[update.Operation]++ + } + + for kind, count := range updateCountByOperation { + writeUpdateCounter.WithLabelValues(v1.RelationshipUpdate_Operation_name[int32(kind)]).Observe(float64(count)) + } + + return &v1.WriteRelationshipsResponse{ + WrittenAt: zedtoken.MustNewFromRevision(revision), + }, nil +} + +func (ps *permissionServer) validateTransactionMetadata(metadata *structpb.Struct) error { + if metadata == nil { + return nil + } + + b, err := metadata.MarshalJSON() + if err != nil { + return err + } + + if len(b) > MaximumTransactionMetadataSize { + return NewTransactionMetadataTooLargeErr(len(b), MaximumTransactionMetadataSize) + } + + return nil +} + +func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.DeleteRelationshipsRequest) (*v1.DeleteRelationshipsResponse, error) { + if err := ps.validateTransactionMetadata(req.OptionalTransactionMetadata); err != nil { + return nil, ps.rewriteError(ctx, err) + } + + if len(req.OptionalPreconditions) > int(ps.config.MaxPreconditionsCount) { + return nil, ps.rewriteError( + ctx, + NewExceedsMaximumPreconditionsErr(uint64(len(req.OptionalPreconditions)), uint64(ps.config.MaxPreconditionsCount)), + ) + } + + if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxDeleteRelationshipsLimit { + return nil, ps.rewriteError(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxDeleteRelationshipsLimit))) + } + + ds := datastoremw.MustFromContext(ctx) + deletionProgress := v1.DeleteRelationshipsResponse_DELETION_PROGRESS_COMPLETE + + var deletedRelationshipCount uint64 + revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, rwt); err != nil { + return err + } + + dispatchCount, err := genutil.EnsureUInt32(len(req.OptionalPreconditions) + 1) + if err != nil { + return ps.rewriteError(ctx, err) + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + // One request per precondition and one request for the actual delete. + DispatchCount: dispatchCount, + }) + + for _, precond := range req.OptionalPreconditions { + if err := validatePrecondition(ctx, precond, rwt); err != nil { + return err + } + } + + if err := checkPreconditions(ctx, rwt, req.OptionalPreconditions); err != nil { + return err + } + + // If a limit was specified but partial deletion is not allowed, we need to check if the + // number of relationships to be deleted exceeds the limit. + if req.OptionalLimit > 0 && !req.OptionalAllowPartialDeletions { + limit := uint64(req.OptionalLimit) + limitPlusOne := limit + 1 + filter, err := datastore.RelationshipsFilterFromPublicFilter(req.RelationshipFilter) + if err != nil { + return ps.rewriteError(ctx, err) + } + + it, err := rwt.QueryRelationships(ctx, filter, options.WithLimit(&limitPlusOne), options.WithQueryShape(queryshape.Varying)) + if err != nil { + return ps.rewriteError(ctx, err) + } + + counter := uint64(0) + for _, err := range it { + if err != nil { + return ps.rewriteError(ctx, err) + } + + if counter == limit { + return ps.rewriteError(ctx, NewCouldNotTransactionallyDeleteErr(req.RelationshipFilter, req.OptionalLimit)) + } + + counter++ + } + } + + // Delete with the specified limit. + if req.OptionalLimit > 0 { + deleteLimit := uint64(req.OptionalLimit) + drc, reachedLimit, err := rwt.DeleteRelationships(ctx, req.RelationshipFilter, options.WithDeleteLimit(&deleteLimit)) + if err != nil { + return err + } + + if reachedLimit { + deletionProgress = v1.DeleteRelationshipsResponse_DELETION_PROGRESS_PARTIAL + } + + deletedRelationshipCount = drc + return nil + } + + // Otherwise, kick off an unlimited deletion. + deletedRelationshipCount, _, err = rwt.DeleteRelationships(ctx, req.RelationshipFilter) + return err + }, options.WithMetadata(req.OptionalTransactionMetadata)) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + return &v1.DeleteRelationshipsResponse{ + DeletedAt: zedtoken.MustNewFromRevision(revision), + DeletionProgress: deletionProgress, + RelationshipsDeletedCount: deletedRelationshipCount, + }, nil +} + +var emptyPrecondition = &v1.Precondition{} + +func validatePrecondition(ctx context.Context, precond *v1.Precondition, reader datastore.Reader) error { + if precond.EqualVT(emptyPrecondition) || precond.Filter == nil { + return NewEmptyPreconditionErr() + } + + return validateRelationshipsFilter(ctx, precond.Filter, reader) +} + +func checkFilterComponent(ctx context.Context, objectType, optionalRelation string, ds datastore.Reader) error { + if objectType == "" { + return nil + } + + relationToTest := stringz.DefaultEmpty(optionalRelation, datastore.Ellipsis) + allowEllipsis := optionalRelation == "" + return namespace.CheckNamespaceAndRelation(ctx, objectType, relationToTest, allowEllipsis, ds) +} + +func validateRelationshipsFilter(ctx context.Context, filter *v1.RelationshipFilter, ds datastore.Reader) error { + // ResourceType is optional, so only check the relation if it is specified. + if filter.ResourceType != "" { + if err := checkFilterComponent(ctx, filter.ResourceType, filter.OptionalRelation, ds); err != nil { + return err + } + } + + // SubjectFilter is optional, so only check if it is specified. + if subjectFilter := filter.OptionalSubjectFilter; subjectFilter != nil { + subjectRelation := "" + if subjectFilter.OptionalRelation != nil { + subjectRelation = subjectFilter.OptionalRelation.Relation + } + if err := checkFilterComponent(ctx, subjectFilter.SubjectType, subjectRelation, ds); err != nil { + return err + } + } + + // Ensure the resource ID and the resource ID prefix are not set at the same time. + if filter.OptionalResourceId != "" && filter.OptionalResourceIdPrefix != "" { + return NewInvalidFilterErr("resource_id and resource_id_prefix cannot be set at the same time", filter.String()) + } + + // Ensure that at least one field is set. + return checkIfFilterIsEmpty(filter) +} + +func checkIfFilterIsEmpty(filter *v1.RelationshipFilter) error { + if filter.ResourceType == "" && + filter.OptionalResourceId == "" && + filter.OptionalResourceIdPrefix == "" && + filter.OptionalRelation == "" && + filter.OptionalSubjectFilter == nil { + return NewInvalidFilterErr("at least one field must be set", filter.String()) + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go b/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go new file mode 100644 index 0000000..14faf3d --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go @@ -0,0 +1,375 @@ +package v1 + +import ( + "context" + "sort" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/middleware" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/services/shared" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +// NewSchemaServer creates a SchemaServiceServer instance. +func NewSchemaServer(caveatTypeSet *caveattypes.TypeSet, additiveOnly bool, expiringRelsEnabled bool) v1.SchemaServiceServer { + cts := caveattypes.TypeSetOrDefault(caveatTypeSet) + return &schemaServer{ + WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{ + Unary: middleware.ChainUnaryServer( + grpcvalidate.UnaryServerInterceptor(), + usagemetrics.UnaryServerInterceptor(), + ), + Stream: middleware.ChainStreamServer( + grpcvalidate.StreamServerInterceptor(), + usagemetrics.StreamServerInterceptor(), + ), + }, + additiveOnly: additiveOnly, + expiringRelsEnabled: expiringRelsEnabled, + caveatTypeSet: cts, + } +} + +type schemaServer struct { + v1.UnimplementedSchemaServiceServer + shared.WithServiceSpecificInterceptors + + caveatTypeSet *caveattypes.TypeSet + additiveOnly bool + expiringRelsEnabled bool +} + +func (ss *schemaServer) rewriteError(ctx context.Context, err error) error { + return shared.RewriteError(ctx, err, nil) +} + +func (ss *schemaServer) ReadSchema(ctx context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) { + // Schema is always read from the head revision. + ds := datastoremw.MustFromContext(ctx) + headRevision, err := ds.HeadRevision(ctx) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + reader := ds.SnapshotReader(headRevision) + + nsDefs, err := reader.ListAllNamespaces(ctx) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + caveatDefs, err := reader.ListAllCaveats(ctx) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + if len(nsDefs) == 0 { + return nil, status.Errorf(codes.NotFound, "No schema has been defined; please call WriteSchema to start") + } + + schemaDefinitions := make([]compiler.SchemaDefinition, 0, len(nsDefs)+len(caveatDefs)) + for _, caveatDef := range caveatDefs { + schemaDefinitions = append(schemaDefinitions, caveatDef.Definition) + } + + for _, nsDef := range nsDefs { + schemaDefinitions = append(schemaDefinitions, nsDef.Definition) + } + + schemaText, _, err := generator.GenerateSchema(schemaDefinitions) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + dispatchCount, err := genutil.EnsureUInt32(len(nsDefs) + len(caveatDefs)) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + DispatchCount: dispatchCount, + }) + + return &v1.ReadSchemaResponse{ + SchemaText: schemaText, + ReadAt: zedtoken.MustNewFromRevision(headRevision), + }, nil +} + +func (ss *schemaServer) WriteSchema(ctx context.Context, in *v1.WriteSchemaRequest) (*v1.WriteSchemaResponse, error) { + log.Ctx(ctx).Trace().Str("schema", in.GetSchema()).Msg("requested Schema to be written") + + ds := datastoremw.MustFromContext(ctx) + + // Compile the schema into the namespace definitions. + opts := make([]compiler.Option, 0, 3) + if !ss.expiringRelsEnabled { + opts = append(opts, compiler.DisallowExpirationFlag()) + } + + opts = append(opts, compiler.CaveatTypeSet(ss.caveatTypeSet)) + + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: in.GetSchema(), + }, compiler.AllowUnprefixedObjectType(), opts...) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + log.Ctx(ctx).Trace().Int("objectDefinitions", len(compiled.ObjectDefinitions)).Int("caveatDefinitions", len(compiled.CaveatDefinitions)).Msg("compiled namespace definitions") + + // Do as much validation as we can before talking to the datastore. + validated, err := shared.ValidateSchemaChanges(ctx, compiled, ss.caveatTypeSet, ss.additiveOnly) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + // Update the schema. + revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + applied, err := shared.ApplySchemaChanges(ctx, rwt, ss.caveatTypeSet, validated) + if err != nil { + return err + } + + dispatchCount, err := genutil.EnsureUInt32(applied.TotalOperationCount) + if err != nil { + return err + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + DispatchCount: dispatchCount, + }) + return nil + }) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + return &v1.WriteSchemaResponse{ + WrittenAt: zedtoken.MustNewFromRevision(revision), + }, nil +} + +func (ss *schemaServer) ReflectSchema(ctx context.Context, req *v1.ReflectSchemaRequest) (*v1.ReflectSchemaResponse, error) { + // Get the current schema. + schema, atRevision, err := loadCurrentSchema(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + filters, err := newSchemaFilters(req.OptionalFilters) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + definitions := make([]*v1.ReflectionDefinition, 0, len(schema.ObjectDefinitions)) + if filters.HasNamespaces() { + for _, ns := range schema.ObjectDefinitions { + def, err := namespaceAPIRepr(ns, filters) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + if def != nil { + definitions = append(definitions, def) + } + } + } + + caveats := make([]*v1.ReflectionCaveat, 0, len(schema.CaveatDefinitions)) + if filters.HasCaveats() { + for _, cd := range schema.CaveatDefinitions { + caveat, err := caveatAPIRepr(cd, filters, ss.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + if caveat != nil { + caveats = append(caveats, caveat) + } + } + } + + return &v1.ReflectSchemaResponse{ + Definitions: definitions, + Caveats: caveats, + ReadAt: zedtoken.MustNewFromRevision(atRevision), + }, nil +} + +func (ss *schemaServer) DiffSchema(ctx context.Context, req *v1.DiffSchemaRequest) (*v1.DiffSchemaResponse, error) { + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, err + } + + diff, existingSchema, comparisonSchema, err := schemaDiff(ctx, req.ComparisonSchema, ss.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + resp, err := convertDiff(diff, existingSchema, comparisonSchema, atRevision, ss.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return resp, nil +} + +func (ss *schemaServer) ComputablePermissions(ctx context.Context, req *v1.ComputablePermissionsRequest) (*v1.ComputablePermissionsResponse, error) { + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) + vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relationName := req.RelationName + if relationName == "" { + relationName = tuple.Ellipsis + } else { + if _, ok := vdef.GetRelation(relationName); !ok { + return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, relationName)) + } + } + + allNamespaces, err := ds.ListAllNamespaces(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + allDefinitions := make([]*core.NamespaceDefinition, 0, len(allNamespaces)) + for _, ns := range allNamespaces { + allDefinitions = append(allDefinitions, ns.Definition) + } + + rg := vdef.Reachability() + rr, err := rg.RelationsEncounteredForSubject(ctx, allDefinitions, &core.RelationReference{ + Namespace: req.DefinitionName, + Relation: relationName, + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations := make([]*v1.ReflectionRelationReference, 0, len(rr)) + for _, r := range rr { + if r.Namespace == req.DefinitionName && r.Relation == req.RelationName { + continue + } + + if req.OptionalDefinitionNameFilter != "" && !strings.HasPrefix(r.Namespace, req.OptionalDefinitionNameFilter) { + continue + } + + ts, err := ts.GetDefinition(ctx, r.Namespace) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations = append(relations, &v1.ReflectionRelationReference{ + DefinitionName: r.Namespace, + RelationName: r.Relation, + IsPermission: ts.IsPermission(r.Relation), + }) + } + + sort.Slice(relations, func(i, j int) bool { + if relations[i].DefinitionName == relations[j].DefinitionName { + return relations[i].RelationName < relations[j].RelationName + } + return relations[i].DefinitionName < relations[j].DefinitionName + }) + + return &v1.ComputablePermissionsResponse{ + Permissions: relations, + ReadAt: revisionReadAt, + }, nil +} + +func (ss *schemaServer) DependentRelations(ctx context.Context, req *v1.DependentRelationsRequest) (*v1.DependentRelationsResponse, error) { + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) + vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + _, ok := vdef.GetRelation(req.PermissionName) + if !ok { + return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, req.PermissionName)) + } + + if !vdef.IsPermission(req.PermissionName) { + return nil, shared.RewriteErrorWithoutConfig(ctx, NewNotAPermissionError(req.PermissionName)) + } + + rg := vdef.Reachability() + rr, err := rg.RelationsEncounteredForResource(ctx, &core.RelationReference{ + Namespace: req.DefinitionName, + Relation: req.PermissionName, + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations := make([]*v1.ReflectionRelationReference, 0, len(rr)) + for _, r := range rr { + if r.Namespace == req.DefinitionName && r.Relation == req.PermissionName { + continue + } + + ts, err := ts.GetDefinition(ctx, r.Namespace) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations = append(relations, &v1.ReflectionRelationReference{ + DefinitionName: r.Namespace, + RelationName: r.Relation, + IsPermission: ts.IsPermission(r.Relation), + }) + } + + sort.Slice(relations, func(i, j int) bool { + if relations[i].DefinitionName == relations[j].DefinitionName { + return relations[i].RelationName < relations[j].RelationName + } + + return relations[i].DefinitionName < relations[j].DefinitionName + }) + + return &v1.DependentRelationsResponse{ + Relations: relations, + ReadAt: revisionReadAt, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go b/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go new file mode 100644 index 0000000..ef13a26 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go @@ -0,0 +1,190 @@ +package v1 + +import ( + "context" + "errors" + "slices" + "time" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/mapz" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +type watchServer struct { + v1.UnimplementedWatchServiceServer + shared.WithStreamServiceSpecificInterceptor + + heartbeatDuration time.Duration +} + +// NewWatchServer creates an instance of the watch server. +func NewWatchServer(heartbeatDuration time.Duration) v1.WatchServiceServer { + s := &watchServer{ + WithStreamServiceSpecificInterceptor: shared.WithStreamServiceSpecificInterceptor{ + Stream: grpcvalidate.StreamServerInterceptor(), + }, + heartbeatDuration: heartbeatDuration, + } + return s +} + +func (ws *watchServer) Watch(req *v1.WatchRequest, stream v1.WatchService_WatchServer) error { + if len(req.GetOptionalUpdateKinds()) == 0 || + slices.Contains(req.GetOptionalUpdateKinds(), v1.WatchKind_WATCH_KIND_UNSPECIFIED) || + slices.Contains(req.GetOptionalUpdateKinds(), v1.WatchKind_WATCH_KIND_INCLUDE_RELATIONSHIP_UPDATES) { + if len(req.GetOptionalObjectTypes()) > 0 && len(req.OptionalRelationshipFilters) > 0 { + return status.Errorf(codes.InvalidArgument, "cannot specify both object types and relationship filters") + } + } + + objectTypes := mapz.NewSet[string](req.GetOptionalObjectTypes()...) + + ctx := stream.Context() + ds := datastoremw.MustFromContext(ctx) + + var afterRevision datastore.Revision + if req.OptionalStartCursor != nil && req.OptionalStartCursor.Token != "" { + decodedRevision, err := zedtoken.DecodeRevision(req.OptionalStartCursor, ds) + if err != nil { + return status.Errorf(codes.InvalidArgument, "failed to decode start revision: %s", err) + } + + afterRevision = decodedRevision + } else { + var err error + afterRevision, err = ds.OptimizedRevision(ctx) + if err != nil { + return status.Errorf(codes.Unavailable, "failed to start watch: %s", err) + } + } + + reader := ds.SnapshotReader(afterRevision) + + filters, err := buildRelationshipFilters(req, stream, reader, ws, ctx) + if err != nil { + return err + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + DispatchCount: 1, + }) + + updates, errchan := ds.Watch(ctx, afterRevision, datastore.WatchOptions{ + Content: convertWatchKindToContent(req.OptionalUpdateKinds), + CheckpointInterval: ws.heartbeatDuration, + }) + for { + select { + case update, ok := <-updates: + if ok { + filteredRelationshipUpdates := filterRelationshipUpdates(objectTypes, filters, update.RelationshipChanges) + if len(filteredRelationshipUpdates) > 0 { + converted, err := tuple.UpdatesToV1RelationshipUpdates(filteredRelationshipUpdates) + if err != nil { + return status.Errorf(codes.Internal, "failed to convert updates: %s", err) + } + + if err := stream.Send(&v1.WatchResponse{ + Updates: converted, + ChangesThrough: zedtoken.MustNewFromRevision(update.Revision), + OptionalTransactionMetadata: update.Metadata, + }); err != nil { + return status.Errorf(codes.Canceled, "watch canceled by user: %s", err) + } + } + if len(update.ChangedDefinitions) > 0 || len(update.DeletedCaveats) > 0 || len(update.DeletedNamespaces) > 0 { + if err := stream.Send(&v1.WatchResponse{ + SchemaUpdated: true, + ChangesThrough: zedtoken.MustNewFromRevision(update.Revision), + OptionalTransactionMetadata: update.Metadata, + }); err != nil { + return status.Errorf(codes.Canceled, "watch canceled by user: %s", err) + } + } + if update.IsCheckpoint { + if err := stream.Send(&v1.WatchResponse{ + IsCheckpoint: update.IsCheckpoint, + ChangesThrough: zedtoken.MustNewFromRevision(update.Revision), + OptionalTransactionMetadata: update.Metadata, + }); err != nil { + return status.Errorf(codes.Canceled, "watch canceled by user: %s", err) + } + } + } + case err := <-errchan: + switch { + case errors.As(err, &datastore.WatchCanceledError{}): + return status.Errorf(codes.Canceled, "watch canceled by user: %s", err) + case errors.As(err, &datastore.WatchDisconnectedError{}): + return status.Errorf(codes.ResourceExhausted, "watch disconnected: %s", err) + default: + return status.Errorf(codes.Internal, "watch error: %s", err) + } + } + } +} + +func buildRelationshipFilters(req *v1.WatchRequest, stream v1.WatchService_WatchServer, reader datastore.Reader, ws *watchServer, ctx context.Context) ([]datastore.RelationshipsFilter, error) { + filters := make([]datastore.RelationshipsFilter, 0, len(req.OptionalRelationshipFilters)) + for _, filter := range req.OptionalRelationshipFilters { + if err := validateRelationshipsFilter(stream.Context(), filter, reader); err != nil { + return nil, ws.rewriteError(ctx, err) + } + + dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to parse relationship filter: %s", err) + } + + filters = append(filters, dsFilter) + } + return filters, nil +} + +func (ws *watchServer) rewriteError(ctx context.Context, err error) error { + return shared.RewriteError(ctx, err, &shared.ConfigForErrors{}) +} + +func filterRelationshipUpdates(objectTypes *mapz.Set[string], filters []datastore.RelationshipsFilter, updates []tuple.RelationshipUpdate) []tuple.RelationshipUpdate { + if objectTypes.IsEmpty() && len(filters) == 0 { + return updates + } + + filtered := make([]tuple.RelationshipUpdate, 0, len(updates)) + for _, update := range updates { + objectType := update.Relationship.Resource.ObjectType + if !objectTypes.IsEmpty() && !objectTypes.Has(objectType) { + continue + } + + if len(filters) > 0 { + // If there are filters, we need to check if the update matches any of them. + matched := false + for _, filter := range filters { + if filter.Test(update.Relationship) { + matched = true + break + } + } + + if !matched { + continue + } + } + + filtered = append(filtered, update) + } + + return filtered +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go b/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go new file mode 100644 index 0000000..08910a1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go @@ -0,0 +1,22 @@ +package v1 + +import ( + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/datastore" +) + +func convertWatchKindToContent(kinds []v1.WatchKind) datastore.WatchContent { + res := datastore.WatchRelationships + for _, kind := range kinds { + switch kind { + case v1.WatchKind_WATCH_KIND_INCLUDE_RELATIONSHIP_UPDATES: + res |= datastore.WatchRelationships + case v1.WatchKind_WATCH_KIND_INCLUDE_SCHEMA_UPDATES: + res |= datastore.WatchSchema + case v1.WatchKind_WATCH_KIND_INCLUDE_CHECKPOINTS: + res |= datastore.WatchCheckpoints + } + } + return res +} diff --git a/vendor/github.com/authzed/spicedb/internal/sharederrors/interfaces.go b/vendor/github.com/authzed/spicedb/internal/sharederrors/interfaces.go new file mode 100644 index 0000000..852c28f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/sharederrors/interfaces.go @@ -0,0 +1,16 @@ +package sharederrors + +// UnknownNamespaceError is an error raised when a namespace was not found. +type UnknownNamespaceError interface { + // NotFoundNamespaceName is the name of the namespace that was not found. + NotFoundNamespaceName() string +} + +// UnknownRelationError is an error raised when a relation was not found. +type UnknownRelationError interface { + // NamespaceName is the name of the namespace under which the relation was not found. + NamespaceName() string + + // NotFoundRelationName is the name of the relation that was not found. + NotFoundRelationName() string +} diff --git a/vendor/github.com/authzed/spicedb/internal/taskrunner/doc.go b/vendor/github.com/authzed/spicedb/internal/taskrunner/doc.go new file mode 100644 index 0000000..a6000de --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/taskrunner/doc.go @@ -0,0 +1,2 @@ +// Package taskrunner contains helper code run concurrent code. +package taskrunner diff --git a/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go b/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go new file mode 100644 index 0000000..a088e35 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go @@ -0,0 +1,153 @@ +package taskrunner + +import ( + "context" + "sync" +) + +// PreloadedTaskRunner is a task runner that invokes a series of preloaded tasks, +// running until the tasks are completed, the context is canceled or an error is +// returned by one of the tasks (which cancels the context). +type PreloadedTaskRunner struct { + // ctx holds the context given to the task runner and annotated with the cancel + // function. + ctx context.Context + cancel func() + + // sem is a chan of length `concurrencyLimit` used to ensure the task runner does + // not exceed the concurrencyLimit with spawned goroutines. + sem chan struct{} + + wg sync.WaitGroup + + lock sync.Mutex + err error // GUARDED_BY(lock) + tasks []TaskFunc // GUARDED_BY(lock) +} + +func NewPreloadedTaskRunner(ctx context.Context, concurrencyLimit uint16, initialCapacity int) *PreloadedTaskRunner { + // Ensure a concurrency level of at least 1. + if concurrencyLimit <= 0 { + concurrencyLimit = 1 + } + + ctxWithCancel, cancel := context.WithCancel(ctx) + return &PreloadedTaskRunner{ + ctx: ctxWithCancel, + cancel: cancel, + sem: make(chan struct{}, concurrencyLimit), + tasks: make([]TaskFunc, 0, initialCapacity), + } +} + +// Add adds the given task function to be run. +func (tr *PreloadedTaskRunner) Add(f TaskFunc) { + tr.tasks = append(tr.tasks, f) + tr.wg.Add(1) +} + +// Start starts running the tasks in the task runner. This does *not* wait for the tasks +// to complete, but rather returns immediately. +func (tr *PreloadedTaskRunner) Start() { + for range tr.tasks { + tr.spawnIfAvailable() + } +} + +// StartAndWait starts running the tasks in the task runner and waits for them to complete. +func (tr *PreloadedTaskRunner) StartAndWait() error { + tr.Start() + tr.wg.Wait() + + tr.lock.Lock() + defer tr.lock.Unlock() + + return tr.err +} + +func (tr *PreloadedTaskRunner) spawnIfAvailable() { + // To spawn a runner, write a struct{} to the sem channel. If the task runner + // is already at the concurrency limit, then this chan write will fail, + // and nothing will be spawned. This also checks if the context has already + // been canceled, in which case nothing needs to be done. + select { + case tr.sem <- struct{}{}: + go tr.runner() + + case <-tr.ctx.Done(): + // If the context was canceled, nothing more to do. + tr.emptyForCancel() + return + + default: + return + } +} + +func (tr *PreloadedTaskRunner) runner() { + for { + select { + case <-tr.ctx.Done(): + // If the context was canceled, nothing more to do. + tr.emptyForCancel() + return + + default: + // Select a task from the list, if any. + task := tr.selectTask() + if task == nil { + return + } + + // Run the task. If an error occurs, store it and cancel any further tasks. + err := task(tr.ctx) + if err != nil { + tr.storeErrorAndCancel(err) + } + tr.wg.Done() + } + } +} + +func (tr *PreloadedTaskRunner) selectTask() TaskFunc { + tr.lock.Lock() + defer tr.lock.Unlock() + + if len(tr.tasks) == 0 { + return nil + } + + task := tr.tasks[0] + tr.tasks[0] = nil // to free the reference once the task completes. + tr.tasks = tr.tasks[1:] + return task +} + +func (tr *PreloadedTaskRunner) storeErrorAndCancel(err error) { + tr.lock.Lock() + defer tr.lock.Unlock() + + if tr.err == nil { + tr.err = err + tr.cancel() + } +} + +func (tr *PreloadedTaskRunner) emptyForCancel() { + tr.lock.Lock() + defer tr.lock.Unlock() + + if tr.err == nil { + tr.err = tr.ctx.Err() + } + + for { + if len(tr.tasks) == 0 { + break + } + + tr.tasks[0] = nil // to free the reference + tr.tasks = tr.tasks[1:] + tr.wg.Done() + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go b/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go new file mode 100644 index 0000000..1b519ed --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go @@ -0,0 +1,168 @@ +package taskrunner + +import ( + "context" + "sync" +) + +// TaskRunner is a helper which runs a series of scheduled tasks against a defined +// limit of goroutines. +type TaskRunner struct { + // ctx holds the context given to the task runner and annotated with the cancel + // function. + ctx context.Context + cancel func() + + // sem is a chan of length `concurrencyLimit` used to ensure the task runner does + // not exceed the concurrencyLimit with spawned goroutines. + sem chan struct{} + + wg sync.WaitGroup + + lock sync.Mutex + tasks []TaskFunc // GUARDED_BY(lock) + + // err holds the error returned by any task, if any. If the context is canceled, + // this err will hold the cancelation error. + err error // GUARDED_BY(lock) +} + +// TaskFunc defines functions representing tasks. +type TaskFunc func(ctx context.Context) error + +// NewTaskRunner creates a new task runner with the given starting context and +// concurrency limit. The TaskRunner will schedule no more goroutines that the +// specified concurrencyLimit. If the given context is canceled, then all tasks +// started after that point will also be canceled and the error returned. If +// a task returns an error, the context provided to all tasks is also canceled. +func NewTaskRunner(ctx context.Context, concurrencyLimit uint16) *TaskRunner { + if concurrencyLimit < 1 { + concurrencyLimit = 1 + } + + ctxWithCancel, cancel := context.WithCancel(ctx) + return &TaskRunner{ + ctx: ctxWithCancel, + cancel: cancel, + sem: make(chan struct{}, concurrencyLimit), + tasks: make([]TaskFunc, 0), + } +} + +// Schedule schedules a task to be run. This is safe to call from within another +// task handler function and immediately returns. +func (tr *TaskRunner) Schedule(f TaskFunc) { + if tr.addTask(f) { + tr.spawnIfAvailable() + } +} + +func (tr *TaskRunner) spawnIfAvailable() { + // To spawn a runner, write a struct{} to the sem channel. If the task runner + // is already at the concurrency limit, then this chan write will fail, + // and nothing will be spawned. This also checks if the context has already + // been canceled, in which case nothing needs to be done. + select { + case tr.sem <- struct{}{}: + go tr.runner() + + case <-tr.ctx.Done(): + return + + default: + return + } +} + +func (tr *TaskRunner) runner() { + for { + select { + case <-tr.ctx.Done(): + // If the context was canceled, mark all the remaining tasks as "Done". + tr.emptyForCancel() + return + + default: + // Select a task from the list, if any. + task := tr.selectTask() + if task == nil { + // If there are no further tasks, then "return" the struct{} by reading + // it from the channel (freeing a slot potentially for another worker + // to be spawned later). + <-tr.sem + return + } + + // Run the task. If an error occurs, store it and cancel any further tasks. + err := task(tr.ctx) + if err != nil { + tr.storeErrorAndCancel(err) + } + tr.wg.Done() + } + } +} + +func (tr *TaskRunner) addTask(f TaskFunc) bool { + tr.lock.Lock() + defer tr.lock.Unlock() + + if tr.err != nil { + return false + } + + tr.wg.Add(1) + tr.tasks = append(tr.tasks, f) + return true +} + +func (tr *TaskRunner) selectTask() TaskFunc { + tr.lock.Lock() + defer tr.lock.Unlock() + + if len(tr.tasks) == 0 { + return nil + } + + task := tr.tasks[0] + tr.tasks = tr.tasks[1:] + return task +} + +func (tr *TaskRunner) storeErrorAndCancel(err error) { + tr.lock.Lock() + defer tr.lock.Unlock() + + if tr.err == nil { + tr.err = err + tr.cancel() + } +} + +func (tr *TaskRunner) emptyForCancel() { + tr.lock.Lock() + defer tr.lock.Unlock() + + if tr.err == nil { + tr.err = tr.ctx.Err() + } + + for { + if len(tr.tasks) == 0 { + break + } + + tr.tasks = tr.tasks[1:] + tr.wg.Done() + } +} + +// Wait waits for all tasks to be completed, or a task to raise an error, +// or the parent context to have been canceled. +func (tr *TaskRunner) Wait() error { + tr.wg.Wait() + + tr.lock.Lock() + defer tr.lock.Unlock() + return tr.err +} diff --git a/vendor/github.com/authzed/spicedb/internal/telemetry/doc.go b/vendor/github.com/authzed/spicedb/internal/telemetry/doc.go new file mode 100644 index 0000000..acf0c0b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/telemetry/doc.go @@ -0,0 +1,6 @@ +// Package telemetry implements a client for reporting telemetry data used to +// prioritize development of SpiceDB. +// +// For more information, see: +// https://github.com/authzed/spicedb/blob/main/TELEMETRY.md +package telemetry diff --git a/vendor/github.com/authzed/spicedb/internal/telemetry/logicalchecks.go b/vendor/github.com/authzed/spicedb/internal/telemetry/logicalchecks.go new file mode 100644 index 0000000..3aa91a4 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/telemetry/logicalchecks.go @@ -0,0 +1,16 @@ +package telemetry + +import "sync/atomic" + +var logicalChecksCountTotal atomic.Uint64 + +// RecordLogicalChecks records the number of logical checks performed by the server. +func RecordLogicalChecks(logicalCheckCount uint64) { + logicalChecksCountTotal.Add(logicalCheckCount) +} + +// loadLogicalChecksCount returns the total number of logical checks performed by the server, +// zeroing out the existing count as well. +func loadLogicalChecksCount() uint64 { + return logicalChecksCountTotal.Swap(0) +} diff --git a/vendor/github.com/authzed/spicedb/internal/telemetry/metrics.go b/vendor/github.com/authzed/spicedb/internal/telemetry/metrics.go new file mode 100644 index 0000000..4323297 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/telemetry/metrics.go @@ -0,0 +1,203 @@ +package telemetry + +import ( + "context" + "fmt" + "os" + "runtime" + "runtime/debug" + "strconv" + "time" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "golang.org/x/sync/errgroup" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/promutil" +) + +func SpiceDBClusterInfoCollector(ctx context.Context, subsystem, dsEngine string, ds datastore.Datastore) (promutil.CollectorFunc, error) { + nodeID, err := os.Hostname() + if err != nil { + return nil, fmt.Errorf("unable to get hostname: %w", err) + } + + dbStats, err := ds.Statistics(ctx) + if err != nil { + return nil, fmt.Errorf("unable to query DB stats: %w", err) + } + + clusterID := dbStats.UniqueID + buildInfo, ok := debug.ReadBuildInfo() + if !ok { + return nil, fmt.Errorf("failed to read BuildInfo") + } + + return func(ch chan<- prometheus.Metric) { + ch <- prometheus.MustNewConstMetric(prometheus.NewDesc( + prometheus.BuildFQName("spicedb", subsystem, "info"), + "Information about the SpiceDB environment.", + nil, + prometheus.Labels{ + "cluster_id": clusterID, + "node_id": nodeID, + "version": cobrautil.VersionWithFallbacks(buildInfo), + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "go": buildInfo.GoVersion, + "vcpu": strconv.Itoa(runtime.NumCPU()), + "ds_engine": dsEngine, + }, + ), prometheus.GaugeValue, 1) + }, nil +} + +// RegisterTelemetryCollector registers a collector for the various pieces of +// data required by SpiceDB telemetry. +func RegisterTelemetryCollector(datastoreEngine string, ds datastore.Datastore) (*prometheus.Registry, error) { + registry := prometheus.NewRegistry() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + infoCollector, err := SpiceDBClusterInfoCollector(ctx, "telemetry", datastoreEngine, ds) + if err != nil { + return nil, fmt.Errorf("unable create info collector: %w", err) + } + + if err := registry.Register(infoCollector); err != nil { + return nil, fmt.Errorf("unable to register telemetry collector: %w", err) + } + + nodeID, err := os.Hostname() + if err != nil { + return nil, fmt.Errorf("unable to get hostname: %w", err) + } + + dbStats, err := ds.Statistics(ctx) + if err != nil { + return nil, fmt.Errorf("unable to query DB stats: %w", err) + } + clusterID := dbStats.UniqueID + + if err := registry.Register(&collector{ + ds: ds, + objectDefsDesc: prometheus.NewDesc( + prometheus.BuildFQName("spicedb", "telemetry", "object_definitions_total"), + "Count of the number of objects defined by the schema.", + nil, + prometheus.Labels{ + "cluster_id": clusterID, + "node_id": nodeID, + }, + ), + relationshipsDesc: prometheus.NewDesc( + prometheus.BuildFQName("spicedb", "telemetry", "relationships_estimate_total"), + "Count of the estimated number of stored relationships.", + nil, + prometheus.Labels{ + "cluster_id": clusterID, + "node_id": nodeID, + }, + ), + dispatchedDesc: prometheus.NewDesc( + prometheus.BuildFQName("spicedb", "telemetry", "dispatches"), + "Histogram of cluster dispatches performed by the instance.", + usagemetrics.DispatchedCountLabels, + prometheus.Labels{ + "cluster_id": clusterID, + "node_id": nodeID, + }, + ), + logicalChecksDec: prometheus.NewDesc( + prometheus.BuildFQName("spicedb", "telemetry", "logical_checks_total"), + "Count of the number of logical checks made.", + usagemetrics.DispatchedCountLabels, + prometheus.Labels{ + "cluster_id": clusterID, + "node_id": nodeID, + }, + ), + }); err != nil { + return nil, fmt.Errorf("unable to register telemetry collector: %w", err) + } + + return registry, nil +} + +type collector struct { + ds datastore.Datastore + objectDefsDesc *prometheus.Desc + relationshipsDesc *prometheus.Desc + dispatchedDesc *prometheus.Desc + logicalChecksDec *prometheus.Desc +} + +var _ prometheus.Collector = &collector{} + +func (c *collector) Describe(ch chan<- *prometheus.Desc) { + ch <- c.objectDefsDesc + ch <- c.relationshipsDesc + ch <- c.dispatchedDesc + ch <- c.logicalChecksDec +} + +func (c *collector) Collect(ch chan<- prometheus.Metric) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + dsStats, err := c.ds.Statistics(ctx) + if err != nil { + log.Warn().Err(err).Msg("unable to collect datastore statistics") + } + + logicalChecksCount := loadLogicalChecksCount() + + ch <- prometheus.MustNewConstMetric(c.objectDefsDesc, prometheus.GaugeValue, float64(len(dsStats.ObjectTypeStatistics))) + ch <- prometheus.MustNewConstMetric(c.relationshipsDesc, prometheus.GaugeValue, float64(dsStats.EstimatedRelationshipCount)) + ch <- prometheus.MustNewConstMetric(c.logicalChecksDec, prometheus.GaugeValue, float64(logicalChecksCount)) + + dispatchedCountMetrics := make(chan prometheus.Metric) + g := errgroup.Group{} + g.Go(func() error { + for metric := range dispatchedCountMetrics { + var m dto.Metric + if err := metric.Write(&m); err != nil { + return fmt.Errorf("error writing metric: %w", err) + } + + buckets := make(map[float64]uint64, len(m.Histogram.Bucket)) + for _, bucket := range m.Histogram.Bucket { + buckets[*bucket.UpperBound] = *bucket.CumulativeCount + } + + dynamicLabels := make([]string, len(usagemetrics.DispatchedCountLabels)) + for i, labelName := range usagemetrics.DispatchedCountLabels { + for _, labelVal := range m.Label { + if *labelVal.Name == labelName { + dynamicLabels[i] = *labelVal.Value + } + } + } + ch <- prometheus.MustNewConstHistogram( + c.dispatchedDesc, + *m.Histogram.SampleCount, + *m.Histogram.SampleSum, + buckets, + dynamicLabels..., + ) + } + return nil + }) + + usagemetrics.DispatchedCountHistogram.Collect(dispatchedCountMetrics) + close(dispatchedCountMetrics) + + if err := g.Wait(); err != nil { + log.Error().Err(err).Msg("error collecting metrics") + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/telemetry/reporter.go b/vendor/github.com/authzed/spicedb/internal/telemetry/reporter.go new file mode 100644 index 0000000..8296b3c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/telemetry/reporter.go @@ -0,0 +1,234 @@ +package telemetry + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "io" + "math/rand" + "net/http" + "net/url" + "time" + + prompb "buf.build/gen/go/prometheus/prometheus/protocolbuffers/go" + "github.com/cenkalti/backoff/v4" + "github.com/gogo/protobuf/proto" + "github.com/golang/snappy" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/expfmt" + "github.com/prometheus/common/model" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/x509util" +) + +const ( + // DefaultEndpoint is the endpoint to which telemetry will report if none + // other is specified. + DefaultEndpoint = "https://telemetry.authzed.com" + + // DefaultInterval is the default amount of time to wait between telemetry + // reports. + DefaultInterval = 1 * time.Hour + + // MaxElapsedTimeBetweenReports is the maximum amount of time that the + // telemetry reporter will attempt to write to the telemetry endpoint + // before terminating the reporter. + MaxElapsedTimeBetweenReports = 168 * time.Hour + + // MinimumAllowedInterval is the minimum amount of time one can request + // between telemetry reports. + MinimumAllowedInterval = 1 * time.Minute +) + +func writeTimeSeries(ctx context.Context, client *http.Client, endpoint string, ts []*prompb.TimeSeries) error { + // Reference upstream client: + // https://github.com/prometheus/prometheus/blob/6555cc68caf8d8f323056e497ae7bb1e32a81667/storage/remote/client.go#L191 + pbBytes, err := proto.Marshal(&prompb.WriteRequest{ + Timeseries: ts, + }) + if err != nil { + return fmt.Errorf("failed to marshal Prometheus remote write protobuf: %w", err) + } + compressedPB := snappy.Encode(nil, pbBytes) + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(compressedPB)) + if err != nil { + return fmt.Errorf("failed to create Prometheus remote write http request: %w", err) + } + + r.Header.Add("X-Prometheus-Remote-Write-Version", "0.1.0") + r.Header.Add("Content-Encoding", "snappy") + r.Header.Set("Content-Type", "application/x-protobuf") + + resp, err := client.Do(r) + if err != nil { + return fmt.Errorf("failed to send Prometheus remote write: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode/100 != 2 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf( + "unexpected Prometheus remote write response: %d: %s", + resp.StatusCode, + string(body), + ) + } + + return nil +} + +func discoverTimeseries(registry *prometheus.Registry) (allTS []*prompb.TimeSeries, err error) { + metricFams, err := registry.Gather() + if err != nil { + return nil, fmt.Errorf("failed to gather telemetry metrics: %w", err) + } + + defaultTimestamp := model.Time(time.Now().UnixNano() / int64(time.Millisecond)) + sampleVector, err := expfmt.ExtractSamples(&expfmt.DecodeOptions{ + Timestamp: defaultTimestamp, + }, metricFams...) + if err != nil { + return nil, fmt.Errorf("unable to extract sample from metrics families: %w", err) + } + + for _, sample := range sampleVector { + allTS = append(allTS, &prompb.TimeSeries{ + Labels: convertLabels(sample.Metric), + Samples: []*prompb.Sample{{ + Value: float64(sample.Value), + Timestamp: int64(sample.Timestamp), + }}, + }) + } + + return +} + +func discoverAndWriteMetrics( + ctx context.Context, + registry *prometheus.Registry, + client *http.Client, + endpoint string, +) error { + ts, err := discoverTimeseries(registry) + if err != nil { + return err + } + + return writeTimeSeries(ctx, client, endpoint, ts) +} + +type Reporter func(ctx context.Context) error + +// RemoteReporter creates a telemetry reporter with the specified parameters, or errors +// if the configuration was invalid. +func RemoteReporter( + registry *prometheus.Registry, + endpoint string, + caOverridePath string, + interval time.Duration, +) (Reporter, error) { + if _, err := url.Parse(endpoint); err != nil { + return nil, fmt.Errorf("invalid telemetry endpoint: %w", err) + } + if interval < MinimumAllowedInterval { + return nil, fmt.Errorf("invalid telemetry reporting interval: %s < %s", interval, MinimumAllowedInterval) + } + if endpoint == DefaultEndpoint && interval != DefaultInterval { + return nil, fmt.Errorf("cannot change the telemetry reporting interval for the default endpoint") + } + + client := &http.Client{} + if caOverridePath != "" { + pool, err := x509util.CustomCertPool(caOverridePath) + if err != nil { + return nil, fmt.Errorf("invalid custom cert pool path `%s`: %w", caOverridePath, err) + } + + t := &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: pool, + MinVersion: tls.VersionTLS12, + }, + } + + client.Transport = t + } + + return func(ctx context.Context) error { + // nolint:gosec + // G404 use of non cryptographically secure random number generator is not a security concern here, + // as this is only used to smear the startup delay out over 10% of the reporting interval + startupDelay := time.Duration(rand.Int63n(int64(interval.Seconds()/10))) * time.Second + + log.Ctx(ctx).Info(). + Stringer("interval", interval). + Str("endpoint", endpoint). + Stringer("next", startupDelay). + Msg("telemetry reporter scheduled") + + backoffInterval := backoff.NewExponentialBackOff() + backoffInterval.InitialInterval = interval + backoffInterval.MaxInterval = MaxElapsedTimeBetweenReports + backoffInterval.MaxElapsedTime = 0 + + // Must reset the backoff object after changing parameters + backoffInterval.Reset() + + ticker := time.After(startupDelay) + + for { + select { + case <-ticker: + nextPush := backoffInterval.InitialInterval + if err := discoverAndWriteMetrics(ctx, registry, client, endpoint); err != nil { + nextPush = backoffInterval.NextBackOff() + log.Ctx(ctx).Warn(). + Err(err). + Str("endpoint", endpoint). + Stringer("next", nextPush). + Msg("failed to push telemetry metric") + } else { + log.Ctx(ctx).Debug(). + Str("endpoint", endpoint). + Stringer("next", nextPush). + Msg("reported telemetry") + backoffInterval.Reset() + } + if nextPush == backoff.Stop { + return fmt.Errorf( + "exceeded maximum time between successful reports of %s", + MaxElapsedTimeBetweenReports, + ) + } + ticker = time.After(nextPush) + + case <-ctx.Done(): + return nil + } + } + }, nil +} + +func DisabledReporter(ctx context.Context) error { + log.Ctx(ctx).Info().Msg("telemetry disabled") + return nil +} + +func SilentlyDisabledReporter(_ context.Context) error { + return nil +} + +func convertLabels(labels model.Metric) []*prompb.Label { + out := make([]*prompb.Label, 0, len(labels)) + for name, value := range labels { + out = append(out, &prompb.Label{ + Name: string(name), + Value: string(value), + }) + } + return out +} |
