summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/internal
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal')
-rw-r--r--vendor/github.com/authzed/spicedb/internal/caveats/builder.go152
-rw-r--r--vendor/github.com/authzed/spicedb/internal/caveats/debug.go164
-rw-r--r--vendor/github.com/authzed/spicedb/internal/caveats/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/caveats/errors.go107
-rw-r--r--vendor/github.com/authzed/spicedb/internal/caveats/run.go427
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datasets/basesubjectset.go856
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datasets/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datasets/subjectset.go65
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbyresourceid.go117
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbytype.go113
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go352
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go154
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go269
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go49
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/index.go28
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go15
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go42
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go214
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go188
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go17
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go961
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go31
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/url.go19
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go276
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md23
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go156
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go37
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go386
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go597
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go386
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go118
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go232
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go51
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go148
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go79
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go166
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go118
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go125
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go97
-rw-r--r--vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go80
-rw-r--r--vendor/github.com/authzed/spicedb/internal/developmentmembership/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/developmentmembership/foundsubject.go127
-rw-r--r--vendor/github.com/authzed/spicedb/internal/developmentmembership/membership.go167
-rw-r--r--vendor/github.com/authzed/spicedb/internal/developmentmembership/onrset.go87
-rw-r--r--vendor/github.com/authzed/spicedb/internal/developmentmembership/trackingsubjectset.go235
-rw-r--r--vendor/github.com/authzed/spicedb/internal/dispatch/dispatch.go98
-rw-r--r--vendor/github.com/authzed/spicedb/internal/dispatch/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/dispatch/errors.go39
-rw-r--r--vendor/github.com/authzed/spicedb/internal/dispatch/graph/errors.go77
-rw-r--r--vendor/github.com/authzed/spicedb/internal/dispatch/graph/graph.go437
-rw-r--r--vendor/github.com/authzed/spicedb/internal/dispatch/graph/zz_generated.options.go92
-rw-r--r--vendor/github.com/authzed/spicedb/internal/dispatch/stream.go187
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/check.go1354
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go144
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go205
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/context.go33
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/cursors.go542
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/errors.go213
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/expand.go436
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/graph.go89
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go96
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/limits.go80
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go681
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go803
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go334
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/membershipset.go243
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go248
-rw-r--r--vendor/github.com/authzed/spicedb/internal/graph/traceid.go13
-rw-r--r--vendor/github.com/authzed/spicedb/internal/grpchelpers/grpchelpers.go20
-rw-r--r--vendor/github.com/authzed/spicedb/internal/logging/logger.go43
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/chain.go58
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/datastore/datastore.go85
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/datastore/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/handwrittenvalidation.go54
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/servicespecific.go39
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/streamtimeout.go57
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/usagemetrics.go128
-rw-r--r--vendor/github.com/authzed/spicedb/internal/namespace/aliasing.go82
-rw-r--r--vendor/github.com/authzed/spicedb/internal/namespace/annotate.go29
-rw-r--r--vendor/github.com/authzed/spicedb/internal/namespace/canonicalization.go282
-rw-r--r--vendor/github.com/authzed/spicedb/internal/namespace/caveats.go69
-rw-r--r--vendor/github.com/authzed/spicedb/internal/namespace/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/namespace/errors.go171
-rw-r--r--vendor/github.com/authzed/spicedb/internal/namespace/util.go148
-rw-r--r--vendor/github.com/authzed/spicedb/internal/relationships/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/relationships/errors.go195
-rw-r--r--vendor/github.com/authzed/spicedb/internal/relationships/validation.go280
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/shared/errors.go208
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go52
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/shared/schema.go474
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go332
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/debug.go238
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/errors.go511
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go824
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go720
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go72
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/hash.go110
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go52
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go50
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go12
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go93
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go1094
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go54
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go720
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go76
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go576
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/schema.go375
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/watch.go190
-rw-r--r--vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go22
-rw-r--r--vendor/github.com/authzed/spicedb/internal/sharederrors/interfaces.go16
-rw-r--r--vendor/github.com/authzed/spicedb/internal/taskrunner/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go153
-rw-r--r--vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go168
-rw-r--r--vendor/github.com/authzed/spicedb/internal/telemetry/doc.go6
-rw-r--r--vendor/github.com/authzed/spicedb/internal/telemetry/logicalchecks.go16
-rw-r--r--vendor/github.com/authzed/spicedb/internal/telemetry/metrics.go203
-rw-r--r--vendor/github.com/authzed/spicedb/internal/telemetry/reporter.go234
123 files changed, 23894 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/builder.go b/vendor/github.com/authzed/spicedb/internal/caveats/builder.go
new file mode 100644
index 0000000..0c93d39
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/caveats/builder.go
@@ -0,0 +1,152 @@
+package caveats
+
+import (
+ "google.golang.org/protobuf/types/known/structpb"
+
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+)
+
+// CaveatAsExpr wraps a contextualized caveat into a caveat expression.
+func CaveatAsExpr(caveat *core.ContextualizedCaveat) *core.CaveatExpression {
+ if caveat == nil {
+ return nil
+ }
+
+ return &core.CaveatExpression{
+ OperationOrCaveat: &core.CaveatExpression_Caveat{
+ Caveat: caveat,
+ },
+ }
+}
+
+// CaveatForTesting returns a new ContextualizedCaveat for testing, with empty context.
+func CaveatForTesting(name string) *core.ContextualizedCaveat {
+ return &core.ContextualizedCaveat{
+ CaveatName: name,
+ }
+}
+
+// CaveatExprForTesting returns a CaveatExpression referencing a caveat with the given name and
+// empty context.
+func CaveatExprForTesting(name string) *core.CaveatExpression {
+ return &core.CaveatExpression{
+ OperationOrCaveat: &core.CaveatExpression_Caveat{
+ Caveat: CaveatForTesting(name),
+ },
+ }
+}
+
+// MustCaveatExprForTestingWithContext returns a CaveatExpression referencing a caveat with the given name and
+// given context.
+func MustCaveatExprForTestingWithContext(name string, context map[string]any) *core.CaveatExpression {
+ contextStruct, err := structpb.NewStruct(context)
+ if err != nil {
+ panic(err)
+ }
+
+ return &core.CaveatExpression{
+ OperationOrCaveat: &core.CaveatExpression_Caveat{
+ Caveat: &core.ContextualizedCaveat{
+ CaveatName: name,
+ Context: contextStruct,
+ },
+ },
+ }
+}
+
+// ShortcircuitedOr combines two caveat expressions via an `||`. If one of the expressions is nil,
+// then the entire expression is *short-circuited*, and a nil is returned.
+func ShortcircuitedOr(first *core.CaveatExpression, second *core.CaveatExpression) *core.CaveatExpression {
+ if first == nil || second == nil {
+ return nil
+ }
+
+ return Or(first, second)
+}
+
+// Or `||`'s together two caveat expressions. If one expression is nil, the other is returned.
+func Or(first *core.CaveatExpression, second *core.CaveatExpression) *core.CaveatExpression {
+ if first == nil {
+ return second
+ }
+
+ if second == nil {
+ return first
+ }
+
+ if first.EqualVT(second) {
+ return first
+ }
+
+ return &core.CaveatExpression{
+ OperationOrCaveat: &core.CaveatExpression_Operation{
+ Operation: &core.CaveatOperation{
+ Op: core.CaveatOperation_OR,
+ Children: []*core.CaveatExpression{first, second},
+ },
+ },
+ }
+}
+
+// And `&&`'s together two caveat expressions. If one expression is nil, the other is returned.
+func And(first *core.CaveatExpression, second *core.CaveatExpression) *core.CaveatExpression {
+ if first == nil {
+ return second
+ }
+
+ if second == nil {
+ return first
+ }
+
+ if first.EqualVT(second) {
+ return first
+ }
+
+ return &core.CaveatExpression{
+ OperationOrCaveat: &core.CaveatExpression_Operation{
+ Operation: &core.CaveatOperation{
+ Op: core.CaveatOperation_AND,
+ Children: []*core.CaveatExpression{first, second},
+ },
+ },
+ }
+}
+
+// Invert returns the caveat expression with a `!` placed in front of it. If the expression is
+// nil, returns nil.
+func Invert(ce *core.CaveatExpression) *core.CaveatExpression {
+ if ce == nil {
+ return nil
+ }
+
+ return &core.CaveatExpression{
+ OperationOrCaveat: &core.CaveatExpression_Operation{
+ Operation: &core.CaveatOperation{
+ Op: core.CaveatOperation_NOT,
+ Children: []*core.CaveatExpression{ce},
+ },
+ },
+ }
+}
+
+// Subtract returns a caveat expression representing the subtracted expression subtracted from the given
+// expression.
+func Subtract(caveat *core.CaveatExpression, subtracted *core.CaveatExpression) *core.CaveatExpression {
+ inversion := Invert(subtracted)
+ if caveat == nil {
+ return inversion
+ }
+
+ if subtracted == nil {
+ return caveat
+ }
+
+ return &core.CaveatExpression{
+ OperationOrCaveat: &core.CaveatExpression_Operation{
+ Operation: &core.CaveatOperation{
+ Op: core.CaveatOperation_AND,
+ Children: []*core.CaveatExpression{caveat, inversion},
+ },
+ },
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/debug.go b/vendor/github.com/authzed/spicedb/internal/caveats/debug.go
new file mode 100644
index 0000000..bfcf62b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/caveats/debug.go
@@ -0,0 +1,164 @@
+package caveats
+
+import (
+ "fmt"
+ "maps"
+ "strconv"
+ "strings"
+
+ "google.golang.org/protobuf/types/known/structpb"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// BuildDebugInformation returns a human-readable string representation of the given
+// ExpressionResult and a Struct representation of the context values used in the expression.
+func BuildDebugInformation(exprResult ExpressionResult) (string, *structpb.Struct, error) {
+ // If a concrete result, return its information directly.
+ if concrete, ok := exprResult.(*caveats.CaveatResult); ok {
+ exprString, err := concrete.ParentCaveat().ExprString()
+ if err != nil {
+ return "", nil, err
+ }
+
+ contextStruct, err := concrete.ContextStruct()
+ if err != nil {
+ return "", nil, err
+ }
+
+ return exprString, contextStruct, nil
+ }
+
+ // Collect parameters which are shared across expressions.
+ syntheticResult, ok := exprResult.(syntheticResult)
+ if !ok {
+ return "", nil, spiceerrors.MustBugf("unknown ExpressionResult type: %T", exprResult)
+ }
+
+ resultsByParam := mapz.NewMultiMap[string, *caveats.CaveatResult]()
+ if err := collectParameterUsage(syntheticResult, resultsByParam); err != nil {
+ return "", nil, err
+ }
+
+ // Build the synthetic debug information.
+ exprString, contextMap, err := buildDebugInformation(syntheticResult, resultsByParam)
+ if err != nil {
+ return "", nil, err
+ }
+
+ // Convert the context map to a struct.
+ contextStruct, err := caveats.ConvertContextToStruct(contextMap)
+ if err != nil {
+ return "", nil, err
+ }
+
+ return exprString, contextStruct, nil
+}
+
+func buildDebugInformation(sr syntheticResult, resultsByParam *mapz.MultiMap[string, *caveats.CaveatResult]) (string, map[string]any, error) {
+ childExprStrings := make([]string, 0, len(sr.exprResultsForDebug))
+ combinedContext := map[string]any{}
+
+ for _, child := range sr.exprResultsForDebug {
+ if _, ok := child.(*caveats.CaveatResult); ok {
+ childExprString, contextMap, err := buildDebugInformationForConcrete(child.(*caveats.CaveatResult), resultsByParam)
+ if err != nil {
+ return "", nil, err
+ }
+
+ childExprStrings = append(childExprStrings, "("+childExprString+")")
+ maps.Copy(combinedContext, contextMap)
+ continue
+ }
+
+ childExprString, contextMap, err := buildDebugInformation(child.(syntheticResult), resultsByParam)
+ if err != nil {
+ return "", nil, err
+ }
+
+ childExprStrings = append(childExprStrings, "("+childExprString+")")
+ maps.Copy(combinedContext, contextMap)
+ }
+
+ var combinedExprString string
+ switch sr.op {
+ case corev1.CaveatOperation_AND:
+ combinedExprString = strings.Join(childExprStrings, " && ")
+
+ case corev1.CaveatOperation_OR:
+ combinedExprString = strings.Join(childExprStrings, " || ")
+
+ case corev1.CaveatOperation_NOT:
+ if len(childExprStrings) != 1 {
+ return "", nil, spiceerrors.MustBugf("NOT operator must have exactly one child")
+ }
+
+ combinedExprString = "!" + childExprStrings[0]
+
+ default:
+ return "", nil, fmt.Errorf("unknown operator: %v", sr.op)
+ }
+
+ return combinedExprString, combinedContext, nil
+}
+
+func buildDebugInformationForConcrete(cr *caveats.CaveatResult, resultsByParam *mapz.MultiMap[string, *caveats.CaveatResult]) (string, map[string]any, error) {
+ // For each paramter used in the context of the caveat, check if it is shared across multiple
+ // caveats. If so, rewrite the parameter to a unique name.
+ existingContextMap := cr.ContextValues()
+ contextMap := make(map[string]any, len(existingContextMap))
+
+ caveat := *cr.ParentCaveat()
+
+ for paramName, paramValue := range existingContextMap {
+ index := mapz.IndexOfValueInMultimap(resultsByParam, paramName, cr)
+ if resultsByParam.CountOf(paramName) > 1 {
+ newName := paramName + "__" + strconv.Itoa(index)
+ if resultsByParam.Has(newName) {
+ return "", nil, fmt.Errorf("failed to generate unique name for parameter: %s", newName)
+ }
+
+ rewritten, err := caveat.RewriteVariable(paramName, newName)
+ if err != nil {
+ return "", nil, err
+ }
+
+ caveat = rewritten
+ contextMap[newName] = paramValue
+ continue
+ }
+
+ contextMap[paramName] = paramValue
+ }
+
+ exprString, err := caveat.ExprString()
+ if err != nil {
+ return "", nil, err
+ }
+
+ return exprString, contextMap, nil
+}
+
+func collectParameterUsage(sr syntheticResult, resultsByParam *mapz.MultiMap[string, *caveats.CaveatResult]) error {
+ for _, exprResult := range sr.exprResultsForDebug {
+ if concrete, ok := exprResult.(*caveats.CaveatResult); ok {
+ for paramName := range concrete.ContextValues() {
+ resultsByParam.Add(paramName, concrete)
+ }
+ } else {
+ cast, ok := exprResult.(syntheticResult)
+ if !ok {
+ return spiceerrors.MustBugf("unknown ExpressionResult type: %T", exprResult)
+ }
+
+ if err := collectParameterUsage(cast, resultsByParam); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/doc.go b/vendor/github.com/authzed/spicedb/internal/caveats/doc.go
new file mode 100644
index 0000000..587d7d8
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/caveats/doc.go
@@ -0,0 +1,2 @@
+// Package caveats contains code to evaluate a caveat with a given context.
+package caveats
diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/errors.go b/vendor/github.com/authzed/spicedb/internal/caveats/errors.go
new file mode 100644
index 0000000..06284f2
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/caveats/errors.go
@@ -0,0 +1,107 @@
+package caveats
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/rs/zerolog"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// EvaluationError is an error in evaluation of a caveat expression.
+type EvaluationError struct {
+ error
+ caveatExpr *core.CaveatExpression
+ evalErr caveats.EvaluationError
+}
+
+// MarshalZerologObject implements zerolog.LogObjectMarshaler
+func (err EvaluationError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("caveat_name", err.caveatExpr.GetCaveat().CaveatName).Interface("context", err.caveatExpr.GetCaveat().Context)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err EvaluationError) DetailsMetadata() map[string]string {
+ return spiceerrors.CombineMetadata(err.evalErr, map[string]string{
+ "caveat_name": err.caveatExpr.GetCaveat().CaveatName,
+ })
+}
+
+func (err EvaluationError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_CAVEAT_EVALUATION_ERROR,
+ err.DetailsMetadata(),
+ ),
+ )
+}
+
+func NewEvaluationError(caveatExpr *core.CaveatExpression, err caveats.EvaluationError) EvaluationError {
+ return EvaluationError{
+ fmt.Errorf("evaluation error for caveat %s: %w", caveatExpr.GetCaveat().CaveatName, err), caveatExpr, err,
+ }
+}
+
+// ParameterTypeError is a type error in constructing a parameter from a value.
+type ParameterTypeError struct {
+ error
+ caveatExpr *core.CaveatExpression
+ conversionError *caveats.ParameterConversionError
+}
+
+// MarshalZerologObject implements zerolog.LogObjectMarshaler
+func (err ParameterTypeError) MarshalZerologObject(e *zerolog.Event) {
+ evt := e.Err(err.error).
+ Str("caveat_name", err.caveatExpr.GetCaveat().CaveatName).
+ Interface("context", err.caveatExpr.GetCaveat().Context)
+
+ if err.conversionError != nil {
+ evt.Str("parameter_name", err.conversionError.ParameterName())
+ }
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err ParameterTypeError) DetailsMetadata() map[string]string {
+ if err.conversionError != nil {
+ return spiceerrors.CombineMetadata(err.conversionError, map[string]string{
+ "caveat_name": err.caveatExpr.GetCaveat().CaveatName,
+ })
+ }
+
+ return map[string]string{
+ "caveat_name": err.caveatExpr.GetCaveat().CaveatName,
+ }
+}
+
+func (err ParameterTypeError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_CAVEAT_PARAMETER_TYPE_ERROR,
+ err.DetailsMetadata(),
+ ),
+ )
+}
+
+func NewParameterTypeError(caveatExpr *core.CaveatExpression, err error) ParameterTypeError {
+ conversionError := &caveats.ParameterConversionError{}
+ if !errors.As(err, conversionError) {
+ conversionError = nil
+ }
+
+ return ParameterTypeError{
+ fmt.Errorf("type error for parameters for caveat `%s`: %w", caveatExpr.GetCaveat().CaveatName, err),
+ caveatExpr,
+ conversionError,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/caveats/run.go b/vendor/github.com/authzed/spicedb/internal/caveats/run.go
new file mode 100644
index 0000000..1aed483
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/caveats/run.go
@@ -0,0 +1,427 @@
+package caveats
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "maps"
+
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+var tracer = otel.Tracer("spicedb/internal/caveats/run")
+
+// RunCaveatExpressionDebugOption are the options for running caveat expression evaluation
+// with debugging enabled or disabled.
+type RunCaveatExpressionDebugOption int
+
+const (
+ // RunCaveatExpressionNoDebugging runs the evaluation without debugging enabled.
+ RunCaveatExpressionNoDebugging RunCaveatExpressionDebugOption = 0
+
+ // RunCaveatExpressionWithDebugInformation runs the evaluation with debugging enabled.
+ RunCaveatExpressionWithDebugInformation RunCaveatExpressionDebugOption = 1
+)
+
+// RunSingleCaveatExpression runs a caveat expression over the given context and returns the result.
+// This instantiates its own CaveatRunner, and should therefore only be used in one-off situations.
+func RunSingleCaveatExpression(
+ ctx context.Context,
+ ts *caveattypes.TypeSet,
+ expr *core.CaveatExpression,
+ context map[string]any,
+ reader datastore.CaveatReader,
+ debugOption RunCaveatExpressionDebugOption,
+) (ExpressionResult, error) {
+ runner := NewCaveatRunner(ts)
+ return runner.RunCaveatExpression(ctx, expr, context, reader, debugOption)
+}
+
+// CaveatRunner is a helper for running caveats, providing a cache for deserialized caveats.
+type CaveatRunner struct {
+ caveatTypeSet *caveattypes.TypeSet
+ caveatDefs map[string]*core.CaveatDefinition
+ deserializedCaveats map[string]*caveats.CompiledCaveat
+}
+
+// NewCaveatRunner creates a new CaveatRunner.
+func NewCaveatRunner(ts *caveattypes.TypeSet) *CaveatRunner {
+ return &CaveatRunner{
+ caveatTypeSet: ts,
+ caveatDefs: map[string]*core.CaveatDefinition{},
+ deserializedCaveats: map[string]*caveats.CompiledCaveat{},
+ }
+}
+
+// RunCaveatExpression runs a caveat expression over the given context and returns the result.
+func (cr *CaveatRunner) RunCaveatExpression(
+ ctx context.Context,
+ expr *core.CaveatExpression,
+ context map[string]any,
+ reader datastore.CaveatReader,
+ debugOption RunCaveatExpressionDebugOption,
+) (ExpressionResult, error) {
+ ctx, span := tracer.Start(ctx, "RunCaveatExpression")
+ defer span.End()
+
+ if err := cr.PopulateCaveatDefinitionsForExpr(ctx, expr, reader); err != nil {
+ return nil, err
+ }
+
+ env := caveats.NewEnvironment()
+ return cr.runExpressionWithCaveats(ctx, env, expr, context, debugOption)
+}
+
+// PopulateCaveatDefinitionsForExpr populates the CaveatRunner's cache with the definitions
+// referenced in the given caveat expression.
+func (cr *CaveatRunner) PopulateCaveatDefinitionsForExpr(ctx context.Context, expr *core.CaveatExpression, reader datastore.CaveatReader) error {
+ ctx, span := tracer.Start(ctx, "PopulateCaveatDefinitions")
+ defer span.End()
+
+ // Collect all referenced caveat definitions in the expression.
+ caveatNames := mapz.NewSet[string]()
+ collectCaveatNames(expr, caveatNames)
+
+ span.AddEvent("collected caveat names")
+ span.SetAttributes(attribute.StringSlice("caveat-names", caveatNames.AsSlice()))
+
+ if caveatNames.IsEmpty() {
+ return fmt.Errorf("received empty caveat expression")
+ }
+
+ // Remove any caveats already loaded.
+ for name := range cr.caveatDefs {
+ caveatNames.Delete(name)
+ }
+
+ if caveatNames.IsEmpty() {
+ return nil
+ }
+
+ // Bulk lookup all of the referenced caveat definitions.
+ caveatDefs, err := reader.LookupCaveatsWithNames(ctx, caveatNames.AsSlice())
+ if err != nil {
+ return err
+ }
+ span.AddEvent("looked up caveats")
+
+ for _, cd := range caveatDefs {
+ cr.caveatDefs[cd.Definition.GetName()] = cd.Definition
+ }
+
+ return nil
+}
+
+// get retrieves a caveat definition and its deserialized form. The caveat name must be
+// present in the CaveatRunner's cache.
+func (cr *CaveatRunner) get(caveatDefName string) (*core.CaveatDefinition, *caveats.CompiledCaveat, error) {
+ caveat, ok := cr.caveatDefs[caveatDefName]
+ if !ok {
+ return nil, nil, datastore.NewCaveatNameNotFoundErr(caveatDefName)
+ }
+
+ deserialized, ok := cr.deserializedCaveats[caveatDefName]
+ if ok {
+ return caveat, deserialized, nil
+ }
+
+ parameterTypes, err := caveattypes.DecodeParameterTypes(cr.caveatTypeSet, caveat.ParameterTypes)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ justDeserialized, err := caveats.DeserializeCaveatWithTypeSet(cr.caveatTypeSet, caveat.SerializedExpression, parameterTypes)
+ if err != nil {
+ return caveat, nil, err
+ }
+
+ cr.deserializedCaveats[caveatDefName] = justDeserialized
+ return caveat, justDeserialized, nil
+}
+
+func collectCaveatNames(expr *core.CaveatExpression, caveatNames *mapz.Set[string]) {
+ if expr.GetCaveat() != nil {
+ caveatNames.Add(expr.GetCaveat().CaveatName)
+ return
+ }
+
+ cop := expr.GetOperation()
+ for _, child := range cop.Children {
+ collectCaveatNames(child, caveatNames)
+ }
+}
+
+func (cr *CaveatRunner) runExpressionWithCaveats(
+ ctx context.Context,
+ env *caveats.Environment,
+ expr *core.CaveatExpression,
+ context map[string]any,
+ debugOption RunCaveatExpressionDebugOption,
+) (ExpressionResult, error) {
+ ctx, span := tracer.Start(ctx, "runExpressionWithCaveats")
+ defer span.End()
+
+ if expr.GetCaveat() != nil {
+ span.SetAttributes(attribute.String("caveat-name", expr.GetCaveat().CaveatName))
+
+ caveat, compiled, err := cr.get(expr.GetCaveat().CaveatName)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create a combined context, with the written context taking precedence over that specified.
+ untypedFullContext := maps.Clone(context)
+ if untypedFullContext == nil {
+ untypedFullContext = map[string]any{}
+ }
+
+ relationshipContext := expr.GetCaveat().GetContext().AsMap()
+ maps.Copy(untypedFullContext, relationshipContext)
+
+ // Perform type checking and conversion on the context map.
+ typedParameters, err := caveats.ConvertContextToParameters(
+ cr.caveatTypeSet,
+ untypedFullContext,
+ caveat.ParameterTypes,
+ caveats.SkipUnknownParameters,
+ )
+ if err != nil {
+ return nil, NewParameterTypeError(expr, err)
+ }
+
+ result, err := caveats.EvaluateCaveat(compiled, typedParameters)
+ if err != nil {
+ var evalErr caveats.EvaluationError
+ if errors.As(err, &evalErr) {
+ return nil, NewEvaluationError(expr, evalErr)
+ }
+
+ return nil, err
+ }
+
+ return result, nil
+ }
+
+ cop := expr.GetOperation()
+ span.SetAttributes(attribute.String("caveat-operation", cop.Op.String()))
+
+ var currentResult ExpressionResult = syntheticResult{
+ value: cop.Op == core.CaveatOperation_AND,
+ isPartialResult: false,
+ }
+
+ var exprResultsForDebug []ExpressionResult
+ if debugOption == RunCaveatExpressionWithDebugInformation {
+ exprResultsForDebug = make([]ExpressionResult, 0, len(cop.Children))
+ }
+
+ var missingVarNames *mapz.Set[string]
+ if debugOption == RunCaveatExpressionNoDebugging {
+ missingVarNames = mapz.NewSet[string]()
+ }
+
+ and := func(existing ExpressionResult, found ExpressionResult) (ExpressionResult, error) {
+ if !existing.IsPartial() && !existing.Value() {
+ return syntheticResult{
+ value: false,
+ op: core.CaveatOperation_AND,
+ exprResultsForDebug: exprResultsForDebug,
+ isPartialResult: false,
+ missingVarNames: nil,
+ }, nil
+ }
+
+ if !found.IsPartial() && !found.Value() {
+ return syntheticResult{
+ value: false,
+ op: core.CaveatOperation_AND,
+ exprResultsForDebug: exprResultsForDebug,
+ isPartialResult: false,
+ missingVarNames: nil,
+ }, nil
+ }
+
+ value := existing.Value() && found.Value()
+ if existing.IsPartial() || found.IsPartial() {
+ value = false
+ }
+
+ return syntheticResult{
+ value: value,
+ op: core.CaveatOperation_AND,
+ exprResultsForDebug: exprResultsForDebug,
+ isPartialResult: existing.IsPartial() || found.IsPartial(),
+ missingVarNames: missingVarNames,
+ }, nil
+ }
+
+ or := func(existing ExpressionResult, found ExpressionResult) (ExpressionResult, error) {
+ if !existing.IsPartial() && existing.Value() {
+ return syntheticResult{
+ value: true,
+ op: core.CaveatOperation_OR,
+ exprResultsForDebug: exprResultsForDebug,
+ isPartialResult: false,
+ missingVarNames: nil,
+ }, nil
+ }
+
+ if !found.IsPartial() && found.Value() {
+ return syntheticResult{
+ value: true,
+ op: core.CaveatOperation_OR,
+ exprResultsForDebug: exprResultsForDebug,
+ isPartialResult: false,
+ missingVarNames: nil,
+ }, nil
+ }
+
+ value := existing.Value() || found.Value()
+ if existing.IsPartial() || found.IsPartial() {
+ value = false
+ }
+
+ return syntheticResult{
+ value: value,
+ op: core.CaveatOperation_OR,
+ exprResultsForDebug: exprResultsForDebug,
+ isPartialResult: existing.IsPartial() || found.IsPartial(),
+ missingVarNames: missingVarNames,
+ }, nil
+ }
+
+ invert := func(existing ExpressionResult) (ExpressionResult, error) {
+ value := !existing.Value()
+ if existing.IsPartial() {
+ value = false
+ }
+
+ return syntheticResult{
+ value: value,
+ op: core.CaveatOperation_NOT,
+ exprResultsForDebug: exprResultsForDebug,
+ isPartialResult: existing.IsPartial(),
+ missingVarNames: missingVarNames,
+ }, nil
+ }
+
+ for _, child := range cop.Children {
+ childResult, err := cr.runExpressionWithCaveats(ctx, env, child, context, debugOption)
+ if err != nil {
+ return nil, err
+ }
+
+ if debugOption != RunCaveatExpressionNoDebugging {
+ exprResultsForDebug = append(exprResultsForDebug, childResult)
+ } else if childResult.IsPartial() {
+ missingVars, err := childResult.MissingVarNames()
+ if err != nil {
+ return nil, err
+ }
+
+ missingVarNames.Extend(missingVars)
+ }
+
+ switch cop.Op {
+ case core.CaveatOperation_AND:
+ cr, err := and(currentResult, childResult)
+ if err != nil {
+ return nil, err
+ }
+
+ currentResult = cr
+ if debugOption == RunCaveatExpressionNoDebugging && isFalseResult(currentResult) {
+ return currentResult, nil
+ }
+
+ case core.CaveatOperation_OR:
+ cr, err := or(currentResult, childResult)
+ if err != nil {
+ return nil, err
+ }
+
+ currentResult = cr
+ if debugOption == RunCaveatExpressionNoDebugging && isTrueResult(currentResult) {
+ return currentResult, nil
+ }
+
+ case core.CaveatOperation_NOT:
+ return invert(childResult)
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown caveat operation: %v", cop.Op)
+ }
+ }
+
+ return currentResult, nil
+}
+
+// ExpressionResult is the result of a caveat expression being run.
+// See also caveats.CaveatResult
+type ExpressionResult interface {
+ // Value is the resolved value for the expression. For partially applied expressions, this value will be false.
+ Value() bool
+
+ // IsPartial returns whether the expression was only partially applied.
+ IsPartial() bool
+
+ // MissingVarNames returns the names of the parameters missing from the context.
+ MissingVarNames() ([]string, error)
+}
+
+type syntheticResult struct {
+ value bool
+ isPartialResult bool
+
+ op core.CaveatOperation_Operation
+ exprResultsForDebug []ExpressionResult
+ missingVarNames *mapz.Set[string]
+}
+
+func (sr syntheticResult) Value() bool {
+ return sr.value
+}
+
+func (sr syntheticResult) IsPartial() bool {
+ return sr.isPartialResult
+}
+
+func (sr syntheticResult) MissingVarNames() ([]string, error) {
+ if sr.isPartialResult {
+ if sr.missingVarNames != nil {
+ return sr.missingVarNames.AsSlice(), nil
+ }
+
+ missingVarNames := mapz.NewSet[string]()
+ for _, exprResult := range sr.exprResultsForDebug {
+ if exprResult.IsPartial() {
+ found, err := exprResult.MissingVarNames()
+ if err != nil {
+ return nil, err
+ }
+
+ missingVarNames.Extend(found)
+ }
+ }
+
+ return missingVarNames.AsSlice(), nil
+ }
+
+ return nil, fmt.Errorf("not a partial value")
+}
+
+func isFalseResult(result ExpressionResult) bool {
+ return !result.Value() && !result.IsPartial()
+}
+
+func isTrueResult(result ExpressionResult) bool {
+ return result.Value() && !result.IsPartial()
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/basesubjectset.go b/vendor/github.com/authzed/spicedb/internal/datasets/basesubjectset.go
new file mode 100644
index 0000000..80ab666
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datasets/basesubjectset.go
@@ -0,0 +1,856 @@
+package datasets
+
+import (
+ "golang.org/x/exp/maps"
+
+ "github.com/authzed/spicedb/internal/caveats"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+var (
+ caveatAnd = caveats.And
+ caveatOr = caveats.Or
+ caveatInvert = caveats.Invert
+ shortcircuitedOr = caveats.ShortcircuitedOr
+)
+
+// Subject is a subject that can be placed into a BaseSubjectSet. It is defined in a generic
+// manner to allow implementations that wrap BaseSubjectSet to add their own additional bookkeeping
+// to the base implementation.
+type Subject[T any] interface {
+ // GetSubjectId returns the ID of the subject. For wildcards, this should be `*`.
+ GetSubjectId() string
+
+ // GetCaveatExpression returns the caveat expression for this subject, if it is conditional.
+ GetCaveatExpression() *core.CaveatExpression
+
+ // GetExcludedSubjects returns the list of subjects excluded. Must only have values
+ // for wildcards and must never be nested.
+ GetExcludedSubjects() []T
+}
+
+// BaseSubjectSet defines a set that tracks accessible subjects, their exclusions (if wildcards),
+// and all conditional expressions applied due to caveats.
+//
+// It is generic to allow other implementations to define the kind of tracking information
+// associated with each subject.
+//
+// NOTE: Unlike a traditional set, unions between wildcards and a concrete subject will result
+// in *both* being present in the set, to maintain the proper set semantics around wildcards.
+type BaseSubjectSet[T Subject[T]] struct {
+ constructor constructor[T]
+ concrete map[string]T
+ wildcard *handle[T]
+}
+
+// NewBaseSubjectSet creates a new base subject set for use underneath well-typed implementation.
+//
+// The constructor function returns a new instance of type T for a particular subject ID.
+func NewBaseSubjectSet[T Subject[T]](constructor constructor[T]) BaseSubjectSet[T] {
+ return BaseSubjectSet[T]{
+ constructor: constructor,
+ concrete: map[string]T{},
+ wildcard: newHandle[T](),
+ }
+}
+
+// constructor defines a function for constructing a new instance of the Subject type T for
+// a subject ID, its (optional) conditional expression, any excluded subjects, and any sources
+// for bookkeeping. The sources are those other subjects that were combined to create the current
+// subject.
+type constructor[T Subject[T]] func(subjectID string, conditionalExpression *core.CaveatExpression, excludedSubjects []T, sources ...T) T
+
+// MustAdd adds the found subject to the set. This is equivalent to a Union operation between the
+// existing set of subjects and a set containing the single subject, but modifies the set
+// *in place*.
+func (bss BaseSubjectSet[T]) MustAdd(foundSubject T) {
+ err := bss.Add(foundSubject)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// Add adds the found subject to the set. This is equivalent to a Union operation between the
+// existing set of subjects and a set containing the single subject, but modifies the set
+// *in place*.
+func (bss BaseSubjectSet[T]) Add(foundSubject T) error {
+ if foundSubject.GetSubjectId() == tuple.PublicWildcard {
+ existing := bss.wildcard.getOrNil()
+ updated, err := unionWildcardWithWildcard(existing, foundSubject, bss.constructor)
+ if err != nil {
+ return err
+ }
+
+ bss.wildcard.setOrNil(updated)
+
+ for _, concrete := range bss.concrete {
+ updated = unionWildcardWithConcrete(updated, concrete, bss.constructor)
+ }
+ bss.wildcard.setOrNil(updated)
+ return nil
+ }
+
+ var updatedOrNil *T
+ if updated, ok := bss.concrete[foundSubject.GetSubjectId()]; ok {
+ updatedOrNil = &updated
+ }
+ bss.setConcrete(foundSubject.GetSubjectId(), unionConcreteWithConcrete(updatedOrNil, &foundSubject, bss.constructor))
+
+ wildcard := bss.wildcard.getOrNil()
+ wildcard = unionWildcardWithConcrete(wildcard, foundSubject, bss.constructor)
+ bss.wildcard.setOrNil(wildcard)
+ return nil
+}
+
+func (bss BaseSubjectSet[T]) setConcrete(subjectID string, subjectOrNil *T) {
+ if subjectOrNil == nil {
+ delete(bss.concrete, subjectID)
+ return
+ }
+
+ subject := *subjectOrNil
+ bss.concrete[subject.GetSubjectId()] = subject
+}
+
+// Subtract subtracts the given subject found the set.
+func (bss BaseSubjectSet[T]) Subtract(toRemove T) {
+ if toRemove.GetSubjectId() == tuple.PublicWildcard {
+ for _, concrete := range bss.concrete {
+ bss.setConcrete(concrete.GetSubjectId(), subtractWildcardFromConcrete(concrete, toRemove, bss.constructor))
+ }
+
+ existing := bss.wildcard.getOrNil()
+ updatedWildcard, concretesToAdd := subtractWildcardFromWildcard(existing, toRemove, bss.constructor)
+ bss.wildcard.setOrNil(updatedWildcard)
+ for _, concrete := range concretesToAdd {
+ concrete := concrete
+ bss.setConcrete(concrete.GetSubjectId(), &concrete)
+ }
+ return
+ }
+
+ if existing, ok := bss.concrete[toRemove.GetSubjectId()]; ok {
+ bss.setConcrete(toRemove.GetSubjectId(), subtractConcreteFromConcrete(existing, toRemove, bss.constructor))
+ }
+
+ wildcard, ok := bss.wildcard.get()
+ if ok {
+ bss.wildcard.setOrNil(subtractConcreteFromWildcard(wildcard, toRemove, bss.constructor))
+ }
+}
+
+// SubtractAll subtracts the other set of subjects from this set of subtracts, modifying this
+// set *in place*.
+func (bss BaseSubjectSet[T]) SubtractAll(other BaseSubjectSet[T]) {
+ for _, otherSubject := range other.AsSlice() {
+ bss.Subtract(otherSubject)
+ }
+}
+
+// MustIntersectionDifference performs an intersection between this set and the other set, modifying
+// this set *in place*.
+func (bss BaseSubjectSet[T]) MustIntersectionDifference(other BaseSubjectSet[T]) {
+ err := bss.IntersectionDifference(other)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// IntersectionDifference performs an intersection between this set and the other set, modifying
+// this set *in place*.
+func (bss BaseSubjectSet[T]) IntersectionDifference(other BaseSubjectSet[T]) error {
+ // Intersect the wildcards of the sets, if any.
+ existingWildcard := bss.wildcard.getOrNil()
+ otherWildcard := other.wildcard.getOrNil()
+
+ intersection, err := intersectWildcardWithWildcard(existingWildcard, otherWildcard, bss.constructor)
+ if err != nil {
+ return err
+ }
+
+ bss.wildcard.setOrNil(intersection)
+
+ // Intersect the concretes of each set, as well as with the wildcards.
+ updatedConcretes := make(map[string]T, len(bss.concrete))
+
+ for _, concreteSubject := range bss.concrete {
+ var otherConcreteOrNil *T
+ if otherConcrete, ok := other.concrete[concreteSubject.GetSubjectId()]; ok {
+ otherConcreteOrNil = &otherConcrete
+ }
+
+ concreteIntersected := intersectConcreteWithConcrete(concreteSubject, otherConcreteOrNil, bss.constructor)
+ otherWildcardIntersected, err := intersectConcreteWithWildcard(concreteSubject, otherWildcard, bss.constructor)
+ if err != nil {
+ return err
+ }
+
+ result := unionConcreteWithConcrete(concreteIntersected, otherWildcardIntersected, bss.constructor)
+ if result != nil {
+ updatedConcretes[concreteSubject.GetSubjectId()] = *result
+ }
+ }
+
+ if existingWildcard != nil {
+ for _, otherSubject := range other.concrete {
+ existingWildcardIntersect, err := intersectConcreteWithWildcard(otherSubject, existingWildcard, bss.constructor)
+ if err != nil {
+ return err
+ }
+
+ if existingUpdated, ok := updatedConcretes[otherSubject.GetSubjectId()]; ok {
+ result := unionConcreteWithConcrete(&existingUpdated, existingWildcardIntersect, bss.constructor)
+ updatedConcretes[otherSubject.GetSubjectId()] = *result
+ } else if existingWildcardIntersect != nil {
+ updatedConcretes[otherSubject.GetSubjectId()] = *existingWildcardIntersect
+ }
+ }
+ }
+
+ clear(bss.concrete)
+ maps.Copy(bss.concrete, updatedConcretes)
+ return nil
+}
+
+// UnionWith adds the given subjects to this set, via a union call.
+func (bss BaseSubjectSet[T]) UnionWith(foundSubjects []T) error {
+ for _, fs := range foundSubjects {
+ err := bss.Add(fs)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// UnionWithSet performs a union operation between this set and the other set, modifying this
+// set *in place*.
+func (bss BaseSubjectSet[T]) UnionWithSet(other BaseSubjectSet[T]) error {
+ return bss.UnionWith(other.AsSlice())
+}
+
+// MustUnionWithSet performs a union operation between this set and the other set, modifying this
+// set *in place*.
+func (bss BaseSubjectSet[T]) MustUnionWithSet(other BaseSubjectSet[T]) {
+ err := bss.UnionWithSet(other)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// Get returns the found subject with the given ID in the set, if any.
+func (bss BaseSubjectSet[T]) Get(id string) (T, bool) {
+ if id == tuple.PublicWildcard {
+ return bss.wildcard.get()
+ }
+
+ found, ok := bss.concrete[id]
+ return found, ok
+}
+
+// IsEmpty returns whether the subject set is empty.
+func (bss BaseSubjectSet[T]) IsEmpty() bool {
+ return bss.wildcard.getOrNil() == nil && len(bss.concrete) == 0
+}
+
+// AsSlice returns the contents of the subject set as a slice of found subjects.
+func (bss BaseSubjectSet[T]) AsSlice() []T {
+ values := maps.Values(bss.concrete)
+ if wildcard, ok := bss.wildcard.get(); ok {
+ values = append(values, wildcard)
+ }
+ return values
+}
+
+// Clone returns a clone of this subject set. Note that this is a shallow clone.
+// NOTE: Should only be used when performance is not a concern.
+func (bss BaseSubjectSet[T]) Clone() BaseSubjectSet[T] {
+ return BaseSubjectSet[T]{
+ constructor: bss.constructor,
+ concrete: maps.Clone(bss.concrete),
+ wildcard: bss.wildcard.clone(),
+ }
+}
+
+// UnsafeRemoveExact removes the *exact* matching subject, with no wildcard handling.
+// This should ONLY be used for testing.
+func (bss BaseSubjectSet[T]) UnsafeRemoveExact(foundSubject T) {
+ if foundSubject.GetSubjectId() == tuple.PublicWildcard {
+ bss.wildcard.clear()
+ return
+ }
+
+ delete(bss.concrete, foundSubject.GetSubjectId())
+}
+
+// WithParentCaveatExpression returns a copy of the subject set with the parent caveat expression applied
+// to all members of this set.
+func (bss BaseSubjectSet[T]) WithParentCaveatExpression(parentCaveatExpr *core.CaveatExpression) BaseSubjectSet[T] {
+ clone := bss.Clone()
+
+ // Apply the parent caveat expression to the wildcard, if any.
+ if wildcard, ok := clone.wildcard.get(); ok {
+ constructed := bss.constructor(
+ tuple.PublicWildcard,
+ caveatAnd(parentCaveatExpr, wildcard.GetCaveatExpression()),
+ wildcard.GetExcludedSubjects(),
+ wildcard,
+ )
+ clone.wildcard.setOrNil(&constructed)
+ }
+
+ // Apply the parent caveat expression to each concrete.
+ for subjectID, concrete := range clone.concrete {
+ clone.concrete[subjectID] = bss.constructor(
+ subjectID,
+ caveatAnd(parentCaveatExpr, concrete.GetCaveatExpression()),
+ nil,
+ concrete,
+ )
+ }
+
+ return clone
+}
+
+// unionWildcardWithWildcard performs a union operation over two wildcards, returning the updated
+// wildcard (if any).
+func unionWildcardWithWildcard[T Subject[T]](existing *T, adding T, constructor constructor[T]) (*T, error) {
+ // If there is no existing wildcard, return the added one.
+ if existing == nil {
+ return &adding, nil
+ }
+
+ // Otherwise, union together the conditionals for the wildcards and *intersect* their exclusion
+ // sets.
+ existingWildcard := *existing
+ expression := shortcircuitedOr(existingWildcard.GetCaveatExpression(), adding.GetCaveatExpression())
+
+ // Exclusion sets are intersected because if an exclusion is missing from one wildcard
+ // but not the other, the missing element will be, by definition, in that other wildcard.
+ //
+ // Examples:
+ //
+ // {*} + {*} => {*}
+ // {* - {user:tom}} + {*} => {*}
+ // {* - {user:tom}} + {* - {user:sarah}} => {*}
+ // {* - {user:tom, user:sarah}} + {* - {user:sarah}} => {* - {user:sarah}}
+ // {*}[c1] + {*} => {*}
+ // {*}[c1] + {*}[c2] => {*}[c1 || c2]
+
+ // NOTE: since we're only using concretes here, it is safe to reuse the BaseSubjectSet itself.
+ exisingConcreteExclusions := NewBaseSubjectSet(constructor)
+ for _, excludedSubject := range existingWildcard.GetExcludedSubjects() {
+ if excludedSubject.GetSubjectId() == tuple.PublicWildcard {
+ return nil, spiceerrors.MustBugf("wildcards are not allowed in exclusions")
+ }
+
+ err := exisingConcreteExclusions.Add(excludedSubject)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ foundConcreteExclusions := NewBaseSubjectSet(constructor)
+ for _, excludedSubject := range adding.GetExcludedSubjects() {
+ if excludedSubject.GetSubjectId() == tuple.PublicWildcard {
+ return nil, spiceerrors.MustBugf("wildcards are not allowed in exclusions")
+ }
+
+ err := foundConcreteExclusions.Add(excludedSubject)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err := exisingConcreteExclusions.IntersectionDifference(foundConcreteExclusions)
+ if err != nil {
+ return nil, err
+ }
+
+ constructed := constructor(
+ tuple.PublicWildcard,
+ expression,
+ exisingConcreteExclusions.AsSlice(),
+ *existing,
+ adding)
+ return &constructed, nil
+}
+
+// unionWildcardWithConcrete performs a union operation between a wildcard and a concrete subject
+// being added to the set, returning the updated wildcard (if applicable).
+func unionWildcardWithConcrete[T Subject[T]](existing *T, adding T, constructor constructor[T]) *T {
+ // If there is no existing wildcard, nothing more to do.
+ if existing == nil {
+ return nil
+ }
+
+ // If the concrete is in the exclusion set, remove it if not conditional. Otherwise, mark
+ // it as conditional.
+ //
+ // Examples:
+ // {*} | {user:tom} => {*} (and user:tom in the concrete)
+ // {* - {user:tom}} | {user:tom} => {*} (and user:tom in the concrete)
+ // {* - {user:tom}[c1]} | {user:tom}[c2] => {* - {user:tom}[c1 && !c2]} (and user:tom in the concrete)
+ existingWildcard := *existing
+ updatedExclusions := make([]T, 0, len(existingWildcard.GetExcludedSubjects()))
+ for _, existingExclusion := range existingWildcard.GetExcludedSubjects() {
+ if existingExclusion.GetSubjectId() == adding.GetSubjectId() {
+ // If the conditional on the concrete is empty, then the concrete is always present, so
+ // we remove the exclusion entirely.
+ if adding.GetCaveatExpression() == nil {
+ continue
+ }
+
+ // Otherwise, the conditional expression for the new exclusion is the existing expression &&
+ // the *inversion* of the concrete's expression, as the exclusion will only apply if the
+ // concrete subject is not present and the exclusion's expression is true.
+ exclusionConditionalExpression := caveatAnd(
+ existingExclusion.GetCaveatExpression(),
+ caveatInvert(adding.GetCaveatExpression()),
+ )
+
+ updatedExclusions = append(updatedExclusions, constructor(
+ adding.GetSubjectId(),
+ exclusionConditionalExpression,
+ nil,
+ existingExclusion,
+ adding),
+ )
+ } else {
+ updatedExclusions = append(updatedExclusions, existingExclusion)
+ }
+ }
+
+ constructed := constructor(
+ tuple.PublicWildcard,
+ existingWildcard.GetCaveatExpression(),
+ updatedExclusions,
+ existingWildcard)
+ return &constructed
+}
+
+// unionConcreteWithConcrete performs a union operation between two concrete subjects and returns
+// the concrete subject produced, if any.
+func unionConcreteWithConcrete[T Subject[T]](existing *T, adding *T, constructor constructor[T]) *T {
+ // Check for union with other concretes.
+ if existing == nil {
+ return adding
+ }
+
+ if adding == nil {
+ return existing
+ }
+
+ existingConcrete := *existing
+ addingConcrete := *adding
+
+ // A union of a concrete subjects has the conditionals of each concrete merged.
+ constructed := constructor(
+ existingConcrete.GetSubjectId(),
+ shortcircuitedOr(
+ existingConcrete.GetCaveatExpression(),
+ addingConcrete.GetCaveatExpression(),
+ ),
+ nil,
+ existingConcrete, addingConcrete)
+ return &constructed
+}
+
+// subtractWildcardFromWildcard performs a subtraction operation of wildcard from another, returning
+// the updated wildcard (if any), as well as any concrete subjects produced by the subtraction
+// operation due to exclusions.
+func subtractWildcardFromWildcard[T Subject[T]](existing *T, toRemove T, constructor constructor[T]) (*T, []T) {
+ // If there is no existing wildcard, nothing more to do.
+ if existing == nil {
+ return nil, nil
+ }
+
+ // If there is no condition on the wildcard and the new wildcard has no exclusions, then this wildcard goes away.
+ // Example: {*} - {*} => {}
+ if toRemove.GetCaveatExpression() == nil && len(toRemove.GetExcludedSubjects()) == 0 {
+ return nil, nil
+ }
+
+ // Otherwise, we construct a new wildcard and return any concrete subjects that might result from this subtraction.
+ existingWildcard := *existing
+ existingExclusions := exclusionsMapFor(existingWildcard)
+
+ // Calculate the exclusions which turn into concrete subjects.
+ // This occurs when a wildcard with exclusions is subtracted from a wildcard
+ // (with, or without *matching* exclusions).
+ //
+ // Example:
+ // Given the two wildcards `* - {user:sarah}` and `* - {user:tom, user:amy, user:sarah}`,
+ // the resulting concrete subjects are {user:tom, user:amy} because the first set contains
+ // `tom` and `amy` (but not `sarah`) and the second set contains all three.
+ resultingConcreteSubjects := make([]T, 0, len(toRemove.GetExcludedSubjects()))
+ for _, excludedSubject := range toRemove.GetExcludedSubjects() {
+ if existingExclusion, isExistingExclusion := existingExclusions[excludedSubject.GetSubjectId()]; !isExistingExclusion || existingExclusion.GetCaveatExpression() != nil {
+ // The conditional expression for the now-concrete subject type is the conditional on the provided exclusion
+ // itself.
+ //
+ // As an example, subtracting the wildcards
+ // {*[caveat1] - {user:tom}}
+ // -
+ // {*[caveat3] - {user:sarah[caveat4]}}
+ //
+ // the resulting expression to produce a *concrete* `user:sarah` is
+ // `caveat1 && caveat3 && caveat4`, because the concrete subject only appears if the first
+ // wildcard applies, the *second* wildcard applies and its exclusion applies.
+ exclusionConditionalExpression := caveatAnd(
+ caveatAnd(
+ existingWildcard.GetCaveatExpression(),
+ toRemove.GetCaveatExpression(),
+ ),
+ excludedSubject.GetCaveatExpression(),
+ )
+
+ // If there is an existing exclusion, then its caveat expression is added as well, but inverted.
+ //
+ // As an example, subtracting the wildcards
+ // {*[caveat1] - {user:tom[caveat2]}}
+ // -
+ // {*[caveat3] - {user:sarah[caveat4]}}
+ //
+ // the resulting expression to produce a *concrete* `user:sarah` is
+ // `caveat1 && !caveat2 && caveat3 && caveat4`, because the concrete subject only appears
+ // if the first wildcard applies, the *second* wildcard applies, the first exclusion
+ // does *not* apply (ensuring the concrete is in the first wildcard) and the second exclusion
+ // *does* apply (ensuring it is not in the second wildcard).
+ if existingExclusion.GetCaveatExpression() != nil {
+ exclusionConditionalExpression = caveatAnd(
+ caveatAnd(
+ caveatAnd(
+ existingWildcard.GetCaveatExpression(),
+ toRemove.GetCaveatExpression(),
+ ),
+ caveatInvert(existingExclusion.GetCaveatExpression()),
+ ),
+ excludedSubject.GetCaveatExpression(),
+ )
+ }
+
+ resultingConcreteSubjects = append(resultingConcreteSubjects, constructor(
+ excludedSubject.GetSubjectId(),
+ exclusionConditionalExpression,
+ nil, excludedSubject))
+ }
+ }
+
+ // Create the combined conditional: the wildcard can only exist when it is present and the other wildcard is not.
+ combinedConditionalExpression := caveatAnd(existingWildcard.GetCaveatExpression(), caveatInvert(toRemove.GetCaveatExpression()))
+ if combinedConditionalExpression != nil {
+ constructed := constructor(
+ tuple.PublicWildcard,
+ combinedConditionalExpression,
+ existingWildcard.GetExcludedSubjects(),
+ existingWildcard,
+ toRemove)
+ return &constructed, resultingConcreteSubjects
+ }
+
+ return nil, resultingConcreteSubjects
+}
+
+// subtractWildcardFromConcrete subtracts a wildcard from a concrete element, returning the updated
+// concrete subject, if any.
+func subtractWildcardFromConcrete[T Subject[T]](existingConcrete T, wildcardToRemove T, constructor constructor[T]) *T {
+ // Subtraction of a wildcard removes *all* elements of the concrete set, except those that
+ // are found in the excluded list. If the wildcard *itself* is conditional, then instead of
+ // items being removed, they are made conditional on the inversion of the wildcard's expression,
+ // and the exclusion's conditional, if any.
+ //
+ // Examples:
+ // {user:sarah, user:tom} - {*} => {}
+ // {user:sarah, user:tom} - {*[somecaveat]} => {user:sarah[!somecaveat], user:tom[!somecaveat]}
+ // {user:sarah, user:tom} - {* - {user:tom}} => {user:tom}
+ // {user:sarah, user:tom} - {*[somecaveat] - {user:tom}} => {user:sarah[!somecaveat], user:tom}
+ // {user:sarah, user:tom} - {* - {user:tom[c2]}}[somecaveat] => {user:sarah[!somecaveat], user:tom[c2]}
+ // {user:sarah[c1], user:tom} - {*[somecaveat] - {user:tom}} => {user:sarah[c1 && !somecaveat], user:tom}
+ exclusions := exclusionsMapFor(wildcardToRemove)
+ exclusion, isExcluded := exclusions[existingConcrete.GetSubjectId()]
+ if !isExcluded {
+ // If the subject was not excluded within the wildcard, it is either removed directly
+ // (in the case where the wildcard is not conditional), or has its condition updated to
+ // reflect that it is only present when the condition for the wildcard is *false*.
+ if wildcardToRemove.GetCaveatExpression() == nil {
+ return nil
+ }
+
+ constructed := constructor(
+ existingConcrete.GetSubjectId(),
+ caveatAnd(existingConcrete.GetCaveatExpression(), caveatInvert(wildcardToRemove.GetCaveatExpression())),
+ nil,
+ existingConcrete)
+ return &constructed
+ }
+
+ // If the exclusion is not conditional, then the subject is always present.
+ if exclusion.GetCaveatExpression() == nil {
+ return &existingConcrete
+ }
+
+ // The conditional of the exclusion is that of the exclusion itself OR the caveatInverted case of
+ // the wildcard, which would mean the wildcard itself does not apply.
+ exclusionConditional := caveatOr(caveatInvert(wildcardToRemove.GetCaveatExpression()), exclusion.GetCaveatExpression())
+
+ constructed := constructor(
+ existingConcrete.GetSubjectId(),
+ caveatAnd(existingConcrete.GetCaveatExpression(), exclusionConditional),
+ nil,
+ existingConcrete)
+ return &constructed
+}
+
+// subtractConcreteFromConcrete subtracts a concrete subject from another concrete subject.
+func subtractConcreteFromConcrete[T Subject[T]](existingConcrete T, toRemove T, constructor constructor[T]) *T {
+ // Subtraction of a concrete type removes the entry from the concrete list
+ // *unless* the subtraction is conditional, in which case the conditional is updated
+ // to remove the element when it is true.
+ //
+ // Examples:
+ // {user:sarah} - {user:tom} => {user:sarah}
+ // {user:tom} - {user:tom} => {}
+ // {user:tom[c1]} - {user:tom} => {user:tom}
+ // {user:tom} - {user:tom[c2]} => {user:tom[!c2]}
+ // {user:tom[c1]} - {user:tom[c2]} => {user:tom[c1 && !c2]}
+ if toRemove.GetCaveatExpression() == nil {
+ return nil
+ }
+
+ // Otherwise, adjust the conditional of the existing item to remove it if it is true.
+ expression := caveatAnd(
+ existingConcrete.GetCaveatExpression(),
+ caveatInvert(
+ toRemove.GetCaveatExpression(),
+ ),
+ )
+
+ constructed := constructor(
+ existingConcrete.GetSubjectId(),
+ expression,
+ nil,
+ existingConcrete, toRemove)
+ return &constructed
+}
+
+// subtractConcreteFromWildcard subtracts a concrete element from a wildcard.
+func subtractConcreteFromWildcard[T Subject[T]](wildcard T, concreteToRemove T, constructor constructor[T]) *T {
+ // Subtracting a concrete type from a wildcard adds the concrete to the exclusions for the wildcard.
+ // Examples:
+ // {*} - {user:tom} => {* - {user:tom}}
+ // {*} - {user:tom[c1]} => {* - {user:tom[c1]}}
+ // {* - {user:tom[c1]}} - {user:tom} => {* - {user:tom}}
+ // {* - {user:tom[c1]}} - {user:tom[c2]} => {* - {user:tom[c1 || c2]}}
+ updatedExclusions := make([]T, 0, len(wildcard.GetExcludedSubjects())+1)
+ wasFound := false
+ for _, existingExclusion := range wildcard.GetExcludedSubjects() {
+ if existingExclusion.GetSubjectId() == concreteToRemove.GetSubjectId() {
+ // The conditional expression for the exclusion is a combination on the existing exclusion or
+ // the new expression. The caveat is short-circuited here because if either the exclusion or
+ // the concrete is non-caveated, then the whole exclusion is non-caveated.
+ exclusionConditionalExpression := shortcircuitedOr(
+ existingExclusion.GetCaveatExpression(),
+ concreteToRemove.GetCaveatExpression(),
+ )
+
+ updatedExclusions = append(updatedExclusions, constructor(
+ concreteToRemove.GetSubjectId(),
+ exclusionConditionalExpression,
+ nil,
+ existingExclusion,
+ concreteToRemove),
+ )
+ wasFound = true
+ } else {
+ updatedExclusions = append(updatedExclusions, existingExclusion)
+ }
+ }
+
+ if !wasFound {
+ updatedExclusions = append(updatedExclusions, concreteToRemove)
+ }
+
+ constructed := constructor(
+ tuple.PublicWildcard,
+ wildcard.GetCaveatExpression(),
+ updatedExclusions,
+ wildcard)
+ return &constructed
+}
+
+// intersectConcreteWithConcrete performs intersection between two concrete subjects, returning the
+// resolved concrete subject, if any.
+func intersectConcreteWithConcrete[T Subject[T]](first T, second *T, constructor constructor[T]) *T {
+ // Intersection of concrete subjects is a standard intersection operation, where subjects
+ // must be in both sets, with a combination of the two elements into one for conditionals.
+ // Otherwise, `and` together conditionals.
+ if second == nil {
+ return nil
+ }
+
+ secondConcrete := *second
+ constructed := constructor(
+ first.GetSubjectId(),
+ caveatAnd(first.GetCaveatExpression(), secondConcrete.GetCaveatExpression()),
+ nil,
+ first,
+ secondConcrete)
+
+ return &constructed
+}
+
+// intersectWildcardWithWildcard performs intersection between two wildcards, returning the resolved
+// wildcard subject, if any.
+func intersectWildcardWithWildcard[T Subject[T]](first *T, second *T, constructor constructor[T]) (*T, error) {
+ // If either wildcard does not exist, then no wildcard is placed into the resulting set.
+ if first == nil || second == nil {
+ return nil, nil
+ }
+
+ // If the other wildcard exists, then the intersection between the two wildcards is an && of
+ // their conditionals, and a *union* of their exclusions.
+ firstWildcard := *first
+ secondWildcard := *second
+
+ concreteExclusions := NewBaseSubjectSet(constructor)
+ for _, excludedSubject := range firstWildcard.GetExcludedSubjects() {
+ if excludedSubject.GetSubjectId() == tuple.PublicWildcard {
+ return nil, spiceerrors.MustBugf("wildcards are not allowed in exclusions")
+ }
+
+ err := concreteExclusions.Add(excludedSubject)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ for _, excludedSubject := range secondWildcard.GetExcludedSubjects() {
+ if excludedSubject.GetSubjectId() == tuple.PublicWildcard {
+ return nil, spiceerrors.MustBugf("wildcards are not allowed in exclusions")
+ }
+
+ err := concreteExclusions.Add(excludedSubject)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ constructed := constructor(
+ tuple.PublicWildcard,
+ caveatAnd(firstWildcard.GetCaveatExpression(), secondWildcard.GetCaveatExpression()),
+ concreteExclusions.AsSlice(),
+ firstWildcard,
+ secondWildcard)
+ return &constructed, nil
+}
+
+// intersectConcreteWithWildcard performs intersection between a concrete subject and a wildcard
+// subject, returning the concrete, if any.
+func intersectConcreteWithWildcard[T Subject[T]](concrete T, wildcard *T, constructor constructor[T]) (*T, error) {
+ // If no wildcard exists, then the concrete cannot exist (for this branch)
+ if wildcard == nil {
+ return nil, nil
+ }
+
+ wildcardToIntersect := *wildcard
+ exclusionsMap := exclusionsMapFor(wildcardToIntersect)
+ exclusion, isExcluded := exclusionsMap[concrete.GetSubjectId()]
+
+ // Cases:
+ // - The concrete subject is not excluded and the wildcard is not conditional => concrete is kept
+ // - The concrete subject is excluded and the wildcard is not conditional but the exclusion *is* conditional => concrete is made conditional
+ // - The concrete subject is excluded and the wildcard is not conditional => concrete is removed
+ // - The concrete subject is not excluded but the wildcard is conditional => concrete is kept, but made conditional
+ // - The concrete subject is excluded and the wildcard is conditional => concrete is removed, since it is always excluded
+ // - The concrete subject is excluded and the wildcard is conditional and the exclusion is conditional => combined conditional
+ switch {
+ case !isExcluded && wildcardToIntersect.GetCaveatExpression() == nil:
+ // If the concrete is not excluded and the wildcard conditional is empty, then the concrete is always found.
+ // Example: {user:tom} & {*} => {user:tom}
+ return &concrete, nil
+
+ case !isExcluded && wildcardToIntersect.GetCaveatExpression() != nil:
+ // The concrete subject is only included if the wildcard's caveat is true.
+ // Example: {user:tom}[acaveat] & {* - user:tom}[somecaveat] => {user:tom}[acaveat && somecaveat]
+ constructed := constructor(
+ concrete.GetSubjectId(),
+ caveatAnd(concrete.GetCaveatExpression(), wildcardToIntersect.GetCaveatExpression()),
+ nil,
+ concrete,
+ wildcardToIntersect)
+ return &constructed, nil
+
+ case isExcluded && exclusion.GetCaveatExpression() == nil:
+ // If the concrete is excluded and the exclusion is not conditional, then the concrete can never show up,
+ // regardless of whether the wildcard is conditional.
+ // Example: {user:tom} & {* - user:tom}[somecaveat] => {}
+ return nil, nil
+
+ case isExcluded && exclusion.GetCaveatExpression() != nil:
+ // NOTE: whether the wildcard is itself conditional or not is handled within the expression combinators below.
+ // The concrete subject is included if the wildcard's caveat is true and the exclusion's caveat is *false*.
+ // Example: {user:tom}[acaveat] & {* - user:tom[ecaveat]}[wcaveat] => {user:tom[acaveat && wcaveat && !ecaveat]}
+ constructed := constructor(
+ concrete.GetSubjectId(),
+ caveatAnd(
+ concrete.GetCaveatExpression(),
+ caveatAnd(
+ wildcardToIntersect.GetCaveatExpression(),
+ caveatInvert(exclusion.GetCaveatExpression()),
+ )),
+ nil,
+ concrete,
+ wildcardToIntersect,
+ exclusion)
+ return &constructed, nil
+
+ default:
+ return nil, spiceerrors.MustBugf("unhandled case in basesubjectset intersectConcreteWithWildcard: %v & %v", concrete, wildcardToIntersect)
+ }
+}
+
+type handle[T any] struct {
+ value *T
+}
+
+func newHandle[T any]() *handle[T] {
+ return &handle[T]{}
+}
+
+func (h *handle[T]) getOrNil() *T {
+ return h.value
+}
+
+func (h *handle[T]) setOrNil(value *T) {
+ h.value = value
+}
+
+func (h *handle[T]) get() (T, bool) {
+ if h.value != nil {
+ return *h.value, true
+ }
+
+ return *new(T), false
+}
+
+func (h *handle[T]) clear() {
+ h.value = nil
+}
+
+func (h *handle[T]) clone() *handle[T] {
+ return &handle[T]{
+ value: h.value,
+ }
+}
+
+// exclusionsMapFor creates a map of all the exclusions on a wildcard, by subject ID.
+func exclusionsMapFor[T Subject[T]](wildcard T) map[string]T {
+ exclusions := make(map[string]T, len(wildcard.GetExcludedSubjects()))
+ for _, excludedSubject := range wildcard.GetExcludedSubjects() {
+ exclusions[excludedSubject.GetSubjectId()] = excludedSubject
+ }
+ return exclusions
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/doc.go b/vendor/github.com/authzed/spicedb/internal/datasets/doc.go
new file mode 100644
index 0000000..6ff324c
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datasets/doc.go
@@ -0,0 +1,2 @@
+// Package datasets defines operations with sets of subjects.
+package datasets
diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/subjectset.go b/vendor/github.com/authzed/spicedb/internal/datasets/subjectset.go
new file mode 100644
index 0000000..551bfaa
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datasets/subjectset.go
@@ -0,0 +1,65 @@
+package datasets
+
+import (
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+)
+
+// SubjectSet defines a set that tracks accessible subjects.
+//
+// NOTE: Unlike a traditional set, unions between wildcards and a concrete subject will result
+// in *both* being present in the set, to maintain the proper set semantics around wildcards.
+type SubjectSet struct {
+ BaseSubjectSet[*v1.FoundSubject]
+}
+
+// NewSubjectSet creates and returns a new subject set.
+func NewSubjectSet() SubjectSet {
+ return SubjectSet{
+ BaseSubjectSet: NewBaseSubjectSet(subjectSetConstructor),
+ }
+}
+
+func (ss SubjectSet) SubtractAll(other SubjectSet) {
+ ss.BaseSubjectSet.SubtractAll(other.BaseSubjectSet)
+}
+
+func (ss SubjectSet) MustIntersectionDifference(other SubjectSet) {
+ ss.BaseSubjectSet.MustIntersectionDifference(other.BaseSubjectSet)
+}
+
+func (ss SubjectSet) IntersectionDifference(other SubjectSet) error {
+ return ss.BaseSubjectSet.IntersectionDifference(other.BaseSubjectSet)
+}
+
+func (ss SubjectSet) MustUnionWithSet(other SubjectSet) {
+ ss.BaseSubjectSet.MustUnionWithSet(other.BaseSubjectSet)
+}
+
+func (ss SubjectSet) Clone() SubjectSet {
+ return SubjectSet{ss.BaseSubjectSet.Clone()}
+}
+
+func (ss SubjectSet) UnionWithSet(other SubjectSet) error {
+ return ss.BaseSubjectSet.UnionWithSet(other.BaseSubjectSet)
+}
+
+// WithParentCaveatExpression returns a copy of the subject set with the parent caveat expression applied
+// to all members of this set.
+func (ss SubjectSet) WithParentCaveatExpression(parentCaveatExpr *core.CaveatExpression) SubjectSet {
+ return SubjectSet{ss.BaseSubjectSet.WithParentCaveatExpression(parentCaveatExpr)}
+}
+
+func (ss SubjectSet) AsFoundSubjects() *v1.FoundSubjects {
+ return &v1.FoundSubjects{
+ FoundSubjects: ss.AsSlice(),
+ }
+}
+
+func subjectSetConstructor(subjectID string, caveatExpression *core.CaveatExpression, excludedSubjects []*v1.FoundSubject, _ ...*v1.FoundSubject) *v1.FoundSubject {
+ return &v1.FoundSubject{
+ SubjectId: subjectID,
+ CaveatExpression: caveatExpression,
+ ExcludedSubjects: excludedSubjects,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbyresourceid.go b/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbyresourceid.go
new file mode 100644
index 0000000..5b1ba13
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbyresourceid.go
@@ -0,0 +1,117 @@
+package datasets
+
+import (
+ "fmt"
+
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// NewSubjectSetByResourceID creates and returns a map of subject sets, indexed by resource ID.
+func NewSubjectSetByResourceID() SubjectSetByResourceID {
+ return SubjectSetByResourceID{
+ subjectSetByResourceID: map[string]SubjectSet{},
+ }
+}
+
+// SubjectSetByResourceID defines a helper type which maps from a resource ID to its associated found
+// subjects, in the form of a subject set per resource ID.
+type SubjectSetByResourceID struct {
+ subjectSetByResourceID map[string]SubjectSet
+}
+
+func (ssr SubjectSetByResourceID) add(resourceID string, subject *v1.FoundSubject) error {
+ if subject == nil {
+ return fmt.Errorf("cannot add a nil subject to SubjectSetByResourceID")
+ }
+
+ _, ok := ssr.subjectSetByResourceID[resourceID]
+ if !ok {
+ ssr.subjectSetByResourceID[resourceID] = NewSubjectSet()
+ }
+ return ssr.subjectSetByResourceID[resourceID].Add(subject)
+}
+
+// AddFromRelationship adds the subject found in the given relationship to this map, indexed at
+// the resource ID specified in the relationship.
+func (ssr SubjectSetByResourceID) AddFromRelationship(relationship tuple.Relationship) error {
+ return ssr.add(relationship.Resource.ObjectID, &v1.FoundSubject{
+ SubjectId: relationship.Subject.ObjectID,
+ CaveatExpression: wrapCaveat(relationship.OptionalCaveat),
+ })
+}
+
+// UnionWith unions the map's sets with the other map of sets provided.
+func (ssr SubjectSetByResourceID) UnionWith(other map[string]*v1.FoundSubjects) error {
+ for resourceID, subjects := range other {
+ if subjects == nil {
+ return fmt.Errorf("received nil FoundSubjects in other map of SubjectSetByResourceID's UnionWith for key %s", resourceID)
+ }
+
+ for _, subject := range subjects.FoundSubjects {
+ if err := ssr.add(resourceID, subject); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// IntersectionDifference performs an in-place intersection between the two maps' sets.
+func (ssr SubjectSetByResourceID) IntersectionDifference(other SubjectSetByResourceID) error {
+ for otherResourceID, otherSubjectSet := range other.subjectSetByResourceID {
+ existing, ok := ssr.subjectSetByResourceID[otherResourceID]
+ if !ok {
+ continue
+ }
+
+ err := existing.IntersectionDifference(otherSubjectSet)
+ if err != nil {
+ return err
+ }
+
+ if existing.IsEmpty() {
+ delete(ssr.subjectSetByResourceID, otherResourceID)
+ }
+ }
+
+ for existingResourceID := range ssr.subjectSetByResourceID {
+ _, ok := other.subjectSetByResourceID[existingResourceID]
+ if !ok {
+ delete(ssr.subjectSetByResourceID, existingResourceID)
+ continue
+ }
+ }
+
+ return nil
+}
+
+// SubtractAll subtracts all sets in the other map from this map's sets.
+func (ssr SubjectSetByResourceID) SubtractAll(other SubjectSetByResourceID) {
+ for otherResourceID, otherSubjectSet := range other.subjectSetByResourceID {
+ existing, ok := ssr.subjectSetByResourceID[otherResourceID]
+ if !ok {
+ continue
+ }
+
+ existing.SubtractAll(otherSubjectSet)
+ if existing.IsEmpty() {
+ delete(ssr.subjectSetByResourceID, otherResourceID)
+ }
+ }
+}
+
+// IsEmpty returns true if the map is empty.
+func (ssr SubjectSetByResourceID) IsEmpty() bool {
+ return len(ssr.subjectSetByResourceID) == 0
+}
+
+// AsMap converts the map into a map for storage in a proto.
+func (ssr SubjectSetByResourceID) AsMap() map[string]*v1.FoundSubjects {
+ mapped := make(map[string]*v1.FoundSubjects, len(ssr.subjectSetByResourceID))
+ for resourceID, subjectsSet := range ssr.subjectSetByResourceID {
+ mapped[resourceID] = subjectsSet.AsFoundSubjects()
+ }
+ return mapped
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbytype.go b/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbytype.go
new file mode 100644
index 0000000..8882a2e
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datasets/subjectsetbytype.go
@@ -0,0 +1,113 @@
+package datasets
+
+import (
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// SubjectByTypeSet is a set of SubjectSet's, grouped by their subject types.
+type SubjectByTypeSet struct {
+ byType map[string]SubjectSet
+}
+
+// NewSubjectByTypeSet creates and returns a new SubjectByTypeSet.
+func NewSubjectByTypeSet() *SubjectByTypeSet {
+ return &SubjectByTypeSet{
+ byType: map[string]SubjectSet{},
+ }
+}
+
+// AddSubjectOf adds the subject found in the given relationship, along with its caveat.
+func (s *SubjectByTypeSet) AddSubjectOf(relationship tuple.Relationship) error {
+ return s.AddSubject(relationship.Subject, relationship.OptionalCaveat)
+}
+
+// AddConcreteSubject adds a non-caveated subject to the set.
+func (s *SubjectByTypeSet) AddConcreteSubject(subject tuple.ObjectAndRelation) error {
+ return s.AddSubject(subject, nil)
+}
+
+// AddSubject adds the specified subject to the set.
+func (s *SubjectByTypeSet) AddSubject(subject tuple.ObjectAndRelation, caveat *core.ContextualizedCaveat) error {
+ key := tuple.JoinRelRef(subject.ObjectType, subject.Relation)
+ if _, ok := s.byType[key]; !ok {
+ s.byType[key] = NewSubjectSet()
+ }
+
+ return s.byType[key].Add(&v1.FoundSubject{
+ SubjectId: subject.ObjectID,
+ CaveatExpression: wrapCaveat(caveat),
+ })
+}
+
+// ForEachType invokes the handler for each type of ObjectAndRelation found in the set, along
+// with all IDs of objects of that type.
+func (s *SubjectByTypeSet) ForEachType(handler func(rr *core.RelationReference, subjects SubjectSet)) {
+ for key, subjects := range s.byType {
+ ns, rel := tuple.MustSplitRelRef(key)
+ handler(&core.RelationReference{
+ Namespace: ns,
+ Relation: rel,
+ }, subjects)
+ }
+}
+
+// Map runs the mapper function over each type of object in the set, returning a new SubjectByTypeSet with
+// the object type replaced by that returned by the mapper function.
+func (s *SubjectByTypeSet) Map(mapper func(rr *core.RelationReference) (*core.RelationReference, error)) (*SubjectByTypeSet, error) {
+ mapped := NewSubjectByTypeSet()
+ for key, subjectset := range s.byType {
+ ns, rel := tuple.MustSplitRelRef(key)
+ updatedType, err := mapper(&core.RelationReference{
+ Namespace: ns,
+ Relation: rel,
+ })
+ if err != nil {
+ return nil, err
+ }
+ if updatedType == nil {
+ continue
+ }
+
+ key := tuple.JoinRelRef(updatedType.Namespace, updatedType.Relation)
+ if existing, ok := mapped.byType[key]; ok {
+ cloned := subjectset.Clone()
+ if err := cloned.UnionWithSet(existing); err != nil {
+ return nil, err
+ }
+ mapped.byType[key] = cloned
+ } else {
+ mapped.byType[key] = subjectset
+ }
+ }
+ return mapped, nil
+}
+
+// IsEmpty returns true if the set is empty.
+func (s *SubjectByTypeSet) IsEmpty() bool {
+ return len(s.byType) == 0
+}
+
+// Len returns the number of keys in the set.
+func (s *SubjectByTypeSet) Len() int {
+ return len(s.byType)
+}
+
+// SubjectSetForType returns the subject set associated with the given subject type, if any.
+func (s *SubjectByTypeSet) SubjectSetForType(rr *core.RelationReference) (SubjectSet, bool) {
+ found, ok := s.byType[tuple.JoinRelRef(rr.Namespace, rr.Relation)]
+ return found, ok
+}
+
+func wrapCaveat(caveat *core.ContextualizedCaveat) *core.CaveatExpression {
+ if caveat == nil {
+ return nil
+ }
+
+ return &core.CaveatExpression{
+ OperationOrCaveat: &core.CaveatExpression_Caveat{
+ Caveat: caveat,
+ },
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go
new file mode 100644
index 0000000..291abb5
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/changes.go
@@ -0,0 +1,352 @@
+package common
+
+import (
+ "context"
+ "sort"
+
+ "golang.org/x/exp/maps"
+ "google.golang.org/protobuf/types/known/structpb"
+
+ "github.com/ccoveille/go-safecast"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+const (
+ nsPrefix = "n$"
+ caveatPrefix = "c$"
+)
+
+// Changes represents a set of datastore mutations that are kept self-consistent
+// across one or more transaction revisions.
+type Changes[R datastore.Revision, K comparable] struct {
+ records map[K]changeRecord[R]
+ keyFunc func(R) K
+ content datastore.WatchContent
+ maxByteSize uint64
+ currentByteSize int64
+}
+
+type changeRecord[R datastore.Revision] struct {
+ rev R
+ relTouches map[string]tuple.Relationship
+ relDeletes map[string]tuple.Relationship
+ definitionsChanged map[string]datastore.SchemaDefinition
+ namespacesDeleted map[string]struct{}
+ caveatsDeleted map[string]struct{}
+ metadata map[string]any
+}
+
+// NewChanges creates a new Changes object for change tracking and de-duplication.
+func NewChanges[R datastore.Revision, K comparable](keyFunc func(R) K, content datastore.WatchContent, maxByteSize uint64) *Changes[R, K] {
+ return &Changes[R, K]{
+ records: make(map[K]changeRecord[R], 0),
+ keyFunc: keyFunc,
+ content: content,
+ maxByteSize: maxByteSize,
+ currentByteSize: 0,
+ }
+}
+
+// IsEmpty returns if the change set is empty.
+func (ch *Changes[R, K]) IsEmpty() bool {
+ return len(ch.records) == 0
+}
+
+// AddRelationshipChange adds a specific change to the complete list of tracked changes
+func (ch *Changes[R, K]) AddRelationshipChange(
+ ctx context.Context,
+ rev R,
+ rel tuple.Relationship,
+ op tuple.UpdateOperation,
+) error {
+ if ch.content&datastore.WatchRelationships != datastore.WatchRelationships {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ key := tuple.StringWithoutCaveatOrExpiration(rel)
+
+ switch op {
+ case tuple.UpdateOperationTouch:
+ // If there was a delete for the same tuple at the same revision, drop it
+ existing, ok := record.relDeletes[key]
+ if ok {
+ delete(record.relDeletes, key)
+ if err := ch.adjustByteSize(existing, -1); err != nil {
+ return err
+ }
+ }
+
+ record.relTouches[key] = rel
+ if err := ch.adjustByteSize(rel, 1); err != nil {
+ return err
+ }
+
+ case tuple.UpdateOperationDelete:
+ _, alreadyTouched := record.relTouches[key]
+ if !alreadyTouched {
+ record.relDeletes[key] = rel
+ if err := ch.adjustByteSize(rel, 1); err != nil {
+ return err
+ }
+ }
+
+ default:
+ return spiceerrors.MustBugf("unknown change operation")
+ }
+
+ return nil
+}
+
+type sized interface {
+ SizeVT() int
+}
+
+func (ch *Changes[R, K]) adjustByteSize(item sized, delta int) error {
+ if ch.maxByteSize == 0 {
+ return nil
+ }
+
+ size := item.SizeVT() * delta
+ ch.currentByteSize += int64(size)
+ if ch.currentByteSize < 0 {
+ return spiceerrors.MustBugf("byte size underflow")
+ }
+
+ currentByteSize, err := safecast.ToUint64(ch.currentByteSize)
+ if err != nil {
+ return spiceerrors.MustBugf("could not cast currentByteSize to uint64: %v", err)
+ }
+
+ if currentByteSize > ch.maxByteSize {
+ return datastore.NewMaximumChangesSizeExceededError(ch.maxByteSize)
+ }
+
+ return nil
+}
+
+// SetRevisionMetadata sets the metadata for the given revision.
+func (ch *Changes[R, K]) SetRevisionMetadata(ctx context.Context, rev R, metadata map[string]any) error {
+ if len(metadata) == 0 {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ if len(record.metadata) > 0 {
+ return spiceerrors.MustBugf("metadata already set for revision")
+ }
+
+ maps.Copy(record.metadata, metadata)
+ return nil
+}
+
+func (ch *Changes[R, K]) recordForRevision(rev R) (changeRecord[R], error) {
+ k := ch.keyFunc(rev)
+ revisionChanges, ok := ch.records[k]
+ if !ok {
+ revisionChanges = changeRecord[R]{
+ rev,
+ make(map[string]tuple.Relationship),
+ make(map[string]tuple.Relationship),
+ make(map[string]datastore.SchemaDefinition),
+ make(map[string]struct{}),
+ make(map[string]struct{}),
+ make(map[string]any),
+ }
+ ch.records[k] = revisionChanges
+ }
+
+ return revisionChanges, nil
+}
+
+// AddDeletedNamespace adds a change indicating that the namespace with the name was deleted.
+func (ch *Changes[R, K]) AddDeletedNamespace(
+ _ context.Context,
+ rev R,
+ namespaceName string,
+) error {
+ if ch.content&datastore.WatchSchema != datastore.WatchSchema {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ // if a delete happens in the same transaction as a change, we assume it was a change in the first place
+ // because that's how namespace changes are implemented in the MVCC
+ if _, ok := record.definitionsChanged[nsPrefix+namespaceName]; ok {
+ return nil
+ }
+
+ delete(record.definitionsChanged, nsPrefix+namespaceName)
+ record.namespacesDeleted[namespaceName] = struct{}{}
+ return nil
+}
+
+// AddDeletedCaveat adds a change indicating that the caveat with the name was deleted.
+func (ch *Changes[R, K]) AddDeletedCaveat(
+ _ context.Context,
+ rev R,
+ caveatName string,
+) error {
+ if ch.content&datastore.WatchSchema != datastore.WatchSchema {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ // if a delete happens in the same transaction as a change, we assume it was a change in the first place
+ // because that's how namespace changes are implemented in the MVCC
+ if _, ok := record.definitionsChanged[caveatPrefix+caveatName]; ok {
+ return nil
+ }
+
+ delete(record.definitionsChanged, caveatPrefix+caveatName)
+ record.caveatsDeleted[caveatName] = struct{}{}
+ return nil
+}
+
+// AddChangedDefinition adds a change indicating that the schema definition (namespace or caveat)
+// was changed to the definition given.
+func (ch *Changes[R, K]) AddChangedDefinition(
+ ctx context.Context,
+ rev R,
+ def datastore.SchemaDefinition,
+) error {
+ if ch.content&datastore.WatchSchema != datastore.WatchSchema {
+ return nil
+ }
+
+ record, err := ch.recordForRevision(rev)
+ if err != nil {
+ return err
+ }
+
+ switch t := def.(type) {
+ case *core.NamespaceDefinition:
+ delete(record.namespacesDeleted, t.Name)
+
+ if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok {
+ if err := ch.adjustByteSize(existing, -1); err != nil {
+ return err
+ }
+ }
+
+ record.definitionsChanged[nsPrefix+t.Name] = t
+
+ if err := ch.adjustByteSize(t, 1); err != nil {
+ return err
+ }
+
+ case *core.CaveatDefinition:
+ delete(record.caveatsDeleted, t.Name)
+
+ if existing, ok := record.definitionsChanged[nsPrefix+t.Name]; ok {
+ if err := ch.adjustByteSize(existing, -1); err != nil {
+ return err
+ }
+ }
+
+ record.definitionsChanged[caveatPrefix+t.Name] = t
+
+ if err := ch.adjustByteSize(t, 1); err != nil {
+ return err
+ }
+
+ default:
+ log.Ctx(ctx).Fatal().Msg("unknown schema definition kind")
+ }
+
+ return nil
+}
+
+// AsRevisionChanges returns the list of changes processed so far as a datastore watch
+// compatible, ordered, changelist.
+func (ch *Changes[R, K]) AsRevisionChanges(lessThanFunc func(lhs, rhs K) bool) ([]datastore.RevisionChanges, error) {
+ return ch.revisionChanges(lessThanFunc, *new(R), false)
+}
+
+// FilterAndRemoveRevisionChanges filters a list of changes processed up to the bound revision from the changes list, removing them
+// and returning the filtered changes.
+func (ch *Changes[R, K]) FilterAndRemoveRevisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R) ([]datastore.RevisionChanges, error) {
+ changes, err := ch.revisionChanges(lessThanFunc, boundRev, true)
+ if err != nil {
+ return nil, err
+ }
+
+ ch.removeAllChangesBefore(boundRev)
+ return changes, nil
+}
+
+func (ch *Changes[R, K]) revisionChanges(lessThanFunc func(lhs, rhs K) bool, boundRev R, withBound bool) ([]datastore.RevisionChanges, error) {
+ if ch.IsEmpty() {
+ return nil, nil
+ }
+
+ revisionsWithChanges := make([]K, 0, len(ch.records))
+ for rk, cr := range ch.records {
+ if !withBound || boundRev.GreaterThan(cr.rev) {
+ revisionsWithChanges = append(revisionsWithChanges, rk)
+ }
+ }
+
+ if len(revisionsWithChanges) == 0 {
+ return nil, nil
+ }
+
+ sort.Slice(revisionsWithChanges, func(i int, j int) bool {
+ return lessThanFunc(revisionsWithChanges[i], revisionsWithChanges[j])
+ })
+
+ changes := make([]datastore.RevisionChanges, len(revisionsWithChanges))
+ for i, k := range revisionsWithChanges {
+ revisionChangeRecord := ch.records[k]
+ changes[i].Revision = revisionChangeRecord.rev
+ for _, rel := range revisionChangeRecord.relTouches {
+ changes[i].RelationshipChanges = append(changes[i].RelationshipChanges, tuple.Touch(rel))
+ }
+ for _, rel := range revisionChangeRecord.relDeletes {
+ changes[i].RelationshipChanges = append(changes[i].RelationshipChanges, tuple.Delete(rel))
+ }
+ changes[i].ChangedDefinitions = maps.Values(revisionChangeRecord.definitionsChanged)
+ changes[i].DeletedNamespaces = maps.Keys(revisionChangeRecord.namespacesDeleted)
+ changes[i].DeletedCaveats = maps.Keys(revisionChangeRecord.caveatsDeleted)
+
+ if len(revisionChangeRecord.metadata) > 0 {
+ metadata, err := structpb.NewStruct(revisionChangeRecord.metadata)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("failed to convert metadata to structpb: %v", err)
+ }
+
+ changes[i].Metadata = metadata
+ }
+ }
+
+ return changes, nil
+}
+
+func (ch *Changes[R, K]) removeAllChangesBefore(boundRev R) {
+ for rk, cr := range ch.records {
+ if boundRev.GreaterThan(cr.rev) {
+ delete(ch.records, rk)
+ }
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go
new file mode 100644
index 0000000..af0b229
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/errors.go
@@ -0,0 +1,154 @@
+package common
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "regexp"
+ "strings"
+
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// SerializationError is returned when there's been a serialization
+// error while performing a datastore operation
+type SerializationError struct {
+ error
+}
+
+func (err SerializationError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.Aborted,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_SERIALIZATION_FAILURE,
+ map[string]string{},
+ ),
+ )
+}
+
+func (err SerializationError) Unwrap() error {
+ return err.error
+}
+
+// NewSerializationError creates a new SerializationError
+func NewSerializationError(err error) error {
+ return SerializationError{err}
+}
+
+// ReadOnlyTransactionError is returned when an otherwise read-write
+// transaction fails on writes with an error indicating that the datastore
+// is currently in a read-only mode.
+type ReadOnlyTransactionError struct {
+ error
+}
+
+func (err ReadOnlyTransactionError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.Aborted,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_SERVICE_READ_ONLY,
+ map[string]string{},
+ ),
+ )
+}
+
+// NewReadOnlyTransactionError creates a new ReadOnlyTransactionError.
+func NewReadOnlyTransactionError(err error) error {
+ return ReadOnlyTransactionError{
+ fmt.Errorf("could not perform write operation, as the datastore is currently in read-only mode: %w. This may indicate that the datastore has been put into maintenance mode", err),
+ }
+}
+
+// CreateRelationshipExistsError is an error returned when attempting to CREATE an already-existing
+// relationship.
+type CreateRelationshipExistsError struct {
+ error
+
+ // Relationship is the relationship that caused the error. May be nil, depending on the datastore.
+ Relationship *tuple.Relationship
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err CreateRelationshipExistsError) GRPCStatus() *status.Status {
+ if err.Relationship == nil {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.AlreadyExists,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_ATTEMPT_TO_RECREATE_RELATIONSHIP,
+ map[string]string{},
+ ),
+ )
+ }
+
+ relationship := tuple.ToV1Relationship(*err.Relationship)
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.AlreadyExists,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_ATTEMPT_TO_RECREATE_RELATIONSHIP,
+ map[string]string{
+ "relationship": tuple.V1StringRelationshipWithoutCaveatOrExpiration(relationship),
+ "resource_type": relationship.Resource.ObjectType,
+ "resource_object_id": relationship.Resource.ObjectId,
+ "resource_relation": relationship.Relation,
+ "subject_type": relationship.Subject.Object.ObjectType,
+ "subject_object_id": relationship.Subject.Object.ObjectId,
+ "subject_relation": relationship.Subject.OptionalRelation,
+ },
+ ),
+ )
+}
+
+// NewCreateRelationshipExistsError creates a new CreateRelationshipExistsError.
+func NewCreateRelationshipExistsError(relationship *tuple.Relationship) error {
+ msg := "could not CREATE one or more relationships, as they already existed. If this is persistent, please switch to TOUCH operations or specify a precondition"
+ if relationship != nil {
+ msg = fmt.Sprintf("could not CREATE relationship `%s`, as it already existed. If this is persistent, please switch to TOUCH operations or specify a precondition", tuple.StringWithoutCaveatOrExpiration(*relationship))
+ }
+
+ return CreateRelationshipExistsError{
+ errors.New(msg),
+ relationship,
+ }
+}
+
+var (
+ portMatchRegex = regexp.MustCompile("invalid port \\\"(.+)\\\" after host")
+ parseMatchRegex = regexp.MustCompile("parse \\\"(.+)\\\":")
+)
+
+// RedactAndLogSensitiveConnString elides the given error, logging it only at trace
+// level (after being redacted).
+func RedactAndLogSensitiveConnString(ctx context.Context, baseErr string, err error, pgURL string) error {
+ if err == nil {
+ return errors.New(baseErr)
+ }
+
+ // See: https://github.com/jackc/pgx/issues/1271
+ filtered := err.Error()
+ filtered = strings.ReplaceAll(filtered, pgURL, "(redacted)")
+ filtered = portMatchRegex.ReplaceAllString(filtered, "(redacted)")
+ filtered = parseMatchRegex.ReplaceAllString(filtered, "(redacted)")
+ log.Ctx(ctx).Trace().Msg(baseErr + ": " + filtered)
+ return fmt.Errorf("%s. To view details of this error (that may contain sensitive information), please run with --log-level=trace", baseErr)
+}
+
+// RevisionUnavailableError is returned when a revision is not available on a replica.
+type RevisionUnavailableError struct {
+ error
+}
+
+// NewRevisionUnavailableError creates a new RevisionUnavailableError.
+func NewRevisionUnavailableError(err error) error {
+ return RevisionUnavailableError{err}
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go
new file mode 100644
index 0000000..5788134
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/gc.go
@@ -0,0 +1,269 @@
+package common
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/rs/zerolog"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+var (
+ gcDurationHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_duration_seconds",
+ Help: "The duration of datastore garbage collection.",
+ Buckets: []float64{0.01, 0.1, 0.5, 1, 5, 10, 25, 60, 120},
+ })
+
+ gcRelationshipsCounter = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_relationships_total",
+ Help: "The number of stale relationships deleted by the datastore garbage collection.",
+ })
+
+ gcExpiredRelationshipsCounter = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_expired_relationships_total",
+ Help: "The number of expired relationships deleted by the datastore garbage collection.",
+ })
+
+ gcTransactionsCounter = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_transactions_total",
+ Help: "The number of stale transactions deleted by the datastore garbage collection.",
+ })
+
+ gcNamespacesCounter = prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_namespaces_total",
+ Help: "The number of stale namespaces deleted by the datastore garbage collection.",
+ })
+
+ gcFailureCounterConfig = prometheus.CounterOpts{
+ Namespace: "spicedb",
+ Subsystem: "datastore",
+ Name: "gc_failure_total",
+ Help: "The number of failed runs of the datastore garbage collection.",
+ }
+ gcFailureCounter = prometheus.NewCounter(gcFailureCounterConfig)
+)
+
+// RegisterGCMetrics registers garbage collection metrics to the default
+// registry.
+func RegisterGCMetrics() error {
+ for _, metric := range []prometheus.Collector{
+ gcDurationHistogram,
+ gcRelationshipsCounter,
+ gcTransactionsCounter,
+ gcNamespacesCounter,
+ gcFailureCounter,
+ } {
+ if err := prometheus.Register(metric); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// GarbageCollector represents any datastore that supports external garbage
+// collection.
+type GarbageCollector interface {
+ // HasGCRun returns true if a garbage collection run has been completed.
+ HasGCRun() bool
+
+ // MarkGCCompleted marks that a garbage collection run has been completed.
+ MarkGCCompleted()
+
+ // ResetGCCompleted resets the state of the garbage collection run.
+ ResetGCCompleted()
+
+ // LockForGCRun attempts to acquire a lock for garbage collection. This lock
+ // is typically done at the datastore level, to ensure that no other nodes are
+ // running garbage collection at the same time.
+ LockForGCRun(ctx context.Context) (bool, error)
+
+ // UnlockAfterGCRun releases the lock after a garbage collection run.
+ // NOTE: this method does not take a context, as the context used for the
+ // reset of the GC run can be canceled/timed out and the unlock will still need to happen.
+ UnlockAfterGCRun() error
+
+ // ReadyState returns the current state of the datastore.
+ ReadyState(context.Context) (datastore.ReadyState, error)
+
+ // Now returns the current time from the datastore.
+ Now(context.Context) (time.Time, error)
+
+ // TxIDBefore returns the highest transaction ID before the provided time.
+ TxIDBefore(context.Context, time.Time) (datastore.Revision, error)
+
+ // DeleteBeforeTx deletes all data before the provided transaction ID.
+ DeleteBeforeTx(ctx context.Context, txID datastore.Revision) (DeletionCounts, error)
+
+ // DeleteExpiredRels deletes all relationships that have expired.
+ DeleteExpiredRels(ctx context.Context) (int64, error)
+}
+
+// DeletionCounts tracks the amount of deletions that occurred when calling
+// DeleteBeforeTx.
+type DeletionCounts struct {
+ Relationships int64
+ Transactions int64
+ Namespaces int64
+}
+
+func (g DeletionCounts) MarshalZerologObject(e *zerolog.Event) {
+ e.
+ Int64("relationships", g.Relationships).
+ Int64("transactions", g.Transactions).
+ Int64("namespaces", g.Namespaces)
+}
+
+var MaxGCInterval = 60 * time.Minute
+
+// StartGarbageCollector loops forever until the context is canceled and
+// performs garbage collection on the provided interval.
+func StartGarbageCollector(ctx context.Context, gc GarbageCollector, interval, window, timeout time.Duration) error {
+ return startGarbageCollectorWithMaxElapsedTime(ctx, gc, interval, window, 0, timeout, gcFailureCounter)
+}
+
+func startGarbageCollectorWithMaxElapsedTime(ctx context.Context, gc GarbageCollector, interval, window, maxElapsedTime, timeout time.Duration, failureCounter prometheus.Counter) error {
+ backoffInterval := backoff.NewExponentialBackOff()
+ backoffInterval.InitialInterval = interval
+ backoffInterval.MaxInterval = max(MaxGCInterval, interval)
+ backoffInterval.MaxElapsedTime = maxElapsedTime
+ backoffInterval.Reset()
+
+ nextInterval := interval
+
+ log.Ctx(ctx).Info().
+ Dur("interval", nextInterval).
+ Msg("datastore garbage collection worker started")
+
+ for {
+ select {
+ case <-ctx.Done():
+ log.Ctx(ctx).Info().
+ Msg("shutting down datastore garbage collection worker")
+ return ctx.Err()
+
+ case <-time.After(nextInterval):
+ log.Ctx(ctx).Info().
+ Dur("interval", nextInterval).
+ Dur("window", window).
+ Dur("timeout", timeout).
+ Msg("running garbage collection worker")
+
+ err := RunGarbageCollection(gc, window, timeout)
+ if err != nil {
+ failureCounter.Inc()
+ nextInterval = backoffInterval.NextBackOff()
+ log.Ctx(ctx).Warn().Err(err).
+ Dur("next-attempt-in", nextInterval).
+ Msg("error attempting to perform garbage collection")
+ continue
+ }
+
+ backoffInterval.Reset()
+ nextInterval = interval
+
+ log.Ctx(ctx).Debug().
+ Dur("next-run-in", interval).
+ Msg("datastore garbage collection scheduled for next run")
+ }
+ }
+}
+
+// RunGarbageCollection runs garbage collection for the datastore.
+func RunGarbageCollection(gc GarbageCollector, window, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+
+ ctx, span := tracer.Start(ctx, "RunGarbageCollection")
+ defer span.End()
+
+ // Before attempting anything, check if the datastore is ready.
+ startTime := time.Now()
+ ready, err := gc.ReadyState(ctx)
+ if err != nil {
+ return err
+ }
+ if !ready.IsReady {
+ log.Ctx(ctx).Warn().
+ Msgf("datastore wasn't ready when attempting garbage collection: %s", ready.Message)
+ return nil
+ }
+
+ ok, err := gc.LockForGCRun(ctx)
+ if err != nil {
+ return fmt.Errorf("error locking for gc run: %w", err)
+ }
+
+ if !ok {
+ log.Info().
+ Msg("datastore garbage collection already in progress on another node")
+ return nil
+ }
+
+ defer func() {
+ err := gc.UnlockAfterGCRun()
+ if err != nil {
+ log.Error().
+ Err(err).
+ Msg("error unlocking after gc run")
+ }
+ }()
+
+ now, err := gc.Now(ctx)
+ if err != nil {
+ return fmt.Errorf("error retrieving now: %w", err)
+ }
+
+ watermark, err := gc.TxIDBefore(ctx, now.Add(-1*window))
+ if err != nil {
+ return fmt.Errorf("error retrieving watermark: %w", err)
+ }
+
+ collected, err := gc.DeleteBeforeTx(ctx, watermark)
+
+ expiredRelationshipsCount, eerr := gc.DeleteExpiredRels(ctx)
+
+ // even if an error happened, garbage would have been collected. This makes sure these are reflected even if the
+ // worker eventually fails or times out.
+ gcRelationshipsCounter.Add(float64(collected.Relationships))
+ gcTransactionsCounter.Add(float64(collected.Transactions))
+ gcNamespacesCounter.Add(float64(collected.Namespaces))
+ gcExpiredRelationshipsCounter.Add(float64(expiredRelationshipsCount))
+ collectionDuration := time.Since(startTime)
+ gcDurationHistogram.Observe(collectionDuration.Seconds())
+
+ if err != nil {
+ return fmt.Errorf("error deleting in gc: %w", err)
+ }
+
+ if eerr != nil {
+ return fmt.Errorf("error deleting expired relationships in gc: %w", eerr)
+ }
+
+ log.Ctx(ctx).Info().
+ Stringer("highestTxID", watermark).
+ Dur("duration", collectionDuration).
+ Time("nowTime", now).
+ Interface("collected", collected).
+ Int64("expiredRelationships", expiredRelationshipsCount).
+ Msg("datastore garbage collection completed successfully")
+
+ gc.MarkGCCompleted()
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go
new file mode 100644
index 0000000..8f34134
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/helpers.go
@@ -0,0 +1,49 @@
+package common
+
+import (
+ "context"
+ "fmt"
+
+ "google.golang.org/protobuf/types/known/structpb"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// WriteRelationships is a convenience method to perform the same update operation on a set of relationships
+func WriteRelationships(ctx context.Context, ds datastore.Datastore, op tuple.UpdateOperation, rels ...tuple.Relationship) (datastore.Revision, error) {
+ updates := make([]tuple.RelationshipUpdate, 0, len(rels))
+ for _, rel := range rels {
+ ru := tuple.RelationshipUpdate{
+ Operation: op,
+ Relationship: rel,
+ }
+ updates = append(updates, ru)
+ }
+ return UpdateRelationshipsInDatastore(ctx, ds, updates...)
+}
+
+// UpdateRelationshipsInDatastore is a convenience method to perform multiple relation update operations on a Datastore
+func UpdateRelationshipsInDatastore(ctx context.Context, ds datastore.Datastore, updates ...tuple.RelationshipUpdate) (datastore.Revision, error) {
+ return ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ return rwt.WriteRelationships(ctx, updates)
+ })
+}
+
+// ContextualizedCaveatFrom convenience method that handles creation of a contextualized caveat
+// given the possibility of arguments with zero-values.
+func ContextualizedCaveatFrom(name string, context map[string]any) (*core.ContextualizedCaveat, error) {
+ var caveat *core.ContextualizedCaveat
+ if name != "" {
+ strct, err := structpb.NewStruct(context)
+ if err != nil {
+ return nil, fmt.Errorf("malformed caveat context: %w", err)
+ }
+ caveat = &core.ContextualizedCaveat{
+ CaveatName: name,
+ Context: strct,
+ }
+ }
+ return caveat, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go
new file mode 100644
index 0000000..1eb64d1
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/index.go
@@ -0,0 +1,28 @@
+package common
+
+import "github.com/authzed/spicedb/pkg/datastore/queryshape"
+
+// IndexDefinition is a definition of an index for a datastore.
+type IndexDefinition struct {
+ // Name is the unique name for the index.
+ Name string
+
+ // ColumnsSQL is the SQL fragment of the columns over which this index will apply.
+ ColumnsSQL string
+
+ // Shapes are those query shapes for which this index should be used.
+ Shapes []queryshape.Shape
+
+ // IsDeprecated is true if this index is deprecated and should not be used.
+ IsDeprecated bool
+}
+
+// matchesShape returns true if the index matches the given shape.
+func (id IndexDefinition) matchesShape(shape queryshape.Shape) bool {
+ for _, s := range id.Shapes {
+ if s == shape {
+ return true
+ }
+ }
+ return false
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go
new file mode 100644
index 0000000..6e84549
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/logging.go
@@ -0,0 +1,15 @@
+package common
+
+import (
+ "context"
+
+ log "github.com/authzed/spicedb/internal/logging"
+)
+
+// LogOnError executes the function and logs the error.
+// Useful to avoid silently ignoring errors in defer statements
+func LogOnError(ctx context.Context, f func() error) {
+ if err := f(); err != nil {
+ log.Ctx(ctx).Err(err).Msg("datastore error")
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go
new file mode 100644
index 0000000..304f62c
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/migrations.go
@@ -0,0 +1,42 @@
+package common
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+type MigrationValidator struct {
+ additionalAllowedMigrations []string
+ headMigration string
+}
+
+func NewMigrationValidator(headMigration string, additionalAllowedMigrations []string) *MigrationValidator {
+ return &MigrationValidator{
+ additionalAllowedMigrations: additionalAllowedMigrations,
+ headMigration: headMigration,
+ }
+}
+
+// MigrationReadyState returns the readiness of the datastore for the given version.
+func (mv *MigrationValidator) MigrationReadyState(version string) datastore.ReadyState {
+ if version == mv.headMigration {
+ return datastore.ReadyState{IsReady: true}
+ }
+ if slices.Contains(mv.additionalAllowedMigrations, version) {
+ return datastore.ReadyState{IsReady: true}
+ }
+ var msgBuilder strings.Builder
+ msgBuilder.WriteString(fmt.Sprintf("datastore is not migrated: currently at revision %q, but requires %q", version, mv.headMigration))
+
+ if len(mv.additionalAllowedMigrations) > 0 {
+ msgBuilder.WriteString(fmt.Sprintf(" (additional allowed migrations: %v)", mv.additionalAllowedMigrations))
+ }
+ msgBuilder.WriteString(". Please run \"spicedb datastore migrate\".")
+ return datastore.ReadyState{
+ Message: msgBuilder.String(),
+ IsReady: false,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go
new file mode 100644
index 0000000..dee0ad5
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/relationships.go
@@ -0,0 +1,214 @@
+package common
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "time"
+
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+const errUnableToQueryRels = "unable to query relationships: %w"
+
+// Querier is an interface for querying the database.
+type Querier[R Rows] interface {
+ QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error
+}
+
+// Rows is a common interface for database rows reading.
+type Rows interface {
+ Scan(dest ...any) error
+ Next() bool
+ Err() error
+}
+
+type closeRowsWithError interface {
+ Rows
+ Close() error
+}
+
+type closeRows interface {
+ Rows
+ Close()
+}
+
+func runExplainIfNecessary[R Rows](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R], explainable datastore.Explainable) error {
+ if builder.SQLExplainCallbackForTest == nil {
+ return nil
+ }
+
+ // Determine the expected index names via the schema.
+ expectedIndexes := builder.Schema.expectedIndexesForShape(builder.queryShape)
+
+ // Run any pre-explain statements.
+ for _, statement := range explainable.PreExplainStatements() {
+ if err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error {
+ rows.Next()
+ return nil
+ }, statement); err != nil {
+ return fmt.Errorf(errUnableToQueryRels, err)
+ }
+ }
+
+ // Run the query with EXPLAIN ANALYZE.
+ sqlString, args, err := builder.SelectSQL()
+ if err != nil {
+ return fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ explainSQL, explainArgs, err := explainable.BuildExplainQuery(sqlString, args)
+ if err != nil {
+ return fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ err = tx.QueryFunc(ctx, func(ctx context.Context, rows R) error {
+ explainString := ""
+ for rows.Next() {
+ var explain string
+ if err := rows.Scan(&explain); err != nil {
+ return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err))
+ }
+ explainString += explain + "\n"
+ }
+ if explainString == "" {
+ return fmt.Errorf("received empty explain")
+ }
+
+ return builder.SQLExplainCallbackForTest(ctx, sqlString, args, builder.queryShape, explainString, expectedIndexes)
+ }, explainSQL, explainArgs...)
+ if err != nil {
+ return fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ return nil
+}
+
+// QueryRelationships queries relationships for the given query and transaction.
+func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R], explainable datastore.Explainable) (datastore.RelationshipIterator, error) {
+ span := trace.SpanFromContext(ctx)
+ sqlString, args, err := builder.SelectSQL()
+ if err != nil {
+ return nil, fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ if err := runExplainIfNecessary(ctx, builder, tx, explainable); err != nil {
+ return nil, err
+ }
+
+ var resourceObjectType string
+ var resourceObjectID string
+ var resourceRelation string
+ var subjectObjectType string
+ var subjectObjectID string
+ var subjectRelation string
+ var caveatName sql.NullString
+ var caveatCtx C
+ var expiration *time.Time
+
+ var integrityKeyID string
+ var integrityHash []byte
+ var timestamp time.Time
+
+ span.AddEvent("Selecting columns")
+ colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, &timestamp)
+ if err != nil {
+ return nil, fmt.Errorf(errUnableToQueryRels, err)
+ }
+
+ span.AddEvent("Returning iterator", trace.WithAttributes(attribute.Int("column-count", len(colsToSelect))))
+ return func(yield func(tuple.Relationship, error) bool) {
+ span.AddEvent("Issuing query to database")
+ err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error {
+ span.AddEvent("Query issued to database")
+
+ var r Rows = rows
+ if crwe, ok := r.(closeRowsWithError); ok {
+ defer LogOnError(ctx, crwe.Close)
+ } else if cr, ok := r.(closeRows); ok {
+ defer cr.Close()
+ }
+
+ relCount := 0
+ for rows.Next() {
+ if relCount == 0 {
+ span.AddEvent("First row returned")
+ }
+
+ if err := rows.Scan(colsToSelect...); err != nil {
+ return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err))
+ }
+
+ if relCount == 0 {
+ span.AddEvent("First row scanned")
+ }
+
+ var caveat *corev1.ContextualizedCaveat
+ if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone {
+ if caveatName.Valid {
+ var err error
+ caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx)
+ if err != nil {
+ return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err))
+ }
+ }
+ }
+
+ var integrity *corev1.RelationshipIntegrity
+ if integrityKeyID != "" {
+ integrity = &corev1.RelationshipIntegrity{
+ KeyId: integrityKeyID,
+ Hash: integrityHash,
+ HashedAt: timestamppb.New(timestamp),
+ }
+ }
+
+ if expiration != nil {
+ // Ensure the expiration is always read in UTC, since some datastores (like CRDB)
+ // will normalize to local time.
+ t := expiration.UTC()
+ expiration = &t
+ }
+
+ relCount++
+ if !yield(tuple.Relationship{
+ RelationshipReference: tuple.RelationshipReference{
+ Resource: tuple.ObjectAndRelation{
+ ObjectType: resourceObjectType,
+ ObjectID: resourceObjectID,
+ Relation: resourceRelation,
+ },
+ Subject: tuple.ObjectAndRelation{
+ ObjectType: subjectObjectType,
+ ObjectID: subjectObjectID,
+ Relation: subjectRelation,
+ },
+ },
+ OptionalCaveat: caveat,
+ OptionalExpiration: expiration,
+ OptionalIntegrity: integrity,
+ }, nil) {
+ return nil
+ }
+ }
+
+ span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount)))
+ if err := rows.Err(); err != nil {
+ return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err))
+ }
+
+ return nil
+ }, sqlString, args...)
+ if err != nil {
+ if !yield(tuple.Relationship{}, err) {
+ return
+ }
+ }
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go
new file mode 100644
index 0000000..6e44d0b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/schema.go
@@ -0,0 +1,188 @@
+package common
+
+import (
+ sq "github.com/Masterminds/squirrel"
+
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+const (
+ relationshipStandardColumnCount = 6 // ColNamespace, ColObjectID, ColRelation, ColUsersetNamespace, ColUsersetObjectID, ColUsersetRelation
+ relationshipCaveatColumnCount = 2 // ColCaveatName, ColCaveatContext
+ relationshipExpirationColumnCount = 1 // ColExpiration
+ relationshipIntegrityColumnCount = 3 // ColIntegrityKeyID, ColIntegrityHash, ColIntegrityTimestamp
+)
+
+// SchemaInformation holds the schema information from the SQL datastore implementation.
+//
+//go:generate go run github.com/ecordell/optgen -output zz_generated.schema_options.go . SchemaInformation
+type SchemaInformation struct {
+ RelationshipTableName string `debugmap:"visible"`
+
+ ColNamespace string `debugmap:"visible"`
+ ColObjectID string `debugmap:"visible"`
+ ColRelation string `debugmap:"visible"`
+ ColUsersetNamespace string `debugmap:"visible"`
+ ColUsersetObjectID string `debugmap:"visible"`
+ ColUsersetRelation string `debugmap:"visible"`
+
+ ColCaveatName string `debugmap:"visible"`
+ ColCaveatContext string `debugmap:"visible"`
+
+ ColExpiration string `debugmap:"visible"`
+
+ ColIntegrityKeyID string `debugmap:"visible"`
+ ColIntegrityHash string `debugmap:"visible"`
+ ColIntegrityTimestamp string `debugmap:"visible"`
+
+ // Indexes are the indexes to use for this schema.
+ Indexes []IndexDefinition `debugmap:"visible"`
+
+ // PaginationFilterType is the type of pagination filter to use for this schema.
+ PaginationFilterType PaginationFilterType `debugmap:"visible"`
+
+ // PlaceholderFormat is the format of placeholders to use for this schema.
+ PlaceholderFormat sq.PlaceholderFormat `debugmap:"visible"`
+
+ // NowFunction is the function to use to get the current time in the datastore.
+ NowFunction string `debugmap:"visible"`
+
+ // ColumnOptimization is the optimization to use for columns in the schema, if any.
+ ColumnOptimization ColumnOptimizationOption `debugmap:"visible"`
+
+ // IntegrityEnabled is a flag to indicate if the schema has integrity columns.
+ IntegrityEnabled bool `debugmap:"visible"`
+
+ // ExpirationDisabled is a flag to indicate whether expiration support is disabled.
+ ExpirationDisabled bool `debugmap:"visible"`
+
+ // SortByResourceColumnOrder is the order of the resource columns in the schema to use
+ // when sorting by resource. If unspecified, the default will be used.
+ SortByResourceColumnOrder []string `debugmap:"visible"`
+
+ // SortBySubjectColumnOrder is the order of the subject columns in the schema to use
+ // when sorting by subject. If unspecified, the default will be used.
+ SortBySubjectColumnOrder []string `debugmap:"visible"`
+}
+
+// expectedIndexesForShape returns the expected index names for a given query shape.
+func (si SchemaInformation) expectedIndexesForShape(shape queryshape.Shape) options.SQLIndexInformation {
+ expectedIndexes := options.SQLIndexInformation{}
+ for _, index := range si.Indexes {
+ if index.matchesShape(shape) {
+ expectedIndexes.ExpectedIndexNames = append(expectedIndexes.ExpectedIndexNames, index.Name)
+ }
+ }
+ return expectedIndexes
+}
+
+func (si SchemaInformation) debugValidate() {
+ spiceerrors.DebugAssert(func() bool {
+ si.mustValidate()
+ return true
+ }, "SchemaInformation failed to validate")
+}
+
+func (si SchemaInformation) sortByResourceColumnOrderColumns() []string {
+ if len(si.SortByResourceColumnOrder) > 0 {
+ return si.SortByResourceColumnOrder
+ }
+
+ return []string{
+ si.ColNamespace,
+ si.ColObjectID,
+ si.ColRelation,
+ si.ColUsersetNamespace,
+ si.ColUsersetObjectID,
+ si.ColUsersetRelation,
+ }
+}
+
+func (si SchemaInformation) sortBySubjectColumnOrderColumns() []string {
+ if len(si.SortBySubjectColumnOrder) > 0 {
+ return si.SortBySubjectColumnOrder
+ }
+
+ return []string{
+ si.ColUsersetNamespace,
+ si.ColUsersetObjectID,
+ si.ColUsersetRelation,
+ si.ColNamespace,
+ si.ColObjectID,
+ si.ColRelation,
+ }
+}
+
+func (si SchemaInformation) mustValidate() {
+ if si.RelationshipTableName == "" {
+ panic("RelationshipTableName is required")
+ }
+
+ if si.ColNamespace == "" {
+ panic("ColNamespace is required")
+ }
+
+ if si.ColObjectID == "" {
+ panic("ColObjectID is required")
+ }
+
+ if si.ColRelation == "" {
+ panic("ColRelation is required")
+ }
+
+ if si.ColUsersetNamespace == "" {
+ panic("ColUsersetNamespace is required")
+ }
+
+ if si.ColUsersetObjectID == "" {
+ panic("ColUsersetObjectID is required")
+ }
+
+ if si.ColUsersetRelation == "" {
+ panic("ColUsersetRelation is required")
+ }
+
+ if si.ColCaveatName == "" {
+ panic("ColCaveatName is required")
+ }
+
+ if si.ColCaveatContext == "" {
+ panic("ColCaveatContext is required")
+ }
+
+ if si.ColExpiration == "" {
+ panic("ColExpiration is required")
+ }
+
+ if si.IntegrityEnabled {
+ if si.ColIntegrityKeyID == "" {
+ panic("ColIntegrityKeyID is required")
+ }
+
+ if si.ColIntegrityHash == "" {
+ panic("ColIntegrityHash is required")
+ }
+
+ if si.ColIntegrityTimestamp == "" {
+ panic("ColIntegrityTimestamp is required")
+ }
+ }
+
+ if si.NowFunction == "" {
+ panic("NowFunction is required")
+ }
+
+ if si.ColumnOptimization == ColumnOptimizationOptionUnknown {
+ panic("ColumnOptimization is required")
+ }
+
+ if si.PaginationFilterType == PaginationFilterTypeUnknown {
+ panic("PaginationFilterType is required")
+ }
+
+ if si.PlaceholderFormat == nil {
+ panic("PlaceholderFormat is required")
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go
new file mode 100644
index 0000000..4972700
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sliceiter.go
@@ -0,0 +1,17 @@
+package common
+
+import (
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// NewSliceRelationshipIterator creates a datastore.RelationshipIterator instance from a materialized slice of tuples.
+func NewSliceRelationshipIterator(rels []tuple.Relationship) datastore.RelationshipIterator {
+ return func(yield func(tuple.Relationship, error) bool) {
+ for _, rel := range rels {
+ if !yield(rel, nil) {
+ break
+ }
+ }
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go
new file mode 100644
index 0000000..ba9c4f6
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sql.go
@@ -0,0 +1,961 @@
+package common
+
+import (
+ "context"
+ "fmt"
+ "maps"
+ "math"
+ "strings"
+ "time"
+
+ sq "github.com/Masterminds/squirrel"
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/jzelinskie/stringz"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+var (
+ // CaveatNameKey is a tracing attribute representing a caveat name
+ CaveatNameKey = attribute.Key("authzed.com/spicedb/sql/caveatName")
+
+ // ObjNamespaceNameKey is a tracing attribute representing the resource
+ // object type.
+ ObjNamespaceNameKey = attribute.Key("authzed.com/spicedb/sql/objNamespaceName")
+
+ // ObjRelationNameKey is a tracing attribute representing the resource
+ // relation.
+ ObjRelationNameKey = attribute.Key("authzed.com/spicedb/sql/objRelationName")
+
+ // ObjIDKey is a tracing attribute representing the resource object ID.
+ ObjIDKey = attribute.Key("authzed.com/spicedb/sql/objId")
+
+ // SubNamespaceNameKey is a tracing attribute representing the subject object
+ // type.
+ SubNamespaceNameKey = attribute.Key("authzed.com/spicedb/sql/subNamespaceName")
+
+ // SubRelationNameKey is a tracing attribute representing the subject
+ // relation.
+ SubRelationNameKey = attribute.Key("authzed.com/spicedb/sql/subRelationName")
+
+ // SubObjectIDKey is a tracing attribute representing the the subject object
+ // ID.
+ SubObjectIDKey = attribute.Key("authzed.com/spicedb/sql/subObjectId")
+
+ tracer = otel.Tracer("spicedb/internal/datastore/common")
+)
+
+// PaginationFilterType is an enumerator for pagination filter types.
+type PaginationFilterType uint8
+
+const (
+ PaginationFilterTypeUnknown PaginationFilterType = iota
+
+ // TupleComparison uses a comparison with a compound key,
+ // e.g. (namespace, object_id, relation) > ('ns', '123', 'viewer')
+ // which is not compatible with all datastores.
+ TupleComparison = 1
+
+ // ExpandedLogicComparison comparison uses a nested tree of ANDs and ORs to properly
+ // filter out already received relationships. Useful for databases that do not support
+ // tuple comparison, or do not execute it efficiently
+ ExpandedLogicComparison = 2
+)
+
+// ColumnOptimizationOption is an enumerator for column optimization options.
+type ColumnOptimizationOption int
+
+const (
+ ColumnOptimizationOptionUnknown ColumnOptimizationOption = iota
+
+ // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns.
+ ColumnOptimizationOptionNone
+
+ // ColumnOptimizationOptionStaticValues is an option that optimizes columns for static values.
+ ColumnOptimizationOptionStaticValues
+)
+
+type columnTracker struct {
+ SingleValue *string
+}
+
+type columnTrackerMap map[string]columnTracker
+
+func (ctm columnTrackerMap) hasStaticValue(columnName string) bool {
+ if r, ok := ctm[columnName]; ok && r.SingleValue != nil {
+ return true
+ }
+ return false
+}
+
+// SchemaQueryFilterer wraps a SchemaInformation and SelectBuilder to give an opinionated
+// way to build query objects.
+type SchemaQueryFilterer struct {
+ schema SchemaInformation
+ queryBuilder sq.SelectBuilder
+ filteringColumnTracker columnTrackerMap
+ filterMaximumIDCount uint16
+ isCustomQuery bool
+ extraFields []string
+ fromSuffix string
+ fromTable string
+ indexingHint IndexingHint
+}
+
+// IndexingHint is an interface that can be implemented to provide a hint for the SQL query.
+type IndexingHint interface {
+ // SQLPrefix returns the SQL prefix to be used for the indexing hint, if any.
+ SQLPrefix() (string, error)
+
+ // FromTable returns the table name to be used for the indexing hint, if any.
+ FromTable(existingTableName string) (string, error)
+
+ // FromSQLSuffix returns the suffix to be used for the indexing hint, if any.
+ FromSQLSuffix() (string, error)
+}
+
+// NewSchemaQueryFiltererForRelationshipsSelect creates a new SchemaQueryFilterer object for selecting
+// relationships. This method will automatically filter the columns retrieved from the database, only
+// selecting the columns that are not already specified with a single static value in the query.
+func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer {
+ schema.debugValidate()
+
+ if filterMaximumIDCount == 0 {
+ filterMaximumIDCount = 100
+ log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100")
+ }
+
+ queryBuilder := sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFormat).Select()
+ return SchemaQueryFilterer{
+ schema: schema,
+ queryBuilder: queryBuilder,
+ filteringColumnTracker: map[string]columnTracker{},
+ filterMaximumIDCount: filterMaximumIDCount,
+ isCustomQuery: false,
+ extraFields: extraFields,
+ fromTable: "",
+ }
+}
+
+// NewSchemaQueryFiltererWithStartingQuery creates a new SchemaQueryFilterer object for selecting
+// relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect,
+// this method will not auto-filter the columns retrieved from the database.
+func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer {
+ schema.debugValidate()
+
+ if filterMaximumIDCount == 0 {
+ filterMaximumIDCount = 100
+ log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100")
+ }
+
+ return SchemaQueryFilterer{
+ schema: schema,
+ queryBuilder: startingQuery,
+ filteringColumnTracker: map[string]columnTracker{},
+ filterMaximumIDCount: filterMaximumIDCount,
+ isCustomQuery: true,
+ extraFields: nil,
+ fromTable: "",
+ }
+}
+
+// WithAdditionalFilter returns the SchemaQueryFilterer with an additional filter applied to the query.
+func (sqf SchemaQueryFilterer) WithAdditionalFilter(filter func(original sq.SelectBuilder) sq.SelectBuilder) SchemaQueryFilterer {
+ sqf.queryBuilder = filter(sqf.queryBuilder)
+ return sqf
+}
+
+// WithFromTable returns the SchemaQueryFilterer with a custom FROM table.
+func (sqf SchemaQueryFilterer) WithFromTable(fromTable string) SchemaQueryFilterer {
+ sqf.fromTable = fromTable
+ return sqf
+}
+
+// WithFromSuffix returns the SchemaQueryFilterer with a suffix added to the FROM clause.
+func (sqf SchemaQueryFilterer) WithFromSuffix(fromSuffix string) SchemaQueryFilterer {
+ sqf.fromSuffix = fromSuffix
+ return sqf
+}
+
+// WithIndexingHint returns the SchemaQueryFilterer with an indexing hint applied to the query.
+func (sqf SchemaQueryFilterer) WithIndexingHint(indexingHint IndexingHint) SchemaQueryFilterer {
+ sqf.indexingHint = indexingHint
+ return sqf
+}
+
+func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder {
+ spiceerrors.DebugAssert(func() bool {
+ return sqf.isCustomQuery
+ }, "UnderlyingQueryBuilder should only be called on custom queries")
+ return sqf.queryBuilderWithMaybeExpirationFilter(false)
+}
+
+// queryBuilderWithMaybeExpirationFilter returns the query builder with the expiration filter applied, when necessary.
+// Note that this adds the clause to the existing builder.
+func (sqf SchemaQueryFilterer) queryBuilderWithMaybeExpirationFilter(skipExpiration bool) sq.SelectBuilder {
+ if sqf.schema.ExpirationDisabled || skipExpiration {
+ return sqf.queryBuilder
+ }
+
+ // Filter out any expired relationships.
+ return sqf.queryBuilder.Where(sq.Or{
+ sq.Eq{sqf.schema.ColExpiration: nil},
+ sq.Expr(sqf.schema.ColExpiration + " > " + sqf.schema.NowFunction + "()"),
+ })
+}
+
+func (sqf SchemaQueryFilterer) TupleOrder(order options.SortOrder) SchemaQueryFilterer {
+ switch order {
+ case options.ByResource:
+ sqf.queryBuilder = sqf.queryBuilder.OrderBy(sqf.schema.sortByResourceColumnOrderColumns()...)
+
+ case options.BySubject:
+ sqf.queryBuilder = sqf.queryBuilder.OrderBy(sqf.schema.sortBySubjectColumnOrderColumns()...)
+ }
+
+ return sqf
+}
+
+type nameAndValue struct {
+ name string
+ value string
+}
+
+func columnsAndValuesForSort(
+ order options.SortOrder,
+ schema SchemaInformation,
+ cursor options.Cursor,
+) ([]nameAndValue, error) {
+ var columnNames []string
+
+ switch order {
+ case options.ByResource:
+ columnNames = schema.sortByResourceColumnOrderColumns()
+
+ case options.BySubject:
+ columnNames = schema.sortBySubjectColumnOrderColumns()
+
+ default:
+ return nil, spiceerrors.MustBugf("invalid sort order %q", order)
+ }
+
+ nameAndValues := make([]nameAndValue, 0, len(columnNames))
+ for _, columnName := range columnNames {
+ switch columnName {
+ case schema.ColNamespace:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Resource.ObjectType,
+ })
+
+ case schema.ColObjectID:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Resource.ObjectID,
+ })
+
+ case schema.ColRelation:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Resource.Relation,
+ })
+
+ case schema.ColUsersetNamespace:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Subject.ObjectType,
+ })
+
+ case schema.ColUsersetObjectID:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Subject.ObjectID,
+ })
+
+ case schema.ColUsersetRelation:
+ nameAndValues = append(nameAndValues, nameAndValue{
+ name: columnName,
+ value: cursor.Subject.Relation,
+ })
+
+ default:
+ return nil, spiceerrors.MustBugf("invalid column name %q", columnName)
+ }
+ }
+
+ return nameAndValues, nil
+}
+
+func (sqf SchemaQueryFilterer) MustAfter(cursor options.Cursor, order options.SortOrder) SchemaQueryFilterer {
+ updated, err := sqf.After(cursor, order)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOrder) (SchemaQueryFilterer, error) {
+ spiceerrors.DebugAssertNotNil(cursor, "cursor cannot be nil")
+
+ // NOTE: The ordering of these columns can affect query performance, be aware when changing.
+ columnsAndValues, err := columnsAndValuesForSort(order, sqf.schema, cursor)
+ if err != nil {
+ return sqf, err
+ }
+
+ switch sqf.schema.PaginationFilterType {
+ case TupleComparison:
+ // For performance reasons, remove any column names that have static values in the query.
+ columnNames := make([]string, 0, len(columnsAndValues))
+ valueSlots := make([]any, 0, len(columnsAndValues))
+ comparisonSlotCount := 0
+
+ for _, cav := range columnsAndValues {
+ if !sqf.filteringColumnTracker.hasStaticValue(cav.name) {
+ columnNames = append(columnNames, cav.name)
+ valueSlots = append(valueSlots, cav.value)
+ comparisonSlotCount++
+ }
+ }
+
+ if comparisonSlotCount > 0 {
+ comparisonTuple := "(" + strings.Join(columnNames, ",") + ") > (" + strings.Repeat(",?", comparisonSlotCount)[1:] + ")"
+ sqf.queryBuilder = sqf.queryBuilder.Where(
+ comparisonTuple,
+ valueSlots...,
+ )
+ }
+
+ case ExpandedLogicComparison:
+ // For performance reasons, remove any column names that have static values in the query.
+ orClause := sq.Or{}
+
+ for index, cav := range columnsAndValues {
+ if !sqf.filteringColumnTracker.hasStaticValue(cav.name) {
+ andClause := sq.And{}
+ for _, previous := range columnsAndValues[0:index] {
+ if !sqf.filteringColumnTracker.hasStaticValue(previous.name) {
+ andClause = append(andClause, sq.Eq{previous.name: previous.value})
+ }
+ }
+
+ andClause = append(andClause, sq.Gt{cav.name: cav.value})
+ orClause = append(orClause, andClause)
+ }
+ }
+
+ if len(orClause) > 0 {
+ sqf.queryBuilder = sqf.queryBuilder.Where(orClause)
+ }
+ }
+
+ return sqf, nil
+}
+
+// FilterToResourceType returns a new SchemaQueryFilterer that is limited to resources of the
+// specified type.
+func (sqf SchemaQueryFilterer) FilterToResourceType(resourceType string) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColNamespace: resourceType})
+ sqf.recordColumnValue(sqf.schema.ColNamespace, resourceType)
+ return sqf
+}
+
+func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string) {
+ existing, ok := sqf.filteringColumnTracker[colName]
+ if ok {
+ if existing.SingleValue != nil && *existing.SingleValue != colValue {
+ sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil}
+ }
+ } else {
+ sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: &colValue}
+ }
+}
+
+func (sqf SchemaQueryFilterer) recordVaryingColumnValue(colName string) {
+ sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil}
+}
+
+// FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the
+// specified ID.
+func (sqf SchemaQueryFilterer) FilterToResourceID(objectID string) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColObjectID: objectID})
+ sqf.recordColumnValue(sqf.schema.ColObjectID, objectID)
+ return sqf
+}
+
+func (sqf SchemaQueryFilterer) MustFilterToResourceIDs(resourceIds []string) SchemaQueryFilterer {
+ updated, err := sqf.FilterToResourceIDs(resourceIds)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+// FilterWithResourceIDPrefix returns new SchemaQueryFilterer that is limited to resources whose ID
+// starts with the specified prefix.
+func (sqf SchemaQueryFilterer) FilterWithResourceIDPrefix(prefix string) (SchemaQueryFilterer, error) {
+ if strings.Contains(prefix, "%") {
+ return sqf, spiceerrors.MustBugf("prefix cannot contain the percent sign")
+ }
+ if prefix == "" {
+ return sqf, spiceerrors.MustBugf("prefix cannot be empty")
+ }
+
+ prefix = strings.ReplaceAll(prefix, `\`, `\\`)
+ prefix = strings.ReplaceAll(prefix, "_", `\_`)
+
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.ColObjectID: prefix + "%"})
+
+ // NOTE: we do *not* record the use of the resource ID column here, because it is not used
+ // statically and thus is necessary for sorting operations.
+ return sqf, nil
+}
+
+func (sqf SchemaQueryFilterer) MustFilterWithResourceIDPrefix(prefix string) SchemaQueryFilterer {
+ updated, err := sqf.FilterWithResourceIDPrefix(prefix)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+// FilterToResourceIDs returns a new SchemaQueryFilterer that is limited to resources with any of the
+// specified IDs.
+func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (SchemaQueryFilterer, error) {
+ spiceerrors.DebugAssert(func() bool {
+ return len(resourceIds) <= int(sqf.filterMaximumIDCount)
+ }, "cannot have more than %d resource IDs in a single filter", sqf.filterMaximumIDCount)
+
+ var builder strings.Builder
+ builder.WriteString(sqf.schema.ColObjectID)
+ builder.WriteString(" IN (")
+ args := make([]any, 0, len(resourceIds))
+
+ for _, resourceID := range resourceIds {
+ if len(resourceID) == 0 {
+ return sqf, spiceerrors.MustBugf("got empty resource ID")
+ }
+
+ args = append(args, resourceID)
+ sqf.recordColumnValue(sqf.schema.ColObjectID, resourceID)
+ }
+
+ builder.WriteString("?")
+ if len(resourceIds) > 1 {
+ builder.WriteString(strings.Repeat(",?", len(resourceIds)-1))
+ }
+ builder.WriteString(")")
+
+ sqf.queryBuilder = sqf.queryBuilder.Where(builder.String(), args...)
+ return sqf, nil
+}
+
+// FilterToRelation returns a new SchemaQueryFilterer that is limited to resources with the
+// specified relation.
+func (sqf SchemaQueryFilterer) FilterToRelation(relation string) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColRelation: relation})
+ sqf.recordColumnValue(sqf.schema.ColRelation, relation)
+ return sqf
+}
+
+// MustFilterWithRelationshipsFilter returns a new SchemaQueryFilterer that is limited to resources with
+// resources that match the specified filter.
+func (sqf SchemaQueryFilterer) MustFilterWithRelationshipsFilter(filter datastore.RelationshipsFilter) SchemaQueryFilterer {
+ updated, err := sqf.FilterWithRelationshipsFilter(filter)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+func (sqf SchemaQueryFilterer) FilterWithRelationshipsFilter(filter datastore.RelationshipsFilter) (SchemaQueryFilterer, error) {
+ csqf := sqf
+
+ if filter.OptionalResourceType != "" {
+ csqf = csqf.FilterToResourceType(filter.OptionalResourceType)
+ }
+
+ if filter.OptionalResourceRelation != "" {
+ csqf = csqf.FilterToRelation(filter.OptionalResourceRelation)
+ }
+
+ if len(filter.OptionalResourceIds) > 0 && filter.OptionalResourceIDPrefix != "" {
+ return csqf, spiceerrors.MustBugf("cannot filter by both resource IDs and ID prefix")
+ }
+
+ if len(filter.OptionalResourceIds) > 0 {
+ usqf, err := csqf.FilterToResourceIDs(filter.OptionalResourceIds)
+ if err != nil {
+ return csqf, err
+ }
+ csqf = usqf
+ }
+
+ if len(filter.OptionalResourceIDPrefix) > 0 {
+ usqf, err := csqf.FilterWithResourceIDPrefix(filter.OptionalResourceIDPrefix)
+ if err != nil {
+ return csqf, err
+ }
+ csqf = usqf
+ }
+
+ if len(filter.OptionalSubjectsSelectors) > 0 {
+ usqf, err := csqf.FilterWithSubjectsSelectors(filter.OptionalSubjectsSelectors...)
+ if err != nil {
+ return csqf, err
+ }
+ csqf = usqf
+ }
+
+ switch filter.OptionalCaveatNameFilter.Option {
+ case datastore.CaveatFilterOptionHasMatchingCaveat:
+ spiceerrors.DebugAssert(func() bool {
+ return filter.OptionalCaveatNameFilter.CaveatName != ""
+ }, "caveat name must be set when using HasMatchingCaveat")
+ csqf = csqf.FilterWithCaveatName(filter.OptionalCaveatNameFilter.CaveatName)
+
+ case datastore.CaveatFilterOptionNoCaveat:
+ csqf = csqf.FilterWithNoCaveat()
+
+ case datastore.CaveatFilterOptionNone:
+ // No action needed, as this is the default behavior.
+
+ default:
+ return csqf, spiceerrors.MustBugf("unknown caveat filter option: %v", filter.OptionalCaveatNameFilter.Option)
+ }
+
+ if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionHasExpiration {
+ csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.ColExpiration: nil})
+ spiceerrors.DebugAssert(func() bool { return !sqf.schema.ExpirationDisabled }, "expiration filter requested but schema does not support expiration")
+ } else if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionNoExpiration {
+ csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.ColExpiration: nil})
+ }
+
+ return csqf, nil
+}
+
+// MustFilterWithSubjectsSelectors returns a new SchemaQueryFilterer that is limited to resources with
+// subjects that match the specified selector(s).
+func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) SchemaQueryFilterer {
+ usqf, err := sqf.FilterWithSubjectsSelectors(selectors...)
+ if err != nil {
+ panic(err)
+ }
+ return usqf
+}
+
+// FilterWithSubjectsSelectors returns a new SchemaQueryFilterer that is limited to resources with
+// subjects that match the specified selector(s).
+func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) {
+ selectorsOrClause := sq.Or{}
+
+ // If there is more than a single filter, record all the subjects as varying, as the subjects returned
+ // can differ for each branch.
+ // TODO(jschorr): Optimize this further where applicable.
+ if len(selectors) > 1 {
+ sqf.recordVaryingColumnValue(sqf.schema.ColUsersetNamespace)
+ sqf.recordVaryingColumnValue(sqf.schema.ColUsersetObjectID)
+ sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation)
+ }
+
+ for _, selector := range selectors {
+ selectorClause := sq.And{}
+
+ if len(selector.OptionalSubjectType) > 0 {
+ selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetNamespace: selector.OptionalSubjectType})
+ sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, selector.OptionalSubjectType)
+ }
+
+ if len(selector.OptionalSubjectIds) > 0 {
+ spiceerrors.DebugAssert(func() bool {
+ return len(selector.OptionalSubjectIds) <= int(sqf.filterMaximumIDCount)
+ }, "cannot have more than %d subject IDs in a single filter", sqf.filterMaximumIDCount)
+
+ var builder strings.Builder
+ builder.WriteString(sqf.schema.ColUsersetObjectID)
+ builder.WriteString(" IN (")
+ args := make([]any, 0, len(selector.OptionalSubjectIds))
+
+ for _, subjectID := range selector.OptionalSubjectIds {
+ if len(subjectID) == 0 {
+ return sqf, spiceerrors.MustBugf("got empty subject ID")
+ }
+
+ args = append(args, subjectID)
+ sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, subjectID)
+ }
+
+ builder.WriteString("?")
+ if len(selector.OptionalSubjectIds) > 1 {
+ builder.WriteString(strings.Repeat(",?", len(selector.OptionalSubjectIds)-1))
+ }
+
+ builder.WriteString(")")
+ selectorClause = append(selectorClause, sq.Expr(builder.String(), args...))
+ }
+
+ if !selector.RelationFilter.IsEmpty() {
+ if selector.RelationFilter.OnlyNonEllipsisRelations {
+ selectorClause = append(selectorClause, sq.NotEq{sqf.schema.ColUsersetRelation: datastore.Ellipsis})
+ sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation)
+ } else {
+ relations := make([]string, 0, 2)
+ if selector.RelationFilter.IncludeEllipsisRelation {
+ relations = append(relations, datastore.Ellipsis)
+ }
+
+ if selector.RelationFilter.NonEllipsisRelation != "" {
+ relations = append(relations, selector.RelationFilter.NonEllipsisRelation)
+ }
+
+ if len(relations) == 1 {
+ relName := relations[0]
+ selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetRelation: relName})
+ sqf.recordColumnValue(sqf.schema.ColUsersetRelation, relName)
+ } else {
+ orClause := sq.Or{}
+ for _, relationName := range relations {
+ dsRelationName := stringz.DefaultEmpty(relationName, datastore.Ellipsis)
+ orClause = append(orClause, sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName})
+ sqf.recordColumnValue(sqf.schema.ColUsersetRelation, dsRelationName)
+ }
+
+ selectorClause = append(selectorClause, orClause)
+ }
+ }
+ }
+
+ selectorsOrClause = append(selectorsOrClause, selectorClause)
+ }
+
+ sqf.queryBuilder = sqf.queryBuilder.Where(selectorsOrClause)
+ return sqf, nil
+}
+
+// FilterToSubjectFilter returns a new SchemaQueryFilterer that is limited to resources with
+// subjects that match the specified filter.
+func (sqf SchemaQueryFilterer) FilterToSubjectFilter(filter *v1.SubjectFilter) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetNamespace: filter.SubjectType})
+ sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, filter.SubjectType)
+
+ if filter.OptionalSubjectId != "" {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetObjectID: filter.OptionalSubjectId})
+ sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, filter.OptionalSubjectId)
+ }
+
+ if filter.OptionalRelation != nil {
+ dsRelationName := stringz.DefaultEmpty(filter.OptionalRelation.Relation, datastore.Ellipsis)
+
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName})
+ sqf.recordColumnValue(sqf.schema.ColUsersetRelation, datastore.Ellipsis)
+ }
+
+ return sqf
+}
+
+// FilterWithCaveatName returns a new SchemaQueryFilterer that is limited to resources with the
+// specified caveat name.
+func (sqf SchemaQueryFilterer) FilterWithCaveatName(caveatName string) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColCaveatName: caveatName})
+ sqf.recordColumnValue(sqf.schema.ColCaveatName, caveatName)
+ return sqf
+}
+
+// FilterWithNoCaveat returns a new SchemaQueryFilterer that is limited to resources with no caveat.
+func (sqf SchemaQueryFilterer) FilterWithNoCaveat() SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Where(
+ sq.Or{
+ sq.Eq{sqf.schema.ColCaveatName: nil},
+ sq.Eq{sqf.schema.ColCaveatName: ""},
+ })
+ sqf.recordVaryingColumnValue(sqf.schema.ColCaveatName)
+ return sqf
+}
+
+// Limit returns a new SchemaQueryFilterer which is limited to the specified number of results.
+func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer {
+ sqf.queryBuilder = sqf.queryBuilder.Limit(limit)
+ return sqf
+}
+
+// QueryRelationshipsExecutor is a relationships query runner shared by SQL implementations of the datastore.
+type QueryRelationshipsExecutor struct {
+ Executor ExecuteReadRelsQueryFunc
+}
+
+// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query.
+type ExecuteReadRelsQueryFunc func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error)
+
+// ExecuteQuery executes the query.
+func (exc QueryRelationshipsExecutor) ExecuteQuery(
+ ctx context.Context,
+ query SchemaQueryFilterer,
+ opts ...options.QueryOptionsOption,
+) (datastore.RelationshipIterator, error) {
+ if query.isCustomQuery {
+ return nil, spiceerrors.MustBugf("ExecuteQuery should not be called on custom queries")
+ }
+
+ queryOpts := options.NewQueryOptionsWithOptions(opts...)
+
+ // Add sort order.
+ query = query.TupleOrder(queryOpts.Sort)
+
+ // Add cursor.
+ if queryOpts.After != nil {
+ if queryOpts.Sort == options.Unsorted {
+ return nil, datastore.ErrCursorsWithoutSorting
+ }
+
+ q, err := query.After(queryOpts.After, queryOpts.Sort)
+ if err != nil {
+ return nil, err
+ }
+ query = q
+ }
+
+ // Add limit.
+ var limit uint64
+ // NOTE: we use a uint here because it lines up with the
+ // assignments in this function, but we set it to MaxInt64
+ // because that's the biggest value that postgres and friends
+ // treat as valid.
+ limit = math.MaxInt64
+ if queryOpts.Limit != nil {
+ limit = *queryOpts.Limit
+ }
+
+ if limit < math.MaxInt64 {
+ query = query.limit(limit)
+ }
+
+ // Add FROM clause.
+ from := query.schema.RelationshipTableName
+ if query.fromTable != "" {
+ from = query.fromTable
+ }
+
+ // Add index hints, if any.
+ if query.indexingHint != nil {
+ // Check for a SQL prefix (pg_hint_plan).
+ sqlPrefix, err := query.indexingHint.SQLPrefix()
+ if err != nil {
+ return nil, fmt.Errorf("error getting SQL prefix for indexing hint: %w", err)
+ }
+
+ if sqlPrefix != "" {
+ query.queryBuilder = query.queryBuilder.Prefix(sqlPrefix)
+ }
+
+ // Check for an adjusting FROM table name (CRDB).
+ fromTableName, err := query.indexingHint.FromTable(from)
+ if err != nil {
+ return nil, fmt.Errorf("error getting FROM table name for indexing hint: %w", err)
+ }
+ from = fromTableName
+
+ // Check for a SQL suffix (MySQL, Spanner).
+ fromSuffix, err := query.indexingHint.FromSQLSuffix()
+ if err != nil {
+ return nil, fmt.Errorf("error getting SQL suffix for indexing hint: %w", err)
+ }
+
+ if fromSuffix != "" {
+ from += " " + fromSuffix
+ }
+ }
+
+ if query.fromSuffix != "" {
+ from += " " + query.fromSuffix
+ }
+
+ query.queryBuilder = query.queryBuilder.From(from)
+
+ builder := RelationshipsQueryBuilder{
+ Schema: query.schema,
+ SkipCaveats: queryOpts.SkipCaveats,
+ SkipExpiration: queryOpts.SkipExpiration,
+ SQLCheckAssertionForTest: queryOpts.SQLCheckAssertionForTest,
+ SQLExplainCallbackForTest: queryOpts.SQLExplainCallbackForTest,
+ filteringValues: query.filteringColumnTracker,
+ queryShape: queryOpts.QueryShape,
+ baseQueryBuilder: query,
+ }
+
+ return exc.Executor(ctx, builder)
+}
+
+// RelationshipsQueryBuilder is a builder for producing the SQL and arguments necessary for reading
+// relationships.
+type RelationshipsQueryBuilder struct {
+ Schema SchemaInformation
+ SkipCaveats bool
+ SkipExpiration bool
+
+ filteringValues columnTrackerMap
+ baseQueryBuilder SchemaQueryFilterer
+ SQLCheckAssertionForTest options.SQLCheckAssertionForTest
+ SQLExplainCallbackForTest options.SQLExplainCallbackForTest
+ queryShape queryshape.Shape
+}
+
+// withCaveats returns true if caveats should be included in the query.
+func (b RelationshipsQueryBuilder) withCaveats() bool {
+ return !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone
+}
+
+// withExpiration returns true if expiration should be included in the query.
+func (b RelationshipsQueryBuilder) withExpiration() bool {
+ return !b.SkipExpiration && !b.Schema.ExpirationDisabled
+}
+
+// integrityEnabled returns true if integrity columns should be included in the query.
+func (b RelationshipsQueryBuilder) integrityEnabled() bool {
+ return b.Schema.IntegrityEnabled
+}
+
+// columnCount returns the number of columns that will be selected in the query.
+func (b RelationshipsQueryBuilder) columnCount() int {
+ columnCount := relationshipStandardColumnCount
+ if b.withCaveats() {
+ columnCount += relationshipCaveatColumnCount
+ }
+ if b.withExpiration() {
+ columnCount += relationshipExpirationColumnCount
+ }
+ if b.integrityEnabled() {
+ columnCount += relationshipIntegrityColumnCount
+ }
+ return columnCount
+}
+
+// SelectSQL returns the SQL and arguments necessary for reading relationships.
+func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) {
+ // Set the column names to select.
+ columnNamesToSelect := make([]string, 0, b.columnCount())
+
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColNamespace)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColObjectID)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColRelation)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetNamespace)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetObjectID)
+ columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetRelation)
+
+ if b.withCaveats() {
+ columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext)
+ }
+
+ if b.withExpiration() {
+ columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration)
+ }
+
+ if b.integrityEnabled() {
+ columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp)
+ }
+
+ if len(columnNamesToSelect) == 0 {
+ columnNamesToSelect = append(columnNamesToSelect, "1")
+ }
+
+ sqlBuilder := b.baseQueryBuilder.queryBuilderWithMaybeExpirationFilter(b.SkipExpiration)
+ sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...)
+
+ sql, args, err := sqlBuilder.ToSql()
+ if err != nil {
+ return "", nil, err
+ }
+
+ if b.SQLCheckAssertionForTest != nil {
+ b.SQLCheckAssertionForTest(sql)
+ }
+
+ return sql, args, nil
+}
+
+// FilteringValuesForTesting returns the filtering values. For test use only.
+func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]columnTracker {
+ return maps.Clone(b.filteringValues)
+}
+
+func (b RelationshipsQueryBuilder) checkColumn(columns []string, colName string) []string {
+ if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone {
+ return append(columns, colName)
+ }
+
+ if !b.filteringValues.hasStaticValue(colName) {
+ return append(columns, colName)
+ }
+
+ return columns
+}
+
+func (b RelationshipsQueryBuilder) staticValueOrAddColumnForSelect(colsToSelect []any, colName string, field *string) []any {
+ if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone {
+ // If column optimization is disabled, always add the column to the list of columns to select.
+ colsToSelect = append(colsToSelect, field)
+ return colsToSelect
+ }
+
+ // If the value is static, set the field to it and return.
+ if found, ok := b.filteringValues[colName]; ok && found.SingleValue != nil {
+ *field = *found.SingleValue
+ return colsToSelect
+ }
+
+ // Otherwise, add the column to the list of columns to select, as the value is not static.
+ colsToSelect = append(colsToSelect, field)
+ return colsToSelect
+}
+
+// ColumnsToSelect returns the columns to select for a given query. The columns provided are
+// the references to the slots in which the values for each relationship will be placed.
+func ColumnsToSelect[CN any, CC any, EC any](
+ b RelationshipsQueryBuilder,
+ resourceObjectType *string,
+ resourceObjectID *string,
+ resourceRelation *string,
+ subjectObjectType *string,
+ subjectObjectID *string,
+ subjectRelation *string,
+ caveatName *CN,
+ caveatCtx *CC,
+ expiration EC,
+
+ integrityKeyID *string,
+ integrityHash *[]byte,
+ timestamp *time.Time,
+) ([]any, error) {
+ colsToSelect := make([]any, 0, b.columnCount())
+
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColNamespace, resourceObjectType)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColObjectID, resourceObjectID)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColRelation, resourceRelation)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetNamespace, subjectObjectType)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetObjectID, subjectObjectID)
+ colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetRelation, subjectRelation)
+
+ if b.withCaveats() {
+ colsToSelect = append(colsToSelect, caveatName, caveatCtx)
+ }
+
+ if b.withExpiration() {
+ colsToSelect = append(colsToSelect, expiration)
+ }
+
+ if b.Schema.IntegrityEnabled {
+ colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp)
+ }
+
+ if len(colsToSelect) == 0 {
+ var unused int64
+ colsToSelect = append(colsToSelect, &unused)
+ }
+
+ return colsToSelect, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go
new file mode 100644
index 0000000..fa23efc
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/sqlerrors.go
@@ -0,0 +1,31 @@
+package common
+
+import (
+ "context"
+ "errors"
+ "strings"
+)
+
+// IsCancellationError determines if an error returned by pgx has been caused by context cancellation.
+func IsCancellationError(err error) bool {
+ if errors.Is(err, context.Canceled) ||
+ errors.Is(err, context.DeadlineExceeded) ||
+ err.Error() == "conn closed" { // conns are sometimes closed async upon cancellation
+ return true
+ }
+ return false
+}
+
+// IsResettableError returns whether the given error is a resettable error.
+func IsResettableError(err error) bool {
+ // detect when an error is likely due to a node taken out of service
+ if strings.Contains(err.Error(), "broken pipe") ||
+ strings.Contains(err.Error(), "unexpected EOF") ||
+ strings.Contains(err.Error(), "conn closed") ||
+ strings.Contains(err.Error(), "connection refused") ||
+ strings.Contains(err.Error(), "connection reset by peer") {
+ return true
+ }
+
+ return false
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go
new file mode 100644
index 0000000..be665ed
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/url.go
@@ -0,0 +1,19 @@
+package common
+
+import (
+ "errors"
+ "net/url"
+)
+
+// MetricsIDFromURL extracts the metrics ID from a given datastore URL.
+func MetricsIDFromURL(dsURL string) (string, error) {
+ if dsURL == "" {
+ return "", errors.New("datastore URL is empty")
+ }
+
+ u, err := url.Parse(dsURL)
+ if err != nil {
+ return "", errors.New("could not parse datastore URL")
+ }
+ return u.Host + u.Path, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go b/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go
new file mode 100644
index 0000000..2caa57a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/common/zz_generated.schema_options.go
@@ -0,0 +1,276 @@
+// Code generated by github.com/ecordell/optgen. DO NOT EDIT.
+package common
+
+import (
+ squirrel "github.com/Masterminds/squirrel"
+ defaults "github.com/creasty/defaults"
+ helpers "github.com/ecordell/optgen/helpers"
+)
+
+type SchemaInformationOption func(s *SchemaInformation)
+
+// NewSchemaInformationWithOptions creates a new SchemaInformation with the passed in options set
+func NewSchemaInformationWithOptions(opts ...SchemaInformationOption) *SchemaInformation {
+ s := &SchemaInformation{}
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// NewSchemaInformationWithOptionsAndDefaults creates a new SchemaInformation with the passed in options set starting from the defaults
+func NewSchemaInformationWithOptionsAndDefaults(opts ...SchemaInformationOption) *SchemaInformation {
+ s := &SchemaInformation{}
+ defaults.MustSet(s)
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// ToOption returns a new SchemaInformationOption that sets the values from the passed in SchemaInformation
+func (s *SchemaInformation) ToOption() SchemaInformationOption {
+ return func(to *SchemaInformation) {
+ to.RelationshipTableName = s.RelationshipTableName
+ to.ColNamespace = s.ColNamespace
+ to.ColObjectID = s.ColObjectID
+ to.ColRelation = s.ColRelation
+ to.ColUsersetNamespace = s.ColUsersetNamespace
+ to.ColUsersetObjectID = s.ColUsersetObjectID
+ to.ColUsersetRelation = s.ColUsersetRelation
+ to.ColCaveatName = s.ColCaveatName
+ to.ColCaveatContext = s.ColCaveatContext
+ to.ColExpiration = s.ColExpiration
+ to.ColIntegrityKeyID = s.ColIntegrityKeyID
+ to.ColIntegrityHash = s.ColIntegrityHash
+ to.ColIntegrityTimestamp = s.ColIntegrityTimestamp
+ to.Indexes = s.Indexes
+ to.PaginationFilterType = s.PaginationFilterType
+ to.PlaceholderFormat = s.PlaceholderFormat
+ to.NowFunction = s.NowFunction
+ to.ColumnOptimization = s.ColumnOptimization
+ to.IntegrityEnabled = s.IntegrityEnabled
+ to.ExpirationDisabled = s.ExpirationDisabled
+ to.SortByResourceColumnOrder = s.SortByResourceColumnOrder
+ to.SortBySubjectColumnOrder = s.SortBySubjectColumnOrder
+ }
+}
+
+// DebugMap returns a map form of SchemaInformation for debugging
+func (s SchemaInformation) DebugMap() map[string]any {
+ debugMap := map[string]any{}
+ debugMap["RelationshipTableName"] = helpers.DebugValue(s.RelationshipTableName, false)
+ debugMap["ColNamespace"] = helpers.DebugValue(s.ColNamespace, false)
+ debugMap["ColObjectID"] = helpers.DebugValue(s.ColObjectID, false)
+ debugMap["ColRelation"] = helpers.DebugValue(s.ColRelation, false)
+ debugMap["ColUsersetNamespace"] = helpers.DebugValue(s.ColUsersetNamespace, false)
+ debugMap["ColUsersetObjectID"] = helpers.DebugValue(s.ColUsersetObjectID, false)
+ debugMap["ColUsersetRelation"] = helpers.DebugValue(s.ColUsersetRelation, false)
+ debugMap["ColCaveatName"] = helpers.DebugValue(s.ColCaveatName, false)
+ debugMap["ColCaveatContext"] = helpers.DebugValue(s.ColCaveatContext, false)
+ debugMap["ColExpiration"] = helpers.DebugValue(s.ColExpiration, false)
+ debugMap["ColIntegrityKeyID"] = helpers.DebugValue(s.ColIntegrityKeyID, false)
+ debugMap["ColIntegrityHash"] = helpers.DebugValue(s.ColIntegrityHash, false)
+ debugMap["ColIntegrityTimestamp"] = helpers.DebugValue(s.ColIntegrityTimestamp, false)
+ debugMap["Indexes"] = helpers.DebugValue(s.Indexes, false)
+ debugMap["PaginationFilterType"] = helpers.DebugValue(s.PaginationFilterType, false)
+ debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false)
+ debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false)
+ debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false)
+ debugMap["IntegrityEnabled"] = helpers.DebugValue(s.IntegrityEnabled, false)
+ debugMap["ExpirationDisabled"] = helpers.DebugValue(s.ExpirationDisabled, false)
+ debugMap["SortByResourceColumnOrder"] = helpers.DebugValue(s.SortByResourceColumnOrder, false)
+ debugMap["SortBySubjectColumnOrder"] = helpers.DebugValue(s.SortBySubjectColumnOrder, false)
+ return debugMap
+}
+
+// SchemaInformationWithOptions configures an existing SchemaInformation with the passed in options set
+func SchemaInformationWithOptions(s *SchemaInformation, opts ...SchemaInformationOption) *SchemaInformation {
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// WithOptions configures the receiver SchemaInformation with the passed in options set
+func (s *SchemaInformation) WithOptions(opts ...SchemaInformationOption) *SchemaInformation {
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// WithRelationshipTableName returns an option that can set RelationshipTableName on a SchemaInformation
+func WithRelationshipTableName(relationshipTableName string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.RelationshipTableName = relationshipTableName
+ }
+}
+
+// WithColNamespace returns an option that can set ColNamespace on a SchemaInformation
+func WithColNamespace(colNamespace string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColNamespace = colNamespace
+ }
+}
+
+// WithColObjectID returns an option that can set ColObjectID on a SchemaInformation
+func WithColObjectID(colObjectID string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColObjectID = colObjectID
+ }
+}
+
+// WithColRelation returns an option that can set ColRelation on a SchemaInformation
+func WithColRelation(colRelation string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColRelation = colRelation
+ }
+}
+
+// WithColUsersetNamespace returns an option that can set ColUsersetNamespace on a SchemaInformation
+func WithColUsersetNamespace(colUsersetNamespace string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColUsersetNamespace = colUsersetNamespace
+ }
+}
+
+// WithColUsersetObjectID returns an option that can set ColUsersetObjectID on a SchemaInformation
+func WithColUsersetObjectID(colUsersetObjectID string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColUsersetObjectID = colUsersetObjectID
+ }
+}
+
+// WithColUsersetRelation returns an option that can set ColUsersetRelation on a SchemaInformation
+func WithColUsersetRelation(colUsersetRelation string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColUsersetRelation = colUsersetRelation
+ }
+}
+
+// WithColCaveatName returns an option that can set ColCaveatName on a SchemaInformation
+func WithColCaveatName(colCaveatName string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColCaveatName = colCaveatName
+ }
+}
+
+// WithColCaveatContext returns an option that can set ColCaveatContext on a SchemaInformation
+func WithColCaveatContext(colCaveatContext string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColCaveatContext = colCaveatContext
+ }
+}
+
+// WithColExpiration returns an option that can set ColExpiration on a SchemaInformation
+func WithColExpiration(colExpiration string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColExpiration = colExpiration
+ }
+}
+
+// WithColIntegrityKeyID returns an option that can set ColIntegrityKeyID on a SchemaInformation
+func WithColIntegrityKeyID(colIntegrityKeyID string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColIntegrityKeyID = colIntegrityKeyID
+ }
+}
+
+// WithColIntegrityHash returns an option that can set ColIntegrityHash on a SchemaInformation
+func WithColIntegrityHash(colIntegrityHash string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColIntegrityHash = colIntegrityHash
+ }
+}
+
+// WithColIntegrityTimestamp returns an option that can set ColIntegrityTimestamp on a SchemaInformation
+func WithColIntegrityTimestamp(colIntegrityTimestamp string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColIntegrityTimestamp = colIntegrityTimestamp
+ }
+}
+
+// WithIndexes returns an option that can append Indexess to SchemaInformation.Indexes
+func WithIndexes(indexes IndexDefinition) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.Indexes = append(s.Indexes, indexes)
+ }
+}
+
+// SetIndexes returns an option that can set Indexes on a SchemaInformation
+func SetIndexes(indexes []IndexDefinition) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.Indexes = indexes
+ }
+}
+
+// WithPaginationFilterType returns an option that can set PaginationFilterType on a SchemaInformation
+func WithPaginationFilterType(paginationFilterType PaginationFilterType) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.PaginationFilterType = paginationFilterType
+ }
+}
+
+// WithPlaceholderFormat returns an option that can set PlaceholderFormat on a SchemaInformation
+func WithPlaceholderFormat(placeholderFormat squirrel.PlaceholderFormat) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.PlaceholderFormat = placeholderFormat
+ }
+}
+
+// WithNowFunction returns an option that can set NowFunction on a SchemaInformation
+func WithNowFunction(nowFunction string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.NowFunction = nowFunction
+ }
+}
+
+// WithColumnOptimization returns an option that can set ColumnOptimization on a SchemaInformation
+func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ColumnOptimization = columnOptimization
+ }
+}
+
+// WithIntegrityEnabled returns an option that can set IntegrityEnabled on a SchemaInformation
+func WithIntegrityEnabled(integrityEnabled bool) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.IntegrityEnabled = integrityEnabled
+ }
+}
+
+// WithExpirationDisabled returns an option that can set ExpirationDisabled on a SchemaInformation
+func WithExpirationDisabled(expirationDisabled bool) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.ExpirationDisabled = expirationDisabled
+ }
+}
+
+// WithSortByResourceColumnOrder returns an option that can append SortByResourceColumnOrders to SchemaInformation.SortByResourceColumnOrder
+func WithSortByResourceColumnOrder(sortByResourceColumnOrder string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.SortByResourceColumnOrder = append(s.SortByResourceColumnOrder, sortByResourceColumnOrder)
+ }
+}
+
+// SetSortByResourceColumnOrder returns an option that can set SortByResourceColumnOrder on a SchemaInformation
+func SetSortByResourceColumnOrder(sortByResourceColumnOrder []string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.SortByResourceColumnOrder = sortByResourceColumnOrder
+ }
+}
+
+// WithSortBySubjectColumnOrder returns an option that can append SortBySubjectColumnOrders to SchemaInformation.SortBySubjectColumnOrder
+func WithSortBySubjectColumnOrder(sortBySubjectColumnOrder string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.SortBySubjectColumnOrder = append(s.SortBySubjectColumnOrder, sortBySubjectColumnOrder)
+ }
+}
+
+// SetSortBySubjectColumnOrder returns an option that can set SortBySubjectColumnOrder on a SchemaInformation
+func SetSortBySubjectColumnOrder(sortBySubjectColumnOrder []string) SchemaInformationOption {
+ return func(s *SchemaInformation) {
+ s.SortBySubjectColumnOrder = sortBySubjectColumnOrder
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md
new file mode 100644
index 0000000..de32e34
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/README.md
@@ -0,0 +1,23 @@
+# MemDB Datastore Implementation
+
+The `memdb` datastore implementation is based on Hashicorp's [go-memdb library](https://github.com/hashicorp/go-memdb).
+Its implementation most closely mimics that of `spanner`, or `crdb`, where there is a single immutable datastore that supports querying at any point in time.
+The `memdb` datastore is used for validating and rapidly iterating on concepts from consumers of other datastores.
+It is 100% compliant with the datastore acceptance test suite and it should be possible to use it in place of any other datastore for development purposes.
+Differences between the `memdb` datastore and other implementations that manifest themselves as differences visible to the caller should be reported as bugs.
+
+**The memdb datastore can NOT be used in a production setting!**
+
+## Implementation Caveats
+
+### No Garbage Collection
+
+This implementation of the datastore has no garbage collection, meaning that memory usage will grow monotonically with mutations.
+
+### No Durable Storage
+
+The `memdb` datastore, as its name implies, stores information entirely in memory, and therefore will lose all data when the host process terminates.
+
+### Cannot be used for multi-node dispatch
+
+If you attempt to run SpiceDB with multi-node dispatch enabled using the memory datastore, each independent node will get a separate copy of the datastore, and you will end up very confused.
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go
new file mode 100644
index 0000000..2b4baca
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/caveat.go
@@ -0,0 +1,156 @@
+package memdb
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/hashicorp/go-memdb"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+)
+
+const tableCaveats = "caveats"
+
+type caveat struct {
+ name string
+ definition []byte
+ revision datastore.Revision
+}
+
+func (c *caveat) Unwrap() (*core.CaveatDefinition, error) {
+ definition := core.CaveatDefinition{}
+ err := definition.UnmarshalVT(c.definition)
+ return &definition, err
+}
+
+func (r *memdbReader) ReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+ return r.readUnwrappedCaveatByName(tx, name)
+}
+
+func (r *memdbReader) readCaveatByName(tx *memdb.Txn, name string) (*caveat, datastore.Revision, error) {
+ found, err := tx.First(tableCaveats, indexID, name)
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+ if found == nil {
+ return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name)
+ }
+ cvt := found.(*caveat)
+ return cvt, cvt.revision, nil
+}
+
+func (r *memdbReader) readUnwrappedCaveatByName(tx *memdb.Txn, name string) (*core.CaveatDefinition, datastore.Revision, error) {
+ c, rev, err := r.readCaveatByName(tx, name)
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+ unwrapped, err := c.Unwrap()
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+ return unwrapped, rev, nil
+}
+
+func (r *memdbReader) ListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) {
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ var caveats []datastore.RevisionedCaveat
+ it, err := tx.LowerBound(tableCaveats, indexID)
+ if err != nil {
+ return nil, err
+ }
+
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ rawCaveat := foundRaw.(*caveat)
+ definition, err := rawCaveat.Unwrap()
+ if err != nil {
+ return nil, err
+ }
+ caveats = append(caveats, datastore.RevisionedCaveat{
+ Definition: definition,
+ LastWrittenRevision: rawCaveat.revision,
+ })
+ }
+
+ return caveats, nil
+}
+
+func (r *memdbReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
+ allCaveats, err := r.ListAllCaveats(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ allowedCaveatNames := mapz.NewSet[string]()
+ allowedCaveatNames.Extend(caveatNames)
+
+ toReturn := make([]datastore.RevisionedCaveat, 0, len(caveatNames))
+ for _, caveat := range allCaveats {
+ if allowedCaveatNames.Has(caveat.Definition.Name) {
+ toReturn = append(toReturn, caveat)
+ }
+ }
+ return toReturn, nil
+}
+
+func (rwt *memdbReadWriteTx) WriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+ return rwt.writeCaveat(tx, caveats)
+}
+
+func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.CaveatDefinition) error {
+ caveatNames := mapz.NewSet[string]()
+ for _, coreCaveat := range caveats {
+ if !caveatNames.Add(coreCaveat.Name) {
+ return fmt.Errorf("duplicate caveat %s", coreCaveat.Name)
+ }
+ marshalled, err := coreCaveat.MarshalVT()
+ if err != nil {
+ return err
+ }
+ c := caveat{
+ name: coreCaveat.Name,
+ definition: marshalled,
+ revision: rwt.newRevision,
+ }
+ if err := tx.Insert(tableCaveats, &c); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (rwt *memdbReadWriteTx) DeleteCaveats(_ context.Context, names []string) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+ for _, name := range names {
+ if err := tx.Delete(tableCaveats, caveat{name: name}); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go
new file mode 100644
index 0000000..0ef4b8b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/errors.go
@@ -0,0 +1,37 @@
+package memdb
+
+import (
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// SerializationMaxRetriesReachedError occurs when a write request has reached its maximum number
+// of retries due to serialization errors.
+type SerializationMaxRetriesReachedError struct {
+ error
+}
+
+// NewSerializationMaxRetriesReachedErr constructs a new max retries reached error.
+func NewSerializationMaxRetriesReachedErr(baseErr error) error {
+ return SerializationMaxRetriesReachedError{
+ error: baseErr,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err SerializationMaxRetriesReachedError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.DeadlineExceeded,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_UNSPECIFIED,
+ map[string]string{
+ "details": "too many updates were made to the in-memory datastore at once; this datastore has limited write throughput capability",
+ },
+ ),
+ )
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go
new file mode 100644
index 0000000..61eba84
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/memdb.go
@@ -0,0 +1,386 @@
+package memdb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/authzed/spicedb/internal/datastore/common"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+
+ "github.com/google/uuid"
+ "github.com/hashicorp/go-memdb"
+
+ "github.com/authzed/spicedb/internal/datastore/revisions"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
+)
+
+const (
+ Engine = "memory"
+ defaultWatchBufferLength = 128
+ numAttempts = 10
+)
+
+var (
+ ErrMemDBIsClosed = errors.New("datastore is closed")
+ ErrSerialization = errors.New("serialization error")
+)
+
+// DisableGC is a convenient constant for setting the garbage collection
+// interval high enough that it will never run.
+const DisableGC = time.Duration(math.MaxInt64)
+
+// NewMemdbDatastore creates a new Datastore compliant datastore backed by memdb.
+//
+// If the watchBufferLength value of 0 is set then a default value of 128 will be used.
+func NewMemdbDatastore(
+ watchBufferLength uint16,
+ revisionQuantization,
+ gcWindow time.Duration,
+) (datastore.Datastore, error) {
+ if revisionQuantization > gcWindow {
+ return nil, errors.New("gc window must be larger than quantization interval")
+ }
+
+ if revisionQuantization <= 1 {
+ revisionQuantization = 1
+ }
+
+ db, err := memdb.NewMemDB(schema)
+ if err != nil {
+ return nil, err
+ }
+
+ if watchBufferLength == 0 {
+ watchBufferLength = defaultWatchBufferLength
+ }
+
+ uniqueID := uuid.NewString()
+ return &memdbDatastore{
+ CommonDecoder: revisions.CommonDecoder{
+ Kind: revisions.Timestamp,
+ },
+ db: db,
+ revisions: []snapshot{
+ {
+ revision: nowRevision(),
+ db: db,
+ },
+ },
+
+ negativeGCWindow: gcWindow.Nanoseconds() * -1,
+ quantizationPeriod: revisionQuantization.Nanoseconds(),
+ watchBufferLength: watchBufferLength,
+ watchBufferWriteTimeout: 100 * time.Millisecond,
+ uniqueID: uniqueID,
+ }, nil
+}
+
+type memdbDatastore struct {
+ sync.RWMutex
+ revisions.CommonDecoder
+
+ // NOTE: call checkNotClosed before using
+ db *memdb.MemDB // GUARDED_BY(RWMutex)
+ revisions []snapshot // GUARDED_BY(RWMutex)
+ activeWriteTxn *memdb.Txn // GUARDED_BY(RWMutex)
+
+ negativeGCWindow int64
+ quantizationPeriod int64
+ watchBufferLength uint16
+ watchBufferWriteTimeout time.Duration
+ uniqueID string
+}
+
+type snapshot struct {
+ revision revisions.TimestampRevision
+ db *memdb.MemDB
+}
+
+func (mdb *memdbDatastore) MetricsID() (string, error) {
+ return "memdb", nil
+}
+
+func (mdb *memdbDatastore) SnapshotReader(dr datastore.Revision) datastore.Reader {
+ mdb.RLock()
+ defer mdb.RUnlock()
+
+ if err := mdb.checkNotClosed(); err != nil {
+ return &memdbReader{nil, nil, err, time.Now()}
+ }
+
+ if len(mdb.revisions) == 0 {
+ return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is not ready"), time.Now()}
+ }
+
+ if err := mdb.checkRevisionLocalCallerMustLock(dr); err != nil {
+ return &memdbReader{nil, nil, err, time.Now()}
+ }
+
+ revIndex := sort.Search(len(mdb.revisions), func(i int) bool {
+ return mdb.revisions[i].revision.GreaterThan(dr) || mdb.revisions[i].revision.Equal(dr)
+ })
+
+ // handle the case when there is no revision snapshot newer than the requested revision
+ if revIndex == len(mdb.revisions) {
+ revIndex = len(mdb.revisions) - 1
+ }
+
+ rev := mdb.revisions[revIndex]
+ if rev.db == nil {
+ return &memdbReader{nil, nil, fmt.Errorf("memdb datastore is already closed"), time.Now()}
+ }
+
+ roTxn := rev.db.Txn(false)
+ txSrc := func() (*memdb.Txn, error) {
+ return roTxn, nil
+ }
+
+ return &memdbReader{noopTryLocker{}, txSrc, nil, time.Now()}
+}
+
+func (mdb *memdbDatastore) SupportsIntegrity() bool {
+ return true
+}
+
+func (mdb *memdbDatastore) ReadWriteTx(
+ ctx context.Context,
+ f datastore.TxUserFunc,
+ opts ...options.RWTOptionsOption,
+) (datastore.Revision, error) {
+ config := options.NewRWTOptionsWithOptions(opts...)
+ txNumAttempts := numAttempts
+ if config.DisableRetries {
+ txNumAttempts = 1
+ }
+
+ for i := 0; i < txNumAttempts; i++ {
+ var tx *memdb.Txn
+ createTxOnce := sync.Once{}
+ txSrc := func() (*memdb.Txn, error) {
+ var err error
+ createTxOnce.Do(func() {
+ mdb.Lock()
+ defer mdb.Unlock()
+
+ if mdb.activeWriteTxn != nil {
+ err = ErrSerialization
+ return
+ }
+
+ if err = mdb.checkNotClosed(); err != nil {
+ return
+ }
+
+ tx = mdb.db.Txn(true)
+ tx.TrackChanges()
+ mdb.activeWriteTxn = tx
+ })
+
+ return tx, err
+ }
+
+ newRevision := mdb.newRevisionID()
+ rwt := &memdbReadWriteTx{memdbReader{&sync.Mutex{}, txSrc, nil, time.Now()}, newRevision}
+ if err := f(ctx, rwt); err != nil {
+ mdb.Lock()
+ if tx != nil {
+ tx.Abort()
+ mdb.activeWriteTxn = nil
+ }
+
+ // If the error was a serialization error, retry the transaction
+ if errors.Is(err, ErrSerialization) {
+ mdb.Unlock()
+
+ // If we don't sleep here, we run out of retries instantaneously
+ time.Sleep(1 * time.Millisecond)
+ continue
+ }
+ defer mdb.Unlock()
+
+ // We *must* return the inner error unmodified in case it's not an error type
+ // that supports unwrapping (e.g. gRPC errors)
+ return datastore.NoRevision, err
+ }
+
+ mdb.Lock()
+ defer mdb.Unlock()
+
+ tracked := common.NewChanges(revisions.TimestampIDKeyFunc, datastore.WatchRelationships|datastore.WatchSchema, 0)
+ if tx != nil {
+ if config.Metadata != nil && len(config.Metadata.GetFields()) > 0 {
+ if err := tracked.SetRevisionMetadata(ctx, newRevision, config.Metadata.AsMap()); err != nil {
+ return datastore.NoRevision, err
+ }
+ }
+
+ for _, change := range tx.Changes() {
+ switch change.Table {
+ case tableRelationship:
+ if change.After != nil {
+ rt, err := change.After.(*relationship).Relationship()
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+
+ if err := tracked.AddRelationshipChange(ctx, newRevision, rt, tuple.UpdateOperationTouch); err != nil {
+ return datastore.NoRevision, err
+ }
+ } else if change.After == nil && change.Before != nil {
+ rt, err := change.Before.(*relationship).Relationship()
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+
+ if err := tracked.AddRelationshipChange(ctx, newRevision, rt, tuple.UpdateOperationDelete); err != nil {
+ return datastore.NoRevision, err
+ }
+ } else {
+ return datastore.NoRevision, spiceerrors.MustBugf("unexpected relationship change")
+ }
+ case tableNamespace:
+ if change.After != nil {
+ loaded := &corev1.NamespaceDefinition{}
+ if err := loaded.UnmarshalVT(change.After.(*namespace).configBytes); err != nil {
+ return datastore.NoRevision, err
+ }
+
+ err := tracked.AddChangedDefinition(ctx, newRevision, loaded)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ } else if change.After == nil && change.Before != nil {
+ err := tracked.AddDeletedNamespace(ctx, newRevision, change.Before.(*namespace).name)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ } else {
+ return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change")
+ }
+ case tableCaveats:
+ if change.After != nil {
+ loaded := &corev1.CaveatDefinition{}
+ if err := loaded.UnmarshalVT(change.After.(*caveat).definition); err != nil {
+ return datastore.NoRevision, err
+ }
+
+ err := tracked.AddChangedDefinition(ctx, newRevision, loaded)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ } else if change.After == nil && change.Before != nil {
+ err := tracked.AddDeletedCaveat(ctx, newRevision, change.Before.(*caveat).name)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ } else {
+ return datastore.NoRevision, spiceerrors.MustBugf("unexpected namespace change")
+ }
+ }
+ }
+
+ var rc datastore.RevisionChanges
+ changes, err := tracked.AsRevisionChanges(revisions.TimestampIDKeyLessThanFunc)
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+
+ if len(changes) > 1 {
+ return datastore.NoRevision, spiceerrors.MustBugf("unexpected MemDB transaction with multiple revision changes")
+ } else if len(changes) == 1 {
+ rc = changes[0]
+ }
+
+ change := &changelog{
+ revisionNanos: newRevision.TimestampNanoSec(),
+ changes: rc,
+ }
+ if err := tx.Insert(tableChangelog, change); err != nil {
+ return datastore.NoRevision, fmt.Errorf("error writing changelog: %w", err)
+ }
+
+ tx.Commit()
+ }
+ mdb.activeWriteTxn = nil
+
+ if err := mdb.checkNotClosed(); err != nil {
+ return datastore.NoRevision, err
+ }
+
+ // Create a snapshot and add it to the revisions slice
+ snap := mdb.db.Snapshot()
+ mdb.revisions = append(mdb.revisions, snapshot{newRevision, snap})
+ return newRevision, nil
+ }
+
+ return datastore.NoRevision, NewSerializationMaxRetriesReachedErr(errors.New("serialization max retries exceeded; please reduce your parallel writes"))
+}
+
+func (mdb *memdbDatastore) ReadyState(_ context.Context) (datastore.ReadyState, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+
+ return datastore.ReadyState{
+ Message: "missing expected initial revision",
+ IsReady: len(mdb.revisions) > 0,
+ }, nil
+}
+
+func (mdb *memdbDatastore) OfflineFeatures() (*datastore.Features, error) {
+ return &datastore.Features{
+ Watch: datastore.Feature{
+ Status: datastore.FeatureSupported,
+ },
+ IntegrityData: datastore.Feature{
+ Status: datastore.FeatureSupported,
+ },
+ ContinuousCheckpointing: datastore.Feature{
+ Status: datastore.FeatureUnsupported,
+ },
+ WatchEmitsImmediately: datastore.Feature{
+ Status: datastore.FeatureUnsupported,
+ },
+ }, nil
+}
+
+func (mdb *memdbDatastore) Features(_ context.Context) (*datastore.Features, error) {
+ return mdb.OfflineFeatures()
+}
+
+func (mdb *memdbDatastore) Close() error {
+ mdb.Lock()
+ defer mdb.Unlock()
+
+ if db := mdb.db; db != nil {
+ mdb.revisions = []snapshot{
+ {
+ revision: nowRevision(),
+ db: db,
+ },
+ }
+ } else {
+ mdb.revisions = []snapshot{}
+ }
+
+ mdb.db = nil
+
+ return nil
+}
+
+// This code assumes that the RWMutex has been acquired.
+func (mdb *memdbDatastore) checkNotClosed() error {
+ if mdb.db == nil {
+ return ErrMemDBIsClosed
+ }
+ return nil
+}
+
+var _ datastore.Datastore = &memdbDatastore{}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go
new file mode 100644
index 0000000..fdd224a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readonly.go
@@ -0,0 +1,597 @@
+package memdb
+
+import (
+ "context"
+ "fmt"
+ "slices"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/hashicorp/go-memdb"
+
+ "github.com/authzed/spicedb/internal/datastore/common"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+type txFactory func() (*memdb.Txn, error)
+
+type memdbReader struct {
+ TryLocker
+ txSource txFactory
+ initErr error
+ now time.Time
+}
+
+func (r *memdbReader) CountRelationships(ctx context.Context, name string) (int, error) {
+ counters, err := r.LookupCounters(ctx)
+ if err != nil {
+ return 0, err
+ }
+
+ var found *core.RelationshipFilter
+ for _, counter := range counters {
+ if counter.Name == name {
+ found = counter.Filter
+ break
+ }
+ }
+
+ if found == nil {
+ return 0, datastore.NewCounterNotRegisteredErr(name)
+ }
+
+ coreFilter, err := datastore.RelationshipsFilterFromCoreFilter(found)
+ if err != nil {
+ return 0, err
+ }
+
+ iter, err := r.QueryRelationships(ctx, coreFilter)
+ if err != nil {
+ return 0, err
+ }
+
+ count := 0
+ for _, err := range iter {
+ if err != nil {
+ return 0, err
+ }
+
+ count++
+ }
+ return count, nil
+}
+
+func (r *memdbReader) LookupCounters(ctx context.Context) ([]datastore.RelationshipCounter, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ var counters []datastore.RelationshipCounter
+
+ it, err := tx.LowerBound(tableCounters, indexID)
+ if err != nil {
+ return nil, err
+ }
+
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ found := foundRaw.(*counter)
+
+ loaded := &core.RelationshipFilter{}
+ if err := loaded.UnmarshalVT(found.filterBytes); err != nil {
+ return nil, err
+ }
+
+ counters = append(counters, datastore.RelationshipCounter{
+ Name: found.name,
+ Filter: loaded,
+ Count: found.count,
+ ComputedAtRevision: found.updated,
+ })
+ }
+
+ return counters, nil
+}
+
+// QueryRelationships reads relationships starting from the resource side.
+func (r *memdbReader) QueryRelationships(
+ _ context.Context,
+ filter datastore.RelationshipsFilter,
+ opts ...options.QueryOptionsOption,
+) (datastore.RelationshipIterator, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ queryOpts := options.NewQueryOptionsWithOptions(opts...)
+
+ bestIterator, err := iteratorForFilter(tx, filter)
+ if err != nil {
+ return nil, err
+ }
+
+ if queryOpts.After != nil && queryOpts.Sort == options.Unsorted {
+ return nil, datastore.ErrCursorsWithoutSorting
+ }
+
+ matchingRelationshipsFilterFunc := filterFuncForFilters(
+ filter.OptionalResourceType,
+ filter.OptionalResourceIds,
+ filter.OptionalResourceIDPrefix,
+ filter.OptionalResourceRelation,
+ filter.OptionalSubjectsSelectors,
+ filter.OptionalCaveatNameFilter,
+ filter.OptionalExpirationOption,
+ makeCursorFilterFn(queryOpts.After, queryOpts.Sort),
+ )
+ filteredIterator := memdb.NewFilterIterator(bestIterator, matchingRelationshipsFilterFunc)
+
+ switch queryOpts.Sort {
+ case options.Unsorted:
+ fallthrough
+
+ case options.ByResource:
+ iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration)
+ return iter, nil
+
+ case options.BySubject:
+ return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration)
+
+ default:
+ return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort)
+ }
+}
+
+// ReverseQueryRelationships reads relationships starting from the subject.
+func (r *memdbReader) ReverseQueryRelationships(
+ _ context.Context,
+ subjectsFilter datastore.SubjectsFilter,
+ opts ...options.ReverseQueryOptionsOption,
+) (datastore.RelationshipIterator, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ queryOpts := options.NewReverseQueryOptionsWithOptions(opts...)
+
+ iterator, err := tx.Get(
+ tableRelationship,
+ indexSubjectNamespace,
+ subjectsFilter.SubjectType,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ filterObjectType, filterRelation := "", ""
+ if queryOpts.ResRelation != nil {
+ filterObjectType = queryOpts.ResRelation.Namespace
+ filterRelation = queryOpts.ResRelation.Relation
+ }
+
+ matchingRelationshipsFilterFunc := filterFuncForFilters(
+ filterObjectType,
+ nil,
+ "",
+ filterRelation,
+ []datastore.SubjectsSelector{subjectsFilter.AsSelector()},
+ datastore.CaveatNameFilter{},
+ datastore.ExpirationFilterOptionNone,
+ makeCursorFilterFn(queryOpts.AfterForReverse, queryOpts.SortForReverse),
+ )
+ filteredIterator := memdb.NewFilterIterator(iterator, matchingRelationshipsFilterFunc)
+
+ switch queryOpts.SortForReverse {
+ case options.Unsorted:
+ fallthrough
+
+ case options.ByResource:
+ iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false)
+ return iter, nil
+
+ case options.BySubject:
+ return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false)
+
+ default:
+ return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.SortForReverse)
+ }
+}
+
+// ReadNamespace reads a namespace definition and version and returns it, and the revision at
+// which it was created or last written, if found.
+func (r *memdbReader) ReadNamespaceByName(_ context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) {
+ if r.initErr != nil {
+ return nil, datastore.NoRevision, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+
+ foundRaw, err := tx.First(tableNamespace, indexID, nsName)
+ if err != nil {
+ return nil, datastore.NoRevision, err
+ }
+
+ if foundRaw == nil {
+ return nil, datastore.NoRevision, datastore.NewNamespaceNotFoundErr(nsName)
+ }
+
+ found := foundRaw.(*namespace)
+
+ loaded := &core.NamespaceDefinition{}
+ if err := loaded.UnmarshalVT(found.configBytes); err != nil {
+ return nil, datastore.NoRevision, err
+ }
+
+ return loaded, found.updated, nil
+}
+
+// ListNamespaces lists all namespaces defined.
+func (r *memdbReader) ListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ var nsDefs []datastore.RevisionedNamespace
+
+ it, err := tx.LowerBound(tableNamespace, indexID)
+ if err != nil {
+ return nil, err
+ }
+
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ found := foundRaw.(*namespace)
+
+ loaded := &core.NamespaceDefinition{}
+ if err := loaded.UnmarshalVT(found.configBytes); err != nil {
+ return nil, err
+ }
+
+ nsDefs = append(nsDefs, datastore.RevisionedNamespace{
+ Definition: loaded,
+ LastWrittenRevision: found.updated,
+ })
+ }
+
+ return nsDefs, nil
+}
+
+func (r *memdbReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) {
+ if r.initErr != nil {
+ return nil, r.initErr
+ }
+
+ if len(nsNames) == 0 {
+ return nil, nil
+ }
+
+ r.mustLock()
+ defer r.Unlock()
+
+ tx, err := r.txSource()
+ if err != nil {
+ return nil, err
+ }
+
+ it, err := tx.LowerBound(tableNamespace, indexID)
+ if err != nil {
+ return nil, err
+ }
+
+ nsNameMap := make(map[string]struct{}, len(nsNames))
+ for _, nsName := range nsNames {
+ nsNameMap[nsName] = struct{}{}
+ }
+
+ nsDefs := make([]datastore.RevisionedNamespace, 0, len(nsNames))
+
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ found := foundRaw.(*namespace)
+
+ loaded := &core.NamespaceDefinition{}
+ if err := loaded.UnmarshalVT(found.configBytes); err != nil {
+ return nil, err
+ }
+
+ if _, ok := nsNameMap[loaded.Name]; ok {
+ nsDefs = append(nsDefs, datastore.RevisionedNamespace{
+ Definition: loaded,
+ LastWrittenRevision: found.updated,
+ })
+ }
+ }
+
+ return nsDefs, nil
+}
+
+func (r *memdbReader) mustLock() {
+ if !r.TryLock() {
+ panic("detected concurrent use of ReadWriteTransaction")
+ }
+}
+
+func iteratorForFilter(txn *memdb.Txn, filter datastore.RelationshipsFilter) (memdb.ResultIterator, error) {
+ // "_prefix" is a specialized index suffix used by github.com/hashicorp/go-memdb to match on
+ // a prefix of a string.
+ // See: https://github.com/hashicorp/go-memdb/blob/9940d4a14258e3b887bfb4bc6ebc28f65461a01c/txn.go#L531
+ index := indexNamespace + "_prefix"
+
+ var args []any
+ if filter.OptionalResourceType != "" {
+ args = append(args, filter.OptionalResourceType)
+ index = indexNamespace
+ } else {
+ args = append(args, "")
+ }
+
+ if filter.OptionalResourceType != "" && filter.OptionalResourceRelation != "" {
+ args = append(args, filter.OptionalResourceRelation)
+ index = indexNamespaceAndRelation
+ }
+
+ if len(args) == 0 {
+ return nil, spiceerrors.MustBugf("cannot specify an empty filter")
+ }
+
+ iter, err := txn.Get(tableRelationship, index, args...)
+ if err != nil {
+ return nil, fmt.Errorf("unable to get iterator for filter: %w", err)
+ }
+
+ return iter, err
+}
+
+func filterFuncForFilters(
+ optionalResourceType string,
+ optionalResourceIds []string,
+ optionalResourceIDPrefix string,
+ optionalRelation string,
+ optionalSubjectsSelectors []datastore.SubjectsSelector,
+ optionalCaveatFilter datastore.CaveatNameFilter,
+ optionalExpirationFilter datastore.ExpirationFilterOption,
+ cursorFilter func(*relationship) bool,
+) memdb.FilterFunc {
+ return func(tupleRaw interface{}) bool {
+ tuple := tupleRaw.(*relationship)
+
+ switch {
+ case optionalResourceType != "" && optionalResourceType != tuple.namespace:
+ return true
+ case len(optionalResourceIds) > 0 && !slices.Contains(optionalResourceIds, tuple.resourceID):
+ return true
+ case optionalResourceIDPrefix != "" && !strings.HasPrefix(tuple.resourceID, optionalResourceIDPrefix):
+ return true
+ case optionalRelation != "" && optionalRelation != tuple.relation:
+ return true
+ case optionalCaveatFilter.Option == datastore.CaveatFilterOptionHasMatchingCaveat && (tuple.caveat == nil || tuple.caveat.caveatName != optionalCaveatFilter.CaveatName):
+ return true
+ case optionalCaveatFilter.Option == datastore.CaveatFilterOptionNoCaveat && (tuple.caveat != nil && tuple.caveat.caveatName != ""):
+ return true
+ case optionalExpirationFilter == datastore.ExpirationFilterOptionHasExpiration && tuple.expiration == nil:
+ return true
+ case optionalExpirationFilter == datastore.ExpirationFilterOptionNoExpiration && tuple.expiration != nil:
+ return true
+ }
+
+ applySubjectSelector := func(selector datastore.SubjectsSelector) bool {
+ switch {
+ case len(selector.OptionalSubjectType) > 0 && selector.OptionalSubjectType != tuple.subjectNamespace:
+ return false
+ case len(selector.OptionalSubjectIds) > 0 && !slices.Contains(selector.OptionalSubjectIds, tuple.subjectObjectID):
+ return false
+ }
+
+ if selector.RelationFilter.OnlyNonEllipsisRelations {
+ return tuple.subjectRelation != datastore.Ellipsis
+ }
+
+ relations := make([]string, 0, 2)
+ if selector.RelationFilter.IncludeEllipsisRelation {
+ relations = append(relations, datastore.Ellipsis)
+ }
+
+ if selector.RelationFilter.NonEllipsisRelation != "" {
+ relations = append(relations, selector.RelationFilter.NonEllipsisRelation)
+ }
+
+ return len(relations) == 0 || slices.Contains(relations, tuple.subjectRelation)
+ }
+
+ if len(optionalSubjectsSelectors) > 0 {
+ hasMatchingSelector := false
+ for _, selector := range optionalSubjectsSelectors {
+ if applySubjectSelector(selector) {
+ hasMatchingSelector = true
+ break
+ }
+ }
+
+ if !hasMatchingSelector {
+ return true
+ }
+ }
+
+ return cursorFilter(tuple)
+ }
+}
+
+func makeCursorFilterFn(after options.Cursor, order options.SortOrder) func(tpl *relationship) bool {
+ if after != nil {
+ switch order {
+ case options.ByResource:
+ return func(tpl *relationship) bool {
+ return less(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) ||
+ (eq(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) &&
+ (less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) ||
+ eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject)))
+ }
+ case options.BySubject:
+ return func(tpl *relationship) bool {
+ return less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) ||
+ (eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) &&
+ (less(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource) ||
+ eq(tpl.namespace, tpl.resourceID, tpl.relation, after.Resource)))
+ }
+ }
+ }
+ return noopCursorFilter
+}
+
+func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) (datastore.RelationshipIterator, error) {
+ results := make([]tuple.Relationship, 0)
+
+ // Coalesce all of the results into memory
+ for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
+ rt, err := foundRaw.(*relationship).Relationship()
+ if err != nil {
+ return nil, err
+ }
+
+ if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) {
+ continue
+ }
+
+ if skipCaveats && rt.OptionalCaveat != nil {
+ return nil, spiceerrors.MustBugf("unexpected caveat in result for relationship: %v", rt)
+ }
+
+ if skipExpiration && rt.OptionalExpiration != nil {
+ return nil, spiceerrors.MustBugf("unexpected expiration in result for relationship: %v", rt)
+ }
+
+ results = append(results, rt)
+ }
+
+ // Sort them by subject
+ sort.Slice(results, func(i, j int) bool {
+ lhsRes := results[i].Resource
+ lhsSub := results[i].Subject
+ rhsRes := results[j].Resource
+ rhsSub := results[j].Subject
+ return less(lhsSub.ObjectType, lhsSub.ObjectID, lhsSub.Relation, rhsSub) ||
+ (eq(lhsSub.ObjectType, lhsSub.ObjectID, lhsSub.Relation, rhsSub) &&
+ (less(lhsRes.ObjectType, lhsRes.ObjectID, lhsRes.Relation, rhsRes)))
+ })
+
+ // Limit them if requested
+ if limit != nil && uint64(len(results)) > *limit {
+ results = results[0:*limit]
+ }
+
+ return common.NewSliceRelationshipIterator(results), nil
+}
+
+func noopCursorFilter(_ *relationship) bool {
+ return false
+}
+
+func less(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelation) bool {
+ return lhsNamespace < rhs.ObjectType ||
+ (lhsNamespace == rhs.ObjectType && lhsObjectID < rhs.ObjectID) ||
+ (lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation < rhs.Relation)
+}
+
+func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelation) bool {
+ return lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation == rhs.Relation
+}
+
+func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) datastore.RelationshipIterator {
+ var count uint64
+ return func(yield func(tuple.Relationship, error) bool) {
+ for {
+ foundRaw := it.Next()
+ if foundRaw == nil {
+ return
+ }
+
+ if limit != nil && count >= *limit {
+ return
+ }
+
+ rt, err := foundRaw.(*relationship).Relationship()
+ if err != nil {
+ if !yield(tuple.Relationship{}, err) {
+ return
+ }
+ continue
+ }
+
+ if skipCaveats && rt.OptionalCaveat != nil {
+ yield(rt, fmt.Errorf("unexpected caveat in result for relationship: %v", rt))
+ return
+ }
+
+ if skipExpiration && rt.OptionalExpiration != nil {
+ yield(rt, fmt.Errorf("unexpected expiration in result for relationship: %v", rt))
+ return
+ }
+
+ if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) {
+ continue
+ }
+
+ if !yield(rt, err) {
+ return
+ }
+ count++
+ }
+ }
+}
+
+var _ datastore.Reader = &memdbReader{}
+
+type TryLocker interface {
+ TryLock() bool
+ Unlock()
+}
+
+type noopTryLocker struct{}
+
+func (ntl noopTryLocker) TryLock() bool {
+ return true
+}
+
+func (ntl noopTryLocker) Unlock() {}
+
+var _ TryLocker = noopTryLocker{}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go
new file mode 100644
index 0000000..8929e84
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/readwrite.go
@@ -0,0 +1,386 @@
+package memdb
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/hashicorp/go-memdb"
+ "github.com/jzelinskie/stringz"
+
+ "github.com/authzed/spicedb/internal/datastore/common"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+type memdbReadWriteTx struct {
+ memdbReader
+ newRevision datastore.Revision
+}
+
+func (rwt *memdbReadWriteTx) WriteRelationships(_ context.Context, mutations []tuple.RelationshipUpdate) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ return rwt.write(tx, mutations...)
+}
+
+func (rwt *memdbReadWriteTx) toIntegrity(mutation tuple.RelationshipUpdate) *relationshipIntegrity {
+ var ri *relationshipIntegrity
+ if mutation.Relationship.OptionalIntegrity != nil {
+ ri = &relationshipIntegrity{
+ keyID: mutation.Relationship.OptionalIntegrity.KeyId,
+ hash: mutation.Relationship.OptionalIntegrity.Hash,
+ timestamp: mutation.Relationship.OptionalIntegrity.HashedAt.AsTime(),
+ }
+ }
+ return ri
+}
+
+// Caller must already hold the concurrent access lock!
+func (rwt *memdbReadWriteTx) write(tx *memdb.Txn, mutations ...tuple.RelationshipUpdate) error {
+ // Apply the mutations
+ for _, mutation := range mutations {
+ rel := &relationship{
+ mutation.Relationship.Resource.ObjectType,
+ mutation.Relationship.Resource.ObjectID,
+ mutation.Relationship.Resource.Relation,
+ mutation.Relationship.Subject.ObjectType,
+ mutation.Relationship.Subject.ObjectID,
+ mutation.Relationship.Subject.Relation,
+ rwt.toCaveatReference(mutation),
+ rwt.toIntegrity(mutation),
+ mutation.Relationship.OptionalExpiration,
+ }
+
+ found, err := tx.First(
+ tableRelationship,
+ indexID,
+ rel.namespace,
+ rel.resourceID,
+ rel.relation,
+ rel.subjectNamespace,
+ rel.subjectObjectID,
+ rel.subjectRelation,
+ )
+ if err != nil {
+ return fmt.Errorf("error loading existing relationship: %w", err)
+ }
+
+ var existing *relationship
+ if found != nil {
+ existing = found.(*relationship)
+ }
+
+ switch mutation.Operation {
+ case tuple.UpdateOperationCreate:
+ if existing != nil {
+ rt, err := existing.Relationship()
+ if err != nil {
+ return err
+ }
+ return common.NewCreateRelationshipExistsError(&rt)
+ }
+ if err := tx.Insert(tableRelationship, rel); err != nil {
+ return fmt.Errorf("error inserting relationship: %w", err)
+ }
+
+ case tuple.UpdateOperationTouch:
+ if existing != nil {
+ rt, err := existing.Relationship()
+ if err != nil {
+ return err
+ }
+ if tuple.MustString(rt) == tuple.MustString(mutation.Relationship) {
+ continue
+ }
+ }
+
+ if err := tx.Insert(tableRelationship, rel); err != nil {
+ return fmt.Errorf("error inserting relationship: %w", err)
+ }
+
+ case tuple.UpdateOperationDelete:
+ if existing != nil {
+ if err := tx.Delete(tableRelationship, existing); err != nil {
+ return fmt.Errorf("error deleting relationship: %w", err)
+ }
+ }
+ default:
+ return spiceerrors.MustBugf("unknown tuple mutation operation type: %v", mutation.Operation)
+ }
+ }
+
+ return nil
+}
+
+func (rwt *memdbReadWriteTx) toCaveatReference(mutation tuple.RelationshipUpdate) *contextualizedCaveat {
+ var cr *contextualizedCaveat
+ if mutation.Relationship.OptionalCaveat != nil {
+ cr = &contextualizedCaveat{
+ caveatName: mutation.Relationship.OptionalCaveat.CaveatName,
+ context: mutation.Relationship.OptionalCaveat.Context.AsMap(),
+ }
+ }
+ return cr
+}
+
+func (rwt *memdbReadWriteTx) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, opts ...options.DeleteOptionsOption) (uint64, bool, error) {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return 0, false, err
+ }
+
+ delOpts := options.NewDeleteOptionsWithOptionsAndDefaults(opts...)
+ var delLimit uint64
+ if delOpts.DeleteLimit != nil && *delOpts.DeleteLimit > 0 {
+ delLimit = *delOpts.DeleteLimit
+ }
+
+ return rwt.deleteWithLock(tx, filter, delLimit)
+}
+
+// caller must already hold the concurrent access lock
+func (rwt *memdbReadWriteTx) deleteWithLock(tx *memdb.Txn, filter *v1.RelationshipFilter, limit uint64) (uint64, bool, error) {
+ // Create an iterator to find the relevant tuples
+ dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter)
+ if err != nil {
+ return 0, false, err
+ }
+
+ bestIter, err := iteratorForFilter(tx, dsFilter)
+ if err != nil {
+ return 0, false, err
+ }
+ filteredIter := memdb.NewFilterIterator(bestIter, relationshipFilterFilterFunc(filter))
+
+ // Collect the tuples into a slice of mutations for the changelog
+ var mutations []tuple.RelationshipUpdate
+ var counter uint64
+
+ metLimit := false
+ for row := filteredIter.Next(); row != nil; row = filteredIter.Next() {
+ rt, err := row.(*relationship).Relationship()
+ if err != nil {
+ return 0, false, err
+ }
+ mutations = append(mutations, tuple.Delete(rt))
+ counter++
+
+ if limit > 0 && counter == limit {
+ metLimit = true
+ break
+ }
+ }
+
+ return counter, metLimit, rwt.write(tx, mutations...)
+}
+
+func (rwt *memdbReadWriteTx) RegisterCounter(ctx context.Context, name string, filter *core.RelationshipFilter) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ foundRaw, err := tx.First(tableCounters, indexID, name)
+ if err != nil {
+ return err
+ }
+
+ if foundRaw != nil {
+ return datastore.NewCounterAlreadyRegisteredErr(name, filter)
+ }
+
+ filterBytes, err := filter.MarshalVT()
+ if err != nil {
+ return err
+ }
+
+ // Insert the counter
+ counter := &counter{
+ name,
+ filterBytes,
+ 0,
+ datastore.NoRevision,
+ }
+
+ return tx.Insert(tableCounters, counter)
+}
+
+func (rwt *memdbReadWriteTx) UnregisterCounter(ctx context.Context, name string) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ // Check if the counter exists
+ foundRaw, err := tx.First(tableCounters, indexID, name)
+ if err != nil {
+ return err
+ }
+
+ if foundRaw == nil {
+ return datastore.NewCounterNotRegisteredErr(name)
+ }
+
+ return tx.Delete(tableCounters, foundRaw)
+}
+
+func (rwt *memdbReadWriteTx) StoreCounterValue(ctx context.Context, name string, value int, computedAtRevision datastore.Revision) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ // Check if the counter exists
+ foundRaw, err := tx.First(tableCounters, indexID, name)
+ if err != nil {
+ return err
+ }
+
+ if foundRaw == nil {
+ return datastore.NewCounterNotRegisteredErr(name)
+ }
+
+ counter := foundRaw.(*counter)
+ counter.count = value
+ counter.updated = computedAtRevision
+
+ return tx.Insert(tableCounters, counter)
+}
+
+func (rwt *memdbReadWriteTx) WriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ for _, newConfig := range newConfigs {
+ serialized, err := newConfig.MarshalVT()
+ if err != nil {
+ return err
+ }
+
+ newConfigEntry := &namespace{newConfig.Name, serialized, rwt.newRevision}
+
+ err = tx.Insert(tableNamespace, newConfigEntry)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (rwt *memdbReadWriteTx) DeleteNamespaces(_ context.Context, nsNames ...string) error {
+ rwt.mustLock()
+ defer rwt.Unlock()
+
+ tx, err := rwt.txSource()
+ if err != nil {
+ return err
+ }
+
+ for _, nsName := range nsNames {
+ foundRaw, err := tx.First(tableNamespace, indexID, nsName)
+ if err != nil {
+ return err
+ }
+
+ if foundRaw == nil {
+ return fmt.Errorf("namespace not found")
+ }
+
+ if err := tx.Delete(tableNamespace, foundRaw); err != nil {
+ return err
+ }
+
+ // Delete the relationships from the namespace
+ if _, _, err := rwt.deleteWithLock(tx, &v1.RelationshipFilter{
+ ResourceType: nsName,
+ }, 0); err != nil {
+ return fmt.Errorf("unable to delete relationships from deleted namespace: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func (rwt *memdbReadWriteTx) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) {
+ var numCopied uint64
+ var next *tuple.Relationship
+ var err error
+
+ updates := []tuple.RelationshipUpdate{{
+ Operation: tuple.UpdateOperationCreate,
+ }}
+
+ for next, err = iter.Next(ctx); next != nil && err == nil; next, err = iter.Next(ctx) {
+ updates[0].Relationship = *next
+ if err := rwt.WriteRelationships(ctx, updates); err != nil {
+ return 0, err
+ }
+ numCopied++
+ }
+
+ return numCopied, err
+}
+
+func relationshipFilterFilterFunc(filter *v1.RelationshipFilter) func(interface{}) bool {
+ return func(tupleRaw interface{}) bool {
+ tuple := tupleRaw.(*relationship)
+
+ // If it doesn't match one of the resource filters, filter it.
+ switch {
+ case filter.ResourceType != "" && filter.ResourceType != tuple.namespace:
+ return true
+ case filter.OptionalResourceId != "" && filter.OptionalResourceId != tuple.resourceID:
+ return true
+ case filter.OptionalResourceIdPrefix != "" && !strings.HasPrefix(tuple.resourceID, filter.OptionalResourceIdPrefix):
+ return true
+ case filter.OptionalRelation != "" && filter.OptionalRelation != tuple.relation:
+ return true
+ }
+
+ // If it doesn't match one of the subject filters, filter it.
+ if subjectFilter := filter.OptionalSubjectFilter; subjectFilter != nil {
+ switch {
+ case subjectFilter.SubjectType != tuple.subjectNamespace:
+ return true
+ case subjectFilter.OptionalSubjectId != "" && subjectFilter.OptionalSubjectId != tuple.subjectObjectID:
+ return true
+ case subjectFilter.OptionalRelation != nil &&
+ stringz.DefaultEmpty(subjectFilter.OptionalRelation.Relation, datastore.Ellipsis) != tuple.subjectRelation:
+ return true
+ }
+ }
+
+ return false
+ }
+}
+
+var _ datastore.ReadWriteTransaction = &memdbReadWriteTx{}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go
new file mode 100644
index 0000000..be79771
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/revisions.go
@@ -0,0 +1,118 @@
+package memdb
+
+import (
+ "context"
+ "time"
+
+ "github.com/authzed/spicedb/internal/datastore/revisions"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+var ParseRevisionString = revisions.RevisionParser(revisions.Timestamp)
+
+func nowRevision() revisions.TimestampRevision {
+ return revisions.NewForTime(time.Now().UTC())
+}
+
+func (mdb *memdbDatastore) newRevisionID() revisions.TimestampRevision {
+ mdb.Lock()
+ defer mdb.Unlock()
+
+ existing := mdb.revisions[len(mdb.revisions)-1].revision
+ created := nowRevision()
+
+ // NOTE: The time.Now().UTC() only appears to have *microsecond* level
+ // precision on macOS Monterey in Go 1.19.1. This means that HeadRevision
+ // and the result of a ReadWriteTx could return the *same* transaction ID
+ // if both are executed in sequence without any other forms of delay on
+ // macOS. We therefore check if the created transaction ID matches that
+ // previously created and, if not, add to it.
+ //
+ // See: https://github.com/golang/go/issues/22037 which appeared to fix
+ // this in Go 1.9.2, but there appears to have been a reversion with either
+ // the new version of macOS or Go.
+ if created.Equal(existing) {
+ return revisions.NewForTimestamp(created.TimestampNanoSec() + 1)
+ }
+
+ return created
+}
+
+func (mdb *memdbDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+ if err := mdb.checkNotClosed(); err != nil {
+ return nil, err
+ }
+
+ return mdb.headRevisionNoLock(), nil
+}
+
+func (mdb *memdbDatastore) SquashRevisionsForTesting() {
+ mdb.revisions = []snapshot{
+ {
+ revision: nowRevision(),
+ db: mdb.db,
+ },
+ }
+}
+
+func (mdb *memdbDatastore) headRevisionNoLock() revisions.TimestampRevision {
+ return mdb.revisions[len(mdb.revisions)-1].revision
+}
+
+func (mdb *memdbDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+ if err := mdb.checkNotClosed(); err != nil {
+ return nil, err
+ }
+
+ now := nowRevision()
+ return revisions.NewForTimestamp(now.TimestampNanoSec() - now.TimestampNanoSec()%mdb.quantizationPeriod), nil
+}
+
+func (mdb *memdbDatastore) CheckRevision(_ context.Context, dr datastore.Revision) error {
+ mdb.RLock()
+ defer mdb.RUnlock()
+ if err := mdb.checkNotClosed(); err != nil {
+ return err
+ }
+
+ return mdb.checkRevisionLocalCallerMustLock(dr)
+}
+
+func (mdb *memdbDatastore) checkRevisionLocalCallerMustLock(dr datastore.Revision) error {
+ now := nowRevision()
+
+ // Ensure the revision has not fallen outside of the GC window. If it has, it is considered
+ // invalid.
+ if mdb.revisionOutsideGCWindow(now, dr) {
+ return datastore.NewInvalidRevisionErr(dr, datastore.RevisionStale)
+ }
+
+ // If the revision <= now and later than the GC window, it is assumed to be valid, even if
+ // HEAD revision is behind it.
+ if dr.GreaterThan(now) {
+ // If the revision is in the "future", then check to ensure that it is <= of HEAD to handle
+ // the microsecond granularity on macos (see comment above in newRevisionID)
+ headRevision := mdb.headRevisionNoLock()
+ if dr.LessThan(headRevision) || dr.Equal(headRevision) {
+ return nil
+ }
+
+ return datastore.NewInvalidRevisionErr(dr, datastore.CouldNotDetermineRevision)
+ }
+
+ return nil
+}
+
+func (mdb *memdbDatastore) revisionOutsideGCWindow(now revisions.TimestampRevision, revisionRaw datastore.Revision) bool {
+ // make an exception for head revision - it will be acceptable even if outside GC Window
+ if revisionRaw.Equal(mdb.headRevisionNoLock()) {
+ return false
+ }
+
+ oldest := revisions.NewForTimestamp(now.TimestampNanoSec() + mdb.negativeGCWindow)
+ return revisionRaw.LessThan(oldest)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go
new file mode 100644
index 0000000..7905d48
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/schema.go
@@ -0,0 +1,232 @@
+package memdb
+
+import (
+ "time"
+
+ "github.com/hashicorp/go-memdb"
+ "github.com/rs/zerolog"
+ "google.golang.org/protobuf/types/known/structpb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+const (
+ tableNamespace = "namespace"
+
+ tableRelationship = "relationship"
+ indexID = "id"
+ indexNamespace = "namespace"
+ indexNamespaceAndRelation = "namespaceAndRelation"
+ indexSubjectNamespace = "subjectNamespace"
+
+ tableCounters = "counters"
+
+ tableChangelog = "changelog"
+ indexRevision = "id"
+)
+
+type namespace struct {
+ name string
+ configBytes []byte
+ updated datastore.Revision
+}
+
+func (ns namespace) MarshalZerologObject(e *zerolog.Event) {
+ e.Stringer("rev", ns.updated).Str("name", ns.name)
+}
+
+type counter struct {
+ name string
+ filterBytes []byte
+ count int
+ updated datastore.Revision
+}
+
+type relationship struct {
+ namespace string
+ resourceID string
+ relation string
+ subjectNamespace string
+ subjectObjectID string
+ subjectRelation string
+ caveat *contextualizedCaveat
+ integrity *relationshipIntegrity
+ expiration *time.Time
+}
+
+type relationshipIntegrity struct {
+ keyID string
+ hash []byte
+ timestamp time.Time
+}
+
+func (ri relationshipIntegrity) MarshalZerologObject(e *zerolog.Event) {
+ e.Str("keyID", ri.keyID).Bytes("hash", ri.hash).Time("timestamp", ri.timestamp)
+}
+
+func (ri relationshipIntegrity) RelationshipIntegrity() *core.RelationshipIntegrity {
+ return &core.RelationshipIntegrity{
+ KeyId: ri.keyID,
+ Hash: ri.hash,
+ HashedAt: timestamppb.New(ri.timestamp),
+ }
+}
+
+type contextualizedCaveat struct {
+ caveatName string
+ context map[string]any
+}
+
+func (cr *contextualizedCaveat) ContextualizedCaveat() (*core.ContextualizedCaveat, error) {
+ if cr == nil {
+ return nil, nil
+ }
+ v, err := structpb.NewStruct(cr.context)
+ if err != nil {
+ return nil, err
+ }
+ return &core.ContextualizedCaveat{
+ CaveatName: cr.caveatName,
+ Context: v,
+ }, nil
+}
+
+func (r relationship) String() string {
+ caveat := ""
+ if r.caveat != nil {
+ caveat = "[" + r.caveat.caveatName + "]"
+ }
+
+ expiration := ""
+ if r.expiration != nil {
+ expiration = "[expiration:" + r.expiration.Format(time.RFC3339Nano) + "]"
+ }
+
+ return r.namespace + ":" + r.resourceID + "#" + r.relation + "@" + r.subjectNamespace + ":" + r.subjectObjectID + "#" + r.subjectRelation + caveat + expiration
+}
+
+func (r relationship) MarshalZerologObject(e *zerolog.Event) {
+ e.Str("rel", r.String())
+}
+
+func (r relationship) Relationship() (tuple.Relationship, error) {
+ cr, err := r.caveat.ContextualizedCaveat()
+ if err != nil {
+ return tuple.Relationship{}, err
+ }
+
+ var ig *core.RelationshipIntegrity
+ if r.integrity != nil {
+ ig = r.integrity.RelationshipIntegrity()
+ }
+
+ return tuple.Relationship{
+ RelationshipReference: tuple.RelationshipReference{
+ Resource: tuple.ObjectAndRelation{
+ ObjectType: r.namespace,
+ ObjectID: r.resourceID,
+ Relation: r.relation,
+ },
+ Subject: tuple.ObjectAndRelation{
+ ObjectType: r.subjectNamespace,
+ ObjectID: r.subjectObjectID,
+ Relation: r.subjectRelation,
+ },
+ },
+ OptionalCaveat: cr,
+ OptionalIntegrity: ig,
+ OptionalExpiration: r.expiration,
+ }, nil
+}
+
+type changelog struct {
+ revisionNanos int64
+ changes datastore.RevisionChanges
+}
+
+var schema = &memdb.DBSchema{
+ Tables: map[string]*memdb.TableSchema{
+ tableNamespace: {
+ Name: tableNamespace,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexID: {
+ Name: indexID,
+ Unique: true,
+ Indexer: &memdb.StringFieldIndex{Field: "name"},
+ },
+ },
+ },
+ tableChangelog: {
+ Name: tableChangelog,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexRevision: {
+ Name: indexRevision,
+ Unique: true,
+ Indexer: &memdb.IntFieldIndex{Field: "revisionNanos"},
+ },
+ },
+ },
+ tableRelationship: {
+ Name: tableRelationship,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexID: {
+ Name: indexID,
+ Unique: true,
+ Indexer: &memdb.CompoundIndex{
+ Indexes: []memdb.Indexer{
+ &memdb.StringFieldIndex{Field: "namespace"},
+ &memdb.StringFieldIndex{Field: "resourceID"},
+ &memdb.StringFieldIndex{Field: "relation"},
+ &memdb.StringFieldIndex{Field: "subjectNamespace"},
+ &memdb.StringFieldIndex{Field: "subjectObjectID"},
+ &memdb.StringFieldIndex{Field: "subjectRelation"},
+ },
+ },
+ },
+ indexNamespace: {
+ Name: indexNamespace,
+ Unique: false,
+ Indexer: &memdb.StringFieldIndex{Field: "namespace"},
+ },
+ indexNamespaceAndRelation: {
+ Name: indexNamespaceAndRelation,
+ Unique: false,
+ Indexer: &memdb.CompoundIndex{
+ Indexes: []memdb.Indexer{
+ &memdb.StringFieldIndex{Field: "namespace"},
+ &memdb.StringFieldIndex{Field: "relation"},
+ },
+ },
+ },
+ indexSubjectNamespace: {
+ Name: indexSubjectNamespace,
+ Unique: false,
+ Indexer: &memdb.StringFieldIndex{Field: "subjectNamespace"},
+ },
+ },
+ },
+ tableCaveats: {
+ Name: tableCaveats,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexID: {
+ Name: indexID,
+ Unique: true,
+ Indexer: &memdb.StringFieldIndex{Field: "name"},
+ },
+ },
+ },
+ tableCounters: {
+ Name: tableCounters,
+ Indexes: map[string]*memdb.IndexSchema{
+ indexID: {
+ Name: indexID,
+ Unique: true,
+ Indexer: &memdb.StringFieldIndex{Field: "name"},
+ },
+ },
+ },
+ },
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go
new file mode 100644
index 0000000..33665a1
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/stats.go
@@ -0,0 +1,51 @@
+package memdb
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+func (mdb *memdbDatastore) Statistics(ctx context.Context) (datastore.Stats, error) {
+ head, err := mdb.HeadRevision(ctx)
+ if err != nil {
+ return datastore.Stats{}, fmt.Errorf("unable to compute head revision: %w", err)
+ }
+
+ count, err := mdb.countRelationships(ctx)
+ if err != nil {
+ return datastore.Stats{}, fmt.Errorf("unable to count relationships: %w", err)
+ }
+
+ objTypes, err := mdb.SnapshotReader(head).ListAllNamespaces(ctx)
+ if err != nil {
+ return datastore.Stats{}, fmt.Errorf("unable to list object types: %w", err)
+ }
+
+ return datastore.Stats{
+ UniqueID: mdb.uniqueID,
+ EstimatedRelationshipCount: count,
+ ObjectTypeStatistics: datastore.ComputeObjectTypeStats(objTypes),
+ }, nil
+}
+
+func (mdb *memdbDatastore) countRelationships(_ context.Context) (uint64, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+
+ txn := mdb.db.Txn(false)
+ defer txn.Abort()
+
+ it, err := txn.LowerBound(tableRelationship, indexID)
+ if err != nil {
+ return 0, err
+ }
+
+ var count uint64
+ for row := it.Next(); row != nil; row = it.Next() {
+ count++
+ }
+
+ return count, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go
new file mode 100644
index 0000000..eaa4812
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/memdb/watch.go
@@ -0,0 +1,148 @@
+package memdb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/hashicorp/go-memdb"
+
+ "github.com/authzed/spicedb/internal/datastore/revisions"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+const errWatchError = "watch error: %w"
+
+func (mdb *memdbDatastore) Watch(ctx context.Context, ar datastore.Revision, options datastore.WatchOptions) (<-chan datastore.RevisionChanges, <-chan error) {
+ watchBufferLength := options.WatchBufferLength
+ if watchBufferLength == 0 {
+ watchBufferLength = mdb.watchBufferLength
+ }
+
+ updates := make(chan datastore.RevisionChanges, watchBufferLength)
+ errs := make(chan error, 1)
+
+ if options.EmissionStrategy == datastore.EmitImmediatelyStrategy {
+ close(updates)
+ errs <- errors.New("emit immediately strategy is unsupported in MemDB")
+ return updates, errs
+ }
+
+ watchBufferWriteTimeout := options.WatchBufferWriteTimeout
+ if watchBufferWriteTimeout == 0 {
+ watchBufferWriteTimeout = mdb.watchBufferWriteTimeout
+ }
+
+ sendChange := func(change datastore.RevisionChanges) bool {
+ select {
+ case updates <- change:
+ return true
+
+ default:
+ // If we cannot immediately write, setup the timer and try again.
+ }
+
+ timer := time.NewTimer(watchBufferWriteTimeout)
+ defer timer.Stop()
+
+ select {
+ case updates <- change:
+ return true
+
+ case <-timer.C:
+ errs <- datastore.NewWatchDisconnectedErr()
+ return false
+ }
+ }
+
+ go func() {
+ defer close(updates)
+ defer close(errs)
+
+ currentTxn := ar.(revisions.TimestampRevision).TimestampNanoSec()
+
+ for {
+ var stagedUpdates []datastore.RevisionChanges
+ var watchChan <-chan struct{}
+ var err error
+ stagedUpdates, currentTxn, watchChan, err = mdb.loadChanges(ctx, currentTxn, options)
+ if err != nil {
+ errs <- err
+ return
+ }
+
+ // Write the staged updates to the channel
+ for _, changeToWrite := range stagedUpdates {
+ if !sendChange(changeToWrite) {
+ return
+ }
+ }
+
+ // Wait for new changes
+ ws := memdb.NewWatchSet()
+ ws.Add(watchChan)
+
+ err = ws.WatchCtx(ctx)
+ if err != nil {
+ switch {
+ case errors.Is(err, context.Canceled):
+ errs <- datastore.NewWatchCanceledErr()
+ default:
+ errs <- fmt.Errorf(errWatchError, err)
+ }
+ return
+ }
+ }
+ }()
+
+ return updates, errs
+}
+
+func (mdb *memdbDatastore) loadChanges(_ context.Context, currentTxn int64, options datastore.WatchOptions) ([]datastore.RevisionChanges, int64, <-chan struct{}, error) {
+ mdb.RLock()
+ defer mdb.RUnlock()
+
+ if err := mdb.checkNotClosed(); err != nil {
+ return nil, 0, nil, err
+ }
+
+ loadNewTxn := mdb.db.Txn(false)
+ defer loadNewTxn.Abort()
+
+ it, err := loadNewTxn.LowerBound(tableChangelog, indexRevision, currentTxn+1)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf(errWatchError, err)
+ }
+
+ var changes []datastore.RevisionChanges
+ lastRevision := currentTxn
+ for changeRaw := it.Next(); changeRaw != nil; changeRaw = it.Next() {
+ change := changeRaw.(*changelog)
+
+ if options.Content&datastore.WatchRelationships == datastore.WatchRelationships && len(change.changes.RelationshipChanges) > 0 {
+ changes = append(changes, change.changes)
+ }
+
+ if options.Content&datastore.WatchSchema == datastore.WatchSchema &&
+ len(change.changes.ChangedDefinitions) > 0 || len(change.changes.DeletedCaveats) > 0 || len(change.changes.DeletedNamespaces) > 0 {
+ changes = append(changes, change.changes)
+ }
+
+ if options.Content&datastore.WatchCheckpoints == datastore.WatchCheckpoints && change.revisionNanos > lastRevision {
+ changes = append(changes, datastore.RevisionChanges{
+ Revision: revisions.NewForTimestamp(change.revisionNanos),
+ IsCheckpoint: true,
+ })
+ }
+
+ lastRevision = change.revisionNanos
+ }
+
+ watchChan, _, err := loadNewTxn.LastWatch(tableChangelog, indexRevision)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf(errWatchError, err)
+ }
+
+ return changes, lastRevision, watchChan, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go
new file mode 100644
index 0000000..7092728
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/commonrevision.go
@@ -0,0 +1,79 @@
+package revisions
+
+import (
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// RevisionKind is an enum of the different kinds of revisions that can be used.
+type RevisionKind string
+
+const (
+ // Timestamp is a revision that is a timestamp.
+ Timestamp RevisionKind = "timestamp"
+
+ // TransactionID is a revision that is a transaction ID.
+ TransactionID = "txid"
+
+ // HybridLogicalClock is a revision that is a hybrid logical clock.
+ HybridLogicalClock = "hlc"
+)
+
+// ParsingFunc is a function that can parse a string into a revision.
+type ParsingFunc func(revisionStr string) (rev datastore.Revision, err error)
+
+// RevisionParser returns a ParsingFunc for the given RevisionKind.
+func RevisionParser(kind RevisionKind) ParsingFunc {
+ switch kind {
+ case TransactionID:
+ return parseTransactionIDRevisionString
+
+ case Timestamp:
+ return parseTimestampRevisionString
+
+ case HybridLogicalClock:
+ return parseHLCRevisionString
+
+ default:
+ return func(revisionStr string) (rev datastore.Revision, err error) {
+ return nil, spiceerrors.MustBugf("unknown revision kind: %v", kind)
+ }
+ }
+}
+
+// CommonDecoder is a revision decoder that can decode revisions of a given kind.
+type CommonDecoder struct {
+ Kind RevisionKind
+}
+
+func (cd CommonDecoder) RevisionFromString(s string) (datastore.Revision, error) {
+ switch cd.Kind {
+ case TransactionID:
+ return parseTransactionIDRevisionString(s)
+
+ case Timestamp:
+ return parseTimestampRevisionString(s)
+
+ case HybridLogicalClock:
+ return parseHLCRevisionString(s)
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown revision kind in decoder: %v", cd.Kind)
+ }
+}
+
+// WithInexactFloat64 is an interface that can be implemented by a revision to
+// provide an inexact float64 representation of the revision.
+type WithInexactFloat64 interface {
+ // InexactFloat64 returns a float64 that is an inexact representation of the
+ // revision.
+ InexactFloat64() float64
+}
+
+// WithTimestampRevision is an interface that can be implemented by a revision to
+// provide a timestamp.
+type WithTimestampRevision interface {
+ datastore.Revision
+ TimestampNanoSec() int64
+ ConstructForTimestamp(timestampNanoSec int64) WithTimestampRevision
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go
new file mode 100644
index 0000000..e4f7fc6
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/hlcrevision.go
@@ -0,0 +1,166 @@
+package revisions
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/shopspring/decimal"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+var zeroHLC = HLCRevision{}
+
+// NOTE: This *must* match the length defined in CRDB or the implementation below will break.
+const logicalClockLength = 10
+
+var logicalClockOffset = uint32(math.Pow10(logicalClockLength + 1))
+
+// HLCRevision is a revision that is a hybrid logical clock, stored as two integers.
+// The first integer is the timestamp in nanoseconds, and the second integer is the
+// logical clock defined as 11 digits, with the first digit being ignored to ensure
+// precision of the given logical clock.
+type HLCRevision struct {
+ time int64
+ logicalclock uint32
+}
+
+// parseHLCRevisionString parses a string into a hybrid logical clock revision.
+func parseHLCRevisionString(revisionStr string) (datastore.Revision, error) {
+ pieces := strings.Split(revisionStr, ".")
+ if len(pieces) == 1 {
+ // If there is no decimal point, assume the revision is a timestamp.
+ timestamp, err := strconv.ParseInt(pieces[0], 10, 64)
+ if err != nil {
+ return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr)
+ }
+ return HLCRevision{timestamp, logicalClockOffset}, nil
+ }
+
+ if len(pieces) != 2 {
+ return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr)
+ }
+
+ timestamp, err := strconv.ParseInt(pieces[0], 10, 64)
+ if err != nil {
+ return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr)
+ }
+
+ if len(pieces[1]) > logicalClockLength {
+ return datastore.NoRevision, spiceerrors.MustBugf("invalid revision string due to unexpected logical clock size (%d): %q", len(pieces[1]), revisionStr)
+ }
+
+ paddedLogicalClockStr := pieces[1] + strings.Repeat("0", logicalClockLength-len(pieces[1]))
+ logicalclock, err := strconv.ParseUint(paddedLogicalClockStr, 10, 64)
+ if err != nil {
+ return datastore.NoRevision, fmt.Errorf("invalid revision string: %q", revisionStr)
+ }
+
+ if logicalclock > math.MaxUint32 {
+ return datastore.NoRevision, spiceerrors.MustBugf("received logical lock that exceeds MaxUint32 (%d > %d): revision %q", logicalclock, math.MaxUint32, revisionStr)
+ }
+
+ uintLogicalClock, err := safecast.ToUint32(logicalclock)
+ if err != nil {
+ return datastore.NoRevision, spiceerrors.MustBugf("could not cast logicalclock to uint32: %v", err)
+ }
+
+ return HLCRevision{timestamp, uintLogicalClock + logicalClockOffset}, nil
+}
+
+// HLCRevisionFromString parses a string into a hybrid logical clock revision.
+func HLCRevisionFromString(revisionStr string) (HLCRevision, error) {
+ rev, err := parseHLCRevisionString(revisionStr)
+ if err != nil {
+ return zeroHLC, err
+ }
+
+ return rev.(HLCRevision), nil
+}
+
+// NewForHLC creates a new revision for the given hybrid logical clock.
+func NewForHLC(decimal decimal.Decimal) (HLCRevision, error) {
+ rev, err := HLCRevisionFromString(decimal.String())
+ if err != nil {
+ return zeroHLC, fmt.Errorf("invalid HLC decimal: %v (%s) => %w", decimal, decimal.String(), err)
+ }
+
+ return rev, nil
+}
+
+// NewHLCForTime creates a new revision for the given time.
+func NewHLCForTime(time time.Time) HLCRevision {
+ return HLCRevision{time.UnixNano(), logicalClockOffset}
+}
+
+func (hlc HLCRevision) ByteSortable() bool {
+ return true
+}
+
+func (hlc HLCRevision) Equal(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroHLC
+ }
+
+ rhsHLC := rhs.(HLCRevision)
+ return hlc.time == rhsHLC.time && hlc.logicalclock == rhsHLC.logicalclock
+}
+
+func (hlc HLCRevision) GreaterThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroHLC
+ }
+
+ rhsHLC := rhs.(HLCRevision)
+ return hlc.time > rhsHLC.time || (hlc.time == rhsHLC.time && hlc.logicalclock > rhsHLC.logicalclock)
+}
+
+func (hlc HLCRevision) LessThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroHLC
+ }
+
+ rhsHLC := rhs.(HLCRevision)
+ return hlc.time < rhsHLC.time || (hlc.time == rhsHLC.time && hlc.logicalclock < rhsHLC.logicalclock)
+}
+
+func (hlc HLCRevision) String() string {
+ logicalClockString := strconv.FormatInt(int64(hlc.logicalclock)-int64(logicalClockOffset), 10)
+ return strconv.FormatInt(hlc.time, 10) + "." + strings.Repeat("0", logicalClockLength-len(logicalClockString)) + logicalClockString
+}
+
+func (hlc HLCRevision) TimestampNanoSec() int64 {
+ return hlc.time
+}
+
+func (hlc HLCRevision) InexactFloat64() float64 {
+ return float64(hlc.time) + float64(hlc.logicalclock-logicalClockOffset)/math.Pow10(logicalClockLength)
+}
+
+func (hlc HLCRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevision {
+ return HLCRevision{timestamp, logicalClockOffset}
+}
+
+func (hlc HLCRevision) AsDecimal() (decimal.Decimal, error) {
+ return decimal.NewFromString(hlc.String())
+}
+
+var (
+ _ datastore.Revision = HLCRevision{}
+ _ WithTimestampRevision = HLCRevision{}
+)
+
+// HLCKeyFunc is used to convert a simple HLC for use in maps.
+func HLCKeyFunc(r HLCRevision) HLCRevision {
+ return r
+}
+
+// HLCKeyLessThanFunc is used to compare keys created by the HLCKeyFunc.
+func HLCKeyLessThanFunc(lhs, rhs HLCRevision) bool {
+ return lhs.time < rhs.time || (lhs.time == rhs.time && lhs.logicalclock < rhs.logicalclock)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go
new file mode 100644
index 0000000..3a5a919
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/optimized.go
@@ -0,0 +1,118 @@
+package revisions
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "sync"
+ "time"
+
+ "github.com/benbjohnson/clock"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/trace"
+ "golang.org/x/sync/singleflight"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+var tracer = otel.Tracer("spicedb/internal/datastore/common/revisions")
+
+// OptimizedRevisionFunction instructs the datastore to compute its own current
+// optimized revision given the specific quantization, and return for how long
+// it will remain valid.
+type OptimizedRevisionFunction func(context.Context) (rev datastore.Revision, validFor time.Duration, err error)
+
+// NewCachedOptimizedRevisions returns a CachedOptimizedRevisions for the given configuration
+func NewCachedOptimizedRevisions(maxRevisionStaleness time.Duration) *CachedOptimizedRevisions {
+ return &CachedOptimizedRevisions{
+ maxRevisionStaleness: maxRevisionStaleness,
+ clockFn: clock.New(),
+ }
+}
+
+// SetOptimizedRevisionFunc must be called after construction, and is the method
+// by which one specializes this helper for a specific datastore.
+func (cor *CachedOptimizedRevisions) SetOptimizedRevisionFunc(revisionFunc OptimizedRevisionFunction) {
+ cor.optimizedFunc = revisionFunc
+}
+
+func (cor *CachedOptimizedRevisions) OptimizedRevision(ctx context.Context) (datastore.Revision, error) {
+ span := trace.SpanFromContext(ctx)
+ localNow := cor.clockFn.Now()
+
+ // Subtract a random amount of time from now, to let barely expired candidates get selected
+ adjustedNow := localNow
+ if cor.maxRevisionStaleness > 0 {
+ // nolint:gosec
+ // G404 use of non cryptographically secure random number generator is not a security concern here,
+ // as we are using it to introduce randomness to the accepted staleness of a revision and reduce the odds of
+ // a thundering herd to the datastore
+ adjustedNow = localNow.Add(-1 * time.Duration(rand.Int63n(cor.maxRevisionStaleness.Nanoseconds())) * time.Nanosecond)
+ }
+
+ cor.RLock()
+ for _, candidate := range cor.candidates {
+ if candidate.validThrough.After(adjustedNow) {
+ cor.RUnlock()
+ log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", candidate.validThrough).Msg("returning cached revision")
+ span.AddEvent("returning cached revision")
+ return candidate.revision, nil
+ }
+ }
+ cor.RUnlock()
+
+ newQuantizedRevision, err, _ := cor.updateGroup.Do("", func() (interface{}, error) {
+ log.Ctx(ctx).Debug().Time("now", localNow).Msg("computing new revision")
+ span.AddEvent("computing new revision")
+
+ optimized, validFor, err := cor.optimizedFunc(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("unable to compute optimized revision: %w", err)
+ }
+
+ rvt := localNow.Add(validFor)
+
+ // Prune the candidates that have definitely expired
+ cor.Lock()
+ var numToDrop uint
+ for _, candidate := range cor.candidates {
+ if candidate.validThrough.Add(cor.maxRevisionStaleness).Before(localNow) {
+ numToDrop++
+ } else {
+ break
+ }
+ }
+
+ cor.candidates = cor.candidates[numToDrop:]
+ cor.candidates = append(cor.candidates, validRevision{optimized, rvt})
+ cor.Unlock()
+
+ log.Ctx(ctx).Debug().Time("now", localNow).Time("valid", rvt).Stringer("validFor", validFor).Msg("setting valid through")
+ return optimized, nil
+ })
+ if err != nil {
+ return datastore.NoRevision, err
+ }
+ return newQuantizedRevision.(datastore.Revision), err
+}
+
+// CachedOptimizedRevisions does caching and deduplication for requests for optimized revisions.
+type CachedOptimizedRevisions struct {
+ sync.RWMutex
+
+ maxRevisionStaleness time.Duration
+ optimizedFunc OptimizedRevisionFunction
+ clockFn clock.Clock
+
+ // these values are read and set by multiple consumers
+ candidates []validRevision // GUARDED_BY(RWMutex)
+
+ // the updategroup consolidates concurrent requests to the database into 1
+ updateGroup singleflight.Group
+}
+
+type validRevision struct {
+ revision datastore.Revision
+ validThrough time.Time
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go
new file mode 100644
index 0000000..ef793c8
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/remoteclock.go
@@ -0,0 +1,125 @@
+package revisions
+
+import (
+ "context"
+ "time"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// RemoteNowFunction queries the datastore to get a current revision.
+type RemoteNowFunction func(context.Context) (datastore.Revision, error)
+
+// RemoteClockRevisions handles revision calculation for datastores that provide
+// their own clocks.
+type RemoteClockRevisions struct {
+ *CachedOptimizedRevisions
+
+ gcWindowNanos int64
+ nowFunc RemoteNowFunction
+ followerReadDelayNanos int64
+ quantizationNanos int64
+}
+
+// NewRemoteClockRevisions returns a RemoteClockRevisions for the given configuration
+func NewRemoteClockRevisions(gcWindow, maxRevisionStaleness, followerReadDelay, quantization time.Duration) *RemoteClockRevisions {
+ // Ensure the max revision staleness never exceeds the GC window.
+ if maxRevisionStaleness > gcWindow {
+ log.Warn().
+ Dur("maxRevisionStaleness", maxRevisionStaleness).
+ Dur("gcWindow", gcWindow).
+ Msg("the configured maximum revision staleness exceeds the configured gc window, so capping to gcWindow")
+ maxRevisionStaleness = gcWindow - 1
+ }
+
+ revisions := &RemoteClockRevisions{
+ CachedOptimizedRevisions: NewCachedOptimizedRevisions(
+ maxRevisionStaleness,
+ ),
+ gcWindowNanos: gcWindow.Nanoseconds(),
+ followerReadDelayNanos: followerReadDelay.Nanoseconds(),
+ quantizationNanos: quantization.Nanoseconds(),
+ }
+
+ revisions.SetOptimizedRevisionFunc(revisions.optimizedRevisionFunc)
+
+ return revisions
+}
+
+func (rcr *RemoteClockRevisions) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, error) {
+ nowRev, err := rcr.nowFunc(ctx)
+ if err != nil {
+ return datastore.NoRevision, 0, err
+ }
+
+ if nowRev == datastore.NoRevision {
+ return datastore.NoRevision, 0, datastore.NewInvalidRevisionErr(nowRev, datastore.CouldNotDetermineRevision)
+ }
+
+ nowTS, ok := nowRev.(WithTimestampRevision)
+ if !ok {
+ return datastore.NoRevision, 0, spiceerrors.MustBugf("expected with-timestamp revision, got %T", nowRev)
+ }
+
+ delayedNow := nowTS.TimestampNanoSec() - rcr.followerReadDelayNanos
+ quantized := delayedNow
+ validForNanos := int64(0)
+ if rcr.quantizationNanos > 0 {
+ afterLastQuantization := delayedNow % rcr.quantizationNanos
+ quantized -= afterLastQuantization
+ validForNanos = rcr.quantizationNanos - afterLastQuantization
+ }
+ log.Ctx(ctx).Debug().
+ Time("quantized", time.Unix(0, quantized)).
+ Int64("readSkew", rcr.followerReadDelayNanos).
+ Int64("totalSkew", nowTS.TimestampNanoSec()-quantized).
+ Msg("revision skews")
+
+ return nowTS.ConstructForTimestamp(quantized), time.Duration(validForNanos) * time.Nanosecond, nil
+}
+
+// SetNowFunc sets the function used to determine the head revision
+func (rcr *RemoteClockRevisions) SetNowFunc(nowFunc RemoteNowFunction) {
+ rcr.nowFunc = nowFunc
+}
+
+func (rcr *RemoteClockRevisions) CheckRevision(ctx context.Context, dsRevision datastore.Revision) error {
+ if dsRevision == datastore.NoRevision {
+ return datastore.NewInvalidRevisionErr(dsRevision, datastore.CouldNotDetermineRevision)
+ }
+
+ revision := dsRevision.(WithTimestampRevision)
+
+ ctx, span := tracer.Start(ctx, "CheckRevision")
+ defer span.End()
+
+ // Make sure the system time indicated is within the software GC window
+ now, err := rcr.nowFunc(ctx)
+ if err != nil {
+ return err
+ }
+
+ nowTS, ok := now.(WithTimestampRevision)
+ if !ok {
+ return spiceerrors.MustBugf("expected HLC revision, got %T", now)
+ }
+
+ nowNanos := nowTS.TimestampNanoSec()
+ revisionNanos := revision.TimestampNanoSec()
+
+ isStale := revisionNanos < (nowNanos - rcr.gcWindowNanos)
+ if isStale {
+ log.Ctx(ctx).Debug().Stringer("now", now).Stringer("revision", revision).Msg("stale revision")
+ return datastore.NewInvalidRevisionErr(revision, datastore.RevisionStale)
+ }
+
+ isUnknown := revisionNanos > nowNanos
+ if isUnknown {
+ log.Ctx(ctx).Debug().Stringer("now", now).Stringer("revision", revision).Msg("unknown revision")
+ return datastore.NewInvalidRevisionErr(revision, datastore.CouldNotDetermineRevision)
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go
new file mode 100644
index 0000000..fc2a250
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/timestamprevision.go
@@ -0,0 +1,97 @@
+package revisions
+
+import (
+ "fmt"
+ "strconv"
+ "time"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+// TimestampRevision is a revision that is a timestamp.
+type TimestampRevision int64
+
+var zeroTimestampRevision = TimestampRevision(0)
+
+// NewForTime creates a new revision for the given time.
+func NewForTime(time time.Time) TimestampRevision {
+ return TimestampRevision(time.UnixNano())
+}
+
+// NewForTimestamp creates a new revision for the given timestamp.
+func NewForTimestamp(timestampNanosec int64) TimestampRevision {
+ return TimestampRevision(timestampNanosec)
+}
+
+// parseTimestampRevisionString parses a string into a timestamp revision.
+func parseTimestampRevisionString(revisionStr string) (rev datastore.Revision, err error) {
+ parsed, err := strconv.ParseInt(revisionStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid integer revision: %w", err)
+ }
+
+ return TimestampRevision(parsed), nil
+}
+
+func (ir TimestampRevision) ByteSortable() bool {
+ return true
+}
+
+func (ir TimestampRevision) Equal(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTimestampRevision
+ }
+
+ return int64(ir) == int64(rhs.(TimestampRevision))
+}
+
+func (ir TimestampRevision) GreaterThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTimestampRevision
+ }
+
+ return int64(ir) > int64(rhs.(TimestampRevision))
+}
+
+func (ir TimestampRevision) LessThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTimestampRevision
+ }
+
+ return int64(ir) < int64(rhs.(TimestampRevision))
+}
+
+func (ir TimestampRevision) TimestampNanoSec() int64 {
+ return int64(ir)
+}
+
+func (ir TimestampRevision) String() string {
+ return strconv.FormatInt(int64(ir), 10)
+}
+
+func (ir TimestampRevision) Time() time.Time {
+ return time.Unix(0, int64(ir))
+}
+
+func (ir TimestampRevision) WithInexactFloat64() float64 {
+ return float64(ir)
+}
+
+func (ir TimestampRevision) ConstructForTimestamp(timestamp int64) WithTimestampRevision {
+ return TimestampRevision(timestamp)
+}
+
+var (
+ _ datastore.Revision = TimestampRevision(0)
+ _ WithTimestampRevision = TimestampRevision(0)
+)
+
+// TimestampIDKeyFunc is used to create keys for timestamps.
+func TimestampIDKeyFunc(r TimestampRevision) int64 {
+ return int64(r)
+}
+
+// TimestampIDKeyLessThanFunc is used to create keys for timestamps.
+func TimestampIDKeyLessThanFunc(l, r int64) bool {
+ return l < r
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go
new file mode 100644
index 0000000..31d837f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/datastore/revisions/txidrevision.go
@@ -0,0 +1,80 @@
+package revisions
+
+import (
+ "fmt"
+ "strconv"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+// TransactionIDRevision is a revision that is a transaction ID.
+type TransactionIDRevision uint64
+
+var zeroTransactionIDRevision = TransactionIDRevision(0)
+
+// NewForTransactionID creates a new revision for the given transaction ID.
+func NewForTransactionID(transactionID uint64) TransactionIDRevision {
+ return TransactionIDRevision(transactionID)
+}
+
+// parseTransactionIDRevisionString parses a string into a transaction ID revision.
+func parseTransactionIDRevisionString(revisionStr string) (rev datastore.Revision, err error) {
+ parsed, err := strconv.ParseUint(revisionStr, 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("invalid integer revision: %w", err)
+ }
+
+ return TransactionIDRevision(parsed), nil
+}
+
+func (ir TransactionIDRevision) ByteSortable() bool {
+ return true
+}
+
+func (ir TransactionIDRevision) Equal(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTransactionIDRevision
+ }
+
+ return uint64(ir) == uint64(rhs.(TransactionIDRevision))
+}
+
+func (ir TransactionIDRevision) GreaterThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTransactionIDRevision
+ }
+
+ return uint64(ir) > uint64(rhs.(TransactionIDRevision))
+}
+
+func (ir TransactionIDRevision) LessThan(rhs datastore.Revision) bool {
+ if rhs == datastore.NoRevision {
+ rhs = zeroTransactionIDRevision
+ }
+
+ return uint64(ir) < uint64(rhs.(TransactionIDRevision))
+}
+
+func (ir TransactionIDRevision) TransactionID() uint64 {
+ return uint64(ir)
+}
+
+func (ir TransactionIDRevision) String() string {
+ return strconv.FormatUint(uint64(ir), 10)
+}
+
+func (ir TransactionIDRevision) WithInexactFloat64() float64 {
+ return float64(ir)
+}
+
+var _ datastore.Revision = TransactionIDRevision(0)
+
+// TransactionIDKeyFunc is used to create keys for transaction IDs.
+func TransactionIDKeyFunc(r TransactionIDRevision) uint64 {
+ return uint64(r)
+}
+
+// TransactionIDKeyLessThanFunc is used to create keys for transaction IDs.
+func TransactionIDKeyLessThanFunc(l, r uint64) bool {
+ return l < r
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/doc.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/doc.go
new file mode 100644
index 0000000..325981d
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/doc.go
@@ -0,0 +1,2 @@
+// Package developmentmembership defines operations with sets. To be used in tests only.
+package developmentmembership
diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/foundsubject.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/foundsubject.go
new file mode 100644
index 0000000..bf93d56
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/foundsubject.go
@@ -0,0 +1,127 @@
+package developmentmembership
+
+import (
+ "sort"
+ "strings"
+
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// NewFoundSubject creates a new FoundSubject for a subject and a set of its resources.
+func NewFoundSubject(subject *core.DirectSubject, resources ...tuple.ObjectAndRelation) FoundSubject {
+ return FoundSubject{tuple.FromCoreObjectAndRelation(subject.Subject), nil, subject.CaveatExpression, NewONRSet(resources...)}
+}
+
+// FoundSubject contains a single found subject and all the relationships in which that subject
+// is a member which were found via the ONRs expansion.
+type FoundSubject struct {
+ // subject is the subject found.
+ subject tuple.ObjectAndRelation
+
+ // excludedSubjects are any subjects excluded. Only should be set if subject is a wildcard.
+ excludedSubjects []FoundSubject
+
+ // caveatExpression is the conditional expression on the found subject.
+ caveatExpression *core.CaveatExpression
+
+ // resources are the resources under which the subject lives that informed the locating
+ // of this subject for the root ONR.
+ resources ONRSet
+}
+
+// GetSubjectId is named to match the Subject interface for the BaseSubjectSet.
+//
+//nolint:all
+func (fs FoundSubject) GetSubjectId() string {
+ return fs.subject.ObjectID
+}
+
+func (fs FoundSubject) GetCaveatExpression() *core.CaveatExpression {
+ return fs.caveatExpression
+}
+
+func (fs FoundSubject) GetExcludedSubjects() []FoundSubject {
+ return fs.excludedSubjects
+}
+
+// Subject returns the Subject of the FoundSubject.
+func (fs FoundSubject) Subject() tuple.ObjectAndRelation {
+ return fs.subject
+}
+
+// WildcardType returns the object type for the wildcard subject, if this is a wildcard subject.
+func (fs FoundSubject) WildcardType() (string, bool) {
+ if fs.subject.ObjectID == tuple.PublicWildcard {
+ return fs.subject.ObjectType, true
+ }
+
+ return "", false
+}
+
+// ExcludedSubjectsFromWildcard returns those subjects excluded from the wildcard subject.
+// If not a wildcard subject, returns false.
+func (fs FoundSubject) ExcludedSubjectsFromWildcard() ([]FoundSubject, bool) {
+ if fs.subject.ObjectID == tuple.PublicWildcard {
+ return fs.excludedSubjects, true
+ }
+
+ return nil, false
+}
+
+func (fs FoundSubject) excludedSubjectStrings() []string {
+ excludedStrings := make([]string, 0, len(fs.excludedSubjects))
+ for _, excludedSubject := range fs.excludedSubjects {
+ excludedSubjectString := tuple.StringONR(excludedSubject.subject)
+ if excludedSubject.GetCaveatExpression() != nil {
+ excludedSubjectString += "[...]"
+ }
+ excludedStrings = append(excludedStrings, excludedSubjectString)
+ }
+
+ sort.Strings(excludedStrings)
+ return excludedStrings
+}
+
+// ToValidationString returns the FoundSubject in a format that is consumable by the validationfile
+// package.
+func (fs FoundSubject) ToValidationString() string {
+ onrString := tuple.StringONR(fs.Subject())
+ validationString := onrString
+ if fs.caveatExpression != nil {
+ validationString = validationString + "[...]"
+ }
+
+ excluded, isWildcard := fs.ExcludedSubjectsFromWildcard()
+ if isWildcard && len(excluded) > 0 {
+ validationString = validationString + " - {" + strings.Join(fs.excludedSubjectStrings(), ", ") + "}"
+ }
+
+ return validationString
+}
+
+func (fs FoundSubject) String() string {
+ return fs.ToValidationString()
+}
+
+// ParentResources returns all the resources in which the subject was found as per the expand.
+func (fs FoundSubject) ParentResources() []tuple.ObjectAndRelation {
+ return fs.resources.AsSlice()
+}
+
+// FoundSubjects contains the subjects found for a specific ONR.
+type FoundSubjects struct {
+ // subjects is a map from the Subject ONR (as a string) to the FoundSubject information.
+ subjects *TrackingSubjectSet
+}
+
+// ListFound returns a slice of all the FoundSubject's.
+func (fs FoundSubjects) ListFound() []FoundSubject {
+ return fs.subjects.ToSlice()
+}
+
+// LookupSubject returns the FoundSubject for a matching subject, if any.
+func (fs FoundSubjects) LookupSubject(subject tuple.ObjectAndRelation) (FoundSubject, bool) {
+ return fs.subjects.Get(subject)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/membership.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/membership.go
new file mode 100644
index 0000000..caa5e8f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/membership.go
@@ -0,0 +1,167 @@
+package developmentmembership
+
+import (
+ "fmt"
+
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// Set represents the set of membership for one or more ONRs, based on expansion
+// trees.
+type Set struct {
+ // objectsAndRelations is a map from an ONR (as a string) to the subjects found for that ONR.
+ objectsAndRelations map[string]FoundSubjects
+}
+
+// SubjectsByONR returns a map from ONR (as a string) to the FoundSubjects for that ONR.
+func (ms *Set) SubjectsByONR() map[string]FoundSubjects {
+ return ms.objectsAndRelations
+}
+
+// NewMembershipSet constructs a new membership set.
+//
+// NOTE: This is designed solely for the developer API and should *not* be used in any performance
+// sensitive code.
+func NewMembershipSet() *Set {
+ return &Set{
+ objectsAndRelations: map[string]FoundSubjects{},
+ }
+}
+
+// AddExpansion adds the expansion of an ONR to the membership set. Returns false if the ONR was already added.
+//
+// NOTE: The expansion tree *should* be the fully recursive expansion.
+func (ms *Set) AddExpansion(onr tuple.ObjectAndRelation, expansion *core.RelationTupleTreeNode) (FoundSubjects, bool, error) {
+ onrString := tuple.StringONR(onr)
+ existing, ok := ms.objectsAndRelations[onrString]
+ if ok {
+ return existing, false, nil
+ }
+
+ tss, err := populateFoundSubjects(onr, expansion)
+ if err != nil {
+ return FoundSubjects{}, false, err
+ }
+
+ fs := tss.ToFoundSubjects()
+ ms.objectsAndRelations[onrString] = fs
+ return fs, true, nil
+}
+
+// AccessibleExpansionSubjects returns a TrackingSubjectSet representing the set of accessible subjects in the expansion.
+func AccessibleExpansionSubjects(treeNode *core.RelationTupleTreeNode) (*TrackingSubjectSet, error) {
+ return populateFoundSubjects(tuple.FromCoreObjectAndRelation(treeNode.Expanded), treeNode)
+}
+
+func populateFoundSubjects(rootONR tuple.ObjectAndRelation, treeNode *core.RelationTupleTreeNode) (*TrackingSubjectSet, error) {
+ resource := rootONR
+ if treeNode.Expanded != nil {
+ resource = tuple.FromCoreObjectAndRelation(treeNode.Expanded)
+ }
+
+ switch typed := treeNode.NodeType.(type) {
+ case *core.RelationTupleTreeNode_IntermediateNode:
+ switch typed.IntermediateNode.Operation {
+ case core.SetOperationUserset_UNION:
+ toReturn := NewTrackingSubjectSet()
+ for _, child := range typed.IntermediateNode.ChildNodes {
+ tss, err := populateFoundSubjects(resource, child)
+ if err != nil {
+ return nil, err
+ }
+
+ err = toReturn.AddFrom(tss)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression)
+ return toReturn, nil
+
+ case core.SetOperationUserset_INTERSECTION:
+ if len(typed.IntermediateNode.ChildNodes) == 0 {
+ return nil, fmt.Errorf("found intersection with no children")
+ }
+
+ firstChildSet, err := populateFoundSubjects(rootONR, typed.IntermediateNode.ChildNodes[0])
+ if err != nil {
+ return nil, err
+ }
+
+ toReturn := NewTrackingSubjectSet()
+ err = toReturn.AddFrom(firstChildSet)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, child := range typed.IntermediateNode.ChildNodes[1:] {
+ childSet, err := populateFoundSubjects(rootONR, child)
+ if err != nil {
+ return nil, err
+ }
+
+ updated, err := toReturn.Intersect(childSet)
+ if err != nil {
+ return nil, err
+ }
+
+ toReturn = updated
+ }
+
+ toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression)
+ return toReturn, nil
+
+ case core.SetOperationUserset_EXCLUSION:
+ if len(typed.IntermediateNode.ChildNodes) == 0 {
+ return nil, fmt.Errorf("found exclusion with no children")
+ }
+
+ firstChildSet, err := populateFoundSubjects(rootONR, typed.IntermediateNode.ChildNodes[0])
+ if err != nil {
+ return nil, err
+ }
+
+ toReturn := NewTrackingSubjectSet()
+ err = toReturn.AddFrom(firstChildSet)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, child := range typed.IntermediateNode.ChildNodes[1:] {
+ childSet, err := populateFoundSubjects(rootONR, child)
+ if err != nil {
+ return nil, err
+ }
+ toReturn = toReturn.Exclude(childSet)
+ }
+
+ toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression)
+ return toReturn, nil
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown expand operation")
+ }
+
+ case *core.RelationTupleTreeNode_LeafNode:
+ toReturn := NewTrackingSubjectSet()
+ for _, subject := range typed.LeafNode.Subjects {
+ fs := NewFoundSubject(subject)
+ err := toReturn.Add(fs)
+ if err != nil {
+ return nil, err
+ }
+
+ fs.resources.Add(resource)
+ }
+
+ toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression)
+ return toReturn, nil
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown TreeNode type")
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/onrset.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/onrset.go
new file mode 100644
index 0000000..ad7fcfd
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/onrset.go
@@ -0,0 +1,87 @@
+package developmentmembership
+
+import (
+ "github.com/ccoveille/go-safecast"
+
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// TODO(jschorr): Replace with the generic set over tuple.ObjectAndRelation
+
+// ONRSet is a set of ObjectAndRelation's.
+type ONRSet struct {
+ onrs *mapz.Set[tuple.ObjectAndRelation]
+}
+
+// NewONRSet creates a new set.
+func NewONRSet(onrs ...tuple.ObjectAndRelation) ONRSet {
+ created := ONRSet{
+ onrs: mapz.NewSet[tuple.ObjectAndRelation](),
+ }
+ created.Update(onrs)
+ return created
+}
+
+// Length returns the size of the set.
+func (ons ONRSet) Length() uint64 {
+ // This is the length of a set so we should never fall out of bounds.
+ length, _ := safecast.ToUint64(ons.onrs.Len())
+ return length
+}
+
+// IsEmpty returns whether the set is empty.
+func (ons ONRSet) IsEmpty() bool {
+ return ons.onrs.IsEmpty()
+}
+
+// Has returns true if the set contains the given ONR.
+func (ons ONRSet) Has(onr tuple.ObjectAndRelation) bool {
+ return ons.onrs.Has(onr)
+}
+
+// Add adds the given ONR to the set. Returns true if the object was not in the set before this
+// call and false otherwise.
+func (ons ONRSet) Add(onr tuple.ObjectAndRelation) bool {
+ return ons.onrs.Add(onr)
+}
+
+// Update updates the set by adding the given ONRs to it.
+func (ons ONRSet) Update(onrs []tuple.ObjectAndRelation) {
+ for _, onr := range onrs {
+ ons.Add(onr)
+ }
+}
+
+// UpdateFrom updates the set by adding the ONRs found in the other set to it.
+func (ons ONRSet) UpdateFrom(otherSet ONRSet) {
+ if otherSet.onrs == nil {
+ return
+ }
+ ons.onrs.Merge(otherSet.onrs)
+}
+
+// Intersect returns an intersection between this ONR set and the other set provided.
+func (ons ONRSet) Intersect(otherSet ONRSet) ONRSet {
+ return ONRSet{ons.onrs.Intersect(otherSet.onrs)}
+}
+
+// Subtract returns a subtraction from this ONR set of the other set provided.
+func (ons ONRSet) Subtract(otherSet ONRSet) ONRSet {
+ return ONRSet{ons.onrs.Subtract(otherSet.onrs)}
+}
+
+// Union returns a copy of this ONR set with the other set's elements added in.
+func (ons ONRSet) Union(otherSet ONRSet) ONRSet {
+ return ONRSet{ons.onrs.Union(otherSet.onrs)}
+}
+
+// AsSlice returns the ONRs found in the set as a slice.
+func (ons ONRSet) AsSlice() []tuple.ObjectAndRelation {
+ slice := make([]tuple.ObjectAndRelation, 0, ons.Length())
+ _ = ons.onrs.ForEach(func(onr tuple.ObjectAndRelation) error {
+ slice = append(slice, onr)
+ return nil
+ })
+ return slice
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/developmentmembership/trackingsubjectset.go b/vendor/github.com/authzed/spicedb/internal/developmentmembership/trackingsubjectset.go
new file mode 100644
index 0000000..00f8836
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/developmentmembership/trackingsubjectset.go
@@ -0,0 +1,235 @@
+package developmentmembership
+
+import (
+ "github.com/authzed/spicedb/internal/datasets"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// TrackingSubjectSet defines a set that tracks accessible subjects and their associated
+// relationships.
+//
+// NOTE: This is designed solely for the developer API and testing and should *not* be used in any
+// performance sensitive code.
+type TrackingSubjectSet struct {
+ setByType map[tuple.RelationReference]datasets.BaseSubjectSet[FoundSubject]
+}
+
+// NewTrackingSubjectSet creates a new TrackingSubjectSet
+func NewTrackingSubjectSet() *TrackingSubjectSet {
+ tss := &TrackingSubjectSet{
+ setByType: map[tuple.RelationReference]datasets.BaseSubjectSet[FoundSubject]{},
+ }
+ return tss
+}
+
+// MustNewTrackingSubjectSetWith creates a new TrackingSubjectSet, and adds the specified
+// subjects to it.
+func MustNewTrackingSubjectSetWith(subjects ...FoundSubject) *TrackingSubjectSet {
+ tss := NewTrackingSubjectSet()
+ for _, subject := range subjects {
+ err := tss.Add(subject)
+ if err != nil {
+ panic(err)
+ }
+ }
+ return tss
+}
+
+// AddFrom adds the subjects found in the other set to this set.
+func (tss *TrackingSubjectSet) AddFrom(otherSet *TrackingSubjectSet) error {
+ for key, oss := range otherSet.setByType {
+ err := tss.getSetForKey(key).UnionWithSet(oss)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// MustAddFrom adds the subjects found in the other set to this set.
+func (tss *TrackingSubjectSet) MustAddFrom(otherSet *TrackingSubjectSet) {
+ err := tss.AddFrom(otherSet)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// RemoveFrom removes any subjects found in the other set from this set.
+func (tss *TrackingSubjectSet) RemoveFrom(otherSet *TrackingSubjectSet) {
+ for key, oss := range otherSet.setByType {
+ tss.getSetForKey(key).SubtractAll(oss)
+ }
+}
+
+// MustAdd adds the given subjects to this set.
+func (tss *TrackingSubjectSet) MustAdd(subjectsAndResources ...FoundSubject) {
+ err := tss.Add(subjectsAndResources...)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// Add adds the given subjects to this set.
+func (tss *TrackingSubjectSet) Add(subjectsAndResources ...FoundSubject) error {
+ for _, fs := range subjectsAndResources {
+ err := tss.getSet(fs).Add(fs)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (tss *TrackingSubjectSet) getSetForKey(key tuple.RelationReference) datasets.BaseSubjectSet[FoundSubject] {
+ if existing, ok := tss.setByType[key]; ok {
+ return existing
+ }
+
+ created := datasets.NewBaseSubjectSet(
+ func(subjectID string, caveatExpression *core.CaveatExpression, excludedSubjects []FoundSubject, sources ...FoundSubject) FoundSubject {
+ fs := NewFoundSubject(&core.DirectSubject{
+ Subject: &core.ObjectAndRelation{
+ Namespace: key.ObjectType,
+ ObjectId: subjectID,
+ Relation: key.Relation,
+ },
+ CaveatExpression: caveatExpression,
+ })
+ fs.excludedSubjects = excludedSubjects
+ fs.caveatExpression = caveatExpression
+ for _, source := range sources {
+ fs.resources.UpdateFrom(source.resources)
+ }
+ return fs
+ },
+ )
+ tss.setByType[key] = created
+ return created
+}
+
+func (tss *TrackingSubjectSet) getSet(fs FoundSubject) datasets.BaseSubjectSet[FoundSubject] {
+ return tss.getSetForKey(fs.subject.RelationReference())
+}
+
+// Get returns the found subject in the set, if any.
+func (tss *TrackingSubjectSet) Get(subject tuple.ObjectAndRelation) (FoundSubject, bool) {
+ set, ok := tss.setByType[subject.RelationReference()]
+ if !ok {
+ return FoundSubject{}, false
+ }
+
+ return set.Get(subject.ObjectID)
+}
+
+// Contains returns true if the set contains the given subject.
+func (tss *TrackingSubjectSet) Contains(subject tuple.ObjectAndRelation) bool {
+ _, ok := tss.Get(subject)
+ return ok
+}
+
+// Exclude returns a new set that contains the items in this set minus those in the other set.
+func (tss *TrackingSubjectSet) Exclude(otherSet *TrackingSubjectSet) *TrackingSubjectSet {
+ newSet := NewTrackingSubjectSet()
+
+ for key, bss := range tss.setByType {
+ cloned := bss.Clone()
+ if oss, ok := otherSet.setByType[key]; ok {
+ cloned.SubtractAll(oss)
+ }
+
+ newSet.setByType[key] = cloned
+ }
+
+ return newSet
+}
+
+// MustIntersect returns a new set that contains the items in this set *and* the other set. Note that
+// if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found
+// on the other side of the intersection.
+func (tss *TrackingSubjectSet) MustIntersect(otherSet *TrackingSubjectSet) *TrackingSubjectSet {
+ updated, err := tss.Intersect(otherSet)
+ if err != nil {
+ panic(err)
+ }
+ return updated
+}
+
+// Intersect returns a new set that contains the items in this set *and* the other set. Note that
+// if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found
+// on the other side of the intersection.
+func (tss *TrackingSubjectSet) Intersect(otherSet *TrackingSubjectSet) (*TrackingSubjectSet, error) {
+ newSet := NewTrackingSubjectSet()
+
+ for key, bss := range tss.setByType {
+ if oss, ok := otherSet.setByType[key]; ok {
+ cloned := bss.Clone()
+ err := cloned.IntersectionDifference(oss)
+ if err != nil {
+ return nil, err
+ }
+
+ newSet.setByType[key] = cloned
+ }
+ }
+
+ return newSet, nil
+}
+
+// ApplyParentCaveatExpression applies the given parent caveat expression (if any) to each subject set.
+func (tss *TrackingSubjectSet) ApplyParentCaveatExpression(parentCaveatExpr *core.CaveatExpression) {
+ if parentCaveatExpr == nil {
+ return
+ }
+
+ for key, bss := range tss.setByType {
+ tss.setByType[key] = bss.WithParentCaveatExpression(parentCaveatExpr)
+ }
+}
+
+// removeExact removes the given subject(s) from the set. If the subject is a wildcard, only
+// the exact matching wildcard will be removed.
+func (tss *TrackingSubjectSet) removeExact(subjects ...tuple.ObjectAndRelation) {
+ for _, subject := range subjects {
+ if set, ok := tss.setByType[subject.RelationReference()]; ok {
+ set.UnsafeRemoveExact(FoundSubject{
+ subject: subject,
+ })
+ }
+ }
+}
+
+func (tss *TrackingSubjectSet) getSubjects() []string {
+ var subjects []string
+ for _, subjectSet := range tss.setByType {
+ for _, foundSubject := range subjectSet.AsSlice() {
+ subjects = append(subjects, tuple.StringONR(foundSubject.subject))
+ }
+ }
+ return subjects
+}
+
+// ToSlice returns a slice of all subjects found in the set.
+func (tss *TrackingSubjectSet) ToSlice() []FoundSubject {
+ subjects := []FoundSubject{}
+ for _, bss := range tss.setByType {
+ subjects = append(subjects, bss.AsSlice()...)
+ }
+
+ return subjects
+}
+
+// ToFoundSubjects returns the set as a FoundSubjects struct.
+func (tss *TrackingSubjectSet) ToFoundSubjects() FoundSubjects {
+ return FoundSubjects{tss}
+}
+
+// IsEmpty returns true if the tracking subject set is empty.
+func (tss *TrackingSubjectSet) IsEmpty() bool {
+ for _, bss := range tss.setByType {
+ if !bss.IsEmpty() {
+ return false
+ }
+ }
+ return true
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/dispatch.go b/vendor/github.com/authzed/spicedb/internal/dispatch/dispatch.go
new file mode 100644
index 0000000..95a231a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/dispatch/dispatch.go
@@ -0,0 +1,98 @@
+package dispatch
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/rs/zerolog"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+)
+
+// ReadyState represents the ready state of the dispatcher.
+type ReadyState struct {
+ // Message is a human-readable status message for the current state.
+ Message string
+
+ // IsReady indicates whether the datastore is ready.
+ IsReady bool
+}
+
+// Dispatcher interface describes a method for passing subchecks off to additional machines.
+type Dispatcher interface {
+ Check
+ Expand
+ LookupSubjects
+ LookupResources2
+
+ // Close closes the dispatcher.
+ Close() error
+
+ // ReadyState returns true when dispatcher is able to respond to requests
+ ReadyState() ReadyState
+}
+
+// Check interface describes just the methods required to dispatch check requests.
+type Check interface {
+ // DispatchCheck submits a single check request and returns its result.
+ DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error)
+}
+
+// Expand interface describes just the methods required to dispatch expand requests.
+type Expand interface {
+ // DispatchExpand submits a single expand request and returns its result.
+ // If an error is returned, DispatchExpandResponse will still contain Metadata.
+ DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error)
+}
+
+type LookupResources2Stream = Stream[*v1.DispatchLookupResources2Response]
+
+type LookupResources2 interface {
+ DispatchLookupResources2(
+ req *v1.DispatchLookupResources2Request,
+ stream LookupResources2Stream,
+ ) error
+}
+
+// LookupSubjectsStream is an alias for the stream to which found subjects will be written.
+type LookupSubjectsStream = Stream[*v1.DispatchLookupSubjectsResponse]
+
+// LookupSubjects interface describes just the methods required to dispatch lookup subjects requests.
+type LookupSubjects interface {
+ // DispatchLookupSubjects submits a single lookup subjects request, writing its results to the specified stream.
+ DispatchLookupSubjects(
+ req *v1.DispatchLookupSubjectsRequest,
+ stream LookupSubjectsStream,
+ ) error
+}
+
+// DispatchableRequest is an interface for requests.
+type DispatchableRequest interface {
+ zerolog.LogObjectMarshaler
+
+ GetMetadata() *v1.ResolverMeta
+}
+
+// CheckDepth returns ErrMaxDepth if there is insufficient depth remaining to dispatch.
+func CheckDepth(ctx context.Context, req DispatchableRequest) error {
+ metadata := req.GetMetadata()
+ if metadata == nil {
+ log.Ctx(ctx).Warn().Object("request", req).Msg("request missing metadata")
+ return fmt.Errorf("request missing metadata")
+ }
+
+ if metadata.DepthRemaining == 0 {
+ return NewMaxDepthExceededError(req)
+ }
+
+ return nil
+}
+
+// AddResponseMetadata adds the metadata found in the incoming metadata to the existing
+// metadata, *modifying it in place*.
+func AddResponseMetadata(existing *v1.ResponseMeta, incoming *v1.ResponseMeta) {
+ existing.DispatchCount += incoming.DispatchCount
+ existing.CachedDispatchCount += incoming.CachedDispatchCount
+ existing.DepthRequired = max(existing.DepthRequired, incoming.DepthRequired)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/doc.go b/vendor/github.com/authzed/spicedb/internal/dispatch/doc.go
new file mode 100644
index 0000000..8b88bb0
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/dispatch/doc.go
@@ -0,0 +1,2 @@
+// Package dispatch contains logic to dispatch requests locally or to other nodes.
+package dispatch
diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/errors.go b/vendor/github.com/authzed/spicedb/internal/dispatch/errors.go
new file mode 100644
index 0000000..17cec3f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/dispatch/errors.go
@@ -0,0 +1,39 @@
+package dispatch
+
+import (
+ "fmt"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// MaxDepthExceededError is an error returned when the maximum depth for dispatching has been exceeded.
+type MaxDepthExceededError struct {
+ error
+
+ // Request is the request that exceeded the maximum depth.
+ Request DispatchableRequest
+}
+
+// NewMaxDepthExceededError creates a new MaxDepthExceededError.
+func NewMaxDepthExceededError(req DispatchableRequest) error {
+ return MaxDepthExceededError{
+ fmt.Errorf("max depth exceeded: this usually indicates a recursive or too deep data dependency. See: https://spicedb.dev/d/debug-max-depth"),
+ req,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err MaxDepthExceededError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.ResourceExhausted,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_MAXIMUM_DEPTH_EXCEEDED,
+ map[string]string{},
+ ),
+ )
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/graph/errors.go b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/errors.go
new file mode 100644
index 0000000..ecaf59a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/errors.go
@@ -0,0 +1,77 @@
+package graph
+
+import (
+ "fmt"
+
+ "github.com/rs/zerolog"
+)
+
+// NamespaceNotFoundError occurs when a namespace was not found.
+type NamespaceNotFoundError struct {
+ error
+ namespaceName string
+}
+
+// NotFoundNamespaceName returns the name of the namespace that was not found.
+func (err NamespaceNotFoundError) NotFoundNamespaceName() string {
+ return err.namespaceName
+}
+
+// MarshalZerologObject implements zerolog.LogObjectMarshaler
+func (err NamespaceNotFoundError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("namespace", err.namespaceName)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err NamespaceNotFoundError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "definition_name": err.namespaceName,
+ }
+}
+
+// NewNamespaceNotFoundErr constructs a new namespace not found error.
+func NewNamespaceNotFoundErr(nsName string) error {
+ return NamespaceNotFoundError{
+ error: fmt.Errorf("object definition `%s` not found", nsName),
+ namespaceName: nsName,
+ }
+}
+
+// RelationNotFoundError occurs when a relation was not found under a namespace.
+type RelationNotFoundError struct {
+ error
+ namespaceName string
+ relationName string
+}
+
+// NamespaceName returns the name of the namespace in which the relation was not found.
+func (err RelationNotFoundError) NamespaceName() string {
+ return err.namespaceName
+}
+
+// NotFoundRelationName returns the name of the relation not found.
+func (err RelationNotFoundError) NotFoundRelationName() string {
+ return err.relationName
+}
+
+// MarshalZerologObject implements zerolog.LogObjectMarshaler
+func (err RelationNotFoundError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err RelationNotFoundError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "definition_name": err.namespaceName,
+ "relation_or_permission_name": err.relationName,
+ }
+}
+
+// NewRelationNotFoundErr constructs a new relation not found error.
+func NewRelationNotFoundErr(nsName string, relationName string) error {
+ return RelationNotFoundError{
+ error: fmt.Errorf("relation/permission `%s` not found under definition `%s`", relationName, nsName),
+ namespaceName: nsName,
+ relationName: relationName,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/graph/graph.go b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/graph.go
new file mode 100644
index 0000000..232b1e7
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/graph.go
@@ -0,0 +1,437 @@
+package graph
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/rs/zerolog"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/graph"
+ log "github.com/authzed/spicedb/internal/logging"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/middleware/nodeid"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+const errDispatch = "error dispatching request: %w"
+
+var tracer = otel.Tracer("spicedb/internal/dispatch/local")
+
+// ConcurrencyLimits defines per-dispatch-type concurrency limits.
+//
+//go:generate go run github.com/ecordell/optgen -output zz_generated.options.go . ConcurrencyLimits
+type ConcurrencyLimits struct {
+ Check uint16 `debugmap:"visible"`
+ ReachableResources uint16 `debugmap:"visible"`
+ LookupResources uint16 `debugmap:"visible"`
+ LookupSubjects uint16 `debugmap:"visible"`
+}
+
+const defaultConcurrencyLimit = 50
+
+// WithOverallDefaultLimit sets the overall default limit for any unspecified limits
+// and returns a new struct.
+func (cl ConcurrencyLimits) WithOverallDefaultLimit(overallDefaultLimit uint16) ConcurrencyLimits {
+ return limitsOrDefaults(cl, overallDefaultLimit)
+}
+
+func (cl ConcurrencyLimits) MarshalZerologObject(e *zerolog.Event) {
+ e.Uint16("concurrency-limit-check-permission", cl.Check)
+ e.Uint16("concurrency-limit-lookup-resources", cl.LookupResources)
+ e.Uint16("concurrency-limit-lookup-subjects", cl.LookupSubjects)
+ e.Uint16("concurrency-limit-reachable-resources", cl.ReachableResources)
+}
+
+func limitsOrDefaults(limits ConcurrencyLimits, overallDefaultLimit uint16) ConcurrencyLimits {
+ limits.Check = limitOrDefault(limits.Check, overallDefaultLimit)
+ limits.LookupResources = limitOrDefault(limits.LookupResources, overallDefaultLimit)
+ limits.LookupSubjects = limitOrDefault(limits.LookupSubjects, overallDefaultLimit)
+ limits.ReachableResources = limitOrDefault(limits.ReachableResources, overallDefaultLimit)
+ return limits
+}
+
+func limitOrDefault(limit uint16, defaultLimit uint16) uint16 {
+ if limit <= 0 {
+ return defaultLimit
+ }
+ return limit
+}
+
+// SharedConcurrencyLimits returns a ConcurrencyLimits struct with the limit
+// set to that provided for each operation.
+func SharedConcurrencyLimits(concurrencyLimit uint16) ConcurrencyLimits {
+ return ConcurrencyLimits{
+ Check: concurrencyLimit,
+ ReachableResources: concurrencyLimit,
+ LookupResources: concurrencyLimit,
+ LookupSubjects: concurrencyLimit,
+ }
+}
+
+// NewLocalOnlyDispatcher creates a dispatcher that consults with the graph to formulate a response.
+func NewLocalOnlyDispatcher(typeSet *caveattypes.TypeSet, concurrencyLimit uint16, dispatchChunkSize uint16) dispatch.Dispatcher {
+ return NewLocalOnlyDispatcherWithLimits(typeSet, SharedConcurrencyLimits(concurrencyLimit), dispatchChunkSize)
+}
+
+// NewLocalOnlyDispatcherWithLimits creates a dispatcher thatg consults with the graph to formulate a response
+// and has the defined concurrency limits per dispatch type.
+func NewLocalOnlyDispatcherWithLimits(typeSet *caveattypes.TypeSet, concurrencyLimits ConcurrencyLimits, dispatchChunkSize uint16) dispatch.Dispatcher {
+ d := &localDispatcher{}
+
+ concurrencyLimits = limitsOrDefaults(concurrencyLimits, defaultConcurrencyLimit)
+ chunkSize := dispatchChunkSize
+ if chunkSize == 0 {
+ chunkSize = 100
+ log.Warn().Msgf("LocalOnlyDispatcher: dispatchChunkSize not set, defaulting to %d", chunkSize)
+ }
+
+ d.checker = graph.NewConcurrentChecker(d, concurrencyLimits.Check, chunkSize)
+ d.expander = graph.NewConcurrentExpander(d)
+ d.lookupSubjectsHandler = graph.NewConcurrentLookupSubjects(d, concurrencyLimits.LookupSubjects, chunkSize)
+ d.lookupResourcesHandler2 = graph.NewCursoredLookupResources2(d, d, typeSet, concurrencyLimits.LookupResources, chunkSize)
+
+ return d
+}
+
+// NewDispatcher creates a dispatcher that consults with the graph and redispatches subproblems to
+// the provided redispatcher.
+func NewDispatcher(redispatcher dispatch.Dispatcher, typeSet *caveattypes.TypeSet, concurrencyLimits ConcurrencyLimits, dispatchChunkSize uint16) dispatch.Dispatcher {
+ concurrencyLimits = limitsOrDefaults(concurrencyLimits, defaultConcurrencyLimit)
+ chunkSize := dispatchChunkSize
+ if chunkSize == 0 {
+ chunkSize = 100
+ log.Warn().Msgf("Dispatcher: dispatchChunkSize not set, defaulting to %d", chunkSize)
+ }
+
+ checker := graph.NewConcurrentChecker(redispatcher, concurrencyLimits.Check, chunkSize)
+ expander := graph.NewConcurrentExpander(redispatcher)
+ lookupSubjectsHandler := graph.NewConcurrentLookupSubjects(redispatcher, concurrencyLimits.LookupSubjects, chunkSize)
+ lookupResourcesHandler2 := graph.NewCursoredLookupResources2(redispatcher, redispatcher, typeSet, concurrencyLimits.LookupResources, chunkSize)
+
+ return &localDispatcher{
+ checker: checker,
+ expander: expander,
+ lookupSubjectsHandler: lookupSubjectsHandler,
+ lookupResourcesHandler2: lookupResourcesHandler2,
+ }
+}
+
+type localDispatcher struct {
+ checker *graph.ConcurrentChecker
+ expander *graph.ConcurrentExpander
+ lookupSubjectsHandler *graph.ConcurrentLookupSubjects
+ lookupResourcesHandler2 *graph.CursoredLookupResources2
+}
+
+func (ld *localDispatcher) loadNamespace(ctx context.Context, nsName string, revision datastore.Revision) (*core.NamespaceDefinition, error) {
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(revision)
+
+ // Load namespace and relation from the datastore
+ ns, _, err := ds.ReadNamespaceByName(ctx, nsName)
+ if err != nil {
+ return nil, rewriteNamespaceError(err)
+ }
+
+ return ns, err
+}
+
+func (ld *localDispatcher) parseRevision(ctx context.Context, s string) (datastore.Revision, error) {
+ ds := datastoremw.MustFromContext(ctx)
+ return ds.RevisionFromString(s)
+}
+
+func (ld *localDispatcher) lookupRelation(_ context.Context, ns *core.NamespaceDefinition, relationName string) (*core.Relation, error) {
+ var relation *core.Relation
+ for _, candidate := range ns.Relation {
+ if candidate.Name == relationName {
+ relation = candidate
+ break
+ }
+ }
+
+ if relation == nil {
+ return nil, NewRelationNotFoundErr(ns.Name, relationName)
+ }
+
+ return relation, nil
+}
+
+// DispatchCheck implements dispatch.Check interface
+func (ld *localDispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) {
+ resourceType := tuple.StringCoreRR(req.ResourceRelation)
+ spanName := "DispatchCheck → " + resourceType + "@" + req.Subject.Namespace + "#" + req.Subject.Relation
+
+ nodeID, err := nodeid.FromContext(ctx)
+ if err != nil {
+ log.Err(err).Msg("failed to get node ID")
+ }
+
+ ctx, span := tracer.Start(ctx, spanName, trace.WithAttributes(
+ attribute.String("resource-type", resourceType),
+ attribute.StringSlice("resource-ids", req.ResourceIds),
+ attribute.String("subject", tuple.StringCoreONR(req.Subject)),
+ attribute.String("node-id", nodeID),
+ ))
+ defer span.End()
+
+ if err := dispatch.CheckDepth(ctx, req); err != nil {
+ if req.Debug != v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING {
+ return &v1.DispatchCheckResponse{
+ Metadata: &v1.ResponseMeta{
+ DispatchCount: 0,
+ },
+ }, rewriteError(ctx, err)
+ }
+
+ // NOTE: we return debug information here to ensure tooling can see the cycle.
+ nodeID, nerr := nodeid.FromContext(ctx)
+ if nerr != nil {
+ log.Err(nerr).Msg("failed to get nodeID from context")
+ }
+
+ return &v1.DispatchCheckResponse{
+ Metadata: &v1.ResponseMeta{
+ DispatchCount: 0,
+ DebugInfo: &v1.DebugInformation{
+ Check: &v1.CheckDebugTrace{
+ Request: req,
+ SourceId: nodeID,
+ },
+ },
+ },
+ }, rewriteError(ctx, err)
+ }
+
+ revision, err := ld.parseRevision(ctx, req.Metadata.AtRevision)
+ if err != nil {
+ return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err)
+ }
+
+ ns, err := ld.loadNamespace(ctx, req.ResourceRelation.Namespace, revision)
+ if err != nil {
+ return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err)
+ }
+
+ relation, err := ld.lookupRelation(ctx, ns, req.ResourceRelation.Relation)
+ if err != nil {
+ return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err)
+ }
+
+ // If the relation is aliasing another one and the subject does not have the same type as
+ // resource, load the aliased relation and dispatch to it. We cannot use the alias if the
+ // resource and subject types are the same because a check on the *exact same* resource and
+ // subject must pass, and we don't know how many intermediate steps may hit that case.
+ if relation.AliasingRelation != "" && req.ResourceRelation.Namespace != req.Subject.Namespace {
+ relation, err := ld.lookupRelation(ctx, ns, relation.AliasingRelation)
+ if err != nil {
+ return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, rewriteError(ctx, err)
+ }
+
+ // Rewrite the request over the aliased relation.
+ validatedReq := graph.ValidatedCheckRequest{
+ DispatchCheckRequest: &v1.DispatchCheckRequest{
+ ResourceRelation: &core.RelationReference{
+ Namespace: req.ResourceRelation.Namespace,
+ Relation: relation.Name,
+ },
+ ResourceIds: req.ResourceIds,
+ Subject: req.Subject,
+ Metadata: req.Metadata,
+ Debug: req.Debug,
+ CheckHints: req.CheckHints,
+ },
+ Revision: revision,
+ OriginalRelationName: req.ResourceRelation.Relation,
+ }
+
+ resp, err := ld.checker.Check(ctx, validatedReq, relation)
+ return resp, rewriteError(ctx, err)
+ }
+
+ resp, err := ld.checker.Check(ctx, graph.ValidatedCheckRequest{
+ DispatchCheckRequest: req,
+ Revision: revision,
+ }, relation)
+ return resp, rewriteError(ctx, err)
+}
+
+// DispatchExpand implements dispatch.Expand interface
+func (ld *localDispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) {
+ nodeID, err := nodeid.FromContext(ctx)
+ if err != nil {
+ log.Err(err).Msg("failed to get node ID")
+ }
+
+ ctx, span := tracer.Start(ctx, "DispatchExpand", trace.WithAttributes(
+ attribute.String("start", tuple.StringCoreONR(req.ResourceAndRelation)),
+ attribute.String("node-id", nodeID),
+ ))
+ defer span.End()
+
+ if err := dispatch.CheckDepth(ctx, req); err != nil {
+ return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err
+ }
+
+ revision, err := ld.parseRevision(ctx, req.Metadata.AtRevision)
+ if err != nil {
+ return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err
+ }
+
+ ns, err := ld.loadNamespace(ctx, req.ResourceAndRelation.Namespace, revision)
+ if err != nil {
+ return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err
+ }
+
+ relation, err := ld.lookupRelation(ctx, ns, req.ResourceAndRelation.Relation)
+ if err != nil {
+ return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err
+ }
+
+ return ld.expander.Expand(ctx, graph.ValidatedExpandRequest{
+ DispatchExpandRequest: req,
+ Revision: revision,
+ }, relation)
+}
+
+func (ld *localDispatcher) DispatchLookupResources2(
+ req *v1.DispatchLookupResources2Request,
+ stream dispatch.LookupResources2Stream,
+) error {
+ nodeID, err := nodeid.FromContext(stream.Context())
+ if err != nil {
+ log.Err(err).Msg("failed to get node ID")
+ }
+
+ ctx, span := tracer.Start(stream.Context(), "DispatchLookupResources2", trace.WithAttributes(
+ attribute.String("resource-type", tuple.StringCoreRR(req.ResourceRelation)),
+ attribute.String("subject-type", tuple.StringCoreRR(req.SubjectRelation)),
+ attribute.StringSlice("subject-ids", req.SubjectIds),
+ attribute.String("terminal-subject", tuple.StringCoreONR(req.TerminalSubject)),
+ attribute.String("node-id", nodeID),
+ ))
+ defer span.End()
+
+ if err := dispatch.CheckDepth(ctx, req); err != nil {
+ return err
+ }
+
+ revision, err := ld.parseRevision(ctx, req.Metadata.AtRevision)
+ if err != nil {
+ return err
+ }
+
+ return ld.lookupResourcesHandler2.LookupResources2(
+ graph.ValidatedLookupResources2Request{
+ DispatchLookupResources2Request: req,
+ Revision: revision,
+ },
+ dispatch.StreamWithContext(ctx, stream),
+ )
+}
+
+// DispatchLookupSubjects implements dispatch.LookupSubjects interface
+func (ld *localDispatcher) DispatchLookupSubjects(
+ req *v1.DispatchLookupSubjectsRequest,
+ stream dispatch.LookupSubjectsStream,
+) error {
+ nodeID, err := nodeid.FromContext(stream.Context())
+ if err != nil {
+ log.Err(err).Msg("failed to get node ID")
+ }
+
+ resourceType := tuple.StringCoreRR(req.ResourceRelation)
+ subjectRelation := tuple.StringCoreRR(req.SubjectRelation)
+ spanName := "DispatchLookupSubjects → " + resourceType + "@" + subjectRelation
+
+ ctx, span := tracer.Start(stream.Context(), spanName, trace.WithAttributes(
+ attribute.String("resource-type", resourceType),
+ attribute.String("subject-type", subjectRelation),
+ attribute.StringSlice("resource-ids", req.ResourceIds),
+ attribute.String("node-id", nodeID),
+ ))
+ defer span.End()
+
+ if err := dispatch.CheckDepth(ctx, req); err != nil {
+ return err
+ }
+
+ revision, err := ld.parseRevision(ctx, req.Metadata.AtRevision)
+ if err != nil {
+ return err
+ }
+
+ return ld.lookupSubjectsHandler.LookupSubjects(
+ graph.ValidatedLookupSubjectsRequest{
+ DispatchLookupSubjectsRequest: req,
+ Revision: revision,
+ },
+ dispatch.StreamWithContext(ctx, stream),
+ )
+}
+
+func (ld *localDispatcher) Close() error {
+ return nil
+}
+
+func (ld *localDispatcher) ReadyState() dispatch.ReadyState {
+ return dispatch.ReadyState{IsReady: true}
+}
+
+func rewriteNamespaceError(original error) error {
+ nsNotFound := datastore.NamespaceNotFoundError{}
+
+ switch {
+ case errors.As(original, &nsNotFound):
+ return NewNamespaceNotFoundErr(nsNotFound.NotFoundNamespaceName())
+ case errors.As(original, &NamespaceNotFoundError{}):
+ fallthrough
+ case errors.As(original, &RelationNotFoundError{}):
+ return original
+ default:
+ return fmt.Errorf(errDispatch, original)
+ }
+}
+
+// rewriteError transforms graph errors into a gRPC Status
+func rewriteError(ctx context.Context, err error) error {
+ if err == nil {
+ return nil
+ }
+
+ // Check if the error can be directly used.
+ if st, ok := status.FromError(err); ok {
+ return st.Err()
+ }
+
+ switch {
+ case errors.Is(err, context.DeadlineExceeded):
+ return status.Errorf(codes.DeadlineExceeded, "%s", err)
+ case errors.Is(err, context.Canceled):
+ err := context.Cause(ctx)
+ if err != nil {
+ if _, ok := status.FromError(err); ok {
+ return err
+ }
+ }
+
+ return status.Errorf(codes.Canceled, "%s", err)
+ default:
+ log.Ctx(ctx).Err(err).Msg("received unexpected graph error")
+ return err
+ }
+}
+
+var emptyMetadata = &v1.ResponseMeta{
+ DispatchCount: 0,
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/graph/zz_generated.options.go b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/zz_generated.options.go
new file mode 100644
index 0000000..9a0a7fc
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/dispatch/graph/zz_generated.options.go
@@ -0,0 +1,92 @@
+// Code generated by github.com/ecordell/optgen. DO NOT EDIT.
+package graph
+
+import (
+ defaults "github.com/creasty/defaults"
+ helpers "github.com/ecordell/optgen/helpers"
+)
+
+type ConcurrencyLimitsOption func(c *ConcurrencyLimits)
+
+// NewConcurrencyLimitsWithOptions creates a new ConcurrencyLimits with the passed in options set
+func NewConcurrencyLimitsWithOptions(opts ...ConcurrencyLimitsOption) *ConcurrencyLimits {
+ c := &ConcurrencyLimits{}
+ for _, o := range opts {
+ o(c)
+ }
+ return c
+}
+
+// NewConcurrencyLimitsWithOptionsAndDefaults creates a new ConcurrencyLimits with the passed in options set starting from the defaults
+func NewConcurrencyLimitsWithOptionsAndDefaults(opts ...ConcurrencyLimitsOption) *ConcurrencyLimits {
+ c := &ConcurrencyLimits{}
+ defaults.MustSet(c)
+ for _, o := range opts {
+ o(c)
+ }
+ return c
+}
+
+// ToOption returns a new ConcurrencyLimitsOption that sets the values from the passed in ConcurrencyLimits
+func (c *ConcurrencyLimits) ToOption() ConcurrencyLimitsOption {
+ return func(to *ConcurrencyLimits) {
+ to.Check = c.Check
+ to.ReachableResources = c.ReachableResources
+ to.LookupResources = c.LookupResources
+ to.LookupSubjects = c.LookupSubjects
+ }
+}
+
+// DebugMap returns a map form of ConcurrencyLimits for debugging
+func (c ConcurrencyLimits) DebugMap() map[string]any {
+ debugMap := map[string]any{}
+ debugMap["Check"] = helpers.DebugValue(c.Check, false)
+ debugMap["ReachableResources"] = helpers.DebugValue(c.ReachableResources, false)
+ debugMap["LookupResources"] = helpers.DebugValue(c.LookupResources, false)
+ debugMap["LookupSubjects"] = helpers.DebugValue(c.LookupSubjects, false)
+ return debugMap
+}
+
+// ConcurrencyLimitsWithOptions configures an existing ConcurrencyLimits with the passed in options set
+func ConcurrencyLimitsWithOptions(c *ConcurrencyLimits, opts ...ConcurrencyLimitsOption) *ConcurrencyLimits {
+ for _, o := range opts {
+ o(c)
+ }
+ return c
+}
+
+// WithOptions configures the receiver ConcurrencyLimits with the passed in options set
+func (c *ConcurrencyLimits) WithOptions(opts ...ConcurrencyLimitsOption) *ConcurrencyLimits {
+ for _, o := range opts {
+ o(c)
+ }
+ return c
+}
+
+// WithCheck returns an option that can set Check on a ConcurrencyLimits
+func WithCheck(check uint16) ConcurrencyLimitsOption {
+ return func(c *ConcurrencyLimits) {
+ c.Check = check
+ }
+}
+
+// WithReachableResources returns an option that can set ReachableResources on a ConcurrencyLimits
+func WithReachableResources(reachableResources uint16) ConcurrencyLimitsOption {
+ return func(c *ConcurrencyLimits) {
+ c.ReachableResources = reachableResources
+ }
+}
+
+// WithLookupResources returns an option that can set LookupResources on a ConcurrencyLimits
+func WithLookupResources(lookupResources uint16) ConcurrencyLimitsOption {
+ return func(c *ConcurrencyLimits) {
+ c.LookupResources = lookupResources
+ }
+}
+
+// WithLookupSubjects returns an option that can set LookupSubjects on a ConcurrencyLimits
+func WithLookupSubjects(lookupSubjects uint16) ConcurrencyLimitsOption {
+ return func(c *ConcurrencyLimits) {
+ c.LookupSubjects = lookupSubjects
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go b/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go
new file mode 100644
index 0000000..1d6636c
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/dispatch/stream.go
@@ -0,0 +1,187 @@
+package dispatch
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+
+ grpc "google.golang.org/grpc"
+)
+
+// Stream defines the interface generically matching a streaming dispatch response.
+type Stream[T any] interface {
+ // Publish publishes the result to the stream.
+ Publish(T) error
+
+ // Context returns the context for the stream.
+ Context() context.Context
+}
+
+type grpcStream[T any] interface {
+ grpc.ServerStream
+ Send(T) error
+}
+
+// WrapGRPCStream wraps a gRPC result stream with a concurrent-safe dispatch stream. This is
+// necessary because gRPC response streams are *not concurrent safe*.
+// See: https://groups.google.com/g/grpc-io/c/aI6L6M4fzQ0?pli=1
+func WrapGRPCStream[R any, S grpcStream[R]](grpcStream S) Stream[R] {
+ return &concurrentSafeStream[R]{
+ grpcStream: grpcStream,
+ mu: sync.Mutex{},
+ }
+}
+
+type concurrentSafeStream[T any] struct {
+ grpcStream grpcStream[T] // GUARDED_BY(mu)
+ mu sync.Mutex
+}
+
+func (s *concurrentSafeStream[T]) Context() context.Context {
+ return s.grpcStream.Context()
+}
+
+func (s *concurrentSafeStream[T]) Publish(result T) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.grpcStream.Send(result)
+}
+
+// NewCollectingDispatchStream creates a new CollectingDispatchStream.
+func NewCollectingDispatchStream[T any](ctx context.Context) *CollectingDispatchStream[T] {
+ return &CollectingDispatchStream[T]{
+ ctx: ctx,
+ results: nil,
+ mu: sync.Mutex{},
+ }
+}
+
+// CollectingDispatchStream is a dispatch stream that collects results in memory.
+type CollectingDispatchStream[T any] struct {
+ ctx context.Context
+ results []T // GUARDED_BY(mu)
+ mu sync.Mutex
+}
+
+func (s *CollectingDispatchStream[T]) Context() context.Context {
+ return s.ctx
+}
+
+func (s *CollectingDispatchStream[T]) Results() []T {
+ return s.results
+}
+
+func (s *CollectingDispatchStream[T]) Publish(result T) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.results = append(s.results, result)
+ return nil
+}
+
+// WrappedDispatchStream is a dispatch stream that wraps another dispatch stream, and performs
+// an operation on each result before puppeting back up to the parent stream.
+type WrappedDispatchStream[T any] struct {
+ Stream Stream[T]
+ Ctx context.Context
+ Processor func(result T) (T, bool, error)
+}
+
+func (s *WrappedDispatchStream[T]) Publish(result T) error {
+ if s.Processor == nil {
+ return s.Stream.Publish(result)
+ }
+
+ processed, ok, err := s.Processor(result)
+ if err != nil {
+ return err
+ }
+ if !ok {
+ return nil
+ }
+
+ return s.Stream.Publish(processed)
+}
+
+func (s *WrappedDispatchStream[T]) Context() context.Context {
+ return s.Ctx
+}
+
+// StreamWithContext returns the given dispatch stream, wrapped to return the given context.
+func StreamWithContext[T any](context context.Context, stream Stream[T]) Stream[T] {
+ return &WrappedDispatchStream[T]{
+ Stream: stream,
+ Ctx: context,
+ Processor: nil,
+ }
+}
+
+// HandlingDispatchStream is a dispatch stream that executes a handler for each item published.
+// It uses an internal mutex to ensure it is thread safe.
+type HandlingDispatchStream[T any] struct {
+ ctx context.Context
+ processor func(result T) error // GUARDED_BY(mu)
+ mu sync.Mutex
+}
+
+// NewHandlingDispatchStream returns a new handling dispatch stream.
+func NewHandlingDispatchStream[T any](ctx context.Context, processor func(result T) error) Stream[T] {
+ return &HandlingDispatchStream[T]{
+ ctx: ctx,
+ processor: processor,
+ mu: sync.Mutex{},
+ }
+}
+
+func (s *HandlingDispatchStream[T]) Publish(result T) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.processor == nil {
+ return nil
+ }
+
+ return s.processor(result)
+}
+
+func (s *HandlingDispatchStream[T]) Context() context.Context {
+ return s.ctx
+}
+
+// CountingDispatchStream is a dispatch stream that counts the number of items published.
+// It uses an internal atomic int to ensure it is thread safe.
+type CountingDispatchStream[T any] struct {
+ Stream Stream[T]
+ count *atomic.Uint64
+}
+
+func NewCountingDispatchStream[T any](wrapped Stream[T]) *CountingDispatchStream[T] {
+ return &CountingDispatchStream[T]{
+ Stream: wrapped,
+ count: &atomic.Uint64{},
+ }
+}
+
+func (s *CountingDispatchStream[T]) PublishedCount() uint64 {
+ return s.count.Load()
+}
+
+func (s *CountingDispatchStream[T]) Publish(result T) error {
+ err := s.Stream.Publish(result)
+ if err != nil {
+ return err
+ }
+
+ s.count.Add(1)
+ return nil
+}
+
+func (s *CountingDispatchStream[T]) Context() context.Context {
+ return s.Stream.Context()
+}
+
+// Ensure the streams implement the interface.
+var (
+ _ Stream[any] = &CollectingDispatchStream[any]{}
+ _ Stream[any] = &WrappedDispatchStream[any]{}
+ _ Stream[any] = &CountingDispatchStream[any]{}
+)
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/check.go b/vendor/github.com/authzed/spicedb/internal/graph/check.go
new file mode 100644
index 0000000..65bcb50
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/check.go
@@ -0,0 +1,1354 @@
+package graph
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/samber/lo"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/trace"
+ "google.golang.org/protobuf/types/known/durationpb"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/graph/hints"
+ log "github.com/authzed/spicedb/internal/logging"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/namespace"
+ "github.com/authzed/spicedb/internal/taskrunner"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/middleware/nodeid"
+ nspkg "github.com/authzed/spicedb/pkg/namespace"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+var tracer = otel.Tracer("spicedb/internal/graph/check")
+
+var dispatchChunkCountHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{
+ Name: "spicedb_check_dispatch_chunk_count",
+ Help: "number of chunks when dispatching in check",
+ Buckets: []float64{1, 2, 3, 5, 10, 25, 100, 250},
+})
+
+var directDispatchQueryHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{
+ Name: "spicedb_check_direct_dispatch_query_count",
+ Help: "number of queries made per direct dispatch",
+ Buckets: []float64{1, 2},
+})
+
+const noOriginalRelation = ""
+
+func init() {
+ prometheus.MustRegister(directDispatchQueryHistogram)
+ prometheus.MustRegister(dispatchChunkCountHistogram)
+}
+
+// NewConcurrentChecker creates an instance of ConcurrentChecker.
+func NewConcurrentChecker(d dispatch.Check, concurrencyLimit uint16, dispatchChunkSize uint16) *ConcurrentChecker {
+ return &ConcurrentChecker{d, concurrencyLimit, dispatchChunkSize}
+}
+
+// ConcurrentChecker exposes a method to perform Check requests, and delegates subproblems to the
+// provided dispatch.Check instance.
+type ConcurrentChecker struct {
+ d dispatch.Check
+ concurrencyLimit uint16
+ dispatchChunkSize uint16
+}
+
+// ValidatedCheckRequest represents a request after it has been validated and parsed for internal
+// consumption.
+type ValidatedCheckRequest struct {
+ *v1.DispatchCheckRequest
+ Revision datastore.Revision
+
+ // OriginalRelationName is the original relation/permission name that was used in the request,
+ // before being changed due to aliasing.
+ OriginalRelationName string
+}
+
+// currentRequestContext holds context information for the current request being
+// processed.
+type currentRequestContext struct {
+ // parentReq is the parent request being processed.
+ parentReq ValidatedCheckRequest
+
+ // filteredResourceIDs are those resource IDs to be checked after filtering for
+ // any resource IDs found directly matching the incoming subject.
+ //
+ // For example, a check of resources `user:{tom,sarah,fred}` and subject `user:sarah` will
+ // result in this slice containing `tom` and `fred`, but not `sarah`, as she was found as a
+ // match.
+ //
+ // This check and filter occurs via the filterForFoundMemberResource function in the
+ // checkInternal function before the rest of the checking logic is run. This slice should never
+ // be empty.
+ filteredResourceIDs []string
+
+ // resultsSetting is the results setting to use for this request and all subsequent
+ // requests.
+ resultsSetting v1.DispatchCheckRequest_ResultsSetting
+
+ // dispatchChunkSize is the maximum number of resource IDs that can be specified in each dispatch.
+ dispatchChunkSize uint16
+}
+
+// Check performs a check request with the provided request and context
+func (cc *ConcurrentChecker) Check(ctx context.Context, req ValidatedCheckRequest, relation *core.Relation) (*v1.DispatchCheckResponse, error) {
+ var startTime *time.Time
+ if req.Debug != v1.DispatchCheckRequest_NO_DEBUG {
+ now := time.Now()
+ startTime = &now
+ }
+
+ resolved := cc.checkInternal(ctx, req, relation)
+ resolved.Resp.Metadata = addCallToResponseMetadata(resolved.Resp.Metadata)
+ if req.Debug == v1.DispatchCheckRequest_NO_DEBUG {
+ return resolved.Resp, resolved.Err
+ }
+
+ nodeID, err := nodeid.FromContext(ctx)
+ if err != nil {
+ // NOTE: we ignore this error here as if the node ID is missing, the debug
+ // trace is still valid.
+ log.Err(err).Msg("failed to get node ID")
+ }
+
+ // Add debug information if requested.
+ debugInfo := resolved.Resp.Metadata.DebugInfo
+ if debugInfo == nil {
+ debugInfo = &v1.DebugInformation{
+ Check: &v1.CheckDebugTrace{
+ TraceId: NewTraceID(),
+ SourceId: nodeID,
+ },
+ }
+ } else if debugInfo.Check != nil && debugInfo.Check.SourceId == "" {
+ debugInfo.Check.SourceId = nodeID
+ }
+
+ // Remove the traversal bloom from the debug request to save some data over the
+ // wire.
+ clonedRequest := req.DispatchCheckRequest.CloneVT()
+ clonedRequest.Metadata.TraversalBloom = nil
+
+ debugInfo.Check.Request = clonedRequest
+ debugInfo.Check.Duration = durationpb.New(time.Since(*startTime))
+
+ if nspkg.GetRelationKind(relation) == iv1.RelationMetadata_PERMISSION {
+ debugInfo.Check.ResourceRelationType = v1.CheckDebugTrace_PERMISSION
+ } else if nspkg.GetRelationKind(relation) == iv1.RelationMetadata_RELATION {
+ debugInfo.Check.ResourceRelationType = v1.CheckDebugTrace_RELATION
+ }
+
+ // Build the results for the debug trace.
+ results := make(map[string]*v1.ResourceCheckResult, len(req.DispatchCheckRequest.ResourceIds))
+ for _, resourceID := range req.DispatchCheckRequest.ResourceIds {
+ if found, ok := resolved.Resp.ResultsByResourceId[resourceID]; ok {
+ results[resourceID] = found
+ }
+ }
+ debugInfo.Check.Results = results
+
+ // If there is existing debug information in the error, then place it as the subproblem of the current
+ // debug information.
+ if existingDebugInfo, ok := spiceerrors.GetDetails[*v1.DebugInformation](resolved.Err); ok {
+ debugInfo.Check.SubProblems = []*v1.CheckDebugTrace{existingDebugInfo.Check}
+ }
+
+ resolved.Resp.Metadata.DebugInfo = debugInfo
+
+ // If there is an error and it is already a gRPC error, add the debug information
+ // into the details portion of the payload. This allows the client to see the debug
+ // information, as gRPC will only return the error.
+ updatedErr := spiceerrors.WithReplacedDetails(resolved.Err, debugInfo)
+ return resolved.Resp, updatedErr
+}
+
+func (cc *ConcurrentChecker) checkInternal(ctx context.Context, req ValidatedCheckRequest, relation *core.Relation) CheckResult {
+ spiceerrors.DebugAssert(func() bool {
+ return relation.GetUsersetRewrite() != nil || relation.GetTypeInformation() != nil
+ }, "found relation without type information")
+
+ // Ensure that we have at least one resource ID for which to execute the check.
+ if len(req.ResourceIds) == 0 {
+ return checkResultError(
+ spiceerrors.MustBugf("empty resource IDs given to dispatched check"),
+ emptyMetadata,
+ )
+ }
+
+ // Ensure that we are not performing a check for a wildcard as the subject.
+ if req.Subject.ObjectId == tuple.PublicWildcard {
+ return checkResultError(NewWildcardNotAllowedErr("cannot perform check on wildcard subject", "subject.object_id"), emptyMetadata)
+ }
+
+ // Deduplicate any incoming resource IDs.
+ resourceIds := lo.Uniq(req.ResourceIds)
+
+ // Filter the incoming resource IDs for any which match the subject directly. For example, if we receive
+ // a check for resource `user:{tom, fred, sarah}#...` and a subject of `user:sarah#...`, then we know
+ // that `user:sarah#...` is a valid "member" of the resource, as it matches exactly.
+ //
+ // If the filtering results in no further resource IDs to check, or a result is found and a single
+ // result is allowed, we terminate early.
+ membershipSet, filteredResourcesIds := filterForFoundMemberResource(req.ResourceRelation, resourceIds, req.Subject)
+ if membershipSet.HasDeterminedMember() && req.DispatchCheckRequest.ResultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT {
+ return checkResultsForMembership(membershipSet, emptyMetadata)
+ }
+
+ // Filter for check hints, if any.
+ if len(req.CheckHints) > 0 {
+ subject := tuple.FromCoreObjectAndRelation(req.Subject)
+ filteredResourcesIdsSet := mapz.NewSet(filteredResourcesIds...)
+ for _, checkHint := range req.CheckHints {
+ resourceID, ok := hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.ResourceRelation.Relation, subject)
+ if ok {
+ filteredResourcesIdsSet.Delete(resourceID)
+ continue
+ }
+
+ if req.OriginalRelationName != "" {
+ resourceID, ok = hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.OriginalRelationName, subject)
+ if ok {
+ filteredResourcesIdsSet.Delete(resourceID)
+ }
+ }
+ }
+ filteredResourcesIds = filteredResourcesIdsSet.AsSlice()
+ }
+
+ if len(filteredResourcesIds) == 0 {
+ return combineWithCheckHints(combineResultWithFoundResources(noMembers(), membershipSet), req)
+ }
+
+ // NOTE: We can always allow a single result if we're only trying to find the results for a
+ // single resource ID. This "reset" allows for short circuiting of downstream dispatched calls.
+ resultsSetting := req.ResultsSetting
+ if len(filteredResourcesIds) == 1 {
+ resultsSetting = v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT
+ }
+
+ crc := currentRequestContext{
+ parentReq: req,
+ filteredResourceIDs: filteredResourcesIds,
+ resultsSetting: resultsSetting,
+ dispatchChunkSize: cc.dispatchChunkSize,
+ }
+
+ if req.Debug == v1.DispatchCheckRequest_ENABLE_TRACE_DEBUGGING {
+ crc.dispatchChunkSize = 1
+ }
+
+ if relation.UsersetRewrite == nil {
+ return combineWithCheckHints(combineResultWithFoundResources(cc.checkDirect(ctx, crc, relation), membershipSet), req)
+ }
+
+ return combineWithCheckHints(combineResultWithFoundResources(cc.checkUsersetRewrite(ctx, crc, relation.UsersetRewrite), membershipSet), req)
+}
+
+func combineWithComputedHints(result CheckResult, hints map[string]*v1.ResourceCheckResult) CheckResult {
+ if len(hints) == 0 {
+ return result
+ }
+
+ for resourceID, hint := range hints {
+ if _, ok := result.Resp.ResultsByResourceId[resourceID]; ok {
+ return checkResultError(
+ spiceerrors.MustBugf("check hint for resource ID %q, which already exists", resourceID),
+ emptyMetadata,
+ )
+ }
+
+ if result.Resp.ResultsByResourceId == nil {
+ result.Resp.ResultsByResourceId = make(map[string]*v1.ResourceCheckResult)
+ }
+ result.Resp.ResultsByResourceId[resourceID] = hint
+ }
+
+ return result
+}
+
+func combineWithCheckHints(result CheckResult, req ValidatedCheckRequest) CheckResult {
+ if len(req.CheckHints) == 0 {
+ return result
+ }
+
+ subject := tuple.FromCoreObjectAndRelation(req.Subject)
+ for _, checkHint := range req.CheckHints {
+ resourceID, ok := hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.ResourceRelation.Relation, subject)
+ if !ok {
+ if req.OriginalRelationName != "" {
+ resourceID, ok = hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.OriginalRelationName, subject)
+ }
+
+ if !ok {
+ continue
+ }
+ }
+
+ if result.Resp.ResultsByResourceId == nil {
+ result.Resp.ResultsByResourceId = make(map[string]*v1.ResourceCheckResult)
+ }
+
+ if _, ok := result.Resp.ResultsByResourceId[resourceID]; ok {
+ return checkResultError(
+ spiceerrors.MustBugf("check hint for resource ID %q, which already exists", resourceID),
+ emptyMetadata,
+ )
+ }
+
+ result.Resp.ResultsByResourceId[resourceID] = checkHint.Result
+ }
+
+ return result
+}
+
+func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequestContext, relation *core.Relation) CheckResult {
+ ctx, span := tracer.Start(ctx, "checkDirect")
+ defer span.End()
+
+ // Build a filter for finding the direct relationships for the check. There are three
+ // classes of relationships to be found:
+ // 1) the target subject itself, if allowed on this relation
+ // 2) the wildcard form of the target subject, if a wildcard is allowed on this relation
+ // 3) Otherwise, any non-terminal (non-`...`) subjects, if allowed on this relation, to be
+ // redispatched outward
+ totalNonTerminals := 0
+ totalDirectSubjects := 0
+ totalWildcardSubjects := 0
+
+ defer func() {
+ if totalNonTerminals > 0 {
+ span.SetName("non terminal")
+ } else if totalDirectSubjects > 0 {
+ span.SetName("terminal")
+ } else {
+ span.SetName("wildcard subject")
+ }
+ }()
+ log.Ctx(ctx).Trace().Object("direct", crc.parentReq).Send()
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision)
+
+ directSubjectsAndWildcardsWithoutCaveats := 0
+ directSubjectsAndWildcardsWithoutExpiration := 0
+ nonTerminalsWithoutCaveats := 0
+ nonTerminalsWithoutExpiration := 0
+
+ for _, allowedDirectRelation := range relation.GetTypeInformation().GetAllowedDirectRelations() {
+ // If the namespace of the allowed direct relation matches the subject type, there are two
+ // cases to optimize:
+ // 1) Finding the target subject itself, as a direct lookup
+ // 2) Finding a wildcard for the subject type+relation
+ if allowedDirectRelation.GetNamespace() == crc.parentReq.Subject.Namespace {
+ if allowedDirectRelation.GetPublicWildcard() != nil {
+ totalWildcardSubjects++
+ } else if allowedDirectRelation.GetRelation() == crc.parentReq.Subject.Relation {
+ totalDirectSubjects++
+ }
+
+ if allowedDirectRelation.RequiredCaveat == nil {
+ directSubjectsAndWildcardsWithoutCaveats++
+ }
+
+ if allowedDirectRelation.RequiredExpiration == nil {
+ directSubjectsAndWildcardsWithoutExpiration++
+ }
+ }
+
+ // If the relation found is not an ellipsis, then this is a nested relation that
+ // might need to be followed, so indicate that such relationships should be returned
+ //
+ // TODO(jschorr): Use type information to *further* optimize this query around which nested
+ // relations can reach the target subject type.
+ if allowedDirectRelation.GetRelation() != tuple.Ellipsis {
+ totalNonTerminals++
+ if allowedDirectRelation.RequiredCaveat == nil {
+ nonTerminalsWithoutCaveats++
+ }
+ if allowedDirectRelation.RequiredExpiration == nil {
+ nonTerminalsWithoutExpiration++
+ }
+ }
+ }
+
+ nonTerminalsCanHaveCaveats := totalNonTerminals != nonTerminalsWithoutCaveats
+ nonTerminalsCanHaveExpiration := totalNonTerminals != nonTerminalsWithoutExpiration
+ hasNonTerminals := totalNonTerminals > 0
+
+ foundResources := NewMembershipSet()
+
+ // If the direct subject or a wildcard form can be found, issue a query for just that
+ // subject.
+ var queryCount float64
+ defer func() {
+ directDispatchQueryHistogram.Observe(queryCount)
+ }()
+
+ hasDirectSubject := totalDirectSubjects > 0
+ hasWildcardSubject := totalWildcardSubjects > 0
+ if hasDirectSubject || hasWildcardSubject {
+ directSubjectOrWildcardCanHaveCaveats := directSubjectsAndWildcardsWithoutCaveats != (totalDirectSubjects + totalWildcardSubjects)
+ directSubjectOrWildcardCanHaveExpiration := directSubjectsAndWildcardsWithoutExpiration != (totalDirectSubjects + totalWildcardSubjects)
+
+ subjectSelectors := []datastore.SubjectsSelector{}
+
+ if hasDirectSubject {
+ subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{
+ OptionalSubjectType: crc.parentReq.Subject.Namespace,
+ OptionalSubjectIds: []string{crc.parentReq.Subject.ObjectId},
+ RelationFilter: datastore.SubjectRelationFilter{}.WithRelation(crc.parentReq.Subject.Relation),
+ })
+ }
+
+ if hasWildcardSubject {
+ subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{
+ OptionalSubjectType: crc.parentReq.Subject.Namespace,
+ OptionalSubjectIds: []string{tuple.PublicWildcard},
+ RelationFilter: datastore.SubjectRelationFilter{}.WithEllipsisRelation(),
+ })
+ }
+
+ filter := datastore.RelationshipsFilter{
+ OptionalResourceType: crc.parentReq.ResourceRelation.Namespace,
+ OptionalResourceIds: crc.filteredResourceIDs,
+ OptionalResourceRelation: crc.parentReq.ResourceRelation.Relation,
+ OptionalSubjectsSelectors: subjectSelectors,
+ }
+
+ it, err := ds.QueryRelationships(ctx, filter,
+ options.WithSkipCaveats(!directSubjectOrWildcardCanHaveCaveats),
+ options.WithSkipExpiration(!directSubjectOrWildcardCanHaveExpiration),
+ options.WithQueryShape(queryshape.CheckPermissionSelectDirectSubjects),
+ )
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+ queryCount += 1.0
+
+ // Find the matching subject(s).
+ for rel, err := range it {
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+
+ // If the subject of the relationship matches the target subject, then we've found
+ // a result.
+ foundResources.AddDirectMember(rel.Resource.ObjectID, rel.OptionalCaveat)
+ if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT && foundResources.HasDeterminedMember() {
+ return checkResultsForMembership(foundResources, emptyMetadata)
+ }
+ }
+ }
+
+ // Filter down the resource IDs for further dispatch based on whether they exist as found
+ // subjects in the existing membership set.
+ furtherFilteredResourceIDs := make([]string, 0, len(crc.filteredResourceIDs)-foundResources.Size())
+ for _, resourceID := range crc.filteredResourceIDs {
+ if foundResources.HasConcreteResourceID(resourceID) {
+ continue
+ }
+
+ furtherFilteredResourceIDs = append(furtherFilteredResourceIDs, resourceID)
+ }
+
+ // If there are no possible non-terminals, then the check is completed.
+ if !hasNonTerminals || len(furtherFilteredResourceIDs) == 0 {
+ return checkResultsForMembership(foundResources, emptyMetadata)
+ }
+
+ // Otherwise, for any remaining resource IDs, query for redispatch.
+ filter := datastore.RelationshipsFilter{
+ OptionalResourceType: crc.parentReq.ResourceRelation.Namespace,
+ OptionalResourceIds: furtherFilteredResourceIDs,
+ OptionalResourceRelation: crc.parentReq.ResourceRelation.Relation,
+ OptionalSubjectsSelectors: []datastore.SubjectsSelector{
+ {
+ RelationFilter: datastore.SubjectRelationFilter{}.WithOnlyNonEllipsisRelations(),
+ },
+ },
+ }
+
+ it, err := ds.QueryRelationships(ctx, filter,
+ options.WithSkipCaveats(!nonTerminalsCanHaveCaveats),
+ options.WithSkipExpiration(!nonTerminalsCanHaveExpiration),
+ options.WithQueryShape(queryshape.CheckPermissionSelectIndirectSubjects),
+ )
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+ queryCount += 1.0
+
+ // Build the set of subjects over which to dispatch, along with metadata for
+ // mapping over caveats (if any).
+ checksToDispatch := newCheckDispatchSet()
+ for rel, err := range it {
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+ checksToDispatch.addForRelationship(rel)
+ }
+
+ // Dispatch and map to the associated resource ID(s).
+ toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize)
+ result := union(ctx, crc, toDispatch, func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) CheckResult {
+ // If there are caveats on any of the incoming relationships for the subjects to dispatch, then we must require all
+ // results to be found, as we need to ensure that all caveats are used for building the final expression.
+ resultsSetting := crc.resultsSetting
+ if dd.hasIncomingCaveats {
+ resultsSetting = v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS
+ }
+
+ childResult := cc.dispatch(ctx, crc, ValidatedCheckRequest{
+ &v1.DispatchCheckRequest{
+ ResourceRelation: dd.resourceType.ToCoreRR(),
+ ResourceIds: dd.resourceIds,
+ Subject: crc.parentReq.Subject,
+ ResultsSetting: resultsSetting,
+
+ Metadata: decrementDepth(crc.parentReq.Metadata),
+ Debug: crc.parentReq.Debug,
+ CheckHints: crc.parentReq.CheckHints,
+ },
+ crc.parentReq.Revision,
+ noOriginalRelation,
+ })
+
+ if childResult.Err != nil {
+ return childResult
+ }
+
+ return mapFoundResources(childResult, dd.resourceType, checksToDispatch)
+ }, cc.concurrencyLimit)
+
+ return combineResultWithFoundResources(result, foundResources)
+}
+
+func mapFoundResources(result CheckResult, resourceType tuple.RelationReference, checksToDispatch *checkDispatchSet) CheckResult {
+ // Map any resources found to the parent resource IDs.
+ membershipSet := NewMembershipSet()
+ for foundResourceID, result := range result.Resp.ResultsByResourceId {
+ resourceIDAndCaveats := checksToDispatch.mappingsForSubject(resourceType.ObjectType, foundResourceID, resourceType.Relation)
+
+ spiceerrors.DebugAssert(func() bool {
+ return len(resourceIDAndCaveats) > 0
+ }, "found resource ID without associated caveats")
+
+ for _, riac := range resourceIDAndCaveats {
+ membershipSet.AddMemberWithParentCaveat(riac.resourceID, result.Expression, riac.caveat)
+ }
+ }
+
+ if membershipSet.IsEmpty() {
+ return noMembersWithMetadata(result.Resp.Metadata)
+ }
+
+ return checkResultsForMembership(membershipSet, result.Resp.Metadata)
+}
+
+func (cc *ConcurrentChecker) checkUsersetRewrite(ctx context.Context, crc currentRequestContext, rewrite *core.UsersetRewrite) CheckResult {
+ switch rw := rewrite.RewriteOperation.(type) {
+ case *core.UsersetRewrite_Union:
+ if len(rw.Union.Child) > 1 {
+ var span trace.Span
+ ctx, span = tracer.Start(ctx, "+")
+ defer span.End()
+ }
+ return union(ctx, crc, rw.Union.Child, cc.runSetOperation, cc.concurrencyLimit)
+ case *core.UsersetRewrite_Intersection:
+ ctx, span := tracer.Start(ctx, "&")
+ defer span.End()
+ return all(ctx, crc, rw.Intersection.Child, cc.runSetOperation, cc.concurrencyLimit)
+ case *core.UsersetRewrite_Exclusion:
+ ctx, span := tracer.Start(ctx, "-")
+ defer span.End()
+ return difference(ctx, crc, rw.Exclusion.Child, cc.runSetOperation, cc.concurrencyLimit)
+ default:
+ return checkResultError(spiceerrors.MustBugf("unknown userset rewrite operator"), emptyMetadata)
+ }
+}
+
+func (cc *ConcurrentChecker) dispatch(ctx context.Context, _ currentRequestContext, req ValidatedCheckRequest) CheckResult {
+ log.Ctx(ctx).Trace().Object("dispatch", req).Send()
+ result, err := cc.d.DispatchCheck(ctx, req.DispatchCheckRequest)
+ return CheckResult{result, err}
+}
+
+func (cc *ConcurrentChecker) runSetOperation(ctx context.Context, crc currentRequestContext, childOneof *core.SetOperation_Child) CheckResult {
+ switch child := childOneof.ChildType.(type) {
+ case *core.SetOperation_Child_XThis:
+ return checkResultError(spiceerrors.MustBugf("use of _this is unsupported; please rewrite your schema"), emptyMetadata)
+ case *core.SetOperation_Child_ComputedUserset:
+ return cc.checkComputedUserset(ctx, crc, child.ComputedUserset, nil, nil)
+ case *core.SetOperation_Child_UsersetRewrite:
+ return cc.checkUsersetRewrite(ctx, crc, child.UsersetRewrite)
+ case *core.SetOperation_Child_TupleToUserset:
+ return checkTupleToUserset(ctx, cc, crc, child.TupleToUserset)
+ case *core.SetOperation_Child_FunctionedTupleToUserset:
+ switch child.FunctionedTupleToUserset.Function {
+ case core.FunctionedTupleToUserset_FUNCTION_ANY:
+ return checkTupleToUserset(ctx, cc, crc, child.FunctionedTupleToUserset)
+
+ case core.FunctionedTupleToUserset_FUNCTION_ALL:
+ return checkIntersectionTupleToUserset(ctx, cc, crc, child.FunctionedTupleToUserset)
+
+ default:
+ return checkResultError(spiceerrors.MustBugf("unknown userset function `%s`", child.FunctionedTupleToUserset.Function), emptyMetadata)
+ }
+
+ case *core.SetOperation_Child_XNil:
+ return noMembers()
+ default:
+ return checkResultError(spiceerrors.MustBugf("unknown set operation child `%T` in check", child), emptyMetadata)
+ }
+}
+
+func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc currentRequestContext, cu *core.ComputedUserset, rr *tuple.RelationReference, resourceIds []string) CheckResult {
+ ctx, span := tracer.Start(ctx, cu.Relation)
+ defer span.End()
+
+ var startNamespace string
+ var targetResourceIds []string
+ if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT {
+ if rr == nil || len(resourceIds) == 0 {
+ return checkResultError(spiceerrors.MustBugf("computed userset for tupleset without tuples"), emptyMetadata)
+ }
+
+ startNamespace = rr.ObjectType
+ targetResourceIds = resourceIds
+ } else if cu.Object == core.ComputedUserset_TUPLE_OBJECT {
+ if rr != nil {
+ return checkResultError(spiceerrors.MustBugf("computed userset for tupleset with wrong object type"), emptyMetadata)
+ }
+
+ startNamespace = crc.parentReq.ResourceRelation.Namespace
+ targetResourceIds = crc.filteredResourceIDs
+ }
+
+ targetRR := &core.RelationReference{
+ Namespace: startNamespace,
+ Relation: cu.Relation,
+ }
+
+ // If we will be dispatching to the goal's ONR, then we know that the ONR is a member.
+ membershipSet, updatedTargetResourceIds := filterForFoundMemberResource(targetRR, targetResourceIds, crc.parentReq.Subject)
+ if (membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT) || len(updatedTargetResourceIds) == 0 {
+ return checkResultsForMembership(membershipSet, emptyMetadata)
+ }
+
+ // Check if the target relation exists. If not, return nothing. This is only necessary
+ // for TTU-based computed usersets, as directly computed ones reference relations within
+ // the same namespace as the caller, and thus must be fully typed checked.
+ if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT {
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision)
+ err := namespace.CheckNamespaceAndRelation(ctx, targetRR.Namespace, targetRR.Relation, true, ds)
+ if err != nil {
+ if errors.As(err, &namespace.RelationNotFoundError{}) {
+ return noMembers()
+ }
+
+ return checkResultError(err, emptyMetadata)
+ }
+ }
+
+ result := cc.dispatch(ctx, crc, ValidatedCheckRequest{
+ &v1.DispatchCheckRequest{
+ ResourceRelation: targetRR,
+ ResourceIds: updatedTargetResourceIds,
+ Subject: crc.parentReq.Subject,
+ ResultsSetting: crc.resultsSetting,
+ Metadata: decrementDepth(crc.parentReq.Metadata),
+ Debug: crc.parentReq.Debug,
+ CheckHints: crc.parentReq.CheckHints,
+ },
+ crc.parentReq.Revision,
+ noOriginalRelation,
+ })
+ return combineResultWithFoundResources(result, membershipSet)
+}
+
+type Traits struct {
+ HasCaveats bool
+ HasExpiration bool
+}
+
+// TraitsForArrowRelation returns traits such as HasCaveats and HasExpiration if *any* of the subject
+// types of the given relation support caveats or expiration.
+func TraitsForArrowRelation(ctx context.Context, reader datastore.Reader, namespaceName string, relationName string) (Traits, error) {
+ // TODO(jschorr): Change to use the type system once we wire it through Check dispatch.
+ nsDef, _, err := reader.ReadNamespaceByName(ctx, namespaceName)
+ if err != nil {
+ return Traits{}, err
+ }
+
+ var relation *core.Relation
+ for _, rel := range nsDef.Relation {
+ if rel.Name == relationName {
+ relation = rel
+ break
+ }
+ }
+
+ if relation == nil || relation.TypeInformation == nil {
+ return Traits{}, fmt.Errorf("relation %q not found", relationName)
+ }
+
+ hasCaveats := false
+ hasExpiration := false
+
+ for _, allowedDirectRelation := range relation.TypeInformation.GetAllowedDirectRelations() {
+ if allowedDirectRelation.RequiredCaveat != nil {
+ hasCaveats = true
+ }
+
+ if allowedDirectRelation.RequiredExpiration != nil {
+ hasExpiration = true
+ }
+ }
+
+ return Traits{
+ HasCaveats: hasCaveats,
+ HasExpiration: hasExpiration,
+ }, nil
+}
+
+func queryOptionsForArrowRelation(ctx context.Context, ds datastore.Reader, namespaceName string, relationName string) ([]options.QueryOptionsOption, error) {
+ opts := make([]options.QueryOptionsOption, 0, 3)
+ opts = append(opts, options.WithQueryShape(queryshape.AllSubjectsForResources))
+
+ traits, err := TraitsForArrowRelation(ctx, ds, namespaceName, relationName)
+ if err != nil {
+ return nil, err
+ }
+
+ if !traits.HasCaveats {
+ opts = append(opts, options.WithSkipCaveats(true))
+ }
+
+ if !traits.HasExpiration {
+ opts = append(opts, options.WithSkipExpiration(true))
+ }
+
+ return opts, nil
+}
+
+func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) (*MembershipSet, []string) {
+ if resourceRelation.Namespace != subject.Namespace || resourceRelation.Relation != subject.Relation {
+ return nil, resourceIds
+ }
+
+ for index, resourceID := range resourceIds {
+ if subject.ObjectId == resourceID {
+ membershipSet := NewMembershipSet()
+ membershipSet.AddDirectMember(resourceID, nil)
+ return membershipSet, removeIndexFromSlice(resourceIds, index)
+ }
+ }
+
+ return nil, resourceIds
+}
+
+func removeIndexFromSlice[T any](s []T, index int) []T {
+ cpy := make([]T, 0, len(s)-1)
+ cpy = append(cpy, s[:index]...)
+ return append(cpy, s[index+1:]...)
+}
+
+type relation interface {
+ GetRelation() string
+}
+
+type ttu[T relation] interface {
+ GetComputedUserset() *core.ComputedUserset
+ GetTupleset() T
+}
+
+type checkResultWithType struct {
+ CheckResult
+
+ relationType tuple.RelationReference
+}
+
+func checkIntersectionTupleToUserset(
+ ctx context.Context,
+ cc *ConcurrentChecker,
+ crc currentRequestContext,
+ ttu *core.FunctionedTupleToUserset,
+) CheckResult {
+ // TODO(jschorr): use check hints here
+ ctx, span := tracer.Start(ctx, ttu.GetTupleset().GetRelation()+"-(all)->"+ttu.GetComputedUserset().Relation)
+ defer span.End()
+
+ // Query for the subjects over which to walk the TTU.
+ log.Ctx(ctx).Trace().Object("intersectionttu", crc.parentReq).Send()
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision)
+ queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation())
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+
+ it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{
+ OptionalResourceType: crc.parentReq.ResourceRelation.Namespace,
+ OptionalResourceIds: crc.filteredResourceIDs,
+ OptionalResourceRelation: ttu.GetTupleset().GetRelation(),
+ }, queryOpts...)
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+
+ checksToDispatch := newCheckDispatchSet()
+ subjectsByResourceID := mapz.NewMultiMap[string, tuple.ObjectAndRelation]()
+ for rel, err := range it {
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+
+ checksToDispatch.addForRelationship(rel)
+ subjectsByResourceID.Add(rel.Resource.ObjectID, rel.Subject)
+ }
+
+ // Convert the subjects into batched requests.
+ toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize)
+ if len(toDispatch) == 0 {
+ return noMembers()
+ }
+
+ // Run the dispatch for all the chunks. Unlike a standard TTU, we do *not* perform mapping here,
+ // as we need to access the results on a per subject basis. Instead, we keep each result and map
+ // by the relation type of the dispatched subject.
+ chunkResults, err := run(
+ ctx,
+ currentRequestContext{
+ parentReq: crc.parentReq,
+ filteredResourceIDs: crc.filteredResourceIDs,
+ resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS,
+ dispatchChunkSize: crc.dispatchChunkSize,
+ },
+ toDispatch,
+ func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) checkResultWithType {
+ resourceType := dd.resourceType
+ childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), &resourceType, dd.resourceIds)
+ return checkResultWithType{
+ CheckResult: childResult,
+ relationType: dd.resourceType,
+ }
+ },
+ cc.concurrencyLimit,
+ )
+ if err != nil {
+ return checkResultError(err, emptyMetadata)
+ }
+
+ // Create a membership set per-subject-type, representing the membership for each of the dispatched subjects.
+ resultsByDispatchedSubject := map[tuple.RelationReference]*MembershipSet{}
+ combinedMetadata := emptyMetadata
+ for _, result := range chunkResults {
+ if result.Err != nil {
+ return checkResultError(result.Err, emptyMetadata)
+ }
+
+ if _, ok := resultsByDispatchedSubject[result.relationType]; !ok {
+ resultsByDispatchedSubject[result.relationType] = NewMembershipSet()
+ }
+
+ resultsByDispatchedSubject[result.relationType].UnionWith(result.Resp.ResultsByResourceId)
+ combinedMetadata = combineResponseMetadata(ctx, combinedMetadata, result.Resp.Metadata)
+ }
+
+ // For each resource ID, check that there exist some sort of permission for *each* subject. If not, then the
+ // intersection for that resource fails. If all subjects have some sort of permission, then the resource ID is
+ // a member, perhaps caveated.
+ resourcesFound := NewMembershipSet()
+ for _, resourceID := range subjectsByResourceID.Keys() {
+ subjects, _ := subjectsByResourceID.Get(resourceID)
+ if len(subjects) == 0 {
+ return checkResultError(spiceerrors.MustBugf("no subjects found for resource ID %s", resourceID), emptyMetadata)
+ }
+
+ hasAllSubjects := true
+ caveats := make([]*core.CaveatExpression, 0, len(subjects))
+
+ // Check each of the subjects found for the resource ID and ensure that membership (at least caveated)
+ // was found for each. If any are not found, then the resource ID is not a member.
+ // We also collect up the caveats for each subject, as they will be added to the final result.
+ for _, subject := range subjects {
+ subjectTypeKey := subject.RelationReference()
+ results, ok := resultsByDispatchedSubject[subjectTypeKey]
+ if !ok {
+ hasAllSubjects = false
+ break
+ }
+
+ hasMembership, caveat := results.GetResourceID(subject.ObjectID)
+ if !hasMembership {
+ hasAllSubjects = false
+ break
+ }
+
+ if caveat != nil {
+ caveats = append(caveats, caveat)
+ }
+
+ // Add any caveats on the subject from the starting relationship(s) as well.
+ resourceIDAndCaveats := checksToDispatch.mappingsForSubject(subject.ObjectType, subject.ObjectID, subject.Relation)
+ for _, riac := range resourceIDAndCaveats {
+ if riac.caveat != nil {
+ caveats = append(caveats, wrapCaveat(riac.caveat))
+ }
+ }
+ }
+
+ if !hasAllSubjects {
+ continue
+ }
+
+ // Add the member to the membership set, with the caveats for each (if any).
+ resourcesFound.AddMemberWithOptionalCaveats(resourceID, caveats)
+ }
+
+ return checkResultsForMembership(resourcesFound, combinedMetadata)
+}
+
+func checkTupleToUserset[T relation](
+ ctx context.Context,
+ cc *ConcurrentChecker,
+ crc currentRequestContext,
+ ttu ttu[T],
+) CheckResult {
+ filteredResourceIDs := crc.filteredResourceIDs
+ hintsToReturn := make(map[string]*v1.ResourceCheckResult, len(crc.parentReq.CheckHints))
+ if len(crc.parentReq.CheckHints) > 0 {
+ filteredResourcesIdsSet := mapz.NewSet(crc.filteredResourceIDs...)
+
+ for _, checkHint := range crc.parentReq.CheckHints {
+ resourceID, ok := hints.AsCheckHintForArrow(
+ checkHint,
+ crc.parentReq.ResourceRelation.Namespace,
+ ttu.GetTupleset().GetRelation(),
+ ttu.GetComputedUserset().Relation,
+ tuple.FromCoreObjectAndRelation(crc.parentReq.Subject),
+ )
+ if !ok {
+ continue
+ }
+
+ filteredResourcesIdsSet.Delete(resourceID)
+ hintsToReturn[resourceID] = checkHint.Result
+ }
+
+ filteredResourceIDs = filteredResourcesIdsSet.AsSlice()
+ }
+
+ if len(filteredResourceIDs) == 0 {
+ return combineWithComputedHints(noMembers(), hintsToReturn)
+ }
+
+ ctx, span := tracer.Start(ctx, ttu.GetTupleset().GetRelation()+"->"+ttu.GetComputedUserset().Relation)
+ defer span.End()
+
+ log.Ctx(ctx).Trace().Object("ttu", crc.parentReq).Send()
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision)
+
+ queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation())
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+
+ it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{
+ OptionalResourceType: crc.parentReq.ResourceRelation.Namespace,
+ OptionalResourceIds: filteredResourceIDs,
+ OptionalResourceRelation: ttu.GetTupleset().GetRelation(),
+ }, queryOpts...)
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+
+ checksToDispatch := newCheckDispatchSet()
+ for rel, err := range it {
+ if err != nil {
+ return checkResultError(NewCheckFailureErr(err), emptyMetadata)
+ }
+ checksToDispatch.addForRelationship(rel)
+ }
+
+ toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize)
+ return combineWithComputedHints(union(
+ ctx,
+ crc,
+ toDispatch,
+ func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) CheckResult {
+ resourceType := dd.resourceType
+ childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), &resourceType, dd.resourceIds)
+ if childResult.Err != nil {
+ return childResult
+ }
+
+ return mapFoundResources(childResult, dd.resourceType, checksToDispatch)
+ },
+ cc.concurrencyLimit,
+ ), hintsToReturn)
+}
+
+func withDistinctMetadata(ctx context.Context, result CheckResult) CheckResult {
+ // NOTE: This is necessary to ensure unique debug information on the request and that debug
+ // information from the child metadata is *not* copied over.
+ clonedResp := result.Resp.CloneVT()
+ clonedResp.Metadata = combineResponseMetadata(ctx, emptyMetadata, clonedResp.Metadata)
+ return CheckResult{
+ Resp: clonedResp,
+ Err: result.Err,
+ }
+}
+
+// run runs all the children in parallel and returns the full set of results.
+func run[T any, R withError](
+ ctx context.Context,
+ crc currentRequestContext,
+ children []T,
+ handler func(ctx context.Context, crc currentRequestContext, child T) R,
+ concurrencyLimit uint16,
+) ([]R, error) {
+ if len(children) == 0 {
+ return nil, nil
+ }
+
+ if len(children) == 1 {
+ return []R{handler(ctx, crc, children[0])}, nil
+ }
+
+ resultChan := make(chan R, len(children))
+ childCtx, cancelFn := context.WithCancel(ctx)
+ dispatchAllAsync(childCtx, crc, children, handler, resultChan, concurrencyLimit)
+ defer cancelFn()
+
+ results := make([]R, 0, len(children))
+ for i := 0; i < len(children); i++ {
+ select {
+ case result := <-resultChan:
+ results = append(results, result)
+
+ case <-ctx.Done():
+ log.Ctx(ctx).Trace().Msg("anyCanceled")
+ return nil, ctx.Err()
+ }
+ }
+
+ return results, nil
+}
+
+// union returns whether any one of the lazy checks pass, and is used for union.
+func union[T any](
+ ctx context.Context,
+ crc currentRequestContext,
+ children []T,
+ handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult,
+ concurrencyLimit uint16,
+) CheckResult {
+ if len(children) == 0 {
+ return noMembers()
+ }
+
+ if len(children) == 1 {
+ return withDistinctMetadata(ctx, handler(ctx, crc, children[0]))
+ }
+
+ resultChan := make(chan CheckResult, len(children))
+ childCtx, cancelFn := context.WithCancel(ctx)
+ dispatchAllAsync(childCtx, crc, children, handler, resultChan, concurrencyLimit)
+ defer cancelFn()
+
+ responseMetadata := emptyMetadata
+ membershipSet := NewMembershipSet()
+
+ for i := 0; i < len(children); i++ {
+ select {
+ case result := <-resultChan:
+ log.Ctx(ctx).Trace().Object("anyResult", result.Resp).Send()
+ responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata)
+ if result.Err != nil {
+ return checkResultError(result.Err, responseMetadata)
+ }
+
+ membershipSet.UnionWith(result.Resp.ResultsByResourceId)
+ if membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT {
+ return checkResultsForMembership(membershipSet, responseMetadata)
+ }
+
+ case <-ctx.Done():
+ log.Ctx(ctx).Trace().Msg("anyCanceled")
+ return checkResultError(context.Canceled, responseMetadata)
+ }
+ }
+
+ return checkResultsForMembership(membershipSet, responseMetadata)
+}
+
+// all returns whether all of the lazy checks pass, and is used for intersection.
+func all[T any](
+ ctx context.Context,
+ crc currentRequestContext,
+ children []T,
+ handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult,
+ concurrencyLimit uint16,
+) CheckResult {
+ if len(children) == 0 {
+ return noMembers()
+ }
+
+ if len(children) == 1 {
+ return withDistinctMetadata(ctx, handler(ctx, crc, children[0]))
+ }
+
+ responseMetadata := emptyMetadata
+
+ resultChan := make(chan CheckResult, len(children))
+ childCtx, cancelFn := context.WithCancel(ctx)
+ dispatchAllAsync(childCtx, currentRequestContext{
+ parentReq: crc.parentReq,
+ filteredResourceIDs: crc.filteredResourceIDs,
+ resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS,
+ dispatchChunkSize: crc.dispatchChunkSize,
+ }, children, handler, resultChan, concurrencyLimit)
+ defer cancelFn()
+
+ var membershipSet *MembershipSet
+ for i := 0; i < len(children); i++ {
+ select {
+ case result := <-resultChan:
+ responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata)
+ if result.Err != nil {
+ return checkResultError(result.Err, responseMetadata)
+ }
+
+ if membershipSet == nil {
+ membershipSet = NewMembershipSet()
+ membershipSet.UnionWith(result.Resp.ResultsByResourceId)
+ } else {
+ membershipSet.IntersectWith(result.Resp.ResultsByResourceId)
+ }
+
+ if membershipSet.IsEmpty() {
+ return noMembersWithMetadata(responseMetadata)
+ }
+ case <-ctx.Done():
+ return checkResultError(context.Canceled, responseMetadata)
+ }
+ }
+
+ return checkResultsForMembership(membershipSet, responseMetadata)
+}
+
+// difference returns whether the first lazy check passes and none of the subsequent checks pass.
+func difference[T any](
+ ctx context.Context,
+ crc currentRequestContext,
+ children []T,
+ handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult,
+ concurrencyLimit uint16,
+) CheckResult {
+ if len(children) == 0 {
+ return noMembers()
+ }
+
+ if len(children) == 1 {
+ return checkResultError(spiceerrors.MustBugf("difference requires more than a single child"), emptyMetadata)
+ }
+
+ childCtx, cancelFn := context.WithCancel(ctx)
+ baseChan := make(chan CheckResult, 1)
+ othersChan := make(chan CheckResult, len(children)-1)
+
+ go func() {
+ result := handler(childCtx, currentRequestContext{
+ parentReq: crc.parentReq,
+ filteredResourceIDs: crc.filteredResourceIDs,
+ resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS,
+ dispatchChunkSize: crc.dispatchChunkSize,
+ }, children[0])
+ baseChan <- result
+ }()
+
+ dispatchAllAsync(childCtx, currentRequestContext{
+ parentReq: crc.parentReq,
+ filteredResourceIDs: crc.filteredResourceIDs,
+ resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS,
+ dispatchChunkSize: crc.dispatchChunkSize,
+ }, children[1:], handler, othersChan, concurrencyLimit-1)
+ defer cancelFn()
+
+ responseMetadata := emptyMetadata
+ membershipSet := NewMembershipSet()
+
+ // Wait for the base set to return.
+ select {
+ case base := <-baseChan:
+ responseMetadata = combineResponseMetadata(ctx, responseMetadata, base.Resp.Metadata)
+
+ if base.Err != nil {
+ return checkResultError(base.Err, responseMetadata)
+ }
+
+ membershipSet.UnionWith(base.Resp.ResultsByResourceId)
+ if membershipSet.IsEmpty() {
+ return noMembersWithMetadata(responseMetadata)
+ }
+
+ case <-ctx.Done():
+ return checkResultError(context.Canceled, responseMetadata)
+ }
+
+ // Subtract the remaining sets.
+ for i := 1; i < len(children); i++ {
+ select {
+ case sub := <-othersChan:
+ responseMetadata = combineResponseMetadata(ctx, responseMetadata, sub.Resp.Metadata)
+
+ if sub.Err != nil {
+ return checkResultError(sub.Err, responseMetadata)
+ }
+
+ membershipSet.Subtract(sub.Resp.ResultsByResourceId)
+ if membershipSet.IsEmpty() {
+ return noMembersWithMetadata(responseMetadata)
+ }
+
+ case <-ctx.Done():
+ return checkResultError(context.Canceled, responseMetadata)
+ }
+ }
+
+ return checkResultsForMembership(membershipSet, responseMetadata)
+}
+
+type withError interface {
+ ResultError() error
+}
+
+func dispatchAllAsync[T any, R withError](
+ ctx context.Context,
+ crc currentRequestContext,
+ children []T,
+ handler func(ctx context.Context, crc currentRequestContext, child T) R,
+ resultChan chan<- R,
+ concurrencyLimit uint16,
+) {
+ tr := taskrunner.NewPreloadedTaskRunner(ctx, concurrencyLimit, len(children))
+ for _, currentChild := range children {
+ currentChild := currentChild
+ tr.Add(func(ctx context.Context) error {
+ result := handler(ctx, crc, currentChild)
+ resultChan <- result
+ return result.ResultError()
+ })
+ }
+
+ tr.Start()
+}
+
+func noMembers() CheckResult {
+ return CheckResult{
+ &v1.DispatchCheckResponse{
+ Metadata: emptyMetadata,
+ },
+ nil,
+ }
+}
+
+func noMembersWithMetadata(metadata *v1.ResponseMeta) CheckResult {
+ return CheckResult{
+ &v1.DispatchCheckResponse{
+ Metadata: metadata,
+ },
+ nil,
+ }
+}
+
+func checkResultsForMembership(foundMembership *MembershipSet, subProblemMetadata *v1.ResponseMeta) CheckResult {
+ return CheckResult{
+ &v1.DispatchCheckResponse{
+ Metadata: ensureMetadata(subProblemMetadata),
+ ResultsByResourceId: foundMembership.AsCheckResultsMap(),
+ },
+ nil,
+ }
+}
+
+func checkResultError(err error, subProblemMetadata *v1.ResponseMeta) CheckResult {
+ return CheckResult{
+ &v1.DispatchCheckResponse{
+ Metadata: ensureMetadata(subProblemMetadata),
+ },
+ err,
+ }
+}
+
+func combineResultWithFoundResources(result CheckResult, foundResources *MembershipSet) CheckResult {
+ if result.Err != nil {
+ return result
+ }
+
+ if foundResources.IsEmpty() {
+ return result
+ }
+
+ foundResources.UnionWith(result.Resp.ResultsByResourceId)
+ return CheckResult{
+ Resp: &v1.DispatchCheckResponse{
+ ResultsByResourceId: foundResources.AsCheckResultsMap(),
+ Metadata: result.Resp.Metadata,
+ },
+ Err: result.Err,
+ }
+}
+
+func combineResponseMetadata(ctx context.Context, existing *v1.ResponseMeta, responseMetadata *v1.ResponseMeta) *v1.ResponseMeta {
+ combined := &v1.ResponseMeta{
+ DispatchCount: existing.DispatchCount + responseMetadata.DispatchCount,
+ DepthRequired: max(existing.DepthRequired, responseMetadata.DepthRequired),
+ CachedDispatchCount: existing.CachedDispatchCount + responseMetadata.CachedDispatchCount,
+ }
+
+ if existing.DebugInfo == nil && responseMetadata.DebugInfo == nil {
+ return combined
+ }
+
+ nodeID, err := nodeid.FromContext(ctx)
+ if err != nil {
+ log.Err(err).Msg("failed to get nodeID from context")
+ }
+
+ debugInfo := &v1.DebugInformation{
+ Check: &v1.CheckDebugTrace{
+ TraceId: NewTraceID(),
+ SourceId: nodeID,
+ },
+ }
+
+ if existing.DebugInfo != nil {
+ if existing.DebugInfo.Check.Request != nil {
+ debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, existing.DebugInfo.Check)
+ } else {
+ debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, existing.DebugInfo.Check.SubProblems...)
+ }
+ }
+
+ if responseMetadata.DebugInfo != nil {
+ if responseMetadata.DebugInfo.Check.Request != nil {
+ debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, responseMetadata.DebugInfo.Check)
+ } else {
+ debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, responseMetadata.DebugInfo.Check.SubProblems...)
+ }
+ }
+
+ combined.DebugInfo = debugInfo
+ return combined
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go b/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go
new file mode 100644
index 0000000..ed3f3cb
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go
@@ -0,0 +1,144 @@
+package graph
+
+import (
+ "sort"
+
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/genutil/slicez"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// checkDispatchSet is the set of subjects over which check will need to dispatch
+// as subproblems in order to answer the parent problem.
+type checkDispatchSet struct {
+ // bySubjectType is a map from the type of subject to the set of subjects of that type
+ // over which to dispatch, along with information indicating whether caveats are present
+ // for that chunk.
+ bySubjectType map[tuple.RelationReference]map[string]bool
+
+ // bySubject is a map from the subject to the set of resources for which the subject
+ // has a relationship, along with the caveats that apply to that relationship.
+ bySubject *mapz.MultiMap[tuple.ObjectAndRelation, resourceIDAndCaveat]
+}
+
+// checkDispatchChunk is a chunk of subjects over which to dispatch a check operation.
+type checkDispatchChunk struct {
+ // resourceType is the type of the subjects in this chunk.
+ resourceType tuple.RelationReference
+
+ // resourceIds is the set of subjects in this chunk.
+ resourceIds []string
+
+ // hasIncomingCaveats is true if any of the subjects in this chunk have incoming caveats.
+ // This is used to determine whether the check operation should be dispatched requiring
+ // all results.
+ hasIncomingCaveats bool
+}
+
+// subjectIDAndHasCaveat is a tuple of a subject ID and whether it has a caveat.
+type subjectIDAndHasCaveat struct {
+ // objectID is the ID of the subject.
+ objectID string
+
+ // hasIncomingCaveats is true if the subject has a caveat.
+ hasIncomingCaveats bool
+}
+
+// resourceIDAndCaveat is a tuple of a resource ID and a caveat.
+type resourceIDAndCaveat struct {
+ // resourceID is the ID of the resource.
+ resourceID string
+
+ // caveat is the caveat that applies to the relationship between the subject and the resource.
+ // May be nil.
+ caveat *core.ContextualizedCaveat
+}
+
+// newCheckDispatchSet creates and returns a new checkDispatchSet.
+func newCheckDispatchSet() *checkDispatchSet {
+ return &checkDispatchSet{
+ bySubjectType: map[tuple.RelationReference]map[string]bool{},
+ bySubject: mapz.NewMultiMap[tuple.ObjectAndRelation, resourceIDAndCaveat](),
+ }
+}
+
+// Add adds the specified ObjectAndRelation to the set.
+func (s *checkDispatchSet) addForRelationship(rel tuple.Relationship) {
+ // Add an entry for the subject pointing to the resource ID and caveat for the subject.
+ riac := resourceIDAndCaveat{
+ resourceID: rel.Resource.ObjectID,
+ caveat: rel.OptionalCaveat,
+ }
+ s.bySubject.Add(rel.Subject, riac)
+
+ // Add the subject ID to the map of subjects for the type of subject.
+ siac := subjectIDAndHasCaveat{
+ objectID: rel.Subject.ObjectID,
+ hasIncomingCaveats: rel.OptionalCaveat != nil && rel.OptionalCaveat.CaveatName != "",
+ }
+
+ subjectIDsForType, ok := s.bySubjectType[rel.Subject.RelationReference()]
+ if !ok {
+ subjectIDsForType = make(map[string]bool)
+ s.bySubjectType[rel.Subject.RelationReference()] = subjectIDsForType
+ }
+
+ // If a caveat exists for the subject ID in any branch, the whole branch is considered caveated.
+ subjectIDsForType[rel.Subject.ObjectID] = siac.hasIncomingCaveats || subjectIDsForType[rel.Subject.ObjectID]
+}
+
+func (s *checkDispatchSet) dispatchChunks(dispatchChunkSize uint16) []checkDispatchChunk {
+ // Start with an estimate of one chunk per type, plus one for the remainder.
+ expectedNumberOfChunks := len(s.bySubjectType) + 1
+ toDispatch := make([]checkDispatchChunk, 0, expectedNumberOfChunks)
+
+ // For each type of subject, create chunks of the IDs over which to dispatch.
+ for subjectType, subjectIDsAndHasCaveats := range s.bySubjectType {
+ entries := make([]subjectIDAndHasCaveat, 0, len(subjectIDsAndHasCaveats))
+ for objectID, hasIncomingCaveats := range subjectIDsAndHasCaveats {
+ entries = append(entries, subjectIDAndHasCaveat{objectID: objectID, hasIncomingCaveats: hasIncomingCaveats})
+ }
+
+ // Sort the list of subject IDs by whether they have caveats and then the ID itself.
+ sort.Slice(entries, func(i, j int) bool {
+ iHasCaveat := entries[i].hasIncomingCaveats
+ jHasCaveat := entries[j].hasIncomingCaveats
+ if iHasCaveat == jHasCaveat {
+ return entries[i].objectID < entries[j].objectID
+ }
+ return iHasCaveat && !jHasCaveat
+ })
+
+ chunkCount := 0.0
+ slicez.ForEachChunk(entries, dispatchChunkSize, func(subjectIdChunk []subjectIDAndHasCaveat) {
+ chunkCount++
+
+ subjectIDsToDispatch := make([]string, 0, len(subjectIdChunk))
+ hasIncomingCaveats := false
+ for _, entry := range subjectIdChunk {
+ subjectIDsToDispatch = append(subjectIDsToDispatch, entry.objectID)
+ hasIncomingCaveats = hasIncomingCaveats || entry.hasIncomingCaveats
+ }
+
+ toDispatch = append(toDispatch, checkDispatchChunk{
+ resourceType: subjectType,
+ resourceIds: subjectIDsToDispatch,
+ hasIncomingCaveats: hasIncomingCaveats,
+ })
+ })
+ dispatchChunkCountHistogram.Observe(chunkCount)
+ }
+
+ return toDispatch
+}
+
+// mappingsForSubject returns the mappings that apply to the relationship between the specified
+// subject and any of its resources. The returned caveats include the resource ID of the resource
+// that the subject has a relationship with.
+func (s *checkDispatchSet) mappingsForSubject(subjectType string, subjectObjectID string, subjectRelation string) []resourceIDAndCaveat {
+ results, ok := s.bySubject.Get(tuple.ONR(subjectType, subjectObjectID, subjectRelation))
+ spiceerrors.DebugAssert(func() bool { return ok }, "no caveats found for subject %s:%s:%s", subjectType, subjectObjectID, subjectRelation)
+ return results
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go b/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go
new file mode 100644
index 0000000..0bf20b5
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go
@@ -0,0 +1,205 @@
+package computed
+
+import (
+ "context"
+
+ cexpr "github.com/authzed/spicedb/internal/caveats"
+ "github.com/authzed/spicedb/internal/dispatch"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/genutil/slicez"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// DebugOption defines the various debug level options for Checks.
+type DebugOption int
+
+const (
+ // NoDebugging indicates that debug information should be retained
+ // while performing the Check.
+ NoDebugging DebugOption = 0
+
+ // BasicDebuggingEnabled indicates that basic debug information, such
+ // as which steps were taken, should be retained while performing the
+ // Check and returned to the caller.
+ //
+ // NOTE: This has a minor performance impact.
+ BasicDebuggingEnabled DebugOption = 1
+
+ // TraceDebuggingEnabled indicates that the Check is being issued for
+ // tracing the exact calls made for debugging, which means that not only
+ // should debug information be recorded and returned, but that optimizations
+ // such as batching should be disabled.
+ //
+ // WARNING: This has a fairly significant performance impact and should only
+ // be used in tooling!
+ TraceDebuggingEnabled DebugOption = 2
+)
+
+// CheckParameters are the parameters for the ComputeCheck call. *All* are required.
+type CheckParameters struct {
+ ResourceType tuple.RelationReference
+ Subject tuple.ObjectAndRelation
+ CaveatContext map[string]any
+ AtRevision datastore.Revision
+ MaximumDepth uint32
+ DebugOption DebugOption
+ CheckHints []*v1.CheckHint
+}
+
+// ComputeCheck computes a check result for the given resource and subject, computing any
+// caveat expressions found.
+func ComputeCheck(
+ ctx context.Context,
+ d dispatch.Check,
+ ts *caveattypes.TypeSet,
+ params CheckParameters,
+ resourceID string,
+ dispatchChunkSize uint16,
+) (*v1.ResourceCheckResult, *v1.ResponseMeta, error) {
+ resultsMap, meta, di, err := computeCheck(ctx, d, ts, params, []string{resourceID}, dispatchChunkSize)
+ if err != nil {
+ return nil, meta, err
+ }
+
+ spiceerrors.DebugAssert(func() bool {
+ return (len(di) == 0 && meta.DebugInfo == nil) || (len(di) == 1 && meta.DebugInfo != nil)
+ }, "mismatch in debug information returned from computeCheck")
+
+ return resultsMap[resourceID], meta, err
+}
+
+// ComputeBulkCheck computes a check result for the given resources and subject, computing any
+// caveat expressions found.
+func ComputeBulkCheck(
+ ctx context.Context,
+ d dispatch.Check,
+ ts *caveattypes.TypeSet,
+ params CheckParameters,
+ resourceIDs []string,
+ dispatchChunkSize uint16,
+) (map[string]*v1.ResourceCheckResult, *v1.ResponseMeta, []*v1.DebugInformation, error) {
+ return computeCheck(ctx, d, ts, params, resourceIDs, dispatchChunkSize)
+}
+
+func computeCheck(ctx context.Context,
+ d dispatch.Check,
+ ts *caveattypes.TypeSet,
+ params CheckParameters,
+ resourceIDs []string,
+ dispatchChunkSize uint16,
+) (map[string]*v1.ResourceCheckResult, *v1.ResponseMeta, []*v1.DebugInformation, error) {
+ debugging := v1.DispatchCheckRequest_NO_DEBUG
+ if params.DebugOption == BasicDebuggingEnabled {
+ debugging = v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING
+ } else if params.DebugOption == TraceDebuggingEnabled {
+ debugging = v1.DispatchCheckRequest_ENABLE_TRACE_DEBUGGING
+ }
+
+ setting := v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS
+ if len(resourceIDs) == 1 {
+ setting = v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT
+ }
+
+ // Ensure that the number of resources IDs given to each dispatch call is not in excess of the maximum.
+ results := make(map[string]*v1.ResourceCheckResult, len(resourceIDs))
+ metadata := &v1.ResponseMeta{}
+
+ bf, err := v1.NewTraversalBloomFilter(uint(params.MaximumDepth))
+ if err != nil {
+ return nil, nil, nil, spiceerrors.MustBugf("failed to create new traversal bloom filter")
+ }
+
+ caveatRunner := cexpr.NewCaveatRunner(ts)
+
+ // TODO(jschorr): Should we make this run in parallel via the preloadedTaskRunner?
+ debugInfo := make([]*v1.DebugInformation, 0)
+ _, err = slicez.ForEachChunkUntil(resourceIDs, dispatchChunkSize, func(resourceIDsToCheck []string) (bool, error) {
+ checkResult, err := d.DispatchCheck(ctx, &v1.DispatchCheckRequest{
+ ResourceRelation: params.ResourceType.ToCoreRR(),
+ ResourceIds: resourceIDsToCheck,
+ ResultsSetting: setting,
+ Subject: params.Subject.ToCoreONR(),
+ Metadata: &v1.ResolverMeta{
+ AtRevision: params.AtRevision.String(),
+ DepthRemaining: params.MaximumDepth,
+ TraversalBloom: bf,
+ },
+ Debug: debugging,
+ CheckHints: params.CheckHints,
+ })
+
+ if checkResult.Metadata.DebugInfo != nil {
+ debugInfo = append(debugInfo, checkResult.Metadata.DebugInfo)
+ }
+
+ if len(resourceIDs) == 1 {
+ metadata = checkResult.Metadata
+ } else {
+ metadata = &v1.ResponseMeta{
+ DispatchCount: metadata.DispatchCount + checkResult.Metadata.DispatchCount,
+ DepthRequired: max(metadata.DepthRequired, checkResult.Metadata.DepthRequired),
+ CachedDispatchCount: metadata.CachedDispatchCount + checkResult.Metadata.CachedDispatchCount,
+ DebugInfo: nil,
+ }
+ }
+
+ if err != nil {
+ return false, err
+ }
+
+ for _, resourceID := range resourceIDsToCheck {
+ computed, err := computeCaveatedCheckResult(ctx, caveatRunner, params, resourceID, checkResult)
+ if err != nil {
+ return false, err
+ }
+ results[resourceID] = computed
+ }
+
+ return true, nil
+ })
+ return results, metadata, debugInfo, err
+}
+
+func computeCaveatedCheckResult(ctx context.Context, runner *cexpr.CaveatRunner, params CheckParameters, resourceID string, checkResult *v1.DispatchCheckResponse) (*v1.ResourceCheckResult, error) {
+ result, ok := checkResult.ResultsByResourceId[resourceID]
+ if !ok {
+ return &v1.ResourceCheckResult{
+ Membership: v1.ResourceCheckResult_NOT_MEMBER,
+ }, nil
+ }
+
+ if result.Membership == v1.ResourceCheckResult_MEMBER {
+ return result, nil
+ }
+
+ ds := datastoremw.MustFromContext(ctx)
+ reader := ds.SnapshotReader(params.AtRevision)
+
+ caveatResult, err := runner.RunCaveatExpression(ctx, result.Expression, params.CaveatContext, reader, cexpr.RunCaveatExpressionNoDebugging)
+ if err != nil {
+ return nil, err
+ }
+
+ if caveatResult.IsPartial() {
+ missingFields, _ := caveatResult.MissingVarNames()
+ return &v1.ResourceCheckResult{
+ Membership: v1.ResourceCheckResult_CAVEATED_MEMBER,
+ Expression: result.Expression,
+ MissingExprFields: missingFields,
+ }, nil
+ }
+
+ if caveatResult.Value() {
+ return &v1.ResourceCheckResult{
+ Membership: v1.ResourceCheckResult_MEMBER,
+ }, nil
+ }
+
+ return &v1.ResourceCheckResult{
+ Membership: v1.ResourceCheckResult_NOT_MEMBER,
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/context.go b/vendor/github.com/authzed/spicedb/internal/graph/context.go
new file mode 100644
index 0000000..1485fa0
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/context.go
@@ -0,0 +1,33 @@
+package graph
+
+import (
+ "context"
+
+ "go.opentelemetry.io/otel/trace"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/pkg/middleware/requestid"
+)
+
+// branchContext returns a context disconnected from the parent context, but populated with the datastore.
+// Also returns a function for canceling the newly created context, without canceling the parent context.
+func branchContext(ctx context.Context) (context.Context, func(cancelErr error)) {
+ // Add tracing to the context.
+ span := trace.SpanFromContext(ctx)
+ detachedContext := trace.ContextWithSpan(context.Background(), span)
+
+ // Add datastore to the context.
+ ds := datastoremw.FromContext(ctx)
+ detachedContext = datastoremw.ContextWithDatastore(detachedContext, ds)
+
+ // Add logging to the context.
+ loggerFromContext := log.Ctx(ctx)
+ if loggerFromContext != nil {
+ detachedContext = loggerFromContext.WithContext(detachedContext)
+ }
+
+ detachedContext = requestid.PropagateIfExists(ctx, detachedContext)
+
+ return context.WithCancelCause(detachedContext)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/cursors.go b/vendor/github.com/authzed/spicedb/internal/graph/cursors.go
new file mode 100644
index 0000000..ad3d705
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/cursors.go
@@ -0,0 +1,542 @@
+package graph
+
+import (
+ "context"
+ "errors"
+ "strconv"
+ "sync"
+
+ "github.com/ccoveille/go-safecast"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/taskrunner"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// cursorInformation is a struct which holds information about the current incoming cursor (if any)
+// and the sections to be added to the *outgoing* partial cursor.
+type cursorInformation struct {
+ // currentCursor is the current incoming cursor. This may be nil.
+ currentCursor *v1.Cursor
+
+ // outgoingCursorSections are the sections to be added to the outgoing *partial* cursor.
+ // It is the responsibility of the *caller* to append together the incoming cursors to form
+ // the final cursor.
+ //
+ // A `section` is a portion of the cursor, representing a section of code that was
+ // executed to produce the section of the cursor.
+ outgoingCursorSections []string
+
+ // limits is the limits tracker for the call over which the cursor is being used.
+ limits *limitTracker
+
+ // dispatchCursorVersion is the version of the dispatch to be stored in the cursor.
+ dispatchCursorVersion uint32
+}
+
+// newCursorInformation constructs a new cursorInformation struct from the incoming cursor (which
+// may be nil)
+func newCursorInformation(incomingCursor *v1.Cursor, limits *limitTracker, dispatchCursorVersion uint32) (cursorInformation, error) {
+ if incomingCursor != nil && incomingCursor.DispatchVersion != dispatchCursorVersion {
+ return cursorInformation{}, NewInvalidCursorErr(dispatchCursorVersion, incomingCursor)
+ }
+
+ if dispatchCursorVersion == 0 {
+ return cursorInformation{}, spiceerrors.MustBugf("invalid dispatch cursor version")
+ }
+
+ return cursorInformation{
+ currentCursor: incomingCursor,
+ outgoingCursorSections: nil,
+ limits: limits,
+ dispatchCursorVersion: dispatchCursorVersion,
+ }, nil
+}
+
+// responsePartialCursor is the *partial* cursor to return in a response.
+func (ci cursorInformation) responsePartialCursor() *v1.Cursor {
+ return &v1.Cursor{
+ DispatchVersion: ci.dispatchCursorVersion,
+ Sections: ci.outgoingCursorSections,
+ }
+}
+
+// withClonedLimits returns the cursor, but with its limits tracker cloned.
+func (ci cursorInformation) withClonedLimits() cursorInformation {
+ return cursorInformation{
+ currentCursor: ci.currentCursor,
+ outgoingCursorSections: ci.outgoingCursorSections,
+ limits: ci.limits.clone(),
+ dispatchCursorVersion: ci.dispatchCursorVersion,
+ }
+}
+
+// headSectionValue returns the string value found at the head of the incoming cursor.
+// If the incoming cursor is empty, returns empty.
+func (ci cursorInformation) headSectionValue() (string, bool) {
+ if ci.currentCursor == nil || len(ci.currentCursor.Sections) < 1 {
+ return "", false
+ }
+
+ return ci.currentCursor.Sections[0], true
+}
+
+// integerSectionValue returns the *integer* found at the head of the incoming cursor.
+// If the incoming cursor is empty, returns 0. If the incoming cursor does not start with an
+// int value, fails with an error.
+func (ci cursorInformation) integerSectionValue() (int, error) {
+ valueStr, hasValue := ci.headSectionValue()
+ if !hasValue {
+ return 0, nil
+ }
+
+ if valueStr == "" {
+ return 0, nil
+ }
+
+ return strconv.Atoi(valueStr)
+}
+
+// withOutgoingSection returns cursorInformation updated with the given optional
+// value appended to the outgoingCursorSections for the current cursor. If the current
+// cursor already begins with any values, those values are replaced.
+func (ci cursorInformation) withOutgoingSection(value string) (cursorInformation, error) {
+ ocs := make([]string, 0, len(ci.outgoingCursorSections)+1)
+ ocs = append(ocs, ci.outgoingCursorSections...)
+ ocs = append(ocs, value)
+
+ if ci.currentCursor != nil && len(ci.currentCursor.Sections) > 0 {
+ // If the cursor already has values, replace them with those specified.
+ return cursorInformation{
+ currentCursor: &v1.Cursor{
+ DispatchVersion: ci.dispatchCursorVersion,
+ Sections: ci.currentCursor.Sections[1:],
+ },
+ outgoingCursorSections: ocs,
+ limits: ci.limits,
+ dispatchCursorVersion: ci.dispatchCursorVersion,
+ }, nil
+ }
+
+ return cursorInformation{
+ currentCursor: nil,
+ outgoingCursorSections: ocs,
+ limits: ci.limits,
+ dispatchCursorVersion: ci.dispatchCursorVersion,
+ }, nil
+}
+
+func (ci cursorInformation) clearIncoming() cursorInformation {
+ return cursorInformation{
+ currentCursor: nil,
+ outgoingCursorSections: ci.outgoingCursorSections,
+ limits: ci.limits,
+ dispatchCursorVersion: ci.dispatchCursorVersion,
+ }
+}
+
+type cursorHandler func(c cursorInformation) error
+
+// itemAndPostCursor represents an item and the cursor to be used for all items after it.
+type itemAndPostCursor[T any] struct {
+ item T
+ cursor options.Cursor
+}
+
+// withDatastoreCursorInCursor executes the given lookup function to retrieve items from the datastore,
+// and then executes the handler on each of the produced items *in parallel*, streaming the results
+// in the correct order to the parent stream.
+func withDatastoreCursorInCursor[T any, Q any](
+ ctx context.Context,
+ ci cursorInformation,
+ parentStream dispatch.Stream[Q],
+ concurrencyLimit uint16,
+ lookup func(queryCursor options.Cursor) ([]itemAndPostCursor[T], error),
+ handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error,
+) error {
+ // Retrieve the *datastore* cursor, if one is found at the head of the incoming cursor.
+ var datastoreCursor options.Cursor
+ datastoreCursorString, _ := ci.headSectionValue()
+ if datastoreCursorString != "" {
+ datastoreCursor = options.ToCursor(tuple.MustParse(datastoreCursorString))
+ }
+
+ if ci.limits.hasExhaustedLimit() {
+ return nil
+ }
+
+ // Execute the lookup to call the database and find items for processing.
+ itemsToBeProcessed, err := lookup(datastoreCursor)
+ if err != nil {
+ return err
+ }
+
+ if len(itemsToBeProcessed) == 0 {
+ return nil
+ }
+
+ itemsToRun := make([]T, 0, len(itemsToBeProcessed))
+ for _, itemAndCursor := range itemsToBeProcessed {
+ itemsToRun = append(itemsToRun, itemAndCursor.item)
+ }
+
+ getItemCursor := func(taskIndex int) (cursorInformation, error) {
+ // Create an updated cursor referencing the current item's cursor, so that any items returned know to resume from this point.
+ cursorRel := options.ToRelationship(itemsToBeProcessed[taskIndex].cursor)
+ cursorSection := ""
+ if cursorRel != nil {
+ cursorSection = tuple.StringWithoutCaveatOrExpiration(*cursorRel)
+ }
+
+ currentCursor, err := ci.withOutgoingSection(cursorSection)
+ if err != nil {
+ return currentCursor, err
+ }
+
+ // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top
+ // of the cursor.
+ if taskIndex > 0 {
+ currentCursor = currentCursor.clearIncoming()
+ }
+
+ return currentCursor, nil
+ }
+
+ return withInternalParallelizedStreamingIterableInCursor(
+ ctx,
+ ci,
+ itemsToRun,
+ parentStream,
+ concurrencyLimit,
+ getItemCursor,
+ handler,
+ )
+}
+
+type afterResponseCursor func(nextOffset int) *v1.Cursor
+
+// withSubsetInCursor executes the given handler with the offset index found at the beginning of the
+// cursor. If the offset is not found, executes with 0. The handler is given the current offset as
+// well as a callback to mint the cursor with the next offset.
+func withSubsetInCursor(
+ ci cursorInformation,
+ handler func(currentOffset int, nextCursorWith afterResponseCursor) error,
+ next cursorHandler,
+) error {
+ if ci.limits.hasExhaustedLimit() {
+ return nil
+ }
+
+ afterIndex, err := ci.integerSectionValue()
+ if err != nil {
+ return err
+ }
+
+ if afterIndex >= 0 {
+ var foundCerr error
+ err = handler(afterIndex, func(nextOffset int) *v1.Cursor {
+ cursor, cerr := ci.withOutgoingSection(strconv.Itoa(nextOffset))
+ foundCerr = cerr
+ if cerr != nil {
+ return nil
+ }
+
+ return cursor.responsePartialCursor()
+ })
+ if err != nil {
+ return err
+ }
+ if foundCerr != nil {
+ return foundCerr
+ }
+ }
+
+ if ci.limits.hasExhaustedLimit() {
+ return nil
+ }
+
+ // -1 means that the handler has been completed.
+ uci, err := ci.withOutgoingSection("-1")
+ if err != nil {
+ return err
+ }
+ return next(uci)
+}
+
+// combineCursors combines the given cursors into one resulting cursor.
+func combineCursors(cursor *v1.Cursor, toAdd *v1.Cursor) (*v1.Cursor, error) {
+ if toAdd == nil || len(toAdd.Sections) == 0 {
+ return nil, spiceerrors.MustBugf("supplied toAdd cursor was nil or empty")
+ }
+
+ if cursor == nil || len(cursor.Sections) == 0 {
+ return toAdd, nil
+ }
+
+ sections := make([]string, 0, len(cursor.Sections)+len(toAdd.Sections))
+ sections = append(sections, cursor.Sections...)
+ sections = append(sections, toAdd.Sections...)
+
+ return &v1.Cursor{
+ DispatchVersion: toAdd.DispatchVersion,
+ Sections: sections,
+ }, nil
+}
+
+// withParallelizedStreamingIterableInCursor executes the given handler for each item in the items list, skipping any
+// items marked as completed at the head of the cursor and injecting a cursor representing the current
+// item.
+//
+// For example, if items contains 3 items, and the cursor returned was within the handler for item
+// index #1, then item index #0 will be skipped on subsequent invocation.
+//
+// The next index is executed in parallel with the current index, with its results stored in a CollectingStream
+// until the next iteration.
+func withParallelizedStreamingIterableInCursor[T any, Q any](
+ ctx context.Context,
+ ci cursorInformation,
+ items []T,
+ parentStream dispatch.Stream[Q],
+ concurrencyLimit uint16,
+ handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error,
+) error {
+ // Check the cursor for a starting index, before which any items will be skipped.
+ startingIndex, err := ci.integerSectionValue()
+ if err != nil {
+ return err
+ }
+
+ if startingIndex < 0 || startingIndex > len(items) {
+ return spiceerrors.MustBugf("invalid cursor in withParallelizedStreamingIterableInCursor: found starting index %d for items %v", startingIndex, items)
+ }
+
+ itemsToRun := items[startingIndex:]
+ if len(itemsToRun) == 0 {
+ return nil
+ }
+
+ getItemCursor := func(taskIndex int) (cursorInformation, error) {
+ // Create an updated cursor referencing the current item's index, so that any items returned know to resume from this point.
+ currentCursor, err := ci.withOutgoingSection(strconv.Itoa(taskIndex + startingIndex))
+ if err != nil {
+ return currentCursor, err
+ }
+
+ // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top
+ // of the cursor.
+ if taskIndex > 0 {
+ currentCursor = currentCursor.clearIncoming()
+ }
+
+ return currentCursor, nil
+ }
+
+ return withInternalParallelizedStreamingIterableInCursor(
+ ctx,
+ ci,
+ itemsToRun,
+ parentStream,
+ concurrencyLimit,
+ getItemCursor,
+ handler,
+ )
+}
+
+func withInternalParallelizedStreamingIterableInCursor[T any, Q any](
+ ctx context.Context,
+ ci cursorInformation,
+ itemsToRun []T,
+ parentStream dispatch.Stream[Q],
+ concurrencyLimit uint16,
+ getItemCursor func(taskIndex int) (cursorInformation, error),
+ handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error,
+) error {
+ // Queue up each iteration's worth of items to be run by the task runner.
+ tr := taskrunner.NewPreloadedTaskRunner(ctx, concurrencyLimit, len(itemsToRun))
+ stream, err := newParallelLimitedIndexedStream(ctx, ci, parentStream, len(itemsToRun))
+ if err != nil {
+ return err
+ }
+
+ // Schedule a task to be invoked for each item to be run.
+ for taskIndex, item := range itemsToRun {
+ taskIndex := taskIndex
+ item := item
+ tr.Add(func(ctx context.Context) error {
+ stream.lock.Lock()
+ if ci.limits.hasExhaustedLimit() {
+ stream.lock.Unlock()
+ return nil
+ }
+ stream.lock.Unlock()
+
+ ici, err := getItemCursor(taskIndex)
+ if err != nil {
+ return err
+ }
+
+ // Invoke the handler with the current item's index in the outgoing cursor, indicating that
+ // subsequent invocations should jump right to this item.
+ ictx, istream, icursor := stream.forTaskIndex(ctx, taskIndex, ici)
+
+ err = handler(ictx, icursor, item, istream)
+ if err != nil {
+ // If the branch was canceled explicitly by *this* streaming iterable because other branches have fulfilled
+ // the configured limit, then we can safely ignore this error.
+ if errors.Is(context.Cause(ictx), stream.errCanceledBecauseFulfilled) {
+ return nil
+ }
+ return err
+ }
+
+ return stream.completedTaskIndex(taskIndex)
+ })
+ }
+
+ err = tr.StartAndWait()
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// parallelLimitedIndexedStream is a specialization of a dispatch.Stream that collects results from multiple
+// tasks running in parallel, and emits them in the order of the tasks. The first task's results are directly
+// emitted to the parent stream, while subsequent tasks' results are emitted in the defined order of the tasks
+// to ensure cursors and limits work as expected.
+type parallelLimitedIndexedStream[Q any] struct {
+ lock sync.Mutex
+
+ ctx context.Context
+ ci cursorInformation
+ parentStream dispatch.Stream[Q]
+
+ streamCount int
+ toPublishTaskIndex int
+ countingStream *dispatch.CountingDispatchStream[Q] // GUARDED_BY(lock)
+ childStreams map[int]*dispatch.CollectingDispatchStream[Q] // GUARDED_BY(lock)
+ childContextCancels map[int]func(cause error) // GUARDED_BY(lock)
+ completedTaskIndexes map[int]bool // GUARDED_BY(lock)
+ errCanceledBecauseFulfilled error
+}
+
+func newParallelLimitedIndexedStream[Q any](
+ ctx context.Context,
+ ci cursorInformation,
+ parentStream dispatch.Stream[Q],
+ streamCount int,
+) (*parallelLimitedIndexedStream[Q], error) {
+ if streamCount <= 0 {
+ return nil, spiceerrors.MustBugf("got invalid stream count")
+ }
+
+ return &parallelLimitedIndexedStream[Q]{
+ ctx: ctx,
+ ci: ci,
+ parentStream: parentStream,
+ countingStream: nil,
+ childStreams: map[int]*dispatch.CollectingDispatchStream[Q]{},
+ childContextCancels: map[int]func(cause error){},
+ completedTaskIndexes: map[int]bool{},
+ toPublishTaskIndex: 0,
+ streamCount: streamCount,
+
+ // NOTE: we mint a new error here to ensure that we only skip cancelations from this very instance.
+ errCanceledBecauseFulfilled: errors.New("canceled because other branches fulfilled limit"),
+ }, nil
+}
+
+// forTaskIndex returns a new context, stream and cursor for invoking the task at the specific index and publishing its results.
+func (ls *parallelLimitedIndexedStream[Q]) forTaskIndex(ctx context.Context, index int, currentCursor cursorInformation) (context.Context, dispatch.Stream[Q], cursorInformation) {
+ ls.lock.Lock()
+ defer ls.lock.Unlock()
+
+ // Create a new cursor with cloned limits, because each child task which executes (in parallel) will need its own
+ // limit tracking. The overall limit on the original cursor is managed in completedTaskIndex.
+ childCI := currentCursor.withClonedLimits()
+ childContext, cancelDispatch := branchContext(ctx)
+
+ ls.childContextCancels[index] = cancelDispatch
+
+ // If executing for the first index, it can stream directly to the parent stream, but we need to count the number
+ // of items streamed to adjust the overall limits.
+ if index == 0 {
+ countingStream := dispatch.NewCountingDispatchStream(ls.parentStream)
+ ls.countingStream = countingStream
+ return childContext, countingStream, childCI
+ }
+
+ // Otherwise, create a child stream with an adjusted limits on the cursor. We have to clone the cursor's
+ // limits here to ensure that the child's publishing doesn't affect the first branch.
+ childStream := dispatch.NewCollectingDispatchStream[Q](childContext)
+ ls.childStreams[index] = childStream
+
+ return childContext, childStream, childCI
+}
+
+// cancelRemainingDispatches cancels the contexts for each dispatched branch, indicating that no additional results
+// are necessary.
+func (ls *parallelLimitedIndexedStream[Q]) cancelRemainingDispatches() {
+ for _, cancel := range ls.childContextCancels {
+ cancel(ls.errCanceledBecauseFulfilled)
+ }
+}
+
+// completedTaskIndex indicates the the task at the specific index has completed successfully and that its collected
+// results should be published to the parent stream, so long as all previous tasks have been completed and published as well.
+func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error {
+ ls.lock.Lock()
+ defer ls.lock.Unlock()
+
+ // Mark the task as completed, but not yet published.
+ ls.completedTaskIndexes[index] = true
+
+ // If the overall limit has been reached, nothing more to do.
+ if ls.ci.limits.hasExhaustedLimit() {
+ ls.cancelRemainingDispatches()
+ return nil
+ }
+
+ // Otherwise, publish any results from previous completed tasks up, and including, this task. This loop ensures
+ // that the collected results for each task are published to the parent stream in the correct order.
+ for {
+ if !ls.completedTaskIndexes[ls.toPublishTaskIndex] {
+ return nil
+ }
+
+ if ls.toPublishTaskIndex == 0 {
+ // Remove the already emitted data from the overall limits.
+ publishedCount, err := safecast.ToUint32(ls.countingStream.PublishedCount())
+ if err != nil {
+ return spiceerrors.MustBugf("cannot cast published count to uint32: %v", err)
+ }
+ if err := ls.ci.limits.markAlreadyPublished(publishedCount); err != nil {
+ return err
+ }
+
+ if ls.ci.limits.hasExhaustedLimit() {
+ ls.cancelRemainingDispatches()
+ }
+ } else {
+ // Publish, to the parent stream, the results produced by the task and stored in the child stream.
+ childStream := ls.childStreams[ls.toPublishTaskIndex]
+ for _, result := range childStream.Results() {
+ if !ls.ci.limits.prepareForPublishing() {
+ ls.cancelRemainingDispatches()
+ return nil
+ }
+
+ err := ls.parentStream.Publish(result)
+ if err != nil {
+ return err
+ }
+ }
+ ls.childStreams[ls.toPublishTaskIndex] = nil
+ }
+
+ ls.toPublishTaskIndex++
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/doc.go b/vendor/github.com/authzed/spicedb/internal/graph/doc.go
new file mode 100644
index 0000000..904b216
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/doc.go
@@ -0,0 +1,2 @@
+// Package graph contains the code to traverse a relationship graph to solve requests like Checks, Expansions and Lookups.
+package graph
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/errors.go b/vendor/github.com/authzed/spicedb/internal/graph/errors.go
new file mode 100644
index 0000000..31577d9
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/errors.go
@@ -0,0 +1,213 @@
+package graph
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/rs/zerolog"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/internal/sharederrors"
+ dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// CheckFailureError occurs when check failed in some manner. Note this should not apply to
+// namespaces and relations not being found.
+type CheckFailureError struct {
+ error
+}
+
+func (e CheckFailureError) Unwrap() error {
+ return e.error
+}
+
+// NewCheckFailureErr constructs a new check failed error.
+func NewCheckFailureErr(baseErr error) error {
+ return CheckFailureError{
+ error: fmt.Errorf("error performing check: %w", baseErr),
+ }
+}
+
+// ExpansionFailureError occurs when expansion failed in some manner. Note this should not apply to
+// namespaces and relations not being found.
+type ExpansionFailureError struct {
+ error
+}
+
+func (e ExpansionFailureError) Unwrap() error {
+ return e.error
+}
+
+// NewExpansionFailureErr constructs a new expansion failed error.
+func NewExpansionFailureErr(baseErr error) error {
+ return ExpansionFailureError{
+ error: fmt.Errorf("error performing expand: %w", baseErr),
+ }
+}
+
+// AlwaysFailError is returned when an internal error leads to an operation
+// guaranteed to fail.
+type AlwaysFailError struct {
+ error
+}
+
+// NewAlwaysFailErr constructs a new always fail error.
+func NewAlwaysFailErr() error {
+ return AlwaysFailError{
+ error: errors.New("always fail"),
+ }
+}
+
+// RelationNotFoundError occurs when a relation was not found under a namespace.
+type RelationNotFoundError struct {
+ error
+ namespaceName string
+ relationName string
+}
+
+// NamespaceName returns the name of the namespace in which the relation was not found.
+func (err RelationNotFoundError) NamespaceName() string {
+ return err.namespaceName
+}
+
+// NotFoundRelationName returns the name of the relation not found.
+func (err RelationNotFoundError) NotFoundRelationName() string {
+ return err.relationName
+}
+
+func (err RelationNotFoundError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err RelationNotFoundError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "definition_name": err.namespaceName,
+ "relation_or_permission_name": err.relationName,
+ }
+}
+
+// NewRelationNotFoundErr constructs a new relation not found error.
+func NewRelationNotFoundErr(nsName string, relationName string) error {
+ return RelationNotFoundError{
+ error: fmt.Errorf("relation/permission `%s` not found under definition `%s`", relationName, nsName),
+ namespaceName: nsName,
+ relationName: relationName,
+ }
+}
+
+var _ sharederrors.UnknownRelationError = RelationNotFoundError{}
+
+// RelationMissingTypeInfoError defines an error for when type information is missing from a relation
+// during a lookup.
+type RelationMissingTypeInfoError struct {
+ error
+ namespaceName string
+ relationName string
+}
+
+// NamespaceName returns the name of the namespace in which the relation was found.
+func (err RelationMissingTypeInfoError) NamespaceName() string {
+ return err.namespaceName
+}
+
+// RelationName returns the name of the relation missing type information.
+func (err RelationMissingTypeInfoError) RelationName() string {
+ return err.relationName
+}
+
+func (err RelationMissingTypeInfoError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err RelationMissingTypeInfoError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "definition_name": err.namespaceName,
+ "relation_name": err.relationName,
+ }
+}
+
+// NewRelationMissingTypeInfoErr constructs a new relation not missing type information error.
+func NewRelationMissingTypeInfoErr(nsName string, relationName string) error {
+ return RelationMissingTypeInfoError{
+ error: fmt.Errorf("relation/permission `%s` under definition `%s` is missing type information", relationName, nsName),
+ namespaceName: nsName,
+ relationName: relationName,
+ }
+}
+
+// WildcardNotAllowedError occurs when a request sent has an invalid wildcard argument.
+type WildcardNotAllowedError struct {
+ error
+
+ fieldName string
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err WildcardNotAllowedError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_WILDCARD_NOT_ALLOWED,
+ map[string]string{
+ "field": err.fieldName,
+ },
+ ),
+ )
+}
+
+// NewWildcardNotAllowedErr constructs an error indicating that a wildcard was not allowed.
+func NewWildcardNotAllowedErr(message string, fieldName string) error {
+ return WildcardNotAllowedError{
+ error: fmt.Errorf("invalid argument: %s", message),
+ fieldName: fieldName,
+ }
+}
+
+// UnimplementedError is returned when some functionality is not yet supported.
+type UnimplementedError struct {
+ error
+}
+
+// NewUnimplementedErr constructs a new unimplemented error.
+func NewUnimplementedErr(baseErr error) error {
+ return UnimplementedError{
+ error: baseErr,
+ }
+}
+
+func (e UnimplementedError) Unwrap() error {
+ return e.error
+}
+
+// InvalidCursorError is returned when a cursor is no longer valid.
+type InvalidCursorError struct {
+ error
+}
+
+// NewInvalidCursorErr constructs a new unimplemented error.
+func NewInvalidCursorErr(dispatchCursorVersion uint32, cursor *dispatch.Cursor) error {
+ return InvalidCursorError{
+ error: fmt.Errorf("the supplied cursor is no longer valid: found version %d, expected version %d", cursor.DispatchVersion, dispatchCursorVersion),
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err InvalidCursorError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_INVALID_CURSOR,
+ map[string]string{
+ "details": "cursor was used against an incompatible version of SpiceDB",
+ },
+ ),
+ )
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/expand.go b/vendor/github.com/authzed/spicedb/internal/graph/expand.go
new file mode 100644
index 0000000..9418bec
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/expand.go
@@ -0,0 +1,436 @@
+package graph
+
+import (
+ "context"
+ "errors"
+
+ "github.com/authzed/spicedb/internal/caveats"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ log "github.com/authzed/spicedb/internal/logging"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/namespace"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// NewConcurrentExpander creates an instance of ConcurrentExpander
+func NewConcurrentExpander(d dispatch.Expand) *ConcurrentExpander {
+ return &ConcurrentExpander{d: d}
+}
+
+// ConcurrentExpander exposes a method to perform Expand requests, and delegates subproblems to the
+// provided dispatch.Expand instance.
+type ConcurrentExpander struct {
+ d dispatch.Expand
+}
+
+// ValidatedExpandRequest represents a request after it has been validated and parsed for internal
+// consumption.
+type ValidatedExpandRequest struct {
+ *v1.DispatchExpandRequest
+ Revision datastore.Revision
+}
+
+// Expand performs an expand request with the provided request and context.
+func (ce *ConcurrentExpander) Expand(ctx context.Context, req ValidatedExpandRequest, relation *core.Relation) (*v1.DispatchExpandResponse, error) {
+ log.Ctx(ctx).Trace().Object("expand", req).Send()
+
+ var directFunc ReduceableExpandFunc
+ if relation.UsersetRewrite == nil {
+ directFunc = ce.expandDirect(ctx, req)
+ } else {
+ directFunc = ce.expandUsersetRewrite(ctx, req, relation.UsersetRewrite)
+ }
+
+ resolved := expandOne(ctx, directFunc)
+ resolved.Resp.Metadata = addCallToResponseMetadata(resolved.Resp.Metadata)
+ return resolved.Resp, resolved.Err
+}
+
+func (ce *ConcurrentExpander) expandDirect(
+ ctx context.Context,
+ req ValidatedExpandRequest,
+) ReduceableExpandFunc {
+ log.Ctx(ctx).Trace().Object("direct", req).Send()
+ return func(ctx context.Context, resultChan chan<- ExpandResult) {
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision)
+ it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{
+ OptionalResourceType: req.ResourceAndRelation.Namespace,
+ OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId},
+ OptionalResourceRelation: req.ResourceAndRelation.Relation,
+ }, options.WithQueryShape(queryshape.AllSubjectsForResources))
+ if err != nil {
+ resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata)
+ return
+ }
+
+ var foundNonTerminalUsersets []*core.DirectSubject
+ var foundTerminalUsersets []*core.DirectSubject
+ for rel, err := range it {
+ if err != nil {
+ resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata)
+ return
+ }
+
+ ds := &core.DirectSubject{
+ Subject: rel.Subject.ToCoreONR(),
+ CaveatExpression: caveats.CaveatAsExpr(rel.OptionalCaveat),
+ }
+
+ if rel.Subject.Relation == Ellipsis {
+ foundTerminalUsersets = append(foundTerminalUsersets, ds)
+ } else {
+ foundNonTerminalUsersets = append(foundNonTerminalUsersets, ds)
+ }
+ }
+
+ // If only shallow expansion was required, or there are no non-terminal subjects found,
+ // nothing more to do.
+ if req.ExpansionMode == v1.DispatchExpandRequest_SHALLOW || len(foundNonTerminalUsersets) == 0 {
+ resultChan <- expandResult(
+ &core.RelationTupleTreeNode{
+ NodeType: &core.RelationTupleTreeNode_LeafNode{
+ LeafNode: &core.DirectSubjects{
+ Subjects: append(foundTerminalUsersets, foundNonTerminalUsersets...),
+ },
+ },
+ Expanded: req.ResourceAndRelation,
+ },
+ emptyMetadata,
+ )
+ return
+ }
+
+ // Otherwise, recursively issue expansion and collect the results from that, plus the
+ // found terminals together.
+ var requestsToDispatch []ReduceableExpandFunc
+ for _, nonTerminalUser := range foundNonTerminalUsersets {
+ toDispatch := ce.dispatch(ValidatedExpandRequest{
+ &v1.DispatchExpandRequest{
+ ResourceAndRelation: nonTerminalUser.Subject,
+ Metadata: decrementDepth(req.Metadata),
+ ExpansionMode: req.ExpansionMode,
+ },
+ req.Revision,
+ })
+
+ requestsToDispatch = append(requestsToDispatch, decorateWithCaveatIfNecessary(toDispatch, nonTerminalUser.CaveatExpression))
+ }
+
+ result := expandAny(ctx, req.ResourceAndRelation, requestsToDispatch)
+ if result.Err != nil {
+ resultChan <- result
+ return
+ }
+
+ unionNode := result.Resp.TreeNode.GetIntermediateNode()
+ unionNode.ChildNodes = append(unionNode.ChildNodes, &core.RelationTupleTreeNode{
+ NodeType: &core.RelationTupleTreeNode_LeafNode{
+ LeafNode: &core.DirectSubjects{
+ Subjects: append(foundTerminalUsersets, foundNonTerminalUsersets...),
+ },
+ },
+ Expanded: req.ResourceAndRelation,
+ })
+ resultChan <- result
+ }
+}
+
+func decorateWithCaveatIfNecessary(toDispatch ReduceableExpandFunc, caveatExpr *core.CaveatExpression) ReduceableExpandFunc {
+ // If no caveat expression, simply return the func unmodified.
+ if caveatExpr == nil {
+ return toDispatch
+ }
+
+ // Otherwise return a wrapped function that expands the underlying func to be dispatched, and then decorates
+ // the resulting node with the caveat expression.
+ //
+ // TODO(jschorr): This will generate a lot of function closures, so we should change Expand to avoid them
+ // like we did in Check.
+ return func(ctx context.Context, resultChan chan<- ExpandResult) {
+ result := expandOne(ctx, toDispatch)
+ if result.Err != nil {
+ resultChan <- result
+ return
+ }
+
+ result.Resp.TreeNode.CaveatExpression = caveatExpr
+ resultChan <- result
+ }
+}
+
+func (ce *ConcurrentExpander) expandUsersetRewrite(ctx context.Context, req ValidatedExpandRequest, usr *core.UsersetRewrite) ReduceableExpandFunc {
+ switch rw := usr.RewriteOperation.(type) {
+ case *core.UsersetRewrite_Union:
+ log.Ctx(ctx).Trace().Msg("union")
+ return ce.expandSetOperation(ctx, req, rw.Union, expandAny)
+ case *core.UsersetRewrite_Intersection:
+ log.Ctx(ctx).Trace().Msg("intersection")
+ return ce.expandSetOperation(ctx, req, rw.Intersection, expandAll)
+ case *core.UsersetRewrite_Exclusion:
+ log.Ctx(ctx).Trace().Msg("exclusion")
+ return ce.expandSetOperation(ctx, req, rw.Exclusion, expandDifference)
+ default:
+ return alwaysFailExpand
+ }
+}
+
+func (ce *ConcurrentExpander) expandSetOperation(ctx context.Context, req ValidatedExpandRequest, so *core.SetOperation, reducer ExpandReducer) ReduceableExpandFunc {
+ var requests []ReduceableExpandFunc
+ for _, childOneof := range so.Child {
+ switch child := childOneof.ChildType.(type) {
+ case *core.SetOperation_Child_XThis:
+ return expandError(errors.New("use of _this is unsupported; please rewrite your schema"))
+ case *core.SetOperation_Child_ComputedUserset:
+ requests = append(requests, ce.expandComputedUserset(ctx, req, child.ComputedUserset, nil))
+ case *core.SetOperation_Child_UsersetRewrite:
+ requests = append(requests, ce.expandUsersetRewrite(ctx, req, child.UsersetRewrite))
+ case *core.SetOperation_Child_TupleToUserset:
+ requests = append(requests, expandTupleToUserset(ctx, ce, req, child.TupleToUserset, expandAny))
+ case *core.SetOperation_Child_FunctionedTupleToUserset:
+ switch child.FunctionedTupleToUserset.Function {
+ case core.FunctionedTupleToUserset_FUNCTION_ANY:
+ requests = append(requests, expandTupleToUserset(ctx, ce, req, child.FunctionedTupleToUserset, expandAny))
+
+ case core.FunctionedTupleToUserset_FUNCTION_ALL:
+ requests = append(requests, expandTupleToUserset(ctx, ce, req, child.FunctionedTupleToUserset, expandAll))
+
+ default:
+ return expandError(spiceerrors.MustBugf("unknown function `%s` in expand", child.FunctionedTupleToUserset.Function))
+ }
+ case *core.SetOperation_Child_XNil:
+ requests = append(requests, emptyExpansion(req.ResourceAndRelation))
+ default:
+ return expandError(spiceerrors.MustBugf("unknown set operation child `%T` in expand", child))
+ }
+ }
+ return func(ctx context.Context, resultChan chan<- ExpandResult) {
+ resultChan <- reducer(ctx, req.ResourceAndRelation, requests)
+ }
+}
+
+func (ce *ConcurrentExpander) dispatch(req ValidatedExpandRequest) ReduceableExpandFunc {
+ return func(ctx context.Context, resultChan chan<- ExpandResult) {
+ log.Ctx(ctx).Trace().Object("dispatchExpand", req).Send()
+ result, err := ce.d.DispatchExpand(ctx, req.DispatchExpandRequest)
+ resultChan <- ExpandResult{result, err}
+ }
+}
+
+func (ce *ConcurrentExpander) expandComputedUserset(ctx context.Context, req ValidatedExpandRequest, cu *core.ComputedUserset, rel *tuple.Relationship) ReduceableExpandFunc {
+ log.Ctx(ctx).Trace().Str("relation", cu.Relation).Msg("computed userset")
+ var start tuple.ObjectAndRelation
+ if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT {
+ if rel == nil {
+ return expandError(spiceerrors.MustBugf("computed userset for tupleset without tuple"))
+ }
+
+ start = rel.Subject
+ } else if cu.Object == core.ComputedUserset_TUPLE_OBJECT {
+ if rel != nil {
+ start = rel.Resource
+ } else {
+ start = tuple.FromCoreObjectAndRelation(req.ResourceAndRelation)
+ }
+ }
+
+ // Check if the target relation exists. If not, return nothing.
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision)
+ err := namespace.CheckNamespaceAndRelation(ctx, start.ObjectType, cu.Relation, true, ds)
+ if err != nil {
+ if errors.As(err, &namespace.RelationNotFoundError{}) {
+ return emptyExpansion(req.ResourceAndRelation)
+ }
+
+ return expandError(err)
+ }
+
+ return ce.dispatch(ValidatedExpandRequest{
+ &v1.DispatchExpandRequest{
+ ResourceAndRelation: &core.ObjectAndRelation{
+ Namespace: start.ObjectType,
+ ObjectId: start.ObjectID,
+ Relation: cu.Relation,
+ },
+ Metadata: decrementDepth(req.Metadata),
+ ExpansionMode: req.ExpansionMode,
+ },
+ req.Revision,
+ })
+}
+
+type expandFunc func(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult
+
+func expandTupleToUserset[T relation](
+ _ context.Context,
+ ce *ConcurrentExpander,
+ req ValidatedExpandRequest,
+ ttu ttu[T],
+ expandFunc expandFunc,
+) ReduceableExpandFunc {
+ return func(ctx context.Context, resultChan chan<- ExpandResult) {
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision)
+ it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{
+ OptionalResourceType: req.ResourceAndRelation.Namespace,
+ OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId},
+ OptionalResourceRelation: ttu.GetTupleset().GetRelation(),
+ }, options.WithQueryShape(queryshape.AllSubjectsForResources))
+ if err != nil {
+ resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata)
+ return
+ }
+
+ var requestsToDispatch []ReduceableExpandFunc
+ for rel, err := range it {
+ if err != nil {
+ resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata)
+ return
+ }
+
+ toDispatch := ce.expandComputedUserset(ctx, req, ttu.GetComputedUserset(), &rel)
+ requestsToDispatch = append(requestsToDispatch, decorateWithCaveatIfNecessary(toDispatch, caveats.CaveatAsExpr(rel.OptionalCaveat)))
+ }
+
+ resultChan <- expandFunc(ctx, req.ResourceAndRelation, requestsToDispatch)
+ }
+}
+
+func setResult(
+ op core.SetOperationUserset_Operation,
+ start *core.ObjectAndRelation,
+ children []*core.RelationTupleTreeNode,
+ metadata *v1.ResponseMeta,
+) ExpandResult {
+ return expandResult(
+ &core.RelationTupleTreeNode{
+ NodeType: &core.RelationTupleTreeNode_IntermediateNode{
+ IntermediateNode: &core.SetOperationUserset{
+ Operation: op,
+ ChildNodes: children,
+ },
+ },
+ Expanded: start,
+ },
+ metadata,
+ )
+}
+
+func expandSetOperation(
+ ctx context.Context,
+ start *core.ObjectAndRelation,
+ requests []ReduceableExpandFunc,
+ op core.SetOperationUserset_Operation,
+) ExpandResult {
+ children := make([]*core.RelationTupleTreeNode, 0, len(requests))
+
+ if len(requests) == 0 {
+ return setResult(op, start, children, emptyMetadata)
+ }
+
+ childCtx, cancelFn := context.WithCancel(ctx)
+ defer cancelFn()
+
+ resultChans := make([]chan ExpandResult, 0, len(requests))
+ for _, req := range requests {
+ resultChan := make(chan ExpandResult, 1)
+ resultChans = append(resultChans, resultChan)
+ go req(childCtx, resultChan)
+ }
+
+ responseMetadata := emptyMetadata
+ for _, resultChan := range resultChans {
+ select {
+ case result := <-resultChan:
+ responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata)
+ if result.Err != nil {
+ return expandResultError(result.Err, responseMetadata)
+ }
+ children = append(children, result.Resp.TreeNode)
+ case <-ctx.Done():
+ return expandResultError(context.Canceled, responseMetadata)
+ }
+ }
+
+ return setResult(op, start, children, responseMetadata)
+}
+
+// emptyExpansion returns an empty expansion.
+func emptyExpansion(start *core.ObjectAndRelation) ReduceableExpandFunc {
+ return func(ctx context.Context, resultChan chan<- ExpandResult) {
+ resultChan <- expandResult(&core.RelationTupleTreeNode{
+ NodeType: &core.RelationTupleTreeNode_LeafNode{
+ LeafNode: &core.DirectSubjects{},
+ },
+ Expanded: start,
+ }, emptyMetadata)
+ }
+}
+
+// expandError returns the error.
+func expandError(err error) ReduceableExpandFunc {
+ return func(ctx context.Context, resultChan chan<- ExpandResult) {
+ resultChan <- expandResultError(err, emptyMetadata)
+ }
+}
+
+// expandAll returns a tree with all of the children and an intersection node type.
+func expandAll(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult {
+ return expandSetOperation(ctx, start, requests, core.SetOperationUserset_INTERSECTION)
+}
+
+// expandAny returns a tree with all of the children and a union node type.
+func expandAny(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult {
+ return expandSetOperation(ctx, start, requests, core.SetOperationUserset_UNION)
+}
+
+// expandDifference returns a tree with all of the children and an exclusion node type.
+func expandDifference(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult {
+ return expandSetOperation(ctx, start, requests, core.SetOperationUserset_EXCLUSION)
+}
+
+// expandOne waits for exactly one response
+func expandOne(ctx context.Context, request ReduceableExpandFunc) ExpandResult {
+ resultChan := make(chan ExpandResult, 1)
+ go request(ctx, resultChan)
+
+ select {
+ case result := <-resultChan:
+ if result.Err != nil {
+ return result
+ }
+ return result
+ case <-ctx.Done():
+ return expandResultError(context.Canceled, emptyMetadata)
+ }
+}
+
+var errAlwaysFailExpand = errors.New("always fail")
+
+func alwaysFailExpand(_ context.Context, resultChan chan<- ExpandResult) {
+ resultChan <- expandResultError(errAlwaysFailExpand, emptyMetadata)
+}
+
+func expandResult(treeNode *core.RelationTupleTreeNode, subProblemMetadata *v1.ResponseMeta) ExpandResult {
+ return ExpandResult{
+ &v1.DispatchExpandResponse{
+ Metadata: ensureMetadata(subProblemMetadata),
+ TreeNode: treeNode,
+ },
+ nil,
+ }
+}
+
+func expandResultError(err error, subProblemMetadata *v1.ResponseMeta) ExpandResult {
+ return ExpandResult{
+ &v1.DispatchExpandResponse{
+ Metadata: ensureMetadata(subProblemMetadata),
+ },
+ err,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/graph.go b/vendor/github.com/authzed/spicedb/internal/graph/graph.go
new file mode 100644
index 0000000..2a44189
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/graph.go
@@ -0,0 +1,89 @@
+package graph
+
+import (
+ "context"
+
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+)
+
+// Ellipsis relation is used to signify a semantic-free relationship.
+const Ellipsis = "..."
+
+// CheckResult is the data that is returned by a single check or sub-check.
+type CheckResult struct {
+ Resp *v1.DispatchCheckResponse
+ Err error
+}
+
+func (cr CheckResult) ResultError() error {
+ return cr.Err
+}
+
+// ExpandResult is the data that is returned by a single expand or sub-expand.
+type ExpandResult struct {
+ Resp *v1.DispatchExpandResponse
+ Err error
+}
+
+func (er ExpandResult) ResultError() error {
+ return er.Err
+}
+
+// ReduceableExpandFunc is a function that can be bound to a execution context.
+type ReduceableExpandFunc func(ctx context.Context, resultChan chan<- ExpandResult)
+
+// AlwaysFailExpand is a ReduceableExpandFunc which will always fail when reduced.
+func AlwaysFailExpand(_ context.Context, resultChan chan<- ExpandResult) {
+ resultChan <- expandResultError(NewAlwaysFailErr(), emptyMetadata)
+}
+
+// ExpandReducer is a type for the functions Any and All which combine check results.
+type ExpandReducer func(
+ ctx context.Context,
+ start *core.ObjectAndRelation,
+ requests []ReduceableExpandFunc,
+) ExpandResult
+
+func decrementDepth(md *v1.ResolverMeta) *v1.ResolverMeta {
+ return &v1.ResolverMeta{
+ AtRevision: md.AtRevision,
+ DepthRemaining: md.DepthRemaining - 1,
+ TraversalBloom: md.TraversalBloom,
+ }
+}
+
+var emptyMetadata = &v1.ResponseMeta{}
+
+func ensureMetadata(subProblemMetadata *v1.ResponseMeta) *v1.ResponseMeta {
+ if subProblemMetadata == nil {
+ subProblemMetadata = emptyMetadata
+ }
+
+ return &v1.ResponseMeta{
+ DispatchCount: subProblemMetadata.DispatchCount,
+ DepthRequired: subProblemMetadata.DepthRequired,
+ CachedDispatchCount: subProblemMetadata.CachedDispatchCount,
+ DebugInfo: subProblemMetadata.DebugInfo,
+ }
+}
+
+func addCallToResponseMetadata(metadata *v1.ResponseMeta) *v1.ResponseMeta {
+ // + 1 for the current call.
+ return &v1.ResponseMeta{
+ DispatchCount: metadata.DispatchCount + 1,
+ DepthRequired: metadata.DepthRequired + 1,
+ CachedDispatchCount: metadata.CachedDispatchCount,
+ DebugInfo: metadata.DebugInfo,
+ }
+}
+
+func addAdditionalDepthRequired(metadata *v1.ResponseMeta) *v1.ResponseMeta {
+ return &v1.ResponseMeta{
+ DispatchCount: metadata.DispatchCount,
+ DepthRequired: metadata.DepthRequired + 1,
+ CachedDispatchCount: metadata.CachedDispatchCount,
+ DebugInfo: metadata.DebugInfo,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go b/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go
new file mode 100644
index 0000000..485a3bf
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go
@@ -0,0 +1,96 @@
+package hints
+
+import (
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// CheckHintForComputedUserset creates a CheckHint for a relation and a subject.
+func CheckHintForComputedUserset(resourceType string, resourceID string, relation string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) *v1.CheckHint {
+ return &v1.CheckHint{
+ Resource: &core.ObjectAndRelation{
+ Namespace: resourceType,
+ ObjectId: resourceID,
+ Relation: relation,
+ },
+ Subject: subject.ToCoreONR(),
+ Result: result,
+ }
+}
+
+// CheckHintForArrow creates a CheckHint for an arrow and a subject.
+func CheckHintForArrow(resourceType string, resourceID string, tuplesetRelation string, computedUsersetRelation string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) *v1.CheckHint {
+ return &v1.CheckHint{
+ Resource: &core.ObjectAndRelation{
+ Namespace: resourceType,
+ ObjectId: resourceID,
+ Relation: tuplesetRelation,
+ },
+ TtuComputedUsersetRelation: computedUsersetRelation,
+ Subject: subject.ToCoreONR(),
+ Result: result,
+ }
+}
+
+// AsCheckHintForComputedUserset returns the resourceID if the checkHint is for the given relation and subject.
+func AsCheckHintForComputedUserset(checkHint *v1.CheckHint, resourceType string, relationName string, subject tuple.ObjectAndRelation) (string, bool) {
+ if checkHint.TtuComputedUsersetRelation != "" {
+ return "", false
+ }
+
+ if checkHint.Resource.Namespace == resourceType && checkHint.Resource.Relation == relationName && checkHint.Subject.EqualVT(subject.ToCoreONR()) {
+ return checkHint.Resource.ObjectId, true
+ }
+
+ return "", false
+}
+
+// AsCheckHintForArrow returns the resourceID if the checkHint is for the given arrow and subject.
+func AsCheckHintForArrow(checkHint *v1.CheckHint, resourceType string, tuplesetRelation string, computedUsersetRelation string, subject tuple.ObjectAndRelation) (string, bool) {
+ if checkHint.TtuComputedUsersetRelation != computedUsersetRelation {
+ return "", false
+ }
+
+ if checkHint.Resource.Namespace == resourceType && checkHint.Resource.Relation == tuplesetRelation && checkHint.Subject.EqualVT(subject.ToCoreONR()) {
+ return checkHint.Resource.ObjectId, true
+ }
+
+ return "", false
+}
+
+// HintForEntrypoint returns a CheckHint for the given reachability graph entrypoint and associated subject and result.
+func HintForEntrypoint(re schema.ReachabilityEntrypoint, resourceID string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) (*v1.CheckHint, error) {
+ switch re.EntrypointKind() {
+ case core.ReachabilityEntrypoint_RELATION_ENTRYPOINT:
+ return nil, spiceerrors.MustBugf("cannot call CheckHintForResource for kind %v", re.EntrypointKind())
+
+ case core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT:
+ namespace := re.TargetNamespace()
+ tuplesetRelation, err := re.TuplesetRelation()
+ if err != nil {
+ return nil, err
+ }
+
+ computedUsersetRelation, err := re.ComputedUsersetRelation()
+ if err != nil {
+ return nil, err
+ }
+
+ return CheckHintForArrow(namespace, resourceID, tuplesetRelation, computedUsersetRelation, subject, result), nil
+
+ case core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT:
+ namespace := re.TargetNamespace()
+ relation, err := re.ComputedUsersetRelation()
+ if err != nil {
+ return nil, err
+ }
+
+ return CheckHintForComputedUserset(namespace, resourceID, relation, subject, result), nil
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown relation entrypoint kind")
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/limits.go b/vendor/github.com/authzed/spicedb/internal/graph/limits.go
new file mode 100644
index 0000000..6b2e2bd
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/limits.go
@@ -0,0 +1,80 @@
+package graph
+
+import (
+ "fmt"
+
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+var ErrLimitReached = fmt.Errorf("limit has been reached")
+
+// limitTracker is a helper struct for tracking the limit requested by a caller and decrementing
+// that limit as results are published.
+type limitTracker struct {
+ hasLimit bool
+ currentLimit uint32
+}
+
+// newLimitTracker creates a new limit tracker, returning the tracker.
+func newLimitTracker(optionalLimit uint32) *limitTracker {
+ return &limitTracker{
+ currentLimit: optionalLimit,
+ hasLimit: optionalLimit > 0,
+ }
+}
+
+// clone creates a copy of the limitTracker, inheriting the current limit.
+func (lt *limitTracker) clone() *limitTracker {
+ return &limitTracker{
+ currentLimit: lt.currentLimit,
+ hasLimit: lt.hasLimit,
+ }
+}
+
+// prepareForPublishing asks the limit tracker to remove an element from the limit requested,
+// returning whether that element can be published.
+//
+// Example usage:
+//
+// okay := limits.prepareForPublishing()
+// if okay { ... publish ... }
+func (lt *limitTracker) prepareForPublishing() bool {
+ // if there is no limit defined, then the count is always allowed.
+ if !lt.hasLimit {
+ return true
+ }
+
+ // if the limit has been reached, allow no further items to be published.
+ if lt.currentLimit == 0 {
+ return false
+ }
+
+ // otherwise, remove the element from the limit.
+ lt.currentLimit--
+ return true
+}
+
+// markAlreadyPublished marks that the given count of results has already been published. If the count is
+// greater than the limit, returns a spiceerror.
+func (lt *limitTracker) markAlreadyPublished(count uint32) error {
+ if !lt.hasLimit {
+ return nil
+ }
+
+ if count > lt.currentLimit {
+ return spiceerrors.MustBugf("given published count of %d exceeds the remaining limit of %d", count, lt.currentLimit)
+ }
+
+ lt.currentLimit -= count
+ if lt.currentLimit == 0 {
+ return nil
+ }
+
+ return nil
+}
+
+// hasExhaustedLimit returns true if the limit has been reached and all items allowable have been
+// published.
+func (lt *limitTracker) hasExhaustedLimit() bool {
+ return lt.hasLimit && lt.currentLimit == 0
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go b/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go
new file mode 100644
index 0000000..57acb49
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go
@@ -0,0 +1,681 @@
+package graph
+
+import (
+ "context"
+ "slices"
+ "sort"
+
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+
+ "github.com/authzed/spicedb/internal/caveats"
+ "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/graph/computed"
+ "github.com/authzed/spicedb/internal/graph/hints"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// dispatchVersion defines the "version" of this dispatcher. Must be incremented
+// anytime an incompatible change is made to the dispatcher itself or its cursor
+// production.
+const dispatchVersion = 1
+
+func NewCursoredLookupResources2(dl dispatch.LookupResources2, dc dispatch.Check, caveatTypeSet *caveattypes.TypeSet, concurrencyLimit uint16, dispatchChunkSize uint16) *CursoredLookupResources2 {
+ return &CursoredLookupResources2{dl, dc, caveatTypeSet, concurrencyLimit, dispatchChunkSize}
+}
+
+type CursoredLookupResources2 struct {
+ dl dispatch.LookupResources2
+ dc dispatch.Check
+ caveatTypeSet *caveattypes.TypeSet
+ concurrencyLimit uint16
+ dispatchChunkSize uint16
+}
+
+type ValidatedLookupResources2Request struct {
+ *v1.DispatchLookupResources2Request
+ Revision datastore.Revision
+}
+
+func (crr *CursoredLookupResources2) LookupResources2(
+ req ValidatedLookupResources2Request,
+ stream dispatch.LookupResources2Stream,
+) error {
+ ctx, span := tracer.Start(stream.Context(), "lookupResources2")
+ defer span.End()
+
+ if req.TerminalSubject == nil {
+ return spiceerrors.MustBugf("no terminal subject given to lookup resources dispatch")
+ }
+
+ if slices.Contains(req.SubjectIds, tuple.PublicWildcard) {
+ return NewWildcardNotAllowedErr("cannot perform lookup resources on wildcard", "subject_id")
+ }
+
+ if len(req.SubjectIds) == 0 {
+ return spiceerrors.MustBugf("no subjects ids given to lookup resources dispatch")
+ }
+
+ // Sort for stability.
+ if len(req.SubjectIds) > 1 {
+ sort.Strings(req.SubjectIds)
+ }
+
+ limits := newLimitTracker(req.OptionalLimit)
+ ci, err := newCursorInformation(req.OptionalCursor, limits, dispatchVersion)
+ if err != nil {
+ return err
+ }
+
+ return withSubsetInCursor(ci,
+ func(currentOffset int, nextCursorWith afterResponseCursor) error {
+ // If the resource type matches the subject type, yield directly as a one-to-one result
+ // for each subjectID.
+ if req.SubjectRelation.Namespace == req.ResourceRelation.Namespace &&
+ req.SubjectRelation.Relation == req.ResourceRelation.Relation {
+ for index, subjectID := range req.SubjectIds {
+ if index < currentOffset {
+ continue
+ }
+
+ if !ci.limits.prepareForPublishing() {
+ return nil
+ }
+
+ err := stream.Publish(&v1.DispatchLookupResources2Response{
+ Resource: &v1.PossibleResource{
+ ResourceId: subjectID,
+ ForSubjectIds: []string{subjectID},
+ },
+ Metadata: emptyMetadata,
+ AfterResponseCursor: nextCursorWith(index + 1),
+ })
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+ }, func(ci cursorInformation) error {
+ // Once done checking for the matching subject type, yield by dispatching over entrypoints.
+ return crr.afterSameType(ctx, ci, req, stream)
+ })
+}
+
+func (crr *CursoredLookupResources2) afterSameType(
+ ctx context.Context,
+ ci cursorInformation,
+ req ValidatedLookupResources2Request,
+ parentStream dispatch.LookupResources2Stream,
+) error {
+ reachabilityForString := req.ResourceRelation.Namespace + "#" + req.ResourceRelation.Relation
+ ctx, span := tracer.Start(ctx, "reachability: "+reachabilityForString)
+ defer span.End()
+
+ dispatched := NewSyncONRSet()
+
+ // Load the type system and reachability graph to find the entrypoints for the reachability.
+ ds := datastoremw.MustFromContext(ctx)
+ reader := ds.SnapshotReader(req.Revision)
+ ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(reader))
+ vdef, err := ts.GetValidatedDefinition(ctx, req.ResourceRelation.Namespace)
+ if err != nil {
+ return err
+ }
+
+ rg := vdef.Reachability()
+ entrypoints, err := rg.FirstEntrypointsForSubjectToResource(ctx, &core.RelationReference{
+ Namespace: req.SubjectRelation.Namespace,
+ Relation: req.SubjectRelation.Relation,
+ }, req.ResourceRelation)
+ if err != nil {
+ return err
+ }
+
+ // For each entrypoint, load the necessary data and re-dispatch if a subproblem was found.
+ return withParallelizedStreamingIterableInCursor(ctx, ci, entrypoints, parentStream, crr.concurrencyLimit,
+ func(ctx context.Context, ci cursorInformation, entrypoint schema.ReachabilityEntrypoint, stream dispatch.LookupResources2Stream) error {
+ ds, err := entrypoint.DebugString()
+ spiceerrors.DebugAssert(func() bool {
+ return err == nil
+ }, "Error in entrypoint.DebugString()")
+ ctx, span := tracer.Start(ctx, "entrypoint: "+ds, trace.WithAttributes())
+ defer span.End()
+
+ switch entrypoint.EntrypointKind() {
+ case core.ReachabilityEntrypoint_RELATION_ENTRYPOINT:
+ return crr.lookupRelationEntrypoint(ctx, ci, entrypoint, rg, ts, reader, req, stream, dispatched)
+
+ case core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT:
+ containingRelation := entrypoint.ContainingRelationOrPermission()
+ rewrittenSubjectRelation := &core.RelationReference{
+ Namespace: containingRelation.Namespace,
+ Relation: containingRelation.Relation,
+ }
+
+ rsm := subjectIDsToResourcesMap2(rewrittenSubjectRelation, req.SubjectIds)
+ drsm := rsm.asReadOnly()
+
+ return crr.redispatchOrReport(
+ ctx,
+ ci,
+ rewrittenSubjectRelation,
+ drsm,
+ rg,
+ entrypoint,
+ stream,
+ req,
+ dispatched,
+ )
+
+ case core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT:
+ return crr.lookupTTUEntrypoint(ctx, ci, entrypoint, rg, ts, reader, req, stream, dispatched)
+
+ default:
+ return spiceerrors.MustBugf("Unknown kind of entrypoint: %v", entrypoint.EntrypointKind())
+ }
+ })
+}
+
+func (crr *CursoredLookupResources2) lookupRelationEntrypoint(
+ ctx context.Context,
+ ci cursorInformation,
+ entrypoint schema.ReachabilityEntrypoint,
+ rg *schema.DefinitionReachability,
+ ts *schema.TypeSystem,
+ reader datastore.Reader,
+ req ValidatedLookupResources2Request,
+ stream dispatch.LookupResources2Stream,
+ dispatched *syncONRSet,
+) error {
+ relationReference, err := entrypoint.DirectRelation()
+ if err != nil {
+ return err
+ }
+
+ relDefinition, err := ts.GetValidatedDefinition(ctx, relationReference.Namespace)
+ if err != nil {
+ return err
+ }
+
+ // Build the list of subjects to lookup based on the type information available.
+ isDirectAllowed, err := relDefinition.IsAllowedDirectRelation(
+ relationReference.Relation,
+ req.SubjectRelation.Namespace,
+ req.SubjectRelation.Relation,
+ )
+ if err != nil {
+ return err
+ }
+
+ subjectIds := make([]string, 0, len(req.SubjectIds)+1)
+ if isDirectAllowed == schema.DirectRelationValid {
+ subjectIds = append(subjectIds, req.SubjectIds...)
+ }
+
+ if req.SubjectRelation.Relation == tuple.Ellipsis {
+ isWildcardAllowed, err := relDefinition.IsAllowedPublicNamespace(relationReference.Relation, req.SubjectRelation.Namespace)
+ if err != nil {
+ return err
+ }
+
+ if isWildcardAllowed == schema.PublicSubjectAllowed {
+ subjectIds = append(subjectIds, "*")
+ }
+ }
+
+ // Lookup the subjects and then redispatch/report results.
+ relationFilter := datastore.SubjectRelationFilter{
+ NonEllipsisRelation: req.SubjectRelation.Relation,
+ }
+
+ if req.SubjectRelation.Relation == tuple.Ellipsis {
+ relationFilter = datastore.SubjectRelationFilter{
+ IncludeEllipsisRelation: true,
+ }
+ }
+
+ subjectsFilter := datastore.SubjectsFilter{
+ SubjectType: req.SubjectRelation.Namespace,
+ OptionalSubjectIds: subjectIds,
+ RelationFilter: relationFilter,
+ }
+
+ return crr.redispatchOrReportOverDatabaseQuery(
+ ctx,
+ redispatchOverDatabaseConfig2{
+ ci: ci,
+ ts: ts,
+ reader: reader,
+ subjectsFilter: subjectsFilter,
+ sourceResourceType: relationReference,
+ foundResourceType: relationReference,
+ entrypoint: entrypoint,
+ rg: rg,
+ concurrencyLimit: crr.concurrencyLimit,
+ parentStream: stream,
+ parentRequest: req,
+ dispatched: dispatched,
+ },
+ )
+}
+
+type redispatchOverDatabaseConfig2 struct {
+ ci cursorInformation
+
+ ts *schema.TypeSystem
+
+ // Direct reader for reverse ReverseQueryRelationships
+ reader datastore.Reader
+
+ subjectsFilter datastore.SubjectsFilter
+ sourceResourceType *core.RelationReference
+ foundResourceType *core.RelationReference
+
+ entrypoint schema.ReachabilityEntrypoint
+ rg *schema.DefinitionReachability
+
+ concurrencyLimit uint16
+ parentStream dispatch.LookupResources2Stream
+ parentRequest ValidatedLookupResources2Request
+ dispatched *syncONRSet
+}
+
+func (crr *CursoredLookupResources2) redispatchOrReportOverDatabaseQuery(
+ ctx context.Context,
+ config redispatchOverDatabaseConfig2,
+) error {
+ ctx, span := tracer.Start(ctx, "datastorequery", trace.WithAttributes(
+ attribute.String("source-resource-type-namespace", config.sourceResourceType.Namespace),
+ attribute.String("source-resource-type-relation", config.sourceResourceType.Relation),
+ attribute.String("subjects-filter-subject-type", config.subjectsFilter.SubjectType),
+ attribute.Int("subjects-filter-subject-ids-count", len(config.subjectsFilter.OptionalSubjectIds)),
+ ))
+ defer span.End()
+
+ return withDatastoreCursorInCursor(ctx, config.ci, config.parentStream, config.concurrencyLimit,
+ // Find the target resources for the subject.
+ func(queryCursor options.Cursor) ([]itemAndPostCursor[dispatchableResourcesSubjectMap2], error) {
+ it, err := config.reader.ReverseQueryRelationships(
+ ctx,
+ config.subjectsFilter,
+ options.WithResRelation(&options.ResourceRelation{
+ Namespace: config.sourceResourceType.Namespace,
+ Relation: config.sourceResourceType.Relation,
+ }),
+ options.WithSortForReverse(options.BySubject),
+ options.WithAfterForReverse(queryCursor),
+ options.WithQueryShapeForReverse(queryshape.MatchingResourcesForSubject),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Chunk based on the FilterMaximumIDCount, to ensure we never send more than that amount of
+ // results to a downstream dispatch.
+ rsm := newResourcesSubjectMap2WithCapacity(config.sourceResourceType, uint32(crr.dispatchChunkSize))
+ toBeHandled := make([]itemAndPostCursor[dispatchableResourcesSubjectMap2], 0)
+ currentCursor := queryCursor
+ caveatRunner := caveats.NewCaveatRunner(crr.caveatTypeSet)
+
+ for rel, err := range it {
+ if err != nil {
+ return nil, err
+ }
+
+ var missingContextParameters []string
+
+ // If a caveat exists on the relationship, run it and filter the results, marking those that have missing context.
+ if rel.OptionalCaveat != nil && rel.OptionalCaveat.CaveatName != "" {
+ caveatExpr := caveats.CaveatAsExpr(rel.OptionalCaveat)
+ runResult, err := caveatRunner.RunCaveatExpression(ctx, caveatExpr, config.parentRequest.Context.AsMap(), config.reader, caveats.RunCaveatExpressionNoDebugging)
+ if err != nil {
+ return nil, err
+ }
+
+ // If a partial result is returned, collect the missing context parameters.
+ if runResult.IsPartial() {
+ missingNames, err := runResult.MissingVarNames()
+ if err != nil {
+ return nil, err
+ }
+
+ missingContextParameters = missingNames
+ } else if !runResult.Value() {
+ // If the run result shows the caveat does not apply, skip. This shears the tree of results early.
+ continue
+ }
+ }
+
+ if err := rsm.addRelationship(rel, missingContextParameters); err != nil {
+ return nil, err
+ }
+
+ if rsm.len() == int(crr.dispatchChunkSize) {
+ toBeHandled = append(toBeHandled, itemAndPostCursor[dispatchableResourcesSubjectMap2]{
+ item: rsm.asReadOnly(),
+ cursor: currentCursor,
+ })
+ rsm = newResourcesSubjectMap2WithCapacity(config.sourceResourceType, uint32(crr.dispatchChunkSize))
+ currentCursor = options.ToCursor(rel)
+ }
+ }
+
+ if rsm.len() > 0 {
+ toBeHandled = append(toBeHandled, itemAndPostCursor[dispatchableResourcesSubjectMap2]{
+ item: rsm.asReadOnly(),
+ cursor: currentCursor,
+ })
+ }
+
+ return toBeHandled, nil
+ },
+
+ // Redispatch or report the results.
+ func(
+ ctx context.Context,
+ ci cursorInformation,
+ drsm dispatchableResourcesSubjectMap2,
+ currentStream dispatch.LookupResources2Stream,
+ ) error {
+ return crr.redispatchOrReport(
+ ctx,
+ ci,
+ config.foundResourceType,
+ drsm,
+ config.rg,
+ config.entrypoint,
+ currentStream,
+ config.parentRequest,
+ config.dispatched,
+ )
+ },
+ )
+}
+
+func (crr *CursoredLookupResources2) lookupTTUEntrypoint(ctx context.Context,
+ ci cursorInformation,
+ entrypoint schema.ReachabilityEntrypoint,
+ rg *schema.DefinitionReachability,
+ ts *schema.TypeSystem,
+ reader datastore.Reader,
+ req ValidatedLookupResources2Request,
+ stream dispatch.LookupResources2Stream,
+ dispatched *syncONRSet,
+) error {
+ containingRelation := entrypoint.ContainingRelationOrPermission()
+
+ ttuDef, err := ts.GetValidatedDefinition(ctx, containingRelation.Namespace)
+ if err != nil {
+ return err
+ }
+
+ tuplesetRelation, err := entrypoint.TuplesetRelation()
+ if err != nil {
+ return err
+ }
+
+ // Determine whether this TTU should be followed, which will be the case if the subject relation's namespace
+ // is allowed in any form on the relation; since arrows ignore the subject's relation (if any), we check
+ // for the subject namespace as a whole.
+ allowedRelations, err := ttuDef.GetAllowedDirectNamespaceSubjectRelations(tuplesetRelation, req.SubjectRelation.Namespace)
+ if err != nil {
+ return err
+ }
+
+ if allowedRelations == nil {
+ return nil
+ }
+
+ // Search for the resolved subjects in the tupleset of the TTU.
+ subjectsFilter := datastore.SubjectsFilter{
+ SubjectType: req.SubjectRelation.Namespace,
+ OptionalSubjectIds: req.SubjectIds,
+ }
+
+ // Optimization: if there is a single allowed relation, pass it as a subject relation filter to make things faster
+ // on querying.
+ if allowedRelations.Len() == 1 {
+ allowedRelationName := allowedRelations.AsSlice()[0]
+ subjectsFilter.RelationFilter = datastore.SubjectRelationFilter{}.WithRelation(allowedRelationName)
+ }
+
+ tuplesetRelationReference := &core.RelationReference{
+ Namespace: containingRelation.Namespace,
+ Relation: tuplesetRelation,
+ }
+
+ return crr.redispatchOrReportOverDatabaseQuery(
+ ctx,
+ redispatchOverDatabaseConfig2{
+ ci: ci,
+ ts: ts,
+ reader: reader,
+ subjectsFilter: subjectsFilter,
+ sourceResourceType: tuplesetRelationReference,
+ foundResourceType: containingRelation,
+ entrypoint: entrypoint,
+ rg: rg,
+ parentStream: stream,
+ parentRequest: req,
+ dispatched: dispatched,
+ },
+ )
+}
+
+type possibleResourceAndIndex struct {
+ resource *v1.PossibleResource
+ index int
+}
+
+// redispatchOrReport checks if further redispatching is necessary for the found resource
+// type. If not, and the found resource type+relation matches the target resource type+relation,
+// the resource is reported to the parent stream.
+func (crr *CursoredLookupResources2) redispatchOrReport(
+ ctx context.Context,
+ ci cursorInformation,
+ foundResourceType *core.RelationReference,
+ foundResources dispatchableResourcesSubjectMap2,
+ rg *schema.DefinitionReachability,
+ entrypoint schema.ReachabilityEntrypoint,
+ parentStream dispatch.LookupResources2Stream,
+ parentRequest ValidatedLookupResources2Request,
+ dispatched *syncONRSet,
+) error {
+ if foundResources.isEmpty() {
+ // Nothing more to do.
+ return nil
+ }
+
+ ctx, span := tracer.Start(ctx, "redispatchOrReport", trace.WithAttributes(
+ attribute.Int("found-resources-count", foundResources.len()),
+ ))
+ defer span.End()
+
+ // Check for entrypoints for the new found resource type.
+ hasResourceEntrypoints, err := rg.HasOptimizedEntrypointsForSubjectToResource(ctx, foundResourceType, parentRequest.ResourceRelation)
+ if err != nil {
+ return err
+ }
+
+ return withSubsetInCursor(ci,
+ func(currentOffset int, nextCursorWith afterResponseCursor) error {
+ if !hasResourceEntrypoints {
+ // If the found resource matches the target resource type and relation, potentially yield the resource.
+ if foundResourceType.Namespace == parentRequest.ResourceRelation.Namespace && foundResourceType.Relation == parentRequest.ResourceRelation.Relation {
+ resources := foundResources.asPossibleResources()
+ if len(resources) == 0 {
+ return nil
+ }
+
+ if currentOffset >= len(resources) {
+ return nil
+ }
+
+ offsetted := resources[currentOffset:]
+ if len(offsetted) == 0 {
+ return nil
+ }
+
+ filtered := make([]possibleResourceAndIndex, 0, len(offsetted))
+ for index, resource := range offsetted {
+ filtered = append(filtered, possibleResourceAndIndex{
+ resource: resource,
+ index: index,
+ })
+ }
+
+ metadata := emptyMetadata
+
+ // If the entrypoint is not a direct result, issue a check to further filter the results on the intersection or exclusion.
+ if !entrypoint.IsDirectResult() {
+ resourceIDs := make([]string, 0, len(offsetted))
+ checkHints := make([]*v1.CheckHint, 0, len(offsetted))
+ for _, resource := range offsetted {
+ resourceIDs = append(resourceIDs, resource.ResourceId)
+
+ checkHint, err := hints.HintForEntrypoint(
+ entrypoint,
+ resource.ResourceId,
+ tuple.FromCoreObjectAndRelation(parentRequest.TerminalSubject),
+ &v1.ResourceCheckResult{
+ Membership: v1.ResourceCheckResult_MEMBER,
+ })
+ if err != nil {
+ return err
+ }
+ checkHints = append(checkHints, checkHint)
+ }
+
+ resultsByResourceID, checkMetadata, _, err := computed.ComputeBulkCheck(ctx, crr.dc, crr.caveatTypeSet, computed.CheckParameters{
+ ResourceType: tuple.FromCoreRelationReference(parentRequest.ResourceRelation),
+ Subject: tuple.FromCoreObjectAndRelation(parentRequest.TerminalSubject),
+ CaveatContext: parentRequest.Context.AsMap(),
+ AtRevision: parentRequest.Revision,
+ MaximumDepth: parentRequest.Metadata.DepthRemaining - 1,
+ DebugOption: computed.NoDebugging,
+ CheckHints: checkHints,
+ }, resourceIDs, crr.dispatchChunkSize)
+ if err != nil {
+ return err
+ }
+
+ metadata = addCallToResponseMetadata(checkMetadata)
+
+ filtered = make([]possibleResourceAndIndex, 0, len(offsetted))
+ for index, resource := range offsetted {
+ result, ok := resultsByResourceID[resource.ResourceId]
+ if !ok {
+ continue
+ }
+
+ switch result.Membership {
+ case v1.ResourceCheckResult_MEMBER:
+ filtered = append(filtered, possibleResourceAndIndex{
+ resource: resource,
+ index: index,
+ })
+
+ case v1.ResourceCheckResult_CAVEATED_MEMBER:
+ missingContextParams := mapz.NewSet(result.MissingExprFields...)
+ missingContextParams.Extend(resource.MissingContextParams)
+
+ filtered = append(filtered, possibleResourceAndIndex{
+ resource: &v1.PossibleResource{
+ ResourceId: resource.ResourceId,
+ ForSubjectIds: resource.ForSubjectIds,
+ MissingContextParams: missingContextParams.AsSlice(),
+ },
+ index: index,
+ })
+
+ case v1.ResourceCheckResult_NOT_MEMBER:
+ // Skip.
+
+ default:
+ return spiceerrors.MustBugf("unexpected result from check: %v", result.Membership)
+ }
+ }
+ }
+
+ for _, resourceAndIndex := range filtered {
+ if !ci.limits.prepareForPublishing() {
+ return nil
+ }
+
+ err := parentStream.Publish(&v1.DispatchLookupResources2Response{
+ Resource: resourceAndIndex.resource,
+ Metadata: metadata,
+ AfterResponseCursor: nextCursorWith(currentOffset + resourceAndIndex.index + 1),
+ })
+ if err != nil {
+ return err
+ }
+
+ metadata = emptyMetadata
+ }
+ return nil
+ }
+ }
+ return nil
+ }, func(ci cursorInformation) error {
+ if !hasResourceEntrypoints {
+ return nil
+ }
+
+ // The new subject type for dispatching was the found type of the *resource*.
+ newSubjectType := foundResourceType
+
+ // To avoid duplicate work, remove any subjects already dispatched.
+ filteredSubjectIDs := foundResources.filterSubjectIDsToDispatch(dispatched, newSubjectType)
+ if len(filteredSubjectIDs) == 0 {
+ return nil
+ }
+
+ // If the entrypoint is a direct result then we can simply dispatch directly and map
+ // all found results, as no further filtering will be needed.
+ if entrypoint.IsDirectResult() {
+ stream := unfilteredLookupResourcesDispatchStreamForEntrypoint(ctx, foundResources, parentStream, ci)
+ return crr.dl.DispatchLookupResources2(&v1.DispatchLookupResources2Request{
+ ResourceRelation: parentRequest.ResourceRelation,
+ SubjectRelation: newSubjectType,
+ SubjectIds: filteredSubjectIDs,
+ TerminalSubject: parentRequest.TerminalSubject,
+ Metadata: &v1.ResolverMeta{
+ AtRevision: parentRequest.Revision.String(),
+ DepthRemaining: parentRequest.Metadata.DepthRemaining - 1,
+ },
+ OptionalCursor: ci.currentCursor,
+ OptionalLimit: parentRequest.OptionalLimit,
+ Context: parentRequest.Context,
+ }, stream)
+ }
+
+ // Otherwise, we need to filter results by batch checking along the way before dispatching.
+ return runCheckerAndDispatch(
+ ctx,
+ parentRequest,
+ foundResources,
+ ci,
+ parentStream,
+ newSubjectType,
+ filteredSubjectIDs,
+ entrypoint,
+ crr.dl,
+ crr.dc,
+ crr.caveatTypeSet,
+ crr.concurrencyLimit,
+ crr.dispatchChunkSize,
+ )
+ })
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go b/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go
new file mode 100644
index 0000000..7560847
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go
@@ -0,0 +1,803 @@
+package graph
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+
+ "golang.org/x/sync/errgroup"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ "github.com/authzed/spicedb/internal/datasets"
+ "github.com/authzed/spicedb/internal/dispatch"
+ log "github.com/authzed/spicedb/internal/logging"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/namespace"
+ "github.com/authzed/spicedb/internal/taskrunner"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/genutil/slicez"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// ValidatedLookupSubjectsRequest represents a request after it has been validated and parsed for internal
+// consumption.
+type ValidatedLookupSubjectsRequest struct {
+ *v1.DispatchLookupSubjectsRequest
+ Revision datastore.Revision
+}
+
+// NewConcurrentLookupSubjects creates an instance of ConcurrentLookupSubjects.
+func NewConcurrentLookupSubjects(d dispatch.LookupSubjects, concurrencyLimit uint16, dispatchChunkSize uint16) *ConcurrentLookupSubjects {
+ return &ConcurrentLookupSubjects{d, concurrencyLimit, dispatchChunkSize}
+}
+
+type ConcurrentLookupSubjects struct {
+ d dispatch.LookupSubjects
+ concurrencyLimit uint16
+ dispatchChunkSize uint16
+}
+
+func (cl *ConcurrentLookupSubjects) LookupSubjects(
+ req ValidatedLookupSubjectsRequest,
+ stream dispatch.LookupSubjectsStream,
+) error {
+ ctx := stream.Context()
+
+ if len(req.ResourceIds) == 0 {
+ return fmt.Errorf("no resources ids given to lookupsubjects dispatch")
+ }
+
+ // If the resource type matches the subject type, yield directly.
+ if req.SubjectRelation.Namespace == req.ResourceRelation.Namespace &&
+ req.SubjectRelation.Relation == req.ResourceRelation.Relation {
+ if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{
+ FoundSubjectsByResourceId: subjectsForConcreteIds(req.ResourceIds),
+ Metadata: emptyMetadata,
+ }); err != nil {
+ return err
+ }
+ }
+
+ ds := datastoremw.MustFromContext(ctx)
+ reader := ds.SnapshotReader(req.Revision)
+ _, relation, err := namespace.ReadNamespaceAndRelation(
+ ctx,
+ req.ResourceRelation.Namespace,
+ req.ResourceRelation.Relation,
+ reader)
+ if err != nil {
+ return err
+ }
+
+ if relation.UsersetRewrite == nil {
+ // Direct lookup of subjects.
+ return cl.lookupDirectSubjects(ctx, req, stream, relation, reader)
+ }
+
+ return cl.lookupViaRewrite(ctx, req, stream, relation.UsersetRewrite)
+}
+
+func subjectsForConcreteIds(subjectIds []string) map[string]*v1.FoundSubjects {
+ foundSubjects := make(map[string]*v1.FoundSubjects, len(subjectIds))
+ for _, subjectID := range subjectIds {
+ foundSubjects[subjectID] = &v1.FoundSubjects{
+ FoundSubjects: []*v1.FoundSubject{
+ {
+ SubjectId: subjectID,
+ CaveatExpression: nil, // Explicitly nil since this is a concrete found subject.
+ },
+ },
+ }
+ }
+ return foundSubjects
+}
+
+func (cl *ConcurrentLookupSubjects) lookupDirectSubjects(
+ ctx context.Context,
+ req ValidatedLookupSubjectsRequest,
+ stream dispatch.LookupSubjectsStream,
+ _ *core.Relation,
+ reader datastore.Reader,
+) error {
+ // TODO(jschorr): use type information to skip subject relations that cannot reach the subject type.
+
+ toDispatchByType := datasets.NewSubjectByTypeSet()
+ foundSubjectsByResourceID := datasets.NewSubjectSetByResourceID()
+ relationshipsBySubjectONR := mapz.NewMultiMap[tuple.ObjectAndRelation, tuple.Relationship]()
+
+ it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{
+ OptionalResourceType: req.ResourceRelation.Namespace,
+ OptionalResourceRelation: req.ResourceRelation.Relation,
+ OptionalResourceIds: req.ResourceIds,
+ }, options.WithQueryShape(queryshape.AllSubjectsForResources))
+ if err != nil {
+ return err
+ }
+
+ for rel, err := range it {
+ if err != nil {
+ return err
+ }
+
+ if rel.Subject.ObjectType == req.SubjectRelation.Namespace &&
+ rel.Subject.Relation == req.SubjectRelation.Relation {
+ if err := foundSubjectsByResourceID.AddFromRelationship(rel); err != nil {
+ return fmt.Errorf("failed to call AddFromRelationship in lookupDirectSubjects: %w", err)
+ }
+ }
+
+ if rel.Subject.Relation != tuple.Ellipsis {
+ err := toDispatchByType.AddSubjectOf(rel)
+ if err != nil {
+ return err
+ }
+
+ relationshipsBySubjectONR.Add(rel.Subject, rel)
+ }
+ }
+
+ if !foundSubjectsByResourceID.IsEmpty() {
+ if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{
+ FoundSubjectsByResourceId: foundSubjectsByResourceID.AsMap(),
+ Metadata: emptyMetadata,
+ }); err != nil {
+ return err
+ }
+ }
+
+ return cl.dispatchTo(ctx, req, toDispatchByType, relationshipsBySubjectONR, stream)
+}
+
+func (cl *ConcurrentLookupSubjects) lookupViaComputed(
+ ctx context.Context,
+ parentRequest ValidatedLookupSubjectsRequest,
+ parentStream dispatch.LookupSubjectsStream,
+ cu *core.ComputedUserset,
+) error {
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision)
+ if err := namespace.CheckNamespaceAndRelation(ctx, parentRequest.ResourceRelation.Namespace, cu.Relation, true, ds); err != nil {
+ if errors.As(err, &namespace.RelationNotFoundError{}) {
+ return nil
+ }
+
+ return err
+ }
+
+ stream := &dispatch.WrappedDispatchStream[*v1.DispatchLookupSubjectsResponse]{
+ Stream: parentStream,
+ Ctx: ctx,
+ Processor: func(result *v1.DispatchLookupSubjectsResponse) (*v1.DispatchLookupSubjectsResponse, bool, error) {
+ return &v1.DispatchLookupSubjectsResponse{
+ FoundSubjectsByResourceId: result.FoundSubjectsByResourceId,
+ Metadata: addCallToResponseMetadata(result.Metadata),
+ }, true, nil
+ },
+ }
+
+ return cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{
+ ResourceRelation: &core.RelationReference{
+ Namespace: parentRequest.ResourceRelation.Namespace,
+ Relation: cu.Relation,
+ },
+ ResourceIds: parentRequest.ResourceIds,
+ SubjectRelation: parentRequest.SubjectRelation,
+ Metadata: &v1.ResolverMeta{
+ AtRevision: parentRequest.Revision.String(),
+ DepthRemaining: parentRequest.Metadata.DepthRemaining - 1,
+ },
+ }, stream)
+}
+
+type resourceDispatchTracker struct {
+ ctx context.Context
+ cancelDispatch context.CancelFunc
+ resourceID string
+
+ subjectsSet datasets.SubjectSet // GUARDED_BY(lock)
+ metadata *v1.ResponseMeta // GUARDED_BY(lock)
+
+ isFirstUpdate bool // GUARDED_BY(lock)
+ wasCanceled bool // GUARDED_BY(lock)
+
+ lock sync.Mutex
+}
+
+func lookupViaIntersectionTupleToUserset(
+ ctx context.Context,
+ cl *ConcurrentLookupSubjects,
+ parentRequest ValidatedLookupSubjectsRequest,
+ parentStream dispatch.LookupSubjectsStream,
+ ttu *core.FunctionedTupleToUserset,
+) error {
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision)
+ it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{
+ OptionalResourceType: parentRequest.ResourceRelation.Namespace,
+ OptionalResourceRelation: ttu.GetTupleset().GetRelation(),
+ OptionalResourceIds: parentRequest.ResourceIds,
+ }, options.WithQueryShape(queryshape.AllSubjectsForResources))
+ if err != nil {
+ return err
+ }
+
+ // TODO(jschorr): Find a means of doing this without dispatching per subject, per resource. Perhaps
+ // there is a way we can still dispatch to all the subjects at once, and then intersect the results
+ // afterwards.
+ resourceDispatchTrackerByResourceID := make(map[string]*resourceDispatchTracker)
+
+ cancelCtx, checkCancel := context.WithCancel(ctx)
+ defer checkCancel()
+
+ // For each found tuple, dispatch a lookup subjects request and collect its results.
+ // We need to intersect between *all* the found subjects for each resource ID.
+ var ttuCaveat *core.CaveatExpression
+ taskrunner := taskrunner.NewPreloadedTaskRunner(cancelCtx, cl.concurrencyLimit, 1)
+ for rel, err := range it {
+ if err != nil {
+ return err
+ }
+
+ // If the relationship has a caveat, add it to the overall TTU caveat. Since this is an intersection
+ // of *all* branches, the caveat will be applied to all found subjects, so this is a safe approach.
+ if rel.OptionalCaveat != nil {
+ ttuCaveat = caveatAnd(ttuCaveat, wrapCaveat(rel.OptionalCaveat))
+ }
+
+ if err := namespace.CheckNamespaceAndRelation(ctx, rel.Subject.ObjectType, ttu.GetComputedUserset().Relation, false, ds); err != nil {
+ if !errors.As(err, &namespace.RelationNotFoundError{}) {
+ return err
+ }
+
+ continue
+ }
+
+ // Create a data structure to track the intersection of subjects for the particular resource. If the resource's subject set
+ // ends up empty anywhere along the way, the dispatches for *that resource* will be canceled early.
+ resourceID := rel.Resource.ObjectID
+ dispatchInfoForResource, ok := resourceDispatchTrackerByResourceID[resourceID]
+ if !ok {
+ dispatchCtx, cancelDispatch := context.WithCancel(cancelCtx)
+ dispatchInfoForResource = &resourceDispatchTracker{
+ ctx: dispatchCtx,
+ cancelDispatch: cancelDispatch,
+ resourceID: resourceID,
+ subjectsSet: datasets.NewSubjectSet(),
+ metadata: emptyMetadata,
+ isFirstUpdate: true,
+ lock: sync.Mutex{},
+ }
+ resourceDispatchTrackerByResourceID[resourceID] = dispatchInfoForResource
+ }
+
+ rel := rel
+ taskrunner.Add(func(ctx context.Context) error {
+ // Collect all results for this branch of the resource ID.
+ // TODO(jschorr): once LS has cursoring (and thus, ordering), we can move to not collecting everything up before intersecting
+ // for this branch of the resource ID.
+ collectingStream := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](dispatchInfoForResource.ctx)
+ err := cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{
+ ResourceRelation: &core.RelationReference{
+ Namespace: rel.Subject.ObjectType,
+ Relation: ttu.GetComputedUserset().Relation,
+ },
+ ResourceIds: []string{rel.Subject.ObjectID},
+ SubjectRelation: parentRequest.SubjectRelation,
+ Metadata: &v1.ResolverMeta{
+ AtRevision: parentRequest.Revision.String(),
+ DepthRemaining: parentRequest.Metadata.DepthRemaining - 1,
+ },
+ }, collectingStream)
+ if err != nil {
+ // Check if the dispatches for the resource were canceled, and if so, return nil to stop the task.
+ dispatchInfoForResource.lock.Lock()
+ wasCanceled := dispatchInfoForResource.wasCanceled
+ dispatchInfoForResource.lock.Unlock()
+
+ if wasCanceled {
+ if errors.Is(err, context.Canceled) {
+ return nil
+ }
+
+ errStatus, ok := status.FromError(err)
+ if ok && errStatus.Code() == codes.Canceled {
+ return nil
+ }
+ }
+
+ return err
+ }
+
+ // Collect the results into a subject set.
+ results := datasets.NewSubjectSet()
+ collectedMetadata := emptyMetadata
+ for _, result := range collectingStream.Results() {
+ collectedMetadata = combineResponseMetadata(ctx, collectedMetadata, result.Metadata)
+ for _, foundSubjects := range result.FoundSubjectsByResourceId {
+ if err := results.UnionWith(foundSubjects.FoundSubjects); err != nil {
+ return fmt.Errorf("failed to UnionWith under lookupSubjectsIntersection: %w", err)
+ }
+ }
+ }
+
+ dispatchInfoForResource.lock.Lock()
+ defer dispatchInfoForResource.lock.Unlock()
+
+ dispatchInfoForResource.metadata = combineResponseMetadata(ctx, dispatchInfoForResource.metadata, collectedMetadata)
+
+ // If the first update for the resource, set the subjects set to the results.
+ if dispatchInfoForResource.isFirstUpdate {
+ dispatchInfoForResource.isFirstUpdate = false
+ dispatchInfoForResource.subjectsSet = results
+ } else {
+ // Otherwise, intersect the results with the existing subjects set.
+ err := dispatchInfoForResource.subjectsSet.IntersectionDifference(results)
+ if err != nil {
+ return err
+ }
+ }
+
+ // If the subjects set is empty, cancel the dispatch for any further results for this resource ID.
+ if dispatchInfoForResource.subjectsSet.IsEmpty() {
+ dispatchInfoForResource.wasCanceled = true
+ dispatchInfoForResource.cancelDispatch()
+ }
+
+ return nil
+ })
+ }
+
+ // Wait for all dispatched operations to complete.
+ if err := taskrunner.StartAndWait(); err != nil {
+ return err
+ }
+
+ // For each resource ID, intersect the found subjects from each stream.
+ metadata := emptyMetadata
+ currentSubjectsByResourceID := map[string]*v1.FoundSubjects{}
+
+ for incomingResourceID, tracker := range resourceDispatchTrackerByResourceID {
+ currentSubjects := tracker.subjectsSet
+ currentSubjects = currentSubjects.WithParentCaveatExpression(ttuCaveat)
+ currentSubjectsByResourceID[incomingResourceID] = currentSubjects.AsFoundSubjects()
+
+ metadata = combineResponseMetadata(ctx, metadata, tracker.metadata)
+ }
+
+ return parentStream.Publish(&v1.DispatchLookupSubjectsResponse{
+ FoundSubjectsByResourceId: currentSubjectsByResourceID,
+ Metadata: metadata,
+ })
+}
+
+func lookupViaTupleToUserset[T relation](
+ ctx context.Context,
+ cl *ConcurrentLookupSubjects,
+ parentRequest ValidatedLookupSubjectsRequest,
+ parentStream dispatch.LookupSubjectsStream,
+ ttu ttu[T],
+) error {
+ toDispatchByTuplesetType := datasets.NewSubjectByTypeSet()
+ relationshipsBySubjectONR := mapz.NewMultiMap[tuple.ObjectAndRelation, tuple.Relationship]()
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision)
+ it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{
+ OptionalResourceType: parentRequest.ResourceRelation.Namespace,
+ OptionalResourceRelation: ttu.GetTupleset().GetRelation(),
+ OptionalResourceIds: parentRequest.ResourceIds,
+ }, options.WithQueryShape(queryshape.AllSubjectsForResources))
+ if err != nil {
+ return err
+ }
+
+ for rel, err := range it {
+ if err != nil {
+ return err
+ }
+
+ // Add the subject to be dispatched.
+ err := toDispatchByTuplesetType.AddSubjectOf(rel)
+ if err != nil {
+ return err
+ }
+
+ // Add the *rewritten* subject to the relationships multimap for mapping back to the associated
+ // relationship, as we will be mapping from the computed relation, not the tupleset relation.
+ relationshipsBySubjectONR.Add(tuple.ONR(rel.Subject.ObjectType, rel.Subject.ObjectID, ttu.GetComputedUserset().Relation), rel)
+ }
+
+ // Map the found subject types by the computed userset relation, so that we dispatch to it.
+ toDispatchByComputedRelationType, err := toDispatchByTuplesetType.Map(func(resourceType *core.RelationReference) (*core.RelationReference, error) {
+ if err := namespace.CheckNamespaceAndRelation(ctx, resourceType.Namespace, ttu.GetComputedUserset().Relation, false, ds); err != nil {
+ if errors.As(err, &namespace.RelationNotFoundError{}) {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+
+ return &core.RelationReference{
+ Namespace: resourceType.Namespace,
+ Relation: ttu.GetComputedUserset().Relation,
+ }, nil
+ })
+ if err != nil {
+ return err
+ }
+
+ return cl.dispatchTo(ctx, parentRequest, toDispatchByComputedRelationType, relationshipsBySubjectONR, parentStream)
+}
+
+func (cl *ConcurrentLookupSubjects) lookupViaRewrite(
+ ctx context.Context,
+ req ValidatedLookupSubjectsRequest,
+ stream dispatch.LookupSubjectsStream,
+ usr *core.UsersetRewrite,
+) error {
+ switch rw := usr.RewriteOperation.(type) {
+ case *core.UsersetRewrite_Union:
+ log.Ctx(ctx).Trace().Msg("union")
+ return cl.lookupSetOperation(ctx, req, rw.Union, newLookupSubjectsUnion(stream))
+ case *core.UsersetRewrite_Intersection:
+ log.Ctx(ctx).Trace().Msg("intersection")
+ return cl.lookupSetOperation(ctx, req, rw.Intersection, newLookupSubjectsIntersection(stream))
+ case *core.UsersetRewrite_Exclusion:
+ log.Ctx(ctx).Trace().Msg("exclusion")
+ return cl.lookupSetOperation(ctx, req, rw.Exclusion, newLookupSubjectsExclusion(stream))
+ default:
+ return fmt.Errorf("unknown kind of rewrite in lookup subjects")
+ }
+}
+
+func (cl *ConcurrentLookupSubjects) lookupSetOperation(
+ ctx context.Context,
+ req ValidatedLookupSubjectsRequest,
+ so *core.SetOperation,
+ reducer lookupSubjectsReducer,
+) error {
+ cancelCtx, checkCancel := context.WithCancel(ctx)
+ defer checkCancel()
+
+ g, subCtx := errgroup.WithContext(cancelCtx)
+ g.SetLimit(int(cl.concurrencyLimit))
+
+ for index, childOneof := range so.Child {
+ stream := reducer.ForIndex(subCtx, index)
+
+ switch child := childOneof.ChildType.(type) {
+ case *core.SetOperation_Child_XThis:
+ return errors.New("use of _this is unsupported; please rewrite your schema")
+
+ case *core.SetOperation_Child_ComputedUserset:
+ g.Go(func() error {
+ return cl.lookupViaComputed(subCtx, req, stream, child.ComputedUserset)
+ })
+
+ case *core.SetOperation_Child_UsersetRewrite:
+ g.Go(func() error {
+ return cl.lookupViaRewrite(subCtx, req, stream, child.UsersetRewrite)
+ })
+
+ case *core.SetOperation_Child_TupleToUserset:
+ g.Go(func() error {
+ return lookupViaTupleToUserset(subCtx, cl, req, stream, child.TupleToUserset)
+ })
+
+ case *core.SetOperation_Child_FunctionedTupleToUserset:
+ switch child.FunctionedTupleToUserset.Function {
+ case core.FunctionedTupleToUserset_FUNCTION_ANY:
+ g.Go(func() error {
+ return lookupViaTupleToUserset(subCtx, cl, req, stream, child.FunctionedTupleToUserset)
+ })
+
+ case core.FunctionedTupleToUserset_FUNCTION_ALL:
+ g.Go(func() error {
+ return lookupViaIntersectionTupleToUserset(subCtx, cl, req, stream, child.FunctionedTupleToUserset)
+ })
+
+ default:
+ return spiceerrors.MustBugf("unknown function in lookup subjects: %v", child.FunctionedTupleToUserset.Function)
+ }
+
+ case *core.SetOperation_Child_XNil:
+ // Purposely do nothing.
+ continue
+
+ default:
+ return spiceerrors.MustBugf("unknown set operation child `%T` in lookup subjects", child)
+ }
+ }
+
+ // Wait for all dispatched operations to complete.
+ if err := g.Wait(); err != nil {
+ return err
+ }
+
+ return reducer.CompletedChildOperations(ctx)
+}
+
+func (cl *ConcurrentLookupSubjects) dispatchTo(
+ ctx context.Context,
+ parentRequest ValidatedLookupSubjectsRequest,
+ toDispatchByType *datasets.SubjectByTypeSet,
+ relationshipsBySubjectONR *mapz.MultiMap[tuple.ObjectAndRelation, tuple.Relationship],
+ parentStream dispatch.LookupSubjectsStream,
+) error {
+ if toDispatchByType.IsEmpty() {
+ return nil
+ }
+
+ cancelCtx, checkCancel := context.WithCancel(ctx)
+ defer checkCancel()
+
+ g, subCtx := errgroup.WithContext(cancelCtx)
+ g.SetLimit(int(cl.concurrencyLimit))
+
+ toDispatchByType.ForEachType(func(resourceType *core.RelationReference, foundSubjects datasets.SubjectSet) {
+ slice := foundSubjects.AsSlice()
+ resourceIds := make([]string, 0, len(slice))
+ for _, foundSubject := range slice {
+ resourceIds = append(resourceIds, foundSubject.SubjectId)
+ }
+
+ stream := &dispatch.WrappedDispatchStream[*v1.DispatchLookupSubjectsResponse]{
+ Stream: parentStream,
+ Ctx: subCtx,
+ Processor: func(result *v1.DispatchLookupSubjectsResponse) (*v1.DispatchLookupSubjectsResponse, bool, error) {
+ // For any found subjects, map them through their associated starting resources, to apply any caveats that were
+ // only those resources' relationships.
+ //
+ // For example, given relationships which formed the dispatch:
+ // - document:firstdoc#viewer@group:group1#member
+ // - document:firstdoc#viewer@group:group2#member[somecaveat]
+ //
+ // And results:
+ // - group1 => {user:tom, user:sarah}
+ // - group2 => {user:tom, user:fred}
+ //
+ // This will produce:
+ // - firstdoc => {user:tom, user:sarah, user:fred[somecaveat]}
+ //
+ mappedFoundSubjects := make(map[string]*v1.FoundSubjects)
+ for childResourceID, foundSubjects := range result.FoundSubjectsByResourceId {
+ subjectKey := tuple.ONR(resourceType.Namespace, childResourceID, resourceType.Relation)
+ relationships, _ := relationshipsBySubjectONR.Get(subjectKey)
+ if len(relationships) == 0 {
+ return nil, false, fmt.Errorf("missing relationships for subject key %v; please report this error", subjectKey)
+ }
+
+ for _, relationship := range relationships {
+ existing := mappedFoundSubjects[relationship.Resource.ObjectID]
+
+ // If the relationship has no caveat, simply map the resource ID.
+ if relationship.OptionalCaveat == nil {
+ combined, err := combineFoundSubjects(existing, foundSubjects)
+ if err != nil {
+ return nil, false, fmt.Errorf("could not combine caveat-less subjects: %w", err)
+ }
+ mappedFoundSubjects[relationship.Resource.ObjectID] = combined
+ continue
+ }
+
+ // Otherwise, apply the caveat to all found subjects for that resource and map to the resource ID.
+ foundSubjectSet := datasets.NewSubjectSet()
+ err := foundSubjectSet.UnionWith(foundSubjects.FoundSubjects)
+ if err != nil {
+ return nil, false, fmt.Errorf("could not combine subject sets: %w", err)
+ }
+
+ combined, err := combineFoundSubjects(
+ existing,
+ foundSubjectSet.WithParentCaveatExpression(wrapCaveat(relationship.OptionalCaveat)).AsFoundSubjects(),
+ )
+ if err != nil {
+ return nil, false, fmt.Errorf("could not combine caveated subjects: %w", err)
+ }
+
+ mappedFoundSubjects[relationship.Resource.ObjectID] = combined
+ }
+ }
+
+ return &v1.DispatchLookupSubjectsResponse{
+ FoundSubjectsByResourceId: mappedFoundSubjects,
+ Metadata: addCallToResponseMetadata(result.Metadata),
+ }, true, nil
+ },
+ }
+
+ // Dispatch the found subjects as the resources of the next step.
+ slicez.ForEachChunk(resourceIds, cl.dispatchChunkSize, func(resourceIdChunk []string) {
+ g.Go(func() error {
+ return cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{
+ ResourceRelation: resourceType,
+ ResourceIds: resourceIdChunk,
+ SubjectRelation: parentRequest.SubjectRelation,
+ Metadata: &v1.ResolverMeta{
+ AtRevision: parentRequest.Revision.String(),
+ DepthRemaining: parentRequest.Metadata.DepthRemaining - 1,
+ },
+ }, stream)
+ })
+ })
+ })
+
+ return g.Wait()
+}
+
+func combineFoundSubjects(existing *v1.FoundSubjects, toAdd *v1.FoundSubjects) (*v1.FoundSubjects, error) {
+ if existing == nil {
+ return toAdd, nil
+ }
+
+ if toAdd == nil {
+ return nil, fmt.Errorf("toAdd FoundSubject cannot be nil")
+ }
+
+ return &v1.FoundSubjects{
+ FoundSubjects: append(existing.FoundSubjects, toAdd.FoundSubjects...),
+ }, nil
+}
+
+type lookupSubjectsReducer interface {
+ ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream
+ CompletedChildOperations(ctx context.Context) error
+}
+
+// Union
+type lookupSubjectsUnion struct {
+ parentStream dispatch.LookupSubjectsStream
+ collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]
+}
+
+func newLookupSubjectsUnion(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsUnion {
+ return &lookupSubjectsUnion{
+ parentStream: parentStream,
+ collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{},
+ }
+}
+
+func (lsu *lookupSubjectsUnion) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream {
+ collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx)
+ lsu.collectors[setOperationIndex] = collector
+ return collector
+}
+
+func (lsu *lookupSubjectsUnion) CompletedChildOperations(ctx context.Context) error {
+ foundSubjects := datasets.NewSubjectSetByResourceID()
+ metadata := emptyMetadata
+
+ for index := 0; index < len(lsu.collectors); index++ {
+ collector, ok := lsu.collectors[index]
+ if !ok {
+ return fmt.Errorf("missing collector for index %d", index)
+ }
+
+ for _, result := range collector.Results() {
+ metadata = combineResponseMetadata(ctx, metadata, result.Metadata)
+ if err := foundSubjects.UnionWith(result.FoundSubjectsByResourceId); err != nil {
+ return fmt.Errorf("failed to UnionWith under lookupSubjectsUnion: %w", err)
+ }
+ }
+ }
+
+ if foundSubjects.IsEmpty() {
+ return nil
+ }
+
+ return lsu.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{
+ FoundSubjectsByResourceId: foundSubjects.AsMap(),
+ Metadata: metadata,
+ })
+}
+
+// Intersection
+type lookupSubjectsIntersection struct {
+ parentStream dispatch.LookupSubjectsStream
+ collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]
+}
+
+func newLookupSubjectsIntersection(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsIntersection {
+ return &lookupSubjectsIntersection{
+ parentStream: parentStream,
+ collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{},
+ }
+}
+
+func (lsi *lookupSubjectsIntersection) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream {
+ collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx)
+ lsi.collectors[setOperationIndex] = collector
+ return collector
+}
+
+func (lsi *lookupSubjectsIntersection) CompletedChildOperations(ctx context.Context) error {
+ var foundSubjects datasets.SubjectSetByResourceID
+ metadata := emptyMetadata
+
+ for index := 0; index < len(lsi.collectors); index++ {
+ collector, ok := lsi.collectors[index]
+ if !ok {
+ return fmt.Errorf("missing collector for index %d", index)
+ }
+
+ results := datasets.NewSubjectSetByResourceID()
+ for _, result := range collector.Results() {
+ metadata = combineResponseMetadata(ctx, metadata, result.Metadata)
+ if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil {
+ return fmt.Errorf("failed to UnionWith under lookupSubjectsIntersection: %w", err)
+ }
+ }
+
+ if index == 0 {
+ foundSubjects = results
+ } else {
+ err := foundSubjects.IntersectionDifference(results)
+ if err != nil {
+ return err
+ }
+
+ if foundSubjects.IsEmpty() {
+ return nil
+ }
+ }
+ }
+
+ return lsi.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{
+ FoundSubjectsByResourceId: foundSubjects.AsMap(),
+ Metadata: metadata,
+ })
+}
+
+// Exclusion
+type lookupSubjectsExclusion struct {
+ parentStream dispatch.LookupSubjectsStream
+ collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]
+}
+
+func newLookupSubjectsExclusion(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsExclusion {
+ return &lookupSubjectsExclusion{
+ parentStream: parentStream,
+ collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{},
+ }
+}
+
+func (lse *lookupSubjectsExclusion) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream {
+ collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx)
+ lse.collectors[setOperationIndex] = collector
+ return collector
+}
+
+func (lse *lookupSubjectsExclusion) CompletedChildOperations(ctx context.Context) error {
+ var foundSubjects datasets.SubjectSetByResourceID
+ metadata := emptyMetadata
+
+ for index := 0; index < len(lse.collectors); index++ {
+ collector := lse.collectors[index]
+ results := datasets.NewSubjectSetByResourceID()
+ for _, result := range collector.Results() {
+ metadata = combineResponseMetadata(ctx, metadata, result.Metadata)
+ if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil {
+ return fmt.Errorf("failed to UnionWith under lookupSubjectsExclusion: %w", err)
+ }
+ }
+
+ if index == 0 {
+ foundSubjects = results
+ } else {
+ foundSubjects.SubtractAll(results)
+ if foundSubjects.IsEmpty() {
+ return nil
+ }
+ }
+ }
+
+ return lse.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{
+ FoundSubjectsByResourceId: foundSubjects.AsMap(),
+ Metadata: metadata,
+ })
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go b/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go
new file mode 100644
index 0000000..f04ee6f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go
@@ -0,0 +1,334 @@
+package graph
+
+import (
+ "context"
+ "strconv"
+ "sync"
+
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/graph/computed"
+ "github.com/authzed/spicedb/internal/graph/hints"
+ "github.com/authzed/spicedb/internal/taskrunner"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// runCheckerAndDispatch runs the dispatch and checker for a lookup resources call, and publishes
+// the results to the parent stream. This function is responsible for handling checking the
+// results to filter them, and then dispatching those found.
+func runCheckerAndDispatch(
+ ctx context.Context,
+ parentReq ValidatedLookupResources2Request,
+ foundResources dispatchableResourcesSubjectMap2,
+ ci cursorInformation,
+ parentStream dispatch.LookupResources2Stream,
+ newSubjectType *core.RelationReference,
+ filteredSubjectIDs []string,
+ entrypoint schema.ReachabilityEntrypoint,
+ lrDispatcher dispatch.LookupResources2,
+ checkDispatcher dispatch.Check,
+ caveatTypeSet *caveattypes.TypeSet,
+ concurrencyLimit uint16,
+ dispatchChunkSize uint16,
+) error {
+ // Only allow max one dispatcher and one checker to run concurrently.
+ concurrencyLimit = min(concurrencyLimit, 2)
+
+ currentCheckIndex, err := ci.integerSectionValue()
+ if err != nil {
+ return err
+ }
+
+ rdc := &checkAndDispatchRunner{
+ parentRequest: parentReq,
+ foundResources: foundResources,
+ ci: ci,
+ parentStream: parentStream,
+ newSubjectType: newSubjectType,
+ filteredSubjectIDs: filteredSubjectIDs,
+ currentCheckIndex: currentCheckIndex,
+ entrypoint: entrypoint,
+ lrDispatcher: lrDispatcher,
+ checkDispatcher: checkDispatcher,
+ taskrunner: taskrunner.NewTaskRunner(ctx, concurrencyLimit),
+ lock: &sync.Mutex{},
+ dispatchChunkSize: dispatchChunkSize,
+ caveatTypeSet: caveatTypeSet,
+ }
+
+ return rdc.runAndWait()
+}
+
+type checkAndDispatchRunner struct {
+ parentRequest ValidatedLookupResources2Request
+ foundResources dispatchableResourcesSubjectMap2
+ parentStream dispatch.LookupResources2Stream
+ newSubjectType *core.RelationReference
+ entrypoint schema.ReachabilityEntrypoint
+ lrDispatcher dispatch.LookupResources2
+ checkDispatcher dispatch.Check
+ dispatchChunkSize uint16
+ caveatTypeSet *caveattypes.TypeSet
+ filteredSubjectIDs []string
+
+ currentCheckIndex int
+ taskrunner *taskrunner.TaskRunner
+
+ lock *sync.Mutex
+ ci cursorInformation // GUARDED_BY(lock)
+}
+
+func (rdc *checkAndDispatchRunner) runAndWait() error {
+ // Kick off a check at the current cursor, to filter a portion of the initial results set.
+ rdc.taskrunner.Schedule(func(ctx context.Context) error {
+ return rdc.runChecker(ctx, rdc.currentCheckIndex)
+ })
+
+ return rdc.taskrunner.Wait()
+}
+
+func (rdc *checkAndDispatchRunner) runChecker(ctx context.Context, startingIndex int) error {
+ rdc.lock.Lock()
+ if rdc.ci.limits.hasExhaustedLimit() {
+ rdc.lock.Unlock()
+ return nil
+ }
+ rdc.lock.Unlock()
+
+ endingIndex := min(startingIndex+int(rdc.dispatchChunkSize), len(rdc.filteredSubjectIDs))
+ resourceIDsToCheck := rdc.filteredSubjectIDs[startingIndex:endingIndex]
+ if len(resourceIDsToCheck) == 0 {
+ return nil
+ }
+
+ ctx, span := tracer.Start(ctx, "lr2Check", trace.WithAttributes(
+ attribute.Int("resource-id-count", len(resourceIDsToCheck)),
+ ))
+ defer span.End()
+
+ checkHints := make([]*v1.CheckHint, 0, len(resourceIDsToCheck))
+ for _, resourceID := range resourceIDsToCheck {
+ checkHint, err := hints.HintForEntrypoint(
+ rdc.entrypoint,
+ resourceID,
+ tuple.FromCoreObjectAndRelation(rdc.parentRequest.TerminalSubject),
+ &v1.ResourceCheckResult{
+ Membership: v1.ResourceCheckResult_MEMBER,
+ })
+ if err != nil {
+ return err
+ }
+ checkHints = append(checkHints, checkHint)
+ }
+
+ // NOTE: we are checking the containing permission here, *not* the target relation, as
+ // the goal is to shear for the containing permission.
+ resultsByResourceID, checkMetadata, _, err := computed.ComputeBulkCheck(ctx, rdc.checkDispatcher, rdc.caveatTypeSet, computed.CheckParameters{
+ ResourceType: tuple.FromCoreRelationReference(rdc.newSubjectType),
+ Subject: tuple.FromCoreObjectAndRelation(rdc.parentRequest.TerminalSubject),
+ CaveatContext: rdc.parentRequest.Context.AsMap(),
+ AtRevision: rdc.parentRequest.Revision,
+ MaximumDepth: rdc.parentRequest.Metadata.DepthRemaining - 1,
+ DebugOption: computed.NoDebugging,
+ CheckHints: checkHints,
+ }, resourceIDsToCheck, rdc.dispatchChunkSize)
+ if err != nil {
+ return err
+ }
+
+ adjustedResources := rdc.foundResources.cloneAsMutable()
+
+ // Dispatch any resources that are visible.
+ resourceIDToDispatch := make([]string, 0, len(resourceIDsToCheck))
+ for _, resourceID := range resourceIDsToCheck {
+ result, ok := resultsByResourceID[resourceID]
+ if !ok {
+ continue
+ }
+
+ switch result.Membership {
+ case v1.ResourceCheckResult_MEMBER:
+ fallthrough
+
+ case v1.ResourceCheckResult_CAVEATED_MEMBER:
+ // Record any additional caveats missing from the check.
+ adjustedResources.withAdditionalMissingContextForDispatchedResourceID(resourceID, result.MissingExprFields)
+ resourceIDToDispatch = append(resourceIDToDispatch, resourceID)
+
+ case v1.ResourceCheckResult_NOT_MEMBER:
+ // Skip.
+ continue
+
+ default:
+ return spiceerrors.MustBugf("unexpected result from check: %v", result.Membership)
+ }
+ }
+
+ if len(resourceIDToDispatch) > 0 {
+ // Schedule a dispatch of those resources.
+ rdc.taskrunner.Schedule(func(ctx context.Context) error {
+ return rdc.runDispatch(ctx, resourceIDToDispatch, adjustedResources.asReadOnly(), checkMetadata, startingIndex)
+ })
+ }
+
+ // Start the next check chunk (if applicable).
+ nextIndex := startingIndex + len(resourceIDsToCheck)
+ if nextIndex < len(rdc.filteredSubjectIDs) {
+ rdc.taskrunner.Schedule(func(ctx context.Context) error {
+ return rdc.runChecker(ctx, nextIndex)
+ })
+ }
+
+ return nil
+}
+
+func (rdc *checkAndDispatchRunner) runDispatch(
+ ctx context.Context,
+ resourceIDsToDispatch []string,
+ adjustedResources dispatchableResourcesSubjectMap2,
+ checkMetadata *v1.ResponseMeta,
+ startingIndex int,
+) error {
+ rdc.lock.Lock()
+ if rdc.ci.limits.hasExhaustedLimit() {
+ rdc.lock.Unlock()
+ return nil
+ }
+ rdc.lock.Unlock()
+
+ ctx, span := tracer.Start(ctx, "lr2Dispatch", trace.WithAttributes(
+ attribute.Int("resource-id-count", len(resourceIDsToDispatch)),
+ ))
+ defer span.End()
+
+ // NOTE: Since we extracted a custom section from the cursor at the beginning of this run, we have to add
+ // the starting index to the cursor to ensure that the next run starts from the correct place, and we have
+ // to use the *updated* cursor below on the dispatch.
+ updatedCi, err := rdc.ci.withOutgoingSection(strconv.Itoa(startingIndex))
+ if err != nil {
+ return err
+ }
+ responsePartialCursor := updatedCi.responsePartialCursor()
+
+ // Dispatch to the parent resource type and publish any results found.
+ isFirstPublishCall := true
+
+ wrappedStream := dispatch.NewHandlingDispatchStream(ctx, func(result *v1.DispatchLookupResources2Response) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+
+ if err := publishResultToParentStream(ctx, result, rdc.ci, responsePartialCursor, adjustedResources, nil, isFirstPublishCall, checkMetadata, rdc.parentStream); err != nil {
+ return err
+ }
+ isFirstPublishCall = false
+ return nil
+ })
+
+ return rdc.lrDispatcher.DispatchLookupResources2(&v1.DispatchLookupResources2Request{
+ ResourceRelation: rdc.parentRequest.ResourceRelation,
+ SubjectRelation: rdc.newSubjectType,
+ SubjectIds: resourceIDsToDispatch,
+ TerminalSubject: rdc.parentRequest.TerminalSubject,
+ Metadata: &v1.ResolverMeta{
+ AtRevision: rdc.parentRequest.Revision.String(),
+ DepthRemaining: rdc.parentRequest.Metadata.DepthRemaining - 1,
+ },
+ OptionalCursor: updatedCi.currentCursor,
+ OptionalLimit: rdc.ci.limits.currentLimit,
+ Context: rdc.parentRequest.Context,
+ }, wrappedStream)
+}
+
+// unfilteredLookupResourcesDispatchStreamForEntrypoint creates a new dispatch stream that wraps
+// the parent stream, and publishes the results of the lookup resources call to the parent stream,
+// mapped via foundResources.
+func unfilteredLookupResourcesDispatchStreamForEntrypoint(
+ ctx context.Context,
+ foundResources dispatchableResourcesSubjectMap2,
+ parentStream dispatch.LookupResources2Stream,
+ ci cursorInformation,
+) dispatch.LookupResources2Stream {
+ isFirstPublishCall := true
+
+ wrappedStream := dispatch.NewHandlingDispatchStream(ctx, func(result *v1.DispatchLookupResources2Response) error {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+
+ default:
+ }
+
+ if err := publishResultToParentStream(ctx, result, ci, ci.responsePartialCursor(), foundResources, nil, isFirstPublishCall, emptyMetadata, parentStream); err != nil {
+ return err
+ }
+ isFirstPublishCall = false
+ return nil
+ })
+
+ return wrappedStream
+}
+
+// publishResultToParentStream publishes the result of a lookup resources call to the parent stream,
+// mapped via foundResources.
+func publishResultToParentStream(
+ ctx context.Context,
+ result *v1.DispatchLookupResources2Response,
+ ci cursorInformation,
+ responseCursor *v1.Cursor,
+ foundResources dispatchableResourcesSubjectMap2,
+ additionalMissingContext []string,
+ isFirstPublishCall bool,
+ additionalMetadata *v1.ResponseMeta,
+ parentStream dispatch.LookupResources2Stream,
+) error {
+ // Map the found resources via the subject+resources used for dispatching, to determine
+ // if any need to be made conditional due to caveats.
+ mappedResource, err := foundResources.mapPossibleResource(result.Resource)
+ if err != nil {
+ return err
+ }
+
+ if !ci.limits.prepareForPublishing() {
+ return nil
+ }
+
+ // The cursor for the response is that of the parent response + the cursor from the result itself.
+ afterResponseCursor, err := combineCursors(
+ responseCursor,
+ result.AfterResponseCursor,
+ )
+ if err != nil {
+ return err
+ }
+
+ metadata := result.Metadata
+ if isFirstPublishCall {
+ metadata = addCallToResponseMetadata(metadata)
+ metadata = combineResponseMetadata(ctx, metadata, additionalMetadata)
+ } else {
+ metadata = addAdditionalDepthRequired(metadata)
+ }
+
+ missingContextParameters := mapz.NewSet(mappedResource.MissingContextParams...)
+ missingContextParameters.Extend(result.Resource.MissingContextParams)
+ missingContextParameters.Extend(additionalMissingContext)
+
+ mappedResource.MissingContextParams = missingContextParameters.AsSlice()
+
+ resp := &v1.DispatchLookupResources2Response{
+ Resource: mappedResource,
+ Metadata: metadata,
+ AfterResponseCursor: afterResponseCursor,
+ }
+
+ return parentStream.Publish(resp)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go b/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go
new file mode 100644
index 0000000..8ab20f4
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go
@@ -0,0 +1,243 @@
+package graph
+
+import (
+ "github.com/authzed/spicedb/internal/caveats"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+var (
+ caveatOr = caveats.Or
+ caveatAnd = caveats.And
+ caveatSub = caveats.Subtract
+ wrapCaveat = caveats.CaveatAsExpr
+)
+
+// CheckResultsMap defines a type that is a map from resource ID to ResourceCheckResult.
+// This must match that defined in the DispatchCheckResponse for the `results_by_resource_id`
+// field.
+type CheckResultsMap map[string]*v1.ResourceCheckResult
+
+// NewMembershipSet constructs a new helper set for tracking the membership found for a dispatched
+// check request.
+func NewMembershipSet() *MembershipSet {
+ return &MembershipSet{
+ hasDeterminedMember: false,
+ membersByID: map[string]*core.CaveatExpression{},
+ }
+}
+
+func membershipSetFromMap(mp map[string]*core.CaveatExpression) *MembershipSet {
+ ms := NewMembershipSet()
+ for resourceID, result := range mp {
+ ms.addMember(resourceID, result)
+ }
+ return ms
+}
+
+// MembershipSet is a helper set that trackes the membership results for a dispatched Check
+// request, including tracking of the caveats associated with found resource IDs.
+type MembershipSet struct {
+ membersByID map[string]*core.CaveatExpression
+ hasDeterminedMember bool
+}
+
+// AddDirectMember adds a resource ID that was *directly* found for the dispatched check, with
+// optional caveat found on the relationship.
+func (ms *MembershipSet) AddDirectMember(resourceID string, caveat *core.ContextualizedCaveat) {
+ ms.addMember(resourceID, wrapCaveat(caveat))
+}
+
+// AddMemberViaRelationship adds a resource ID that was found via another relationship, such
+// as the result of an arrow operation. The `parentRelationship` is the relationship that was
+// followed before the resource itself was resolved. This method will properly apply the caveat(s)
+// from both the parent relationship and the resource's result itself, assuming either have a caveat
+// associated.
+func (ms *MembershipSet) AddMemberViaRelationship(
+ resourceID string,
+ resourceCaveatExpression *core.CaveatExpression,
+ parentRelationship tuple.Relationship,
+) {
+ ms.AddMemberWithParentCaveat(resourceID, resourceCaveatExpression, parentRelationship.OptionalCaveat)
+}
+
+// AddMemberWithParentCaveat adds the given resource ID as a member with the parent caveat
+// combined via intersection with the resource's caveat. The parent caveat may be nil.
+func (ms *MembershipSet) AddMemberWithParentCaveat(
+ resourceID string,
+ resourceCaveatExpression *core.CaveatExpression,
+ parentCaveat *core.ContextualizedCaveat,
+) {
+ intersection := caveatAnd(wrapCaveat(parentCaveat), resourceCaveatExpression)
+ ms.addMember(resourceID, intersection)
+}
+
+// AddMemberWithOptionalCaveats adds the given resource ID as a member with the optional caveats combined
+// via intersection.
+func (ms *MembershipSet) AddMemberWithOptionalCaveats(
+ resourceID string,
+ caveats []*core.CaveatExpression,
+) {
+ if len(caveats) == 0 {
+ ms.addMember(resourceID, nil)
+ return
+ }
+
+ intersection := caveats[0]
+ for _, caveat := range caveats[1:] {
+ intersection = caveatAnd(intersection, caveat)
+ }
+
+ ms.addMember(resourceID, intersection)
+}
+
+func (ms *MembershipSet) addMember(resourceID string, caveatExpr *core.CaveatExpression) {
+ existing, ok := ms.membersByID[resourceID]
+ if !ok {
+ ms.hasDeterminedMember = ms.hasDeterminedMember || caveatExpr == nil
+ ms.membersByID[resourceID] = caveatExpr
+ return
+ }
+
+ // If a determined membership result has already been found (i.e. there is no caveat),
+ // then nothing more to do.
+ if existing == nil {
+ return
+ }
+
+ // If the new caveat expression is nil, then we are adding a determined result.
+ if caveatExpr == nil {
+ ms.hasDeterminedMember = true
+ ms.membersByID[resourceID] = nil
+ return
+ }
+
+ // Otherwise, the caveats get unioned together.
+ ms.membersByID[resourceID] = caveatOr(existing, caveatExpr)
+}
+
+// UnionWith combines the results found in the given map with the members of this set.
+// The changes are made in-place.
+func (ms *MembershipSet) UnionWith(resultsMap CheckResultsMap) {
+ for resourceID, details := range resultsMap {
+ if details.Membership != v1.ResourceCheckResult_NOT_MEMBER {
+ ms.addMember(resourceID, details.Expression)
+ }
+ }
+}
+
+// IntersectWith intersects the results found in the given map with the members of this set.
+// The changes are made in-place.
+func (ms *MembershipSet) IntersectWith(resultsMap CheckResultsMap) {
+ for resourceID := range ms.membersByID {
+ if details, ok := resultsMap[resourceID]; !ok || details.Membership == v1.ResourceCheckResult_NOT_MEMBER {
+ delete(ms.membersByID, resourceID)
+ }
+ }
+
+ ms.hasDeterminedMember = false
+ for resourceID, details := range resultsMap {
+ existing, ok := ms.membersByID[resourceID]
+ if !ok || details.Membership == v1.ResourceCheckResult_NOT_MEMBER {
+ continue
+ }
+ if existing == nil && details.Expression == nil {
+ ms.hasDeterminedMember = true
+ continue
+ }
+
+ ms.membersByID[resourceID] = caveatAnd(existing, details.Expression)
+ }
+}
+
+// Subtract subtracts the results found in the given map with the members of this set.
+// The changes are made in-place.
+func (ms *MembershipSet) Subtract(resultsMap CheckResultsMap) {
+ ms.hasDeterminedMember = false
+ for resourceID, expression := range ms.membersByID {
+ if details, ok := resultsMap[resourceID]; ok && details.Membership != v1.ResourceCheckResult_NOT_MEMBER {
+ // If the incoming member has no caveat, then this removal is absolute.
+ if details.Expression == nil {
+ delete(ms.membersByID, resourceID)
+ continue
+ }
+
+ // Otherwise, the caveat expression gets combined with an intersection of the inversion
+ // of the expression.
+ ms.membersByID[resourceID] = caveatSub(expression, details.Expression)
+ } else {
+ if expression == nil {
+ ms.hasDeterminedMember = true
+ }
+ }
+ }
+}
+
+// HasConcreteResourceID returns whether the resourceID was found in the set
+// and has no caveat attached.
+func (ms *MembershipSet) HasConcreteResourceID(resourceID string) bool {
+ if ms == nil {
+ return false
+ }
+
+ found, ok := ms.membersByID[resourceID]
+ return ok && found == nil
+}
+
+// GetResourceID returns a bool indicating whether the resource is found in the set and the
+// associated caveat expression, if any.
+func (ms *MembershipSet) GetResourceID(resourceID string) (bool, *core.CaveatExpression) {
+ if ms == nil {
+ return false, nil
+ }
+
+ caveat, ok := ms.membersByID[resourceID]
+ return ok, caveat
+}
+
+// Size returns the number of elements in the membership set.
+func (ms *MembershipSet) Size() int {
+ if ms == nil {
+ return 0
+ }
+
+ return len(ms.membersByID)
+}
+
+// IsEmpty returns true if the set is empty.
+func (ms *MembershipSet) IsEmpty() bool {
+ if ms == nil {
+ return true
+ }
+
+ return len(ms.membersByID) == 0
+}
+
+// HasDeterminedMember returns whether there exists at least one non-caveated member of the set.
+func (ms *MembershipSet) HasDeterminedMember() bool {
+ if ms == nil {
+ return false
+ }
+
+ return ms.hasDeterminedMember
+}
+
+// AsCheckResultsMap converts the membership set back into a CheckResultsMap for placement into
+// a DispatchCheckResult.
+func (ms *MembershipSet) AsCheckResultsMap() CheckResultsMap {
+ resultsMap := make(CheckResultsMap, len(ms.membersByID))
+ for resourceID, caveat := range ms.membersByID {
+ membership := v1.ResourceCheckResult_MEMBER
+ if caveat != nil {
+ membership = v1.ResourceCheckResult_CAVEATED_MEMBER
+ }
+
+ resultsMap[resourceID] = &v1.ResourceCheckResult{
+ Membership: membership,
+ Expression: caveat,
+ }
+ }
+
+ return resultsMap
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go b/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go
new file mode 100644
index 0000000..4e41955
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go
@@ -0,0 +1,248 @@
+package graph
+
+import (
+ "sort"
+ "sync"
+
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+type syncONRSet struct {
+ sync.Mutex
+ items map[string]struct{} // GUARDED_BY(Mutex)
+}
+
+func (s *syncONRSet) Add(onr *core.ObjectAndRelation) bool {
+ key := tuple.StringONR(tuple.FromCoreObjectAndRelation(onr))
+ s.Lock()
+ _, existed := s.items[key]
+ if !existed {
+ s.items[key] = struct{}{}
+ }
+ s.Unlock()
+ return !existed
+}
+
+func NewSyncONRSet() *syncONRSet {
+ return &syncONRSet{items: make(map[string]struct{})}
+}
+
+// resourcesSubjectMap2 is a multimap which tracks mappings from found resource IDs
+// to the subject IDs (may be more than one) for each, as well as whether the mapping
+// is conditional due to the use of a caveat on the relationship which formed the mapping.
+type resourcesSubjectMap2 struct {
+ resourceType *core.RelationReference
+ resourcesAndSubjects *mapz.MultiMap[string, subjectInfo2]
+}
+
+// subjectInfo2 is the information about a subject contained in a resourcesSubjectMap2.
+type subjectInfo2 struct {
+ subjectID string
+ missingContextParameters []string
+}
+
+func newResourcesSubjectMap2(resourceType *core.RelationReference) resourcesSubjectMap2 {
+ return resourcesSubjectMap2{
+ resourceType: resourceType,
+ resourcesAndSubjects: mapz.NewMultiMap[string, subjectInfo2](),
+ }
+}
+
+func newResourcesSubjectMap2WithCapacity(resourceType *core.RelationReference, capacity uint32) resourcesSubjectMap2 {
+ return resourcesSubjectMap2{
+ resourceType: resourceType,
+ resourcesAndSubjects: mapz.NewMultiMapWithCap[string, subjectInfo2](capacity),
+ }
+}
+
+func subjectIDsToResourcesMap2(resourceType *core.RelationReference, subjectIDs []string) resourcesSubjectMap2 {
+ rsm := newResourcesSubjectMap2(resourceType)
+ for _, subjectID := range subjectIDs {
+ rsm.addSubjectIDAsFoundResourceID(subjectID)
+ }
+ return rsm
+}
+
+// addRelationship adds the relationship to the resource subject map, recording a mapping from
+// the resource of the relationship to the subject, as well as whether the relationship was caveated.
+func (rsm resourcesSubjectMap2) addRelationship(rel tuple.Relationship, missingContextParameters []string) error {
+ spiceerrors.DebugAssert(func() bool {
+ return rel.Resource.ObjectType == rsm.resourceType.Namespace && rel.Resource.Relation == rsm.resourceType.Relation
+ }, "invalid relationship for addRelationship. expected: %v, found: %v", rsm.resourceType, rel.Resource)
+
+ spiceerrors.DebugAssert(func() bool {
+ return len(missingContextParameters) == 0 || rel.OptionalCaveat != nil
+ }, "missing context parameters must be empty if there is no caveat")
+
+ rsm.resourcesAndSubjects.Add(rel.Resource.ObjectID, subjectInfo2{rel.Subject.ObjectID, missingContextParameters})
+ return nil
+}
+
+// withAdditionalMissingContextForDispatchedResourceID adds additional missing context parameters
+// to the existing missing context parameters for the dispatched resource ID.
+func (rsm resourcesSubjectMap2) withAdditionalMissingContextForDispatchedResourceID(
+ resourceID string,
+ additionalMissingContext []string,
+) {
+ if len(additionalMissingContext) == 0 {
+ return
+ }
+
+ subjectInfo2s, _ := rsm.resourcesAndSubjects.Get(resourceID)
+ updatedInfos := make([]subjectInfo2, 0, len(subjectInfo2s))
+ for _, info := range subjectInfo2s {
+ info.missingContextParameters = append(info.missingContextParameters, additionalMissingContext...)
+ updatedInfos = append(updatedInfos, info)
+ }
+ rsm.resourcesAndSubjects.Set(resourceID, updatedInfos)
+}
+
+// addSubjectIDAsFoundResourceID adds a subject ID directly as a found subject for itself as the resource,
+// with no associated caveat.
+func (rsm resourcesSubjectMap2) addSubjectIDAsFoundResourceID(subjectID string) {
+ rsm.resourcesAndSubjects.Add(subjectID, subjectInfo2{subjectID, nil})
+}
+
+// asReadOnly returns a read-only dispatchableResourcesSubjectMap2 for dispatching for the
+// resources in this map (if any).
+func (rsm resourcesSubjectMap2) asReadOnly() dispatchableResourcesSubjectMap2 {
+ return dispatchableResourcesSubjectMap2{rsm}
+}
+
+func (rsm resourcesSubjectMap2) len() int {
+ return rsm.resourcesAndSubjects.Len()
+}
+
+// dispatchableResourcesSubjectMap2 is a read-only, frozen version of the resourcesSubjectMap2 that
+// can be used for mapping conditionals once calls have been dispatched. This is read-only due to
+// its use by concurrent callers.
+type dispatchableResourcesSubjectMap2 struct {
+ resourcesSubjectMap2
+}
+
+func (rsm dispatchableResourcesSubjectMap2) len() int {
+ return rsm.resourcesAndSubjects.Len()
+}
+
+func (rsm dispatchableResourcesSubjectMap2) isEmpty() bool {
+ return rsm.resourcesAndSubjects.IsEmpty()
+}
+
+func (rsm dispatchableResourcesSubjectMap2) resourceIDs() []string {
+ return rsm.resourcesAndSubjects.Keys()
+}
+
+// filterSubjectIDsToDispatch returns the set of subject IDs that have not yet been
+// dispatched, by adding them to the dispatched set.
+func (rsm dispatchableResourcesSubjectMap2) filterSubjectIDsToDispatch(dispatched *syncONRSet, dispatchSubjectType *core.RelationReference) []string {
+ resourceIDs := rsm.resourceIDs()
+ filtered := make([]string, 0, len(resourceIDs))
+ for _, resourceID := range resourceIDs {
+ if dispatched.Add(&core.ObjectAndRelation{
+ Namespace: dispatchSubjectType.Namespace,
+ ObjectId: resourceID,
+ Relation: dispatchSubjectType.Relation,
+ }) {
+ filtered = append(filtered, resourceID)
+ }
+ }
+
+ return filtered
+}
+
+// cloneAsMutable returns a mutable clone of this dispatchableResourcesSubjectMap2.
+func (rsm dispatchableResourcesSubjectMap2) cloneAsMutable() resourcesSubjectMap2 {
+ return resourcesSubjectMap2{
+ resourceType: rsm.resourceType,
+ resourcesAndSubjects: rsm.resourcesAndSubjects.Clone(),
+ }
+}
+
+func (rsm dispatchableResourcesSubjectMap2) asPossibleResources() []*v1.PossibleResource {
+ resources := make([]*v1.PossibleResource, 0, rsm.resourcesAndSubjects.Len())
+
+ // Sort for stability.
+ sortedResourceIds := rsm.resourcesAndSubjects.Keys()
+ sort.Strings(sortedResourceIds)
+
+ for _, resourceID := range sortedResourceIds {
+ subjectInfo2s, _ := rsm.resourcesAndSubjects.Get(resourceID)
+ subjectIDs := make([]string, 0, len(subjectInfo2s))
+ allCaveated := true
+ nonCaveatedSubjectIDs := make([]string, 0, len(subjectInfo2s))
+ missingContextParameters := mapz.NewSet[string]()
+
+ for _, info := range subjectInfo2s {
+ subjectIDs = append(subjectIDs, info.subjectID)
+ if len(info.missingContextParameters) == 0 {
+ allCaveated = false
+ nonCaveatedSubjectIDs = append(nonCaveatedSubjectIDs, info.subjectID)
+ } else {
+ missingContextParameters.Extend(info.missingContextParameters)
+ }
+ }
+
+ // Sort for stability.
+ sort.Strings(subjectIDs)
+
+ // If all the incoming edges are caveated, then the entire status has to be marked as a check
+ // is required. Otherwise, if there is at least *one* non-caveated incoming edge, then we can
+ // return the existing status as a short-circuit for those non-caveated found subjects.
+ if allCaveated {
+ resources = append(resources, &v1.PossibleResource{
+ ResourceId: resourceID,
+ ForSubjectIds: subjectIDs,
+ MissingContextParams: missingContextParameters.AsSlice(),
+ })
+ } else {
+ resources = append(resources, &v1.PossibleResource{
+ ResourceId: resourceID,
+ ForSubjectIds: nonCaveatedSubjectIDs,
+ })
+ }
+ }
+ return resources
+}
+
+func (rsm dispatchableResourcesSubjectMap2) mapPossibleResource(foundResource *v1.PossibleResource) (*v1.PossibleResource, error) {
+ forSubjectIDs := mapz.NewSet[string]()
+ nonCaveatedSubjectIDs := mapz.NewSet[string]()
+ missingContextParameters := mapz.NewSet[string]()
+
+ for _, forSubjectID := range foundResource.ForSubjectIds {
+ // Map from the incoming subject ID to the subject ID(s) that caused the dispatch.
+ infos, ok := rsm.resourcesAndSubjects.Get(forSubjectID)
+ if !ok {
+ return nil, spiceerrors.MustBugf("missing for subject ID")
+ }
+
+ for _, info := range infos {
+ forSubjectIDs.Insert(info.subjectID)
+ if len(info.missingContextParameters) == 0 {
+ nonCaveatedSubjectIDs.Insert(info.subjectID)
+ } else {
+ missingContextParameters.Extend(info.missingContextParameters)
+ }
+ }
+ }
+
+ // If there are some non-caveated IDs, return those and mark as the parent status.
+ if nonCaveatedSubjectIDs.Len() > 0 {
+ return &v1.PossibleResource{
+ ResourceId: foundResource.ResourceId,
+ ForSubjectIds: nonCaveatedSubjectIDs.AsSlice(),
+ }, nil
+ }
+
+ // Otherwise, everything is caveated, so return the full set of subject IDs and mark
+ // as a check is required.
+ return &v1.PossibleResource{
+ ResourceId: foundResource.ResourceId,
+ ForSubjectIds: forSubjectIDs.AsSlice(),
+ MissingContextParams: missingContextParameters.AsSlice(),
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/traceid.go b/vendor/github.com/authzed/spicedb/internal/graph/traceid.go
new file mode 100644
index 0000000..e275bc5
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/graph/traceid.go
@@ -0,0 +1,13 @@
+package graph
+
+import (
+ "github.com/google/uuid"
+)
+
+// NewTraceID generates a new trace ID. The trace IDs will only be unique with
+// a single dispatch request tree and should not be used for any other purpose.
+// This function currently uses the UUID library to generate a new trace ID,
+// which means it should not be invoked from performance-critical code paths.
+func NewTraceID() string {
+ return uuid.NewString()
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/grpchelpers/grpchelpers.go b/vendor/github.com/authzed/spicedb/internal/grpchelpers/grpchelpers.go
new file mode 100644
index 0000000..de91a0f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/grpchelpers/grpchelpers.go
@@ -0,0 +1,20 @@
+package grpchelpers
+
+import (
+ "context"
+
+ "google.golang.org/grpc"
+)
+
+// DialAndWait creates a new client connection to the target and blocks until the connection is ready.
+func DialAndWait(ctx context.Context, target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
+ // TODO: move to NewClient
+ opts = append(opts, grpc.WithBlock()) // nolint: staticcheck
+ return grpc.DialContext(ctx, target, opts...) // nolint: staticcheck
+}
+
+// Dial creates a new client connection to the target.
+func Dial(ctx context.Context, target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
+ // TODO: move to NewClient
+ return grpc.DialContext(ctx, target, opts...) // nolint: staticcheck
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/logging/logger.go b/vendor/github.com/authzed/spicedb/internal/logging/logger.go
new file mode 100644
index 0000000..8204af9
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/logging/logger.go
@@ -0,0 +1,43 @@
+package logging
+
+import (
+ "context"
+
+ "github.com/go-logr/zerologr"
+ "github.com/rs/zerolog"
+ logf "sigs.k8s.io/controller-runtime/pkg/log"
+)
+
+var Logger zerolog.Logger
+
+func init() {
+ SetGlobalLogger(zerolog.Nop())
+ logf.SetLogger(zerologr.New(&Logger))
+}
+
+func SetGlobalLogger(logger zerolog.Logger) {
+ Logger = logger
+ zerolog.DefaultContextLogger = &Logger
+}
+
+func With() zerolog.Context { return Logger.With() }
+
+func Err(err error) *zerolog.Event { return Logger.Err(err) }
+
+func Trace() *zerolog.Event { return Logger.Trace() }
+
+func Debug() *zerolog.Event { return Logger.Debug() }
+
+func Info() *zerolog.Event { return Logger.Info() }
+
+func Warn() *zerolog.Event { return Logger.Warn() }
+
+func Error() *zerolog.Event { return Logger.Error() }
+
+func Fatal() *zerolog.Event { return Logger.Fatal() }
+
+func WithLevel(level zerolog.Level) *zerolog.Event { return Logger.WithLevel(level) }
+
+func Log() *zerolog.Event { return Logger.Log() }
+
+func Ctx(ctx context.Context) *zerolog.Logger { return zerolog.Ctx(ctx) }
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/chain.go b/vendor/github.com/authzed/spicedb/internal/middleware/chain.go
new file mode 100644
index 0000000..de08ffc
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/chain.go
@@ -0,0 +1,58 @@
+package middleware
+
+import (
+ "context"
+
+ "google.golang.org/grpc"
+)
+
+// Vendored from grpc-go-middleware
+// These were removed in v2, see: https://github.com/grpc-ecosystem/go-grpc-middleware/pull/385
+
+// ChainUnaryServer creates a single interceptor out of a chain of many interceptors.
+//
+// Execution is done in left-to-right order, including passing of context.
+// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three
+// will see context changes of one and two.
+func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
+ n := len(interceptors)
+
+ return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ chainer := func(currentInter grpc.UnaryServerInterceptor, currentHandler grpc.UnaryHandler) grpc.UnaryHandler {
+ return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
+ return currentInter(currentCtx, currentReq, info, currentHandler)
+ }
+ }
+
+ chainedHandler := handler
+ for i := n - 1; i >= 0; i-- {
+ chainedHandler = chainer(interceptors[i], chainedHandler)
+ }
+
+ return chainedHandler(ctx, req)
+ }
+}
+
+// ChainStreamServer creates a single interceptor out of a chain of many interceptors.
+//
+// Execution is done in left-to-right order, including passing of context.
+// For example ChainUnaryServer(one, two, three) will execute one before two before three.
+// If you want to pass context between interceptors, use WrapServerStream.
+func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
+ n := len(interceptors)
+
+ return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ chainer := func(currentInter grpc.StreamServerInterceptor, currentHandler grpc.StreamHandler) grpc.StreamHandler {
+ return func(currentSrv interface{}, currentStream grpc.ServerStream) error {
+ return currentInter(currentSrv, currentStream, info, currentHandler)
+ }
+ }
+
+ chainedHandler := handler
+ for i := n - 1; i >= 0; i-- {
+ chainedHandler = chainer(interceptors[i], chainedHandler)
+ }
+
+ return chainedHandler(srv, ss)
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/datastore/datastore.go b/vendor/github.com/authzed/spicedb/internal/middleware/datastore/datastore.go
new file mode 100644
index 0000000..8c321b3
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/datastore/datastore.go
@@ -0,0 +1,85 @@
+package datastore
+
+import (
+ "context"
+
+ middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
+ "google.golang.org/grpc"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+type ctxKeyType struct{}
+
+var datastoreKey ctxKeyType = struct{}{}
+
+type datastoreHandle struct {
+ datastore datastore.Datastore
+}
+
+// ContextWithHandle adds a placeholder to a context that will later be
+// filled by the datastore
+func ContextWithHandle(ctx context.Context) context.Context {
+ return context.WithValue(ctx, datastoreKey, &datastoreHandle{})
+}
+
+// FromContext reads the selected datastore out of a context.Context
+// and returns nil if it does not exist.
+func FromContext(ctx context.Context) datastore.Datastore {
+ if c := ctx.Value(datastoreKey); c != nil {
+ handle := c.(*datastoreHandle)
+ return handle.datastore
+ }
+ return nil
+}
+
+// MustFromContext reads the selected datastore out of a context.Context and panics if it does not exist
+func MustFromContext(ctx context.Context) datastore.Datastore {
+ datastore := FromContext(ctx)
+ if datastore == nil {
+ panic("datastore middleware did not inject datastore")
+ }
+
+ return datastore
+}
+
+// SetInContext adds a datastore to the given context
+func SetInContext(ctx context.Context, datastore datastore.Datastore) error {
+ handle := ctx.Value(datastoreKey)
+ if handle == nil {
+ return nil
+ }
+ handle.(*datastoreHandle).datastore = datastore
+ return nil
+}
+
+// ContextWithDatastore adds the handle and datastore in one step
+func ContextWithDatastore(ctx context.Context, datastore datastore.Datastore) context.Context {
+ return context.WithValue(ctx, datastoreKey, &datastoreHandle{datastore: datastore})
+}
+
+// UnaryServerInterceptor returns a new unary server interceptor that adds the
+// datastore to the context
+func UnaryServerInterceptor(datastore datastore.Datastore) grpc.UnaryServerInterceptor {
+ return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ newCtx := ContextWithHandle(ctx)
+ if err := SetInContext(newCtx, datastore); err != nil {
+ return nil, err
+ }
+
+ return handler(newCtx, req)
+ }
+}
+
+// StreamServerInterceptor returns a new stream server interceptor that adds the
+// datastore to the context
+func StreamServerInterceptor(datastore datastore.Datastore) grpc.StreamServerInterceptor {
+ return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ wrapped := middleware.WrapServerStream(stream)
+ wrapped.WrappedContext = ContextWithHandle(wrapped.WrappedContext)
+ if err := SetInContext(wrapped.WrappedContext, datastore); err != nil {
+ return err
+ }
+ return handler(srv, wrapped)
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/datastore/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/datastore/doc.go
new file mode 100644
index 0000000..a4d0cf0
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/datastore/doc.go
@@ -0,0 +1,2 @@
+// Package datastore defines middleware that injects the datastore into the context.
+package datastore
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/doc.go
new file mode 100644
index 0000000..b11553e
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/doc.go
@@ -0,0 +1,2 @@
+// Package middleware defines various custom middlewares.
+package middleware
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/doc.go
new file mode 100644
index 0000000..2d9aa01
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/doc.go
@@ -0,0 +1,2 @@
+// Package handwrittenvalidation defines middleware that runs custom-made validations on incoming requests.
+package handwrittenvalidation
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/handwrittenvalidation.go b/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/handwrittenvalidation.go
new file mode 100644
index 0000000..2adc4b3
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/handwrittenvalidation.go
@@ -0,0 +1,54 @@
+package handwrittenvalidation
+
+import (
+ "context"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+type handwrittenValidator interface {
+ HandwrittenValidate() error
+}
+
+// UnaryServerInterceptor returns a new unary server interceptor that runs the handwritten validation
+// on the incoming request, if any.
+func UnaryServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ validator, ok := req.(handwrittenValidator)
+ if ok {
+ err := validator.HandwrittenValidate()
+ if err != nil {
+ return nil, status.Errorf(codes.InvalidArgument, "%s", err)
+ }
+ }
+
+ return handler(ctx, req)
+}
+
+// StreamServerInterceptor returns a new stream server interceptor that runs the handwritten validation
+// on the incoming request messages, if any.
+func StreamServerInterceptor(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ wrapper := &recvWrapper{stream}
+ return handler(srv, wrapper)
+}
+
+type recvWrapper struct {
+ grpc.ServerStream
+}
+
+func (s *recvWrapper) RecvMsg(m interface{}) error {
+ if err := s.ServerStream.RecvMsg(m); err != nil {
+ return err
+ }
+
+ validator, ok := m.(handwrittenValidator)
+ if ok {
+ err := validator.HandwrittenValidate()
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/doc.go
new file mode 100644
index 0000000..7e3ae0f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/doc.go
@@ -0,0 +1,2 @@
+// Package servicespecific defines middleware that injects other middlewares.
+package servicespecific
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/servicespecific.go b/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/servicespecific.go
new file mode 100644
index 0000000..10fe753
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/servicespecific/servicespecific.go
@@ -0,0 +1,39 @@
+package servicespecific
+
+import (
+ "context"
+
+ "google.golang.org/grpc"
+)
+
+// ExtraUnaryInterceptor is an interface for a service which has its own bundled
+// unary interceptors that must be run.
+type ExtraUnaryInterceptor interface {
+ UnaryInterceptor() grpc.UnaryServerInterceptor
+}
+
+// ExtraStreamInterceptor is an interface for a service which has its own bundled
+// stream interceptors that must be run.
+type ExtraStreamInterceptor interface {
+ StreamInterceptor() grpc.StreamServerInterceptor
+}
+
+// UnaryServerInterceptor returns a new unary server interceptor that runs bundled interceptors.
+func UnaryServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ if hasExtraInterceptor, ok := info.Server.(ExtraUnaryInterceptor); ok {
+ interceptor := hasExtraInterceptor.UnaryInterceptor()
+ return interceptor(ctx, req, info, handler)
+ }
+
+ return handler(ctx, req)
+}
+
+// StreamServerInterceptor returns a new stream server interceptor that runs bundled interceptors.
+func StreamServerInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ if hasExtraInterceptor, ok := srv.(ExtraStreamInterceptor); ok {
+ interceptor := hasExtraInterceptor.StreamInterceptor()
+ return interceptor(srv, stream, info, handler)
+ }
+
+ return handler(srv, stream)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/doc.go
new file mode 100644
index 0000000..9eceb4d
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/doc.go
@@ -0,0 +1,2 @@
+// Package streamtimeout defines middleware that cancels the context after a timeout if no new data has been received.
+package streamtimeout
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/streamtimeout.go b/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/streamtimeout.go
new file mode 100644
index 0000000..8f09fdb
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/streamtimeout/streamtimeout.go
@@ -0,0 +1,57 @@
+package streamtimeout
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/metadata"
+
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// MustStreamServerInterceptor returns a new stream server interceptor that cancels the context
+// after a timeout if no new data has been received.
+func MustStreamServerInterceptor(timeout time.Duration) grpc.StreamServerInterceptor {
+ if timeout <= 0 {
+ panic("timeout must be >= 0 for streaming timeout interceptor")
+ }
+
+ return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ ctx := stream.Context()
+ withCancel, internalCancelFn := context.WithCancelCause(ctx)
+ timer := time.AfterFunc(timeout, func() {
+ internalCancelFn(spiceerrors.WithCodeAndDetailsAsError(fmt.Errorf("operation took longer than allowed %v to complete", timeout), codes.DeadlineExceeded))
+ })
+ wrapper := &sendWrapper{stream, withCancel, timer, timeout}
+ return handler(srv, wrapper)
+ }
+}
+
+type sendWrapper struct {
+ grpc.ServerStream
+
+ ctx context.Context
+ timer *time.Timer
+ timeout time.Duration
+}
+
+func (s *sendWrapper) Context() context.Context {
+ return s.ctx
+}
+
+func (s *sendWrapper) SetTrailer(_ metadata.MD) {
+ s.timer.Stop()
+}
+
+func (s *sendWrapper) SendMsg(m any) error {
+ err := s.ServerStream.SendMsg(m)
+ if err != nil {
+ s.timer.Stop()
+ } else {
+ s.timer.Reset(s.timeout)
+ }
+ return err
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/doc.go b/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/doc.go
new file mode 100644
index 0000000..c05cacc
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/doc.go
@@ -0,0 +1,2 @@
+// Package usagemetrics defines middleware that adds usage data (e.g. dispatch counts) to the response.
+package usagemetrics
diff --git a/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/usagemetrics.go b/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/usagemetrics.go
new file mode 100644
index 0000000..32f5676
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/middleware/usagemetrics/usagemetrics.go
@@ -0,0 +1,128 @@
+package usagemetrics
+
+import (
+ "context"
+ "strconv"
+ "time"
+
+ "github.com/authzed/authzed-go/pkg/responsemeta"
+ "github.com/authzed/grpcutil"
+ "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "google.golang.org/grpc"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+)
+
+var (
+ // DispatchedCountLabels are the labels that DispatchedCountHistogram will
+ // have by default.
+ DispatchedCountLabels = []string{"method", "cached"}
+
+ // DispatchedCountHistogram is the metric that SpiceDB uses to keep track
+ // of the number of downstream dispatches that are performed to answer a
+ // single query.
+ DispatchedCountHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: "spicedb",
+ Subsystem: "services",
+ Name: "dispatches",
+ Help: "Histogram of cluster dispatches performed by the instance.",
+ Buckets: []float64{1, 5, 10, 25, 50, 100, 250},
+ }, DispatchedCountLabels)
+)
+
+type reporter struct{}
+
+func (r *reporter) ServerReporter(ctx context.Context, callMeta interceptors.CallMeta) (interceptors.Reporter, context.Context) {
+ _, methodName := grpcutil.SplitMethodName(callMeta.FullMethod())
+ ctx = ContextWithHandle(ctx)
+ return &serverReporter{ctx: ctx, methodName: methodName}, ctx
+}
+
+type serverReporter struct {
+ interceptors.NoopReporter
+ ctx context.Context
+ methodName string
+}
+
+func (r *serverReporter) PostCall(_ error, _ time.Duration) {
+ responseMeta := FromContext(r.ctx)
+ if responseMeta == nil {
+ responseMeta = &dispatch.ResponseMeta{}
+ }
+
+ err := annotateAndReportForMetadata(r.ctx, r.methodName, responseMeta)
+ // if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite
+ // this prevents logging unnecessary error messages
+ if r.ctx.Err() != nil {
+ return
+ }
+ if err != nil {
+ log.Ctx(r.ctx).Warn().Err(err).Msg("usagemetrics: could not report metadata")
+ }
+}
+
+// UnaryServerInterceptor implements a gRPC Middleware for reporting usage metrics
+// in both the trailer of the request, as well as to the registered prometheus
+// metrics.
+func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
+ return interceptors.UnaryServerInterceptor(&reporter{})
+}
+
+// StreamServerInterceptor implements a gRPC Middleware for reporting usage metrics
+// in both the trailer of the request, as well as to the registered prometheus
+// metrics
+func StreamServerInterceptor() grpc.StreamServerInterceptor {
+ return interceptors.StreamServerInterceptor(&reporter{})
+}
+
+func annotateAndReportForMetadata(ctx context.Context, methodName string, metadata *dispatch.ResponseMeta) error {
+ DispatchedCountHistogram.WithLabelValues(methodName, "false").Observe(float64(metadata.DispatchCount))
+ DispatchedCountHistogram.WithLabelValues(methodName, "true").Observe(float64(metadata.CachedDispatchCount))
+
+ return responsemeta.SetResponseTrailerMetadata(ctx, map[responsemeta.ResponseMetadataTrailerKey]string{
+ responsemeta.DispatchedOperationsCount: strconv.Itoa(int(metadata.DispatchCount)),
+ responsemeta.CachedOperationsCount: strconv.Itoa(int(metadata.CachedDispatchCount)),
+ })
+}
+
+// Create a new type to prevent context collisions
+type responseMetaKey string
+
+var metadataCtxKey responseMetaKey = "dispatched-response-meta"
+
+type metaHandle struct{ metadata *dispatch.ResponseMeta }
+
+// SetInContext should be called in a gRPC handler to correctly set the response metadata
+// for the dispatched request.
+func SetInContext(ctx context.Context, metadata *dispatch.ResponseMeta) {
+ possibleHandle := ctx.Value(metadataCtxKey)
+ if possibleHandle == nil {
+ return
+ }
+
+ handle := possibleHandle.(*metaHandle)
+ handle.metadata = metadata
+}
+
+// FromContext returns any metadata that was stored in the context.
+//
+// This is useful for testing that a handler is properly setting the context.
+func FromContext(ctx context.Context) *dispatch.ResponseMeta {
+ possibleHandle := ctx.Value(metadataCtxKey)
+ if possibleHandle == nil {
+ return nil
+ }
+ return possibleHandle.(*metaHandle).metadata
+}
+
+// ContextWithHandle creates a new context with a location to store metadata
+// returned from a dispatched request.
+//
+// This should only be called in middleware or testing functions.
+func ContextWithHandle(ctx context.Context) context.Context {
+ var handle metaHandle
+ return context.WithValue(ctx, metadataCtxKey, &handle)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/aliasing.go b/vendor/github.com/authzed/spicedb/internal/namespace/aliasing.go
new file mode 100644
index 0000000..adbaa94
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/namespace/aliasing.go
@@ -0,0 +1,82 @@
+package namespace
+
+import (
+ "sort"
+
+ "github.com/authzed/spicedb/pkg/schema"
+)
+
+// computePermissionAliases computes a map of aliases between the various permissions in a
+// namespace. A permission is considered an alias if it *directly* refers to another permission
+// or relation without any other form of expression.
+func computePermissionAliases(typeDefinition *schema.ValidatedDefinition) (map[string]string, error) {
+ aliases := map[string]string{}
+ done := map[string]struct{}{}
+ unresolvedAliases := map[string]string{}
+
+ for _, rel := range typeDefinition.Namespace().Relation {
+ // Ensure the relation has a rewrite...
+ if rel.GetUsersetRewrite() == nil {
+ done[rel.Name] = struct{}{}
+ continue
+ }
+
+ // ... with a union ...
+ union := rel.GetUsersetRewrite().GetUnion()
+ if union == nil {
+ done[rel.Name] = struct{}{}
+ continue
+ }
+
+ // ... with a single child ...
+ if len(union.Child) != 1 {
+ done[rel.Name] = struct{}{}
+ continue
+ }
+
+ // ... that is a computed userset.
+ computedUserset := union.Child[0].GetComputedUserset()
+ if computedUserset == nil {
+ done[rel.Name] = struct{}{}
+ continue
+ }
+
+ // If the aliased item is a relation, then we've found the alias target.
+ aliasedPermOrRel := computedUserset.GetRelation()
+ if !typeDefinition.IsPermission(aliasedPermOrRel) {
+ done[rel.Name] = struct{}{}
+ aliases[rel.Name] = aliasedPermOrRel
+ continue
+ }
+
+ // Otherwise, add the permission to the working set.
+ unresolvedAliases[rel.Name] = aliasedPermOrRel
+ }
+
+ for len(unresolvedAliases) > 0 {
+ startingCount := len(unresolvedAliases)
+ for relName, aliasedPermission := range unresolvedAliases {
+ if _, ok := done[aliasedPermission]; ok {
+ done[relName] = struct{}{}
+
+ if alias, ok := aliases[aliasedPermission]; ok {
+ aliases[relName] = alias
+ } else {
+ aliases[relName] = aliasedPermission
+ }
+ delete(unresolvedAliases, relName)
+ continue
+ }
+ }
+ if len(unresolvedAliases) == startingCount {
+ keys := make([]string, 0, len(unresolvedAliases))
+ for key := range unresolvedAliases {
+ keys = append(keys, key)
+ }
+ sort.Strings(keys)
+ return nil, NewPermissionsCycleErr(typeDefinition.Namespace().Name, keys)
+ }
+ }
+
+ return aliases, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/annotate.go b/vendor/github.com/authzed/spicedb/internal/namespace/annotate.go
new file mode 100644
index 0000000..d85edff
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/namespace/annotate.go
@@ -0,0 +1,29 @@
+package namespace
+
+import "github.com/authzed/spicedb/pkg/schema"
+
+// AnnotateNamespace annotates the namespace in the type system with computed aliasing and cache key
+// metadata for more efficient dispatching.
+func AnnotateNamespace(def *schema.ValidatedDefinition) error {
+ aliases, aerr := computePermissionAliases(def)
+ if aerr != nil {
+ return aerr
+ }
+
+ cacheKeys, cerr := computeCanonicalCacheKeys(def, aliases)
+ if cerr != nil {
+ return cerr
+ }
+
+ for _, rel := range def.Namespace().Relation {
+ if alias, ok := aliases[rel.Name]; ok {
+ rel.AliasingRelation = alias
+ }
+
+ if cacheKey, ok := cacheKeys[rel.Name]; ok {
+ rel.CanonicalCacheKey = cacheKey
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/canonicalization.go b/vendor/github.com/authzed/spicedb/internal/namespace/canonicalization.go
new file mode 100644
index 0000000..24fa61e
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/namespace/canonicalization.go
@@ -0,0 +1,282 @@
+package namespace
+
+import (
+ "encoding/hex"
+ "hash/fnv"
+
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+
+ "github.com/dalzilio/rudd"
+
+ "github.com/authzed/spicedb/pkg/graph"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+)
+
+const computedKeyPrefix = "%"
+
+// computeCanonicalCacheKeys computes a map from permission name to associated canonicalized
+// cache key for each non-aliased permission in the given type system's namespace.
+//
+// Canonicalization works by taking each permission's userset rewrite expression and transforming
+// it into a Binary Decision Diagram (BDD) via the `rudd` library.
+//
+// Each access of a relation or arrow is assigned a unique integer ID within the *namespace*,
+// and the operations (+, -, &) are converted into binary operations.
+//
+// For example, for the namespace:
+//
+// definition somenamespace {
+// relation first: ...
+// relation second: ...
+// relation third: ...
+// permission someperm = second + (first - third->something)
+// }
+//
+// We begin by assigning a unique integer index to each relation and arrow found for all
+// expressions in the namespace:
+//
+// definition somenamespace {
+// relation first: ...
+// ^ index 0
+// relation second: ...
+// ^ index 1
+// relation third: ...
+// ^ index 2
+// permission someperm = second + (first - third->something)
+// ^ 1 ^ 0 ^ index 3
+// }
+//
+// These indexes are then used with the rudd library to build the expression:
+//
+// someperm => `bdd.Or(bdd.Ithvar(1), bdd.And(bdd.Ithvar(0), bdd.NIthvar(2)))`
+//
+// The `rudd` library automatically handles associativity, and produces a hash representing the
+// canonical representation of the binary expression. These hashes can then be used for caching,
+// representing the same *logical* expressions for a permission, even if the relations have
+// different names.
+func computeCanonicalCacheKeys(typeDef *schema.ValidatedDefinition, aliasMap map[string]string) (map[string]string, error) {
+ varMap, err := buildBddVarMap(typeDef.Namespace().Relation, aliasMap)
+ if err != nil {
+ return nil, err
+ }
+
+ if varMap.Len() == 0 {
+ return map[string]string{}, nil
+ }
+
+ bdd, err := rudd.New(varMap.Len())
+ if err != nil {
+ return nil, err
+ }
+
+ // For each permission, build a canonicalized cache key based on its expression.
+ cacheKeys := make(map[string]string, len(typeDef.Namespace().Relation))
+ for _, rel := range typeDef.Namespace().Relation {
+ rewrite := rel.GetUsersetRewrite()
+ if rewrite == nil {
+ // If the relation has no rewrite (making it a pure relation), then its canonical
+ // key is simply the relation's name.
+ cacheKeys[rel.Name] = rel.Name
+ continue
+ }
+
+ hasher := fnv.New64a()
+ node, err := convertRewriteToBdd(rel, bdd, rewrite, varMap)
+ if err != nil {
+ return nil, err
+ }
+
+ bdd.Print(hasher, node)
+ cacheKeys[rel.Name] = computedKeyPrefix + hex.EncodeToString(hasher.Sum(nil))
+ }
+
+ return cacheKeys, nil
+}
+
+func convertRewriteToBdd(relation *core.Relation, bdd *rudd.BDD, rewrite *core.UsersetRewrite, varMap bddVarMap) (rudd.Node, error) {
+ switch rw := rewrite.RewriteOperation.(type) {
+ case *core.UsersetRewrite_Union:
+ return convertToBdd(relation, bdd, rw.Union, bdd.Or, func(childIndex int, varIndex int) rudd.Node {
+ return bdd.Ithvar(varIndex)
+ }, varMap)
+
+ case *core.UsersetRewrite_Intersection:
+ return convertToBdd(relation, bdd, rw.Intersection, bdd.And, func(childIndex int, varIndex int) rudd.Node {
+ return bdd.Ithvar(varIndex)
+ }, varMap)
+
+ case *core.UsersetRewrite_Exclusion:
+ return convertToBdd(relation, bdd, rw.Exclusion, bdd.And, func(childIndex int, varIndex int) rudd.Node {
+ if childIndex == 0 {
+ return bdd.Ithvar(varIndex)
+ }
+ return bdd.NIthvar(varIndex)
+ }, varMap)
+
+ default:
+ return nil, spiceerrors.MustBugf("Unknown rewrite kind %v", rw)
+ }
+}
+
+type (
+ combiner func(n ...rudd.Node) rudd.Node
+ builder func(childIndex int, varIndex int) rudd.Node
+)
+
+func convertToBdd(relation *core.Relation, bdd *rudd.BDD, so *core.SetOperation, combiner combiner, builder builder, varMap bddVarMap) (rudd.Node, error) {
+ values := make([]rudd.Node, 0, len(so.Child))
+ for index, childOneof := range so.Child {
+ switch child := childOneof.ChildType.(type) {
+ case *core.SetOperation_Child_XThis:
+ return nil, spiceerrors.MustBugf("use of _this is disallowed")
+
+ case *core.SetOperation_Child_ComputedUserset:
+ cuIndex, err := varMap.Get(child.ComputedUserset.Relation)
+ if err != nil {
+ return nil, err
+ }
+
+ values = append(values, builder(index, cuIndex))
+
+ case *core.SetOperation_Child_UsersetRewrite:
+ node, err := convertRewriteToBdd(relation, bdd, child.UsersetRewrite, varMap)
+ if err != nil {
+ return nil, err
+ }
+
+ values = append(values, node)
+
+ case *core.SetOperation_Child_TupleToUserset:
+ arrowIndex, err := varMap.GetArrow(child.TupleToUserset.Tupleset.Relation, child.TupleToUserset.ComputedUserset.Relation)
+ if err != nil {
+ return nil, err
+ }
+
+ values = append(values, builder(index, arrowIndex))
+
+ case *core.SetOperation_Child_FunctionedTupleToUserset:
+ switch child.FunctionedTupleToUserset.Function {
+ case core.FunctionedTupleToUserset_FUNCTION_ANY:
+ arrowIndex, err := varMap.GetArrow(child.FunctionedTupleToUserset.Tupleset.Relation, child.FunctionedTupleToUserset.ComputedUserset.Relation)
+ if err != nil {
+ return nil, err
+ }
+
+ values = append(values, builder(index, arrowIndex))
+
+ case core.FunctionedTupleToUserset_FUNCTION_ALL:
+ arrowIndex, err := varMap.GetIntersectionArrow(child.FunctionedTupleToUserset.Tupleset.Relation, child.FunctionedTupleToUserset.ComputedUserset.Relation)
+ if err != nil {
+ return nil, err
+ }
+
+ values = append(values, builder(index, arrowIndex))
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown function %v", child.FunctionedTupleToUserset.Function)
+ }
+
+ case *core.SetOperation_Child_XNil:
+ values = append(values, builder(index, varMap.Nil()))
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown set operation child %T", child)
+ }
+ }
+ return combiner(values...), nil
+}
+
+type bddVarMap struct {
+ aliasMap map[string]string
+ varMap map[string]int
+}
+
+func (bvm bddVarMap) GetArrow(tuplesetName string, relName string) (int, error) {
+ key := tuplesetName + "->" + relName
+ index, ok := bvm.varMap[key]
+ if !ok {
+ return -1, spiceerrors.MustBugf("missing arrow key %s in varMap", key)
+ }
+ return index, nil
+}
+
+func (bvm bddVarMap) GetIntersectionArrow(tuplesetName string, relName string) (int, error) {
+ key := tuplesetName + "-(all)->" + relName
+ index, ok := bvm.varMap[key]
+ if !ok {
+ return -1, spiceerrors.MustBugf("missing intersection arrow key %s in varMap", key)
+ }
+ return index, nil
+}
+
+func (bvm bddVarMap) Nil() int {
+ return len(bvm.varMap)
+}
+
+func (bvm bddVarMap) Get(relName string) (int, error) {
+ if alias, ok := bvm.aliasMap[relName]; ok {
+ return bvm.Get(alias)
+ }
+
+ index, ok := bvm.varMap[relName]
+ if !ok {
+ return -1, spiceerrors.MustBugf("missing key %s in varMap", relName)
+ }
+ return index, nil
+}
+
+func (bvm bddVarMap) Len() int {
+ return len(bvm.varMap) + 1 // +1 for `nil`
+}
+
+func buildBddVarMap(relations []*core.Relation, aliasMap map[string]string) (bddVarMap, error) {
+ varMap := map[string]int{}
+ for _, rel := range relations {
+ if _, ok := aliasMap[rel.Name]; ok {
+ continue
+ }
+
+ varMap[rel.Name] = len(varMap)
+
+ rewrite := rel.GetUsersetRewrite()
+ if rewrite == nil {
+ continue
+ }
+
+ _, err := graph.WalkRewrite(rewrite, func(childOneof *core.SetOperation_Child) (interface{}, error) {
+ switch child := childOneof.ChildType.(type) {
+ case *core.SetOperation_Child_TupleToUserset:
+ key := child.TupleToUserset.Tupleset.Relation + "->" + child.TupleToUserset.ComputedUserset.Relation
+ if _, ok := varMap[key]; !ok {
+ varMap[key] = len(varMap)
+ }
+ case *core.SetOperation_Child_FunctionedTupleToUserset:
+ key := child.FunctionedTupleToUserset.Tupleset.Relation + "->" + child.FunctionedTupleToUserset.ComputedUserset.Relation
+
+ switch child.FunctionedTupleToUserset.Function {
+ case core.FunctionedTupleToUserset_FUNCTION_ANY:
+ // Use the key.
+
+ case core.FunctionedTupleToUserset_FUNCTION_ALL:
+ key = child.FunctionedTupleToUserset.Tupleset.Relation + "-(all)->" + child.FunctionedTupleToUserset.ComputedUserset.Relation
+
+ default:
+ return nil, spiceerrors.MustBugf("unknown function %v", child.FunctionedTupleToUserset.Function)
+ }
+
+ if _, ok := varMap[key]; !ok {
+ varMap[key] = len(varMap)
+ }
+ }
+ return nil, nil
+ })
+ if err != nil {
+ return bddVarMap{}, err
+ }
+ }
+ return bddVarMap{
+ aliasMap: aliasMap,
+ varMap: varMap,
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/caveats.go b/vendor/github.com/authzed/spicedb/internal/namespace/caveats.go
new file mode 100644
index 0000000..5ddfa9d
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/namespace/caveats.go
@@ -0,0 +1,69 @@
+package namespace
+
+import (
+ "fmt"
+
+ "golang.org/x/exp/maps"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+)
+
+// ValidateCaveatDefinition validates the parameters and types within the given caveat
+// definition, including usage of the parameters.
+func ValidateCaveatDefinition(ts *caveattypes.TypeSet, caveat *core.CaveatDefinition) error {
+ // Ensure all parameters are used by the caveat expression itself.
+ parameterTypes, err := caveattypes.DecodeParameterTypes(ts, caveat.ParameterTypes)
+ if err != nil {
+ return schema.NewTypeWithSourceError(
+ fmt.Errorf("could not decode caveat parameters `%s`: %w", caveat.Name, err),
+ caveat,
+ caveat.Name,
+ )
+ }
+
+ deserialized, err := caveats.DeserializeCaveatWithTypeSet(ts, caveat.SerializedExpression, parameterTypes)
+ if err != nil {
+ return schema.NewTypeWithSourceError(
+ fmt.Errorf("could not decode caveat `%s`: %w", caveat.Name, err),
+ caveat,
+ caveat.Name,
+ )
+ }
+
+ if len(caveat.ParameterTypes) == 0 {
+ return schema.NewTypeWithSourceError(
+ fmt.Errorf("caveat `%s` must have at least one parameter defined", caveat.Name),
+ caveat,
+ caveat.Name,
+ )
+ }
+
+ referencedNames, err := deserialized.ReferencedParameters(maps.Keys(caveat.ParameterTypes))
+ if err != nil {
+ return err
+ }
+
+ for paramName, paramType := range caveat.ParameterTypes {
+ _, err := caveattypes.DecodeParameterType(ts, paramType)
+ if err != nil {
+ return schema.NewTypeWithSourceError(
+ fmt.Errorf("type error for parameter `%s` for caveat `%s`: %w", paramName, caveat.Name, err),
+ caveat,
+ paramName,
+ )
+ }
+
+ if !referencedNames.Has(paramName) {
+ return schema.NewTypeWithSourceError(
+ NewUnusedCaveatParameterErr(caveat.Name, paramName),
+ caveat,
+ paramName,
+ )
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/doc.go b/vendor/github.com/authzed/spicedb/internal/namespace/doc.go
new file mode 100644
index 0000000..1546280
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/namespace/doc.go
@@ -0,0 +1,2 @@
+// Package namespace provides functions for dealing with and validating types, relations and caveats.
+package namespace
diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/errors.go b/vendor/github.com/authzed/spicedb/internal/namespace/errors.go
new file mode 100644
index 0000000..abe7fe6
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/namespace/errors.go
@@ -0,0 +1,171 @@
+package namespace
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/rs/zerolog"
+
+ "github.com/authzed/spicedb/internal/sharederrors"
+)
+
+// NamespaceNotFoundError occurs when a namespace was not found.
+type NamespaceNotFoundError struct {
+ error
+ namespaceName string
+}
+
+// NotFoundNamespaceName is the name of the namespace not found.
+func (err NamespaceNotFoundError) NotFoundNamespaceName() string {
+ return err.namespaceName
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err NamespaceNotFoundError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("namespace", err.namespaceName)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err NamespaceNotFoundError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "definition_name": err.namespaceName,
+ }
+}
+
+// RelationNotFoundError occurs when a relation was not found under a namespace.
+type RelationNotFoundError struct {
+ error
+ namespaceName string
+ relationName string
+}
+
+// NamespaceName returns the name of the namespace in which the relation was not found.
+func (err RelationNotFoundError) NamespaceName() string {
+ return err.namespaceName
+}
+
+// NotFoundRelationName returns the name of the relation not found.
+func (err RelationNotFoundError) NotFoundRelationName() string {
+ return err.relationName
+}
+
+func (err RelationNotFoundError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err RelationNotFoundError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "definition_name": err.namespaceName,
+ "relation_or_permission_name": err.relationName,
+ }
+}
+
+// DuplicateRelationError occurs when a duplicate relation was found inside a namespace.
+type DuplicateRelationError struct {
+ error
+ namespaceName string
+ relationName string
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err DuplicateRelationError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err DuplicateRelationError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "definition_name": err.namespaceName,
+ "relation_or_permission_name": err.relationName,
+ }
+}
+
+// PermissionsCycleError occurs when a cycle exists within permissions.
+type PermissionsCycleError struct {
+ error
+ namespaceName string
+ permissionNames []string
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err PermissionsCycleError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("namespace", err.namespaceName).Str("permissions", strings.Join(err.permissionNames, ", "))
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err PermissionsCycleError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "definition_name": err.namespaceName,
+ "permission_names": strings.Join(err.permissionNames, ","),
+ }
+}
+
+// UnusedCaveatParameterError indicates that a caveat parameter is unused in the caveat expression.
+type UnusedCaveatParameterError struct {
+ error
+ caveatName string
+ paramName string
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err UnusedCaveatParameterError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Str("caveat", err.caveatName).Str("param", err.paramName)
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (err UnusedCaveatParameterError) DetailsMetadata() map[string]string {
+ return map[string]string{
+ "caveat_name": err.caveatName,
+ "parameter_name": err.paramName,
+ }
+}
+
+// NewNamespaceNotFoundErr constructs a new namespace not found error.
+func NewNamespaceNotFoundErr(nsName string) error {
+ return NamespaceNotFoundError{
+ error: fmt.Errorf("object definition `%s` not found", nsName),
+ namespaceName: nsName,
+ }
+}
+
+// NewRelationNotFoundErr constructs a new relation not found error.
+func NewRelationNotFoundErr(nsName string, relationName string) error {
+ return RelationNotFoundError{
+ error: fmt.Errorf("relation/permission `%s` not found under definition `%s`", relationName, nsName),
+ namespaceName: nsName,
+ relationName: relationName,
+ }
+}
+
+// NewDuplicateRelationError constructs an error indicating that a relation was defined more than once in a namespace.
+func NewDuplicateRelationError(nsName string, relationName string) error {
+ return DuplicateRelationError{
+ error: fmt.Errorf("found duplicate relation/permission name `%s` under definition `%s`", relationName, nsName),
+ namespaceName: nsName,
+ relationName: relationName,
+ }
+}
+
+// NewPermissionsCycleErr constructs an error indicating that a cycle exists amongst permissions.
+func NewPermissionsCycleErr(nsName string, permissionNames []string) error {
+ return PermissionsCycleError{
+ error: fmt.Errorf("under definition `%s`, there exists a cycle in permissions: %s", nsName, strings.Join(permissionNames, ", ")),
+ namespaceName: nsName,
+ permissionNames: permissionNames,
+ }
+}
+
+// NewUnusedCaveatParameterErr constructs indicating that a parameter was unused in a caveat expression.
+func NewUnusedCaveatParameterErr(caveatName string, paramName string) error {
+ return UnusedCaveatParameterError{
+ error: fmt.Errorf("parameter `%s` for caveat `%s` is unused", paramName, caveatName),
+ caveatName: caveatName,
+ paramName: paramName,
+ }
+}
+
+var (
+ _ sharederrors.UnknownNamespaceError = NamespaceNotFoundError{}
+ _ sharederrors.UnknownRelationError = RelationNotFoundError{}
+)
diff --git a/vendor/github.com/authzed/spicedb/internal/namespace/util.go b/vendor/github.com/authzed/spicedb/internal/namespace/util.go
new file mode 100644
index 0000000..497bdfb
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/namespace/util.go
@@ -0,0 +1,148 @@
+package namespace
+
+import (
+ "context"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+)
+
+// ReadNamespaceAndRelation checks that the specified namespace and relation exist in the
+// datastore.
+//
+// Returns NamespaceNotFoundError if the namespace cannot be found.
+// Returns RelationNotFoundError if the relation was not found in the namespace.
+// Returns the direct downstream error for all other unknown error.
+func ReadNamespaceAndRelation(
+ ctx context.Context,
+ namespace string,
+ relation string,
+ ds datastore.Reader,
+) (*core.NamespaceDefinition, *core.Relation, error) {
+ config, _, err := ds.ReadNamespaceByName(ctx, namespace)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ for _, rel := range config.Relation {
+ if rel.Name == relation {
+ return config, rel, nil
+ }
+ }
+
+ return nil, nil, NewRelationNotFoundErr(namespace, relation)
+}
+
+// TypeAndRelationToCheck is a single check of a namespace+relation pair.
+type TypeAndRelationToCheck struct {
+ // NamespaceName is the namespace name to ensure exists.
+ NamespaceName string
+
+ // RelationName is the relation name to ensure exists under the namespace.
+ RelationName string
+
+ // AllowEllipsis, if true, allows for the ellipsis as the RelationName.
+ AllowEllipsis bool
+}
+
+// CheckNamespaceAndRelations ensures that the given namespace+relation checks all succeed. If any fail, returns an error.
+//
+// Returns NamespaceNotFoundError if the namespace cannot be found.
+// Returns RelationNotFoundError if the relation was not found in the namespace.
+// Returns the direct downstream error for all other unknown error.
+func CheckNamespaceAndRelations(ctx context.Context, checks []TypeAndRelationToCheck, ds datastore.Reader) error {
+ nsNames := mapz.NewSet[string]()
+ for _, toCheck := range checks {
+ nsNames.Insert(toCheck.NamespaceName)
+ }
+
+ if nsNames.IsEmpty() {
+ return nil
+ }
+
+ namespaces, err := ds.LookupNamespacesWithNames(ctx, nsNames.AsSlice())
+ if err != nil {
+ return err
+ }
+
+ mappedNamespaces := make(map[string]*core.NamespaceDefinition, len(namespaces))
+ for _, namespace := range namespaces {
+ mappedNamespaces[namespace.Definition.Name] = namespace.Definition
+ }
+
+ for _, toCheck := range checks {
+ nsDef, ok := mappedNamespaces[toCheck.NamespaceName]
+ if !ok {
+ return NewNamespaceNotFoundErr(toCheck.NamespaceName)
+ }
+
+ if toCheck.AllowEllipsis && toCheck.RelationName == datastore.Ellipsis {
+ continue
+ }
+
+ foundRelation := false
+ for _, rel := range nsDef.Relation {
+ if rel.Name == toCheck.RelationName {
+ foundRelation = true
+ break
+ }
+ }
+
+ if !foundRelation {
+ return NewRelationNotFoundErr(toCheck.NamespaceName, toCheck.RelationName)
+ }
+ }
+
+ return nil
+}
+
+// CheckNamespaceAndRelation checks that the specified namespace and relation exist in the
+// datastore.
+//
+// Returns datastore.NamespaceNotFoundError if the namespace cannot be found.
+// Returns RelationNotFoundError if the relation was not found in the namespace.
+// Returns the direct downstream error for all other unknown error.
+func CheckNamespaceAndRelation(
+ ctx context.Context,
+ namespace string,
+ relation string,
+ allowEllipsis bool,
+ ds datastore.Reader,
+) error {
+ config, _, err := ds.ReadNamespaceByName(ctx, namespace)
+ if err != nil {
+ return err
+ }
+
+ if allowEllipsis && relation == datastore.Ellipsis {
+ return nil
+ }
+
+ for _, rel := range config.Relation {
+ if rel.Name == relation {
+ return nil
+ }
+ }
+
+ return NewRelationNotFoundErr(namespace, relation)
+}
+
+// ListReferencedNamespaces returns the names of all namespaces referenced in the
+// given namespace definitions. This includes the namespaces themselves, as well as
+// any found in type information on relations.
+func ListReferencedNamespaces(nsdefs []*core.NamespaceDefinition) []string {
+ referencedNamespaceNamesSet := mapz.NewSet[string]()
+ for _, nsdef := range nsdefs {
+ referencedNamespaceNamesSet.Insert(nsdef.Name)
+
+ for _, relation := range nsdef.Relation {
+ if relation.GetTypeInformation() != nil {
+ for _, allowedRel := range relation.GetTypeInformation().AllowedDirectRelations {
+ referencedNamespaceNamesSet.Insert(allowedRel.GetNamespace())
+ }
+ }
+ }
+ }
+ return referencedNamespaceNamesSet.AsSlice()
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/relationships/doc.go b/vendor/github.com/authzed/spicedb/internal/relationships/doc.go
new file mode 100644
index 0000000..6e1bfc6
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/relationships/doc.go
@@ -0,0 +1,2 @@
+// Package relationships contains helper methods to validate relationships that are going to be written.
+package relationships
diff --git a/vendor/github.com/authzed/spicedb/internal/relationships/errors.go b/vendor/github.com/authzed/spicedb/internal/relationships/errors.go
new file mode 100644
index 0000000..3237e0b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/relationships/errors.go
@@ -0,0 +1,195 @@
+package relationships
+
+import (
+ "fmt"
+ "maps"
+ "sort"
+ "strings"
+
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/lithammer/fuzzysearch/fuzzy"
+
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// InvalidSubjectTypeError indicates that a write was attempted with a subject type which is not
+// allowed on relation.
+type InvalidSubjectTypeError struct {
+ error
+ relationship tuple.Relationship
+ relationType *core.AllowedRelation
+ additionalDetails map[string]string
+}
+
+// NewInvalidSubjectTypeError constructs a new error for attempting to write an invalid subject type.
+func NewInvalidSubjectTypeError(
+ relationship tuple.Relationship,
+ relationType *core.AllowedRelation,
+ definition *schema.Definition,
+) error {
+ allowedTypes, err := definition.AllowedDirectRelationsAndWildcards(relationship.Resource.Relation)
+ if err != nil {
+ return err
+ }
+
+ // Special case: if the subject is uncaveated but only a caveated version is allowed, return
+ // a more descriptive error.
+ if relationship.OptionalCaveat == nil {
+ allowedCaveatsForSubject := mapz.NewSet[string]()
+
+ for _, allowedType := range allowedTypes {
+ if allowedType.RequiredCaveat != nil &&
+ allowedType.RequiredCaveat.CaveatName != "" &&
+ allowedType.Namespace == relationship.Subject.ObjectType &&
+ allowedType.GetRelation() == relationship.Subject.Relation &&
+ (allowedType.RequiredExpiration != nil) == (relationship.OptionalExpiration != nil) {
+ allowedCaveatsForSubject.Add(allowedType.RequiredCaveat.CaveatName)
+ }
+ }
+
+ if !allowedCaveatsForSubject.IsEmpty() {
+ return InvalidSubjectTypeError{
+ error: fmt.Errorf(
+ "subjects of type `%s` are not allowed on relation `%s#%s` without one of the following caveats: %s",
+ schema.SourceForAllowedRelation(relationType),
+ relationship.Resource.ObjectType,
+ relationship.Resource.Relation,
+ strings.Join(allowedCaveatsForSubject.AsSlice(), ","),
+ ),
+ relationship: relationship,
+ relationType: relationType,
+ additionalDetails: map[string]string{
+ "allowed_caveats": strings.Join(allowedCaveatsForSubject.AsSlice(), ","),
+ },
+ }
+ }
+ }
+
+ allowedTypeStrings := make([]string, 0, len(allowedTypes))
+ for _, allowedType := range allowedTypes {
+ allowedTypeStrings = append(allowedTypeStrings, schema.SourceForAllowedRelation(allowedType))
+ }
+
+ matches := fuzzy.RankFind(schema.SourceForAllowedRelation(relationType), allowedTypeStrings)
+ sort.Sort(matches)
+ if len(matches) > 0 {
+ return InvalidSubjectTypeError{
+ error: fmt.Errorf(
+ "subjects of type `%s` are not allowed on relation `%s#%s`; did you mean `%s`?",
+ schema.SourceForAllowedRelation(relationType),
+ relationship.Resource.ObjectType,
+ relationship.Resource.Relation,
+ matches[0].Target,
+ ),
+ relationship: relationship,
+ relationType: relationType,
+ additionalDetails: nil,
+ }
+ }
+
+ return InvalidSubjectTypeError{
+ error: fmt.Errorf(
+ "subjects of type `%s` are not allowed on relation `%s#%s`",
+ schema.SourceForAllowedRelation(relationType),
+ relationship.Resource.ObjectType,
+ relationship.Resource.Relation,
+ ),
+ relationship: relationship,
+ relationType: relationType,
+ additionalDetails: nil,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err InvalidSubjectTypeError) GRPCStatus() *status.Status {
+ details := map[string]string{
+ "definition_name": err.relationship.Resource.ObjectType,
+ "relation_name": err.relationship.Resource.Relation,
+ "subject_type": schema.SourceForAllowedRelation(err.relationType),
+ }
+
+ if err.additionalDetails != nil {
+ maps.Copy(details, err.additionalDetails)
+ }
+
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_INVALID_SUBJECT_TYPE,
+ details,
+ ),
+ )
+}
+
+// CannotWriteToPermissionError indicates that a write was attempted on a permission.
+type CannotWriteToPermissionError struct {
+ error
+ rel tuple.Relationship
+}
+
+// NewCannotWriteToPermissionError constructs a new error for attempting to write to a permission.
+func NewCannotWriteToPermissionError(rel tuple.Relationship) CannotWriteToPermissionError {
+ return CannotWriteToPermissionError{
+ error: fmt.Errorf(
+ "cannot write a relationship to permission `%s` under definition `%s`",
+ rel.Resource.Relation,
+ rel.Resource.ObjectType,
+ ),
+ rel: rel,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err CannotWriteToPermissionError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_CANNOT_UPDATE_PERMISSION,
+ map[string]string{
+ "definition_name": err.rel.Resource.ObjectType,
+ "permission_name": err.rel.Resource.Relation,
+ },
+ ),
+ )
+}
+
+// CaveatNotFoundError indicates that a caveat referenced in a relationship update was not found.
+type CaveatNotFoundError struct {
+ error
+ relationship tuple.Relationship
+}
+
+// NewCaveatNotFoundError constructs a new caveat not found error.
+func NewCaveatNotFoundError(relationship tuple.Relationship) CaveatNotFoundError {
+ return CaveatNotFoundError{
+ error: fmt.Errorf(
+ "the caveat `%s` was not found for relationship `%s`",
+ relationship.OptionalCaveat.CaveatName,
+ tuple.MustString(relationship),
+ ),
+ relationship: relationship,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err CaveatNotFoundError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.FailedPrecondition,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_UNKNOWN_CAVEAT,
+ map[string]string{
+ "caveat_name": err.relationship.OptionalCaveat.CaveatName,
+ },
+ ),
+ )
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/relationships/validation.go b/vendor/github.com/authzed/spicedb/internal/relationships/validation.go
new file mode 100644
index 0000000..ff4a6fb
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/relationships/validation.go
@@ -0,0 +1,280 @@
+package relationships
+
+import (
+ "context"
+
+ "github.com/samber/lo"
+
+ "github.com/authzed/spicedb/internal/namespace"
+ "github.com/authzed/spicedb/pkg/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ ns "github.com/authzed/spicedb/pkg/namespace"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// ValidateRelationshipUpdates performs validation on the given relationship updates, ensuring that
+// they can be applied against the datastore.
+func ValidateRelationshipUpdates(
+ ctx context.Context,
+ reader datastore.Reader,
+ caveatTypeSet *caveattypes.TypeSet,
+ updates []tuple.RelationshipUpdate,
+) error {
+ rels := lo.Map(updates, func(item tuple.RelationshipUpdate, _ int) tuple.Relationship {
+ return item.Relationship
+ })
+
+ // Load namespaces and caveats.
+ referencedNamespaceMap, referencedCaveatMap, err := loadNamespacesAndCaveats(ctx, rels, reader)
+ if err != nil {
+ return err
+ }
+
+ // Validate each updates's types.
+ for _, update := range updates {
+ option := ValidateRelationshipForCreateOrTouch
+ if update.Operation == tuple.UpdateOperationDelete {
+ option = ValidateRelationshipForDeletion
+ }
+
+ if err := ValidateOneRelationship(
+ referencedNamespaceMap,
+ referencedCaveatMap,
+ caveatTypeSet,
+ update.Relationship,
+ option,
+ ); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// ValidateRelationshipsForCreateOrTouch performs validation on the given relationships to be written, ensuring that
+// they can be applied against the datastore.
+//
+// NOTE: This method *cannot* be used for relationships that will be deleted.
+func ValidateRelationshipsForCreateOrTouch(
+ ctx context.Context,
+ reader datastore.Reader,
+ caveatTypeSet *caveattypes.TypeSet,
+ rels ...tuple.Relationship,
+) error {
+ // Load namespaces and caveats.
+ referencedNamespaceMap, referencedCaveatMap, err := loadNamespacesAndCaveats(ctx, rels, reader)
+ if err != nil {
+ return err
+ }
+
+ // Validate each relationship's types.
+ for _, rel := range rels {
+ if err := ValidateOneRelationship(
+ referencedNamespaceMap,
+ referencedCaveatMap,
+ caveatTypeSet,
+ rel,
+ ValidateRelationshipForCreateOrTouch,
+ ); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func loadNamespacesAndCaveats(ctx context.Context, rels []tuple.Relationship, reader datastore.Reader) (map[string]*schema.Definition, map[string]*core.CaveatDefinition, error) {
+ referencedNamespaceNames := mapz.NewSet[string]()
+ referencedCaveatNamesWithContext := mapz.NewSet[string]()
+ for _, rel := range rels {
+ referencedNamespaceNames.Insert(rel.Resource.ObjectType)
+ referencedNamespaceNames.Insert(rel.Subject.ObjectType)
+ if hasNonEmptyCaveatContext(rel) {
+ referencedCaveatNamesWithContext.Insert(rel.OptionalCaveat.CaveatName)
+ }
+ }
+
+ var referencedNamespaceMap map[string]*schema.Definition
+ var referencedCaveatMap map[string]*core.CaveatDefinition
+
+ if !referencedNamespaceNames.IsEmpty() {
+ foundNamespaces, err := reader.LookupNamespacesWithNames(ctx, referencedNamespaceNames.AsSlice())
+ if err != nil {
+ return nil, nil, err
+ }
+ ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(reader))
+
+ referencedNamespaceMap = make(map[string]*schema.Definition, len(foundNamespaces))
+ for _, nsDef := range foundNamespaces {
+ nts, err := schema.NewDefinition(ts, nsDef.Definition)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ referencedNamespaceMap[nsDef.Definition.Name] = nts
+ }
+ }
+
+ if !referencedCaveatNamesWithContext.IsEmpty() {
+ foundCaveats, err := reader.LookupCaveatsWithNames(ctx, referencedCaveatNamesWithContext.AsSlice())
+ if err != nil {
+ return nil, nil, err
+ }
+
+ referencedCaveatMap = make(map[string]*core.CaveatDefinition, len(foundCaveats))
+ for _, caveatDef := range foundCaveats {
+ referencedCaveatMap[caveatDef.Definition.Name] = caveatDef.Definition
+ }
+ }
+ return referencedNamespaceMap, referencedCaveatMap, nil
+}
+
+// ValidationRelationshipRule is the rule to use for the validation.
+type ValidationRelationshipRule int
+
+const (
+ // ValidateRelationshipForCreateOrTouch indicates that the validation should occur for a CREATE or TOUCH operation.
+ ValidateRelationshipForCreateOrTouch ValidationRelationshipRule = 0
+
+ // ValidateRelationshipForDeletion indicates that the validation should occur for a DELETE operation.
+ ValidateRelationshipForDeletion ValidationRelationshipRule = 1
+)
+
+// ValidateOneRelationship validates a single relationship for CREATE/TOUCH or DELETE.
+func ValidateOneRelationship(
+ namespaceMap map[string]*schema.Definition,
+ caveatMap map[string]*core.CaveatDefinition,
+ caveatTypeSet *caveattypes.TypeSet,
+ rel tuple.Relationship,
+ rule ValidationRelationshipRule,
+) error {
+ // Validate the IDs of the resource and subject.
+ if err := tuple.ValidateResourceID(rel.Resource.ObjectID); err != nil {
+ return err
+ }
+
+ if err := tuple.ValidateSubjectID(rel.Subject.ObjectID); err != nil {
+ return err
+ }
+
+ // Validate the namespace and relation for the resource.
+ resourceTS, ok := namespaceMap[rel.Resource.ObjectType]
+ if !ok {
+ return namespace.NewNamespaceNotFoundErr(rel.Resource.ObjectType)
+ }
+
+ if !resourceTS.HasRelation(rel.Resource.Relation) {
+ return namespace.NewRelationNotFoundErr(rel.Resource.ObjectType, rel.Resource.Relation)
+ }
+
+ // Validate the namespace and relation for the subject.
+ subjectTS, ok := namespaceMap[rel.Subject.ObjectType]
+ if !ok {
+ return namespace.NewNamespaceNotFoundErr(rel.Subject.ObjectType)
+ }
+
+ if rel.Subject.Relation != tuple.Ellipsis {
+ if !subjectTS.HasRelation(rel.Subject.Relation) {
+ return namespace.NewRelationNotFoundErr(rel.Subject.ObjectType, rel.Subject.Relation)
+ }
+ }
+
+ // Validate that the relationship is not writing to a permission.
+ if resourceTS.IsPermission(rel.Resource.Relation) {
+ return NewCannotWriteToPermissionError(rel)
+ }
+
+ // Validate the subject against the allowed relation(s).
+ var caveat *core.AllowedCaveat
+ if rel.OptionalCaveat != nil {
+ caveat = ns.AllowedCaveat(rel.OptionalCaveat.CaveatName)
+ }
+
+ var relationToCheck *core.AllowedRelation
+ if rel.Subject.ObjectID == tuple.PublicWildcard {
+ relationToCheck = ns.AllowedPublicNamespaceWithCaveat(rel.Subject.ObjectType, caveat)
+ } else {
+ relationToCheck = ns.AllowedRelationWithCaveat(
+ rel.Subject.ObjectType,
+ rel.Subject.Relation,
+ caveat)
+ }
+
+ if rel.OptionalExpiration != nil {
+ relationToCheck = ns.WithExpiration(relationToCheck)
+ }
+
+ switch {
+ case rule == ValidateRelationshipForCreateOrTouch || caveat != nil:
+ // For writing or when the caveat was specified, the caveat must be a direct match.
+ isAllowed, err := resourceTS.HasAllowedRelation(
+ rel.Resource.Relation,
+ relationToCheck)
+ if err != nil {
+ return err
+ }
+
+ if isAllowed != schema.AllowedRelationValid {
+ return NewInvalidSubjectTypeError(rel, relationToCheck, resourceTS)
+ }
+
+ case rule == ValidateRelationshipForDeletion && caveat == nil:
+ // For deletion, the caveat *can* be ignored if not specified.
+ if rel.Subject.ObjectID == tuple.PublicWildcard {
+ isAllowed, err := resourceTS.IsAllowedPublicNamespace(rel.Resource.Relation, rel.Subject.ObjectType)
+ if err != nil {
+ return err
+ }
+
+ if isAllowed != schema.PublicSubjectAllowed {
+ return NewInvalidSubjectTypeError(rel, relationToCheck, resourceTS)
+ }
+ } else {
+ isAllowed, err := resourceTS.IsAllowedDirectRelation(rel.Resource.Relation, rel.Subject.ObjectType, rel.Subject.Relation)
+ if err != nil {
+ return err
+ }
+
+ if isAllowed != schema.DirectRelationValid {
+ return NewInvalidSubjectTypeError(rel, relationToCheck, resourceTS)
+ }
+ }
+
+ default:
+ return spiceerrors.MustBugf("unknown validate rule")
+ }
+
+ // Validate caveat and its context, if applicable.
+ if hasNonEmptyCaveatContext(rel) {
+ caveat, ok := caveatMap[rel.OptionalCaveat.CaveatName]
+ if !ok {
+ // Should ideally never happen since the caveat is type checked above, but just in case.
+ return NewCaveatNotFoundError(rel)
+ }
+
+ // Verify that the provided context information matches the types of the parameters defined.
+ _, err := caveats.ConvertContextToParameters(
+ caveatTypeSet,
+ rel.OptionalCaveat.Context.AsMap(),
+ caveat.ParameterTypes,
+ caveats.ErrorForUnknownParameters,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func hasNonEmptyCaveatContext(relationship tuple.Relationship) bool {
+ return relationship.OptionalCaveat != nil &&
+ relationship.OptionalCaveat.CaveatName != "" &&
+ relationship.OptionalCaveat.Context != nil &&
+ len(relationship.OptionalCaveat.Context.GetFields()) > 0
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go b/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go
new file mode 100644
index 0000000..05b3907
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go
@@ -0,0 +1,208 @@
+package shared
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strconv"
+
+ "github.com/rs/zerolog"
+ "google.golang.org/genproto/googleapis/rpc/errdetails"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/graph"
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/internal/sharederrors"
+ "github.com/authzed/spicedb/pkg/cursor"
+ "github.com/authzed/spicedb/pkg/datastore"
+ dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// ErrServiceReadOnly is an extended GRPC error returned when a service is in read-only mode.
+var ErrServiceReadOnly = mustMakeStatusReadonly()
+
+func mustMakeStatusReadonly() error {
+ status, err := status.New(codes.Unavailable, "service read-only").WithDetails(&errdetails.ErrorInfo{
+ Reason: v1.ErrorReason_name[int32(v1.ErrorReason_ERROR_REASON_SERVICE_READ_ONLY)],
+ Domain: spiceerrors.Domain,
+ })
+ if err != nil {
+ panic("error constructing shared error type")
+ }
+ return status.Err()
+}
+
+// NewSchemaWriteDataValidationError creates a new error representing that a schema write cannot be
+// completed due to existing data that would be left unreferenced.
+func NewSchemaWriteDataValidationError(message string, args ...any) SchemaWriteDataValidationError {
+ return SchemaWriteDataValidationError{
+ error: fmt.Errorf(message, args...),
+ }
+}
+
+// SchemaWriteDataValidationError occurs when a schema cannot be applied due to leaving data unreferenced.
+type SchemaWriteDataValidationError struct {
+ error
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err SchemaWriteDataValidationError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error)
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err SchemaWriteDataValidationError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_SCHEMA_TYPE_ERROR,
+ map[string]string{},
+ ),
+ )
+}
+
+// MaxDepthExceededError is an error returned when the maximum depth for dispatching has been exceeded.
+type MaxDepthExceededError struct {
+ *spiceerrors.WithAdditionalDetailsError
+
+ // AllowedMaximumDepth is the configured allowed maximum depth.
+ AllowedMaximumDepth uint32
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err MaxDepthExceededError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.ResourceExhausted,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_MAXIMUM_DEPTH_EXCEEDED,
+ err.AddToDetails(map[string]string{
+ "maximum_depth_allowed": strconv.Itoa(int(err.AllowedMaximumDepth)),
+ }),
+ ),
+ )
+}
+
+// NewMaxDepthExceededError creates a new MaxDepthExceededError.
+func NewMaxDepthExceededError(allowedMaximumDepth uint32, isCheckRequest bool) error {
+ if isCheckRequest {
+ return MaxDepthExceededError{
+ spiceerrors.NewWithAdditionalDetailsError(fmt.Errorf("the check request has exceeded the allowable maximum depth of %d: this usually indicates a recursive or too deep data dependency. Try running zed with --explain to see the dependency. See: https://spicedb.dev/d/debug-max-depth-check", allowedMaximumDepth)),
+ allowedMaximumDepth,
+ }
+ }
+
+ return MaxDepthExceededError{
+ spiceerrors.NewWithAdditionalDetailsError(fmt.Errorf("the request has exceeded the allowable maximum depth of %d: this usually indicates a recursive or too deep data dependency. See: https://spicedb.dev/d/debug-max-depth", allowedMaximumDepth)),
+ allowedMaximumDepth,
+ }
+}
+
+func AsValidationError(err error) *SchemaWriteDataValidationError {
+ var validationErr SchemaWriteDataValidationError
+ if errors.As(err, &validationErr) {
+ return &validationErr
+ }
+ return nil
+}
+
+type ConfigForErrors struct {
+ MaximumAPIDepth uint32
+ DebugTrace *v1.DebugInformation
+}
+
+func RewriteErrorWithoutConfig(ctx context.Context, err error) error {
+ return rewriteError(ctx, err, nil)
+}
+
+func RewriteError(ctx context.Context, err error, config *ConfigForErrors) error {
+ rerr := rewriteError(ctx, err, config)
+ if config != nil && config.DebugTrace != nil {
+ spiceerrors.WithAdditionalDetails(rerr, spiceerrors.DebugTraceErrorDetailsKey, config.DebugTrace.String())
+ }
+ return rerr
+}
+
+func rewriteError(ctx context.Context, err error, config *ConfigForErrors) error {
+ // Check if the error can be directly used.
+ if _, ok := status.FromError(err); ok {
+ return err
+ }
+
+ // Otherwise, convert any graph/datastore errors.
+ var nsNotFoundError sharederrors.UnknownNamespaceError
+ var relationNotFoundError sharederrors.UnknownRelationError
+
+ var compilerError compiler.BaseCompilerError
+ var sourceError spiceerrors.WithSourceError
+ var typeError schema.TypeError
+ var maxDepthError dispatch.MaxDepthExceededError
+
+ switch {
+ case errors.As(err, &typeError):
+ return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_SCHEMA_TYPE_ERROR)
+ case errors.As(err, &compilerError):
+ return spiceerrors.WithCodeAndReason(err, codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_SCHEMA_PARSE_ERROR)
+ case errors.As(err, &sourceError):
+ return spiceerrors.WithCodeAndReason(err, codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_SCHEMA_PARSE_ERROR)
+
+ case errors.Is(err, cursor.ErrHashMismatch):
+ return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_INVALID_CURSOR)
+
+ case errors.As(err, &nsNotFoundError):
+ return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_DEFINITION)
+ case errors.As(err, &relationNotFoundError):
+ return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_RELATION_OR_PERMISSION)
+
+ case errors.As(err, &maxDepthError):
+ if config == nil {
+ return spiceerrors.MustBugf("missing config for API error")
+ }
+
+ _, isCheckRequest := maxDepthError.Request.(*dispatchv1.DispatchCheckRequest)
+ return NewMaxDepthExceededError(config.MaximumAPIDepth, isCheckRequest)
+
+ case errors.As(err, &datastore.ReadOnlyError{}):
+ return ErrServiceReadOnly
+ case errors.As(err, &datastore.InvalidRevisionError{}):
+ return status.Errorf(codes.OutOfRange, "invalid zedtoken: %s", err)
+ case errors.As(err, &datastore.CaveatNameNotFoundError{}):
+ return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_CAVEAT)
+ case errors.As(err, &datastore.WatchDisabledError{}):
+ return status.Errorf(codes.FailedPrecondition, "%s", err)
+ case errors.As(err, &datastore.CounterAlreadyRegisteredError{}):
+ return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_COUNTER_ALREADY_REGISTERED)
+ case errors.As(err, &datastore.CounterNotRegisteredError{}):
+ return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_COUNTER_NOT_REGISTERED)
+
+ case errors.As(err, &graph.RelationMissingTypeInfoError{}):
+ return status.Errorf(codes.FailedPrecondition, "failed precondition: %s", err)
+ case errors.As(err, &graph.AlwaysFailError{}):
+ log.Ctx(ctx).Err(err).Msg("received internal error")
+ return status.Errorf(codes.Internal, "internal error: %s", err)
+ case errors.As(err, &graph.UnimplementedError{}):
+ return status.Errorf(codes.Unimplemented, "%s", err)
+ case errors.Is(err, context.DeadlineExceeded):
+ return status.Errorf(codes.DeadlineExceeded, "%s", err)
+ case errors.Is(err, context.Canceled):
+ err := context.Cause(ctx)
+ if err != nil {
+ if _, ok := status.FromError(err); ok {
+ return err
+ }
+ }
+
+ return status.Errorf(codes.Canceled, "%s", err)
+ default:
+ log.Ctx(ctx).Err(err).Msg("received unexpected error")
+ return err
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go b/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go
new file mode 100644
index 0000000..455de0a
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go
@@ -0,0 +1,52 @@
+package shared
+
+import (
+ "google.golang.org/grpc"
+
+ "github.com/authzed/spicedb/internal/middleware/servicespecific"
+)
+
+// WithUnaryServiceSpecificInterceptor is a helper to add a unary interceptor or interceptor
+// chain to a service.
+type WithUnaryServiceSpecificInterceptor struct {
+ Unary grpc.UnaryServerInterceptor
+}
+
+// UnaryInterceptor implements servicespecific.ExtraUnaryInterceptor
+func (wussi WithUnaryServiceSpecificInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
+ return wussi.Unary
+}
+
+// WithStreamServiceSpecificInterceptor is a helper to add a stream interceptor or interceptor
+// chain to a service.
+type WithStreamServiceSpecificInterceptor struct {
+ Stream grpc.StreamServerInterceptor
+}
+
+// StreamInterceptor implements servicespecific.ExtraStreamInterceptor
+func (wsssi WithStreamServiceSpecificInterceptor) StreamInterceptor() grpc.StreamServerInterceptor {
+ return wsssi.Stream
+}
+
+// WithServiceSpecificInterceptors is a helper to add both a unary and stream interceptor
+// or interceptor chain to a service.
+type WithServiceSpecificInterceptors struct {
+ Unary grpc.UnaryServerInterceptor
+ Stream grpc.StreamServerInterceptor
+}
+
+// UnaryInterceptor implements servicespecific.ExtraUnaryInterceptor
+func (wssi WithServiceSpecificInterceptors) UnaryInterceptor() grpc.UnaryServerInterceptor {
+ return wssi.Unary
+}
+
+// StreamInterceptor implements servicespecific.ExtraStreamInterceptor
+func (wssi WithServiceSpecificInterceptors) StreamInterceptor() grpc.StreamServerInterceptor {
+ return wssi.Stream
+}
+
+var (
+ _ servicespecific.ExtraUnaryInterceptor = WithUnaryServiceSpecificInterceptor{}
+ _ servicespecific.ExtraUnaryInterceptor = WithServiceSpecificInterceptors{}
+ _ servicespecific.ExtraStreamInterceptor = WithServiceSpecificInterceptors{}
+)
diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go b/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go
new file mode 100644
index 0000000..83accde
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go
@@ -0,0 +1,474 @@
+package shared
+
+import (
+ "context"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/internal/namespace"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats"
+ nsdiff "github.com/authzed/spicedb/pkg/diff/namespace"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// ValidatedSchemaChanges is a set of validated schema changes that can be applied to the datastore.
+type ValidatedSchemaChanges struct {
+ compiled *compiler.CompiledSchema
+ validatedTypeSystems map[string]*schema.ValidatedDefinition
+ newCaveatDefNames *mapz.Set[string]
+ newObjectDefNames *mapz.Set[string]
+ additiveOnly bool
+}
+
+// ValidateSchemaChanges validates the schema found in the compiled schema and returns a
+// ValidatedSchemaChanges, if fully validated.
+func ValidateSchemaChanges(ctx context.Context, compiled *compiler.CompiledSchema, caveatTypeSet *caveattypes.TypeSet, additiveOnly bool) (*ValidatedSchemaChanges, error) {
+ // 1) Validate the caveats defined.
+ newCaveatDefNames := mapz.NewSet[string]()
+ for _, caveatDef := range compiled.CaveatDefinitions {
+ if err := namespace.ValidateCaveatDefinition(caveatTypeSet, caveatDef); err != nil {
+ return nil, err
+ }
+
+ newCaveatDefNames.Insert(caveatDef.Name)
+ }
+
+ // 2) Validate the namespaces defined.
+ newObjectDefNames := mapz.NewSet[string]()
+ validatedTypeSystems := make(map[string]*schema.ValidatedDefinition, len(compiled.ObjectDefinitions))
+ res := schema.ResolverForPredefinedDefinitions(schema.PredefinedElements{
+ Definitions: compiled.ObjectDefinitions,
+ Caveats: compiled.CaveatDefinitions,
+ })
+ ts := schema.NewTypeSystem(res)
+
+ for _, nsdef := range compiled.ObjectDefinitions {
+ vts, err := ts.GetValidatedDefinition(ctx, nsdef.GetName())
+ if err != nil {
+ return nil, err
+ }
+
+ validatedTypeSystems[nsdef.Name] = vts
+ newObjectDefNames.Insert(nsdef.Name)
+ }
+
+ return &ValidatedSchemaChanges{
+ compiled: compiled,
+ validatedTypeSystems: validatedTypeSystems,
+ newCaveatDefNames: newCaveatDefNames,
+ newObjectDefNames: newObjectDefNames,
+ additiveOnly: additiveOnly,
+ }, nil
+}
+
+// AppliedSchemaChanges holds information about the applied schema changes.
+type AppliedSchemaChanges struct {
+ // TotalOperationCount holds the total number of "dispatch" operations performed by the schema
+ // being applied.
+ TotalOperationCount int
+
+ // NewObjectDefNames contains the names of the newly added object definitions.
+ NewObjectDefNames []string
+
+ // RemovedObjectDefNames contains the names of the removed object definitions.
+ RemovedObjectDefNames []string
+
+ // NewCaveatDefNames contains the names of the newly added caveat definitions.
+ NewCaveatDefNames []string
+
+ // RemovedCaveatDefNames contains the names of the removed caveat definitions.
+ RemovedCaveatDefNames []string
+}
+
+// ApplySchemaChanges applies schema changes found in the validated changes struct, via the specified
+// ReadWriteTransaction.
+func ApplySchemaChanges(ctx context.Context, rwt datastore.ReadWriteTransaction, caveatTypeSet *caveattypes.TypeSet, validated *ValidatedSchemaChanges) (*AppliedSchemaChanges, error) {
+ existingCaveats, err := rwt.ListAllCaveats(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ existingObjectDefs, err := rwt.ListAllNamespaces(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return ApplySchemaChangesOverExisting(ctx, rwt, caveatTypeSet, validated, datastore.DefinitionsOf(existingCaveats), datastore.DefinitionsOf(existingObjectDefs))
+}
+
+// ApplySchemaChangesOverExisting applies schema changes found in the validated changes struct, against
+// existing caveat and object definitions given.
+func ApplySchemaChangesOverExisting(
+ ctx context.Context,
+ rwt datastore.ReadWriteTransaction,
+ caveatTypeSet *caveattypes.TypeSet,
+ validated *ValidatedSchemaChanges,
+ existingCaveats []*core.CaveatDefinition,
+ existingObjectDefs []*core.NamespaceDefinition,
+) (*AppliedSchemaChanges, error) {
+ // Build a map of existing caveats to determine those being removed, if any.
+ existingCaveatDefMap := make(map[string]*core.CaveatDefinition, len(existingCaveats))
+ existingCaveatDefNames := mapz.NewSet[string]()
+
+ for _, existingCaveat := range existingCaveats {
+ existingCaveatDefMap[existingCaveat.Name] = existingCaveat
+ existingCaveatDefNames.Insert(existingCaveat.Name)
+ }
+
+ // For each caveat definition, perform a diff and ensure the changes will not result in type errors.
+ caveatDefsWithChanges := make([]*core.CaveatDefinition, 0, len(validated.compiled.CaveatDefinitions))
+ for _, caveatDef := range validated.compiled.CaveatDefinitions {
+ diff, err := sanityCheckCaveatChanges(ctx, rwt, caveatTypeSet, caveatDef, existingCaveatDefMap)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(diff.Deltas()) > 0 {
+ caveatDefsWithChanges = append(caveatDefsWithChanges, caveatDef)
+ }
+ }
+
+ removedCaveatDefNames := existingCaveatDefNames.Subtract(validated.newCaveatDefNames)
+
+ // Build a map of existing definitions to determine those being removed, if any.
+ existingObjectDefMap := make(map[string]*core.NamespaceDefinition, len(existingObjectDefs))
+ existingObjectDefNames := mapz.NewSet[string]()
+ for _, existingDef := range existingObjectDefs {
+ existingObjectDefMap[existingDef.Name] = existingDef
+ existingObjectDefNames.Insert(existingDef.Name)
+ }
+
+ // For each definition, perform a diff and ensure the changes will not result in any
+ // breaking changes.
+ objectDefsWithChanges := make([]*core.NamespaceDefinition, 0, len(validated.compiled.ObjectDefinitions))
+ for _, nsdef := range validated.compiled.ObjectDefinitions {
+ diff, err := sanityCheckNamespaceChanges(ctx, rwt, nsdef, existingObjectDefMap)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(diff.Deltas()) > 0 {
+ objectDefsWithChanges = append(objectDefsWithChanges, nsdef)
+
+ vts, ok := validated.validatedTypeSystems[nsdef.Name]
+ if !ok {
+ return nil, spiceerrors.MustBugf("validated type system not found for namespace `%s`", nsdef.Name)
+ }
+
+ if err := namespace.AnnotateNamespace(vts); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ log.Ctx(ctx).
+ Trace().
+ Int("objectDefinitions", len(validated.compiled.ObjectDefinitions)).
+ Int("caveatDefinitions", len(validated.compiled.CaveatDefinitions)).
+ Int("objectDefsWithChanges", len(objectDefsWithChanges)).
+ Int("caveatDefsWithChanges", len(caveatDefsWithChanges)).
+ Msg("validated namespace definitions")
+
+ // Ensure that deleting namespaces will not result in any relationships left without associated
+ // schema.
+ removedObjectDefNames := existingObjectDefNames.Subtract(validated.newObjectDefNames)
+ if !validated.additiveOnly {
+ if err := removedObjectDefNames.ForEach(func(nsdefName string) error {
+ return ensureNoRelationshipsExist(ctx, rwt, nsdefName)
+ }); err != nil {
+ return nil, err
+ }
+ }
+
+ // Write the new/changes caveats.
+ if len(caveatDefsWithChanges) > 0 {
+ if err := rwt.WriteCaveats(ctx, caveatDefsWithChanges); err != nil {
+ return nil, err
+ }
+ }
+
+ // Write the new/changed namespaces.
+ if len(objectDefsWithChanges) > 0 {
+ if err := rwt.WriteNamespaces(ctx, objectDefsWithChanges...); err != nil {
+ return nil, err
+ }
+ }
+
+ if !validated.additiveOnly {
+ // Delete the removed namespaces.
+ if removedObjectDefNames.Len() > 0 {
+ if err := rwt.DeleteNamespaces(ctx, removedObjectDefNames.AsSlice()...); err != nil {
+ return nil, err
+ }
+ }
+
+ // Delete the removed caveats.
+ if !removedCaveatDefNames.IsEmpty() {
+ if err := rwt.DeleteCaveats(ctx, removedCaveatDefNames.AsSlice()); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ log.Ctx(ctx).Trace().
+ Interface("objectDefinitions", validated.compiled.ObjectDefinitions).
+ Interface("caveatDefinitions", validated.compiled.CaveatDefinitions).
+ Object("addedOrChangedObjectDefinitions", validated.newObjectDefNames).
+ Object("removedObjectDefinitions", removedObjectDefNames).
+ Object("addedOrChangedCaveatDefinitions", validated.newCaveatDefNames).
+ Object("removedCaveatDefinitions", removedCaveatDefNames).
+ Msg("completed schema update")
+
+ return &AppliedSchemaChanges{
+ TotalOperationCount: len(validated.compiled.ObjectDefinitions) + len(validated.compiled.CaveatDefinitions) + removedObjectDefNames.Len() + removedCaveatDefNames.Len(),
+ NewObjectDefNames: validated.newObjectDefNames.Subtract(existingObjectDefNames).AsSlice(),
+ RemovedObjectDefNames: removedObjectDefNames.AsSlice(),
+ NewCaveatDefNames: validated.newCaveatDefNames.Subtract(existingCaveatDefNames).AsSlice(),
+ RemovedCaveatDefNames: removedCaveatDefNames.AsSlice(),
+ }, nil
+}
+
+// sanityCheckCaveatChanges ensures that a caveat definition being written does not break
+// the types of the parameters that may already exist on relationships.
+func sanityCheckCaveatChanges(
+ _ context.Context,
+ _ datastore.ReadWriteTransaction,
+ caveatTypeSet *caveattypes.TypeSet,
+ caveatDef *core.CaveatDefinition,
+ existingDefs map[string]*core.CaveatDefinition,
+) (*caveatdiff.Diff, error) {
+ // Ensure that the updated namespace does not break the existing tuple data.
+ existing := existingDefs[caveatDef.Name]
+ diff, err := caveatdiff.DiffCaveats(existing, caveatDef, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, delta := range diff.Deltas() {
+ switch delta.Type {
+ case caveatdiff.RemovedParameter:
+ return diff, NewSchemaWriteDataValidationError("cannot remove parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name)
+
+ case caveatdiff.ParameterTypeChanged:
+ return diff, NewSchemaWriteDataValidationError("cannot change the type of parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name)
+ }
+ }
+
+ return diff, nil
+}
+
+// ensureNoRelationshipsExist ensures that no relationships exist within the namespace with the given name.
+func ensureNoRelationshipsExist(ctx context.Context, rwt datastore.ReadWriteTransaction, namespaceName string) error {
+ qy, qyErr := rwt.QueryRelationships(
+ ctx,
+ datastore.RelationshipsFilter{OptionalResourceType: namespaceName},
+ options.WithLimit(options.LimitOne),
+ options.WithQueryShape(queryshape.FindResourceOfType),
+ )
+ if err := errorIfTupleIteratorReturnsTuples(
+ ctx,
+ qy,
+ qyErr,
+ "cannot delete object definition `%s`, as a relationship exists under it",
+ namespaceName,
+ ); err != nil {
+ return err
+ }
+
+ qy, qyErr = rwt.ReverseQueryRelationships(
+ ctx,
+ datastore.SubjectsFilter{
+ SubjectType: namespaceName,
+ },
+ options.WithLimitForReverse(options.LimitOne),
+ options.WithQueryShapeForReverse(queryshape.FindSubjectOfType),
+ )
+ err := errorIfTupleIteratorReturnsTuples(
+ ctx,
+ qy,
+ qyErr,
+ "cannot delete object definition `%s`, as a relationship references it",
+ namespaceName,
+ )
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// sanityCheckNamespaceChanges ensures that a namespace definition being written does not result
+// in breaking changes, such as relationships without associated defined schema object definitions
+// and relations.
+func sanityCheckNamespaceChanges(
+ ctx context.Context,
+ rwt datastore.ReadWriteTransaction,
+ nsdef *core.NamespaceDefinition,
+ existingDefs map[string]*core.NamespaceDefinition,
+) (*nsdiff.Diff, error) {
+ // Ensure that the updated namespace does not break the existing tuple data.
+ existing := existingDefs[nsdef.Name]
+ diff, err := nsdiff.DiffNamespaces(existing, nsdef)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, delta := range diff.Deltas() {
+ switch delta.Type {
+ case nsdiff.RemovedRelation:
+ // NOTE: We add the subject filters here to ensure the reverse relationship index is used
+ // by the datastores. As there is no index that has {namespace, relation} directly, but there
+ // *is* an index that has {subject_namespace, subject_relation, namespace, relation}, we can
+ // force the datastore to use the reverse index by adding the subject filters.
+ var previousRelation *core.Relation
+ for _, relation := range existing.Relation {
+ if relation.Name == delta.RelationName {
+ previousRelation = relation
+ break
+ }
+ }
+
+ if previousRelation == nil {
+ return nil, spiceerrors.MustBugf("relation `%s` not found in existing namespace definition", delta.RelationName)
+ }
+
+ subjectSelectors := make([]datastore.SubjectsSelector, 0, len(previousRelation.TypeInformation.AllowedDirectRelations))
+ for _, allowedType := range previousRelation.TypeInformation.AllowedDirectRelations {
+ if allowedType.GetRelation() == datastore.Ellipsis {
+ subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{
+ OptionalSubjectType: allowedType.Namespace,
+ RelationFilter: datastore.SubjectRelationFilter{
+ IncludeEllipsisRelation: true,
+ },
+ })
+ } else {
+ subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{
+ OptionalSubjectType: allowedType.Namespace,
+ RelationFilter: datastore.SubjectRelationFilter{
+ NonEllipsisRelation: allowedType.GetRelation(),
+ },
+ })
+ }
+ }
+
+ qy, qyErr := rwt.QueryRelationships(
+ ctx,
+ datastore.RelationshipsFilter{
+ OptionalResourceType: nsdef.Name,
+ OptionalResourceRelation: delta.RelationName,
+ OptionalSubjectsSelectors: subjectSelectors,
+ },
+ options.WithLimit(options.LimitOne),
+ options.WithQueryShape(queryshape.FindResourceOfTypeAndRelation),
+ )
+
+ err = errorIfTupleIteratorReturnsTuples(
+ ctx,
+ qy,
+ qyErr,
+ "cannot delete relation `%s` in object definition `%s`, as a relationship exists under it", delta.RelationName, nsdef.Name)
+ if err != nil {
+ return diff, err
+ }
+
+ // Also check for right sides of tuples.
+ qy, qyErr = rwt.ReverseQueryRelationships(
+ ctx,
+ datastore.SubjectsFilter{
+ SubjectType: nsdef.Name,
+ RelationFilter: datastore.SubjectRelationFilter{
+ NonEllipsisRelation: delta.RelationName,
+ },
+ },
+ options.WithLimitForReverse(options.LimitOne),
+ options.WithQueryShapeForReverse(queryshape.FindSubjectOfTypeAndRelation),
+ )
+ err = errorIfTupleIteratorReturnsTuples(
+ ctx,
+ qy,
+ qyErr,
+ "cannot delete relation `%s` in object definition `%s`, as a relationship references it", delta.RelationName, nsdef.Name)
+ if err != nil {
+ return diff, err
+ }
+
+ case nsdiff.RelationAllowedTypeRemoved:
+ var optionalSubjectIds []string
+ var relationFilter datastore.SubjectRelationFilter
+ var optionalCaveatNameFilter datastore.CaveatNameFilter
+
+ if delta.AllowedType.GetPublicWildcard() != nil {
+ optionalSubjectIds = []string{tuple.PublicWildcard}
+ } else {
+ relationFilter = datastore.SubjectRelationFilter{
+ NonEllipsisRelation: delta.AllowedType.GetRelation(),
+ }
+ }
+
+ if delta.AllowedType.GetRequiredCaveat() != nil && delta.AllowedType.GetRequiredCaveat().CaveatName != "" {
+ optionalCaveatNameFilter = datastore.WithCaveatName(delta.AllowedType.GetRequiredCaveat().CaveatName)
+ } else {
+ optionalCaveatNameFilter = datastore.WithNoCaveat()
+ }
+
+ expirationOption := datastore.ExpirationFilterOptionNoExpiration
+ if delta.AllowedType.RequiredExpiration != nil {
+ expirationOption = datastore.ExpirationFilterOptionHasExpiration
+ }
+
+ qyr, qyrErr := rwt.QueryRelationships(
+ ctx,
+ datastore.RelationshipsFilter{
+ OptionalResourceType: nsdef.Name,
+ OptionalResourceRelation: delta.RelationName,
+ OptionalSubjectsSelectors: []datastore.SubjectsSelector{
+ {
+ OptionalSubjectType: delta.AllowedType.Namespace,
+ OptionalSubjectIds: optionalSubjectIds,
+ RelationFilter: relationFilter,
+ },
+ },
+ OptionalCaveatNameFilter: optionalCaveatNameFilter,
+ OptionalExpirationOption: expirationOption,
+ },
+ options.WithLimit(options.LimitOne),
+ options.WithQueryShape(queryshape.FindResourceRelationForSubjectRelation),
+ )
+ err = errorIfTupleIteratorReturnsTuples(
+ ctx,
+ qyr,
+ qyrErr,
+ "cannot remove allowed type `%s` from relation `%s` in object definition `%s`, as a relationship exists with it",
+ schema.SourceForAllowedRelation(delta.AllowedType), delta.RelationName, nsdef.Name)
+ if err != nil {
+ return diff, err
+ }
+ }
+ }
+ return diff, nil
+}
+
+// errorIfTupleIteratorReturnsTuples takes a tuple iterator and any error that was generated
+// when the original iterator was created, and returns an error if iterator contains any tuples.
+func errorIfTupleIteratorReturnsTuples(_ context.Context, qy datastore.RelationshipIterator, qyErr error, message string, args ...interface{}) error {
+ if qyErr != nil {
+ return qyErr
+ }
+
+ for _, err := range qy {
+ if err != nil {
+ return err
+ }
+ return NewSchemaWriteDataValidationError(message, args...)
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go b/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go
new file mode 100644
index 0000000..819452e
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go
@@ -0,0 +1,332 @@
+package v1
+
+import (
+ "context"
+ "slices"
+ "sync"
+ "time"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/jzelinskie/stringz"
+ "google.golang.org/grpc/status"
+ "google.golang.org/protobuf/types/known/durationpb"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/graph"
+ "github.com/authzed/spicedb/internal/graph/computed"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/middleware/usagemetrics"
+ "github.com/authzed/spicedb/internal/namespace"
+ "github.com/authzed/spicedb/internal/services/shared"
+ "github.com/authzed/spicedb/internal/taskrunner"
+ "github.com/authzed/spicedb/internal/telemetry"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/genutil"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/genutil/slicez"
+ "github.com/authzed/spicedb/pkg/middleware/consistency"
+ dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// bulkChecker contains the logic to allow ExperimentalService/BulkCheckPermission and
+// PermissionsService/CheckBulkPermissions to share the same implementation.
+type bulkChecker struct {
+ maxAPIDepth uint32
+ maxCaveatContextSize int
+ maxConcurrency uint16
+ caveatTypeSet *caveattypes.TypeSet
+
+ dispatch dispatch.Dispatcher
+ dispatchChunkSize uint16
+}
+
+const maxBulkCheckCount = 10000
+
+func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) {
+ telemetry.RecordLogicalChecks(uint64(len(req.Items)))
+
+ atRevision, checkedAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(req.Items) > maxBulkCheckCount {
+ return nil, NewExceedsMaximumChecksErr(uint64(len(req.Items)), maxBulkCheckCount)
+ }
+
+ // Compute a hash for each requested item and record its index(es) for the items, to be used for sorting of results.
+ itemCount, err := genutil.EnsureUInt32(len(req.Items))
+ if err != nil {
+ return nil, err
+ }
+
+ itemIndexByHash := mapz.NewMultiMapWithCap[string, int](itemCount)
+ for index, item := range req.Items {
+ itemHash, err := computeCheckBulkPermissionsItemHash(item)
+ if err != nil {
+ return nil, err
+ }
+
+ itemIndexByHash.Add(itemHash, index)
+ }
+
+ // Identify checks with same permission+subject over different resources and group them. This is doable because
+ // the dispatching system already internally supports this kind of batching for performance.
+ groupedItems, err := groupItems(ctx, groupingParameters{
+ atRevision: atRevision,
+ maxCaveatContextSize: bc.maxCaveatContextSize,
+ maximumAPIDepth: bc.maxAPIDepth,
+ withTracing: req.WithTracing,
+ }, req.Items)
+ if err != nil {
+ return nil, err
+ }
+
+ bulkResponseMutex := sync.Mutex{}
+
+ spiceerrors.DebugAssert(func() bool {
+ return bc.maxConcurrency > 0
+ }, "max concurrency must be greater than 0 in bulk check")
+
+ tr := taskrunner.NewPreloadedTaskRunner(ctx, bc.maxConcurrency, len(groupedItems))
+
+ respMetadata := &dispatchv1.ResponseMeta{
+ DispatchCount: 1,
+ CachedDispatchCount: 0,
+ DepthRequired: 1,
+ DebugInfo: nil,
+ }
+ usagemetrics.SetInContext(ctx, respMetadata)
+
+ orderedPairs := make([]*v1.CheckBulkPermissionsPair, len(req.Items))
+
+ addPair := func(pair *v1.CheckBulkPermissionsPair) error {
+ pairItemHash, err := computeCheckBulkPermissionsItemHash(pair.Request)
+ if err != nil {
+ return err
+ }
+
+ found, ok := itemIndexByHash.Get(pairItemHash)
+ if !ok {
+ return spiceerrors.MustBugf("missing expected item hash")
+ }
+
+ for _, index := range found {
+ orderedPairs[index] = pair
+ }
+
+ return nil
+ }
+
+ appendResultsForError := func(params *computed.CheckParameters, resourceIDs []string, err error) error {
+ rewritten := shared.RewriteError(ctx, err, &shared.ConfigForErrors{
+ MaximumAPIDepth: bc.maxAPIDepth,
+ })
+ statusResp, ok := status.FromError(rewritten)
+ if !ok {
+ // If error is not a gRPC Status, fail the entire bulk check request.
+ return err
+ }
+
+ bulkResponseMutex.Lock()
+ defer bulkResponseMutex.Unlock()
+
+ for _, resourceID := range resourceIDs {
+ reqItem, err := requestItemFromResourceAndParameters(params, resourceID)
+ if err != nil {
+ return err
+ }
+
+ if err := addPair(&v1.CheckBulkPermissionsPair{
+ Request: reqItem,
+ Response: &v1.CheckBulkPermissionsPair_Error{
+ Error: statusResp.Proto(),
+ },
+ }); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+
+ appendResultsForCheck := func(
+ params *computed.CheckParameters,
+ resourceIDs []string,
+ metadata *dispatchv1.ResponseMeta,
+ debugInfos []*dispatchv1.DebugInformation,
+ results map[string]*dispatchv1.ResourceCheckResult,
+ ) error {
+ bulkResponseMutex.Lock()
+ defer bulkResponseMutex.Unlock()
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+
+ schemaText := ""
+ if len(debugInfos) > 0 {
+ schema, err := getFullSchema(ctx, ds)
+ if err != nil {
+ return err
+ }
+ schemaText = schema
+ }
+
+ for _, resourceID := range resourceIDs {
+ var debugTrace *v1.DebugInformation
+ if len(debugInfos) > 0 {
+ // Find the debug info that matches the resource ID.
+ var debugInfo *dispatchv1.DebugInformation
+ for _, di := range debugInfos {
+ if slices.Contains(di.Check.Request.ResourceIds, resourceID) {
+ debugInfo = di
+ break
+ }
+ }
+
+ if debugInfo != nil {
+ // Synthesize a new debug information with a trace "wrapping" the (potentially batched)
+ // trace.
+ localResults := make(map[string]*dispatchv1.ResourceCheckResult, 1)
+ if result, ok := results[resourceID]; ok {
+ localResults[resourceID] = result
+ }
+ wrappedDebugInfo := &dispatchv1.DebugInformation{
+ Check: &dispatchv1.CheckDebugTrace{
+ Request: &dispatchv1.DispatchCheckRequest{
+ ResourceRelation: debugInfo.Check.Request.ResourceRelation,
+ ResourceIds: []string{resourceID},
+ Subject: debugInfo.Check.Request.Subject,
+ ResultsSetting: debugInfo.Check.Request.ResultsSetting,
+ Debug: debugInfo.Check.Request.Debug,
+ },
+ ResourceRelationType: debugInfo.Check.ResourceRelationType,
+ IsCachedResult: false,
+ SubProblems: []*dispatchv1.CheckDebugTrace{
+ debugInfo.Check,
+ },
+ Results: localResults,
+ Duration: durationpb.New(time.Duration(0)),
+ TraceId: graph.NewTraceID(),
+ SourceId: debugInfo.Check.SourceId,
+ },
+ }
+
+ // Convert to debug information.
+ dt, err := convertCheckDispatchDebugInformationWithSchema(ctx, params.CaveatContext, wrappedDebugInfo, ds, bc.caveatTypeSet, schemaText)
+ if err != nil {
+ return err
+ }
+ debugTrace = dt
+ }
+ }
+
+ reqItem, err := requestItemFromResourceAndParameters(params, resourceID)
+ if err != nil {
+ return err
+ }
+
+ if err := addPair(&v1.CheckBulkPermissionsPair{
+ Request: reqItem,
+ Response: pairItemFromCheckResult(results[resourceID], debugTrace),
+ }); err != nil {
+ return err
+ }
+ }
+
+ respMetadata.DispatchCount += metadata.DispatchCount
+ respMetadata.CachedDispatchCount += metadata.CachedDispatchCount
+ return nil
+ }
+
+ for _, group := range groupedItems {
+ group := group
+
+ slicez.ForEachChunk(group.resourceIDs, bc.dispatchChunkSize, func(resourceIDs []string) {
+ tr.Add(func(ctx context.Context) error {
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+
+ // Ensure the check namespaces and relations are valid.
+ err := namespace.CheckNamespaceAndRelations(ctx,
+ []namespace.TypeAndRelationToCheck{
+ {
+ NamespaceName: group.params.ResourceType.ObjectType,
+ RelationName: group.params.ResourceType.Relation,
+ AllowEllipsis: false,
+ },
+ {
+ NamespaceName: group.params.Subject.ObjectType,
+ RelationName: stringz.DefaultEmpty(group.params.Subject.Relation, graph.Ellipsis),
+ AllowEllipsis: true,
+ },
+ }, ds)
+ if err != nil {
+ return appendResultsForError(group.params, resourceIDs, err)
+ }
+
+ // Call bulk check to compute the check result(s) for the resource ID(s).
+ rcr, metadata, debugInfos, err := computed.ComputeBulkCheck(ctx, bc.dispatch, bc.caveatTypeSet, *group.params, resourceIDs, bc.dispatchChunkSize)
+ if err != nil {
+ return appendResultsForError(group.params, resourceIDs, err)
+ }
+
+ return appendResultsForCheck(group.params, resourceIDs, metadata, debugInfos, rcr)
+ })
+ })
+ }
+
+ // Run the checks in parallel.
+ if err := tr.StartAndWait(); err != nil {
+ return nil, err
+ }
+
+ return &v1.CheckBulkPermissionsResponse{CheckedAt: checkedAt, Pairs: orderedPairs}, nil
+}
+
+func toCheckBulkPermissionsRequest(req *v1.BulkCheckPermissionRequest) *v1.CheckBulkPermissionsRequest {
+ items := make([]*v1.CheckBulkPermissionsRequestItem, len(req.Items))
+ for i, item := range req.Items {
+ items[i] = &v1.CheckBulkPermissionsRequestItem{
+ Resource: item.Resource,
+ Permission: item.Permission,
+ Subject: item.Subject,
+ Context: item.Context,
+ }
+ }
+
+ return &v1.CheckBulkPermissionsRequest{Items: items}
+}
+
+func toBulkCheckPermissionResponse(resp *v1.CheckBulkPermissionsResponse) *v1.BulkCheckPermissionResponse {
+ pairs := make([]*v1.BulkCheckPermissionPair, len(resp.Pairs))
+ for i, pair := range resp.Pairs {
+ pairs[i] = &v1.BulkCheckPermissionPair{}
+ pairs[i].Request = &v1.BulkCheckPermissionRequestItem{
+ Resource: pair.Request.Resource,
+ Permission: pair.Request.Permission,
+ Subject: pair.Request.Subject,
+ Context: pair.Request.Context,
+ }
+
+ switch t := pair.Response.(type) {
+ case *v1.CheckBulkPermissionsPair_Item:
+ pairs[i].Response = &v1.BulkCheckPermissionPair_Item{
+ Item: &v1.BulkCheckPermissionResponseItem{
+ Permissionship: t.Item.Permissionship,
+ PartialCaveatInfo: t.Item.PartialCaveatInfo,
+ },
+ }
+ case *v1.CheckBulkPermissionsPair_Error:
+ pairs[i].Response = &v1.BulkCheckPermissionPair_Error{
+ Error: t.Error,
+ }
+ default:
+ panic("unknown CheckBulkPermissionResponse pair response type")
+ }
+ }
+
+ return &v1.BulkCheckPermissionResponse{
+ CheckedAt: resp.CheckedAt,
+ Pairs: pairs,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go b/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go
new file mode 100644
index 0000000..712f9ec
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go
@@ -0,0 +1,238 @@
+package v1
+
+import (
+ "cmp"
+ "context"
+ "slices"
+ "strings"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ cexpr "github.com/authzed/spicedb/internal/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/generator"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// ConvertCheckDispatchDebugInformation converts dispatch debug information found in the response metadata
+// into DebugInformation returnable to the API.
+func ConvertCheckDispatchDebugInformation(
+ ctx context.Context,
+ caveatTypeSet *caveattypes.TypeSet,
+ caveatContext map[string]any,
+ debugInfo *dispatch.DebugInformation,
+ reader datastore.Reader,
+) (*v1.DebugInformation, error) {
+ if debugInfo == nil {
+ return nil, nil
+ }
+
+ schema, err := getFullSchema(ctx, reader)
+ if err != nil {
+ return nil, err
+ }
+
+ return convertCheckDispatchDebugInformationWithSchema(ctx, caveatContext, debugInfo, reader, caveatTypeSet, schema)
+}
+
+// getFullSchema returns the full schema from the reader.
+func getFullSchema(ctx context.Context, reader datastore.Reader) (string, error) {
+ caveats, err := reader.ListAllCaveats(ctx)
+ if err != nil {
+ return "", err
+ }
+
+ namespaces, err := reader.ListAllNamespaces(ctx)
+ if err != nil {
+ return "", err
+ }
+
+ defs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats))
+ for _, caveat := range caveats {
+ defs = append(defs, caveat.Definition)
+ }
+ for _, ns := range namespaces {
+ defs = append(defs, ns.Definition)
+ }
+
+ schema, _, err := generator.GenerateSchema(defs)
+ if err != nil {
+ return "", err
+ }
+
+ return schema, nil
+}
+
+func convertCheckDispatchDebugInformationWithSchema(
+ ctx context.Context,
+ caveatContext map[string]any,
+ debugInfo *dispatch.DebugInformation,
+ reader datastore.Reader,
+ caveatTypeSet *caveattypes.TypeSet,
+ schema string,
+) (*v1.DebugInformation, error) {
+ converted, err := convertCheckTrace(ctx, caveatContext, debugInfo.Check, reader, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ return &v1.DebugInformation{
+ Check: converted,
+ SchemaUsed: strings.TrimSpace(schema),
+ }, nil
+}
+
+func convertCheckTrace(ctx context.Context, caveatContext map[string]any, ct *dispatch.CheckDebugTrace, reader datastore.Reader, caveatTypeSet *caveattypes.TypeSet) (*v1.CheckDebugTrace, error) {
+ permissionType := v1.CheckDebugTrace_PERMISSION_TYPE_UNSPECIFIED
+ if ct.ResourceRelationType == dispatch.CheckDebugTrace_PERMISSION {
+ permissionType = v1.CheckDebugTrace_PERMISSION_TYPE_PERMISSION
+ } else if ct.ResourceRelationType == dispatch.CheckDebugTrace_RELATION {
+ permissionType = v1.CheckDebugTrace_PERMISSION_TYPE_RELATION
+ }
+
+ subRelation := ct.Request.Subject.Relation
+ if subRelation == tuple.Ellipsis {
+ subRelation = ""
+ }
+
+ permissionship := v1.CheckDebugTrace_PERMISSIONSHIP_NO_PERMISSION
+ var partialResults []*dispatch.ResourceCheckResult
+ for _, checkResult := range ct.Results {
+ if checkResult.Membership == dispatch.ResourceCheckResult_MEMBER {
+ permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_HAS_PERMISSION
+ break
+ }
+
+ if checkResult.Membership == dispatch.ResourceCheckResult_CAVEATED_MEMBER && permissionship != v1.CheckDebugTrace_PERMISSIONSHIP_HAS_PERMISSION {
+ permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION
+ partialResults = append(partialResults, checkResult)
+ }
+ }
+
+ var caveatEvalInfo *v1.CaveatEvalInfo
+
+ // NOTE: Bulk check gives the *fully resolved* results, rather than the result pre-caveat
+ // evaluation. In that case, we skip re-evaluating here.
+ // TODO(jschorr): Add support for evaluating *each* result distinctly.
+ if permissionship == v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION && len(partialResults) == 1 &&
+ len(partialResults[0].MissingExprFields) == 0 {
+ partialCheckResult := partialResults[0]
+ spiceerrors.DebugAssertNotNil(partialCheckResult.Expression, "got nil caveat expression")
+
+ computedResult, err := cexpr.RunSingleCaveatExpression(ctx, caveatTypeSet, partialCheckResult.Expression, caveatContext, reader, cexpr.RunCaveatExpressionWithDebugInformation)
+ if err != nil {
+ return nil, err
+ }
+
+ var partialCaveatInfo *v1.PartialCaveatInfo
+ caveatResult := v1.CaveatEvalInfo_RESULT_FALSE
+ if computedResult.Value() {
+ caveatResult = v1.CaveatEvalInfo_RESULT_TRUE
+ } else if computedResult.IsPartial() {
+ caveatResult = v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT
+ missingNames, _ := computedResult.MissingVarNames()
+ partialCaveatInfo = &v1.PartialCaveatInfo{
+ MissingRequiredContext: missingNames,
+ }
+ }
+
+ exprString, contextStruct, err := cexpr.BuildDebugInformation(computedResult)
+ if err != nil {
+ return nil, err
+ }
+
+ caveatName := ""
+ if partialCheckResult.Expression.GetCaveat() != nil {
+ caveatName = partialCheckResult.Expression.GetCaveat().CaveatName
+ }
+
+ caveatEvalInfo = &v1.CaveatEvalInfo{
+ Expression: exprString,
+ Result: caveatResult,
+ Context: contextStruct,
+ PartialCaveatInfo: partialCaveatInfo,
+ CaveatName: caveatName,
+ }
+ }
+
+ // If there is more than a single result, mark the overall permissionship
+ // as unspecified if *all* results needed to be true and at least one is not.
+ if len(ct.Request.ResourceIds) > 1 && ct.Request.ResultsSetting == dispatch.DispatchCheckRequest_REQUIRE_ALL_RESULTS {
+ for _, resourceID := range ct.Request.ResourceIds {
+ if result, ok := ct.Results[resourceID]; !ok || result.Membership != dispatch.ResourceCheckResult_MEMBER {
+ permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_UNSPECIFIED
+ break
+ }
+ }
+ }
+
+ if len(ct.SubProblems) > 0 {
+ subProblems := make([]*v1.CheckDebugTrace, 0, len(ct.SubProblems))
+ for _, subProblem := range ct.SubProblems {
+ converted, err := convertCheckTrace(ctx, caveatContext, subProblem, reader, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ subProblems = append(subProblems, converted)
+ }
+
+ slices.SortFunc(subProblems, func(a, b *v1.CheckDebugTrace) int {
+ return cmp.Compare(tuple.V1StringObjectRef(a.Resource), tuple.V1StringObjectRef(a.Resource))
+ })
+
+ return &v1.CheckDebugTrace{
+ TraceOperationId: ct.TraceId,
+ Resource: &v1.ObjectReference{
+ ObjectType: ct.Request.ResourceRelation.Namespace,
+ ObjectId: strings.Join(ct.Request.ResourceIds, ","),
+ },
+ Permission: ct.Request.ResourceRelation.Relation,
+ PermissionType: permissionType,
+ Subject: &v1.SubjectReference{
+ Object: &v1.ObjectReference{
+ ObjectType: ct.Request.Subject.Namespace,
+ ObjectId: ct.Request.Subject.ObjectId,
+ },
+ OptionalRelation: subRelation,
+ },
+ CaveatEvaluationInfo: caveatEvalInfo,
+ Result: permissionship,
+ Resolution: &v1.CheckDebugTrace_SubProblems_{
+ SubProblems: &v1.CheckDebugTrace_SubProblems{
+ Traces: subProblems,
+ },
+ },
+ Duration: ct.Duration,
+ Source: ct.SourceId,
+ }, nil
+ }
+
+ return &v1.CheckDebugTrace{
+ TraceOperationId: ct.TraceId,
+ Resource: &v1.ObjectReference{
+ ObjectType: ct.Request.ResourceRelation.Namespace,
+ ObjectId: strings.Join(ct.Request.ResourceIds, ","),
+ },
+ Permission: ct.Request.ResourceRelation.Relation,
+ PermissionType: permissionType,
+ Subject: &v1.SubjectReference{
+ Object: &v1.ObjectReference{
+ ObjectType: ct.Request.Subject.Namespace,
+ ObjectId: ct.Request.Subject.ObjectId,
+ },
+ OptionalRelation: subRelation,
+ },
+ CaveatEvaluationInfo: caveatEvalInfo,
+ Result: permissionship,
+ Resolution: &v1.CheckDebugTrace_WasCachedResult{
+ WasCachedResult: ct.IsCachedResult,
+ },
+ Duration: ct.Duration,
+ Source: ct.SourceId,
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go b/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go
new file mode 100644
index 0000000..6de6749
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go
@@ -0,0 +1,511 @@
+package v1
+
+import (
+ "fmt"
+ "strconv"
+
+ "github.com/rs/zerolog"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+ "google.golang.org/protobuf/proto"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+// ExceedsMaximumLimitError occurs when a limit that is too large is given to a call.
+type ExceedsMaximumLimitError struct {
+ error
+ providedLimit uint64
+ maxLimitAllowed uint64
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err ExceedsMaximumLimitError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Uint64("providedLimit", err.providedLimit).Uint64("maxLimitAllowed", err.maxLimitAllowed)
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err ExceedsMaximumLimitError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_EXCEEDS_MAXIMUM_ALLOWABLE_LIMIT,
+ map[string]string{
+ "limit_provided": strconv.FormatUint(err.providedLimit, 10),
+ "maximum_limit_allowed": strconv.FormatUint(err.maxLimitAllowed, 10),
+ },
+ ),
+ )
+}
+
+// NewExceedsMaximumLimitErr creates a new error representing that the limit specified was too large.
+func NewExceedsMaximumLimitErr(providedLimit uint64, maxLimitAllowed uint64) ExceedsMaximumLimitError {
+ return ExceedsMaximumLimitError{
+ error: fmt.Errorf("provided limit %d is greater than maximum allowed of %d", providedLimit, maxLimitAllowed),
+ providedLimit: providedLimit,
+ maxLimitAllowed: maxLimitAllowed,
+ }
+}
+
+// ExceedsMaximumChecksError occurs when too many checks are given to a call.
+type ExceedsMaximumChecksError struct {
+ error
+ checkCount uint64
+ maxCountAllowed uint64
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err ExceedsMaximumChecksError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Uint64("checkCount", err.checkCount).Uint64("maxCountAllowed", err.maxCountAllowed)
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err ExceedsMaximumChecksError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_UNSPECIFIED,
+ map[string]string{
+ "check_count": strconv.FormatUint(err.checkCount, 10),
+ "maximum_checks_allowed": strconv.FormatUint(err.maxCountAllowed, 10),
+ },
+ ),
+ )
+}
+
+// NewExceedsMaximumChecksErr creates a new error representing that too many updates were given to a BulkCheckPermissions call.
+func NewExceedsMaximumChecksErr(checkCount uint64, maxCountAllowed uint64) ExceedsMaximumChecksError {
+ return ExceedsMaximumChecksError{
+ error: fmt.Errorf("check count of %d is greater than maximum allowed of %d", checkCount, maxCountAllowed),
+ checkCount: checkCount,
+ maxCountAllowed: maxCountAllowed,
+ }
+}
+
+// ExceedsMaximumUpdatesError occurs when too many updates are given to a call.
+type ExceedsMaximumUpdatesError struct {
+ error
+ updateCount uint64
+ maxCountAllowed uint64
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err ExceedsMaximumUpdatesError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Uint64("updateCount", err.updateCount).Uint64("maxCountAllowed", err.maxCountAllowed)
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err ExceedsMaximumUpdatesError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_TOO_MANY_UPDATES_IN_REQUEST,
+ map[string]string{
+ "update_count": strconv.FormatUint(err.updateCount, 10),
+ "maximum_updates_allowed": strconv.FormatUint(err.maxCountAllowed, 10),
+ },
+ ),
+ )
+}
+
+// NewExceedsMaximumUpdatesErr creates a new error representing that too many updates were given to a WriteRelationships call.
+func NewExceedsMaximumUpdatesErr(updateCount uint64, maxCountAllowed uint64) ExceedsMaximumUpdatesError {
+ return ExceedsMaximumUpdatesError{
+ error: fmt.Errorf("update count of %d is greater than maximum allowed of %d", updateCount, maxCountAllowed),
+ updateCount: updateCount,
+ maxCountAllowed: maxCountAllowed,
+ }
+}
+
+// ExceedsMaximumPreconditionsError occurs when too many preconditions are given to a call.
+type ExceedsMaximumPreconditionsError struct {
+ error
+ preconditionCount uint64
+ maxCountAllowed uint64
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err ExceedsMaximumPreconditionsError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Uint64("preconditionCount", err.preconditionCount).Uint64("maxCountAllowed", err.maxCountAllowed)
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err ExceedsMaximumPreconditionsError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_TOO_MANY_PRECONDITIONS_IN_REQUEST,
+ map[string]string{
+ "precondition_count": strconv.FormatUint(err.preconditionCount, 10),
+ "maximum_updates_allowed": strconv.FormatUint(err.maxCountAllowed, 10),
+ },
+ ),
+ )
+}
+
+// NewExceedsMaximumPreconditionsErr creates a new error representing that too many preconditions were given to a call.
+func NewExceedsMaximumPreconditionsErr(preconditionCount uint64, maxCountAllowed uint64) ExceedsMaximumPreconditionsError {
+ return ExceedsMaximumPreconditionsError{
+ error: fmt.Errorf(
+ "precondition count of %d is greater than maximum allowed of %d",
+ preconditionCount,
+ maxCountAllowed),
+ preconditionCount: preconditionCount,
+ maxCountAllowed: maxCountAllowed,
+ }
+}
+
+// PreconditionFailedError occurs when the precondition to a write tuple call does not match.
+type PreconditionFailedError struct {
+ error
+ precondition *v1.Precondition
+}
+
+// MarshalZerologObject implements zerolog object marshalling.
+func (err PreconditionFailedError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Interface("precondition", err.precondition)
+}
+
+// NewPreconditionFailedErr constructs a new precondition failed error.
+func NewPreconditionFailedErr(precondition *v1.Precondition) error {
+ return PreconditionFailedError{
+ error: fmt.Errorf("unable to satisfy write precondition `%s`", precondition),
+ precondition: precondition,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err PreconditionFailedError) GRPCStatus() *status.Status {
+ metadata := map[string]string{
+ "precondition_operation": v1.Precondition_Operation_name[int32(err.precondition.Operation)],
+ }
+
+ if err.precondition.Filter.ResourceType != "" {
+ metadata["precondition_resource_type"] = err.precondition.Filter.ResourceType
+ }
+
+ if err.precondition.Filter.OptionalResourceId != "" {
+ metadata["precondition_resource_id"] = err.precondition.Filter.OptionalResourceId
+ }
+
+ if err.precondition.Filter.OptionalResourceIdPrefix != "" {
+ metadata["precondition_resource_id_prefix"] = err.precondition.Filter.OptionalResourceIdPrefix
+ }
+
+ if err.precondition.Filter.OptionalRelation != "" {
+ metadata["precondition_relation"] = err.precondition.Filter.OptionalRelation
+ }
+
+ if err.precondition.Filter.OptionalSubjectFilter != nil {
+ metadata["precondition_subject_type"] = err.precondition.Filter.OptionalSubjectFilter.SubjectType
+
+ if err.precondition.Filter.OptionalSubjectFilter.OptionalSubjectId != "" {
+ metadata["precondition_subject_id"] = err.precondition.Filter.OptionalSubjectFilter.OptionalSubjectId
+ }
+
+ if err.precondition.Filter.OptionalSubjectFilter.OptionalRelation != nil {
+ metadata["precondition_subject_relation"] = err.precondition.Filter.OptionalSubjectFilter.OptionalRelation.Relation
+ }
+ }
+
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.FailedPrecondition,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_WRITE_OR_DELETE_PRECONDITION_FAILURE,
+ metadata,
+ ),
+ )
+}
+
+// DuplicateRelationErrorshipError indicates that an update was attempted on the same relationship.
+type DuplicateRelationErrorshipError struct {
+ error
+ update *v1.RelationshipUpdate
+}
+
+// NewDuplicateRelationshipErr constructs a new invalid subject error.
+func NewDuplicateRelationshipErr(update *v1.RelationshipUpdate) DuplicateRelationErrorshipError {
+ return DuplicateRelationErrorshipError{
+ error: fmt.Errorf(
+ "found more than one update with relationship `%s` in this request; a relationship can only be specified in an update once per overall WriteRelationships request",
+ tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship),
+ ),
+ update: update,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err DuplicateRelationErrorshipError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_UPDATES_ON_SAME_RELATIONSHIP,
+ map[string]string{
+ "definition_name": err.update.Relationship.Resource.ObjectType,
+ "relationship": tuple.MustV1StringRelationship(err.update.Relationship),
+ },
+ ),
+ )
+}
+
+// ErrMaxRelationshipContextError indicates an attempt to write a relationship that exceeded the maximum
+// configured context size.
+type ErrMaxRelationshipContextError struct {
+ error
+ update *v1.RelationshipUpdate
+ maxAllowedSize int
+}
+
+// NewMaxRelationshipContextError constructs a new max relationship context error.
+func NewMaxRelationshipContextError(update *v1.RelationshipUpdate, maxAllowedSize int) ErrMaxRelationshipContextError {
+ return ErrMaxRelationshipContextError{
+ error: fmt.Errorf(
+ "provided relationship `%s` exceeded maximum allowed caveat size of %d",
+ tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship),
+ maxAllowedSize,
+ ),
+ update: update,
+ maxAllowedSize: maxAllowedSize,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err ErrMaxRelationshipContextError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_MAX_RELATIONSHIP_CONTEXT_SIZE,
+ map[string]string{
+ "relationship": tuple.V1StringRelationshipWithoutCaveatOrExpiration(err.update.Relationship),
+ "max_allowed_size": strconv.Itoa(err.maxAllowedSize),
+ "context_size": strconv.Itoa(proto.Size(err.update.Relationship)),
+ },
+ ),
+ )
+}
+
+// CouldNotTransactionallyDeleteError indicates that a deletion could not occur transactionally.
+type CouldNotTransactionallyDeleteError struct {
+ error
+ limit uint32
+ filter *v1.RelationshipFilter
+}
+
+// NewCouldNotTransactionallyDeleteErr constructs a new could not transactionally deleter error.
+func NewCouldNotTransactionallyDeleteErr(filter *v1.RelationshipFilter, limit uint32) CouldNotTransactionallyDeleteError {
+ return CouldNotTransactionallyDeleteError{
+ error: fmt.Errorf(
+ "found more than %d relationships to be deleted and partial deletion was not requested",
+ limit,
+ ),
+ limit: limit,
+ filter: filter,
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err CouldNotTransactionallyDeleteError) GRPCStatus() *status.Status {
+ metadata := map[string]string{
+ "limit": strconv.Itoa(int(err.limit)),
+ "filter_resource_type": err.filter.ResourceType,
+ }
+
+ if err.filter.OptionalResourceId != "" {
+ metadata["filter_resource_id"] = err.filter.OptionalResourceId
+ }
+
+ if err.filter.OptionalRelation != "" {
+ metadata["filter_relation"] = err.filter.OptionalRelation
+ }
+
+ if err.filter.OptionalSubjectFilter != nil {
+ metadata["filter_subject_type"] = err.filter.OptionalSubjectFilter.SubjectType
+
+ if err.filter.OptionalSubjectFilter.OptionalSubjectId != "" {
+ metadata["filter_subject_id"] = err.filter.OptionalSubjectFilter.OptionalSubjectId
+ }
+
+ if err.filter.OptionalSubjectFilter.OptionalRelation != nil {
+ metadata["filter_subject_relation"] = err.filter.OptionalSubjectFilter.OptionalRelation.Relation
+ }
+ }
+
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_TOO_MANY_RELATIONSHIPS_FOR_TRANSACTIONAL_DELETE,
+ metadata,
+ ),
+ )
+}
+
+// InvalidCursorError indicates that an invalid cursor was found.
+type InvalidCursorError struct {
+ error
+ reason string
+}
+
+// NewInvalidCursorErr constructs a new invalid cursor error.
+func NewInvalidCursorErr(reason string) InvalidCursorError {
+ return InvalidCursorError{
+ error: fmt.Errorf(
+ "the cursor provided is not valid: %s",
+ reason,
+ ),
+ }
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err InvalidCursorError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.FailedPrecondition,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_INVALID_CURSOR,
+ map[string]string{
+ "reason": err.reason,
+ },
+ ),
+ )
+}
+
+// InvalidFilterError indicates the specified relationship filter was invalid.
+type InvalidFilterError struct {
+ error
+
+ filter string
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err InvalidFilterError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_INVALID_FILTER,
+ map[string]string{
+ "filter": err.filter,
+ },
+ ),
+ )
+}
+
+// NewInvalidFilterErr constructs a new invalid filter error.
+func NewInvalidFilterErr(reason string, filter string) InvalidFilterError {
+ return InvalidFilterError{
+ error: fmt.Errorf(
+ "the relationship filter provided is not valid: %s", reason,
+ ),
+ filter: filter,
+ }
+}
+
+// NewEmptyPreconditionErr constructs a new empty precondition error.
+func NewEmptyPreconditionErr() EmptyPreconditionError {
+ return EmptyPreconditionError{
+ error: fmt.Errorf(
+ "one of the specified preconditions is empty",
+ ),
+ }
+}
+
+// EmptyPreconditionError indicates an empty precondition was found.
+type EmptyPreconditionError struct {
+ error
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err EmptyPreconditionError) GRPCStatus() *status.Status {
+ // TODO(jschorr): Put a proper error reason in here.
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_UNSPECIFIED,
+ map[string]string{},
+ ),
+ )
+}
+
+// NewNotAPermissionError constructs a new not a permission error.
+func NewNotAPermissionError(relationName string) NotAPermissionError {
+ return NotAPermissionError{
+ error: fmt.Errorf(
+ "the relation `%s` is not a permission", relationName,
+ ),
+ relationName: relationName,
+ }
+}
+
+// NotAPermissionError indicates that the relation is not a permission.
+type NotAPermissionError struct {
+ error
+ relationName string
+}
+
+// GRPCStatus implements retrieving the gRPC status for the error.
+func (err NotAPermissionError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_UNKNOWN_RELATION_OR_PERMISSION,
+ map[string]string{
+ "relationName": err.relationName,
+ },
+ ),
+ )
+}
+
+func defaultIfZero[T comparable](value T, defaultValue T) T {
+ var zero T
+ if value == zero {
+ return defaultValue
+ }
+ return value
+}
+
+// TransactionMetadataTooLargeError indicates that the metadata for a transaction is too large.
+type TransactionMetadataTooLargeError struct {
+ error
+ metadataSize int
+ maxSize int
+}
+
+// NewTransactionMetadataTooLargeErr constructs a new transaction metadata too large error.
+func NewTransactionMetadataTooLargeErr(metadataSize int, maxSize int) TransactionMetadataTooLargeError {
+ return TransactionMetadataTooLargeError{
+ error: fmt.Errorf("metadata size of %d is greater than maximum allowed of %d", metadataSize, maxSize),
+ metadataSize: metadataSize,
+ maxSize: maxSize,
+ }
+}
+
+func (err TransactionMetadataTooLargeError) MarshalZerologObject(e *zerolog.Event) {
+ e.Err(err.error).Int("metadataSize", err.metadataSize).Int("maxSize", err.maxSize)
+}
+
+func (err TransactionMetadataTooLargeError) GRPCStatus() *status.Status {
+ return spiceerrors.WithCodeAndDetails(
+ err,
+ codes.InvalidArgument,
+ spiceerrors.ForReason(
+ v1.ErrorReason_ERROR_REASON_TRANSACTION_METADATA_TOO_LARGE,
+ map[string]string{
+ "metadata_byte_size": strconv.Itoa(err.metadataSize),
+ "maximum_allowed_metadata_byte_size": strconv.Itoa(err.maxSize),
+ },
+ ),
+ )
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go b/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go
new file mode 100644
index 0000000..0e4b4a7
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go
@@ -0,0 +1,824 @@
+package v1
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "slices"
+ "sort"
+ "strings"
+ "time"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/jzelinskie/stringz"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/internal/middleware"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/middleware/handwrittenvalidation"
+ "github.com/authzed/spicedb/internal/middleware/streamtimeout"
+ "github.com/authzed/spicedb/internal/middleware/usagemetrics"
+ "github.com/authzed/spicedb/internal/relationships"
+ "github.com/authzed/spicedb/internal/services/shared"
+ "github.com/authzed/spicedb/internal/services/v1/options"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/cursor"
+ "github.com/authzed/spicedb/pkg/datastore"
+ dsoptions "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/middleware/consistency"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/zedtoken"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
+ "github.com/samber/lo"
+)
+
+const (
+ defaultExportBatchSizeFallback = 1_000
+ maxExportBatchSizeFallback = 10_000
+ streamReadTimeoutFallbackSeconds = 600
+)
+
+// NewExperimentalServer creates a ExperimentalServiceServer instance.
+func NewExperimentalServer(dispatch dispatch.Dispatcher, permServerConfig PermissionsServerConfig, opts ...options.ExperimentalServerOptionsOption) v1.ExperimentalServiceServer {
+ config := options.NewExperimentalServerOptionsWithOptionsAndDefaults(opts...)
+ if config.DefaultExportBatchSize == 0 {
+ log.
+ Warn().
+ Uint32("specified", config.DefaultExportBatchSize).
+ Uint32("fallback", defaultExportBatchSizeFallback).
+ Msg("experimental server config specified invalid DefaultExportBatchSize, setting to fallback")
+ config.DefaultExportBatchSize = defaultExportBatchSizeFallback
+ }
+ if config.MaxExportBatchSize == 0 {
+ fallback := permServerConfig.MaxBulkExportRelationshipsLimit
+ if fallback == 0 {
+ fallback = maxExportBatchSizeFallback
+ }
+
+ log.
+ Warn().
+ Uint32("specified", config.MaxExportBatchSize).
+ Uint32("fallback", fallback).
+ Msg("experimental server config specified invalid MaxExportBatchSize, setting to fallback")
+ config.MaxExportBatchSize = fallback
+ }
+ if config.StreamReadTimeout == 0 {
+ log.
+ Warn().
+ Stringer("specified", config.StreamReadTimeout).
+ Stringer("fallback", streamReadTimeoutFallbackSeconds*time.Second).
+ Msg("experimental server config specified invalid StreamReadTimeout, setting to fallback")
+ config.StreamReadTimeout = streamReadTimeoutFallbackSeconds * time.Second
+ }
+
+ chunkSize := permServerConfig.DispatchChunkSize
+ if chunkSize == 0 {
+ log.
+ Warn().
+ Msg("experimental server config specified invalid DispatchChunkSize, defaulting to 100")
+ chunkSize = 100
+ }
+
+ return &experimentalServer{
+ WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
+ Unary: middleware.ChainUnaryServer(
+ grpcvalidate.UnaryServerInterceptor(),
+ handwrittenvalidation.UnaryServerInterceptor,
+ usagemetrics.UnaryServerInterceptor(),
+ ),
+ Stream: middleware.ChainStreamServer(
+ grpcvalidate.StreamServerInterceptor(),
+ handwrittenvalidation.StreamServerInterceptor,
+ usagemetrics.StreamServerInterceptor(),
+ streamtimeout.MustStreamServerInterceptor(config.StreamReadTimeout),
+ ),
+ },
+ maxBatchSize: uint64(config.MaxExportBatchSize),
+ caveatTypeSet: caveattypes.TypeSetOrDefault(permServerConfig.CaveatTypeSet),
+ bulkChecker: &bulkChecker{
+ maxAPIDepth: permServerConfig.MaximumAPIDepth,
+ maxCaveatContextSize: permServerConfig.MaxCaveatContextSize,
+ maxConcurrency: config.BulkCheckMaxConcurrency,
+ dispatch: dispatch,
+ dispatchChunkSize: chunkSize,
+ caveatTypeSet: caveattypes.TypeSetOrDefault(permServerConfig.CaveatTypeSet),
+ },
+ }
+}
+
+type experimentalServer struct {
+ v1.UnimplementedExperimentalServiceServer
+ shared.WithServiceSpecificInterceptors
+
+ maxBatchSize uint64
+
+ bulkChecker *bulkChecker
+ caveatTypeSet *caveattypes.TypeSet
+}
+
+type bulkLoadAdapter struct {
+ stream v1.ExperimentalService_BulkImportRelationshipsServer
+ referencedNamespaceMap map[string]*schema.Definition
+ referencedCaveatMap map[string]*core.CaveatDefinition
+ current tuple.Relationship
+ caveat core.ContextualizedCaveat
+ caveatTypeSet *caveattypes.TypeSet
+
+ awaitingNamespaces []string
+ awaitingCaveats []string
+
+ currentBatch []*v1.Relationship
+ numSent int
+ err error
+}
+
+func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) {
+ for a.err == nil && a.numSent == len(a.currentBatch) {
+ // Load a new batch
+ batch, err := a.stream.Recv()
+ if err != nil {
+ a.err = err
+ if errors.Is(a.err, io.EOF) {
+ return nil, nil
+ }
+ return nil, a.err
+ }
+
+ a.currentBatch = batch.Relationships
+ a.numSent = 0
+
+ a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats(
+ a.currentBatch,
+ a.referencedNamespaceMap,
+ a.referencedCaveatMap,
+ )
+ }
+
+ if len(a.awaitingNamespaces) > 0 || len(a.awaitingCaveats) > 0 {
+ // Shut down the stream to give our caller a chance to fill in this information
+ return nil, nil
+ }
+
+ a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType
+ a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId
+ a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation
+ a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType
+ a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId
+ a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis)
+
+ if a.currentBatch[a.numSent].OptionalCaveat != nil {
+ a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName
+ a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context
+ a.current.OptionalCaveat = &a.caveat
+ } else {
+ a.current.OptionalCaveat = nil
+ }
+
+ if a.currentBatch[a.numSent].OptionalExpiresAt != nil {
+ t := a.currentBatch[a.numSent].OptionalExpiresAt.AsTime()
+ a.current.OptionalExpiration = &t
+ } else {
+ a.current.OptionalExpiration = nil
+ }
+
+ a.current.OptionalIntegrity = nil
+
+ if err := relationships.ValidateOneRelationship(
+ a.referencedNamespaceMap,
+ a.referencedCaveatMap,
+ a.caveatTypeSet,
+ a.current,
+ relationships.ValidateRelationshipForCreateOrTouch,
+ ); err != nil {
+ return nil, err
+ }
+
+ a.numSent++
+ return &a.current, nil
+}
+
+func extractBatchNewReferencedNamespacesAndCaveats(
+ batch []*v1.Relationship,
+ existingNamespaces map[string]*schema.Definition,
+ existingCaveats map[string]*core.CaveatDefinition,
+) ([]string, []string) {
+ newNamespaces := make(map[string]struct{}, 2)
+ newCaveats := make(map[string]struct{}, 0)
+ for _, rel := range batch {
+ if _, ok := existingNamespaces[rel.Resource.ObjectType]; !ok {
+ newNamespaces[rel.Resource.ObjectType] = struct{}{}
+ }
+ if _, ok := existingNamespaces[rel.Subject.Object.ObjectType]; !ok {
+ newNamespaces[rel.Subject.Object.ObjectType] = struct{}{}
+ }
+ if rel.OptionalCaveat != nil {
+ if _, ok := existingCaveats[rel.OptionalCaveat.CaveatName]; !ok {
+ newCaveats[rel.OptionalCaveat.CaveatName] = struct{}{}
+ }
+ }
+ }
+
+ return lo.Keys(newNamespaces), lo.Keys(newCaveats)
+}
+
+// TODO: this is now duplicate code with ImportBulkRelationships
+func (es *experimentalServer) BulkImportRelationships(stream v1.ExperimentalService_BulkImportRelationshipsServer) error {
+ ds := datastoremw.MustFromContext(stream.Context())
+
+ var numWritten uint64
+ if _, err := ds.ReadWriteTx(stream.Context(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ loadedNamespaces := make(map[string]*schema.Definition, 2)
+ loadedCaveats := make(map[string]*core.CaveatDefinition, 0)
+
+ adapter := &bulkLoadAdapter{
+ stream: stream,
+ referencedNamespaceMap: loadedNamespaces,
+ referencedCaveatMap: loadedCaveats,
+ current: tuple.Relationship{},
+ caveat: core.ContextualizedCaveat{},
+ caveatTypeSet: es.caveatTypeSet,
+ }
+ resolver := schema.ResolverForDatastoreReader(rwt)
+ ts := schema.NewTypeSystem(resolver)
+
+ var streamWritten uint64
+ var err error
+ for ; adapter.err == nil && err == nil; streamWritten, err = rwt.BulkLoad(stream.Context(), adapter) {
+ numWritten += streamWritten
+
+ // The stream has terminated because we're awaiting namespace and/or caveat information
+ if len(adapter.awaitingNamespaces) > 0 {
+ nsDefs, err := rwt.LookupNamespacesWithNames(stream.Context(), adapter.awaitingNamespaces)
+ if err != nil {
+ return err
+ }
+
+ for _, nsDef := range nsDefs {
+ newDef, err := schema.NewDefinition(ts, nsDef.Definition)
+ if err != nil {
+ return err
+ }
+
+ loadedNamespaces[nsDef.Definition.Name] = newDef
+ }
+ adapter.awaitingNamespaces = nil
+ }
+
+ if len(adapter.awaitingCaveats) > 0 {
+ caveats, err := rwt.LookupCaveatsWithNames(stream.Context(), adapter.awaitingCaveats)
+ if err != nil {
+ return err
+ }
+
+ for _, caveat := range caveats {
+ loadedCaveats[caveat.Definition.Name] = caveat.Definition
+ }
+ adapter.awaitingCaveats = nil
+ }
+ }
+ numWritten += streamWritten
+
+ return err
+ }, dsoptions.WithDisableRetries(true)); err != nil {
+ return shared.RewriteErrorWithoutConfig(stream.Context(), err)
+ }
+
+ usagemetrics.SetInContext(stream.Context(), &dispatchv1.ResponseMeta{
+ // One request for the whole load
+ DispatchCount: 1,
+ })
+
+ return stream.SendAndClose(&v1.BulkImportRelationshipsResponse{
+ NumLoaded: numWritten,
+ })
+}
+
+// TODO: this is now duplicate code with ExportBulkRelationships
+func (es *experimentalServer) BulkExportRelationships(
+ req *v1.BulkExportRelationshipsRequest,
+ resp grpc.ServerStreamingServer[v1.BulkExportRelationshipsResponse],
+) error {
+ ctx := resp.Context()
+ atRevision, _, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ return BulkExport(ctx, datastoremw.MustFromContext(ctx), es.maxBatchSize, req, atRevision, resp.Send)
+}
+
+// BulkExport implements the BulkExportRelationships API functionality. Given a datastore.Datastore, it will
+// export stream via the sender all relationships matched by the incoming request.
+// If no cursor is provided, it will fallback to the provided revision.
+func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.BulkExportRelationshipsResponse) error) error {
+ if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize {
+ return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize))
+ }
+
+ atRevision := fallbackRevision
+ var curNamespace string
+ var cur dsoptions.Cursor
+ if req.OptionalCursor != nil {
+ var err error
+ atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor)
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+ }
+
+ reader := ds.SnapshotReader(atRevision)
+
+ namespaces, err := reader.ListAllNamespaces(ctx)
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ // Make sure the namespaces are always in a stable order
+ slices.SortFunc(namespaces, func(
+ lhs datastore.RevisionedDefinition[*core.NamespaceDefinition],
+ rhs datastore.RevisionedDefinition[*core.NamespaceDefinition],
+ ) int {
+ return strings.Compare(lhs.Definition.Name, rhs.Definition.Name)
+ })
+
+ // Skip the namespaces that are already fully returned
+ for cur != nil && len(namespaces) > 0 && namespaces[0].Definition.Name < curNamespace {
+ namespaces = namespaces[1:]
+ }
+
+ limit := batchSize
+ if req.OptionalLimit > 0 {
+ limit = uint64(req.OptionalLimit)
+ }
+
+ // Pre-allocate all of the relationships that we might need in order to
+ // make export easier and faster for the garbage collector.
+ relsArray := make([]v1.Relationship, limit)
+ objArray := make([]v1.ObjectReference, limit)
+ subArray := make([]v1.SubjectReference, limit)
+ subObjArray := make([]v1.ObjectReference, limit)
+ caveatArray := make([]v1.ContextualizedCaveat, limit)
+ for i := range relsArray {
+ relsArray[i].Resource = &objArray[i]
+ relsArray[i].Subject = &subArray[i]
+ relsArray[i].Subject.Object = &subObjArray[i]
+ }
+
+ emptyRels := make([]*v1.Relationship, limit)
+ for _, ns := range namespaces {
+ rels := emptyRels
+
+ // Reset the cursor between namespaces.
+ if ns.Definition.Name != curNamespace {
+ cur = nil
+ }
+
+ // Skip this namespace if a resource type filter was specified.
+ if req.OptionalRelationshipFilter != nil && req.OptionalRelationshipFilter.ResourceType != "" {
+ if ns.Definition.Name != req.OptionalRelationshipFilter.ResourceType {
+ continue
+ }
+ }
+
+ // Setup the filter to use for the relationships.
+ relationshipFilter := datastore.RelationshipsFilter{OptionalResourceType: ns.Definition.Name}
+ if req.OptionalRelationshipFilter != nil {
+ rf, err := datastore.RelationshipsFilterFromPublicFilter(req.OptionalRelationshipFilter)
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ // Overload the namespace name with the one from the request, because each iteration is for a different namespace.
+ rf.OptionalResourceType = ns.Definition.Name
+ relationshipFilter = rf
+ }
+
+ // We want to keep iterating as long as we're sending full batches.
+ // To bootstrap this loop, we enter the first time with a full rels
+ // slice of dummy rels that were never sent.
+ for uint64(len(rels)) == limit {
+ // Lop off any rels we've already sent
+ rels = rels[:0]
+
+ relFn := func(rel tuple.Relationship) {
+ offset := len(rels)
+ rels = append(rels, &relsArray[offset]) // nozero
+
+ v1Rel := &relsArray[offset]
+ v1Rel.Resource.ObjectType = rel.RelationshipReference.Resource.ObjectType
+ v1Rel.Resource.ObjectId = rel.RelationshipReference.Resource.ObjectID
+ v1Rel.Relation = rel.RelationshipReference.Resource.Relation
+ v1Rel.Subject.Object.ObjectType = rel.RelationshipReference.Subject.ObjectType
+ v1Rel.Subject.Object.ObjectId = rel.RelationshipReference.Subject.ObjectID
+ v1Rel.Subject.OptionalRelation = denormalizeSubjectRelation(rel.RelationshipReference.Subject.Relation)
+
+ if rel.OptionalCaveat != nil {
+ caveatArray[offset].CaveatName = rel.OptionalCaveat.CaveatName
+ caveatArray[offset].Context = rel.OptionalCaveat.Context
+ v1Rel.OptionalCaveat = &caveatArray[offset]
+ } else {
+ v1Rel.OptionalCaveat = nil
+ }
+
+ if rel.OptionalExpiration != nil {
+ v1Rel.OptionalExpiresAt = timestamppb.New(*rel.OptionalExpiration)
+ } else {
+ v1Rel.OptionalExpiresAt = nil
+ }
+ }
+
+ cur, err = queryForEach(
+ ctx,
+ reader,
+ relationshipFilter,
+ relFn,
+ dsoptions.WithLimit(&limit),
+ dsoptions.WithAfter(cur),
+ dsoptions.WithSort(dsoptions.ByResource),
+ dsoptions.WithQueryShape(queryshape.Varying),
+ )
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ if len(rels) == 0 {
+ continue
+ }
+
+ encoded, err := cursor.Encode(&implv1.DecodedCursor{
+ VersionOneof: &implv1.DecodedCursor_V1{
+ V1: &implv1.V1Cursor{
+ Revision: atRevision.String(),
+ Sections: []string{
+ ns.Definition.Name,
+ tuple.MustString(*dsoptions.ToRelationship(cur)),
+ },
+ },
+ },
+ })
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ if err := sender(&v1.BulkExportRelationshipsResponse{
+ AfterResultCursor: encoded,
+ Relationships: rels,
+ }); err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+ }
+ }
+ return nil
+}
+
+func (es *experimentalServer) BulkCheckPermission(ctx context.Context, req *v1.BulkCheckPermissionRequest) (*v1.BulkCheckPermissionResponse, error) {
+ convertedReq := toCheckBulkPermissionsRequest(req)
+ res, err := es.bulkChecker.checkBulkPermissions(ctx, convertedReq)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ return toBulkCheckPermissionResponse(res), nil
+}
+
+func (es *experimentalServer) ExperimentalReflectSchema(ctx context.Context, req *v1.ExperimentalReflectSchemaRequest) (*v1.ExperimentalReflectSchemaResponse, error) {
+ // Get the current schema.
+ schema, atRevision, err := loadCurrentSchema(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ filters, err := newexpSchemaFilters(req.OptionalFilters)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ definitions := make([]*v1.ExpDefinition, 0, len(schema.ObjectDefinitions))
+ if filters.HasNamespaces() {
+ for _, ns := range schema.ObjectDefinitions {
+ def, err := expNamespaceAPIRepr(ns, filters)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ if def != nil {
+ definitions = append(definitions, def)
+ }
+ }
+ }
+
+ caveats := make([]*v1.ExpCaveat, 0, len(schema.CaveatDefinitions))
+ if filters.HasCaveats() {
+ for _, cd := range schema.CaveatDefinitions {
+ caveat, err := expCaveatAPIRepr(cd, filters, es.caveatTypeSet)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ if caveat != nil {
+ caveats = append(caveats, caveat)
+ }
+ }
+ }
+
+ return &v1.ExperimentalReflectSchemaResponse{
+ Definitions: definitions,
+ Caveats: caveats,
+ ReadAt: zedtoken.MustNewFromRevision(atRevision),
+ }, nil
+}
+
+func (es *experimentalServer) ExperimentalDiffSchema(ctx context.Context, req *v1.ExperimentalDiffSchemaRequest) (*v1.ExperimentalDiffSchemaResponse, error) {
+ atRevision, _, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ diff, existingSchema, comparisonSchema, err := schemaDiff(ctx, req.ComparisonSchema, es.caveatTypeSet)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ resp, err := expConvertDiff(diff, existingSchema, comparisonSchema, atRevision, es.caveatTypeSet)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ return resp, nil
+}
+
+func (es *experimentalServer) ExperimentalComputablePermissions(ctx context.Context, req *v1.ExperimentalComputablePermissionsRequest) (*v1.ExperimentalComputablePermissionsResponse, error) {
+ atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+ ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds))
+ vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relationName := req.RelationName
+ if relationName == "" {
+ relationName = tuple.Ellipsis
+ } else {
+ if _, ok := vdef.GetRelation(relationName); !ok {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, relationName))
+ }
+ }
+
+ allNamespaces, err := ds.ListAllNamespaces(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ allDefinitions := make([]*core.NamespaceDefinition, 0, len(allNamespaces))
+ for _, ns := range allNamespaces {
+ allDefinitions = append(allDefinitions, ns.Definition)
+ }
+
+ rg := vdef.Reachability()
+ rr, err := rg.RelationsEncounteredForSubject(ctx, allDefinitions, &core.RelationReference{
+ Namespace: req.DefinitionName,
+ Relation: relationName,
+ })
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relations := make([]*v1.ExpRelationReference, 0, len(rr))
+ for _, r := range rr {
+ if r.Namespace == req.DefinitionName && r.Relation == req.RelationName {
+ continue
+ }
+
+ if req.OptionalDefinitionNameFilter != "" && !strings.HasPrefix(r.Namespace, req.OptionalDefinitionNameFilter) {
+ continue
+ }
+
+ def, err := ts.GetValidatedDefinition(ctx, r.Namespace)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relations = append(relations, &v1.ExpRelationReference{
+ DefinitionName: r.Namespace,
+ RelationName: r.Relation,
+ IsPermission: def.IsPermission(r.Relation),
+ })
+ }
+
+ sort.Slice(relations, func(i, j int) bool {
+ if relations[i].DefinitionName == relations[j].DefinitionName {
+ return relations[i].RelationName < relations[j].RelationName
+ }
+ return relations[i].DefinitionName < relations[j].DefinitionName
+ })
+
+ return &v1.ExperimentalComputablePermissionsResponse{
+ Permissions: relations,
+ ReadAt: revisionReadAt,
+ }, nil
+}
+
+func (es *experimentalServer) ExperimentalDependentRelations(ctx context.Context, req *v1.ExperimentalDependentRelationsRequest) (*v1.ExperimentalDependentRelationsResponse, error) {
+ atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+ ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds))
+ vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ _, ok := vdef.GetRelation(req.PermissionName)
+ if !ok {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, req.PermissionName))
+ }
+
+ if !vdef.IsPermission(req.PermissionName) {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, NewNotAPermissionError(req.PermissionName))
+ }
+
+ rg := vdef.Reachability()
+ rr, err := rg.RelationsEncounteredForResource(ctx, &core.RelationReference{
+ Namespace: req.DefinitionName,
+ Relation: req.PermissionName,
+ })
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relations := make([]*v1.ExpRelationReference, 0, len(rr))
+ for _, r := range rr {
+ if r.Namespace == req.DefinitionName && r.Relation == req.PermissionName {
+ continue
+ }
+
+ ts, err := ts.GetDefinition(ctx, r.Namespace)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relations = append(relations, &v1.ExpRelationReference{
+ DefinitionName: r.Namespace,
+ RelationName: r.Relation,
+ IsPermission: ts.IsPermission(r.Relation),
+ })
+ }
+
+ sort.Slice(relations, func(i, j int) bool {
+ if relations[i].DefinitionName == relations[j].DefinitionName {
+ return relations[i].RelationName < relations[j].RelationName
+ }
+
+ return relations[i].DefinitionName < relations[j].DefinitionName
+ })
+
+ return &v1.ExperimentalDependentRelationsResponse{
+ Relations: relations,
+ ReadAt: revisionReadAt,
+ }, nil
+}
+
+func (es *experimentalServer) ExperimentalRegisterRelationshipCounter(ctx context.Context, req *v1.ExperimentalRegisterRelationshipCounterRequest) (*v1.ExperimentalRegisterRelationshipCounterResponse, error) {
+ ds := datastoremw.MustFromContext(ctx)
+
+ if req.Name == "" {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED))
+ }
+
+ _, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, rwt); err != nil {
+ return err
+ }
+
+ coreFilter := datastore.CoreFilterFromRelationshipFilter(req.RelationshipFilter)
+ return rwt.RegisterCounter(ctx, req.Name, coreFilter)
+ })
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ return &v1.ExperimentalRegisterRelationshipCounterResponse{}, nil
+}
+
+func (es *experimentalServer) ExperimentalUnregisterRelationshipCounter(ctx context.Context, req *v1.ExperimentalUnregisterRelationshipCounterRequest) (*v1.ExperimentalUnregisterRelationshipCounterResponse, error) {
+ ds := datastoremw.MustFromContext(ctx)
+
+ if req.Name == "" {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED))
+ }
+
+ _, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ return rwt.UnregisterCounter(ctx, req.Name)
+ })
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ return &v1.ExperimentalUnregisterRelationshipCounterResponse{}, nil
+}
+
+func (es *experimentalServer) ExperimentalCountRelationships(ctx context.Context, req *v1.ExperimentalCountRelationshipsRequest) (*v1.ExperimentalCountRelationshipsResponse, error) {
+ if req.Name == "" {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED))
+ }
+
+ ds := datastoremw.MustFromContext(ctx)
+ headRev, err := ds.HeadRevision(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ snapshotReader := ds.SnapshotReader(headRev)
+ count, err := snapshotReader.CountRelationships(ctx, req.Name)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ uintCount, err := safecast.ToUint64(count)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("count should not be negative")
+ }
+
+ return &v1.ExperimentalCountRelationshipsResponse{
+ CounterResult: &v1.ExperimentalCountRelationshipsResponse_ReadCounterValue{
+ ReadCounterValue: &v1.ReadCounterValue{
+ RelationshipCount: uintCount,
+ ReadAt: zedtoken.MustNewFromRevision(headRev),
+ },
+ },
+ }, nil
+}
+
+func queryForEach(
+ ctx context.Context,
+ reader datastore.Reader,
+ filter datastore.RelationshipsFilter,
+ fn func(rel tuple.Relationship),
+ opts ...dsoptions.QueryOptionsOption,
+) (dsoptions.Cursor, error) {
+ iter, err := reader.QueryRelationships(ctx, filter, opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ var cursor dsoptions.Cursor
+ for rel, err := range iter {
+ if err != nil {
+ return nil, err
+ }
+
+ fn(rel)
+ cursor = dsoptions.ToCursor(rel)
+ }
+ return cursor, nil
+}
+
+func decodeCursor(ds datastore.ReadOnlyDatastore, encoded *v1.Cursor) (datastore.Revision, string, dsoptions.Cursor, error) {
+ decoded, err := cursor.Decode(encoded)
+ if err != nil {
+ return datastore.NoRevision, "", nil, err
+ }
+
+ if decoded.GetV1() == nil {
+ return datastore.NoRevision, "", nil, errors.New("malformed cursor: no V1 in OneOf")
+ }
+
+ if len(decoded.GetV1().Sections) != 2 {
+ return datastore.NoRevision, "", nil, errors.New("malformed cursor: wrong number of components")
+ }
+
+ atRevision, err := ds.RevisionFromString(decoded.GetV1().Revision)
+ if err != nil {
+ return datastore.NoRevision, "", nil, err
+ }
+
+ cur, err := tuple.Parse(decoded.GetV1().GetSections()[1])
+ if err != nil {
+ return datastore.NoRevision, "", nil, fmt.Errorf("malformed cursor: invalid encoded relation tuple: %w", err)
+ }
+
+ // Returns the current namespace and the cursor.
+ return atRevision, decoded.GetV1().GetSections()[0], dsoptions.ToCursor(cur), nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go b/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go
new file mode 100644
index 0000000..8ef6c25
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go
@@ -0,0 +1,720 @@
+package v1
+
+import (
+ "sort"
+ "strings"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "golang.org/x/exp/maps"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/diff"
+ caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats"
+ nsdiff "github.com/authzed/spicedb/pkg/diff/namespace"
+ "github.com/authzed/spicedb/pkg/namespace"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/zedtoken"
+)
+
+type expSchemaFilters struct {
+ filters []*v1.ExpSchemaFilter
+}
+
+func newexpSchemaFilters(filters []*v1.ExpSchemaFilter) (*expSchemaFilters, error) {
+ for _, filter := range filters {
+ if filter.OptionalDefinitionNameFilter != "" {
+ if filter.OptionalCaveatNameFilter != "" {
+ return nil, NewInvalidFilterErr("cannot filter by both definition and caveat name", filter.String())
+ }
+ }
+
+ if filter.OptionalRelationNameFilter != "" {
+ if filter.OptionalDefinitionNameFilter == "" {
+ return nil, NewInvalidFilterErr("relation name match requires definition name match", filter.String())
+ }
+
+ if filter.OptionalPermissionNameFilter != "" {
+ return nil, NewInvalidFilterErr("cannot filter by both relation and permission name", filter.String())
+ }
+ }
+
+ if filter.OptionalPermissionNameFilter != "" {
+ if filter.OptionalDefinitionNameFilter == "" {
+ return nil, NewInvalidFilterErr("permission name match requires definition name match", filter.String())
+ }
+ }
+ }
+
+ return &expSchemaFilters{filters: filters}, nil
+}
+
+func (sf *expSchemaFilters) HasNamespaces() bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ for _, filter := range sf.filters {
+ if filter.OptionalDefinitionNameFilter != "" {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (sf *expSchemaFilters) HasCaveats() bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ for _, filter := range sf.filters {
+ if filter.OptionalCaveatNameFilter != "" {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (sf *expSchemaFilters) HasNamespace(namespaceName string) bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ hasDefinitionFilter := false
+ for _, filter := range sf.filters {
+ if filter.OptionalDefinitionNameFilter == "" {
+ continue
+ }
+
+ hasDefinitionFilter = true
+ isMatch := strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter)
+ if isMatch {
+ return true
+ }
+ }
+
+ return !hasDefinitionFilter
+}
+
+func (sf *expSchemaFilters) HasCaveat(caveatName string) bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ hasCaveatFilter := false
+ for _, filter := range sf.filters {
+ if filter.OptionalCaveatNameFilter == "" {
+ continue
+ }
+
+ hasCaveatFilter = true
+ isMatch := strings.HasPrefix(caveatName, filter.OptionalCaveatNameFilter)
+ if isMatch {
+ return true
+ }
+ }
+
+ return !hasCaveatFilter
+}
+
+func (sf *expSchemaFilters) HasRelation(namespaceName, relationName string) bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ hasRelationFilter := false
+ for _, filter := range sf.filters {
+ if filter.OptionalRelationNameFilter == "" {
+ continue
+ }
+
+ hasRelationFilter = true
+ isMatch := strings.HasPrefix(relationName, filter.OptionalRelationNameFilter)
+ if !isMatch {
+ continue
+ }
+
+ isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter)
+ if isMatch {
+ return true
+ }
+ }
+
+ return !hasRelationFilter
+}
+
+func (sf *expSchemaFilters) HasPermission(namespaceName, permissionName string) bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ hasPermissionFilter := false
+ for _, filter := range sf.filters {
+ if filter.OptionalPermissionNameFilter == "" {
+ continue
+ }
+
+ hasPermissionFilter = true
+ isMatch := strings.HasPrefix(permissionName, filter.OptionalPermissionNameFilter)
+ if !isMatch {
+ continue
+ }
+
+ isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter)
+ if isMatch {
+ return true
+ }
+ }
+
+ return !hasPermissionFilter
+}
+
+// expConvertDiff converts a schema diff into an API response.
+func expConvertDiff(
+ diff *diff.SchemaDiff,
+ existingSchema *diff.DiffableSchema,
+ comparisonSchema *diff.DiffableSchema,
+ atRevision datastore.Revision,
+ caveatTypeSet *caveattypes.TypeSet,
+) (*v1.ExperimentalDiffSchemaResponse, error) {
+ size := len(diff.AddedNamespaces) + len(diff.RemovedNamespaces) + len(diff.AddedCaveats) + len(diff.RemovedCaveats) + len(diff.ChangedNamespaces) + len(diff.ChangedCaveats)
+ diffs := make([]*v1.ExpSchemaDiff, 0, size)
+
+ // Add/remove namespaces.
+ for _, ns := range diff.AddedNamespaces {
+ nsDef, err := expNamespaceAPIReprForName(ns, comparisonSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_DefinitionAdded{
+ DefinitionAdded: nsDef,
+ },
+ })
+ }
+
+ for _, ns := range diff.RemovedNamespaces {
+ nsDef, err := expNamespaceAPIReprForName(ns, existingSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_DefinitionRemoved{
+ DefinitionRemoved: nsDef,
+ },
+ })
+ }
+
+ // Add/remove caveats.
+ for _, caveat := range diff.AddedCaveats {
+ caveatDef, err := expCaveatAPIReprForName(caveat, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_CaveatAdded{
+ CaveatAdded: caveatDef,
+ },
+ })
+ }
+
+ for _, caveat := range diff.RemovedCaveats {
+ caveatDef, err := expCaveatAPIReprForName(caveat, existingSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_CaveatRemoved{
+ CaveatRemoved: caveatDef,
+ },
+ })
+ }
+
+ // Changed namespaces.
+ for nsName, nsDiff := range diff.ChangedNamespaces {
+ for _, delta := range nsDiff.Deltas() {
+ switch delta.Type {
+ case nsdiff.AddedPermission:
+ permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ perm, err := expPermissionAPIRepr(permission, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_PermissionAdded{
+ PermissionAdded: perm,
+ },
+ })
+
+ case nsdiff.AddedRelation:
+ relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := expRelationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_RelationAdded{
+ RelationAdded: rel,
+ },
+ })
+
+ case nsdiff.ChangedPermissionComment:
+ permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ perm, err := expPermissionAPIRepr(permission, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_PermissionDocCommentChanged{
+ PermissionDocCommentChanged: perm,
+ },
+ })
+
+ case nsdiff.ChangedPermissionImpl:
+ permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ perm, err := expPermissionAPIRepr(permission, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_PermissionExprChanged{
+ PermissionExprChanged: perm,
+ },
+ })
+
+ case nsdiff.ChangedRelationComment:
+ relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := expRelationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_RelationDocCommentChanged{
+ RelationDocCommentChanged: rel,
+ },
+ })
+
+ case nsdiff.LegacyChangedRelationImpl:
+ return nil, spiceerrors.MustBugf("legacy relation implementation changes are not supported")
+
+ case nsdiff.NamespaceCommentsChanged:
+ def, err := expNamespaceAPIReprForName(nsName, comparisonSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_DefinitionDocCommentChanged{
+ DefinitionDocCommentChanged: def,
+ },
+ })
+
+ case nsdiff.RelationAllowedTypeRemoved:
+ relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := expRelationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_RelationSubjectTypeRemoved{
+ RelationSubjectTypeRemoved: &v1.ExpRelationSubjectTypeChange{
+ Relation: rel,
+ ChangedSubjectType: expTypeAPIRepr(delta.AllowedType),
+ },
+ },
+ })
+
+ case nsdiff.RelationAllowedTypeAdded:
+ relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := expRelationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_RelationSubjectTypeAdded{
+ RelationSubjectTypeAdded: &v1.ExpRelationSubjectTypeChange{
+ Relation: rel,
+ ChangedSubjectType: expTypeAPIRepr(delta.AllowedType),
+ },
+ },
+ })
+
+ case nsdiff.RemovedPermission:
+ permission, ok := existingSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ perm, err := expPermissionAPIRepr(permission, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_PermissionRemoved{
+ PermissionRemoved: perm,
+ },
+ })
+
+ case nsdiff.RemovedRelation:
+ relation, ok := existingSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := expRelationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_RelationRemoved{
+ RelationRemoved: rel,
+ },
+ })
+
+ case nsdiff.NamespaceAdded:
+ return nil, spiceerrors.MustBugf("should be handled above")
+
+ case nsdiff.NamespaceRemoved:
+ return nil, spiceerrors.MustBugf("should be handled above")
+
+ default:
+ return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type)
+ }
+ }
+ }
+
+ // Changed caveats.
+ for caveatName, caveatDiff := range diff.ChangedCaveats {
+ for _, delta := range caveatDiff.Deltas() {
+ switch delta.Type {
+ case caveatdiff.CaveatCommentsChanged:
+ caveat, err := expCaveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_CaveatDocCommentChanged{
+ CaveatDocCommentChanged: caveat,
+ },
+ })
+
+ case caveatdiff.AddedParameter:
+ paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_CaveatParameterAdded{
+ CaveatParameterAdded: paramDef,
+ },
+ })
+
+ case caveatdiff.RemovedParameter:
+ paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_CaveatParameterRemoved{
+ CaveatParameterRemoved: paramDef,
+ },
+ })
+
+ case caveatdiff.ParameterTypeChanged:
+ previousParamDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_CaveatParameterTypeChanged{
+ CaveatParameterTypeChanged: &v1.ExpCaveatParameterTypeChange{
+ Parameter: paramDef,
+ PreviousType: previousParamDef.Type,
+ },
+ },
+ })
+
+ case caveatdiff.CaveatExpressionChanged:
+ caveat, err := expCaveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ExpSchemaDiff{
+ Diff: &v1.ExpSchemaDiff_CaveatExprChanged{
+ CaveatExprChanged: caveat,
+ },
+ })
+
+ case caveatdiff.CaveatAdded:
+ return nil, spiceerrors.MustBugf("should be handled above")
+
+ case caveatdiff.CaveatRemoved:
+ return nil, spiceerrors.MustBugf("should be handled above")
+
+ default:
+ return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type)
+ }
+ }
+ }
+
+ return &v1.ExperimentalDiffSchemaResponse{
+ Diffs: diffs,
+ ReadAt: zedtoken.MustNewFromRevision(atRevision),
+ }, nil
+}
+
+// expNamespaceAPIReprForName builds an API representation of a namespace.
+func expNamespaceAPIReprForName(namespaceName string, schema *diff.DiffableSchema) (*v1.ExpDefinition, error) {
+ nsDef, ok := schema.GetNamespace(namespaceName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("namespace %q not found in schema", namespaceName)
+ }
+
+ return expNamespaceAPIRepr(nsDef, nil)
+}
+
+func expNamespaceAPIRepr(nsDef *core.NamespaceDefinition, expSchemaFilters *expSchemaFilters) (*v1.ExpDefinition, error) {
+ if expSchemaFilters != nil && !expSchemaFilters.HasNamespace(nsDef.Name) {
+ return nil, nil
+ }
+
+ relations := make([]*v1.ExpRelation, 0, len(nsDef.Relation))
+ permissions := make([]*v1.ExpPermission, 0, len(nsDef.Relation))
+
+ for _, rel := range nsDef.Relation {
+ if namespace.GetRelationKind(rel) == iv1.RelationMetadata_PERMISSION {
+ permission, err := expPermissionAPIRepr(rel, nsDef.Name, expSchemaFilters)
+ if err != nil {
+ return nil, err
+ }
+
+ if permission != nil {
+ permissions = append(permissions, permission)
+ }
+ continue
+ }
+
+ relation, err := expRelationAPIRepr(rel, nsDef.Name, expSchemaFilters)
+ if err != nil {
+ return nil, err
+ }
+
+ if relation != nil {
+ relations = append(relations, relation)
+ }
+ }
+
+ comments := namespace.GetComments(nsDef.Metadata)
+ return &v1.ExpDefinition{
+ Name: nsDef.Name,
+ Comment: strings.Join(comments, "\n"),
+ Relations: relations,
+ Permissions: permissions,
+ }, nil
+}
+
+// expPermissionAPIRepr builds an API representation of a permission.
+func expPermissionAPIRepr(relation *core.Relation, parentDefName string, expSchemaFilters *expSchemaFilters) (*v1.ExpPermission, error) {
+ if expSchemaFilters != nil && !expSchemaFilters.HasPermission(parentDefName, relation.Name) {
+ return nil, nil
+ }
+
+ comments := namespace.GetComments(relation.Metadata)
+ return &v1.ExpPermission{
+ Name: relation.Name,
+ Comment: strings.Join(comments, "\n"),
+ ParentDefinitionName: parentDefName,
+ }, nil
+}
+
+// expRelationAPIRepr builds an API representation of a relation.
+func expRelationAPIRepr(relation *core.Relation, parentDefName string, expSchemaFilters *expSchemaFilters) (*v1.ExpRelation, error) {
+ if expSchemaFilters != nil && !expSchemaFilters.HasRelation(parentDefName, relation.Name) {
+ return nil, nil
+ }
+
+ comments := namespace.GetComments(relation.Metadata)
+
+ var subjectTypes []*v1.ExpTypeReference
+ if relation.TypeInformation != nil {
+ subjectTypes = make([]*v1.ExpTypeReference, 0, len(relation.TypeInformation.AllowedDirectRelations))
+ for _, subjectType := range relation.TypeInformation.AllowedDirectRelations {
+ typeref := expTypeAPIRepr(subjectType)
+ subjectTypes = append(subjectTypes, typeref)
+ }
+ }
+
+ return &v1.ExpRelation{
+ Name: relation.Name,
+ Comment: strings.Join(comments, "\n"),
+ ParentDefinitionName: parentDefName,
+ SubjectTypes: subjectTypes,
+ }, nil
+}
+
+// expTypeAPIRepr builds an API representation of a type.
+func expTypeAPIRepr(subjectType *core.AllowedRelation) *v1.ExpTypeReference {
+ typeref := &v1.ExpTypeReference{
+ SubjectDefinitionName: subjectType.Namespace,
+ Typeref: &v1.ExpTypeReference_IsTerminalSubject{},
+ }
+
+ if subjectType.GetRelation() != tuple.Ellipsis && subjectType.GetRelation() != "" {
+ typeref.Typeref = &v1.ExpTypeReference_OptionalRelationName{
+ OptionalRelationName: subjectType.GetRelation(),
+ }
+ } else if subjectType.GetPublicWildcard() != nil {
+ typeref.Typeref = &v1.ExpTypeReference_IsPublicWildcard{
+ IsPublicWildcard: true,
+ }
+ }
+
+ if subjectType.GetRequiredCaveat() != nil {
+ typeref.OptionalCaveatName = subjectType.GetRequiredCaveat().CaveatName
+ }
+
+ return typeref
+}
+
+// expCaveatAPIReprForName builds an API representation of a caveat.
+func expCaveatAPIReprForName(caveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveat, error) {
+ caveatDef, ok := schema.GetCaveat(caveatName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("caveat %q not found in schema", caveatName)
+ }
+
+ return expCaveatAPIRepr(caveatDef, nil, caveatTypeSet)
+}
+
+// expCaveatAPIRepr builds an API representation of a caveat.
+func expCaveatAPIRepr(caveatDef *core.CaveatDefinition, expSchemaFilters *expSchemaFilters, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveat, error) {
+ if expSchemaFilters != nil && !expSchemaFilters.HasCaveat(caveatDef.Name) {
+ return nil, nil
+ }
+
+ parameters := make([]*v1.ExpCaveatParameter, 0, len(caveatDef.ParameterTypes))
+ paramNames := maps.Keys(caveatDef.ParameterTypes)
+ sort.Strings(paramNames)
+
+ for _, paramName := range paramNames {
+ paramType, ok := caveatDef.ParameterTypes[paramName]
+ if !ok {
+ return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, caveatDef.Name)
+ }
+
+ decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err)
+ }
+
+ parameters = append(parameters, &v1.ExpCaveatParameter{
+ Name: paramName,
+ Type: decoded.String(),
+ ParentCaveatName: caveatDef.Name,
+ })
+ }
+
+ parameterTypes, err := caveattypes.DecodeParameterTypes(caveatTypeSet, caveatDef.ParameterTypes)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid caveat parameters: %v", err)
+ }
+
+ deserializedExpression, err := caveats.DeserializeCaveatWithTypeSet(caveatTypeSet, caveatDef.SerializedExpression, parameterTypes)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid caveat expression bytes: %v", err)
+ }
+
+ exprString, err := deserializedExpression.ExprString()
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid caveat expression: %v", err)
+ }
+
+ comments := namespace.GetComments(caveatDef.Metadata)
+ return &v1.ExpCaveat{
+ Name: caveatDef.Name,
+ Comment: strings.Join(comments, "\n"),
+ Parameters: parameters,
+ Expression: exprString,
+ }, nil
+}
+
+// expCaveatAPIParamRepr builds an API representation of a caveat parameter.
+func expCaveatAPIParamRepr(paramName, parentCaveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveatParameter, error) {
+ caveatDef, ok := schema.GetCaveat(parentCaveatName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("caveat %q not found in schema", parentCaveatName)
+ }
+
+ paramType, ok := caveatDef.ParameterTypes[paramName]
+ if !ok {
+ return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, parentCaveatName)
+ }
+
+ decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err)
+ }
+
+ return &v1.ExpCaveatParameter{
+ Name: paramName,
+ Type: decoded.String(),
+ ParentCaveatName: parentCaveatName,
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go b/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go
new file mode 100644
index 0000000..99b681d
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go
@@ -0,0 +1,72 @@
+package v1
+
+import (
+ "context"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/internal/graph/computed"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+type groupedCheckParameters struct {
+ params *computed.CheckParameters
+ resourceIDs []string
+}
+
+type groupingParameters struct {
+ atRevision datastore.Revision
+ maximumAPIDepth uint32
+ maxCaveatContextSize int
+ withTracing bool
+}
+
+// groupItems takes a slice of CheckBulkPermissionsRequestItem and groups them based
+// on using the same permission, subject type, subject id, and caveat.
+func groupItems(ctx context.Context, params groupingParameters, items []*v1.CheckBulkPermissionsRequestItem) (map[string]*groupedCheckParameters, error) {
+ res := make(map[string]*groupedCheckParameters)
+
+ for _, item := range items {
+ hash, err := computeCheckBulkPermissionsItemHashWithoutResourceID(item)
+ if err != nil {
+ return nil, err
+ }
+
+ if _, ok := res[hash]; !ok {
+ caveatContext, err := GetCaveatContext(ctx, item.Context, params.maxCaveatContextSize)
+ if err != nil {
+ return nil, err
+ }
+
+ res[hash] = &groupedCheckParameters{
+ params: checkParametersFromCheckBulkPermissionsRequestItem(item, params, caveatContext),
+ resourceIDs: []string{item.Resource.ObjectId},
+ }
+ } else {
+ res[hash].resourceIDs = append(res[hash].resourceIDs, item.Resource.ObjectId)
+ }
+ }
+
+ return res, nil
+}
+
+func checkParametersFromCheckBulkPermissionsRequestItem(
+ bc *v1.CheckBulkPermissionsRequestItem,
+ params groupingParameters,
+ caveatContext map[string]any,
+) *computed.CheckParameters {
+ debugOption := computed.NoDebugging
+ if params.withTracing {
+ debugOption = computed.BasicDebuggingEnabled
+ }
+
+ return &computed.CheckParameters{
+ ResourceType: tuple.RR(bc.Resource.ObjectType, bc.Permission),
+ Subject: tuple.ONR(bc.Subject.Object.ObjectType, bc.Subject.Object.ObjectId, normalizeSubjectRelation(bc.Subject)),
+ CaveatContext: caveatContext,
+ AtRevision: params.atRevision,
+ MaximumDepth: params.maximumAPIDepth,
+ DebugOption: debugOption,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go
new file mode 100644
index 0000000..1754669
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go
@@ -0,0 +1,110 @@
+package v1
+
+import (
+ "strconv"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "google.golang.org/protobuf/types/known/structpb"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+func computeCheckBulkPermissionsItemHashWithoutResourceID(req *v1.CheckBulkPermissionsRequestItem) (string, error) {
+ return computeCallHash("v1.checkbulkpermissionsrequestitem", nil, map[string]any{
+ "resource-type": req.Resource.ObjectType,
+ "permission": req.Permission,
+ "subject-type": req.Subject.Object.ObjectType,
+ "subject-id": req.Subject.Object.ObjectId,
+ "subject-relation": req.Subject.OptionalRelation,
+ "context": req.Context,
+ })
+}
+
+func computeCheckBulkPermissionsItemHash(req *v1.CheckBulkPermissionsRequestItem) (string, error) {
+ return computeCallHash("v1.checkbulkpermissionsrequestitem", nil, map[string]any{
+ "resource-type": req.Resource.ObjectType,
+ "resource-id": req.Resource.ObjectId,
+ "permission": req.Permission,
+ "subject-type": req.Subject.Object.ObjectType,
+ "subject-id": req.Subject.Object.ObjectId,
+ "subject-relation": req.Subject.OptionalRelation,
+ "context": req.Context,
+ })
+}
+
+func computeReadRelationshipsRequestHash(req *v1.ReadRelationshipsRequest) (string, error) {
+ osf := req.RelationshipFilter.OptionalSubjectFilter
+ if osf == nil {
+ osf = &v1.SubjectFilter{}
+ }
+
+ srf := "(none)"
+ if osf.OptionalRelation != nil {
+ srf = osf.OptionalRelation.Relation
+ }
+
+ return computeCallHash("v1.readrelationships", req.Consistency, map[string]any{
+ "filter-resource-type": req.RelationshipFilter.ResourceType,
+ "filter-relation": req.RelationshipFilter.OptionalRelation,
+ "filter-resource-id": req.RelationshipFilter.OptionalResourceId,
+ "subject-type": osf.SubjectType,
+ "subject-relation": srf,
+ "subject-resource-id": osf.OptionalSubjectId,
+ "limit": req.OptionalLimit,
+ })
+}
+
+func computeLRRequestHash(req *v1.LookupResourcesRequest) (string, error) {
+ return computeCallHash("v1.lookupresources", req.Consistency, map[string]any{
+ "resource-type": req.ResourceObjectType,
+ "permission": req.Permission,
+ "subject": tuple.V1StringSubjectRef(req.Subject),
+ "limit": req.OptionalLimit,
+ "context": req.Context,
+ })
+}
+
+func computeCallHash(apiName string, consistency *v1.Consistency, arguments map[string]any) (string, error) {
+ stringArguments := make(map[string]string, len(arguments)+1)
+
+ if consistency == nil {
+ consistency = &v1.Consistency{
+ Requirement: &v1.Consistency_MinimizeLatency{
+ MinimizeLatency: true,
+ },
+ }
+ }
+
+ consistencyBytes, err := consistency.MarshalVT()
+ if err != nil {
+ return "", err
+ }
+
+ stringArguments["consistency"] = string(consistencyBytes)
+
+ for argName, argValue := range arguments {
+ if argName == "consistency" {
+ return "", spiceerrors.MustBugf("cannot specify consistency in the arguments")
+ }
+
+ switch v := argValue.(type) {
+ case string:
+ stringArguments[argName] = v
+
+ case int:
+ stringArguments[argName] = strconv.Itoa(v)
+
+ case uint32:
+ stringArguments[argName] = strconv.Itoa(int(v))
+
+ case *structpb.Struct:
+ stringArguments[argName] = caveats.StableContextStringForHashing(v)
+
+ default:
+ return "", spiceerrors.MustBugf("unknown argument type in compute call hash")
+ }
+ }
+ return computeAPICallHash(apiName, stringArguments)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go
new file mode 100644
index 0000000..fad4a40
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go
@@ -0,0 +1,52 @@
+//go:build !wasm
+// +build !wasm
+
+package v1
+
+import (
+ "fmt"
+ "sort"
+
+ "github.com/cespare/xxhash/v2"
+ "golang.org/x/exp/maps"
+)
+
+func computeAPICallHash(apiName string, arguments map[string]string) (string, error) {
+ hasher := xxhash.New()
+ _, err := hasher.WriteString(apiName)
+ if err != nil {
+ return "", err
+ }
+
+ _, err = hasher.WriteString(":")
+ if err != nil {
+ return "", err
+ }
+
+ keys := maps.Keys(arguments)
+ sort.Strings(keys)
+
+ for _, key := range keys {
+ _, err = hasher.WriteString(key)
+ if err != nil {
+ return "", err
+ }
+
+ _, err = hasher.WriteString(":")
+ if err != nil {
+ return "", err
+ }
+
+ _, err = hasher.WriteString(arguments[key])
+ if err != nil {
+ return "", err
+ }
+
+ _, err = hasher.WriteString(";")
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return fmt.Sprintf("%x", hasher.Sum(nil)), nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go
new file mode 100644
index 0000000..4c75aa0
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go
@@ -0,0 +1,50 @@
+package v1
+
+import (
+ "crypto/sha256"
+ "fmt"
+ "sort"
+
+ "golang.org/x/exp/maps"
+)
+
+func computeAPICallHash(apiName string, arguments map[string]string) (string, error) {
+ h := sha256.New()
+
+ _, err := h.Write([]byte(apiName))
+ if err != nil {
+ return "", err
+ }
+
+ _, err = h.Write([]byte(":"))
+ if err != nil {
+ return "", err
+ }
+
+ keys := maps.Keys(arguments)
+ sort.Strings(keys)
+
+ for _, key := range keys {
+ _, err = h.Write([]byte(key))
+ if err != nil {
+ return "", err
+ }
+
+ _, err = h.Write([]byte(":"))
+ if err != nil {
+ return "", err
+ }
+
+ _, err = h.Write([]byte(arguments[key]))
+ if err != nil {
+ return "", err
+ }
+
+ _, err = h.Write([]byte(";"))
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return fmt.Sprintf("%x", h.Sum(nil)), nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go b/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go
new file mode 100644
index 0000000..d309c3b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go
@@ -0,0 +1,12 @@
+package options
+
+import "time"
+
+//go:generate go run github.com/ecordell/optgen -output zz_generated.query_options.go . ExperimentalServerOptions
+
+type ExperimentalServerOptions struct {
+ StreamReadTimeout time.Duration `debugmap:"visible" default:"600s"`
+ DefaultExportBatchSize uint32 `debugmap:"visible" default:"1_000"`
+ MaxExportBatchSize uint32 `debugmap:"visible" default:"100_000"`
+ BulkCheckMaxConcurrency uint16 `debugmap:"visible" default:"50"`
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go b/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go
new file mode 100644
index 0000000..5b75b5f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go
@@ -0,0 +1,93 @@
+// Code generated by github.com/ecordell/optgen. DO NOT EDIT.
+package options
+
+import (
+ defaults "github.com/creasty/defaults"
+ helpers "github.com/ecordell/optgen/helpers"
+ "time"
+)
+
+type ExperimentalServerOptionsOption func(e *ExperimentalServerOptions)
+
+// NewExperimentalServerOptionsWithOptions creates a new ExperimentalServerOptions with the passed in options set
+func NewExperimentalServerOptionsWithOptions(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions {
+ e := &ExperimentalServerOptions{}
+ for _, o := range opts {
+ o(e)
+ }
+ return e
+}
+
+// NewExperimentalServerOptionsWithOptionsAndDefaults creates a new ExperimentalServerOptions with the passed in options set starting from the defaults
+func NewExperimentalServerOptionsWithOptionsAndDefaults(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions {
+ e := &ExperimentalServerOptions{}
+ defaults.MustSet(e)
+ for _, o := range opts {
+ o(e)
+ }
+ return e
+}
+
+// ToOption returns a new ExperimentalServerOptionsOption that sets the values from the passed in ExperimentalServerOptions
+func (e *ExperimentalServerOptions) ToOption() ExperimentalServerOptionsOption {
+ return func(to *ExperimentalServerOptions) {
+ to.StreamReadTimeout = e.StreamReadTimeout
+ to.DefaultExportBatchSize = e.DefaultExportBatchSize
+ to.MaxExportBatchSize = e.MaxExportBatchSize
+ to.BulkCheckMaxConcurrency = e.BulkCheckMaxConcurrency
+ }
+}
+
+// DebugMap returns a map form of ExperimentalServerOptions for debugging
+func (e ExperimentalServerOptions) DebugMap() map[string]any {
+ debugMap := map[string]any{}
+ debugMap["StreamReadTimeout"] = helpers.DebugValue(e.StreamReadTimeout, false)
+ debugMap["DefaultExportBatchSize"] = helpers.DebugValue(e.DefaultExportBatchSize, false)
+ debugMap["MaxExportBatchSize"] = helpers.DebugValue(e.MaxExportBatchSize, false)
+ debugMap["BulkCheckMaxConcurrency"] = helpers.DebugValue(e.BulkCheckMaxConcurrency, false)
+ return debugMap
+}
+
+// ExperimentalServerOptionsWithOptions configures an existing ExperimentalServerOptions with the passed in options set
+func ExperimentalServerOptionsWithOptions(e *ExperimentalServerOptions, opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions {
+ for _, o := range opts {
+ o(e)
+ }
+ return e
+}
+
+// WithOptions configures the receiver ExperimentalServerOptions with the passed in options set
+func (e *ExperimentalServerOptions) WithOptions(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions {
+ for _, o := range opts {
+ o(e)
+ }
+ return e
+}
+
+// WithStreamReadTimeout returns an option that can set StreamReadTimeout on a ExperimentalServerOptions
+func WithStreamReadTimeout(streamReadTimeout time.Duration) ExperimentalServerOptionsOption {
+ return func(e *ExperimentalServerOptions) {
+ e.StreamReadTimeout = streamReadTimeout
+ }
+}
+
+// WithDefaultExportBatchSize returns an option that can set DefaultExportBatchSize on a ExperimentalServerOptions
+func WithDefaultExportBatchSize(defaultExportBatchSize uint32) ExperimentalServerOptionsOption {
+ return func(e *ExperimentalServerOptions) {
+ e.DefaultExportBatchSize = defaultExportBatchSize
+ }
+}
+
+// WithMaxExportBatchSize returns an option that can set MaxExportBatchSize on a ExperimentalServerOptions
+func WithMaxExportBatchSize(maxExportBatchSize uint32) ExperimentalServerOptionsOption {
+ return func(e *ExperimentalServerOptions) {
+ e.MaxExportBatchSize = maxExportBatchSize
+ }
+}
+
+// WithBulkCheckMaxConcurrency returns an option that can set BulkCheckMaxConcurrency on a ExperimentalServerOptions
+func WithBulkCheckMaxConcurrency(bulkCheckMaxConcurrency uint16) ExperimentalServerOptionsOption {
+ return func(e *ExperimentalServerOptions) {
+ e.BulkCheckMaxConcurrency = bulkCheckMaxConcurrency
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go b/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go
new file mode 100644
index 0000000..da6dd18
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go
@@ -0,0 +1,1094 @@
+package v1
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "slices"
+ "strings"
+
+ "github.com/authzed/authzed-go/pkg/requestmeta"
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/jzelinskie/stringz"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/types/known/structpb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+
+ cexpr "github.com/authzed/spicedb/internal/caveats"
+ dispatchpkg "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/graph"
+ "github.com/authzed/spicedb/internal/graph/computed"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/middleware/usagemetrics"
+ "github.com/authzed/spicedb/internal/namespace"
+ "github.com/authzed/spicedb/internal/relationships"
+ "github.com/authzed/spicedb/internal/services/shared"
+ "github.com/authzed/spicedb/internal/telemetry"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/cursor"
+ "github.com/authzed/spicedb/pkg/datastore"
+ dsoptions "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/middleware/consistency"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+)
+
+func (ps *permissionServer) rewriteError(ctx context.Context, err error) error {
+ return shared.RewriteError(ctx, err, &shared.ConfigForErrors{
+ MaximumAPIDepth: ps.config.MaximumAPIDepth,
+ })
+}
+
+func (ps *permissionServer) rewriteErrorWithOptionalDebugTrace(ctx context.Context, err error, debugTrace *v1.DebugInformation) error {
+ return shared.RewriteError(ctx, err, &shared.ConfigForErrors{
+ MaximumAPIDepth: ps.config.MaximumAPIDepth,
+ DebugTrace: debugTrace,
+ })
+}
+
+func (ps *permissionServer) CheckPermission(ctx context.Context, req *v1.CheckPermissionRequest) (*v1.CheckPermissionResponse, error) {
+ telemetry.RecordLogicalChecks(1)
+
+ atRevision, checkedAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+
+ caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize)
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ if err := namespace.CheckNamespaceAndRelations(ctx,
+ []namespace.TypeAndRelationToCheck{
+ {
+ NamespaceName: req.Resource.ObjectType,
+ RelationName: req.Permission,
+ AllowEllipsis: false,
+ },
+ {
+ NamespaceName: req.Subject.Object.ObjectType,
+ RelationName: normalizeSubjectRelation(req.Subject),
+ AllowEllipsis: true,
+ },
+ }, ds); err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ debugOption := computed.NoDebugging
+
+ if md, ok := metadata.FromIncomingContext(ctx); ok {
+ _, isDebuggingEnabled := md[string(requestmeta.RequestDebugInformation)]
+ if isDebuggingEnabled {
+ debugOption = computed.BasicDebuggingEnabled
+ }
+ }
+
+ if req.WithTracing {
+ debugOption = computed.BasicDebuggingEnabled
+ }
+
+ cr, metadata, err := computed.ComputeCheck(ctx, ps.dispatch,
+ ps.config.CaveatTypeSet,
+ computed.CheckParameters{
+ ResourceType: tuple.RR(req.Resource.ObjectType, req.Permission),
+ Subject: tuple.ONR(req.Subject.Object.ObjectType, req.Subject.Object.ObjectId, normalizeSubjectRelation(req.Subject)),
+ CaveatContext: caveatContext,
+ AtRevision: atRevision,
+ MaximumDepth: ps.config.MaximumAPIDepth,
+ DebugOption: debugOption,
+ },
+ req.Resource.ObjectId,
+ ps.config.DispatchChunkSize,
+ )
+ usagemetrics.SetInContext(ctx, metadata)
+
+ var debugTrace *v1.DebugInformation
+ if debugOption != computed.NoDebugging && metadata.DebugInfo != nil {
+ // Convert the dispatch debug information into API debug information.
+ converted, cerr := ConvertCheckDispatchDebugInformation(ctx, ps.config.CaveatTypeSet, caveatContext, metadata.DebugInfo, ds)
+ if cerr != nil {
+ return nil, ps.rewriteError(ctx, cerr)
+ }
+ debugTrace = converted
+ }
+
+ if err != nil {
+ // If the error already contains debug information, rewrite it. This can happen if
+ // a dispatch error occurs and debug was requested.
+ if dispatchDebugInfo, ok := spiceerrors.GetDetails[*dispatch.DebugInformation](err); ok {
+ // Convert the dispatch debug information into API debug information.
+ converted, cerr := ConvertCheckDispatchDebugInformation(ctx, ps.config.CaveatTypeSet, caveatContext, dispatchDebugInfo, ds)
+ if cerr != nil {
+ return nil, ps.rewriteError(ctx, cerr)
+ }
+
+ if converted != nil {
+ return nil, spiceerrors.AppendDetailsMetadata(err, spiceerrors.DebugTraceErrorDetailsKey, converted.String())
+ }
+ }
+
+ return nil, ps.rewriteErrorWithOptionalDebugTrace(ctx, err, debugTrace)
+ }
+
+ permissionship, partialCaveat := checkResultToAPITypes(cr)
+
+ return &v1.CheckPermissionResponse{
+ CheckedAt: checkedAt,
+ Permissionship: permissionship,
+ PartialCaveatInfo: partialCaveat,
+ DebugTrace: debugTrace,
+ }, nil
+}
+
+func checkResultToAPITypes(cr *dispatch.ResourceCheckResult) (v1.CheckPermissionResponse_Permissionship, *v1.PartialCaveatInfo) {
+ var partialCaveat *v1.PartialCaveatInfo
+ permissionship := v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION
+ if cr.Membership == dispatch.ResourceCheckResult_MEMBER {
+ permissionship = v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION
+ } else if cr.Membership == dispatch.ResourceCheckResult_CAVEATED_MEMBER {
+ permissionship = v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION
+ partialCaveat = &v1.PartialCaveatInfo{
+ MissingRequiredContext: cr.MissingExprFields,
+ }
+ }
+ return permissionship, partialCaveat
+}
+
+func (ps *permissionServer) CheckBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) {
+ res, err := ps.bulkChecker.checkBulkPermissions(ctx, req)
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ return res, nil
+}
+
+func pairItemFromCheckResult(checkResult *dispatch.ResourceCheckResult, debugTrace *v1.DebugInformation) *v1.CheckBulkPermissionsPair_Item {
+ permissionship, partialCaveat := checkResultToAPITypes(checkResult)
+ return &v1.CheckBulkPermissionsPair_Item{
+ Item: &v1.CheckBulkPermissionsResponseItem{
+ Permissionship: permissionship,
+ PartialCaveatInfo: partialCaveat,
+ DebugTrace: debugTrace,
+ },
+ }
+}
+
+func requestItemFromResourceAndParameters(params *computed.CheckParameters, resourceID string) (*v1.CheckBulkPermissionsRequestItem, error) {
+ item := &v1.CheckBulkPermissionsRequestItem{
+ Resource: &v1.ObjectReference{
+ ObjectType: params.ResourceType.ObjectType,
+ ObjectId: resourceID,
+ },
+ Permission: params.ResourceType.Relation,
+ Subject: &v1.SubjectReference{
+ Object: &v1.ObjectReference{
+ ObjectType: params.Subject.ObjectType,
+ ObjectId: params.Subject.ObjectID,
+ },
+ OptionalRelation: denormalizeSubjectRelation(params.Subject.Relation),
+ },
+ }
+ if len(params.CaveatContext) > 0 {
+ var err error
+ item.Context, err = structpb.NewStruct(params.CaveatContext)
+ if err != nil {
+ return nil, fmt.Errorf("caveat context wasn't properly validated: %w", err)
+ }
+ }
+ return item, nil
+}
+
+func (ps *permissionServer) ExpandPermissionTree(ctx context.Context, req *v1.ExpandPermissionTreeRequest) (*v1.ExpandPermissionTreeResponse, error) {
+ telemetry.RecordLogicalChecks(1)
+
+ atRevision, expandedAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+
+ err = namespace.CheckNamespaceAndRelation(ctx, req.Resource.ObjectType, req.Permission, false, ds)
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth))
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := ps.dispatch.DispatchExpand(ctx, &dispatch.DispatchExpandRequest{
+ Metadata: &dispatch.ResolverMeta{
+ AtRevision: atRevision.String(),
+ DepthRemaining: ps.config.MaximumAPIDepth,
+ TraversalBloom: bf,
+ },
+ ResourceAndRelation: &core.ObjectAndRelation{
+ Namespace: req.Resource.ObjectType,
+ ObjectId: req.Resource.ObjectId,
+ Relation: req.Permission,
+ },
+ ExpansionMode: dispatch.DispatchExpandRequest_SHALLOW,
+ })
+ usagemetrics.SetInContext(ctx, resp.Metadata)
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ // TODO(jschorr): Change to either using shared interfaces for nodes, or switch the internal
+ // dispatched expand to return V1 node types.
+ return &v1.ExpandPermissionTreeResponse{
+ TreeRoot: TranslateExpansionTree(resp.TreeNode),
+ ExpandedAt: expandedAt,
+ }, nil
+}
+
+// TranslateRelationshipTree translates a V1 PermissionRelationshipTree into a RelationTupleTreeNode.
+func TranslateRelationshipTree(tree *v1.PermissionRelationshipTree) *core.RelationTupleTreeNode {
+ var expanded *core.ObjectAndRelation
+ if tree.ExpandedObject != nil {
+ expanded = &core.ObjectAndRelation{
+ Namespace: tree.ExpandedObject.ObjectType,
+ ObjectId: tree.ExpandedObject.ObjectId,
+ Relation: tree.ExpandedRelation,
+ }
+ }
+
+ switch t := tree.TreeType.(type) {
+ case *v1.PermissionRelationshipTree_Intermediate:
+ var operation core.SetOperationUserset_Operation
+ switch t.Intermediate.Operation {
+ case v1.AlgebraicSubjectSet_OPERATION_EXCLUSION:
+ operation = core.SetOperationUserset_EXCLUSION
+ case v1.AlgebraicSubjectSet_OPERATION_INTERSECTION:
+ operation = core.SetOperationUserset_INTERSECTION
+ case v1.AlgebraicSubjectSet_OPERATION_UNION:
+ operation = core.SetOperationUserset_UNION
+ default:
+ panic("unknown set operation")
+ }
+
+ children := []*core.RelationTupleTreeNode{}
+ for _, child := range t.Intermediate.Children {
+ children = append(children, TranslateRelationshipTree(child))
+ }
+
+ return &core.RelationTupleTreeNode{
+ NodeType: &core.RelationTupleTreeNode_IntermediateNode{
+ IntermediateNode: &core.SetOperationUserset{
+ Operation: operation,
+ ChildNodes: children,
+ },
+ },
+ Expanded: expanded,
+ }
+
+ case *v1.PermissionRelationshipTree_Leaf:
+ var subjects []*core.DirectSubject
+ for _, subj := range t.Leaf.Subjects {
+ subjects = append(subjects, &core.DirectSubject{
+ Subject: &core.ObjectAndRelation{
+ Namespace: subj.Object.ObjectType,
+ ObjectId: subj.Object.ObjectId,
+ Relation: stringz.DefaultEmpty(subj.OptionalRelation, graph.Ellipsis),
+ },
+ })
+ }
+
+ return &core.RelationTupleTreeNode{
+ NodeType: &core.RelationTupleTreeNode_LeafNode{
+ LeafNode: &core.DirectSubjects{Subjects: subjects},
+ },
+ Expanded: expanded,
+ }
+
+ default:
+ panic("unknown type of expansion tree node")
+ }
+}
+
+func TranslateExpansionTree(node *core.RelationTupleTreeNode) *v1.PermissionRelationshipTree {
+ switch t := node.NodeType.(type) {
+ case *core.RelationTupleTreeNode_IntermediateNode:
+ var operation v1.AlgebraicSubjectSet_Operation
+ switch t.IntermediateNode.Operation {
+ case core.SetOperationUserset_EXCLUSION:
+ operation = v1.AlgebraicSubjectSet_OPERATION_EXCLUSION
+ case core.SetOperationUserset_INTERSECTION:
+ operation = v1.AlgebraicSubjectSet_OPERATION_INTERSECTION
+ case core.SetOperationUserset_UNION:
+ operation = v1.AlgebraicSubjectSet_OPERATION_UNION
+ default:
+ panic("unknown set operation")
+ }
+
+ var children []*v1.PermissionRelationshipTree
+ for _, child := range node.GetIntermediateNode().ChildNodes {
+ children = append(children, TranslateExpansionTree(child))
+ }
+
+ var objRef *v1.ObjectReference
+ var objRel string
+ if node.Expanded != nil {
+ objRef = &v1.ObjectReference{
+ ObjectType: node.Expanded.Namespace,
+ ObjectId: node.Expanded.ObjectId,
+ }
+ objRel = node.Expanded.Relation
+ }
+
+ return &v1.PermissionRelationshipTree{
+ TreeType: &v1.PermissionRelationshipTree_Intermediate{
+ Intermediate: &v1.AlgebraicSubjectSet{
+ Operation: operation,
+ Children: children,
+ },
+ },
+ ExpandedObject: objRef,
+ ExpandedRelation: objRel,
+ }
+
+ case *core.RelationTupleTreeNode_LeafNode:
+ var subjects []*v1.SubjectReference
+ for _, found := range t.LeafNode.Subjects {
+ subjects = append(subjects, &v1.SubjectReference{
+ Object: &v1.ObjectReference{
+ ObjectType: found.Subject.Namespace,
+ ObjectId: found.Subject.ObjectId,
+ },
+ OptionalRelation: denormalizeSubjectRelation(found.Subject.Relation),
+ })
+ }
+
+ if node.Expanded == nil {
+ return &v1.PermissionRelationshipTree{
+ TreeType: &v1.PermissionRelationshipTree_Leaf{
+ Leaf: &v1.DirectSubjectSet{
+ Subjects: subjects,
+ },
+ },
+ }
+ }
+
+ return &v1.PermissionRelationshipTree{
+ TreeType: &v1.PermissionRelationshipTree_Leaf{
+ Leaf: &v1.DirectSubjectSet{
+ Subjects: subjects,
+ },
+ },
+ ExpandedObject: &v1.ObjectReference{
+ ObjectType: node.Expanded.Namespace,
+ ObjectId: node.Expanded.ObjectId,
+ },
+ ExpandedRelation: node.Expanded.Relation,
+ }
+
+ default:
+ panic("unknown type of expansion tree node")
+ }
+}
+
+const lrv2CursorFlag = "lrv2"
+
+func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp v1.PermissionsService_LookupResourcesServer) error {
+ // NOTE: LRv2 is the only valid option, and we'll expect that all cursors include that flag.
+ // This is to preserve backward-compatibility in the meantime.
+ if req.OptionalCursor != nil {
+ _, _, err := cursor.GetCursorFlag(req.OptionalCursor, lrv2CursorFlag)
+ if err != nil {
+ return ps.rewriteError(resp.Context(), err)
+ }
+ }
+
+ if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxLookupResourcesLimit {
+ return ps.rewriteError(resp.Context(), NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxLookupResourcesLimit)))
+ }
+
+ ctx := resp.Context()
+
+ atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+
+ if err := namespace.CheckNamespaceAndRelations(ctx,
+ []namespace.TypeAndRelationToCheck{
+ {
+ NamespaceName: req.ResourceObjectType,
+ RelationName: req.Permission,
+ AllowEllipsis: false,
+ },
+ {
+ NamespaceName: req.Subject.Object.ObjectType,
+ RelationName: normalizeSubjectRelation(req.Subject),
+ AllowEllipsis: true,
+ },
+ }, ds); err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ respMetadata := &dispatch.ResponseMeta{
+ DispatchCount: 1,
+ CachedDispatchCount: 0,
+ DepthRequired: 1,
+ DebugInfo: nil,
+ }
+ usagemetrics.SetInContext(ctx, respMetadata)
+
+ var currentCursor *dispatch.Cursor
+
+ lrRequestHash, err := computeLRRequestHash(req)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ if req.OptionalCursor != nil {
+ decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+ currentCursor = decodedCursor
+ }
+
+ alreadyPublishedPermissionedResourceIds := map[string]struct{}{}
+ var totalCountPublished uint64
+ defer func() {
+ telemetry.RecordLogicalChecks(totalCountPublished)
+ }()
+
+ stream := dispatchpkg.NewHandlingDispatchStream(ctx, func(result *dispatch.DispatchLookupResources2Response) error {
+ found := result.Resource
+
+ dispatchpkg.AddResponseMetadata(respMetadata, result.Metadata)
+ currentCursor = result.AfterResponseCursor
+
+ var partial *v1.PartialCaveatInfo
+ permissionship := v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION
+ if len(found.MissingContextParams) > 0 {
+ permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION
+ partial = &v1.PartialCaveatInfo{
+ MissingRequiredContext: found.MissingContextParams,
+ }
+ } else if req.OptionalLimit == 0 {
+ if _, ok := alreadyPublishedPermissionedResourceIds[found.ResourceId]; ok {
+ // Skip publishing the duplicate.
+ return nil
+ }
+
+ // TODO(jschorr): Investigate something like a Trie here for better memory efficiency.
+ alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{}
+ }
+
+ encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, map[string]string{
+ lrv2CursorFlag: "1",
+ })
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ err = resp.Send(&v1.LookupResourcesResponse{
+ LookedUpAt: revisionReadAt,
+ ResourceObjectId: found.ResourceId,
+ Permissionship: permissionship,
+ PartialCaveatInfo: partial,
+ AfterResultCursor: encodedCursor,
+ })
+ if err != nil {
+ return err
+ }
+
+ totalCountPublished++
+ return nil
+ })
+
+ bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth))
+ if err != nil {
+ return err
+ }
+
+ err = ps.dispatch.DispatchLookupResources2(
+ &dispatch.DispatchLookupResources2Request{
+ Metadata: &dispatch.ResolverMeta{
+ AtRevision: atRevision.String(),
+ DepthRemaining: ps.config.MaximumAPIDepth,
+ TraversalBloom: bf,
+ },
+ ResourceRelation: &core.RelationReference{
+ Namespace: req.ResourceObjectType,
+ Relation: req.Permission,
+ },
+ SubjectRelation: &core.RelationReference{
+ Namespace: req.Subject.Object.ObjectType,
+ Relation: normalizeSubjectRelation(req.Subject),
+ },
+ SubjectIds: []string{req.Subject.Object.ObjectId},
+ TerminalSubject: &core.ObjectAndRelation{
+ Namespace: req.Subject.Object.ObjectType,
+ ObjectId: req.Subject.Object.ObjectId,
+ Relation: normalizeSubjectRelation(req.Subject),
+ },
+ Context: req.Context,
+ OptionalCursor: currentCursor,
+ OptionalLimit: req.OptionalLimit,
+ },
+ stream)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ return nil
+}
+
+func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v1.PermissionsService_LookupSubjectsServer) error {
+ ctx := resp.Context()
+
+ if req.OptionalConcreteLimit != 0 {
+ return ps.rewriteError(ctx, status.Errorf(codes.Unimplemented, "concrete limit is not yet supported"))
+ }
+
+ atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+
+ caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ if err := namespace.CheckNamespaceAndRelations(ctx,
+ []namespace.TypeAndRelationToCheck{
+ {
+ NamespaceName: req.Resource.ObjectType,
+ RelationName: req.Permission,
+ AllowEllipsis: false,
+ },
+ {
+ NamespaceName: req.SubjectObjectType,
+ RelationName: stringz.DefaultEmpty(req.OptionalSubjectRelation, tuple.Ellipsis),
+ AllowEllipsis: true,
+ },
+ }, ds); err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ respMetadata := &dispatch.ResponseMeta{
+ DispatchCount: 0,
+ CachedDispatchCount: 0,
+ DepthRequired: 0,
+ DebugInfo: nil,
+ }
+ usagemetrics.SetInContext(ctx, respMetadata)
+
+ var totalCountPublished uint64
+ defer func() {
+ telemetry.RecordLogicalChecks(totalCountPublished)
+ }()
+
+ stream := dispatchpkg.NewHandlingDispatchStream(ctx, func(result *dispatch.DispatchLookupSubjectsResponse) error {
+ foundSubjects, ok := result.FoundSubjectsByResourceId[req.Resource.ObjectId]
+ if !ok {
+ return fmt.Errorf("missing resource ID in returned LS")
+ }
+
+ for _, foundSubject := range foundSubjects.FoundSubjects {
+ excludedSubjectIDs := make([]string, 0, len(foundSubject.ExcludedSubjects))
+ for _, excludedSubject := range foundSubject.ExcludedSubjects {
+ excludedSubjectIDs = append(excludedSubjectIDs, excludedSubject.SubjectId)
+ }
+
+ excludedSubjects := make([]*v1.ResolvedSubject, 0, len(foundSubject.ExcludedSubjects))
+ for _, excludedSubject := range foundSubject.ExcludedSubjects {
+ resolvedExcludedSubject, err := foundSubjectToResolvedSubject(ctx, excludedSubject, caveatContext, ds, ps.config.CaveatTypeSet)
+ if err != nil {
+ return err
+ }
+
+ if resolvedExcludedSubject == nil {
+ continue
+ }
+
+ excludedSubjects = append(excludedSubjects, resolvedExcludedSubject)
+ }
+
+ subject, err := foundSubjectToResolvedSubject(ctx, foundSubject, caveatContext, ds, ps.config.CaveatTypeSet)
+ if err != nil {
+ return err
+ }
+ if subject == nil {
+ continue
+ }
+
+ err = resp.Send(&v1.LookupSubjectsResponse{
+ Subject: subject,
+ ExcludedSubjects: excludedSubjects,
+ LookedUpAt: revisionReadAt,
+ SubjectObjectId: foundSubject.SubjectId, // Deprecated
+ ExcludedSubjectIds: excludedSubjectIDs, // Deprecated
+ Permissionship: subject.Permissionship, // Deprecated
+ PartialCaveatInfo: subject.PartialCaveatInfo, // Deprecated
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ totalCountPublished++
+ dispatchpkg.AddResponseMetadata(respMetadata, result.Metadata)
+ return nil
+ })
+
+ bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth))
+ if err != nil {
+ return err
+ }
+
+ err = ps.dispatch.DispatchLookupSubjects(
+ &dispatch.DispatchLookupSubjectsRequest{
+ Metadata: &dispatch.ResolverMeta{
+ AtRevision: atRevision.String(),
+ DepthRemaining: ps.config.MaximumAPIDepth,
+ TraversalBloom: bf,
+ },
+ ResourceRelation: &core.RelationReference{
+ Namespace: req.Resource.ObjectType,
+ Relation: req.Permission,
+ },
+ ResourceIds: []string{req.Resource.ObjectId},
+ SubjectRelation: &core.RelationReference{
+ Namespace: req.SubjectObjectType,
+ Relation: stringz.DefaultEmpty(req.OptionalSubjectRelation, tuple.Ellipsis),
+ },
+ },
+ stream)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ return nil
+}
+
+func foundSubjectToResolvedSubject(ctx context.Context, foundSubject *dispatch.FoundSubject, caveatContext map[string]any, ds datastore.CaveatReader, caveatTypeSet *caveattypes.TypeSet) (*v1.ResolvedSubject, error) {
+ var partialCaveat *v1.PartialCaveatInfo
+ permissionship := v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION
+ if foundSubject.GetCaveatExpression() != nil {
+ permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION
+
+ cr, err := cexpr.RunSingleCaveatExpression(ctx, caveatTypeSet, foundSubject.GetCaveatExpression(), caveatContext, ds, cexpr.RunCaveatExpressionNoDebugging)
+ if err != nil {
+ return nil, err
+ }
+
+ if cr.Value() {
+ permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION
+ } else if cr.IsPartial() {
+ missingFields, _ := cr.MissingVarNames()
+ partialCaveat = &v1.PartialCaveatInfo{
+ MissingRequiredContext: missingFields,
+ }
+ } else {
+ // Skip this found subject.
+ return nil, nil
+ }
+ }
+
+ return &v1.ResolvedSubject{
+ SubjectObjectId: foundSubject.SubjectId,
+ Permissionship: permissionship,
+ PartialCaveatInfo: partialCaveat,
+ }, nil
+}
+
+func normalizeSubjectRelation(sub *v1.SubjectReference) string {
+ if sub.OptionalRelation == "" {
+ return graph.Ellipsis
+ }
+ return sub.OptionalRelation
+}
+
+func denormalizeSubjectRelation(relation string) string {
+ if relation == graph.Ellipsis {
+ return ""
+ }
+ return relation
+}
+
+func GetCaveatContext(ctx context.Context, caveatCtx *structpb.Struct, maxCaveatContextSize int) (map[string]any, error) {
+ var caveatContext map[string]any
+ if caveatCtx != nil {
+ if size := proto.Size(caveatCtx); maxCaveatContextSize > 0 && size > maxCaveatContextSize {
+ return nil, shared.RewriteError(
+ ctx,
+ status.Errorf(
+ codes.InvalidArgument,
+ "request caveat context should have less than %d bytes but had %d",
+ maxCaveatContextSize,
+ size,
+ ),
+ nil,
+ )
+ }
+ caveatContext = caveatCtx.AsMap()
+ }
+ return caveatContext, nil
+}
+
+type loadBulkAdapter struct {
+ stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]
+ referencedNamespaceMap map[string]*schema.Definition
+ referencedCaveatMap map[string]*core.CaveatDefinition
+ current tuple.Relationship
+ caveat core.ContextualizedCaveat
+ caveatTypeSet *caveattypes.TypeSet
+
+ awaitingNamespaces []string
+ awaitingCaveats []string
+
+ currentBatch []*v1.Relationship
+ numSent int
+ err error
+}
+
+func (a *loadBulkAdapter) Next(_ context.Context) (*tuple.Relationship, error) {
+ for a.err == nil && a.numSent == len(a.currentBatch) {
+ // Load a new batch
+ batch, err := a.stream.Recv()
+ if err != nil {
+ a.err = err
+ if errors.Is(a.err, io.EOF) {
+ return nil, nil
+ }
+ return nil, a.err
+ }
+
+ a.currentBatch = batch.Relationships
+ a.numSent = 0
+
+ a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats(
+ a.currentBatch,
+ a.referencedNamespaceMap,
+ a.referencedCaveatMap,
+ )
+ }
+
+ if len(a.awaitingNamespaces) > 0 || len(a.awaitingCaveats) > 0 {
+ // Shut down the stream to give our caller a chance to fill in this information
+ return nil, nil
+ }
+
+ a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType
+ a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId
+ a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation
+ a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType
+ a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId
+ a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis)
+
+ if a.currentBatch[a.numSent].OptionalCaveat != nil {
+ a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName
+ a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context
+ a.current.OptionalCaveat = &a.caveat
+ } else {
+ a.current.OptionalCaveat = nil
+ }
+
+ if a.currentBatch[a.numSent].OptionalExpiresAt != nil {
+ t := a.currentBatch[a.numSent].OptionalExpiresAt.AsTime()
+ a.current.OptionalExpiration = &t
+ } else {
+ a.current.OptionalExpiration = nil
+ }
+
+ a.current.OptionalIntegrity = nil
+
+ if err := relationships.ValidateOneRelationship(
+ a.referencedNamespaceMap,
+ a.referencedCaveatMap,
+ a.caveatTypeSet,
+ a.current,
+ relationships.ValidateRelationshipForCreateOrTouch,
+ ); err != nil {
+ return nil, err
+ }
+
+ a.numSent++
+ return &a.current, nil
+}
+
+func (ps *permissionServer) ImportBulkRelationships(stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error {
+ ds := datastoremw.MustFromContext(stream.Context())
+
+ var numWritten uint64
+ if _, err := ds.ReadWriteTx(stream.Context(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ loadedNamespaces := make(map[string]*schema.Definition, 2)
+ loadedCaveats := make(map[string]*core.CaveatDefinition, 0)
+
+ adapter := &loadBulkAdapter{
+ stream: stream,
+ referencedNamespaceMap: loadedNamespaces,
+ referencedCaveatMap: loadedCaveats,
+ caveat: core.ContextualizedCaveat{},
+ caveatTypeSet: ps.config.CaveatTypeSet,
+ }
+ resolver := schema.ResolverForDatastoreReader(rwt)
+ ts := schema.NewTypeSystem(resolver)
+
+ var streamWritten uint64
+ var err error
+ for ; adapter.err == nil && err == nil; streamWritten, err = rwt.BulkLoad(stream.Context(), adapter) {
+ numWritten += streamWritten
+
+ // The stream has terminated because we're awaiting namespace and/or caveat information
+ if len(adapter.awaitingNamespaces) > 0 {
+ nsDefs, err := rwt.LookupNamespacesWithNames(stream.Context(), adapter.awaitingNamespaces)
+ if err != nil {
+ return err
+ }
+
+ for _, nsDef := range nsDefs {
+ newDef, err := schema.NewDefinition(ts, nsDef.Definition)
+ if err != nil {
+ return err
+ }
+
+ loadedNamespaces[nsDef.Definition.Name] = newDef
+ }
+ adapter.awaitingNamespaces = nil
+ }
+
+ if len(adapter.awaitingCaveats) > 0 {
+ caveats, err := rwt.LookupCaveatsWithNames(stream.Context(), adapter.awaitingCaveats)
+ if err != nil {
+ return err
+ }
+
+ for _, caveat := range caveats {
+ loadedCaveats[caveat.Definition.Name] = caveat.Definition
+ }
+ adapter.awaitingCaveats = nil
+ }
+ }
+ numWritten += streamWritten
+
+ return err
+ }, dsoptions.WithDisableRetries(true)); err != nil {
+ return shared.RewriteErrorWithoutConfig(stream.Context(), err)
+ }
+
+ usagemetrics.SetInContext(stream.Context(), &dispatch.ResponseMeta{
+ // One request for the whole load
+ DispatchCount: 1,
+ })
+
+ return stream.SendAndClose(&v1.ImportBulkRelationshipsResponse{
+ NumLoaded: numWritten,
+ })
+}
+
+func (ps *permissionServer) ExportBulkRelationships(
+ req *v1.ExportBulkRelationshipsRequest,
+ resp grpc.ServerStreamingServer[v1.ExportBulkRelationshipsResponse],
+) error {
+ ctx := resp.Context()
+ atRevision, _, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ return ExportBulk(ctx, datastoremw.MustFromContext(ctx), uint64(ps.config.MaxBulkExportRelationshipsLimit), req, atRevision, resp.Send)
+}
+
+// ExportBulk implements the ExportBulkRelationships API functionality. Given a datastore.Datastore, it will
+// export stream via the sender all relationships matched by the incoming request.
+// If no cursor is provided, it will fallback to the provided revision.
+func ExportBulk(ctx context.Context, ds datastore.Datastore, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.ExportBulkRelationshipsResponse) error) error {
+ if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize {
+ return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize))
+ }
+
+ atRevision := fallbackRevision
+ var curNamespace string
+ var cur dsoptions.Cursor
+ if req.OptionalCursor != nil {
+ var err error
+ atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor)
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+ }
+
+ reader := ds.SnapshotReader(atRevision)
+
+ namespaces, err := reader.ListAllNamespaces(ctx)
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ // Make sure the namespaces are always in a stable order
+ slices.SortFunc(namespaces, func(
+ lhs datastore.RevisionedDefinition[*core.NamespaceDefinition],
+ rhs datastore.RevisionedDefinition[*core.NamespaceDefinition],
+ ) int {
+ return strings.Compare(lhs.Definition.Name, rhs.Definition.Name)
+ })
+
+ // Skip the namespaces that are already fully returned
+ for cur != nil && len(namespaces) > 0 && namespaces[0].Definition.Name < curNamespace {
+ namespaces = namespaces[1:]
+ }
+
+ limit := batchSize
+ if req.OptionalLimit > 0 {
+ limit = uint64(req.OptionalLimit)
+ }
+
+ // Pre-allocate all of the relationships that we might need in order to
+ // make export easier and faster for the garbage collector.
+ relsArray := make([]v1.Relationship, limit)
+ objArray := make([]v1.ObjectReference, limit)
+ subArray := make([]v1.SubjectReference, limit)
+ subObjArray := make([]v1.ObjectReference, limit)
+ caveatArray := make([]v1.ContextualizedCaveat, limit)
+ for i := range relsArray {
+ relsArray[i].Resource = &objArray[i]
+ relsArray[i].Subject = &subArray[i]
+ relsArray[i].Subject.Object = &subObjArray[i]
+ }
+
+ emptyRels := make([]*v1.Relationship, limit)
+ // The number of batches/dispatches for the purpose of usage metrics
+ var batches uint32
+ for _, ns := range namespaces {
+ rels := emptyRels
+
+ // Reset the cursor between namespaces.
+ if ns.Definition.Name != curNamespace {
+ cur = nil
+ }
+
+ // Skip this namespace if a resource type filter was specified.
+ if req.OptionalRelationshipFilter != nil && req.OptionalRelationshipFilter.ResourceType != "" {
+ if ns.Definition.Name != req.OptionalRelationshipFilter.ResourceType {
+ continue
+ }
+ }
+
+ // Setup the filter to use for the relationships.
+ relationshipFilter := datastore.RelationshipsFilter{OptionalResourceType: ns.Definition.Name}
+ if req.OptionalRelationshipFilter != nil {
+ rf, err := datastore.RelationshipsFilterFromPublicFilter(req.OptionalRelationshipFilter)
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ // Overload the namespace name with the one from the request, because each iteration is for a different namespace.
+ rf.OptionalResourceType = ns.Definition.Name
+ relationshipFilter = rf
+ }
+
+ // We want to keep iterating as long as we're sending full batches.
+ // To bootstrap this loop, we enter the first time with a full rels
+ // slice of dummy rels that were never sent.
+ for uint64(len(rels)) == limit {
+ // Lop off any rels we've already sent
+ rels = rels[:0]
+
+ relFn := func(rel tuple.Relationship) {
+ offset := len(rels)
+ rels = append(rels, &relsArray[offset]) // nozero
+
+ v1Rel := &relsArray[offset]
+ v1Rel.Resource.ObjectType = rel.RelationshipReference.Resource.ObjectType
+ v1Rel.Resource.ObjectId = rel.RelationshipReference.Resource.ObjectID
+ v1Rel.Relation = rel.RelationshipReference.Resource.Relation
+ v1Rel.Subject.Object.ObjectType = rel.RelationshipReference.Subject.ObjectType
+ v1Rel.Subject.Object.ObjectId = rel.RelationshipReference.Subject.ObjectID
+ v1Rel.Subject.OptionalRelation = denormalizeSubjectRelation(rel.RelationshipReference.Subject.Relation)
+
+ if rel.OptionalCaveat != nil {
+ caveatArray[offset].CaveatName = rel.OptionalCaveat.CaveatName
+ caveatArray[offset].Context = rel.OptionalCaveat.Context
+ v1Rel.OptionalCaveat = &caveatArray[offset]
+ } else {
+ caveatArray[offset].CaveatName = ""
+ caveatArray[offset].Context = nil
+ v1Rel.OptionalCaveat = nil
+ }
+
+ if rel.OptionalExpiration != nil {
+ v1Rel.OptionalExpiresAt = timestamppb.New(*rel.OptionalExpiration)
+ } else {
+ v1Rel.OptionalExpiresAt = nil
+ }
+ }
+
+ cur, err = queryForEach(
+ ctx,
+ reader,
+ relationshipFilter,
+ relFn,
+ dsoptions.WithLimit(&limit),
+ dsoptions.WithAfter(cur),
+ dsoptions.WithSort(dsoptions.ByResource),
+ dsoptions.WithQueryShape(queryshape.Varying),
+ )
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ if len(rels) == 0 {
+ continue
+ }
+
+ encoded, err := cursor.Encode(&implv1.DecodedCursor{
+ VersionOneof: &implv1.DecodedCursor_V1{
+ V1: &implv1.V1Cursor{
+ Revision: atRevision.String(),
+ Sections: []string{
+ ns.Definition.Name,
+ tuple.MustString(*dsoptions.ToRelationship(cur)),
+ },
+ },
+ },
+ })
+ if err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ if err := sender(&v1.ExportBulkRelationshipsResponse{
+ AfterResultCursor: encoded,
+ Relationships: rels,
+ }); err != nil {
+ return shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+ // Increment batches for usagemetrics
+ batches++
+ }
+ }
+
+ // Record usage metrics
+ respMetadata := &dispatch.ResponseMeta{
+ DispatchCount: batches,
+ }
+ usagemetrics.SetInContext(ctx, respMetadata)
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go b/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go
new file mode 100644
index 0000000..c34d5d5
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go
@@ -0,0 +1,54 @@
+package v1
+
+import (
+ "context"
+ "fmt"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+)
+
+var limitOne uint64 = 1
+
+// checkPreconditions checks whether the preconditions are met in the context of a datastore
+// read-write transaction, and returns an error if they are not met.
+func checkPreconditions(
+ ctx context.Context,
+ rwt datastore.ReadWriteTransaction,
+ preconditions []*v1.Precondition,
+) error {
+ for _, precond := range preconditions {
+ dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(precond.Filter)
+ if err != nil {
+ return fmt.Errorf("error converting filter: %w", err)
+ }
+
+ iter, err := rwt.QueryRelationships(ctx, dsFilter, options.WithLimit(&limitOne), options.WithQueryShape(queryshape.Varying))
+ if err != nil {
+ return fmt.Errorf("error reading relationships: %w", err)
+ }
+
+ _, ok, err := datastore.FirstRelationshipIn(iter)
+ if err != nil {
+ return fmt.Errorf("error reading relationships from iterator: %w", err)
+ }
+
+ switch precond.Operation {
+ case v1.Precondition_OPERATION_MUST_NOT_MATCH:
+ if ok {
+ return NewPreconditionFailedErr(precond)
+ }
+ case v1.Precondition_OPERATION_MUST_MATCH:
+ if !ok {
+ return NewPreconditionFailedErr(precond)
+ }
+ default:
+ return fmt.Errorf("unspecified precondition operation: %s", precond.Operation)
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go
new file mode 100644
index 0000000..723a8d3
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go
@@ -0,0 +1,720 @@
+package v1
+
+import (
+ "sort"
+ "strings"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "golang.org/x/exp/maps"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/diff"
+ caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats"
+ nsdiff "github.com/authzed/spicedb/pkg/diff/namespace"
+ "github.com/authzed/spicedb/pkg/namespace"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/zedtoken"
+)
+
+type schemaFilters struct {
+ filters []*v1.ReflectionSchemaFilter
+}
+
+func newSchemaFilters(filters []*v1.ReflectionSchemaFilter) (*schemaFilters, error) {
+ for _, filter := range filters {
+ if filter.OptionalDefinitionNameFilter != "" {
+ if filter.OptionalCaveatNameFilter != "" {
+ return nil, NewInvalidFilterErr("cannot filter by both definition and caveat name", filter.String())
+ }
+ }
+
+ if filter.OptionalRelationNameFilter != "" {
+ if filter.OptionalDefinitionNameFilter == "" {
+ return nil, NewInvalidFilterErr("relation name match requires definition name match", filter.String())
+ }
+
+ if filter.OptionalPermissionNameFilter != "" {
+ return nil, NewInvalidFilterErr("cannot filter by both relation and permission name", filter.String())
+ }
+ }
+
+ if filter.OptionalPermissionNameFilter != "" {
+ if filter.OptionalDefinitionNameFilter == "" {
+ return nil, NewInvalidFilterErr("permission name match requires definition name match", filter.String())
+ }
+ }
+ }
+
+ return &schemaFilters{filters: filters}, nil
+}
+
+func (sf *schemaFilters) HasNamespaces() bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ for _, filter := range sf.filters {
+ if filter.OptionalDefinitionNameFilter != "" {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (sf *schemaFilters) HasCaveats() bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ for _, filter := range sf.filters {
+ if filter.OptionalCaveatNameFilter != "" {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (sf *schemaFilters) HasNamespace(namespaceName string) bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ hasDefinitionFilter := false
+ for _, filter := range sf.filters {
+ if filter.OptionalDefinitionNameFilter == "" {
+ continue
+ }
+
+ hasDefinitionFilter = true
+ isMatch := strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter)
+ if isMatch {
+ return true
+ }
+ }
+
+ return !hasDefinitionFilter
+}
+
+func (sf *schemaFilters) HasCaveat(caveatName string) bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ hasCaveatFilter := false
+ for _, filter := range sf.filters {
+ if filter.OptionalCaveatNameFilter == "" {
+ continue
+ }
+
+ hasCaveatFilter = true
+ isMatch := strings.HasPrefix(caveatName, filter.OptionalCaveatNameFilter)
+ if isMatch {
+ return true
+ }
+ }
+
+ return !hasCaveatFilter
+}
+
+func (sf *schemaFilters) HasRelation(namespaceName, relationName string) bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ hasRelationFilter := false
+ for _, filter := range sf.filters {
+ if filter.OptionalRelationNameFilter == "" {
+ continue
+ }
+
+ hasRelationFilter = true
+ isMatch := strings.HasPrefix(relationName, filter.OptionalRelationNameFilter)
+ if !isMatch {
+ continue
+ }
+
+ isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter)
+ if isMatch {
+ return true
+ }
+ }
+
+ return !hasRelationFilter
+}
+
+func (sf *schemaFilters) HasPermission(namespaceName, permissionName string) bool {
+ if len(sf.filters) == 0 {
+ return true
+ }
+
+ hasPermissionFilter := false
+ for _, filter := range sf.filters {
+ if filter.OptionalPermissionNameFilter == "" {
+ continue
+ }
+
+ hasPermissionFilter = true
+ isMatch := strings.HasPrefix(permissionName, filter.OptionalPermissionNameFilter)
+ if !isMatch {
+ continue
+ }
+
+ isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter)
+ if isMatch {
+ return true
+ }
+ }
+
+ return !hasPermissionFilter
+}
+
+// convertDiff converts a schema diff into an API response.
+func convertDiff(
+ diff *diff.SchemaDiff,
+ existingSchema *diff.DiffableSchema,
+ comparisonSchema *diff.DiffableSchema,
+ atRevision datastore.Revision,
+ caveatTypeSet *caveattypes.TypeSet,
+) (*v1.DiffSchemaResponse, error) {
+ size := len(diff.AddedNamespaces) + len(diff.RemovedNamespaces) + len(diff.AddedCaveats) + len(diff.RemovedCaveats) + len(diff.ChangedNamespaces) + len(diff.ChangedCaveats)
+ diffs := make([]*v1.ReflectionSchemaDiff, 0, size)
+
+ // Add/remove namespaces.
+ for _, ns := range diff.AddedNamespaces {
+ nsDef, err := namespaceAPIReprForName(ns, comparisonSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_DefinitionAdded{
+ DefinitionAdded: nsDef,
+ },
+ })
+ }
+
+ for _, ns := range diff.RemovedNamespaces {
+ nsDef, err := namespaceAPIReprForName(ns, existingSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_DefinitionRemoved{
+ DefinitionRemoved: nsDef,
+ },
+ })
+ }
+
+ // Add/remove caveats.
+ for _, caveat := range diff.AddedCaveats {
+ caveatDef, err := caveatAPIReprForName(caveat, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_CaveatAdded{
+ CaveatAdded: caveatDef,
+ },
+ })
+ }
+
+ for _, caveat := range diff.RemovedCaveats {
+ caveatDef, err := caveatAPIReprForName(caveat, existingSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_CaveatRemoved{
+ CaveatRemoved: caveatDef,
+ },
+ })
+ }
+
+ // Changed namespaces.
+ for nsName, nsDiff := range diff.ChangedNamespaces {
+ for _, delta := range nsDiff.Deltas() {
+ switch delta.Type {
+ case nsdiff.AddedPermission:
+ permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ perm, err := permissionAPIRepr(permission, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_PermissionAdded{
+ PermissionAdded: perm,
+ },
+ })
+
+ case nsdiff.AddedRelation:
+ relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := relationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_RelationAdded{
+ RelationAdded: rel,
+ },
+ })
+
+ case nsdiff.ChangedPermissionComment:
+ permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ perm, err := permissionAPIRepr(permission, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_PermissionDocCommentChanged{
+ PermissionDocCommentChanged: perm,
+ },
+ })
+
+ case nsdiff.ChangedPermissionImpl:
+ permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ perm, err := permissionAPIRepr(permission, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_PermissionExprChanged{
+ PermissionExprChanged: perm,
+ },
+ })
+
+ case nsdiff.ChangedRelationComment:
+ relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := relationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_RelationDocCommentChanged{
+ RelationDocCommentChanged: rel,
+ },
+ })
+
+ case nsdiff.LegacyChangedRelationImpl:
+ return nil, spiceerrors.MustBugf("legacy relation implementation changes are not supported")
+
+ case nsdiff.NamespaceCommentsChanged:
+ def, err := namespaceAPIReprForName(nsName, comparisonSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_DefinitionDocCommentChanged{
+ DefinitionDocCommentChanged: def,
+ },
+ })
+
+ case nsdiff.RelationAllowedTypeRemoved:
+ relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := relationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_RelationSubjectTypeRemoved{
+ RelationSubjectTypeRemoved: &v1.ReflectionRelationSubjectTypeChange{
+ Relation: rel,
+ ChangedSubjectType: typeAPIRepr(delta.AllowedType),
+ },
+ },
+ })
+
+ case nsdiff.RelationAllowedTypeAdded:
+ relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := relationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_RelationSubjectTypeAdded{
+ RelationSubjectTypeAdded: &v1.ReflectionRelationSubjectTypeChange{
+ Relation: rel,
+ ChangedSubjectType: typeAPIRepr(delta.AllowedType),
+ },
+ },
+ })
+
+ case nsdiff.RemovedPermission:
+ permission, ok := existingSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ perm, err := permissionAPIRepr(permission, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_PermissionRemoved{
+ PermissionRemoved: perm,
+ },
+ })
+
+ case nsdiff.RemovedRelation:
+ relation, ok := existingSchema.GetRelation(nsName, delta.RelationName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName)
+ }
+
+ rel, err := relationAPIRepr(relation, nsName, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_RelationRemoved{
+ RelationRemoved: rel,
+ },
+ })
+
+ case nsdiff.NamespaceAdded:
+ return nil, spiceerrors.MustBugf("should be handled above")
+
+ case nsdiff.NamespaceRemoved:
+ return nil, spiceerrors.MustBugf("should be handled above")
+
+ default:
+ return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type)
+ }
+ }
+ }
+
+ // Changed caveats.
+ for caveatName, caveatDiff := range diff.ChangedCaveats {
+ for _, delta := range caveatDiff.Deltas() {
+ switch delta.Type {
+ case caveatdiff.CaveatCommentsChanged:
+ caveat, err := caveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_CaveatDocCommentChanged{
+ CaveatDocCommentChanged: caveat,
+ },
+ })
+
+ case caveatdiff.AddedParameter:
+ paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_CaveatParameterAdded{
+ CaveatParameterAdded: paramDef,
+ },
+ })
+
+ case caveatdiff.RemovedParameter:
+ paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_CaveatParameterRemoved{
+ CaveatParameterRemoved: paramDef,
+ },
+ })
+
+ case caveatdiff.ParameterTypeChanged:
+ previousParamDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_CaveatParameterTypeChanged{
+ CaveatParameterTypeChanged: &v1.ReflectionCaveatParameterTypeChange{
+ Parameter: paramDef,
+ PreviousType: previousParamDef.Type,
+ },
+ },
+ })
+
+ case caveatdiff.CaveatExpressionChanged:
+ caveat, err := caveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, err
+ }
+
+ diffs = append(diffs, &v1.ReflectionSchemaDiff{
+ Diff: &v1.ReflectionSchemaDiff_CaveatExprChanged{
+ CaveatExprChanged: caveat,
+ },
+ })
+
+ case caveatdiff.CaveatAdded:
+ return nil, spiceerrors.MustBugf("should be handled above")
+
+ case caveatdiff.CaveatRemoved:
+ return nil, spiceerrors.MustBugf("should be handled above")
+
+ default:
+ return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type)
+ }
+ }
+ }
+
+ return &v1.DiffSchemaResponse{
+ Diffs: diffs,
+ ReadAt: zedtoken.MustNewFromRevision(atRevision),
+ }, nil
+}
+
+// namespaceAPIReprForName builds an API representation of a namespace.
+func namespaceAPIReprForName(namespaceName string, schema *diff.DiffableSchema) (*v1.ReflectionDefinition, error) {
+ nsDef, ok := schema.GetNamespace(namespaceName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("namespace %q not found in schema", namespaceName)
+ }
+
+ return namespaceAPIRepr(nsDef, nil)
+}
+
+func namespaceAPIRepr(nsDef *core.NamespaceDefinition, schemaFilters *schemaFilters) (*v1.ReflectionDefinition, error) {
+ if schemaFilters != nil && !schemaFilters.HasNamespace(nsDef.Name) {
+ return nil, nil
+ }
+
+ relations := make([]*v1.ReflectionRelation, 0, len(nsDef.Relation))
+ permissions := make([]*v1.ReflectionPermission, 0, len(nsDef.Relation))
+
+ for _, rel := range nsDef.Relation {
+ if namespace.GetRelationKind(rel) == iv1.RelationMetadata_PERMISSION {
+ permission, err := permissionAPIRepr(rel, nsDef.Name, schemaFilters)
+ if err != nil {
+ return nil, err
+ }
+
+ if permission != nil {
+ permissions = append(permissions, permission)
+ }
+ continue
+ }
+
+ relation, err := relationAPIRepr(rel, nsDef.Name, schemaFilters)
+ if err != nil {
+ return nil, err
+ }
+
+ if relation != nil {
+ relations = append(relations, relation)
+ }
+ }
+
+ comments := namespace.GetComments(nsDef.Metadata)
+ return &v1.ReflectionDefinition{
+ Name: nsDef.Name,
+ Comment: strings.Join(comments, "\n"),
+ Relations: relations,
+ Permissions: permissions,
+ }, nil
+}
+
+// permissionAPIRepr builds an API representation of a permission.
+func permissionAPIRepr(relation *core.Relation, parentDefName string, schemaFilters *schemaFilters) (*v1.ReflectionPermission, error) {
+ if schemaFilters != nil && !schemaFilters.HasPermission(parentDefName, relation.Name) {
+ return nil, nil
+ }
+
+ comments := namespace.GetComments(relation.Metadata)
+ return &v1.ReflectionPermission{
+ Name: relation.Name,
+ Comment: strings.Join(comments, "\n"),
+ ParentDefinitionName: parentDefName,
+ }, nil
+}
+
+// relationAPIRepresentation builds an API representation of a relation.
+func relationAPIRepr(relation *core.Relation, parentDefName string, schemaFilters *schemaFilters) (*v1.ReflectionRelation, error) {
+ if schemaFilters != nil && !schemaFilters.HasRelation(parentDefName, relation.Name) {
+ return nil, nil
+ }
+
+ comments := namespace.GetComments(relation.Metadata)
+
+ var subjectTypes []*v1.ReflectionTypeReference
+ if relation.TypeInformation != nil {
+ subjectTypes = make([]*v1.ReflectionTypeReference, 0, len(relation.TypeInformation.AllowedDirectRelations))
+ for _, subjectType := range relation.TypeInformation.AllowedDirectRelations {
+ typeref := typeAPIRepr(subjectType)
+ subjectTypes = append(subjectTypes, typeref)
+ }
+ }
+
+ return &v1.ReflectionRelation{
+ Name: relation.Name,
+ Comment: strings.Join(comments, "\n"),
+ ParentDefinitionName: parentDefName,
+ SubjectTypes: subjectTypes,
+ }, nil
+}
+
+// typeAPIRepr builds an API representation of a type.
+func typeAPIRepr(subjectType *core.AllowedRelation) *v1.ReflectionTypeReference {
+ typeref := &v1.ReflectionTypeReference{
+ SubjectDefinitionName: subjectType.Namespace,
+ Typeref: &v1.ReflectionTypeReference_IsTerminalSubject{},
+ }
+
+ if subjectType.GetRelation() != tuple.Ellipsis && subjectType.GetRelation() != "" {
+ typeref.Typeref = &v1.ReflectionTypeReference_OptionalRelationName{
+ OptionalRelationName: subjectType.GetRelation(),
+ }
+ } else if subjectType.GetPublicWildcard() != nil {
+ typeref.Typeref = &v1.ReflectionTypeReference_IsPublicWildcard{
+ IsPublicWildcard: true,
+ }
+ }
+
+ if subjectType.GetRequiredCaveat() != nil {
+ typeref.OptionalCaveatName = subjectType.GetRequiredCaveat().CaveatName
+ }
+
+ return typeref
+}
+
+// caveatAPIReprForName builds an API representation of a caveat.
+func caveatAPIReprForName(caveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveat, error) {
+ caveatDef, ok := schema.GetCaveat(caveatName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("caveat %q not found in schema", caveatName)
+ }
+
+ return caveatAPIRepr(caveatDef, nil, caveatTypeSet)
+}
+
+// caveatAPIRepr builds an API representation of a caveat.
+func caveatAPIRepr(caveatDef *core.CaveatDefinition, schemaFilters *schemaFilters, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveat, error) {
+ if schemaFilters != nil && !schemaFilters.HasCaveat(caveatDef.Name) {
+ return nil, nil
+ }
+
+ parameters := make([]*v1.ReflectionCaveatParameter, 0, len(caveatDef.ParameterTypes))
+ paramNames := maps.Keys(caveatDef.ParameterTypes)
+ sort.Strings(paramNames)
+
+ for _, paramName := range paramNames {
+ paramType, ok := caveatDef.ParameterTypes[paramName]
+ if !ok {
+ return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, caveatDef.Name)
+ }
+
+ decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err)
+ }
+
+ parameters = append(parameters, &v1.ReflectionCaveatParameter{
+ Name: paramName,
+ Type: decoded.String(),
+ ParentCaveatName: caveatDef.Name,
+ })
+ }
+
+ parameterTypes, err := caveattypes.DecodeParameterTypes(caveatTypeSet, caveatDef.ParameterTypes)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid caveat parameters: %v", err)
+ }
+
+ deserializedReflectionression, err := caveats.DeserializeCaveatWithTypeSet(caveatTypeSet, caveatDef.SerializedExpression, parameterTypes)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid caveat expression bytes: %v", err)
+ }
+
+ exprString, err := deserializedReflectionression.ExprString()
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid caveat expression: %v", err)
+ }
+
+ comments := namespace.GetComments(caveatDef.Metadata)
+ return &v1.ReflectionCaveat{
+ Name: caveatDef.Name,
+ Comment: strings.Join(comments, "\n"),
+ Parameters: parameters,
+ Expression: exprString,
+ }, nil
+}
+
+// caveatAPIParamRepresentation builds an API representation of a caveat parameter.
+func caveatAPIParamRepr(paramName, parentCaveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveatParameter, error) {
+ caveatDef, ok := schema.GetCaveat(parentCaveatName)
+ if !ok {
+ return nil, spiceerrors.MustBugf("caveat %q not found in schema", parentCaveatName)
+ }
+
+ paramType, ok := caveatDef.ParameterTypes[paramName]
+ if !ok {
+ return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, parentCaveatName)
+ }
+
+ decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType)
+ if err != nil {
+ return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err)
+ }
+
+ return &v1.ReflectionCaveatParameter{
+ Name: paramName,
+ Type: decoded.String(),
+ ParentCaveatName: parentCaveatName,
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go
new file mode 100644
index 0000000..a572216
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go
@@ -0,0 +1,76 @@
+package v1
+
+import (
+ "context"
+
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/diff"
+ "github.com/authzed/spicedb/pkg/middleware/consistency"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+func loadCurrentSchema(ctx context.Context) (*diff.DiffableSchema, datastore.Revision, error) {
+ ds := datastoremw.MustFromContext(ctx)
+
+ atRevision, _, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ reader := ds.SnapshotReader(atRevision)
+
+ namespacesAndRevs, err := reader.ListAllNamespaces(ctx)
+ if err != nil {
+ return nil, atRevision, err
+ }
+
+ caveatsAndRevs, err := reader.ListAllCaveats(ctx)
+ if err != nil {
+ return nil, atRevision, err
+ }
+
+ namespaces := make([]*core.NamespaceDefinition, 0, len(namespacesAndRevs))
+ for _, namespaceAndRev := range namespacesAndRevs {
+ namespaces = append(namespaces, namespaceAndRev.Definition)
+ }
+
+ caveats := make([]*core.CaveatDefinition, 0, len(caveatsAndRevs))
+ for _, caveatAndRev := range caveatsAndRevs {
+ caveats = append(caveats, caveatAndRev.Definition)
+ }
+
+ return &diff.DiffableSchema{
+ ObjectDefinitions: namespaces,
+ CaveatDefinitions: caveats,
+ }, atRevision, nil
+}
+
+func schemaDiff(ctx context.Context, comparisonSchemaString string, caveatTypeSet *caveattypes.TypeSet) (*diff.SchemaDiff, *diff.DiffableSchema, *diff.DiffableSchema, error) {
+ existingSchema, _, err := loadCurrentSchema(ctx)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ // Compile the comparison schema.
+ compiled, err := compiler.Compile(compiler.InputSchema{
+ Source: input.Source("schema"),
+ SchemaString: comparisonSchemaString,
+ }, compiler.AllowUnprefixedObjectType(), compiler.CaveatTypeSet(caveatTypeSet))
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ comparisonSchema := diff.NewDiffableSchemaFromCompiledSchema(compiled)
+
+ diff, err := diff.DiffSchemas(*existingSchema, comparisonSchema, caveatTypeSet)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ // Return the diff.
+ return diff, existingSchema, &comparisonSchema, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go b/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go
new file mode 100644
index 0000000..f0b2138
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go
@@ -0,0 +1,576 @@
+package v1
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
+ "github.com/jzelinskie/stringz"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "go.opentelemetry.io/otel/trace"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/types/known/structpb"
+
+ "github.com/authzed/spicedb/internal/dispatch"
+ "github.com/authzed/spicedb/internal/middleware"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/middleware/handwrittenvalidation"
+ "github.com/authzed/spicedb/internal/middleware/streamtimeout"
+ "github.com/authzed/spicedb/internal/middleware/usagemetrics"
+ "github.com/authzed/spicedb/internal/namespace"
+ "github.com/authzed/spicedb/internal/relationships"
+ "github.com/authzed/spicedb/internal/services/shared"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/cursor"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/datastore/options"
+ "github.com/authzed/spicedb/pkg/datastore/pagination"
+ "github.com/authzed/spicedb/pkg/datastore/queryshape"
+ "github.com/authzed/spicedb/pkg/genutil"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/middleware/consistency"
+ dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/zedtoken"
+)
+
+var writeUpdateCounter = promauto.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: "spicedb",
+ Subsystem: "v1",
+ Name: "write_relationships_updates",
+ Help: "The update counts for the WriteRelationships calls",
+ Buckets: []float64{0, 1, 2, 5, 10, 15, 25, 50, 100, 250, 500, 1000},
+}, []string{"kind"})
+
+const MaximumTransactionMetadataSize = 65000 // bytes. Limited by the BLOB size used in MySQL driver
+
+// PermissionsServerConfig is configuration for the permissions server.
+type PermissionsServerConfig struct {
+ // MaxUpdatesPerWrite holds the maximum number of updates allowed per
+ // WriteRelationships call.
+ MaxUpdatesPerWrite uint16
+
+ // MaxPreconditionsCount holds the maximum number of preconditions allowed
+ // on a WriteRelationships or DeleteRelationships call.
+ MaxPreconditionsCount uint16
+
+ // MaximumAPIDepth is the default/starting depth remaining for API calls made
+ // to the permissions server.
+ MaximumAPIDepth uint32
+
+ // DispatchChunkSize is the maximum number of elements to dispach in a dispatch call
+ DispatchChunkSize uint16
+
+ // StreamingAPITimeout is the timeout for streaming APIs when no response has been
+ // recently received.
+ StreamingAPITimeout time.Duration
+
+ // MaxCaveatContextSize defines the maximum length of the request caveat context in bytes
+ MaxCaveatContextSize int
+
+ // MaxRelationshipContextSize defines the maximum length of a relationship's context in bytes
+ MaxRelationshipContextSize int
+
+ // MaxDatastoreReadPageSize defines the maximum number of relationships loaded from the
+ // datastore in one query.
+ MaxDatastoreReadPageSize uint64
+
+ // MaxCheckBulkConcurrency defines the maximum number of concurrent checks that can be
+ // made in a single CheckBulkPermissions call.
+ MaxCheckBulkConcurrency uint16
+
+ // MaxReadRelationshipsLimit defines the maximum number of relationships that can be read
+ // in a single ReadRelationships call.
+ MaxReadRelationshipsLimit uint32
+
+ // MaxDeleteRelationshipsLimit defines the maximum number of relationships that can be deleted
+ // in a single DeleteRelationships call.
+ MaxDeleteRelationshipsLimit uint32
+
+ // MaxLookupResourcesLimit defines the maximum number of resources that can be looked up in a
+ // single LookupResources call.
+ MaxLookupResourcesLimit uint32
+
+ // MaxBulkExportRelationshipsLimit defines the maximum number of relationships that can be
+ // exported in a single BulkExportRelationships call.
+ MaxBulkExportRelationshipsLimit uint32
+
+ // ExpiringRelationshipsEnabled defines whether or not expiring relationships are enabled.
+ ExpiringRelationshipsEnabled bool
+
+ // CaveatTypeSet is the set of caveat types to use for caveats. If not specified,
+ // the default type set is used.
+ CaveatTypeSet *caveattypes.TypeSet
+}
+
+// NewPermissionsServer creates a PermissionsServiceServer instance.
+func NewPermissionsServer(
+ dispatch dispatch.Dispatcher,
+ config PermissionsServerConfig,
+) v1.PermissionsServiceServer {
+ configWithDefaults := PermissionsServerConfig{
+ MaxPreconditionsCount: defaultIfZero(config.MaxPreconditionsCount, 1000),
+ MaxUpdatesPerWrite: defaultIfZero(config.MaxUpdatesPerWrite, 1000),
+ MaximumAPIDepth: defaultIfZero(config.MaximumAPIDepth, 50),
+ StreamingAPITimeout: defaultIfZero(config.StreamingAPITimeout, 30*time.Second),
+ MaxCaveatContextSize: defaultIfZero(config.MaxCaveatContextSize, 4096),
+ MaxRelationshipContextSize: defaultIfZero(config.MaxRelationshipContextSize, 25_000),
+ MaxDatastoreReadPageSize: defaultIfZero(config.MaxDatastoreReadPageSize, 1_000),
+ MaxReadRelationshipsLimit: defaultIfZero(config.MaxReadRelationshipsLimit, 1_000),
+ MaxDeleteRelationshipsLimit: defaultIfZero(config.MaxDeleteRelationshipsLimit, 1_000),
+ MaxLookupResourcesLimit: defaultIfZero(config.MaxLookupResourcesLimit, 1_000),
+ MaxBulkExportRelationshipsLimit: defaultIfZero(config.MaxBulkExportRelationshipsLimit, 100_000),
+ DispatchChunkSize: defaultIfZero(config.DispatchChunkSize, 100),
+ MaxCheckBulkConcurrency: defaultIfZero(config.MaxCheckBulkConcurrency, 50),
+ CaveatTypeSet: caveattypes.TypeSetOrDefault(config.CaveatTypeSet),
+ ExpiringRelationshipsEnabled: true,
+ }
+
+ return &permissionServer{
+ dispatch: dispatch,
+ config: configWithDefaults,
+ WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
+ Unary: middleware.ChainUnaryServer(
+ grpcvalidate.UnaryServerInterceptor(),
+ handwrittenvalidation.UnaryServerInterceptor,
+ usagemetrics.UnaryServerInterceptor(),
+ ),
+ Stream: middleware.ChainStreamServer(
+ grpcvalidate.StreamServerInterceptor(),
+ handwrittenvalidation.StreamServerInterceptor,
+ usagemetrics.StreamServerInterceptor(),
+ streamtimeout.MustStreamServerInterceptor(configWithDefaults.StreamingAPITimeout),
+ ),
+ },
+ bulkChecker: &bulkChecker{
+ maxAPIDepth: configWithDefaults.MaximumAPIDepth,
+ maxCaveatContextSize: configWithDefaults.MaxCaveatContextSize,
+ maxConcurrency: configWithDefaults.MaxCheckBulkConcurrency,
+ dispatch: dispatch,
+ dispatchChunkSize: configWithDefaults.DispatchChunkSize,
+ caveatTypeSet: configWithDefaults.CaveatTypeSet,
+ },
+ }
+}
+
+type permissionServer struct {
+ v1.UnimplementedPermissionsServiceServer
+ shared.WithServiceSpecificInterceptors
+
+ dispatch dispatch.Dispatcher
+ config PermissionsServerConfig
+
+ bulkChecker *bulkChecker
+}
+
+func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest, resp v1.PermissionsService_ReadRelationshipsServer) error {
+ if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxReadRelationshipsLimit {
+ return ps.rewriteError(resp.Context(), NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxReadRelationshipsLimit)))
+ }
+
+ ctx := resp.Context()
+ atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+
+ if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, ds); err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
+ DispatchCount: 1,
+ })
+
+ limit := uint64(0)
+ var startCursor options.Cursor
+
+ rrRequestHash, err := computeReadRelationshipsRequestHash(req)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ if req.OptionalCursor != nil {
+ decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, rrRequestHash)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ if len(decodedCursor.Sections) != 1 {
+ return ps.rewriteError(ctx, NewInvalidCursorErr("did not find expected resume relationship"))
+ }
+
+ parsed, err := tuple.Parse(decodedCursor.Sections[0])
+ if err != nil {
+ return ps.rewriteError(ctx, NewInvalidCursorErr("could not parse resume relationship"))
+ }
+
+ startCursor = options.ToCursor(parsed)
+ }
+
+ pageSize := ps.config.MaxDatastoreReadPageSize
+ if req.OptionalLimit > 0 {
+ limit = uint64(req.OptionalLimit)
+ if limit < pageSize {
+ pageSize = limit
+ }
+ }
+
+ dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(req.RelationshipFilter)
+ if err != nil {
+ return ps.rewriteError(ctx, fmt.Errorf("error filtering: %w", err))
+ }
+
+ it, err := pagination.NewPaginatedIterator(
+ ctx,
+ ds,
+ dsFilter,
+ pageSize,
+ options.ByResource,
+ startCursor,
+ queryshape.Varying,
+ )
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ response := &v1.ReadRelationshipsResponse{
+ ReadAt: revisionReadAt,
+ Relationship: &v1.Relationship{
+ Resource: &v1.ObjectReference{},
+ Subject: &v1.SubjectReference{
+ Object: &v1.ObjectReference{},
+ },
+ },
+ }
+
+ dispatchCursor := &dispatchv1.Cursor{
+ DispatchVersion: 1,
+ Sections: []string{""},
+ }
+
+ var returnedCount uint64
+ for rel, err := range it {
+ if err != nil {
+ return ps.rewriteError(ctx, fmt.Errorf("error when reading tuples: %w", err))
+ }
+
+ if limit > 0 && returnedCount >= limit {
+ break
+ }
+
+ dispatchCursor.Sections[0] = tuple.StringWithoutCaveatOrExpiration(rel)
+ encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, nil)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ tuple.CopyToV1Relationship(rel, response.Relationship)
+ response.AfterResultCursor = encodedCursor
+
+ err = resp.Send(response)
+ if err != nil {
+ return ps.rewriteError(ctx, fmt.Errorf("error when streaming tuple: %w", err))
+ }
+ returnedCount++
+ }
+ return nil
+}
+
+func (ps *permissionServer) WriteRelationships(ctx context.Context, req *v1.WriteRelationshipsRequest) (*v1.WriteRelationshipsResponse, error) {
+ if err := ps.validateTransactionMetadata(req.OptionalTransactionMetadata); err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx)
+
+ span := trace.SpanFromContext(ctx)
+ span.AddEvent("validating mutations")
+ // Ensure that the updates and preconditions are not over the configured limits.
+ if len(req.Updates) > int(ps.config.MaxUpdatesPerWrite) {
+ return nil, ps.rewriteError(
+ ctx,
+ NewExceedsMaximumUpdatesErr(uint64(len(req.Updates)), uint64(ps.config.MaxUpdatesPerWrite)),
+ )
+ }
+
+ if len(req.OptionalPreconditions) > int(ps.config.MaxPreconditionsCount) {
+ return nil, ps.rewriteError(
+ ctx,
+ NewExceedsMaximumPreconditionsErr(uint64(len(req.OptionalPreconditions)), uint64(ps.config.MaxPreconditionsCount)),
+ )
+ }
+
+ // Check for duplicate updates and create the set of caveat names to load.
+ updateRelationshipSet := mapz.NewSet[string]()
+ for _, update := range req.Updates {
+ // TODO(jschorr): Change to struct-based keys.
+ tupleStr := tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship)
+ if !updateRelationshipSet.Add(tupleStr) {
+ return nil, ps.rewriteError(
+ ctx,
+ NewDuplicateRelationshipErr(update),
+ )
+ }
+ if proto.Size(update.Relationship.OptionalCaveat) > ps.config.MaxRelationshipContextSize {
+ return nil, ps.rewriteError(
+ ctx,
+ NewMaxRelationshipContextError(update, ps.config.MaxRelationshipContextSize),
+ )
+ }
+
+ if !ps.config.ExpiringRelationshipsEnabled && update.Relationship.OptionalExpiresAt != nil {
+ return nil, ps.rewriteError(
+ ctx,
+ fmt.Errorf("support for expiring relationships is not enabled"),
+ )
+ }
+ }
+
+ // Execute the write operation(s).
+ span.AddEvent("read write transaction")
+ relUpdates, err := tuple.UpdatesFromV1RelationshipUpdates(req.Updates)
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ span.AddEvent("preconditions")
+
+ // Validate the preconditions.
+ for _, precond := range req.OptionalPreconditions {
+ if err := validatePrecondition(ctx, precond, rwt); err != nil {
+ return err
+ }
+ }
+
+ // Validate the updates.
+ span.AddEvent("validate updates")
+ err := relationships.ValidateRelationshipUpdates(ctx, rwt, ps.config.CaveatTypeSet, relUpdates)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ dispatchCount, err := genutil.EnsureUInt32(len(req.OptionalPreconditions) + 1)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
+ // One request per precondition and one request for the actual writes.
+ DispatchCount: dispatchCount,
+ })
+
+ span.AddEvent("preconditions")
+ if err := checkPreconditions(ctx, rwt, req.OptionalPreconditions); err != nil {
+ return err
+ }
+
+ span.AddEvent("write relationships")
+ return rwt.WriteRelationships(ctx, relUpdates)
+ }, options.WithMetadata(req.OptionalTransactionMetadata))
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ // Log a metric of the counts of the different kinds of update operations.
+ updateCountByOperation := make(map[v1.RelationshipUpdate_Operation]int, 0)
+ for _, update := range req.Updates {
+ updateCountByOperation[update.Operation]++
+ }
+
+ for kind, count := range updateCountByOperation {
+ writeUpdateCounter.WithLabelValues(v1.RelationshipUpdate_Operation_name[int32(kind)]).Observe(float64(count))
+ }
+
+ return &v1.WriteRelationshipsResponse{
+ WrittenAt: zedtoken.MustNewFromRevision(revision),
+ }, nil
+}
+
+func (ps *permissionServer) validateTransactionMetadata(metadata *structpb.Struct) error {
+ if metadata == nil {
+ return nil
+ }
+
+ b, err := metadata.MarshalJSON()
+ if err != nil {
+ return err
+ }
+
+ if len(b) > MaximumTransactionMetadataSize {
+ return NewTransactionMetadataTooLargeErr(len(b), MaximumTransactionMetadataSize)
+ }
+
+ return nil
+}
+
+func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.DeleteRelationshipsRequest) (*v1.DeleteRelationshipsResponse, error) {
+ if err := ps.validateTransactionMetadata(req.OptionalTransactionMetadata); err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ if len(req.OptionalPreconditions) > int(ps.config.MaxPreconditionsCount) {
+ return nil, ps.rewriteError(
+ ctx,
+ NewExceedsMaximumPreconditionsErr(uint64(len(req.OptionalPreconditions)), uint64(ps.config.MaxPreconditionsCount)),
+ )
+ }
+
+ if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxDeleteRelationshipsLimit {
+ return nil, ps.rewriteError(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxDeleteRelationshipsLimit)))
+ }
+
+ ds := datastoremw.MustFromContext(ctx)
+ deletionProgress := v1.DeleteRelationshipsResponse_DELETION_PROGRESS_COMPLETE
+
+ var deletedRelationshipCount uint64
+ revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, rwt); err != nil {
+ return err
+ }
+
+ dispatchCount, err := genutil.EnsureUInt32(len(req.OptionalPreconditions) + 1)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
+ // One request per precondition and one request for the actual delete.
+ DispatchCount: dispatchCount,
+ })
+
+ for _, precond := range req.OptionalPreconditions {
+ if err := validatePrecondition(ctx, precond, rwt); err != nil {
+ return err
+ }
+ }
+
+ if err := checkPreconditions(ctx, rwt, req.OptionalPreconditions); err != nil {
+ return err
+ }
+
+ // If a limit was specified but partial deletion is not allowed, we need to check if the
+ // number of relationships to be deleted exceeds the limit.
+ if req.OptionalLimit > 0 && !req.OptionalAllowPartialDeletions {
+ limit := uint64(req.OptionalLimit)
+ limitPlusOne := limit + 1
+ filter, err := datastore.RelationshipsFilterFromPublicFilter(req.RelationshipFilter)
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ it, err := rwt.QueryRelationships(ctx, filter, options.WithLimit(&limitPlusOne), options.WithQueryShape(queryshape.Varying))
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ counter := uint64(0)
+ for _, err := range it {
+ if err != nil {
+ return ps.rewriteError(ctx, err)
+ }
+
+ if counter == limit {
+ return ps.rewriteError(ctx, NewCouldNotTransactionallyDeleteErr(req.RelationshipFilter, req.OptionalLimit))
+ }
+
+ counter++
+ }
+ }
+
+ // Delete with the specified limit.
+ if req.OptionalLimit > 0 {
+ deleteLimit := uint64(req.OptionalLimit)
+ drc, reachedLimit, err := rwt.DeleteRelationships(ctx, req.RelationshipFilter, options.WithDeleteLimit(&deleteLimit))
+ if err != nil {
+ return err
+ }
+
+ if reachedLimit {
+ deletionProgress = v1.DeleteRelationshipsResponse_DELETION_PROGRESS_PARTIAL
+ }
+
+ deletedRelationshipCount = drc
+ return nil
+ }
+
+ // Otherwise, kick off an unlimited deletion.
+ deletedRelationshipCount, _, err = rwt.DeleteRelationships(ctx, req.RelationshipFilter)
+ return err
+ }, options.WithMetadata(req.OptionalTransactionMetadata))
+ if err != nil {
+ return nil, ps.rewriteError(ctx, err)
+ }
+
+ return &v1.DeleteRelationshipsResponse{
+ DeletedAt: zedtoken.MustNewFromRevision(revision),
+ DeletionProgress: deletionProgress,
+ RelationshipsDeletedCount: deletedRelationshipCount,
+ }, nil
+}
+
+var emptyPrecondition = &v1.Precondition{}
+
+func validatePrecondition(ctx context.Context, precond *v1.Precondition, reader datastore.Reader) error {
+ if precond.EqualVT(emptyPrecondition) || precond.Filter == nil {
+ return NewEmptyPreconditionErr()
+ }
+
+ return validateRelationshipsFilter(ctx, precond.Filter, reader)
+}
+
+func checkFilterComponent(ctx context.Context, objectType, optionalRelation string, ds datastore.Reader) error {
+ if objectType == "" {
+ return nil
+ }
+
+ relationToTest := stringz.DefaultEmpty(optionalRelation, datastore.Ellipsis)
+ allowEllipsis := optionalRelation == ""
+ return namespace.CheckNamespaceAndRelation(ctx, objectType, relationToTest, allowEllipsis, ds)
+}
+
+func validateRelationshipsFilter(ctx context.Context, filter *v1.RelationshipFilter, ds datastore.Reader) error {
+ // ResourceType is optional, so only check the relation if it is specified.
+ if filter.ResourceType != "" {
+ if err := checkFilterComponent(ctx, filter.ResourceType, filter.OptionalRelation, ds); err != nil {
+ return err
+ }
+ }
+
+ // SubjectFilter is optional, so only check if it is specified.
+ if subjectFilter := filter.OptionalSubjectFilter; subjectFilter != nil {
+ subjectRelation := ""
+ if subjectFilter.OptionalRelation != nil {
+ subjectRelation = subjectFilter.OptionalRelation.Relation
+ }
+ if err := checkFilterComponent(ctx, subjectFilter.SubjectType, subjectRelation, ds); err != nil {
+ return err
+ }
+ }
+
+ // Ensure the resource ID and the resource ID prefix are not set at the same time.
+ if filter.OptionalResourceId != "" && filter.OptionalResourceIdPrefix != "" {
+ return NewInvalidFilterErr("resource_id and resource_id_prefix cannot be set at the same time", filter.String())
+ }
+
+ // Ensure that at least one field is set.
+ return checkIfFilterIsEmpty(filter)
+}
+
+func checkIfFilterIsEmpty(filter *v1.RelationshipFilter) error {
+ if filter.ResourceType == "" &&
+ filter.OptionalResourceId == "" &&
+ filter.OptionalResourceIdPrefix == "" &&
+ filter.OptionalRelation == "" &&
+ filter.OptionalSubjectFilter == nil {
+ return NewInvalidFilterErr("at least one field must be set", filter.String())
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go b/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go
new file mode 100644
index 0000000..14faf3d
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go
@@ -0,0 +1,375 @@
+package v1
+
+import (
+ "context"
+ "sort"
+ "strings"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/internal/middleware"
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/middleware/usagemetrics"
+ "github.com/authzed/spicedb/internal/services/shared"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/genutil"
+ "github.com/authzed/spicedb/pkg/middleware/consistency"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/generator"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/zedtoken"
+)
+
+// NewSchemaServer creates a SchemaServiceServer instance.
+func NewSchemaServer(caveatTypeSet *caveattypes.TypeSet, additiveOnly bool, expiringRelsEnabled bool) v1.SchemaServiceServer {
+ cts := caveattypes.TypeSetOrDefault(caveatTypeSet)
+ return &schemaServer{
+ WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
+ Unary: middleware.ChainUnaryServer(
+ grpcvalidate.UnaryServerInterceptor(),
+ usagemetrics.UnaryServerInterceptor(),
+ ),
+ Stream: middleware.ChainStreamServer(
+ grpcvalidate.StreamServerInterceptor(),
+ usagemetrics.StreamServerInterceptor(),
+ ),
+ },
+ additiveOnly: additiveOnly,
+ expiringRelsEnabled: expiringRelsEnabled,
+ caveatTypeSet: cts,
+ }
+}
+
+type schemaServer struct {
+ v1.UnimplementedSchemaServiceServer
+ shared.WithServiceSpecificInterceptors
+
+ caveatTypeSet *caveattypes.TypeSet
+ additiveOnly bool
+ expiringRelsEnabled bool
+}
+
+func (ss *schemaServer) rewriteError(ctx context.Context, err error) error {
+ return shared.RewriteError(ctx, err, nil)
+}
+
+func (ss *schemaServer) ReadSchema(ctx context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) {
+ // Schema is always read from the head revision.
+ ds := datastoremw.MustFromContext(ctx)
+ headRevision, err := ds.HeadRevision(ctx)
+ if err != nil {
+ return nil, ss.rewriteError(ctx, err)
+ }
+
+ reader := ds.SnapshotReader(headRevision)
+
+ nsDefs, err := reader.ListAllNamespaces(ctx)
+ if err != nil {
+ return nil, ss.rewriteError(ctx, err)
+ }
+
+ caveatDefs, err := reader.ListAllCaveats(ctx)
+ if err != nil {
+ return nil, ss.rewriteError(ctx, err)
+ }
+
+ if len(nsDefs) == 0 {
+ return nil, status.Errorf(codes.NotFound, "No schema has been defined; please call WriteSchema to start")
+ }
+
+ schemaDefinitions := make([]compiler.SchemaDefinition, 0, len(nsDefs)+len(caveatDefs))
+ for _, caveatDef := range caveatDefs {
+ schemaDefinitions = append(schemaDefinitions, caveatDef.Definition)
+ }
+
+ for _, nsDef := range nsDefs {
+ schemaDefinitions = append(schemaDefinitions, nsDef.Definition)
+ }
+
+ schemaText, _, err := generator.GenerateSchema(schemaDefinitions)
+ if err != nil {
+ return nil, ss.rewriteError(ctx, err)
+ }
+
+ dispatchCount, err := genutil.EnsureUInt32(len(nsDefs) + len(caveatDefs))
+ if err != nil {
+ return nil, ss.rewriteError(ctx, err)
+ }
+
+ usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
+ DispatchCount: dispatchCount,
+ })
+
+ return &v1.ReadSchemaResponse{
+ SchemaText: schemaText,
+ ReadAt: zedtoken.MustNewFromRevision(headRevision),
+ }, nil
+}
+
+func (ss *schemaServer) WriteSchema(ctx context.Context, in *v1.WriteSchemaRequest) (*v1.WriteSchemaResponse, error) {
+ log.Ctx(ctx).Trace().Str("schema", in.GetSchema()).Msg("requested Schema to be written")
+
+ ds := datastoremw.MustFromContext(ctx)
+
+ // Compile the schema into the namespace definitions.
+ opts := make([]compiler.Option, 0, 3)
+ if !ss.expiringRelsEnabled {
+ opts = append(opts, compiler.DisallowExpirationFlag())
+ }
+
+ opts = append(opts, compiler.CaveatTypeSet(ss.caveatTypeSet))
+
+ compiled, err := compiler.Compile(compiler.InputSchema{
+ Source: input.Source("schema"),
+ SchemaString: in.GetSchema(),
+ }, compiler.AllowUnprefixedObjectType(), opts...)
+ if err != nil {
+ return nil, ss.rewriteError(ctx, err)
+ }
+ log.Ctx(ctx).Trace().Int("objectDefinitions", len(compiled.ObjectDefinitions)).Int("caveatDefinitions", len(compiled.CaveatDefinitions)).Msg("compiled namespace definitions")
+
+ // Do as much validation as we can before talking to the datastore.
+ validated, err := shared.ValidateSchemaChanges(ctx, compiled, ss.caveatTypeSet, ss.additiveOnly)
+ if err != nil {
+ return nil, ss.rewriteError(ctx, err)
+ }
+
+ // Update the schema.
+ revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
+ applied, err := shared.ApplySchemaChanges(ctx, rwt, ss.caveatTypeSet, validated)
+ if err != nil {
+ return err
+ }
+
+ dispatchCount, err := genutil.EnsureUInt32(applied.TotalOperationCount)
+ if err != nil {
+ return err
+ }
+
+ usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
+ DispatchCount: dispatchCount,
+ })
+ return nil
+ })
+ if err != nil {
+ return nil, ss.rewriteError(ctx, err)
+ }
+
+ return &v1.WriteSchemaResponse{
+ WrittenAt: zedtoken.MustNewFromRevision(revision),
+ }, nil
+}
+
+func (ss *schemaServer) ReflectSchema(ctx context.Context, req *v1.ReflectSchemaRequest) (*v1.ReflectSchemaResponse, error) {
+ // Get the current schema.
+ schema, atRevision, err := loadCurrentSchema(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ filters, err := newSchemaFilters(req.OptionalFilters)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ definitions := make([]*v1.ReflectionDefinition, 0, len(schema.ObjectDefinitions))
+ if filters.HasNamespaces() {
+ for _, ns := range schema.ObjectDefinitions {
+ def, err := namespaceAPIRepr(ns, filters)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ if def != nil {
+ definitions = append(definitions, def)
+ }
+ }
+ }
+
+ caveats := make([]*v1.ReflectionCaveat, 0, len(schema.CaveatDefinitions))
+ if filters.HasCaveats() {
+ for _, cd := range schema.CaveatDefinitions {
+ caveat, err := caveatAPIRepr(cd, filters, ss.caveatTypeSet)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ if caveat != nil {
+ caveats = append(caveats, caveat)
+ }
+ }
+ }
+
+ return &v1.ReflectSchemaResponse{
+ Definitions: definitions,
+ Caveats: caveats,
+ ReadAt: zedtoken.MustNewFromRevision(atRevision),
+ }, nil
+}
+
+func (ss *schemaServer) DiffSchema(ctx context.Context, req *v1.DiffSchemaRequest) (*v1.DiffSchemaResponse, error) {
+ atRevision, _, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ diff, existingSchema, comparisonSchema, err := schemaDiff(ctx, req.ComparisonSchema, ss.caveatTypeSet)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ resp, err := convertDiff(diff, existingSchema, comparisonSchema, atRevision, ss.caveatTypeSet)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ return resp, nil
+}
+
+func (ss *schemaServer) ComputablePermissions(ctx context.Context, req *v1.ComputablePermissionsRequest) (*v1.ComputablePermissionsResponse, error) {
+ atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+ ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds))
+ vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relationName := req.RelationName
+ if relationName == "" {
+ relationName = tuple.Ellipsis
+ } else {
+ if _, ok := vdef.GetRelation(relationName); !ok {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, relationName))
+ }
+ }
+
+ allNamespaces, err := ds.ListAllNamespaces(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ allDefinitions := make([]*core.NamespaceDefinition, 0, len(allNamespaces))
+ for _, ns := range allNamespaces {
+ allDefinitions = append(allDefinitions, ns.Definition)
+ }
+
+ rg := vdef.Reachability()
+ rr, err := rg.RelationsEncounteredForSubject(ctx, allDefinitions, &core.RelationReference{
+ Namespace: req.DefinitionName,
+ Relation: relationName,
+ })
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relations := make([]*v1.ReflectionRelationReference, 0, len(rr))
+ for _, r := range rr {
+ if r.Namespace == req.DefinitionName && r.Relation == req.RelationName {
+ continue
+ }
+
+ if req.OptionalDefinitionNameFilter != "" && !strings.HasPrefix(r.Namespace, req.OptionalDefinitionNameFilter) {
+ continue
+ }
+
+ ts, err := ts.GetDefinition(ctx, r.Namespace)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relations = append(relations, &v1.ReflectionRelationReference{
+ DefinitionName: r.Namespace,
+ RelationName: r.Relation,
+ IsPermission: ts.IsPermission(r.Relation),
+ })
+ }
+
+ sort.Slice(relations, func(i, j int) bool {
+ if relations[i].DefinitionName == relations[j].DefinitionName {
+ return relations[i].RelationName < relations[j].RelationName
+ }
+ return relations[i].DefinitionName < relations[j].DefinitionName
+ })
+
+ return &v1.ComputablePermissionsResponse{
+ Permissions: relations,
+ ReadAt: revisionReadAt,
+ }, nil
+}
+
+func (ss *schemaServer) DependentRelations(ctx context.Context, req *v1.DependentRelationsRequest) (*v1.DependentRelationsResponse, error) {
+ atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
+ ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds))
+ vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ _, ok := vdef.GetRelation(req.PermissionName)
+ if !ok {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, req.PermissionName))
+ }
+
+ if !vdef.IsPermission(req.PermissionName) {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, NewNotAPermissionError(req.PermissionName))
+ }
+
+ rg := vdef.Reachability()
+ rr, err := rg.RelationsEncounteredForResource(ctx, &core.RelationReference{
+ Namespace: req.DefinitionName,
+ Relation: req.PermissionName,
+ })
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relations := make([]*v1.ReflectionRelationReference, 0, len(rr))
+ for _, r := range rr {
+ if r.Namespace == req.DefinitionName && r.Relation == req.PermissionName {
+ continue
+ }
+
+ ts, err := ts.GetDefinition(ctx, r.Namespace)
+ if err != nil {
+ return nil, shared.RewriteErrorWithoutConfig(ctx, err)
+ }
+
+ relations = append(relations, &v1.ReflectionRelationReference{
+ DefinitionName: r.Namespace,
+ RelationName: r.Relation,
+ IsPermission: ts.IsPermission(r.Relation),
+ })
+ }
+
+ sort.Slice(relations, func(i, j int) bool {
+ if relations[i].DefinitionName == relations[j].DefinitionName {
+ return relations[i].RelationName < relations[j].RelationName
+ }
+
+ return relations[i].DefinitionName < relations[j].DefinitionName
+ })
+
+ return &v1.DependentRelationsResponse{
+ Relations: relations,
+ ReadAt: revisionReadAt,
+ }, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go b/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go
new file mode 100644
index 0000000..ef13a26
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go
@@ -0,0 +1,190 @@
+package v1
+
+import (
+ "context"
+ "errors"
+ "slices"
+ "time"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
+ "github.com/authzed/spicedb/internal/middleware/usagemetrics"
+ "github.com/authzed/spicedb/internal/services/shared"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/zedtoken"
+)
+
+type watchServer struct {
+ v1.UnimplementedWatchServiceServer
+ shared.WithStreamServiceSpecificInterceptor
+
+ heartbeatDuration time.Duration
+}
+
+// NewWatchServer creates an instance of the watch server.
+func NewWatchServer(heartbeatDuration time.Duration) v1.WatchServiceServer {
+ s := &watchServer{
+ WithStreamServiceSpecificInterceptor: shared.WithStreamServiceSpecificInterceptor{
+ Stream: grpcvalidate.StreamServerInterceptor(),
+ },
+ heartbeatDuration: heartbeatDuration,
+ }
+ return s
+}
+
+func (ws *watchServer) Watch(req *v1.WatchRequest, stream v1.WatchService_WatchServer) error {
+ if len(req.GetOptionalUpdateKinds()) == 0 ||
+ slices.Contains(req.GetOptionalUpdateKinds(), v1.WatchKind_WATCH_KIND_UNSPECIFIED) ||
+ slices.Contains(req.GetOptionalUpdateKinds(), v1.WatchKind_WATCH_KIND_INCLUDE_RELATIONSHIP_UPDATES) {
+ if len(req.GetOptionalObjectTypes()) > 0 && len(req.OptionalRelationshipFilters) > 0 {
+ return status.Errorf(codes.InvalidArgument, "cannot specify both object types and relationship filters")
+ }
+ }
+
+ objectTypes := mapz.NewSet[string](req.GetOptionalObjectTypes()...)
+
+ ctx := stream.Context()
+ ds := datastoremw.MustFromContext(ctx)
+
+ var afterRevision datastore.Revision
+ if req.OptionalStartCursor != nil && req.OptionalStartCursor.Token != "" {
+ decodedRevision, err := zedtoken.DecodeRevision(req.OptionalStartCursor, ds)
+ if err != nil {
+ return status.Errorf(codes.InvalidArgument, "failed to decode start revision: %s", err)
+ }
+
+ afterRevision = decodedRevision
+ } else {
+ var err error
+ afterRevision, err = ds.OptimizedRevision(ctx)
+ if err != nil {
+ return status.Errorf(codes.Unavailable, "failed to start watch: %s", err)
+ }
+ }
+
+ reader := ds.SnapshotReader(afterRevision)
+
+ filters, err := buildRelationshipFilters(req, stream, reader, ws, ctx)
+ if err != nil {
+ return err
+ }
+
+ usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
+ DispatchCount: 1,
+ })
+
+ updates, errchan := ds.Watch(ctx, afterRevision, datastore.WatchOptions{
+ Content: convertWatchKindToContent(req.OptionalUpdateKinds),
+ CheckpointInterval: ws.heartbeatDuration,
+ })
+ for {
+ select {
+ case update, ok := <-updates:
+ if ok {
+ filteredRelationshipUpdates := filterRelationshipUpdates(objectTypes, filters, update.RelationshipChanges)
+ if len(filteredRelationshipUpdates) > 0 {
+ converted, err := tuple.UpdatesToV1RelationshipUpdates(filteredRelationshipUpdates)
+ if err != nil {
+ return status.Errorf(codes.Internal, "failed to convert updates: %s", err)
+ }
+
+ if err := stream.Send(&v1.WatchResponse{
+ Updates: converted,
+ ChangesThrough: zedtoken.MustNewFromRevision(update.Revision),
+ OptionalTransactionMetadata: update.Metadata,
+ }); err != nil {
+ return status.Errorf(codes.Canceled, "watch canceled by user: %s", err)
+ }
+ }
+ if len(update.ChangedDefinitions) > 0 || len(update.DeletedCaveats) > 0 || len(update.DeletedNamespaces) > 0 {
+ if err := stream.Send(&v1.WatchResponse{
+ SchemaUpdated: true,
+ ChangesThrough: zedtoken.MustNewFromRevision(update.Revision),
+ OptionalTransactionMetadata: update.Metadata,
+ }); err != nil {
+ return status.Errorf(codes.Canceled, "watch canceled by user: %s", err)
+ }
+ }
+ if update.IsCheckpoint {
+ if err := stream.Send(&v1.WatchResponse{
+ IsCheckpoint: update.IsCheckpoint,
+ ChangesThrough: zedtoken.MustNewFromRevision(update.Revision),
+ OptionalTransactionMetadata: update.Metadata,
+ }); err != nil {
+ return status.Errorf(codes.Canceled, "watch canceled by user: %s", err)
+ }
+ }
+ }
+ case err := <-errchan:
+ switch {
+ case errors.As(err, &datastore.WatchCanceledError{}):
+ return status.Errorf(codes.Canceled, "watch canceled by user: %s", err)
+ case errors.As(err, &datastore.WatchDisconnectedError{}):
+ return status.Errorf(codes.ResourceExhausted, "watch disconnected: %s", err)
+ default:
+ return status.Errorf(codes.Internal, "watch error: %s", err)
+ }
+ }
+ }
+}
+
+func buildRelationshipFilters(req *v1.WatchRequest, stream v1.WatchService_WatchServer, reader datastore.Reader, ws *watchServer, ctx context.Context) ([]datastore.RelationshipsFilter, error) {
+ filters := make([]datastore.RelationshipsFilter, 0, len(req.OptionalRelationshipFilters))
+ for _, filter := range req.OptionalRelationshipFilters {
+ if err := validateRelationshipsFilter(stream.Context(), filter, reader); err != nil {
+ return nil, ws.rewriteError(ctx, err)
+ }
+
+ dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter)
+ if err != nil {
+ return nil, status.Errorf(codes.InvalidArgument, "failed to parse relationship filter: %s", err)
+ }
+
+ filters = append(filters, dsFilter)
+ }
+ return filters, nil
+}
+
+func (ws *watchServer) rewriteError(ctx context.Context, err error) error {
+ return shared.RewriteError(ctx, err, &shared.ConfigForErrors{})
+}
+
+func filterRelationshipUpdates(objectTypes *mapz.Set[string], filters []datastore.RelationshipsFilter, updates []tuple.RelationshipUpdate) []tuple.RelationshipUpdate {
+ if objectTypes.IsEmpty() && len(filters) == 0 {
+ return updates
+ }
+
+ filtered := make([]tuple.RelationshipUpdate, 0, len(updates))
+ for _, update := range updates {
+ objectType := update.Relationship.Resource.ObjectType
+ if !objectTypes.IsEmpty() && !objectTypes.Has(objectType) {
+ continue
+ }
+
+ if len(filters) > 0 {
+ // If there are filters, we need to check if the update matches any of them.
+ matched := false
+ for _, filter := range filters {
+ if filter.Test(update.Relationship) {
+ matched = true
+ break
+ }
+ }
+
+ if !matched {
+ continue
+ }
+ }
+
+ filtered = append(filtered, update)
+ }
+
+ return filtered
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go b/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go
new file mode 100644
index 0000000..08910a1
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go
@@ -0,0 +1,22 @@
+package v1
+
+import (
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/spicedb/pkg/datastore"
+)
+
+func convertWatchKindToContent(kinds []v1.WatchKind) datastore.WatchContent {
+ res := datastore.WatchRelationships
+ for _, kind := range kinds {
+ switch kind {
+ case v1.WatchKind_WATCH_KIND_INCLUDE_RELATIONSHIP_UPDATES:
+ res |= datastore.WatchRelationships
+ case v1.WatchKind_WATCH_KIND_INCLUDE_SCHEMA_UPDATES:
+ res |= datastore.WatchSchema
+ case v1.WatchKind_WATCH_KIND_INCLUDE_CHECKPOINTS:
+ res |= datastore.WatchCheckpoints
+ }
+ }
+ return res
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/sharederrors/interfaces.go b/vendor/github.com/authzed/spicedb/internal/sharederrors/interfaces.go
new file mode 100644
index 0000000..852c28f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/sharederrors/interfaces.go
@@ -0,0 +1,16 @@
+package sharederrors
+
+// UnknownNamespaceError is an error raised when a namespace was not found.
+type UnknownNamespaceError interface {
+ // NotFoundNamespaceName is the name of the namespace that was not found.
+ NotFoundNamespaceName() string
+}
+
+// UnknownRelationError is an error raised when a relation was not found.
+type UnknownRelationError interface {
+ // NamespaceName is the name of the namespace under which the relation was not found.
+ NamespaceName() string
+
+ // NotFoundRelationName is the name of the relation that was not found.
+ NotFoundRelationName() string
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/taskrunner/doc.go b/vendor/github.com/authzed/spicedb/internal/taskrunner/doc.go
new file mode 100644
index 0000000..a6000de
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/taskrunner/doc.go
@@ -0,0 +1,2 @@
+// Package taskrunner contains helper code run concurrent code.
+package taskrunner
diff --git a/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go b/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go
new file mode 100644
index 0000000..a088e35
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/taskrunner/preloadedtaskrunner.go
@@ -0,0 +1,153 @@
+package taskrunner
+
+import (
+ "context"
+ "sync"
+)
+
+// PreloadedTaskRunner is a task runner that invokes a series of preloaded tasks,
+// running until the tasks are completed, the context is canceled or an error is
+// returned by one of the tasks (which cancels the context).
+type PreloadedTaskRunner struct {
+ // ctx holds the context given to the task runner and annotated with the cancel
+ // function.
+ ctx context.Context
+ cancel func()
+
+ // sem is a chan of length `concurrencyLimit` used to ensure the task runner does
+ // not exceed the concurrencyLimit with spawned goroutines.
+ sem chan struct{}
+
+ wg sync.WaitGroup
+
+ lock sync.Mutex
+ err error // GUARDED_BY(lock)
+ tasks []TaskFunc // GUARDED_BY(lock)
+}
+
+func NewPreloadedTaskRunner(ctx context.Context, concurrencyLimit uint16, initialCapacity int) *PreloadedTaskRunner {
+ // Ensure a concurrency level of at least 1.
+ if concurrencyLimit <= 0 {
+ concurrencyLimit = 1
+ }
+
+ ctxWithCancel, cancel := context.WithCancel(ctx)
+ return &PreloadedTaskRunner{
+ ctx: ctxWithCancel,
+ cancel: cancel,
+ sem: make(chan struct{}, concurrencyLimit),
+ tasks: make([]TaskFunc, 0, initialCapacity),
+ }
+}
+
+// Add adds the given task function to be run.
+func (tr *PreloadedTaskRunner) Add(f TaskFunc) {
+ tr.tasks = append(tr.tasks, f)
+ tr.wg.Add(1)
+}
+
+// Start starts running the tasks in the task runner. This does *not* wait for the tasks
+// to complete, but rather returns immediately.
+func (tr *PreloadedTaskRunner) Start() {
+ for range tr.tasks {
+ tr.spawnIfAvailable()
+ }
+}
+
+// StartAndWait starts running the tasks in the task runner and waits for them to complete.
+func (tr *PreloadedTaskRunner) StartAndWait() error {
+ tr.Start()
+ tr.wg.Wait()
+
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ return tr.err
+}
+
+func (tr *PreloadedTaskRunner) spawnIfAvailable() {
+ // To spawn a runner, write a struct{} to the sem channel. If the task runner
+ // is already at the concurrency limit, then this chan write will fail,
+ // and nothing will be spawned. This also checks if the context has already
+ // been canceled, in which case nothing needs to be done.
+ select {
+ case tr.sem <- struct{}{}:
+ go tr.runner()
+
+ case <-tr.ctx.Done():
+ // If the context was canceled, nothing more to do.
+ tr.emptyForCancel()
+ return
+
+ default:
+ return
+ }
+}
+
+func (tr *PreloadedTaskRunner) runner() {
+ for {
+ select {
+ case <-tr.ctx.Done():
+ // If the context was canceled, nothing more to do.
+ tr.emptyForCancel()
+ return
+
+ default:
+ // Select a task from the list, if any.
+ task := tr.selectTask()
+ if task == nil {
+ return
+ }
+
+ // Run the task. If an error occurs, store it and cancel any further tasks.
+ err := task(tr.ctx)
+ if err != nil {
+ tr.storeErrorAndCancel(err)
+ }
+ tr.wg.Done()
+ }
+ }
+}
+
+func (tr *PreloadedTaskRunner) selectTask() TaskFunc {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if len(tr.tasks) == 0 {
+ return nil
+ }
+
+ task := tr.tasks[0]
+ tr.tasks[0] = nil // to free the reference once the task completes.
+ tr.tasks = tr.tasks[1:]
+ return task
+}
+
+func (tr *PreloadedTaskRunner) storeErrorAndCancel(err error) {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if tr.err == nil {
+ tr.err = err
+ tr.cancel()
+ }
+}
+
+func (tr *PreloadedTaskRunner) emptyForCancel() {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if tr.err == nil {
+ tr.err = tr.ctx.Err()
+ }
+
+ for {
+ if len(tr.tasks) == 0 {
+ break
+ }
+
+ tr.tasks[0] = nil // to free the reference
+ tr.tasks = tr.tasks[1:]
+ tr.wg.Done()
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go b/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go
new file mode 100644
index 0000000..1b519ed
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/taskrunner/taskrunner.go
@@ -0,0 +1,168 @@
+package taskrunner
+
+import (
+ "context"
+ "sync"
+)
+
+// TaskRunner is a helper which runs a series of scheduled tasks against a defined
+// limit of goroutines.
+type TaskRunner struct {
+ // ctx holds the context given to the task runner and annotated with the cancel
+ // function.
+ ctx context.Context
+ cancel func()
+
+ // sem is a chan of length `concurrencyLimit` used to ensure the task runner does
+ // not exceed the concurrencyLimit with spawned goroutines.
+ sem chan struct{}
+
+ wg sync.WaitGroup
+
+ lock sync.Mutex
+ tasks []TaskFunc // GUARDED_BY(lock)
+
+ // err holds the error returned by any task, if any. If the context is canceled,
+ // this err will hold the cancelation error.
+ err error // GUARDED_BY(lock)
+}
+
+// TaskFunc defines functions representing tasks.
+type TaskFunc func(ctx context.Context) error
+
+// NewTaskRunner creates a new task runner with the given starting context and
+// concurrency limit. The TaskRunner will schedule no more goroutines that the
+// specified concurrencyLimit. If the given context is canceled, then all tasks
+// started after that point will also be canceled and the error returned. If
+// a task returns an error, the context provided to all tasks is also canceled.
+func NewTaskRunner(ctx context.Context, concurrencyLimit uint16) *TaskRunner {
+ if concurrencyLimit < 1 {
+ concurrencyLimit = 1
+ }
+
+ ctxWithCancel, cancel := context.WithCancel(ctx)
+ return &TaskRunner{
+ ctx: ctxWithCancel,
+ cancel: cancel,
+ sem: make(chan struct{}, concurrencyLimit),
+ tasks: make([]TaskFunc, 0),
+ }
+}
+
+// Schedule schedules a task to be run. This is safe to call from within another
+// task handler function and immediately returns.
+func (tr *TaskRunner) Schedule(f TaskFunc) {
+ if tr.addTask(f) {
+ tr.spawnIfAvailable()
+ }
+}
+
+func (tr *TaskRunner) spawnIfAvailable() {
+ // To spawn a runner, write a struct{} to the sem channel. If the task runner
+ // is already at the concurrency limit, then this chan write will fail,
+ // and nothing will be spawned. This also checks if the context has already
+ // been canceled, in which case nothing needs to be done.
+ select {
+ case tr.sem <- struct{}{}:
+ go tr.runner()
+
+ case <-tr.ctx.Done():
+ return
+
+ default:
+ return
+ }
+}
+
+func (tr *TaskRunner) runner() {
+ for {
+ select {
+ case <-tr.ctx.Done():
+ // If the context was canceled, mark all the remaining tasks as "Done".
+ tr.emptyForCancel()
+ return
+
+ default:
+ // Select a task from the list, if any.
+ task := tr.selectTask()
+ if task == nil {
+ // If there are no further tasks, then "return" the struct{} by reading
+ // it from the channel (freeing a slot potentially for another worker
+ // to be spawned later).
+ <-tr.sem
+ return
+ }
+
+ // Run the task. If an error occurs, store it and cancel any further tasks.
+ err := task(tr.ctx)
+ if err != nil {
+ tr.storeErrorAndCancel(err)
+ }
+ tr.wg.Done()
+ }
+ }
+}
+
+func (tr *TaskRunner) addTask(f TaskFunc) bool {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if tr.err != nil {
+ return false
+ }
+
+ tr.wg.Add(1)
+ tr.tasks = append(tr.tasks, f)
+ return true
+}
+
+func (tr *TaskRunner) selectTask() TaskFunc {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if len(tr.tasks) == 0 {
+ return nil
+ }
+
+ task := tr.tasks[0]
+ tr.tasks = tr.tasks[1:]
+ return task
+}
+
+func (tr *TaskRunner) storeErrorAndCancel(err error) {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if tr.err == nil {
+ tr.err = err
+ tr.cancel()
+ }
+}
+
+func (tr *TaskRunner) emptyForCancel() {
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+
+ if tr.err == nil {
+ tr.err = tr.ctx.Err()
+ }
+
+ for {
+ if len(tr.tasks) == 0 {
+ break
+ }
+
+ tr.tasks = tr.tasks[1:]
+ tr.wg.Done()
+ }
+}
+
+// Wait waits for all tasks to be completed, or a task to raise an error,
+// or the parent context to have been canceled.
+func (tr *TaskRunner) Wait() error {
+ tr.wg.Wait()
+
+ tr.lock.Lock()
+ defer tr.lock.Unlock()
+ return tr.err
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/telemetry/doc.go b/vendor/github.com/authzed/spicedb/internal/telemetry/doc.go
new file mode 100644
index 0000000..acf0c0b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/telemetry/doc.go
@@ -0,0 +1,6 @@
+// Package telemetry implements a client for reporting telemetry data used to
+// prioritize development of SpiceDB.
+//
+// For more information, see:
+// https://github.com/authzed/spicedb/blob/main/TELEMETRY.md
+package telemetry
diff --git a/vendor/github.com/authzed/spicedb/internal/telemetry/logicalchecks.go b/vendor/github.com/authzed/spicedb/internal/telemetry/logicalchecks.go
new file mode 100644
index 0000000..3aa91a4
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/telemetry/logicalchecks.go
@@ -0,0 +1,16 @@
+package telemetry
+
+import "sync/atomic"
+
+var logicalChecksCountTotal atomic.Uint64
+
+// RecordLogicalChecks records the number of logical checks performed by the server.
+func RecordLogicalChecks(logicalCheckCount uint64) {
+ logicalChecksCountTotal.Add(logicalCheckCount)
+}
+
+// loadLogicalChecksCount returns the total number of logical checks performed by the server,
+// zeroing out the existing count as well.
+func loadLogicalChecksCount() uint64 {
+ return logicalChecksCountTotal.Swap(0)
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/telemetry/metrics.go b/vendor/github.com/authzed/spicedb/internal/telemetry/metrics.go
new file mode 100644
index 0000000..4323297
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/telemetry/metrics.go
@@ -0,0 +1,203 @@
+package telemetry
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "runtime"
+ "runtime/debug"
+ "strconv"
+ "time"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/prometheus/client_golang/prometheus"
+ dto "github.com/prometheus/client_model/go"
+ "golang.org/x/sync/errgroup"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/internal/middleware/usagemetrics"
+ "github.com/authzed/spicedb/pkg/datastore"
+ "github.com/authzed/spicedb/pkg/promutil"
+)
+
+func SpiceDBClusterInfoCollector(ctx context.Context, subsystem, dsEngine string, ds datastore.Datastore) (promutil.CollectorFunc, error) {
+ nodeID, err := os.Hostname()
+ if err != nil {
+ return nil, fmt.Errorf("unable to get hostname: %w", err)
+ }
+
+ dbStats, err := ds.Statistics(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("unable to query DB stats: %w", err)
+ }
+
+ clusterID := dbStats.UniqueID
+ buildInfo, ok := debug.ReadBuildInfo()
+ if !ok {
+ return nil, fmt.Errorf("failed to read BuildInfo")
+ }
+
+ return func(ch chan<- prometheus.Metric) {
+ ch <- prometheus.MustNewConstMetric(prometheus.NewDesc(
+ prometheus.BuildFQName("spicedb", subsystem, "info"),
+ "Information about the SpiceDB environment.",
+ nil,
+ prometheus.Labels{
+ "cluster_id": clusterID,
+ "node_id": nodeID,
+ "version": cobrautil.VersionWithFallbacks(buildInfo),
+ "os": runtime.GOOS,
+ "arch": runtime.GOARCH,
+ "go": buildInfo.GoVersion,
+ "vcpu": strconv.Itoa(runtime.NumCPU()),
+ "ds_engine": dsEngine,
+ },
+ ), prometheus.GaugeValue, 1)
+ }, nil
+}
+
+// RegisterTelemetryCollector registers a collector for the various pieces of
+// data required by SpiceDB telemetry.
+func RegisterTelemetryCollector(datastoreEngine string, ds datastore.Datastore) (*prometheus.Registry, error) {
+ registry := prometheus.NewRegistry()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ infoCollector, err := SpiceDBClusterInfoCollector(ctx, "telemetry", datastoreEngine, ds)
+ if err != nil {
+ return nil, fmt.Errorf("unable create info collector: %w", err)
+ }
+
+ if err := registry.Register(infoCollector); err != nil {
+ return nil, fmt.Errorf("unable to register telemetry collector: %w", err)
+ }
+
+ nodeID, err := os.Hostname()
+ if err != nil {
+ return nil, fmt.Errorf("unable to get hostname: %w", err)
+ }
+
+ dbStats, err := ds.Statistics(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("unable to query DB stats: %w", err)
+ }
+ clusterID := dbStats.UniqueID
+
+ if err := registry.Register(&collector{
+ ds: ds,
+ objectDefsDesc: prometheus.NewDesc(
+ prometheus.BuildFQName("spicedb", "telemetry", "object_definitions_total"),
+ "Count of the number of objects defined by the schema.",
+ nil,
+ prometheus.Labels{
+ "cluster_id": clusterID,
+ "node_id": nodeID,
+ },
+ ),
+ relationshipsDesc: prometheus.NewDesc(
+ prometheus.BuildFQName("spicedb", "telemetry", "relationships_estimate_total"),
+ "Count of the estimated number of stored relationships.",
+ nil,
+ prometheus.Labels{
+ "cluster_id": clusterID,
+ "node_id": nodeID,
+ },
+ ),
+ dispatchedDesc: prometheus.NewDesc(
+ prometheus.BuildFQName("spicedb", "telemetry", "dispatches"),
+ "Histogram of cluster dispatches performed by the instance.",
+ usagemetrics.DispatchedCountLabels,
+ prometheus.Labels{
+ "cluster_id": clusterID,
+ "node_id": nodeID,
+ },
+ ),
+ logicalChecksDec: prometheus.NewDesc(
+ prometheus.BuildFQName("spicedb", "telemetry", "logical_checks_total"),
+ "Count of the number of logical checks made.",
+ usagemetrics.DispatchedCountLabels,
+ prometheus.Labels{
+ "cluster_id": clusterID,
+ "node_id": nodeID,
+ },
+ ),
+ }); err != nil {
+ return nil, fmt.Errorf("unable to register telemetry collector: %w", err)
+ }
+
+ return registry, nil
+}
+
+type collector struct {
+ ds datastore.Datastore
+ objectDefsDesc *prometheus.Desc
+ relationshipsDesc *prometheus.Desc
+ dispatchedDesc *prometheus.Desc
+ logicalChecksDec *prometheus.Desc
+}
+
+var _ prometheus.Collector = &collector{}
+
+func (c *collector) Describe(ch chan<- *prometheus.Desc) {
+ ch <- c.objectDefsDesc
+ ch <- c.relationshipsDesc
+ ch <- c.dispatchedDesc
+ ch <- c.logicalChecksDec
+}
+
+func (c *collector) Collect(ch chan<- prometheus.Metric) {
+ ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+ defer cancel()
+
+ dsStats, err := c.ds.Statistics(ctx)
+ if err != nil {
+ log.Warn().Err(err).Msg("unable to collect datastore statistics")
+ }
+
+ logicalChecksCount := loadLogicalChecksCount()
+
+ ch <- prometheus.MustNewConstMetric(c.objectDefsDesc, prometheus.GaugeValue, float64(len(dsStats.ObjectTypeStatistics)))
+ ch <- prometheus.MustNewConstMetric(c.relationshipsDesc, prometheus.GaugeValue, float64(dsStats.EstimatedRelationshipCount))
+ ch <- prometheus.MustNewConstMetric(c.logicalChecksDec, prometheus.GaugeValue, float64(logicalChecksCount))
+
+ dispatchedCountMetrics := make(chan prometheus.Metric)
+ g := errgroup.Group{}
+ g.Go(func() error {
+ for metric := range dispatchedCountMetrics {
+ var m dto.Metric
+ if err := metric.Write(&m); err != nil {
+ return fmt.Errorf("error writing metric: %w", err)
+ }
+
+ buckets := make(map[float64]uint64, len(m.Histogram.Bucket))
+ for _, bucket := range m.Histogram.Bucket {
+ buckets[*bucket.UpperBound] = *bucket.CumulativeCount
+ }
+
+ dynamicLabels := make([]string, len(usagemetrics.DispatchedCountLabels))
+ for i, labelName := range usagemetrics.DispatchedCountLabels {
+ for _, labelVal := range m.Label {
+ if *labelVal.Name == labelName {
+ dynamicLabels[i] = *labelVal.Value
+ }
+ }
+ }
+ ch <- prometheus.MustNewConstHistogram(
+ c.dispatchedDesc,
+ *m.Histogram.SampleCount,
+ *m.Histogram.SampleSum,
+ buckets,
+ dynamicLabels...,
+ )
+ }
+ return nil
+ })
+
+ usagemetrics.DispatchedCountHistogram.Collect(dispatchedCountMetrics)
+ close(dispatchedCountMetrics)
+
+ if err := g.Wait(); err != nil {
+ log.Error().Err(err).Msg("error collecting metrics")
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/internal/telemetry/reporter.go b/vendor/github.com/authzed/spicedb/internal/telemetry/reporter.go
new file mode 100644
index 0000000..8296b3c
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/internal/telemetry/reporter.go
@@ -0,0 +1,234 @@
+package telemetry
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "math/rand"
+ "net/http"
+ "net/url"
+ "time"
+
+ prompb "buf.build/gen/go/prometheus/prometheus/protocolbuffers/go"
+ "github.com/cenkalti/backoff/v4"
+ "github.com/gogo/protobuf/proto"
+ "github.com/golang/snappy"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/common/expfmt"
+ "github.com/prometheus/common/model"
+
+ log "github.com/authzed/spicedb/internal/logging"
+ "github.com/authzed/spicedb/pkg/x509util"
+)
+
+const (
+ // DefaultEndpoint is the endpoint to which telemetry will report if none
+ // other is specified.
+ DefaultEndpoint = "https://telemetry.authzed.com"
+
+ // DefaultInterval is the default amount of time to wait between telemetry
+ // reports.
+ DefaultInterval = 1 * time.Hour
+
+ // MaxElapsedTimeBetweenReports is the maximum amount of time that the
+ // telemetry reporter will attempt to write to the telemetry endpoint
+ // before terminating the reporter.
+ MaxElapsedTimeBetweenReports = 168 * time.Hour
+
+ // MinimumAllowedInterval is the minimum amount of time one can request
+ // between telemetry reports.
+ MinimumAllowedInterval = 1 * time.Minute
+)
+
+func writeTimeSeries(ctx context.Context, client *http.Client, endpoint string, ts []*prompb.TimeSeries) error {
+ // Reference upstream client:
+ // https://github.com/prometheus/prometheus/blob/6555cc68caf8d8f323056e497ae7bb1e32a81667/storage/remote/client.go#L191
+ pbBytes, err := proto.Marshal(&prompb.WriteRequest{
+ Timeseries: ts,
+ })
+ if err != nil {
+ return fmt.Errorf("failed to marshal Prometheus remote write protobuf: %w", err)
+ }
+ compressedPB := snappy.Encode(nil, pbBytes)
+
+ r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(compressedPB))
+ if err != nil {
+ return fmt.Errorf("failed to create Prometheus remote write http request: %w", err)
+ }
+
+ r.Header.Add("X-Prometheus-Remote-Write-Version", "0.1.0")
+ r.Header.Add("Content-Encoding", "snappy")
+ r.Header.Set("Content-Type", "application/x-protobuf")
+
+ resp, err := client.Do(r)
+ if err != nil {
+ return fmt.Errorf("failed to send Prometheus remote write: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode/100 != 2 {
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf(
+ "unexpected Prometheus remote write response: %d: %s",
+ resp.StatusCode,
+ string(body),
+ )
+ }
+
+ return nil
+}
+
+func discoverTimeseries(registry *prometheus.Registry) (allTS []*prompb.TimeSeries, err error) {
+ metricFams, err := registry.Gather()
+ if err != nil {
+ return nil, fmt.Errorf("failed to gather telemetry metrics: %w", err)
+ }
+
+ defaultTimestamp := model.Time(time.Now().UnixNano() / int64(time.Millisecond))
+ sampleVector, err := expfmt.ExtractSamples(&expfmt.DecodeOptions{
+ Timestamp: defaultTimestamp,
+ }, metricFams...)
+ if err != nil {
+ return nil, fmt.Errorf("unable to extract sample from metrics families: %w", err)
+ }
+
+ for _, sample := range sampleVector {
+ allTS = append(allTS, &prompb.TimeSeries{
+ Labels: convertLabels(sample.Metric),
+ Samples: []*prompb.Sample{{
+ Value: float64(sample.Value),
+ Timestamp: int64(sample.Timestamp),
+ }},
+ })
+ }
+
+ return
+}
+
+func discoverAndWriteMetrics(
+ ctx context.Context,
+ registry *prometheus.Registry,
+ client *http.Client,
+ endpoint string,
+) error {
+ ts, err := discoverTimeseries(registry)
+ if err != nil {
+ return err
+ }
+
+ return writeTimeSeries(ctx, client, endpoint, ts)
+}
+
+type Reporter func(ctx context.Context) error
+
+// RemoteReporter creates a telemetry reporter with the specified parameters, or errors
+// if the configuration was invalid.
+func RemoteReporter(
+ registry *prometheus.Registry,
+ endpoint string,
+ caOverridePath string,
+ interval time.Duration,
+) (Reporter, error) {
+ if _, err := url.Parse(endpoint); err != nil {
+ return nil, fmt.Errorf("invalid telemetry endpoint: %w", err)
+ }
+ if interval < MinimumAllowedInterval {
+ return nil, fmt.Errorf("invalid telemetry reporting interval: %s < %s", interval, MinimumAllowedInterval)
+ }
+ if endpoint == DefaultEndpoint && interval != DefaultInterval {
+ return nil, fmt.Errorf("cannot change the telemetry reporting interval for the default endpoint")
+ }
+
+ client := &http.Client{}
+ if caOverridePath != "" {
+ pool, err := x509util.CustomCertPool(caOverridePath)
+ if err != nil {
+ return nil, fmt.Errorf("invalid custom cert pool path `%s`: %w", caOverridePath, err)
+ }
+
+ t := &http.Transport{
+ TLSClientConfig: &tls.Config{
+ RootCAs: pool,
+ MinVersion: tls.VersionTLS12,
+ },
+ }
+
+ client.Transport = t
+ }
+
+ return func(ctx context.Context) error {
+ // nolint:gosec
+ // G404 use of non cryptographically secure random number generator is not a security concern here,
+ // as this is only used to smear the startup delay out over 10% of the reporting interval
+ startupDelay := time.Duration(rand.Int63n(int64(interval.Seconds()/10))) * time.Second
+
+ log.Ctx(ctx).Info().
+ Stringer("interval", interval).
+ Str("endpoint", endpoint).
+ Stringer("next", startupDelay).
+ Msg("telemetry reporter scheduled")
+
+ backoffInterval := backoff.NewExponentialBackOff()
+ backoffInterval.InitialInterval = interval
+ backoffInterval.MaxInterval = MaxElapsedTimeBetweenReports
+ backoffInterval.MaxElapsedTime = 0
+
+ // Must reset the backoff object after changing parameters
+ backoffInterval.Reset()
+
+ ticker := time.After(startupDelay)
+
+ for {
+ select {
+ case <-ticker:
+ nextPush := backoffInterval.InitialInterval
+ if err := discoverAndWriteMetrics(ctx, registry, client, endpoint); err != nil {
+ nextPush = backoffInterval.NextBackOff()
+ log.Ctx(ctx).Warn().
+ Err(err).
+ Str("endpoint", endpoint).
+ Stringer("next", nextPush).
+ Msg("failed to push telemetry metric")
+ } else {
+ log.Ctx(ctx).Debug().
+ Str("endpoint", endpoint).
+ Stringer("next", nextPush).
+ Msg("reported telemetry")
+ backoffInterval.Reset()
+ }
+ if nextPush == backoff.Stop {
+ return fmt.Errorf(
+ "exceeded maximum time between successful reports of %s",
+ MaxElapsedTimeBetweenReports,
+ )
+ }
+ ticker = time.After(nextPush)
+
+ case <-ctx.Done():
+ return nil
+ }
+ }
+ }, nil
+}
+
+func DisabledReporter(ctx context.Context) error {
+ log.Ctx(ctx).Info().Msg("telemetry disabled")
+ return nil
+}
+
+func SilentlyDisabledReporter(_ context.Context) error {
+ return nil
+}
+
+func convertLabels(labels model.Metric) []*prompb.Label {
+ out := make([]*prompb.Label, 0, len(labels))
+ for name, value := range labels {
+ out = append(out, &prompb.Label{
+ Name: string(name),
+ Value: string(value),
+ })
+ }
+ return out
+}