diff options
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal/graph')
17 files changed, 5516 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/graph/check.go b/vendor/github.com/authzed/spicedb/internal/graph/check.go new file mode 100644 index 0000000..65bcb50 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/check.go @@ -0,0 +1,1354 @@ +package graph + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/samber/lo" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph/hints" + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/taskrunner" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/middleware/nodeid" + nspkg "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +var tracer = otel.Tracer("spicedb/internal/graph/check") + +var dispatchChunkCountHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "spicedb_check_dispatch_chunk_count", + Help: "number of chunks when dispatching in check", + Buckets: []float64{1, 2, 3, 5, 10, 25, 100, 250}, +}) + +var directDispatchQueryHistogram = prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "spicedb_check_direct_dispatch_query_count", + Help: "number of queries made per direct dispatch", + Buckets: []float64{1, 2}, +}) + +const noOriginalRelation = "" + +func init() { + prometheus.MustRegister(directDispatchQueryHistogram) + prometheus.MustRegister(dispatchChunkCountHistogram) +} + +// NewConcurrentChecker creates an instance of ConcurrentChecker. +func NewConcurrentChecker(d dispatch.Check, concurrencyLimit uint16, dispatchChunkSize uint16) *ConcurrentChecker { + return &ConcurrentChecker{d, concurrencyLimit, dispatchChunkSize} +} + +// ConcurrentChecker exposes a method to perform Check requests, and delegates subproblems to the +// provided dispatch.Check instance. +type ConcurrentChecker struct { + d dispatch.Check + concurrencyLimit uint16 + dispatchChunkSize uint16 +} + +// ValidatedCheckRequest represents a request after it has been validated and parsed for internal +// consumption. +type ValidatedCheckRequest struct { + *v1.DispatchCheckRequest + Revision datastore.Revision + + // OriginalRelationName is the original relation/permission name that was used in the request, + // before being changed due to aliasing. + OriginalRelationName string +} + +// currentRequestContext holds context information for the current request being +// processed. +type currentRequestContext struct { + // parentReq is the parent request being processed. + parentReq ValidatedCheckRequest + + // filteredResourceIDs are those resource IDs to be checked after filtering for + // any resource IDs found directly matching the incoming subject. + // + // For example, a check of resources `user:{tom,sarah,fred}` and subject `user:sarah` will + // result in this slice containing `tom` and `fred`, but not `sarah`, as she was found as a + // match. + // + // This check and filter occurs via the filterForFoundMemberResource function in the + // checkInternal function before the rest of the checking logic is run. This slice should never + // be empty. + filteredResourceIDs []string + + // resultsSetting is the results setting to use for this request and all subsequent + // requests. + resultsSetting v1.DispatchCheckRequest_ResultsSetting + + // dispatchChunkSize is the maximum number of resource IDs that can be specified in each dispatch. + dispatchChunkSize uint16 +} + +// Check performs a check request with the provided request and context +func (cc *ConcurrentChecker) Check(ctx context.Context, req ValidatedCheckRequest, relation *core.Relation) (*v1.DispatchCheckResponse, error) { + var startTime *time.Time + if req.Debug != v1.DispatchCheckRequest_NO_DEBUG { + now := time.Now() + startTime = &now + } + + resolved := cc.checkInternal(ctx, req, relation) + resolved.Resp.Metadata = addCallToResponseMetadata(resolved.Resp.Metadata) + if req.Debug == v1.DispatchCheckRequest_NO_DEBUG { + return resolved.Resp, resolved.Err + } + + nodeID, err := nodeid.FromContext(ctx) + if err != nil { + // NOTE: we ignore this error here as if the node ID is missing, the debug + // trace is still valid. + log.Err(err).Msg("failed to get node ID") + } + + // Add debug information if requested. + debugInfo := resolved.Resp.Metadata.DebugInfo + if debugInfo == nil { + debugInfo = &v1.DebugInformation{ + Check: &v1.CheckDebugTrace{ + TraceId: NewTraceID(), + SourceId: nodeID, + }, + } + } else if debugInfo.Check != nil && debugInfo.Check.SourceId == "" { + debugInfo.Check.SourceId = nodeID + } + + // Remove the traversal bloom from the debug request to save some data over the + // wire. + clonedRequest := req.DispatchCheckRequest.CloneVT() + clonedRequest.Metadata.TraversalBloom = nil + + debugInfo.Check.Request = clonedRequest + debugInfo.Check.Duration = durationpb.New(time.Since(*startTime)) + + if nspkg.GetRelationKind(relation) == iv1.RelationMetadata_PERMISSION { + debugInfo.Check.ResourceRelationType = v1.CheckDebugTrace_PERMISSION + } else if nspkg.GetRelationKind(relation) == iv1.RelationMetadata_RELATION { + debugInfo.Check.ResourceRelationType = v1.CheckDebugTrace_RELATION + } + + // Build the results for the debug trace. + results := make(map[string]*v1.ResourceCheckResult, len(req.DispatchCheckRequest.ResourceIds)) + for _, resourceID := range req.DispatchCheckRequest.ResourceIds { + if found, ok := resolved.Resp.ResultsByResourceId[resourceID]; ok { + results[resourceID] = found + } + } + debugInfo.Check.Results = results + + // If there is existing debug information in the error, then place it as the subproblem of the current + // debug information. + if existingDebugInfo, ok := spiceerrors.GetDetails[*v1.DebugInformation](resolved.Err); ok { + debugInfo.Check.SubProblems = []*v1.CheckDebugTrace{existingDebugInfo.Check} + } + + resolved.Resp.Metadata.DebugInfo = debugInfo + + // If there is an error and it is already a gRPC error, add the debug information + // into the details portion of the payload. This allows the client to see the debug + // information, as gRPC will only return the error. + updatedErr := spiceerrors.WithReplacedDetails(resolved.Err, debugInfo) + return resolved.Resp, updatedErr +} + +func (cc *ConcurrentChecker) checkInternal(ctx context.Context, req ValidatedCheckRequest, relation *core.Relation) CheckResult { + spiceerrors.DebugAssert(func() bool { + return relation.GetUsersetRewrite() != nil || relation.GetTypeInformation() != nil + }, "found relation without type information") + + // Ensure that we have at least one resource ID for which to execute the check. + if len(req.ResourceIds) == 0 { + return checkResultError( + spiceerrors.MustBugf("empty resource IDs given to dispatched check"), + emptyMetadata, + ) + } + + // Ensure that we are not performing a check for a wildcard as the subject. + if req.Subject.ObjectId == tuple.PublicWildcard { + return checkResultError(NewWildcardNotAllowedErr("cannot perform check on wildcard subject", "subject.object_id"), emptyMetadata) + } + + // Deduplicate any incoming resource IDs. + resourceIds := lo.Uniq(req.ResourceIds) + + // Filter the incoming resource IDs for any which match the subject directly. For example, if we receive + // a check for resource `user:{tom, fred, sarah}#...` and a subject of `user:sarah#...`, then we know + // that `user:sarah#...` is a valid "member" of the resource, as it matches exactly. + // + // If the filtering results in no further resource IDs to check, or a result is found and a single + // result is allowed, we terminate early. + membershipSet, filteredResourcesIds := filterForFoundMemberResource(req.ResourceRelation, resourceIds, req.Subject) + if membershipSet.HasDeterminedMember() && req.DispatchCheckRequest.ResultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT { + return checkResultsForMembership(membershipSet, emptyMetadata) + } + + // Filter for check hints, if any. + if len(req.CheckHints) > 0 { + subject := tuple.FromCoreObjectAndRelation(req.Subject) + filteredResourcesIdsSet := mapz.NewSet(filteredResourcesIds...) + for _, checkHint := range req.CheckHints { + resourceID, ok := hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.ResourceRelation.Relation, subject) + if ok { + filteredResourcesIdsSet.Delete(resourceID) + continue + } + + if req.OriginalRelationName != "" { + resourceID, ok = hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.OriginalRelationName, subject) + if ok { + filteredResourcesIdsSet.Delete(resourceID) + } + } + } + filteredResourcesIds = filteredResourcesIdsSet.AsSlice() + } + + if len(filteredResourcesIds) == 0 { + return combineWithCheckHints(combineResultWithFoundResources(noMembers(), membershipSet), req) + } + + // NOTE: We can always allow a single result if we're only trying to find the results for a + // single resource ID. This "reset" allows for short circuiting of downstream dispatched calls. + resultsSetting := req.ResultsSetting + if len(filteredResourcesIds) == 1 { + resultsSetting = v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT + } + + crc := currentRequestContext{ + parentReq: req, + filteredResourceIDs: filteredResourcesIds, + resultsSetting: resultsSetting, + dispatchChunkSize: cc.dispatchChunkSize, + } + + if req.Debug == v1.DispatchCheckRequest_ENABLE_TRACE_DEBUGGING { + crc.dispatchChunkSize = 1 + } + + if relation.UsersetRewrite == nil { + return combineWithCheckHints(combineResultWithFoundResources(cc.checkDirect(ctx, crc, relation), membershipSet), req) + } + + return combineWithCheckHints(combineResultWithFoundResources(cc.checkUsersetRewrite(ctx, crc, relation.UsersetRewrite), membershipSet), req) +} + +func combineWithComputedHints(result CheckResult, hints map[string]*v1.ResourceCheckResult) CheckResult { + if len(hints) == 0 { + return result + } + + for resourceID, hint := range hints { + if _, ok := result.Resp.ResultsByResourceId[resourceID]; ok { + return checkResultError( + spiceerrors.MustBugf("check hint for resource ID %q, which already exists", resourceID), + emptyMetadata, + ) + } + + if result.Resp.ResultsByResourceId == nil { + result.Resp.ResultsByResourceId = make(map[string]*v1.ResourceCheckResult) + } + result.Resp.ResultsByResourceId[resourceID] = hint + } + + return result +} + +func combineWithCheckHints(result CheckResult, req ValidatedCheckRequest) CheckResult { + if len(req.CheckHints) == 0 { + return result + } + + subject := tuple.FromCoreObjectAndRelation(req.Subject) + for _, checkHint := range req.CheckHints { + resourceID, ok := hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.ResourceRelation.Relation, subject) + if !ok { + if req.OriginalRelationName != "" { + resourceID, ok = hints.AsCheckHintForComputedUserset(checkHint, req.ResourceRelation.Namespace, req.OriginalRelationName, subject) + } + + if !ok { + continue + } + } + + if result.Resp.ResultsByResourceId == nil { + result.Resp.ResultsByResourceId = make(map[string]*v1.ResourceCheckResult) + } + + if _, ok := result.Resp.ResultsByResourceId[resourceID]; ok { + return checkResultError( + spiceerrors.MustBugf("check hint for resource ID %q, which already exists", resourceID), + emptyMetadata, + ) + } + + result.Resp.ResultsByResourceId[resourceID] = checkHint.Result + } + + return result +} + +func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequestContext, relation *core.Relation) CheckResult { + ctx, span := tracer.Start(ctx, "checkDirect") + defer span.End() + + // Build a filter for finding the direct relationships for the check. There are three + // classes of relationships to be found: + // 1) the target subject itself, if allowed on this relation + // 2) the wildcard form of the target subject, if a wildcard is allowed on this relation + // 3) Otherwise, any non-terminal (non-`...`) subjects, if allowed on this relation, to be + // redispatched outward + totalNonTerminals := 0 + totalDirectSubjects := 0 + totalWildcardSubjects := 0 + + defer func() { + if totalNonTerminals > 0 { + span.SetName("non terminal") + } else if totalDirectSubjects > 0 { + span.SetName("terminal") + } else { + span.SetName("wildcard subject") + } + }() + log.Ctx(ctx).Trace().Object("direct", crc.parentReq).Send() + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + + directSubjectsAndWildcardsWithoutCaveats := 0 + directSubjectsAndWildcardsWithoutExpiration := 0 + nonTerminalsWithoutCaveats := 0 + nonTerminalsWithoutExpiration := 0 + + for _, allowedDirectRelation := range relation.GetTypeInformation().GetAllowedDirectRelations() { + // If the namespace of the allowed direct relation matches the subject type, there are two + // cases to optimize: + // 1) Finding the target subject itself, as a direct lookup + // 2) Finding a wildcard for the subject type+relation + if allowedDirectRelation.GetNamespace() == crc.parentReq.Subject.Namespace { + if allowedDirectRelation.GetPublicWildcard() != nil { + totalWildcardSubjects++ + } else if allowedDirectRelation.GetRelation() == crc.parentReq.Subject.Relation { + totalDirectSubjects++ + } + + if allowedDirectRelation.RequiredCaveat == nil { + directSubjectsAndWildcardsWithoutCaveats++ + } + + if allowedDirectRelation.RequiredExpiration == nil { + directSubjectsAndWildcardsWithoutExpiration++ + } + } + + // If the relation found is not an ellipsis, then this is a nested relation that + // might need to be followed, so indicate that such relationships should be returned + // + // TODO(jschorr): Use type information to *further* optimize this query around which nested + // relations can reach the target subject type. + if allowedDirectRelation.GetRelation() != tuple.Ellipsis { + totalNonTerminals++ + if allowedDirectRelation.RequiredCaveat == nil { + nonTerminalsWithoutCaveats++ + } + if allowedDirectRelation.RequiredExpiration == nil { + nonTerminalsWithoutExpiration++ + } + } + } + + nonTerminalsCanHaveCaveats := totalNonTerminals != nonTerminalsWithoutCaveats + nonTerminalsCanHaveExpiration := totalNonTerminals != nonTerminalsWithoutExpiration + hasNonTerminals := totalNonTerminals > 0 + + foundResources := NewMembershipSet() + + // If the direct subject or a wildcard form can be found, issue a query for just that + // subject. + var queryCount float64 + defer func() { + directDispatchQueryHistogram.Observe(queryCount) + }() + + hasDirectSubject := totalDirectSubjects > 0 + hasWildcardSubject := totalWildcardSubjects > 0 + if hasDirectSubject || hasWildcardSubject { + directSubjectOrWildcardCanHaveCaveats := directSubjectsAndWildcardsWithoutCaveats != (totalDirectSubjects + totalWildcardSubjects) + directSubjectOrWildcardCanHaveExpiration := directSubjectsAndWildcardsWithoutExpiration != (totalDirectSubjects + totalWildcardSubjects) + + subjectSelectors := []datastore.SubjectsSelector{} + + if hasDirectSubject { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: crc.parentReq.Subject.Namespace, + OptionalSubjectIds: []string{crc.parentReq.Subject.ObjectId}, + RelationFilter: datastore.SubjectRelationFilter{}.WithRelation(crc.parentReq.Subject.Relation), + }) + } + + if hasWildcardSubject { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: crc.parentReq.Subject.Namespace, + OptionalSubjectIds: []string{tuple.PublicWildcard}, + RelationFilter: datastore.SubjectRelationFilter{}.WithEllipsisRelation(), + }) + } + + filter := datastore.RelationshipsFilter{ + OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, + OptionalResourceIds: crc.filteredResourceIDs, + OptionalResourceRelation: crc.parentReq.ResourceRelation.Relation, + OptionalSubjectsSelectors: subjectSelectors, + } + + it, err := ds.QueryRelationships(ctx, filter, + options.WithSkipCaveats(!directSubjectOrWildcardCanHaveCaveats), + options.WithSkipExpiration(!directSubjectOrWildcardCanHaveExpiration), + options.WithQueryShape(queryshape.CheckPermissionSelectDirectSubjects), + ) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + queryCount += 1.0 + + // Find the matching subject(s). + for rel, err := range it { + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + // If the subject of the relationship matches the target subject, then we've found + // a result. + foundResources.AddDirectMember(rel.Resource.ObjectID, rel.OptionalCaveat) + if crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT && foundResources.HasDeterminedMember() { + return checkResultsForMembership(foundResources, emptyMetadata) + } + } + } + + // Filter down the resource IDs for further dispatch based on whether they exist as found + // subjects in the existing membership set. + furtherFilteredResourceIDs := make([]string, 0, len(crc.filteredResourceIDs)-foundResources.Size()) + for _, resourceID := range crc.filteredResourceIDs { + if foundResources.HasConcreteResourceID(resourceID) { + continue + } + + furtherFilteredResourceIDs = append(furtherFilteredResourceIDs, resourceID) + } + + // If there are no possible non-terminals, then the check is completed. + if !hasNonTerminals || len(furtherFilteredResourceIDs) == 0 { + return checkResultsForMembership(foundResources, emptyMetadata) + } + + // Otherwise, for any remaining resource IDs, query for redispatch. + filter := datastore.RelationshipsFilter{ + OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, + OptionalResourceIds: furtherFilteredResourceIDs, + OptionalResourceRelation: crc.parentReq.ResourceRelation.Relation, + OptionalSubjectsSelectors: []datastore.SubjectsSelector{ + { + RelationFilter: datastore.SubjectRelationFilter{}.WithOnlyNonEllipsisRelations(), + }, + }, + } + + it, err := ds.QueryRelationships(ctx, filter, + options.WithSkipCaveats(!nonTerminalsCanHaveCaveats), + options.WithSkipExpiration(!nonTerminalsCanHaveExpiration), + options.WithQueryShape(queryshape.CheckPermissionSelectIndirectSubjects), + ) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + queryCount += 1.0 + + // Build the set of subjects over which to dispatch, along with metadata for + // mapping over caveats (if any). + checksToDispatch := newCheckDispatchSet() + for rel, err := range it { + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + checksToDispatch.addForRelationship(rel) + } + + // Dispatch and map to the associated resource ID(s). + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) + result := union(ctx, crc, toDispatch, func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) CheckResult { + // If there are caveats on any of the incoming relationships for the subjects to dispatch, then we must require all + // results to be found, as we need to ensure that all caveats are used for building the final expression. + resultsSetting := crc.resultsSetting + if dd.hasIncomingCaveats { + resultsSetting = v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS + } + + childResult := cc.dispatch(ctx, crc, ValidatedCheckRequest{ + &v1.DispatchCheckRequest{ + ResourceRelation: dd.resourceType.ToCoreRR(), + ResourceIds: dd.resourceIds, + Subject: crc.parentReq.Subject, + ResultsSetting: resultsSetting, + + Metadata: decrementDepth(crc.parentReq.Metadata), + Debug: crc.parentReq.Debug, + CheckHints: crc.parentReq.CheckHints, + }, + crc.parentReq.Revision, + noOriginalRelation, + }) + + if childResult.Err != nil { + return childResult + } + + return mapFoundResources(childResult, dd.resourceType, checksToDispatch) + }, cc.concurrencyLimit) + + return combineResultWithFoundResources(result, foundResources) +} + +func mapFoundResources(result CheckResult, resourceType tuple.RelationReference, checksToDispatch *checkDispatchSet) CheckResult { + // Map any resources found to the parent resource IDs. + membershipSet := NewMembershipSet() + for foundResourceID, result := range result.Resp.ResultsByResourceId { + resourceIDAndCaveats := checksToDispatch.mappingsForSubject(resourceType.ObjectType, foundResourceID, resourceType.Relation) + + spiceerrors.DebugAssert(func() bool { + return len(resourceIDAndCaveats) > 0 + }, "found resource ID without associated caveats") + + for _, riac := range resourceIDAndCaveats { + membershipSet.AddMemberWithParentCaveat(riac.resourceID, result.Expression, riac.caveat) + } + } + + if membershipSet.IsEmpty() { + return noMembersWithMetadata(result.Resp.Metadata) + } + + return checkResultsForMembership(membershipSet, result.Resp.Metadata) +} + +func (cc *ConcurrentChecker) checkUsersetRewrite(ctx context.Context, crc currentRequestContext, rewrite *core.UsersetRewrite) CheckResult { + switch rw := rewrite.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + if len(rw.Union.Child) > 1 { + var span trace.Span + ctx, span = tracer.Start(ctx, "+") + defer span.End() + } + return union(ctx, crc, rw.Union.Child, cc.runSetOperation, cc.concurrencyLimit) + case *core.UsersetRewrite_Intersection: + ctx, span := tracer.Start(ctx, "&") + defer span.End() + return all(ctx, crc, rw.Intersection.Child, cc.runSetOperation, cc.concurrencyLimit) + case *core.UsersetRewrite_Exclusion: + ctx, span := tracer.Start(ctx, "-") + defer span.End() + return difference(ctx, crc, rw.Exclusion.Child, cc.runSetOperation, cc.concurrencyLimit) + default: + return checkResultError(spiceerrors.MustBugf("unknown userset rewrite operator"), emptyMetadata) + } +} + +func (cc *ConcurrentChecker) dispatch(ctx context.Context, _ currentRequestContext, req ValidatedCheckRequest) CheckResult { + log.Ctx(ctx).Trace().Object("dispatch", req).Send() + result, err := cc.d.DispatchCheck(ctx, req.DispatchCheckRequest) + return CheckResult{result, err} +} + +func (cc *ConcurrentChecker) runSetOperation(ctx context.Context, crc currentRequestContext, childOneof *core.SetOperation_Child) CheckResult { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_XThis: + return checkResultError(spiceerrors.MustBugf("use of _this is unsupported; please rewrite your schema"), emptyMetadata) + case *core.SetOperation_Child_ComputedUserset: + return cc.checkComputedUserset(ctx, crc, child.ComputedUserset, nil, nil) + case *core.SetOperation_Child_UsersetRewrite: + return cc.checkUsersetRewrite(ctx, crc, child.UsersetRewrite) + case *core.SetOperation_Child_TupleToUserset: + return checkTupleToUserset(ctx, cc, crc, child.TupleToUserset) + case *core.SetOperation_Child_FunctionedTupleToUserset: + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + return checkTupleToUserset(ctx, cc, crc, child.FunctionedTupleToUserset) + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + return checkIntersectionTupleToUserset(ctx, cc, crc, child.FunctionedTupleToUserset) + + default: + return checkResultError(spiceerrors.MustBugf("unknown userset function `%s`", child.FunctionedTupleToUserset.Function), emptyMetadata) + } + + case *core.SetOperation_Child_XNil: + return noMembers() + default: + return checkResultError(spiceerrors.MustBugf("unknown set operation child `%T` in check", child), emptyMetadata) + } +} + +func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc currentRequestContext, cu *core.ComputedUserset, rr *tuple.RelationReference, resourceIds []string) CheckResult { + ctx, span := tracer.Start(ctx, cu.Relation) + defer span.End() + + var startNamespace string + var targetResourceIds []string + if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT { + if rr == nil || len(resourceIds) == 0 { + return checkResultError(spiceerrors.MustBugf("computed userset for tupleset without tuples"), emptyMetadata) + } + + startNamespace = rr.ObjectType + targetResourceIds = resourceIds + } else if cu.Object == core.ComputedUserset_TUPLE_OBJECT { + if rr != nil { + return checkResultError(spiceerrors.MustBugf("computed userset for tupleset with wrong object type"), emptyMetadata) + } + + startNamespace = crc.parentReq.ResourceRelation.Namespace + targetResourceIds = crc.filteredResourceIDs + } + + targetRR := &core.RelationReference{ + Namespace: startNamespace, + Relation: cu.Relation, + } + + // If we will be dispatching to the goal's ONR, then we know that the ONR is a member. + membershipSet, updatedTargetResourceIds := filterForFoundMemberResource(targetRR, targetResourceIds, crc.parentReq.Subject) + if (membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT) || len(updatedTargetResourceIds) == 0 { + return checkResultsForMembership(membershipSet, emptyMetadata) + } + + // Check if the target relation exists. If not, return nothing. This is only necessary + // for TTU-based computed usersets, as directly computed ones reference relations within + // the same namespace as the caller, and thus must be fully typed checked. + if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + err := namespace.CheckNamespaceAndRelation(ctx, targetRR.Namespace, targetRR.Relation, true, ds) + if err != nil { + if errors.As(err, &namespace.RelationNotFoundError{}) { + return noMembers() + } + + return checkResultError(err, emptyMetadata) + } + } + + result := cc.dispatch(ctx, crc, ValidatedCheckRequest{ + &v1.DispatchCheckRequest{ + ResourceRelation: targetRR, + ResourceIds: updatedTargetResourceIds, + Subject: crc.parentReq.Subject, + ResultsSetting: crc.resultsSetting, + Metadata: decrementDepth(crc.parentReq.Metadata), + Debug: crc.parentReq.Debug, + CheckHints: crc.parentReq.CheckHints, + }, + crc.parentReq.Revision, + noOriginalRelation, + }) + return combineResultWithFoundResources(result, membershipSet) +} + +type Traits struct { + HasCaveats bool + HasExpiration bool +} + +// TraitsForArrowRelation returns traits such as HasCaveats and HasExpiration if *any* of the subject +// types of the given relation support caveats or expiration. +func TraitsForArrowRelation(ctx context.Context, reader datastore.Reader, namespaceName string, relationName string) (Traits, error) { + // TODO(jschorr): Change to use the type system once we wire it through Check dispatch. + nsDef, _, err := reader.ReadNamespaceByName(ctx, namespaceName) + if err != nil { + return Traits{}, err + } + + var relation *core.Relation + for _, rel := range nsDef.Relation { + if rel.Name == relationName { + relation = rel + break + } + } + + if relation == nil || relation.TypeInformation == nil { + return Traits{}, fmt.Errorf("relation %q not found", relationName) + } + + hasCaveats := false + hasExpiration := false + + for _, allowedDirectRelation := range relation.TypeInformation.GetAllowedDirectRelations() { + if allowedDirectRelation.RequiredCaveat != nil { + hasCaveats = true + } + + if allowedDirectRelation.RequiredExpiration != nil { + hasExpiration = true + } + } + + return Traits{ + HasCaveats: hasCaveats, + HasExpiration: hasExpiration, + }, nil +} + +func queryOptionsForArrowRelation(ctx context.Context, ds datastore.Reader, namespaceName string, relationName string) ([]options.QueryOptionsOption, error) { + opts := make([]options.QueryOptionsOption, 0, 3) + opts = append(opts, options.WithQueryShape(queryshape.AllSubjectsForResources)) + + traits, err := TraitsForArrowRelation(ctx, ds, namespaceName, relationName) + if err != nil { + return nil, err + } + + if !traits.HasCaveats { + opts = append(opts, options.WithSkipCaveats(true)) + } + + if !traits.HasExpiration { + opts = append(opts, options.WithSkipExpiration(true)) + } + + return opts, nil +} + +func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) (*MembershipSet, []string) { + if resourceRelation.Namespace != subject.Namespace || resourceRelation.Relation != subject.Relation { + return nil, resourceIds + } + + for index, resourceID := range resourceIds { + if subject.ObjectId == resourceID { + membershipSet := NewMembershipSet() + membershipSet.AddDirectMember(resourceID, nil) + return membershipSet, removeIndexFromSlice(resourceIds, index) + } + } + + return nil, resourceIds +} + +func removeIndexFromSlice[T any](s []T, index int) []T { + cpy := make([]T, 0, len(s)-1) + cpy = append(cpy, s[:index]...) + return append(cpy, s[index+1:]...) +} + +type relation interface { + GetRelation() string +} + +type ttu[T relation] interface { + GetComputedUserset() *core.ComputedUserset + GetTupleset() T +} + +type checkResultWithType struct { + CheckResult + + relationType tuple.RelationReference +} + +func checkIntersectionTupleToUserset( + ctx context.Context, + cc *ConcurrentChecker, + crc currentRequestContext, + ttu *core.FunctionedTupleToUserset, +) CheckResult { + // TODO(jschorr): use check hints here + ctx, span := tracer.Start(ctx, ttu.GetTupleset().GetRelation()+"-(all)->"+ttu.GetComputedUserset().Relation) + defer span.End() + + // Query for the subjects over which to walk the TTU. + log.Ctx(ctx).Trace().Object("intersectionttu", crc.parentReq).Send() + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, + OptionalResourceIds: crc.filteredResourceIDs, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + }, queryOpts...) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + checksToDispatch := newCheckDispatchSet() + subjectsByResourceID := mapz.NewMultiMap[string, tuple.ObjectAndRelation]() + for rel, err := range it { + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + checksToDispatch.addForRelationship(rel) + subjectsByResourceID.Add(rel.Resource.ObjectID, rel.Subject) + } + + // Convert the subjects into batched requests. + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) + if len(toDispatch) == 0 { + return noMembers() + } + + // Run the dispatch for all the chunks. Unlike a standard TTU, we do *not* perform mapping here, + // as we need to access the results on a per subject basis. Instead, we keep each result and map + // by the relation type of the dispatched subject. + chunkResults, err := run( + ctx, + currentRequestContext{ + parentReq: crc.parentReq, + filteredResourceIDs: crc.filteredResourceIDs, + resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS, + dispatchChunkSize: crc.dispatchChunkSize, + }, + toDispatch, + func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) checkResultWithType { + resourceType := dd.resourceType + childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), &resourceType, dd.resourceIds) + return checkResultWithType{ + CheckResult: childResult, + relationType: dd.resourceType, + } + }, + cc.concurrencyLimit, + ) + if err != nil { + return checkResultError(err, emptyMetadata) + } + + // Create a membership set per-subject-type, representing the membership for each of the dispatched subjects. + resultsByDispatchedSubject := map[tuple.RelationReference]*MembershipSet{} + combinedMetadata := emptyMetadata + for _, result := range chunkResults { + if result.Err != nil { + return checkResultError(result.Err, emptyMetadata) + } + + if _, ok := resultsByDispatchedSubject[result.relationType]; !ok { + resultsByDispatchedSubject[result.relationType] = NewMembershipSet() + } + + resultsByDispatchedSubject[result.relationType].UnionWith(result.Resp.ResultsByResourceId) + combinedMetadata = combineResponseMetadata(ctx, combinedMetadata, result.Resp.Metadata) + } + + // For each resource ID, check that there exist some sort of permission for *each* subject. If not, then the + // intersection for that resource fails. If all subjects have some sort of permission, then the resource ID is + // a member, perhaps caveated. + resourcesFound := NewMembershipSet() + for _, resourceID := range subjectsByResourceID.Keys() { + subjects, _ := subjectsByResourceID.Get(resourceID) + if len(subjects) == 0 { + return checkResultError(spiceerrors.MustBugf("no subjects found for resource ID %s", resourceID), emptyMetadata) + } + + hasAllSubjects := true + caveats := make([]*core.CaveatExpression, 0, len(subjects)) + + // Check each of the subjects found for the resource ID and ensure that membership (at least caveated) + // was found for each. If any are not found, then the resource ID is not a member. + // We also collect up the caveats for each subject, as they will be added to the final result. + for _, subject := range subjects { + subjectTypeKey := subject.RelationReference() + results, ok := resultsByDispatchedSubject[subjectTypeKey] + if !ok { + hasAllSubjects = false + break + } + + hasMembership, caveat := results.GetResourceID(subject.ObjectID) + if !hasMembership { + hasAllSubjects = false + break + } + + if caveat != nil { + caveats = append(caveats, caveat) + } + + // Add any caveats on the subject from the starting relationship(s) as well. + resourceIDAndCaveats := checksToDispatch.mappingsForSubject(subject.ObjectType, subject.ObjectID, subject.Relation) + for _, riac := range resourceIDAndCaveats { + if riac.caveat != nil { + caveats = append(caveats, wrapCaveat(riac.caveat)) + } + } + } + + if !hasAllSubjects { + continue + } + + // Add the member to the membership set, with the caveats for each (if any). + resourcesFound.AddMemberWithOptionalCaveats(resourceID, caveats) + } + + return checkResultsForMembership(resourcesFound, combinedMetadata) +} + +func checkTupleToUserset[T relation]( + ctx context.Context, + cc *ConcurrentChecker, + crc currentRequestContext, + ttu ttu[T], +) CheckResult { + filteredResourceIDs := crc.filteredResourceIDs + hintsToReturn := make(map[string]*v1.ResourceCheckResult, len(crc.parentReq.CheckHints)) + if len(crc.parentReq.CheckHints) > 0 { + filteredResourcesIdsSet := mapz.NewSet(crc.filteredResourceIDs...) + + for _, checkHint := range crc.parentReq.CheckHints { + resourceID, ok := hints.AsCheckHintForArrow( + checkHint, + crc.parentReq.ResourceRelation.Namespace, + ttu.GetTupleset().GetRelation(), + ttu.GetComputedUserset().Relation, + tuple.FromCoreObjectAndRelation(crc.parentReq.Subject), + ) + if !ok { + continue + } + + filteredResourcesIdsSet.Delete(resourceID) + hintsToReturn[resourceID] = checkHint.Result + } + + filteredResourceIDs = filteredResourcesIdsSet.AsSlice() + } + + if len(filteredResourceIDs) == 0 { + return combineWithComputedHints(noMembers(), hintsToReturn) + } + + ctx, span := tracer.Start(ctx, ttu.GetTupleset().GetRelation()+"->"+ttu.GetComputedUserset().Relation) + defer span.End() + + log.Ctx(ctx).Trace().Object("ttu", crc.parentReq).Send() + ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + + queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, + OptionalResourceIds: filteredResourceIDs, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + }, queryOpts...) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + + checksToDispatch := newCheckDispatchSet() + for rel, err := range it { + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + checksToDispatch.addForRelationship(rel) + } + + toDispatch := checksToDispatch.dispatchChunks(crc.dispatchChunkSize) + return combineWithComputedHints(union( + ctx, + crc, + toDispatch, + func(ctx context.Context, crc currentRequestContext, dd checkDispatchChunk) CheckResult { + resourceType := dd.resourceType + childResult := cc.checkComputedUserset(ctx, crc, ttu.GetComputedUserset(), &resourceType, dd.resourceIds) + if childResult.Err != nil { + return childResult + } + + return mapFoundResources(childResult, dd.resourceType, checksToDispatch) + }, + cc.concurrencyLimit, + ), hintsToReturn) +} + +func withDistinctMetadata(ctx context.Context, result CheckResult) CheckResult { + // NOTE: This is necessary to ensure unique debug information on the request and that debug + // information from the child metadata is *not* copied over. + clonedResp := result.Resp.CloneVT() + clonedResp.Metadata = combineResponseMetadata(ctx, emptyMetadata, clonedResp.Metadata) + return CheckResult{ + Resp: clonedResp, + Err: result.Err, + } +} + +// run runs all the children in parallel and returns the full set of results. +func run[T any, R withError]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) R, + concurrencyLimit uint16, +) ([]R, error) { + if len(children) == 0 { + return nil, nil + } + + if len(children) == 1 { + return []R{handler(ctx, crc, children[0])}, nil + } + + resultChan := make(chan R, len(children)) + childCtx, cancelFn := context.WithCancel(ctx) + dispatchAllAsync(childCtx, crc, children, handler, resultChan, concurrencyLimit) + defer cancelFn() + + results := make([]R, 0, len(children)) + for i := 0; i < len(children); i++ { + select { + case result := <-resultChan: + results = append(results, result) + + case <-ctx.Done(): + log.Ctx(ctx).Trace().Msg("anyCanceled") + return nil, ctx.Err() + } + } + + return results, nil +} + +// union returns whether any one of the lazy checks pass, and is used for union. +func union[T any]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult, + concurrencyLimit uint16, +) CheckResult { + if len(children) == 0 { + return noMembers() + } + + if len(children) == 1 { + return withDistinctMetadata(ctx, handler(ctx, crc, children[0])) + } + + resultChan := make(chan CheckResult, len(children)) + childCtx, cancelFn := context.WithCancel(ctx) + dispatchAllAsync(childCtx, crc, children, handler, resultChan, concurrencyLimit) + defer cancelFn() + + responseMetadata := emptyMetadata + membershipSet := NewMembershipSet() + + for i := 0; i < len(children); i++ { + select { + case result := <-resultChan: + log.Ctx(ctx).Trace().Object("anyResult", result.Resp).Send() + responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata) + if result.Err != nil { + return checkResultError(result.Err, responseMetadata) + } + + membershipSet.UnionWith(result.Resp.ResultsByResourceId) + if membershipSet.HasDeterminedMember() && crc.resultsSetting == v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT { + return checkResultsForMembership(membershipSet, responseMetadata) + } + + case <-ctx.Done(): + log.Ctx(ctx).Trace().Msg("anyCanceled") + return checkResultError(context.Canceled, responseMetadata) + } + } + + return checkResultsForMembership(membershipSet, responseMetadata) +} + +// all returns whether all of the lazy checks pass, and is used for intersection. +func all[T any]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult, + concurrencyLimit uint16, +) CheckResult { + if len(children) == 0 { + return noMembers() + } + + if len(children) == 1 { + return withDistinctMetadata(ctx, handler(ctx, crc, children[0])) + } + + responseMetadata := emptyMetadata + + resultChan := make(chan CheckResult, len(children)) + childCtx, cancelFn := context.WithCancel(ctx) + dispatchAllAsync(childCtx, currentRequestContext{ + parentReq: crc.parentReq, + filteredResourceIDs: crc.filteredResourceIDs, + resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS, + dispatchChunkSize: crc.dispatchChunkSize, + }, children, handler, resultChan, concurrencyLimit) + defer cancelFn() + + var membershipSet *MembershipSet + for i := 0; i < len(children); i++ { + select { + case result := <-resultChan: + responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata) + if result.Err != nil { + return checkResultError(result.Err, responseMetadata) + } + + if membershipSet == nil { + membershipSet = NewMembershipSet() + membershipSet.UnionWith(result.Resp.ResultsByResourceId) + } else { + membershipSet.IntersectWith(result.Resp.ResultsByResourceId) + } + + if membershipSet.IsEmpty() { + return noMembersWithMetadata(responseMetadata) + } + case <-ctx.Done(): + return checkResultError(context.Canceled, responseMetadata) + } + } + + return checkResultsForMembership(membershipSet, responseMetadata) +} + +// difference returns whether the first lazy check passes and none of the subsequent checks pass. +func difference[T any]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) CheckResult, + concurrencyLimit uint16, +) CheckResult { + if len(children) == 0 { + return noMembers() + } + + if len(children) == 1 { + return checkResultError(spiceerrors.MustBugf("difference requires more than a single child"), emptyMetadata) + } + + childCtx, cancelFn := context.WithCancel(ctx) + baseChan := make(chan CheckResult, 1) + othersChan := make(chan CheckResult, len(children)-1) + + go func() { + result := handler(childCtx, currentRequestContext{ + parentReq: crc.parentReq, + filteredResourceIDs: crc.filteredResourceIDs, + resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS, + dispatchChunkSize: crc.dispatchChunkSize, + }, children[0]) + baseChan <- result + }() + + dispatchAllAsync(childCtx, currentRequestContext{ + parentReq: crc.parentReq, + filteredResourceIDs: crc.filteredResourceIDs, + resultsSetting: v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS, + dispatchChunkSize: crc.dispatchChunkSize, + }, children[1:], handler, othersChan, concurrencyLimit-1) + defer cancelFn() + + responseMetadata := emptyMetadata + membershipSet := NewMembershipSet() + + // Wait for the base set to return. + select { + case base := <-baseChan: + responseMetadata = combineResponseMetadata(ctx, responseMetadata, base.Resp.Metadata) + + if base.Err != nil { + return checkResultError(base.Err, responseMetadata) + } + + membershipSet.UnionWith(base.Resp.ResultsByResourceId) + if membershipSet.IsEmpty() { + return noMembersWithMetadata(responseMetadata) + } + + case <-ctx.Done(): + return checkResultError(context.Canceled, responseMetadata) + } + + // Subtract the remaining sets. + for i := 1; i < len(children); i++ { + select { + case sub := <-othersChan: + responseMetadata = combineResponseMetadata(ctx, responseMetadata, sub.Resp.Metadata) + + if sub.Err != nil { + return checkResultError(sub.Err, responseMetadata) + } + + membershipSet.Subtract(sub.Resp.ResultsByResourceId) + if membershipSet.IsEmpty() { + return noMembersWithMetadata(responseMetadata) + } + + case <-ctx.Done(): + return checkResultError(context.Canceled, responseMetadata) + } + } + + return checkResultsForMembership(membershipSet, responseMetadata) +} + +type withError interface { + ResultError() error +} + +func dispatchAllAsync[T any, R withError]( + ctx context.Context, + crc currentRequestContext, + children []T, + handler func(ctx context.Context, crc currentRequestContext, child T) R, + resultChan chan<- R, + concurrencyLimit uint16, +) { + tr := taskrunner.NewPreloadedTaskRunner(ctx, concurrencyLimit, len(children)) + for _, currentChild := range children { + currentChild := currentChild + tr.Add(func(ctx context.Context) error { + result := handler(ctx, crc, currentChild) + resultChan <- result + return result.ResultError() + }) + } + + tr.Start() +} + +func noMembers() CheckResult { + return CheckResult{ + &v1.DispatchCheckResponse{ + Metadata: emptyMetadata, + }, + nil, + } +} + +func noMembersWithMetadata(metadata *v1.ResponseMeta) CheckResult { + return CheckResult{ + &v1.DispatchCheckResponse{ + Metadata: metadata, + }, + nil, + } +} + +func checkResultsForMembership(foundMembership *MembershipSet, subProblemMetadata *v1.ResponseMeta) CheckResult { + return CheckResult{ + &v1.DispatchCheckResponse{ + Metadata: ensureMetadata(subProblemMetadata), + ResultsByResourceId: foundMembership.AsCheckResultsMap(), + }, + nil, + } +} + +func checkResultError(err error, subProblemMetadata *v1.ResponseMeta) CheckResult { + return CheckResult{ + &v1.DispatchCheckResponse{ + Metadata: ensureMetadata(subProblemMetadata), + }, + err, + } +} + +func combineResultWithFoundResources(result CheckResult, foundResources *MembershipSet) CheckResult { + if result.Err != nil { + return result + } + + if foundResources.IsEmpty() { + return result + } + + foundResources.UnionWith(result.Resp.ResultsByResourceId) + return CheckResult{ + Resp: &v1.DispatchCheckResponse{ + ResultsByResourceId: foundResources.AsCheckResultsMap(), + Metadata: result.Resp.Metadata, + }, + Err: result.Err, + } +} + +func combineResponseMetadata(ctx context.Context, existing *v1.ResponseMeta, responseMetadata *v1.ResponseMeta) *v1.ResponseMeta { + combined := &v1.ResponseMeta{ + DispatchCount: existing.DispatchCount + responseMetadata.DispatchCount, + DepthRequired: max(existing.DepthRequired, responseMetadata.DepthRequired), + CachedDispatchCount: existing.CachedDispatchCount + responseMetadata.CachedDispatchCount, + } + + if existing.DebugInfo == nil && responseMetadata.DebugInfo == nil { + return combined + } + + nodeID, err := nodeid.FromContext(ctx) + if err != nil { + log.Err(err).Msg("failed to get nodeID from context") + } + + debugInfo := &v1.DebugInformation{ + Check: &v1.CheckDebugTrace{ + TraceId: NewTraceID(), + SourceId: nodeID, + }, + } + + if existing.DebugInfo != nil { + if existing.DebugInfo.Check.Request != nil { + debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, existing.DebugInfo.Check) + } else { + debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, existing.DebugInfo.Check.SubProblems...) + } + } + + if responseMetadata.DebugInfo != nil { + if responseMetadata.DebugInfo.Check.Request != nil { + debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, responseMetadata.DebugInfo.Check) + } else { + debugInfo.Check.SubProblems = append(debugInfo.Check.SubProblems, responseMetadata.DebugInfo.Check.SubProblems...) + } + } + + combined.DebugInfo = debugInfo + return combined +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go b/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go new file mode 100644 index 0000000..ed3f3cb --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/checkdispatchset.go @@ -0,0 +1,144 @@ +package graph + +import ( + "sort" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/genutil/slicez" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// checkDispatchSet is the set of subjects over which check will need to dispatch +// as subproblems in order to answer the parent problem. +type checkDispatchSet struct { + // bySubjectType is a map from the type of subject to the set of subjects of that type + // over which to dispatch, along with information indicating whether caveats are present + // for that chunk. + bySubjectType map[tuple.RelationReference]map[string]bool + + // bySubject is a map from the subject to the set of resources for which the subject + // has a relationship, along with the caveats that apply to that relationship. + bySubject *mapz.MultiMap[tuple.ObjectAndRelation, resourceIDAndCaveat] +} + +// checkDispatchChunk is a chunk of subjects over which to dispatch a check operation. +type checkDispatchChunk struct { + // resourceType is the type of the subjects in this chunk. + resourceType tuple.RelationReference + + // resourceIds is the set of subjects in this chunk. + resourceIds []string + + // hasIncomingCaveats is true if any of the subjects in this chunk have incoming caveats. + // This is used to determine whether the check operation should be dispatched requiring + // all results. + hasIncomingCaveats bool +} + +// subjectIDAndHasCaveat is a tuple of a subject ID and whether it has a caveat. +type subjectIDAndHasCaveat struct { + // objectID is the ID of the subject. + objectID string + + // hasIncomingCaveats is true if the subject has a caveat. + hasIncomingCaveats bool +} + +// resourceIDAndCaveat is a tuple of a resource ID and a caveat. +type resourceIDAndCaveat struct { + // resourceID is the ID of the resource. + resourceID string + + // caveat is the caveat that applies to the relationship between the subject and the resource. + // May be nil. + caveat *core.ContextualizedCaveat +} + +// newCheckDispatchSet creates and returns a new checkDispatchSet. +func newCheckDispatchSet() *checkDispatchSet { + return &checkDispatchSet{ + bySubjectType: map[tuple.RelationReference]map[string]bool{}, + bySubject: mapz.NewMultiMap[tuple.ObjectAndRelation, resourceIDAndCaveat](), + } +} + +// Add adds the specified ObjectAndRelation to the set. +func (s *checkDispatchSet) addForRelationship(rel tuple.Relationship) { + // Add an entry for the subject pointing to the resource ID and caveat for the subject. + riac := resourceIDAndCaveat{ + resourceID: rel.Resource.ObjectID, + caveat: rel.OptionalCaveat, + } + s.bySubject.Add(rel.Subject, riac) + + // Add the subject ID to the map of subjects for the type of subject. + siac := subjectIDAndHasCaveat{ + objectID: rel.Subject.ObjectID, + hasIncomingCaveats: rel.OptionalCaveat != nil && rel.OptionalCaveat.CaveatName != "", + } + + subjectIDsForType, ok := s.bySubjectType[rel.Subject.RelationReference()] + if !ok { + subjectIDsForType = make(map[string]bool) + s.bySubjectType[rel.Subject.RelationReference()] = subjectIDsForType + } + + // If a caveat exists for the subject ID in any branch, the whole branch is considered caveated. + subjectIDsForType[rel.Subject.ObjectID] = siac.hasIncomingCaveats || subjectIDsForType[rel.Subject.ObjectID] +} + +func (s *checkDispatchSet) dispatchChunks(dispatchChunkSize uint16) []checkDispatchChunk { + // Start with an estimate of one chunk per type, plus one for the remainder. + expectedNumberOfChunks := len(s.bySubjectType) + 1 + toDispatch := make([]checkDispatchChunk, 0, expectedNumberOfChunks) + + // For each type of subject, create chunks of the IDs over which to dispatch. + for subjectType, subjectIDsAndHasCaveats := range s.bySubjectType { + entries := make([]subjectIDAndHasCaveat, 0, len(subjectIDsAndHasCaveats)) + for objectID, hasIncomingCaveats := range subjectIDsAndHasCaveats { + entries = append(entries, subjectIDAndHasCaveat{objectID: objectID, hasIncomingCaveats: hasIncomingCaveats}) + } + + // Sort the list of subject IDs by whether they have caveats and then the ID itself. + sort.Slice(entries, func(i, j int) bool { + iHasCaveat := entries[i].hasIncomingCaveats + jHasCaveat := entries[j].hasIncomingCaveats + if iHasCaveat == jHasCaveat { + return entries[i].objectID < entries[j].objectID + } + return iHasCaveat && !jHasCaveat + }) + + chunkCount := 0.0 + slicez.ForEachChunk(entries, dispatchChunkSize, func(subjectIdChunk []subjectIDAndHasCaveat) { + chunkCount++ + + subjectIDsToDispatch := make([]string, 0, len(subjectIdChunk)) + hasIncomingCaveats := false + for _, entry := range subjectIdChunk { + subjectIDsToDispatch = append(subjectIDsToDispatch, entry.objectID) + hasIncomingCaveats = hasIncomingCaveats || entry.hasIncomingCaveats + } + + toDispatch = append(toDispatch, checkDispatchChunk{ + resourceType: subjectType, + resourceIds: subjectIDsToDispatch, + hasIncomingCaveats: hasIncomingCaveats, + }) + }) + dispatchChunkCountHistogram.Observe(chunkCount) + } + + return toDispatch +} + +// mappingsForSubject returns the mappings that apply to the relationship between the specified +// subject and any of its resources. The returned caveats include the resource ID of the resource +// that the subject has a relationship with. +func (s *checkDispatchSet) mappingsForSubject(subjectType string, subjectObjectID string, subjectRelation string) []resourceIDAndCaveat { + results, ok := s.bySubject.Get(tuple.ONR(subjectType, subjectObjectID, subjectRelation)) + spiceerrors.DebugAssert(func() bool { return ok }, "no caveats found for subject %s:%s:%s", subjectType, subjectObjectID, subjectRelation) + return results +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go b/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go new file mode 100644 index 0000000..0bf20b5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/computed/computecheck.go @@ -0,0 +1,205 @@ +package computed + +import ( + "context" + + cexpr "github.com/authzed/spicedb/internal/caveats" + "github.com/authzed/spicedb/internal/dispatch" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/slicez" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// DebugOption defines the various debug level options for Checks. +type DebugOption int + +const ( + // NoDebugging indicates that debug information should be retained + // while performing the Check. + NoDebugging DebugOption = 0 + + // BasicDebuggingEnabled indicates that basic debug information, such + // as which steps were taken, should be retained while performing the + // Check and returned to the caller. + // + // NOTE: This has a minor performance impact. + BasicDebuggingEnabled DebugOption = 1 + + // TraceDebuggingEnabled indicates that the Check is being issued for + // tracing the exact calls made for debugging, which means that not only + // should debug information be recorded and returned, but that optimizations + // such as batching should be disabled. + // + // WARNING: This has a fairly significant performance impact and should only + // be used in tooling! + TraceDebuggingEnabled DebugOption = 2 +) + +// CheckParameters are the parameters for the ComputeCheck call. *All* are required. +type CheckParameters struct { + ResourceType tuple.RelationReference + Subject tuple.ObjectAndRelation + CaveatContext map[string]any + AtRevision datastore.Revision + MaximumDepth uint32 + DebugOption DebugOption + CheckHints []*v1.CheckHint +} + +// ComputeCheck computes a check result for the given resource and subject, computing any +// caveat expressions found. +func ComputeCheck( + ctx context.Context, + d dispatch.Check, + ts *caveattypes.TypeSet, + params CheckParameters, + resourceID string, + dispatchChunkSize uint16, +) (*v1.ResourceCheckResult, *v1.ResponseMeta, error) { + resultsMap, meta, di, err := computeCheck(ctx, d, ts, params, []string{resourceID}, dispatchChunkSize) + if err != nil { + return nil, meta, err + } + + spiceerrors.DebugAssert(func() bool { + return (len(di) == 0 && meta.DebugInfo == nil) || (len(di) == 1 && meta.DebugInfo != nil) + }, "mismatch in debug information returned from computeCheck") + + return resultsMap[resourceID], meta, err +} + +// ComputeBulkCheck computes a check result for the given resources and subject, computing any +// caveat expressions found. +func ComputeBulkCheck( + ctx context.Context, + d dispatch.Check, + ts *caveattypes.TypeSet, + params CheckParameters, + resourceIDs []string, + dispatchChunkSize uint16, +) (map[string]*v1.ResourceCheckResult, *v1.ResponseMeta, []*v1.DebugInformation, error) { + return computeCheck(ctx, d, ts, params, resourceIDs, dispatchChunkSize) +} + +func computeCheck(ctx context.Context, + d dispatch.Check, + ts *caveattypes.TypeSet, + params CheckParameters, + resourceIDs []string, + dispatchChunkSize uint16, +) (map[string]*v1.ResourceCheckResult, *v1.ResponseMeta, []*v1.DebugInformation, error) { + debugging := v1.DispatchCheckRequest_NO_DEBUG + if params.DebugOption == BasicDebuggingEnabled { + debugging = v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING + } else if params.DebugOption == TraceDebuggingEnabled { + debugging = v1.DispatchCheckRequest_ENABLE_TRACE_DEBUGGING + } + + setting := v1.DispatchCheckRequest_REQUIRE_ALL_RESULTS + if len(resourceIDs) == 1 { + setting = v1.DispatchCheckRequest_ALLOW_SINGLE_RESULT + } + + // Ensure that the number of resources IDs given to each dispatch call is not in excess of the maximum. + results := make(map[string]*v1.ResourceCheckResult, len(resourceIDs)) + metadata := &v1.ResponseMeta{} + + bf, err := v1.NewTraversalBloomFilter(uint(params.MaximumDepth)) + if err != nil { + return nil, nil, nil, spiceerrors.MustBugf("failed to create new traversal bloom filter") + } + + caveatRunner := cexpr.NewCaveatRunner(ts) + + // TODO(jschorr): Should we make this run in parallel via the preloadedTaskRunner? + debugInfo := make([]*v1.DebugInformation, 0) + _, err = slicez.ForEachChunkUntil(resourceIDs, dispatchChunkSize, func(resourceIDsToCheck []string) (bool, error) { + checkResult, err := d.DispatchCheck(ctx, &v1.DispatchCheckRequest{ + ResourceRelation: params.ResourceType.ToCoreRR(), + ResourceIds: resourceIDsToCheck, + ResultsSetting: setting, + Subject: params.Subject.ToCoreONR(), + Metadata: &v1.ResolverMeta{ + AtRevision: params.AtRevision.String(), + DepthRemaining: params.MaximumDepth, + TraversalBloom: bf, + }, + Debug: debugging, + CheckHints: params.CheckHints, + }) + + if checkResult.Metadata.DebugInfo != nil { + debugInfo = append(debugInfo, checkResult.Metadata.DebugInfo) + } + + if len(resourceIDs) == 1 { + metadata = checkResult.Metadata + } else { + metadata = &v1.ResponseMeta{ + DispatchCount: metadata.DispatchCount + checkResult.Metadata.DispatchCount, + DepthRequired: max(metadata.DepthRequired, checkResult.Metadata.DepthRequired), + CachedDispatchCount: metadata.CachedDispatchCount + checkResult.Metadata.CachedDispatchCount, + DebugInfo: nil, + } + } + + if err != nil { + return false, err + } + + for _, resourceID := range resourceIDsToCheck { + computed, err := computeCaveatedCheckResult(ctx, caveatRunner, params, resourceID, checkResult) + if err != nil { + return false, err + } + results[resourceID] = computed + } + + return true, nil + }) + return results, metadata, debugInfo, err +} + +func computeCaveatedCheckResult(ctx context.Context, runner *cexpr.CaveatRunner, params CheckParameters, resourceID string, checkResult *v1.DispatchCheckResponse) (*v1.ResourceCheckResult, error) { + result, ok := checkResult.ResultsByResourceId[resourceID] + if !ok { + return &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_NOT_MEMBER, + }, nil + } + + if result.Membership == v1.ResourceCheckResult_MEMBER { + return result, nil + } + + ds := datastoremw.MustFromContext(ctx) + reader := ds.SnapshotReader(params.AtRevision) + + caveatResult, err := runner.RunCaveatExpression(ctx, result.Expression, params.CaveatContext, reader, cexpr.RunCaveatExpressionNoDebugging) + if err != nil { + return nil, err + } + + if caveatResult.IsPartial() { + missingFields, _ := caveatResult.MissingVarNames() + return &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_CAVEATED_MEMBER, + Expression: result.Expression, + MissingExprFields: missingFields, + }, nil + } + + if caveatResult.Value() { + return &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_MEMBER, + }, nil + } + + return &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_NOT_MEMBER, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/context.go b/vendor/github.com/authzed/spicedb/internal/graph/context.go new file mode 100644 index 0000000..1485fa0 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/context.go @@ -0,0 +1,33 @@ +package graph + +import ( + "context" + + "go.opentelemetry.io/otel/trace" + + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/pkg/middleware/requestid" +) + +// branchContext returns a context disconnected from the parent context, but populated with the datastore. +// Also returns a function for canceling the newly created context, without canceling the parent context. +func branchContext(ctx context.Context) (context.Context, func(cancelErr error)) { + // Add tracing to the context. + span := trace.SpanFromContext(ctx) + detachedContext := trace.ContextWithSpan(context.Background(), span) + + // Add datastore to the context. + ds := datastoremw.FromContext(ctx) + detachedContext = datastoremw.ContextWithDatastore(detachedContext, ds) + + // Add logging to the context. + loggerFromContext := log.Ctx(ctx) + if loggerFromContext != nil { + detachedContext = loggerFromContext.WithContext(detachedContext) + } + + detachedContext = requestid.PropagateIfExists(ctx, detachedContext) + + return context.WithCancelCause(detachedContext) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/cursors.go b/vendor/github.com/authzed/spicedb/internal/graph/cursors.go new file mode 100644 index 0000000..ad3d705 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/cursors.go @@ -0,0 +1,542 @@ +package graph + +import ( + "context" + "errors" + "strconv" + "sync" + + "github.com/ccoveille/go-safecast" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/taskrunner" + "github.com/authzed/spicedb/pkg/datastore/options" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// cursorInformation is a struct which holds information about the current incoming cursor (if any) +// and the sections to be added to the *outgoing* partial cursor. +type cursorInformation struct { + // currentCursor is the current incoming cursor. This may be nil. + currentCursor *v1.Cursor + + // outgoingCursorSections are the sections to be added to the outgoing *partial* cursor. + // It is the responsibility of the *caller* to append together the incoming cursors to form + // the final cursor. + // + // A `section` is a portion of the cursor, representing a section of code that was + // executed to produce the section of the cursor. + outgoingCursorSections []string + + // limits is the limits tracker for the call over which the cursor is being used. + limits *limitTracker + + // dispatchCursorVersion is the version of the dispatch to be stored in the cursor. + dispatchCursorVersion uint32 +} + +// newCursorInformation constructs a new cursorInformation struct from the incoming cursor (which +// may be nil) +func newCursorInformation(incomingCursor *v1.Cursor, limits *limitTracker, dispatchCursorVersion uint32) (cursorInformation, error) { + if incomingCursor != nil && incomingCursor.DispatchVersion != dispatchCursorVersion { + return cursorInformation{}, NewInvalidCursorErr(dispatchCursorVersion, incomingCursor) + } + + if dispatchCursorVersion == 0 { + return cursorInformation{}, spiceerrors.MustBugf("invalid dispatch cursor version") + } + + return cursorInformation{ + currentCursor: incomingCursor, + outgoingCursorSections: nil, + limits: limits, + dispatchCursorVersion: dispatchCursorVersion, + }, nil +} + +// responsePartialCursor is the *partial* cursor to return in a response. +func (ci cursorInformation) responsePartialCursor() *v1.Cursor { + return &v1.Cursor{ + DispatchVersion: ci.dispatchCursorVersion, + Sections: ci.outgoingCursorSections, + } +} + +// withClonedLimits returns the cursor, but with its limits tracker cloned. +func (ci cursorInformation) withClonedLimits() cursorInformation { + return cursorInformation{ + currentCursor: ci.currentCursor, + outgoingCursorSections: ci.outgoingCursorSections, + limits: ci.limits.clone(), + dispatchCursorVersion: ci.dispatchCursorVersion, + } +} + +// headSectionValue returns the string value found at the head of the incoming cursor. +// If the incoming cursor is empty, returns empty. +func (ci cursorInformation) headSectionValue() (string, bool) { + if ci.currentCursor == nil || len(ci.currentCursor.Sections) < 1 { + return "", false + } + + return ci.currentCursor.Sections[0], true +} + +// integerSectionValue returns the *integer* found at the head of the incoming cursor. +// If the incoming cursor is empty, returns 0. If the incoming cursor does not start with an +// int value, fails with an error. +func (ci cursorInformation) integerSectionValue() (int, error) { + valueStr, hasValue := ci.headSectionValue() + if !hasValue { + return 0, nil + } + + if valueStr == "" { + return 0, nil + } + + return strconv.Atoi(valueStr) +} + +// withOutgoingSection returns cursorInformation updated with the given optional +// value appended to the outgoingCursorSections for the current cursor. If the current +// cursor already begins with any values, those values are replaced. +func (ci cursorInformation) withOutgoingSection(value string) (cursorInformation, error) { + ocs := make([]string, 0, len(ci.outgoingCursorSections)+1) + ocs = append(ocs, ci.outgoingCursorSections...) + ocs = append(ocs, value) + + if ci.currentCursor != nil && len(ci.currentCursor.Sections) > 0 { + // If the cursor already has values, replace them with those specified. + return cursorInformation{ + currentCursor: &v1.Cursor{ + DispatchVersion: ci.dispatchCursorVersion, + Sections: ci.currentCursor.Sections[1:], + }, + outgoingCursorSections: ocs, + limits: ci.limits, + dispatchCursorVersion: ci.dispatchCursorVersion, + }, nil + } + + return cursorInformation{ + currentCursor: nil, + outgoingCursorSections: ocs, + limits: ci.limits, + dispatchCursorVersion: ci.dispatchCursorVersion, + }, nil +} + +func (ci cursorInformation) clearIncoming() cursorInformation { + return cursorInformation{ + currentCursor: nil, + outgoingCursorSections: ci.outgoingCursorSections, + limits: ci.limits, + dispatchCursorVersion: ci.dispatchCursorVersion, + } +} + +type cursorHandler func(c cursorInformation) error + +// itemAndPostCursor represents an item and the cursor to be used for all items after it. +type itemAndPostCursor[T any] struct { + item T + cursor options.Cursor +} + +// withDatastoreCursorInCursor executes the given lookup function to retrieve items from the datastore, +// and then executes the handler on each of the produced items *in parallel*, streaming the results +// in the correct order to the parent stream. +func withDatastoreCursorInCursor[T any, Q any]( + ctx context.Context, + ci cursorInformation, + parentStream dispatch.Stream[Q], + concurrencyLimit uint16, + lookup func(queryCursor options.Cursor) ([]itemAndPostCursor[T], error), + handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error, +) error { + // Retrieve the *datastore* cursor, if one is found at the head of the incoming cursor. + var datastoreCursor options.Cursor + datastoreCursorString, _ := ci.headSectionValue() + if datastoreCursorString != "" { + datastoreCursor = options.ToCursor(tuple.MustParse(datastoreCursorString)) + } + + if ci.limits.hasExhaustedLimit() { + return nil + } + + // Execute the lookup to call the database and find items for processing. + itemsToBeProcessed, err := lookup(datastoreCursor) + if err != nil { + return err + } + + if len(itemsToBeProcessed) == 0 { + return nil + } + + itemsToRun := make([]T, 0, len(itemsToBeProcessed)) + for _, itemAndCursor := range itemsToBeProcessed { + itemsToRun = append(itemsToRun, itemAndCursor.item) + } + + getItemCursor := func(taskIndex int) (cursorInformation, error) { + // Create an updated cursor referencing the current item's cursor, so that any items returned know to resume from this point. + cursorRel := options.ToRelationship(itemsToBeProcessed[taskIndex].cursor) + cursorSection := "" + if cursorRel != nil { + cursorSection = tuple.StringWithoutCaveatOrExpiration(*cursorRel) + } + + currentCursor, err := ci.withOutgoingSection(cursorSection) + if err != nil { + return currentCursor, err + } + + // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top + // of the cursor. + if taskIndex > 0 { + currentCursor = currentCursor.clearIncoming() + } + + return currentCursor, nil + } + + return withInternalParallelizedStreamingIterableInCursor( + ctx, + ci, + itemsToRun, + parentStream, + concurrencyLimit, + getItemCursor, + handler, + ) +} + +type afterResponseCursor func(nextOffset int) *v1.Cursor + +// withSubsetInCursor executes the given handler with the offset index found at the beginning of the +// cursor. If the offset is not found, executes with 0. The handler is given the current offset as +// well as a callback to mint the cursor with the next offset. +func withSubsetInCursor( + ci cursorInformation, + handler func(currentOffset int, nextCursorWith afterResponseCursor) error, + next cursorHandler, +) error { + if ci.limits.hasExhaustedLimit() { + return nil + } + + afterIndex, err := ci.integerSectionValue() + if err != nil { + return err + } + + if afterIndex >= 0 { + var foundCerr error + err = handler(afterIndex, func(nextOffset int) *v1.Cursor { + cursor, cerr := ci.withOutgoingSection(strconv.Itoa(nextOffset)) + foundCerr = cerr + if cerr != nil { + return nil + } + + return cursor.responsePartialCursor() + }) + if err != nil { + return err + } + if foundCerr != nil { + return foundCerr + } + } + + if ci.limits.hasExhaustedLimit() { + return nil + } + + // -1 means that the handler has been completed. + uci, err := ci.withOutgoingSection("-1") + if err != nil { + return err + } + return next(uci) +} + +// combineCursors combines the given cursors into one resulting cursor. +func combineCursors(cursor *v1.Cursor, toAdd *v1.Cursor) (*v1.Cursor, error) { + if toAdd == nil || len(toAdd.Sections) == 0 { + return nil, spiceerrors.MustBugf("supplied toAdd cursor was nil or empty") + } + + if cursor == nil || len(cursor.Sections) == 0 { + return toAdd, nil + } + + sections := make([]string, 0, len(cursor.Sections)+len(toAdd.Sections)) + sections = append(sections, cursor.Sections...) + sections = append(sections, toAdd.Sections...) + + return &v1.Cursor{ + DispatchVersion: toAdd.DispatchVersion, + Sections: sections, + }, nil +} + +// withParallelizedStreamingIterableInCursor executes the given handler for each item in the items list, skipping any +// items marked as completed at the head of the cursor and injecting a cursor representing the current +// item. +// +// For example, if items contains 3 items, and the cursor returned was within the handler for item +// index #1, then item index #0 will be skipped on subsequent invocation. +// +// The next index is executed in parallel with the current index, with its results stored in a CollectingStream +// until the next iteration. +func withParallelizedStreamingIterableInCursor[T any, Q any]( + ctx context.Context, + ci cursorInformation, + items []T, + parentStream dispatch.Stream[Q], + concurrencyLimit uint16, + handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error, +) error { + // Check the cursor for a starting index, before which any items will be skipped. + startingIndex, err := ci.integerSectionValue() + if err != nil { + return err + } + + if startingIndex < 0 || startingIndex > len(items) { + return spiceerrors.MustBugf("invalid cursor in withParallelizedStreamingIterableInCursor: found starting index %d for items %v", startingIndex, items) + } + + itemsToRun := items[startingIndex:] + if len(itemsToRun) == 0 { + return nil + } + + getItemCursor := func(taskIndex int) (cursorInformation, error) { + // Create an updated cursor referencing the current item's index, so that any items returned know to resume from this point. + currentCursor, err := ci.withOutgoingSection(strconv.Itoa(taskIndex + startingIndex)) + if err != nil { + return currentCursor, err + } + + // If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top + // of the cursor. + if taskIndex > 0 { + currentCursor = currentCursor.clearIncoming() + } + + return currentCursor, nil + } + + return withInternalParallelizedStreamingIterableInCursor( + ctx, + ci, + itemsToRun, + parentStream, + concurrencyLimit, + getItemCursor, + handler, + ) +} + +func withInternalParallelizedStreamingIterableInCursor[T any, Q any]( + ctx context.Context, + ci cursorInformation, + itemsToRun []T, + parentStream dispatch.Stream[Q], + concurrencyLimit uint16, + getItemCursor func(taskIndex int) (cursorInformation, error), + handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error, +) error { + // Queue up each iteration's worth of items to be run by the task runner. + tr := taskrunner.NewPreloadedTaskRunner(ctx, concurrencyLimit, len(itemsToRun)) + stream, err := newParallelLimitedIndexedStream(ctx, ci, parentStream, len(itemsToRun)) + if err != nil { + return err + } + + // Schedule a task to be invoked for each item to be run. + for taskIndex, item := range itemsToRun { + taskIndex := taskIndex + item := item + tr.Add(func(ctx context.Context) error { + stream.lock.Lock() + if ci.limits.hasExhaustedLimit() { + stream.lock.Unlock() + return nil + } + stream.lock.Unlock() + + ici, err := getItemCursor(taskIndex) + if err != nil { + return err + } + + // Invoke the handler with the current item's index in the outgoing cursor, indicating that + // subsequent invocations should jump right to this item. + ictx, istream, icursor := stream.forTaskIndex(ctx, taskIndex, ici) + + err = handler(ictx, icursor, item, istream) + if err != nil { + // If the branch was canceled explicitly by *this* streaming iterable because other branches have fulfilled + // the configured limit, then we can safely ignore this error. + if errors.Is(context.Cause(ictx), stream.errCanceledBecauseFulfilled) { + return nil + } + return err + } + + return stream.completedTaskIndex(taskIndex) + }) + } + + err = tr.StartAndWait() + if err != nil { + return err + } + return nil +} + +// parallelLimitedIndexedStream is a specialization of a dispatch.Stream that collects results from multiple +// tasks running in parallel, and emits them in the order of the tasks. The first task's results are directly +// emitted to the parent stream, while subsequent tasks' results are emitted in the defined order of the tasks +// to ensure cursors and limits work as expected. +type parallelLimitedIndexedStream[Q any] struct { + lock sync.Mutex + + ctx context.Context + ci cursorInformation + parentStream dispatch.Stream[Q] + + streamCount int + toPublishTaskIndex int + countingStream *dispatch.CountingDispatchStream[Q] // GUARDED_BY(lock) + childStreams map[int]*dispatch.CollectingDispatchStream[Q] // GUARDED_BY(lock) + childContextCancels map[int]func(cause error) // GUARDED_BY(lock) + completedTaskIndexes map[int]bool // GUARDED_BY(lock) + errCanceledBecauseFulfilled error +} + +func newParallelLimitedIndexedStream[Q any]( + ctx context.Context, + ci cursorInformation, + parentStream dispatch.Stream[Q], + streamCount int, +) (*parallelLimitedIndexedStream[Q], error) { + if streamCount <= 0 { + return nil, spiceerrors.MustBugf("got invalid stream count") + } + + return ¶llelLimitedIndexedStream[Q]{ + ctx: ctx, + ci: ci, + parentStream: parentStream, + countingStream: nil, + childStreams: map[int]*dispatch.CollectingDispatchStream[Q]{}, + childContextCancels: map[int]func(cause error){}, + completedTaskIndexes: map[int]bool{}, + toPublishTaskIndex: 0, + streamCount: streamCount, + + // NOTE: we mint a new error here to ensure that we only skip cancelations from this very instance. + errCanceledBecauseFulfilled: errors.New("canceled because other branches fulfilled limit"), + }, nil +} + +// forTaskIndex returns a new context, stream and cursor for invoking the task at the specific index and publishing its results. +func (ls *parallelLimitedIndexedStream[Q]) forTaskIndex(ctx context.Context, index int, currentCursor cursorInformation) (context.Context, dispatch.Stream[Q], cursorInformation) { + ls.lock.Lock() + defer ls.lock.Unlock() + + // Create a new cursor with cloned limits, because each child task which executes (in parallel) will need its own + // limit tracking. The overall limit on the original cursor is managed in completedTaskIndex. + childCI := currentCursor.withClonedLimits() + childContext, cancelDispatch := branchContext(ctx) + + ls.childContextCancels[index] = cancelDispatch + + // If executing for the first index, it can stream directly to the parent stream, but we need to count the number + // of items streamed to adjust the overall limits. + if index == 0 { + countingStream := dispatch.NewCountingDispatchStream(ls.parentStream) + ls.countingStream = countingStream + return childContext, countingStream, childCI + } + + // Otherwise, create a child stream with an adjusted limits on the cursor. We have to clone the cursor's + // limits here to ensure that the child's publishing doesn't affect the first branch. + childStream := dispatch.NewCollectingDispatchStream[Q](childContext) + ls.childStreams[index] = childStream + + return childContext, childStream, childCI +} + +// cancelRemainingDispatches cancels the contexts for each dispatched branch, indicating that no additional results +// are necessary. +func (ls *parallelLimitedIndexedStream[Q]) cancelRemainingDispatches() { + for _, cancel := range ls.childContextCancels { + cancel(ls.errCanceledBecauseFulfilled) + } +} + +// completedTaskIndex indicates the the task at the specific index has completed successfully and that its collected +// results should be published to the parent stream, so long as all previous tasks have been completed and published as well. +func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error { + ls.lock.Lock() + defer ls.lock.Unlock() + + // Mark the task as completed, but not yet published. + ls.completedTaskIndexes[index] = true + + // If the overall limit has been reached, nothing more to do. + if ls.ci.limits.hasExhaustedLimit() { + ls.cancelRemainingDispatches() + return nil + } + + // Otherwise, publish any results from previous completed tasks up, and including, this task. This loop ensures + // that the collected results for each task are published to the parent stream in the correct order. + for { + if !ls.completedTaskIndexes[ls.toPublishTaskIndex] { + return nil + } + + if ls.toPublishTaskIndex == 0 { + // Remove the already emitted data from the overall limits. + publishedCount, err := safecast.ToUint32(ls.countingStream.PublishedCount()) + if err != nil { + return spiceerrors.MustBugf("cannot cast published count to uint32: %v", err) + } + if err := ls.ci.limits.markAlreadyPublished(publishedCount); err != nil { + return err + } + + if ls.ci.limits.hasExhaustedLimit() { + ls.cancelRemainingDispatches() + } + } else { + // Publish, to the parent stream, the results produced by the task and stored in the child stream. + childStream := ls.childStreams[ls.toPublishTaskIndex] + for _, result := range childStream.Results() { + if !ls.ci.limits.prepareForPublishing() { + ls.cancelRemainingDispatches() + return nil + } + + err := ls.parentStream.Publish(result) + if err != nil { + return err + } + } + ls.childStreams[ls.toPublishTaskIndex] = nil + } + + ls.toPublishTaskIndex++ + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/doc.go b/vendor/github.com/authzed/spicedb/internal/graph/doc.go new file mode 100644 index 0000000..904b216 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/doc.go @@ -0,0 +1,2 @@ +// Package graph contains the code to traverse a relationship graph to solve requests like Checks, Expansions and Lookups. +package graph diff --git a/vendor/github.com/authzed/spicedb/internal/graph/errors.go b/vendor/github.com/authzed/spicedb/internal/graph/errors.go new file mode 100644 index 0000000..31577d9 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/errors.go @@ -0,0 +1,213 @@ +package graph + +import ( + "errors" + "fmt" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/sharederrors" + dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// CheckFailureError occurs when check failed in some manner. Note this should not apply to +// namespaces and relations not being found. +type CheckFailureError struct { + error +} + +func (e CheckFailureError) Unwrap() error { + return e.error +} + +// NewCheckFailureErr constructs a new check failed error. +func NewCheckFailureErr(baseErr error) error { + return CheckFailureError{ + error: fmt.Errorf("error performing check: %w", baseErr), + } +} + +// ExpansionFailureError occurs when expansion failed in some manner. Note this should not apply to +// namespaces and relations not being found. +type ExpansionFailureError struct { + error +} + +func (e ExpansionFailureError) Unwrap() error { + return e.error +} + +// NewExpansionFailureErr constructs a new expansion failed error. +func NewExpansionFailureErr(baseErr error) error { + return ExpansionFailureError{ + error: fmt.Errorf("error performing expand: %w", baseErr), + } +} + +// AlwaysFailError is returned when an internal error leads to an operation +// guaranteed to fail. +type AlwaysFailError struct { + error +} + +// NewAlwaysFailErr constructs a new always fail error. +func NewAlwaysFailErr() error { + return AlwaysFailError{ + error: errors.New("always fail"), + } +} + +// RelationNotFoundError occurs when a relation was not found under a namespace. +type RelationNotFoundError struct { + error + namespaceName string + relationName string +} + +// NamespaceName returns the name of the namespace in which the relation was not found. +func (err RelationNotFoundError) NamespaceName() string { + return err.namespaceName +} + +// NotFoundRelationName returns the name of the relation not found. +func (err RelationNotFoundError) NotFoundRelationName() string { + return err.relationName +} + +func (err RelationNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err RelationNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + "relation_or_permission_name": err.relationName, + } +} + +// NewRelationNotFoundErr constructs a new relation not found error. +func NewRelationNotFoundErr(nsName string, relationName string) error { + return RelationNotFoundError{ + error: fmt.Errorf("relation/permission `%s` not found under definition `%s`", relationName, nsName), + namespaceName: nsName, + relationName: relationName, + } +} + +var _ sharederrors.UnknownRelationError = RelationNotFoundError{} + +// RelationMissingTypeInfoError defines an error for when type information is missing from a relation +// during a lookup. +type RelationMissingTypeInfoError struct { + error + namespaceName string + relationName string +} + +// NamespaceName returns the name of the namespace in which the relation was found. +func (err RelationMissingTypeInfoError) NamespaceName() string { + return err.namespaceName +} + +// RelationName returns the name of the relation missing type information. +func (err RelationMissingTypeInfoError) RelationName() string { + return err.relationName +} + +func (err RelationMissingTypeInfoError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("namespace", err.namespaceName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err RelationMissingTypeInfoError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.namespaceName, + "relation_name": err.relationName, + } +} + +// NewRelationMissingTypeInfoErr constructs a new relation not missing type information error. +func NewRelationMissingTypeInfoErr(nsName string, relationName string) error { + return RelationMissingTypeInfoError{ + error: fmt.Errorf("relation/permission `%s` under definition `%s` is missing type information", relationName, nsName), + namespaceName: nsName, + relationName: relationName, + } +} + +// WildcardNotAllowedError occurs when a request sent has an invalid wildcard argument. +type WildcardNotAllowedError struct { + error + + fieldName string +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err WildcardNotAllowedError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_WILDCARD_NOT_ALLOWED, + map[string]string{ + "field": err.fieldName, + }, + ), + ) +} + +// NewWildcardNotAllowedErr constructs an error indicating that a wildcard was not allowed. +func NewWildcardNotAllowedErr(message string, fieldName string) error { + return WildcardNotAllowedError{ + error: fmt.Errorf("invalid argument: %s", message), + fieldName: fieldName, + } +} + +// UnimplementedError is returned when some functionality is not yet supported. +type UnimplementedError struct { + error +} + +// NewUnimplementedErr constructs a new unimplemented error. +func NewUnimplementedErr(baseErr error) error { + return UnimplementedError{ + error: baseErr, + } +} + +func (e UnimplementedError) Unwrap() error { + return e.error +} + +// InvalidCursorError is returned when a cursor is no longer valid. +type InvalidCursorError struct { + error +} + +// NewInvalidCursorErr constructs a new unimplemented error. +func NewInvalidCursorErr(dispatchCursorVersion uint32, cursor *dispatch.Cursor) error { + return InvalidCursorError{ + error: fmt.Errorf("the supplied cursor is no longer valid: found version %d, expected version %d", cursor.DispatchVersion, dispatchCursorVersion), + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err InvalidCursorError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_INVALID_CURSOR, + map[string]string{ + "details": "cursor was used against an incompatible version of SpiceDB", + }, + ), + ) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/expand.go b/vendor/github.com/authzed/spicedb/internal/graph/expand.go new file mode 100644 index 0000000..9418bec --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/expand.go @@ -0,0 +1,436 @@ +package graph + +import ( + "context" + "errors" + + "github.com/authzed/spicedb/internal/caveats" + + "github.com/authzed/spicedb/internal/dispatch" + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// NewConcurrentExpander creates an instance of ConcurrentExpander +func NewConcurrentExpander(d dispatch.Expand) *ConcurrentExpander { + return &ConcurrentExpander{d: d} +} + +// ConcurrentExpander exposes a method to perform Expand requests, and delegates subproblems to the +// provided dispatch.Expand instance. +type ConcurrentExpander struct { + d dispatch.Expand +} + +// ValidatedExpandRequest represents a request after it has been validated and parsed for internal +// consumption. +type ValidatedExpandRequest struct { + *v1.DispatchExpandRequest + Revision datastore.Revision +} + +// Expand performs an expand request with the provided request and context. +func (ce *ConcurrentExpander) Expand(ctx context.Context, req ValidatedExpandRequest, relation *core.Relation) (*v1.DispatchExpandResponse, error) { + log.Ctx(ctx).Trace().Object("expand", req).Send() + + var directFunc ReduceableExpandFunc + if relation.UsersetRewrite == nil { + directFunc = ce.expandDirect(ctx, req) + } else { + directFunc = ce.expandUsersetRewrite(ctx, req, relation.UsersetRewrite) + } + + resolved := expandOne(ctx, directFunc) + resolved.Resp.Metadata = addCallToResponseMetadata(resolved.Resp.Metadata) + return resolved.Resp, resolved.Err +} + +func (ce *ConcurrentExpander) expandDirect( + ctx context.Context, + req ValidatedExpandRequest, +) ReduceableExpandFunc { + log.Ctx(ctx).Trace().Object("direct", req).Send() + return func(ctx context.Context, resultChan chan<- ExpandResult) { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: req.ResourceAndRelation.Namespace, + OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId}, + OptionalResourceRelation: req.ResourceAndRelation.Relation, + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata) + return + } + + var foundNonTerminalUsersets []*core.DirectSubject + var foundTerminalUsersets []*core.DirectSubject + for rel, err := range it { + if err != nil { + resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata) + return + } + + ds := &core.DirectSubject{ + Subject: rel.Subject.ToCoreONR(), + CaveatExpression: caveats.CaveatAsExpr(rel.OptionalCaveat), + } + + if rel.Subject.Relation == Ellipsis { + foundTerminalUsersets = append(foundTerminalUsersets, ds) + } else { + foundNonTerminalUsersets = append(foundNonTerminalUsersets, ds) + } + } + + // If only shallow expansion was required, or there are no non-terminal subjects found, + // nothing more to do. + if req.ExpansionMode == v1.DispatchExpandRequest_SHALLOW || len(foundNonTerminalUsersets) == 0 { + resultChan <- expandResult( + &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_LeafNode{ + LeafNode: &core.DirectSubjects{ + Subjects: append(foundTerminalUsersets, foundNonTerminalUsersets...), + }, + }, + Expanded: req.ResourceAndRelation, + }, + emptyMetadata, + ) + return + } + + // Otherwise, recursively issue expansion and collect the results from that, plus the + // found terminals together. + var requestsToDispatch []ReduceableExpandFunc + for _, nonTerminalUser := range foundNonTerminalUsersets { + toDispatch := ce.dispatch(ValidatedExpandRequest{ + &v1.DispatchExpandRequest{ + ResourceAndRelation: nonTerminalUser.Subject, + Metadata: decrementDepth(req.Metadata), + ExpansionMode: req.ExpansionMode, + }, + req.Revision, + }) + + requestsToDispatch = append(requestsToDispatch, decorateWithCaveatIfNecessary(toDispatch, nonTerminalUser.CaveatExpression)) + } + + result := expandAny(ctx, req.ResourceAndRelation, requestsToDispatch) + if result.Err != nil { + resultChan <- result + return + } + + unionNode := result.Resp.TreeNode.GetIntermediateNode() + unionNode.ChildNodes = append(unionNode.ChildNodes, &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_LeafNode{ + LeafNode: &core.DirectSubjects{ + Subjects: append(foundTerminalUsersets, foundNonTerminalUsersets...), + }, + }, + Expanded: req.ResourceAndRelation, + }) + resultChan <- result + } +} + +func decorateWithCaveatIfNecessary(toDispatch ReduceableExpandFunc, caveatExpr *core.CaveatExpression) ReduceableExpandFunc { + // If no caveat expression, simply return the func unmodified. + if caveatExpr == nil { + return toDispatch + } + + // Otherwise return a wrapped function that expands the underlying func to be dispatched, and then decorates + // the resulting node with the caveat expression. + // + // TODO(jschorr): This will generate a lot of function closures, so we should change Expand to avoid them + // like we did in Check. + return func(ctx context.Context, resultChan chan<- ExpandResult) { + result := expandOne(ctx, toDispatch) + if result.Err != nil { + resultChan <- result + return + } + + result.Resp.TreeNode.CaveatExpression = caveatExpr + resultChan <- result + } +} + +func (ce *ConcurrentExpander) expandUsersetRewrite(ctx context.Context, req ValidatedExpandRequest, usr *core.UsersetRewrite) ReduceableExpandFunc { + switch rw := usr.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + log.Ctx(ctx).Trace().Msg("union") + return ce.expandSetOperation(ctx, req, rw.Union, expandAny) + case *core.UsersetRewrite_Intersection: + log.Ctx(ctx).Trace().Msg("intersection") + return ce.expandSetOperation(ctx, req, rw.Intersection, expandAll) + case *core.UsersetRewrite_Exclusion: + log.Ctx(ctx).Trace().Msg("exclusion") + return ce.expandSetOperation(ctx, req, rw.Exclusion, expandDifference) + default: + return alwaysFailExpand + } +} + +func (ce *ConcurrentExpander) expandSetOperation(ctx context.Context, req ValidatedExpandRequest, so *core.SetOperation, reducer ExpandReducer) ReduceableExpandFunc { + var requests []ReduceableExpandFunc + for _, childOneof := range so.Child { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_XThis: + return expandError(errors.New("use of _this is unsupported; please rewrite your schema")) + case *core.SetOperation_Child_ComputedUserset: + requests = append(requests, ce.expandComputedUserset(ctx, req, child.ComputedUserset, nil)) + case *core.SetOperation_Child_UsersetRewrite: + requests = append(requests, ce.expandUsersetRewrite(ctx, req, child.UsersetRewrite)) + case *core.SetOperation_Child_TupleToUserset: + requests = append(requests, expandTupleToUserset(ctx, ce, req, child.TupleToUserset, expandAny)) + case *core.SetOperation_Child_FunctionedTupleToUserset: + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + requests = append(requests, expandTupleToUserset(ctx, ce, req, child.FunctionedTupleToUserset, expandAny)) + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + requests = append(requests, expandTupleToUserset(ctx, ce, req, child.FunctionedTupleToUserset, expandAll)) + + default: + return expandError(spiceerrors.MustBugf("unknown function `%s` in expand", child.FunctionedTupleToUserset.Function)) + } + case *core.SetOperation_Child_XNil: + requests = append(requests, emptyExpansion(req.ResourceAndRelation)) + default: + return expandError(spiceerrors.MustBugf("unknown set operation child `%T` in expand", child)) + } + } + return func(ctx context.Context, resultChan chan<- ExpandResult) { + resultChan <- reducer(ctx, req.ResourceAndRelation, requests) + } +} + +func (ce *ConcurrentExpander) dispatch(req ValidatedExpandRequest) ReduceableExpandFunc { + return func(ctx context.Context, resultChan chan<- ExpandResult) { + log.Ctx(ctx).Trace().Object("dispatchExpand", req).Send() + result, err := ce.d.DispatchExpand(ctx, req.DispatchExpandRequest) + resultChan <- ExpandResult{result, err} + } +} + +func (ce *ConcurrentExpander) expandComputedUserset(ctx context.Context, req ValidatedExpandRequest, cu *core.ComputedUserset, rel *tuple.Relationship) ReduceableExpandFunc { + log.Ctx(ctx).Trace().Str("relation", cu.Relation).Msg("computed userset") + var start tuple.ObjectAndRelation + if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT { + if rel == nil { + return expandError(spiceerrors.MustBugf("computed userset for tupleset without tuple")) + } + + start = rel.Subject + } else if cu.Object == core.ComputedUserset_TUPLE_OBJECT { + if rel != nil { + start = rel.Resource + } else { + start = tuple.FromCoreObjectAndRelation(req.ResourceAndRelation) + } + } + + // Check if the target relation exists. If not, return nothing. + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + err := namespace.CheckNamespaceAndRelation(ctx, start.ObjectType, cu.Relation, true, ds) + if err != nil { + if errors.As(err, &namespace.RelationNotFoundError{}) { + return emptyExpansion(req.ResourceAndRelation) + } + + return expandError(err) + } + + return ce.dispatch(ValidatedExpandRequest{ + &v1.DispatchExpandRequest{ + ResourceAndRelation: &core.ObjectAndRelation{ + Namespace: start.ObjectType, + ObjectId: start.ObjectID, + Relation: cu.Relation, + }, + Metadata: decrementDepth(req.Metadata), + ExpansionMode: req.ExpansionMode, + }, + req.Revision, + }) +} + +type expandFunc func(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult + +func expandTupleToUserset[T relation]( + _ context.Context, + ce *ConcurrentExpander, + req ValidatedExpandRequest, + ttu ttu[T], + expandFunc expandFunc, +) ReduceableExpandFunc { + return func(ctx context.Context, resultChan chan<- ExpandResult) { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision) + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: req.ResourceAndRelation.Namespace, + OptionalResourceIds: []string{req.ResourceAndRelation.ObjectId}, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata) + return + } + + var requestsToDispatch []ReduceableExpandFunc + for rel, err := range it { + if err != nil { + resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata) + return + } + + toDispatch := ce.expandComputedUserset(ctx, req, ttu.GetComputedUserset(), &rel) + requestsToDispatch = append(requestsToDispatch, decorateWithCaveatIfNecessary(toDispatch, caveats.CaveatAsExpr(rel.OptionalCaveat))) + } + + resultChan <- expandFunc(ctx, req.ResourceAndRelation, requestsToDispatch) + } +} + +func setResult( + op core.SetOperationUserset_Operation, + start *core.ObjectAndRelation, + children []*core.RelationTupleTreeNode, + metadata *v1.ResponseMeta, +) ExpandResult { + return expandResult( + &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_IntermediateNode{ + IntermediateNode: &core.SetOperationUserset{ + Operation: op, + ChildNodes: children, + }, + }, + Expanded: start, + }, + metadata, + ) +} + +func expandSetOperation( + ctx context.Context, + start *core.ObjectAndRelation, + requests []ReduceableExpandFunc, + op core.SetOperationUserset_Operation, +) ExpandResult { + children := make([]*core.RelationTupleTreeNode, 0, len(requests)) + + if len(requests) == 0 { + return setResult(op, start, children, emptyMetadata) + } + + childCtx, cancelFn := context.WithCancel(ctx) + defer cancelFn() + + resultChans := make([]chan ExpandResult, 0, len(requests)) + for _, req := range requests { + resultChan := make(chan ExpandResult, 1) + resultChans = append(resultChans, resultChan) + go req(childCtx, resultChan) + } + + responseMetadata := emptyMetadata + for _, resultChan := range resultChans { + select { + case result := <-resultChan: + responseMetadata = combineResponseMetadata(ctx, responseMetadata, result.Resp.Metadata) + if result.Err != nil { + return expandResultError(result.Err, responseMetadata) + } + children = append(children, result.Resp.TreeNode) + case <-ctx.Done(): + return expandResultError(context.Canceled, responseMetadata) + } + } + + return setResult(op, start, children, responseMetadata) +} + +// emptyExpansion returns an empty expansion. +func emptyExpansion(start *core.ObjectAndRelation) ReduceableExpandFunc { + return func(ctx context.Context, resultChan chan<- ExpandResult) { + resultChan <- expandResult(&core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_LeafNode{ + LeafNode: &core.DirectSubjects{}, + }, + Expanded: start, + }, emptyMetadata) + } +} + +// expandError returns the error. +func expandError(err error) ReduceableExpandFunc { + return func(ctx context.Context, resultChan chan<- ExpandResult) { + resultChan <- expandResultError(err, emptyMetadata) + } +} + +// expandAll returns a tree with all of the children and an intersection node type. +func expandAll(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult { + return expandSetOperation(ctx, start, requests, core.SetOperationUserset_INTERSECTION) +} + +// expandAny returns a tree with all of the children and a union node type. +func expandAny(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult { + return expandSetOperation(ctx, start, requests, core.SetOperationUserset_UNION) +} + +// expandDifference returns a tree with all of the children and an exclusion node type. +func expandDifference(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult { + return expandSetOperation(ctx, start, requests, core.SetOperationUserset_EXCLUSION) +} + +// expandOne waits for exactly one response +func expandOne(ctx context.Context, request ReduceableExpandFunc) ExpandResult { + resultChan := make(chan ExpandResult, 1) + go request(ctx, resultChan) + + select { + case result := <-resultChan: + if result.Err != nil { + return result + } + return result + case <-ctx.Done(): + return expandResultError(context.Canceled, emptyMetadata) + } +} + +var errAlwaysFailExpand = errors.New("always fail") + +func alwaysFailExpand(_ context.Context, resultChan chan<- ExpandResult) { + resultChan <- expandResultError(errAlwaysFailExpand, emptyMetadata) +} + +func expandResult(treeNode *core.RelationTupleTreeNode, subProblemMetadata *v1.ResponseMeta) ExpandResult { + return ExpandResult{ + &v1.DispatchExpandResponse{ + Metadata: ensureMetadata(subProblemMetadata), + TreeNode: treeNode, + }, + nil, + } +} + +func expandResultError(err error, subProblemMetadata *v1.ResponseMeta) ExpandResult { + return ExpandResult{ + &v1.DispatchExpandResponse{ + Metadata: ensureMetadata(subProblemMetadata), + }, + err, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/graph.go b/vendor/github.com/authzed/spicedb/internal/graph/graph.go new file mode 100644 index 0000000..2a44189 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/graph.go @@ -0,0 +1,89 @@ +package graph + +import ( + "context" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" +) + +// Ellipsis relation is used to signify a semantic-free relationship. +const Ellipsis = "..." + +// CheckResult is the data that is returned by a single check or sub-check. +type CheckResult struct { + Resp *v1.DispatchCheckResponse + Err error +} + +func (cr CheckResult) ResultError() error { + return cr.Err +} + +// ExpandResult is the data that is returned by a single expand or sub-expand. +type ExpandResult struct { + Resp *v1.DispatchExpandResponse + Err error +} + +func (er ExpandResult) ResultError() error { + return er.Err +} + +// ReduceableExpandFunc is a function that can be bound to a execution context. +type ReduceableExpandFunc func(ctx context.Context, resultChan chan<- ExpandResult) + +// AlwaysFailExpand is a ReduceableExpandFunc which will always fail when reduced. +func AlwaysFailExpand(_ context.Context, resultChan chan<- ExpandResult) { + resultChan <- expandResultError(NewAlwaysFailErr(), emptyMetadata) +} + +// ExpandReducer is a type for the functions Any and All which combine check results. +type ExpandReducer func( + ctx context.Context, + start *core.ObjectAndRelation, + requests []ReduceableExpandFunc, +) ExpandResult + +func decrementDepth(md *v1.ResolverMeta) *v1.ResolverMeta { + return &v1.ResolverMeta{ + AtRevision: md.AtRevision, + DepthRemaining: md.DepthRemaining - 1, + TraversalBloom: md.TraversalBloom, + } +} + +var emptyMetadata = &v1.ResponseMeta{} + +func ensureMetadata(subProblemMetadata *v1.ResponseMeta) *v1.ResponseMeta { + if subProblemMetadata == nil { + subProblemMetadata = emptyMetadata + } + + return &v1.ResponseMeta{ + DispatchCount: subProblemMetadata.DispatchCount, + DepthRequired: subProblemMetadata.DepthRequired, + CachedDispatchCount: subProblemMetadata.CachedDispatchCount, + DebugInfo: subProblemMetadata.DebugInfo, + } +} + +func addCallToResponseMetadata(metadata *v1.ResponseMeta) *v1.ResponseMeta { + // + 1 for the current call. + return &v1.ResponseMeta{ + DispatchCount: metadata.DispatchCount + 1, + DepthRequired: metadata.DepthRequired + 1, + CachedDispatchCount: metadata.CachedDispatchCount, + DebugInfo: metadata.DebugInfo, + } +} + +func addAdditionalDepthRequired(metadata *v1.ResponseMeta) *v1.ResponseMeta { + return &v1.ResponseMeta{ + DispatchCount: metadata.DispatchCount, + DepthRequired: metadata.DepthRequired + 1, + CachedDispatchCount: metadata.CachedDispatchCount, + DebugInfo: metadata.DebugInfo, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go b/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go new file mode 100644 index 0000000..485a3bf --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/hints/checkhints.go @@ -0,0 +1,96 @@ +package hints + +import ( + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// CheckHintForComputedUserset creates a CheckHint for a relation and a subject. +func CheckHintForComputedUserset(resourceType string, resourceID string, relation string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) *v1.CheckHint { + return &v1.CheckHint{ + Resource: &core.ObjectAndRelation{ + Namespace: resourceType, + ObjectId: resourceID, + Relation: relation, + }, + Subject: subject.ToCoreONR(), + Result: result, + } +} + +// CheckHintForArrow creates a CheckHint for an arrow and a subject. +func CheckHintForArrow(resourceType string, resourceID string, tuplesetRelation string, computedUsersetRelation string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) *v1.CheckHint { + return &v1.CheckHint{ + Resource: &core.ObjectAndRelation{ + Namespace: resourceType, + ObjectId: resourceID, + Relation: tuplesetRelation, + }, + TtuComputedUsersetRelation: computedUsersetRelation, + Subject: subject.ToCoreONR(), + Result: result, + } +} + +// AsCheckHintForComputedUserset returns the resourceID if the checkHint is for the given relation and subject. +func AsCheckHintForComputedUserset(checkHint *v1.CheckHint, resourceType string, relationName string, subject tuple.ObjectAndRelation) (string, bool) { + if checkHint.TtuComputedUsersetRelation != "" { + return "", false + } + + if checkHint.Resource.Namespace == resourceType && checkHint.Resource.Relation == relationName && checkHint.Subject.EqualVT(subject.ToCoreONR()) { + return checkHint.Resource.ObjectId, true + } + + return "", false +} + +// AsCheckHintForArrow returns the resourceID if the checkHint is for the given arrow and subject. +func AsCheckHintForArrow(checkHint *v1.CheckHint, resourceType string, tuplesetRelation string, computedUsersetRelation string, subject tuple.ObjectAndRelation) (string, bool) { + if checkHint.TtuComputedUsersetRelation != computedUsersetRelation { + return "", false + } + + if checkHint.Resource.Namespace == resourceType && checkHint.Resource.Relation == tuplesetRelation && checkHint.Subject.EqualVT(subject.ToCoreONR()) { + return checkHint.Resource.ObjectId, true + } + + return "", false +} + +// HintForEntrypoint returns a CheckHint for the given reachability graph entrypoint and associated subject and result. +func HintForEntrypoint(re schema.ReachabilityEntrypoint, resourceID string, subject tuple.ObjectAndRelation, result *v1.ResourceCheckResult) (*v1.CheckHint, error) { + switch re.EntrypointKind() { + case core.ReachabilityEntrypoint_RELATION_ENTRYPOINT: + return nil, spiceerrors.MustBugf("cannot call CheckHintForResource for kind %v", re.EntrypointKind()) + + case core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT: + namespace := re.TargetNamespace() + tuplesetRelation, err := re.TuplesetRelation() + if err != nil { + return nil, err + } + + computedUsersetRelation, err := re.ComputedUsersetRelation() + if err != nil { + return nil, err + } + + return CheckHintForArrow(namespace, resourceID, tuplesetRelation, computedUsersetRelation, subject, result), nil + + case core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT: + namespace := re.TargetNamespace() + relation, err := re.ComputedUsersetRelation() + if err != nil { + return nil, err + } + + return CheckHintForComputedUserset(namespace, resourceID, relation, subject, result), nil + + default: + return nil, spiceerrors.MustBugf("unknown relation entrypoint kind") + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/limits.go b/vendor/github.com/authzed/spicedb/internal/graph/limits.go new file mode 100644 index 0000000..6b2e2bd --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/limits.go @@ -0,0 +1,80 @@ +package graph + +import ( + "fmt" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +var ErrLimitReached = fmt.Errorf("limit has been reached") + +// limitTracker is a helper struct for tracking the limit requested by a caller and decrementing +// that limit as results are published. +type limitTracker struct { + hasLimit bool + currentLimit uint32 +} + +// newLimitTracker creates a new limit tracker, returning the tracker. +func newLimitTracker(optionalLimit uint32) *limitTracker { + return &limitTracker{ + currentLimit: optionalLimit, + hasLimit: optionalLimit > 0, + } +} + +// clone creates a copy of the limitTracker, inheriting the current limit. +func (lt *limitTracker) clone() *limitTracker { + return &limitTracker{ + currentLimit: lt.currentLimit, + hasLimit: lt.hasLimit, + } +} + +// prepareForPublishing asks the limit tracker to remove an element from the limit requested, +// returning whether that element can be published. +// +// Example usage: +// +// okay := limits.prepareForPublishing() +// if okay { ... publish ... } +func (lt *limitTracker) prepareForPublishing() bool { + // if there is no limit defined, then the count is always allowed. + if !lt.hasLimit { + return true + } + + // if the limit has been reached, allow no further items to be published. + if lt.currentLimit == 0 { + return false + } + + // otherwise, remove the element from the limit. + lt.currentLimit-- + return true +} + +// markAlreadyPublished marks that the given count of results has already been published. If the count is +// greater than the limit, returns a spiceerror. +func (lt *limitTracker) markAlreadyPublished(count uint32) error { + if !lt.hasLimit { + return nil + } + + if count > lt.currentLimit { + return spiceerrors.MustBugf("given published count of %d exceeds the remaining limit of %d", count, lt.currentLimit) + } + + lt.currentLimit -= count + if lt.currentLimit == 0 { + return nil + } + + return nil +} + +// hasExhaustedLimit returns true if the limit has been reached and all items allowable have been +// published. +func (lt *limitTracker) hasExhaustedLimit() bool { + return lt.hasLimit && lt.currentLimit == 0 +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go b/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go new file mode 100644 index 0000000..57acb49 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/lookupresources2.go @@ -0,0 +1,681 @@ +package graph + +import ( + "context" + "slices" + "sort" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/authzed/spicedb/internal/caveats" + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph/computed" + "github.com/authzed/spicedb/internal/graph/hints" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// dispatchVersion defines the "version" of this dispatcher. Must be incremented +// anytime an incompatible change is made to the dispatcher itself or its cursor +// production. +const dispatchVersion = 1 + +func NewCursoredLookupResources2(dl dispatch.LookupResources2, dc dispatch.Check, caveatTypeSet *caveattypes.TypeSet, concurrencyLimit uint16, dispatchChunkSize uint16) *CursoredLookupResources2 { + return &CursoredLookupResources2{dl, dc, caveatTypeSet, concurrencyLimit, dispatchChunkSize} +} + +type CursoredLookupResources2 struct { + dl dispatch.LookupResources2 + dc dispatch.Check + caveatTypeSet *caveattypes.TypeSet + concurrencyLimit uint16 + dispatchChunkSize uint16 +} + +type ValidatedLookupResources2Request struct { + *v1.DispatchLookupResources2Request + Revision datastore.Revision +} + +func (crr *CursoredLookupResources2) LookupResources2( + req ValidatedLookupResources2Request, + stream dispatch.LookupResources2Stream, +) error { + ctx, span := tracer.Start(stream.Context(), "lookupResources2") + defer span.End() + + if req.TerminalSubject == nil { + return spiceerrors.MustBugf("no terminal subject given to lookup resources dispatch") + } + + if slices.Contains(req.SubjectIds, tuple.PublicWildcard) { + return NewWildcardNotAllowedErr("cannot perform lookup resources on wildcard", "subject_id") + } + + if len(req.SubjectIds) == 0 { + return spiceerrors.MustBugf("no subjects ids given to lookup resources dispatch") + } + + // Sort for stability. + if len(req.SubjectIds) > 1 { + sort.Strings(req.SubjectIds) + } + + limits := newLimitTracker(req.OptionalLimit) + ci, err := newCursorInformation(req.OptionalCursor, limits, dispatchVersion) + if err != nil { + return err + } + + return withSubsetInCursor(ci, + func(currentOffset int, nextCursorWith afterResponseCursor) error { + // If the resource type matches the subject type, yield directly as a one-to-one result + // for each subjectID. + if req.SubjectRelation.Namespace == req.ResourceRelation.Namespace && + req.SubjectRelation.Relation == req.ResourceRelation.Relation { + for index, subjectID := range req.SubjectIds { + if index < currentOffset { + continue + } + + if !ci.limits.prepareForPublishing() { + return nil + } + + err := stream.Publish(&v1.DispatchLookupResources2Response{ + Resource: &v1.PossibleResource{ + ResourceId: subjectID, + ForSubjectIds: []string{subjectID}, + }, + Metadata: emptyMetadata, + AfterResponseCursor: nextCursorWith(index + 1), + }) + if err != nil { + return err + } + } + } + return nil + }, func(ci cursorInformation) error { + // Once done checking for the matching subject type, yield by dispatching over entrypoints. + return crr.afterSameType(ctx, ci, req, stream) + }) +} + +func (crr *CursoredLookupResources2) afterSameType( + ctx context.Context, + ci cursorInformation, + req ValidatedLookupResources2Request, + parentStream dispatch.LookupResources2Stream, +) error { + reachabilityForString := req.ResourceRelation.Namespace + "#" + req.ResourceRelation.Relation + ctx, span := tracer.Start(ctx, "reachability: "+reachabilityForString) + defer span.End() + + dispatched := NewSyncONRSet() + + // Load the type system and reachability graph to find the entrypoints for the reachability. + ds := datastoremw.MustFromContext(ctx) + reader := ds.SnapshotReader(req.Revision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(reader)) + vdef, err := ts.GetValidatedDefinition(ctx, req.ResourceRelation.Namespace) + if err != nil { + return err + } + + rg := vdef.Reachability() + entrypoints, err := rg.FirstEntrypointsForSubjectToResource(ctx, &core.RelationReference{ + Namespace: req.SubjectRelation.Namespace, + Relation: req.SubjectRelation.Relation, + }, req.ResourceRelation) + if err != nil { + return err + } + + // For each entrypoint, load the necessary data and re-dispatch if a subproblem was found. + return withParallelizedStreamingIterableInCursor(ctx, ci, entrypoints, parentStream, crr.concurrencyLimit, + func(ctx context.Context, ci cursorInformation, entrypoint schema.ReachabilityEntrypoint, stream dispatch.LookupResources2Stream) error { + ds, err := entrypoint.DebugString() + spiceerrors.DebugAssert(func() bool { + return err == nil + }, "Error in entrypoint.DebugString()") + ctx, span := tracer.Start(ctx, "entrypoint: "+ds, trace.WithAttributes()) + defer span.End() + + switch entrypoint.EntrypointKind() { + case core.ReachabilityEntrypoint_RELATION_ENTRYPOINT: + return crr.lookupRelationEntrypoint(ctx, ci, entrypoint, rg, ts, reader, req, stream, dispatched) + + case core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT: + containingRelation := entrypoint.ContainingRelationOrPermission() + rewrittenSubjectRelation := &core.RelationReference{ + Namespace: containingRelation.Namespace, + Relation: containingRelation.Relation, + } + + rsm := subjectIDsToResourcesMap2(rewrittenSubjectRelation, req.SubjectIds) + drsm := rsm.asReadOnly() + + return crr.redispatchOrReport( + ctx, + ci, + rewrittenSubjectRelation, + drsm, + rg, + entrypoint, + stream, + req, + dispatched, + ) + + case core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT: + return crr.lookupTTUEntrypoint(ctx, ci, entrypoint, rg, ts, reader, req, stream, dispatched) + + default: + return spiceerrors.MustBugf("Unknown kind of entrypoint: %v", entrypoint.EntrypointKind()) + } + }) +} + +func (crr *CursoredLookupResources2) lookupRelationEntrypoint( + ctx context.Context, + ci cursorInformation, + entrypoint schema.ReachabilityEntrypoint, + rg *schema.DefinitionReachability, + ts *schema.TypeSystem, + reader datastore.Reader, + req ValidatedLookupResources2Request, + stream dispatch.LookupResources2Stream, + dispatched *syncONRSet, +) error { + relationReference, err := entrypoint.DirectRelation() + if err != nil { + return err + } + + relDefinition, err := ts.GetValidatedDefinition(ctx, relationReference.Namespace) + if err != nil { + return err + } + + // Build the list of subjects to lookup based on the type information available. + isDirectAllowed, err := relDefinition.IsAllowedDirectRelation( + relationReference.Relation, + req.SubjectRelation.Namespace, + req.SubjectRelation.Relation, + ) + if err != nil { + return err + } + + subjectIds := make([]string, 0, len(req.SubjectIds)+1) + if isDirectAllowed == schema.DirectRelationValid { + subjectIds = append(subjectIds, req.SubjectIds...) + } + + if req.SubjectRelation.Relation == tuple.Ellipsis { + isWildcardAllowed, err := relDefinition.IsAllowedPublicNamespace(relationReference.Relation, req.SubjectRelation.Namespace) + if err != nil { + return err + } + + if isWildcardAllowed == schema.PublicSubjectAllowed { + subjectIds = append(subjectIds, "*") + } + } + + // Lookup the subjects and then redispatch/report results. + relationFilter := datastore.SubjectRelationFilter{ + NonEllipsisRelation: req.SubjectRelation.Relation, + } + + if req.SubjectRelation.Relation == tuple.Ellipsis { + relationFilter = datastore.SubjectRelationFilter{ + IncludeEllipsisRelation: true, + } + } + + subjectsFilter := datastore.SubjectsFilter{ + SubjectType: req.SubjectRelation.Namespace, + OptionalSubjectIds: subjectIds, + RelationFilter: relationFilter, + } + + return crr.redispatchOrReportOverDatabaseQuery( + ctx, + redispatchOverDatabaseConfig2{ + ci: ci, + ts: ts, + reader: reader, + subjectsFilter: subjectsFilter, + sourceResourceType: relationReference, + foundResourceType: relationReference, + entrypoint: entrypoint, + rg: rg, + concurrencyLimit: crr.concurrencyLimit, + parentStream: stream, + parentRequest: req, + dispatched: dispatched, + }, + ) +} + +type redispatchOverDatabaseConfig2 struct { + ci cursorInformation + + ts *schema.TypeSystem + + // Direct reader for reverse ReverseQueryRelationships + reader datastore.Reader + + subjectsFilter datastore.SubjectsFilter + sourceResourceType *core.RelationReference + foundResourceType *core.RelationReference + + entrypoint schema.ReachabilityEntrypoint + rg *schema.DefinitionReachability + + concurrencyLimit uint16 + parentStream dispatch.LookupResources2Stream + parentRequest ValidatedLookupResources2Request + dispatched *syncONRSet +} + +func (crr *CursoredLookupResources2) redispatchOrReportOverDatabaseQuery( + ctx context.Context, + config redispatchOverDatabaseConfig2, +) error { + ctx, span := tracer.Start(ctx, "datastorequery", trace.WithAttributes( + attribute.String("source-resource-type-namespace", config.sourceResourceType.Namespace), + attribute.String("source-resource-type-relation", config.sourceResourceType.Relation), + attribute.String("subjects-filter-subject-type", config.subjectsFilter.SubjectType), + attribute.Int("subjects-filter-subject-ids-count", len(config.subjectsFilter.OptionalSubjectIds)), + )) + defer span.End() + + return withDatastoreCursorInCursor(ctx, config.ci, config.parentStream, config.concurrencyLimit, + // Find the target resources for the subject. + func(queryCursor options.Cursor) ([]itemAndPostCursor[dispatchableResourcesSubjectMap2], error) { + it, err := config.reader.ReverseQueryRelationships( + ctx, + config.subjectsFilter, + options.WithResRelation(&options.ResourceRelation{ + Namespace: config.sourceResourceType.Namespace, + Relation: config.sourceResourceType.Relation, + }), + options.WithSortForReverse(options.BySubject), + options.WithAfterForReverse(queryCursor), + options.WithQueryShapeForReverse(queryshape.MatchingResourcesForSubject), + ) + if err != nil { + return nil, err + } + + // Chunk based on the FilterMaximumIDCount, to ensure we never send more than that amount of + // results to a downstream dispatch. + rsm := newResourcesSubjectMap2WithCapacity(config.sourceResourceType, uint32(crr.dispatchChunkSize)) + toBeHandled := make([]itemAndPostCursor[dispatchableResourcesSubjectMap2], 0) + currentCursor := queryCursor + caveatRunner := caveats.NewCaveatRunner(crr.caveatTypeSet) + + for rel, err := range it { + if err != nil { + return nil, err + } + + var missingContextParameters []string + + // If a caveat exists on the relationship, run it and filter the results, marking those that have missing context. + if rel.OptionalCaveat != nil && rel.OptionalCaveat.CaveatName != "" { + caveatExpr := caveats.CaveatAsExpr(rel.OptionalCaveat) + runResult, err := caveatRunner.RunCaveatExpression(ctx, caveatExpr, config.parentRequest.Context.AsMap(), config.reader, caveats.RunCaveatExpressionNoDebugging) + if err != nil { + return nil, err + } + + // If a partial result is returned, collect the missing context parameters. + if runResult.IsPartial() { + missingNames, err := runResult.MissingVarNames() + if err != nil { + return nil, err + } + + missingContextParameters = missingNames + } else if !runResult.Value() { + // If the run result shows the caveat does not apply, skip. This shears the tree of results early. + continue + } + } + + if err := rsm.addRelationship(rel, missingContextParameters); err != nil { + return nil, err + } + + if rsm.len() == int(crr.dispatchChunkSize) { + toBeHandled = append(toBeHandled, itemAndPostCursor[dispatchableResourcesSubjectMap2]{ + item: rsm.asReadOnly(), + cursor: currentCursor, + }) + rsm = newResourcesSubjectMap2WithCapacity(config.sourceResourceType, uint32(crr.dispatchChunkSize)) + currentCursor = options.ToCursor(rel) + } + } + + if rsm.len() > 0 { + toBeHandled = append(toBeHandled, itemAndPostCursor[dispatchableResourcesSubjectMap2]{ + item: rsm.asReadOnly(), + cursor: currentCursor, + }) + } + + return toBeHandled, nil + }, + + // Redispatch or report the results. + func( + ctx context.Context, + ci cursorInformation, + drsm dispatchableResourcesSubjectMap2, + currentStream dispatch.LookupResources2Stream, + ) error { + return crr.redispatchOrReport( + ctx, + ci, + config.foundResourceType, + drsm, + config.rg, + config.entrypoint, + currentStream, + config.parentRequest, + config.dispatched, + ) + }, + ) +} + +func (crr *CursoredLookupResources2) lookupTTUEntrypoint(ctx context.Context, + ci cursorInformation, + entrypoint schema.ReachabilityEntrypoint, + rg *schema.DefinitionReachability, + ts *schema.TypeSystem, + reader datastore.Reader, + req ValidatedLookupResources2Request, + stream dispatch.LookupResources2Stream, + dispatched *syncONRSet, +) error { + containingRelation := entrypoint.ContainingRelationOrPermission() + + ttuDef, err := ts.GetValidatedDefinition(ctx, containingRelation.Namespace) + if err != nil { + return err + } + + tuplesetRelation, err := entrypoint.TuplesetRelation() + if err != nil { + return err + } + + // Determine whether this TTU should be followed, which will be the case if the subject relation's namespace + // is allowed in any form on the relation; since arrows ignore the subject's relation (if any), we check + // for the subject namespace as a whole. + allowedRelations, err := ttuDef.GetAllowedDirectNamespaceSubjectRelations(tuplesetRelation, req.SubjectRelation.Namespace) + if err != nil { + return err + } + + if allowedRelations == nil { + return nil + } + + // Search for the resolved subjects in the tupleset of the TTU. + subjectsFilter := datastore.SubjectsFilter{ + SubjectType: req.SubjectRelation.Namespace, + OptionalSubjectIds: req.SubjectIds, + } + + // Optimization: if there is a single allowed relation, pass it as a subject relation filter to make things faster + // on querying. + if allowedRelations.Len() == 1 { + allowedRelationName := allowedRelations.AsSlice()[0] + subjectsFilter.RelationFilter = datastore.SubjectRelationFilter{}.WithRelation(allowedRelationName) + } + + tuplesetRelationReference := &core.RelationReference{ + Namespace: containingRelation.Namespace, + Relation: tuplesetRelation, + } + + return crr.redispatchOrReportOverDatabaseQuery( + ctx, + redispatchOverDatabaseConfig2{ + ci: ci, + ts: ts, + reader: reader, + subjectsFilter: subjectsFilter, + sourceResourceType: tuplesetRelationReference, + foundResourceType: containingRelation, + entrypoint: entrypoint, + rg: rg, + parentStream: stream, + parentRequest: req, + dispatched: dispatched, + }, + ) +} + +type possibleResourceAndIndex struct { + resource *v1.PossibleResource + index int +} + +// redispatchOrReport checks if further redispatching is necessary for the found resource +// type. If not, and the found resource type+relation matches the target resource type+relation, +// the resource is reported to the parent stream. +func (crr *CursoredLookupResources2) redispatchOrReport( + ctx context.Context, + ci cursorInformation, + foundResourceType *core.RelationReference, + foundResources dispatchableResourcesSubjectMap2, + rg *schema.DefinitionReachability, + entrypoint schema.ReachabilityEntrypoint, + parentStream dispatch.LookupResources2Stream, + parentRequest ValidatedLookupResources2Request, + dispatched *syncONRSet, +) error { + if foundResources.isEmpty() { + // Nothing more to do. + return nil + } + + ctx, span := tracer.Start(ctx, "redispatchOrReport", trace.WithAttributes( + attribute.Int("found-resources-count", foundResources.len()), + )) + defer span.End() + + // Check for entrypoints for the new found resource type. + hasResourceEntrypoints, err := rg.HasOptimizedEntrypointsForSubjectToResource(ctx, foundResourceType, parentRequest.ResourceRelation) + if err != nil { + return err + } + + return withSubsetInCursor(ci, + func(currentOffset int, nextCursorWith afterResponseCursor) error { + if !hasResourceEntrypoints { + // If the found resource matches the target resource type and relation, potentially yield the resource. + if foundResourceType.Namespace == parentRequest.ResourceRelation.Namespace && foundResourceType.Relation == parentRequest.ResourceRelation.Relation { + resources := foundResources.asPossibleResources() + if len(resources) == 0 { + return nil + } + + if currentOffset >= len(resources) { + return nil + } + + offsetted := resources[currentOffset:] + if len(offsetted) == 0 { + return nil + } + + filtered := make([]possibleResourceAndIndex, 0, len(offsetted)) + for index, resource := range offsetted { + filtered = append(filtered, possibleResourceAndIndex{ + resource: resource, + index: index, + }) + } + + metadata := emptyMetadata + + // If the entrypoint is not a direct result, issue a check to further filter the results on the intersection or exclusion. + if !entrypoint.IsDirectResult() { + resourceIDs := make([]string, 0, len(offsetted)) + checkHints := make([]*v1.CheckHint, 0, len(offsetted)) + for _, resource := range offsetted { + resourceIDs = append(resourceIDs, resource.ResourceId) + + checkHint, err := hints.HintForEntrypoint( + entrypoint, + resource.ResourceId, + tuple.FromCoreObjectAndRelation(parentRequest.TerminalSubject), + &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_MEMBER, + }) + if err != nil { + return err + } + checkHints = append(checkHints, checkHint) + } + + resultsByResourceID, checkMetadata, _, err := computed.ComputeBulkCheck(ctx, crr.dc, crr.caveatTypeSet, computed.CheckParameters{ + ResourceType: tuple.FromCoreRelationReference(parentRequest.ResourceRelation), + Subject: tuple.FromCoreObjectAndRelation(parentRequest.TerminalSubject), + CaveatContext: parentRequest.Context.AsMap(), + AtRevision: parentRequest.Revision, + MaximumDepth: parentRequest.Metadata.DepthRemaining - 1, + DebugOption: computed.NoDebugging, + CheckHints: checkHints, + }, resourceIDs, crr.dispatchChunkSize) + if err != nil { + return err + } + + metadata = addCallToResponseMetadata(checkMetadata) + + filtered = make([]possibleResourceAndIndex, 0, len(offsetted)) + for index, resource := range offsetted { + result, ok := resultsByResourceID[resource.ResourceId] + if !ok { + continue + } + + switch result.Membership { + case v1.ResourceCheckResult_MEMBER: + filtered = append(filtered, possibleResourceAndIndex{ + resource: resource, + index: index, + }) + + case v1.ResourceCheckResult_CAVEATED_MEMBER: + missingContextParams := mapz.NewSet(result.MissingExprFields...) + missingContextParams.Extend(resource.MissingContextParams) + + filtered = append(filtered, possibleResourceAndIndex{ + resource: &v1.PossibleResource{ + ResourceId: resource.ResourceId, + ForSubjectIds: resource.ForSubjectIds, + MissingContextParams: missingContextParams.AsSlice(), + }, + index: index, + }) + + case v1.ResourceCheckResult_NOT_MEMBER: + // Skip. + + default: + return spiceerrors.MustBugf("unexpected result from check: %v", result.Membership) + } + } + } + + for _, resourceAndIndex := range filtered { + if !ci.limits.prepareForPublishing() { + return nil + } + + err := parentStream.Publish(&v1.DispatchLookupResources2Response{ + Resource: resourceAndIndex.resource, + Metadata: metadata, + AfterResponseCursor: nextCursorWith(currentOffset + resourceAndIndex.index + 1), + }) + if err != nil { + return err + } + + metadata = emptyMetadata + } + return nil + } + } + return nil + }, func(ci cursorInformation) error { + if !hasResourceEntrypoints { + return nil + } + + // The new subject type for dispatching was the found type of the *resource*. + newSubjectType := foundResourceType + + // To avoid duplicate work, remove any subjects already dispatched. + filteredSubjectIDs := foundResources.filterSubjectIDsToDispatch(dispatched, newSubjectType) + if len(filteredSubjectIDs) == 0 { + return nil + } + + // If the entrypoint is a direct result then we can simply dispatch directly and map + // all found results, as no further filtering will be needed. + if entrypoint.IsDirectResult() { + stream := unfilteredLookupResourcesDispatchStreamForEntrypoint(ctx, foundResources, parentStream, ci) + return crr.dl.DispatchLookupResources2(&v1.DispatchLookupResources2Request{ + ResourceRelation: parentRequest.ResourceRelation, + SubjectRelation: newSubjectType, + SubjectIds: filteredSubjectIDs, + TerminalSubject: parentRequest.TerminalSubject, + Metadata: &v1.ResolverMeta{ + AtRevision: parentRequest.Revision.String(), + DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + }, + OptionalCursor: ci.currentCursor, + OptionalLimit: parentRequest.OptionalLimit, + Context: parentRequest.Context, + }, stream) + } + + // Otherwise, we need to filter results by batch checking along the way before dispatching. + return runCheckerAndDispatch( + ctx, + parentRequest, + foundResources, + ci, + parentStream, + newSubjectType, + filteredSubjectIDs, + entrypoint, + crr.dl, + crr.dc, + crr.caveatTypeSet, + crr.concurrencyLimit, + crr.dispatchChunkSize, + ) + }) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go b/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go new file mode 100644 index 0000000..7560847 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/lookupsubjects.go @@ -0,0 +1,803 @@ +package graph + +import ( + "context" + "errors" + "fmt" + "sync" + + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/authzed/spicedb/internal/datasets" + "github.com/authzed/spicedb/internal/dispatch" + log "github.com/authzed/spicedb/internal/logging" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/taskrunner" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/genutil/slicez" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ValidatedLookupSubjectsRequest represents a request after it has been validated and parsed for internal +// consumption. +type ValidatedLookupSubjectsRequest struct { + *v1.DispatchLookupSubjectsRequest + Revision datastore.Revision +} + +// NewConcurrentLookupSubjects creates an instance of ConcurrentLookupSubjects. +func NewConcurrentLookupSubjects(d dispatch.LookupSubjects, concurrencyLimit uint16, dispatchChunkSize uint16) *ConcurrentLookupSubjects { + return &ConcurrentLookupSubjects{d, concurrencyLimit, dispatchChunkSize} +} + +type ConcurrentLookupSubjects struct { + d dispatch.LookupSubjects + concurrencyLimit uint16 + dispatchChunkSize uint16 +} + +func (cl *ConcurrentLookupSubjects) LookupSubjects( + req ValidatedLookupSubjectsRequest, + stream dispatch.LookupSubjectsStream, +) error { + ctx := stream.Context() + + if len(req.ResourceIds) == 0 { + return fmt.Errorf("no resources ids given to lookupsubjects dispatch") + } + + // If the resource type matches the subject type, yield directly. + if req.SubjectRelation.Namespace == req.ResourceRelation.Namespace && + req.SubjectRelation.Relation == req.ResourceRelation.Relation { + if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: subjectsForConcreteIds(req.ResourceIds), + Metadata: emptyMetadata, + }); err != nil { + return err + } + } + + ds := datastoremw.MustFromContext(ctx) + reader := ds.SnapshotReader(req.Revision) + _, relation, err := namespace.ReadNamespaceAndRelation( + ctx, + req.ResourceRelation.Namespace, + req.ResourceRelation.Relation, + reader) + if err != nil { + return err + } + + if relation.UsersetRewrite == nil { + // Direct lookup of subjects. + return cl.lookupDirectSubjects(ctx, req, stream, relation, reader) + } + + return cl.lookupViaRewrite(ctx, req, stream, relation.UsersetRewrite) +} + +func subjectsForConcreteIds(subjectIds []string) map[string]*v1.FoundSubjects { + foundSubjects := make(map[string]*v1.FoundSubjects, len(subjectIds)) + for _, subjectID := range subjectIds { + foundSubjects[subjectID] = &v1.FoundSubjects{ + FoundSubjects: []*v1.FoundSubject{ + { + SubjectId: subjectID, + CaveatExpression: nil, // Explicitly nil since this is a concrete found subject. + }, + }, + } + } + return foundSubjects +} + +func (cl *ConcurrentLookupSubjects) lookupDirectSubjects( + ctx context.Context, + req ValidatedLookupSubjectsRequest, + stream dispatch.LookupSubjectsStream, + _ *core.Relation, + reader datastore.Reader, +) error { + // TODO(jschorr): use type information to skip subject relations that cannot reach the subject type. + + toDispatchByType := datasets.NewSubjectByTypeSet() + foundSubjectsByResourceID := datasets.NewSubjectSetByResourceID() + relationshipsBySubjectONR := mapz.NewMultiMap[tuple.ObjectAndRelation, tuple.Relationship]() + + it, err := reader.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: req.ResourceRelation.Namespace, + OptionalResourceRelation: req.ResourceRelation.Relation, + OptionalResourceIds: req.ResourceIds, + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + return err + } + + for rel, err := range it { + if err != nil { + return err + } + + if rel.Subject.ObjectType == req.SubjectRelation.Namespace && + rel.Subject.Relation == req.SubjectRelation.Relation { + if err := foundSubjectsByResourceID.AddFromRelationship(rel); err != nil { + return fmt.Errorf("failed to call AddFromRelationship in lookupDirectSubjects: %w", err) + } + } + + if rel.Subject.Relation != tuple.Ellipsis { + err := toDispatchByType.AddSubjectOf(rel) + if err != nil { + return err + } + + relationshipsBySubjectONR.Add(rel.Subject, rel) + } + } + + if !foundSubjectsByResourceID.IsEmpty() { + if err := stream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: foundSubjectsByResourceID.AsMap(), + Metadata: emptyMetadata, + }); err != nil { + return err + } + } + + return cl.dispatchTo(ctx, req, toDispatchByType, relationshipsBySubjectONR, stream) +} + +func (cl *ConcurrentLookupSubjects) lookupViaComputed( + ctx context.Context, + parentRequest ValidatedLookupSubjectsRequest, + parentStream dispatch.LookupSubjectsStream, + cu *core.ComputedUserset, +) error { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + if err := namespace.CheckNamespaceAndRelation(ctx, parentRequest.ResourceRelation.Namespace, cu.Relation, true, ds); err != nil { + if errors.As(err, &namespace.RelationNotFoundError{}) { + return nil + } + + return err + } + + stream := &dispatch.WrappedDispatchStream[*v1.DispatchLookupSubjectsResponse]{ + Stream: parentStream, + Ctx: ctx, + Processor: func(result *v1.DispatchLookupSubjectsResponse) (*v1.DispatchLookupSubjectsResponse, bool, error) { + return &v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: result.FoundSubjectsByResourceId, + Metadata: addCallToResponseMetadata(result.Metadata), + }, true, nil + }, + } + + return cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{ + ResourceRelation: &core.RelationReference{ + Namespace: parentRequest.ResourceRelation.Namespace, + Relation: cu.Relation, + }, + ResourceIds: parentRequest.ResourceIds, + SubjectRelation: parentRequest.SubjectRelation, + Metadata: &v1.ResolverMeta{ + AtRevision: parentRequest.Revision.String(), + DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + }, + }, stream) +} + +type resourceDispatchTracker struct { + ctx context.Context + cancelDispatch context.CancelFunc + resourceID string + + subjectsSet datasets.SubjectSet // GUARDED_BY(lock) + metadata *v1.ResponseMeta // GUARDED_BY(lock) + + isFirstUpdate bool // GUARDED_BY(lock) + wasCanceled bool // GUARDED_BY(lock) + + lock sync.Mutex +} + +func lookupViaIntersectionTupleToUserset( + ctx context.Context, + cl *ConcurrentLookupSubjects, + parentRequest ValidatedLookupSubjectsRequest, + parentStream dispatch.LookupSubjectsStream, + ttu *core.FunctionedTupleToUserset, +) error { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: parentRequest.ResourceRelation.Namespace, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + OptionalResourceIds: parentRequest.ResourceIds, + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + return err + } + + // TODO(jschorr): Find a means of doing this without dispatching per subject, per resource. Perhaps + // there is a way we can still dispatch to all the subjects at once, and then intersect the results + // afterwards. + resourceDispatchTrackerByResourceID := make(map[string]*resourceDispatchTracker) + + cancelCtx, checkCancel := context.WithCancel(ctx) + defer checkCancel() + + // For each found tuple, dispatch a lookup subjects request and collect its results. + // We need to intersect between *all* the found subjects for each resource ID. + var ttuCaveat *core.CaveatExpression + taskrunner := taskrunner.NewPreloadedTaskRunner(cancelCtx, cl.concurrencyLimit, 1) + for rel, err := range it { + if err != nil { + return err + } + + // If the relationship has a caveat, add it to the overall TTU caveat. Since this is an intersection + // of *all* branches, the caveat will be applied to all found subjects, so this is a safe approach. + if rel.OptionalCaveat != nil { + ttuCaveat = caveatAnd(ttuCaveat, wrapCaveat(rel.OptionalCaveat)) + } + + if err := namespace.CheckNamespaceAndRelation(ctx, rel.Subject.ObjectType, ttu.GetComputedUserset().Relation, false, ds); err != nil { + if !errors.As(err, &namespace.RelationNotFoundError{}) { + return err + } + + continue + } + + // Create a data structure to track the intersection of subjects for the particular resource. If the resource's subject set + // ends up empty anywhere along the way, the dispatches for *that resource* will be canceled early. + resourceID := rel.Resource.ObjectID + dispatchInfoForResource, ok := resourceDispatchTrackerByResourceID[resourceID] + if !ok { + dispatchCtx, cancelDispatch := context.WithCancel(cancelCtx) + dispatchInfoForResource = &resourceDispatchTracker{ + ctx: dispatchCtx, + cancelDispatch: cancelDispatch, + resourceID: resourceID, + subjectsSet: datasets.NewSubjectSet(), + metadata: emptyMetadata, + isFirstUpdate: true, + lock: sync.Mutex{}, + } + resourceDispatchTrackerByResourceID[resourceID] = dispatchInfoForResource + } + + rel := rel + taskrunner.Add(func(ctx context.Context) error { + //Â Collect all results for this branch of the resource ID. + // TODO(jschorr): once LS has cursoring (and thus, ordering), we can move to not collecting everything up before intersecting + // for this branch of the resource ID. + collectingStream := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](dispatchInfoForResource.ctx) + err := cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{ + ResourceRelation: &core.RelationReference{ + Namespace: rel.Subject.ObjectType, + Relation: ttu.GetComputedUserset().Relation, + }, + ResourceIds: []string{rel.Subject.ObjectID}, + SubjectRelation: parentRequest.SubjectRelation, + Metadata: &v1.ResolverMeta{ + AtRevision: parentRequest.Revision.String(), + DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + }, + }, collectingStream) + if err != nil { + // Check if the dispatches for the resource were canceled, and if so, return nil to stop the task. + dispatchInfoForResource.lock.Lock() + wasCanceled := dispatchInfoForResource.wasCanceled + dispatchInfoForResource.lock.Unlock() + + if wasCanceled { + if errors.Is(err, context.Canceled) { + return nil + } + + errStatus, ok := status.FromError(err) + if ok && errStatus.Code() == codes.Canceled { + return nil + } + } + + return err + } + + // Collect the results into a subject set. + results := datasets.NewSubjectSet() + collectedMetadata := emptyMetadata + for _, result := range collectingStream.Results() { + collectedMetadata = combineResponseMetadata(ctx, collectedMetadata, result.Metadata) + for _, foundSubjects := range result.FoundSubjectsByResourceId { + if err := results.UnionWith(foundSubjects.FoundSubjects); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsIntersection: %w", err) + } + } + } + + dispatchInfoForResource.lock.Lock() + defer dispatchInfoForResource.lock.Unlock() + + dispatchInfoForResource.metadata = combineResponseMetadata(ctx, dispatchInfoForResource.metadata, collectedMetadata) + + // If the first update for the resource, set the subjects set to the results. + if dispatchInfoForResource.isFirstUpdate { + dispatchInfoForResource.isFirstUpdate = false + dispatchInfoForResource.subjectsSet = results + } else { + // Otherwise, intersect the results with the existing subjects set. + err := dispatchInfoForResource.subjectsSet.IntersectionDifference(results) + if err != nil { + return err + } + } + + // If the subjects set is empty, cancel the dispatch for any further results for this resource ID. + if dispatchInfoForResource.subjectsSet.IsEmpty() { + dispatchInfoForResource.wasCanceled = true + dispatchInfoForResource.cancelDispatch() + } + + return nil + }) + } + + // Wait for all dispatched operations to complete. + if err := taskrunner.StartAndWait(); err != nil { + return err + } + + // For each resource ID, intersect the found subjects from each stream. + metadata := emptyMetadata + currentSubjectsByResourceID := map[string]*v1.FoundSubjects{} + + for incomingResourceID, tracker := range resourceDispatchTrackerByResourceID { + currentSubjects := tracker.subjectsSet + currentSubjects = currentSubjects.WithParentCaveatExpression(ttuCaveat) + currentSubjectsByResourceID[incomingResourceID] = currentSubjects.AsFoundSubjects() + + metadata = combineResponseMetadata(ctx, metadata, tracker.metadata) + } + + return parentStream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: currentSubjectsByResourceID, + Metadata: metadata, + }) +} + +func lookupViaTupleToUserset[T relation]( + ctx context.Context, + cl *ConcurrentLookupSubjects, + parentRequest ValidatedLookupSubjectsRequest, + parentStream dispatch.LookupSubjectsStream, + ttu ttu[T], +) error { + toDispatchByTuplesetType := datasets.NewSubjectByTypeSet() + relationshipsBySubjectONR := mapz.NewMultiMap[tuple.ObjectAndRelation, tuple.Relationship]() + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(parentRequest.Revision) + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ + OptionalResourceType: parentRequest.ResourceRelation.Namespace, + OptionalResourceRelation: ttu.GetTupleset().GetRelation(), + OptionalResourceIds: parentRequest.ResourceIds, + }, options.WithQueryShape(queryshape.AllSubjectsForResources)) + if err != nil { + return err + } + + for rel, err := range it { + if err != nil { + return err + } + + // Add the subject to be dispatched. + err := toDispatchByTuplesetType.AddSubjectOf(rel) + if err != nil { + return err + } + + // Add the *rewritten* subject to the relationships multimap for mapping back to the associated + // relationship, as we will be mapping from the computed relation, not the tupleset relation. + relationshipsBySubjectONR.Add(tuple.ONR(rel.Subject.ObjectType, rel.Subject.ObjectID, ttu.GetComputedUserset().Relation), rel) + } + + // Map the found subject types by the computed userset relation, so that we dispatch to it. + toDispatchByComputedRelationType, err := toDispatchByTuplesetType.Map(func(resourceType *core.RelationReference) (*core.RelationReference, error) { + if err := namespace.CheckNamespaceAndRelation(ctx, resourceType.Namespace, ttu.GetComputedUserset().Relation, false, ds); err != nil { + if errors.As(err, &namespace.RelationNotFoundError{}) { + return nil, nil + } + + return nil, err + } + + return &core.RelationReference{ + Namespace: resourceType.Namespace, + Relation: ttu.GetComputedUserset().Relation, + }, nil + }) + if err != nil { + return err + } + + return cl.dispatchTo(ctx, parentRequest, toDispatchByComputedRelationType, relationshipsBySubjectONR, parentStream) +} + +func (cl *ConcurrentLookupSubjects) lookupViaRewrite( + ctx context.Context, + req ValidatedLookupSubjectsRequest, + stream dispatch.LookupSubjectsStream, + usr *core.UsersetRewrite, +) error { + switch rw := usr.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + log.Ctx(ctx).Trace().Msg("union") + return cl.lookupSetOperation(ctx, req, rw.Union, newLookupSubjectsUnion(stream)) + case *core.UsersetRewrite_Intersection: + log.Ctx(ctx).Trace().Msg("intersection") + return cl.lookupSetOperation(ctx, req, rw.Intersection, newLookupSubjectsIntersection(stream)) + case *core.UsersetRewrite_Exclusion: + log.Ctx(ctx).Trace().Msg("exclusion") + return cl.lookupSetOperation(ctx, req, rw.Exclusion, newLookupSubjectsExclusion(stream)) + default: + return fmt.Errorf("unknown kind of rewrite in lookup subjects") + } +} + +func (cl *ConcurrentLookupSubjects) lookupSetOperation( + ctx context.Context, + req ValidatedLookupSubjectsRequest, + so *core.SetOperation, + reducer lookupSubjectsReducer, +) error { + cancelCtx, checkCancel := context.WithCancel(ctx) + defer checkCancel() + + g, subCtx := errgroup.WithContext(cancelCtx) + g.SetLimit(int(cl.concurrencyLimit)) + + for index, childOneof := range so.Child { + stream := reducer.ForIndex(subCtx, index) + + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_XThis: + return errors.New("use of _this is unsupported; please rewrite your schema") + + case *core.SetOperation_Child_ComputedUserset: + g.Go(func() error { + return cl.lookupViaComputed(subCtx, req, stream, child.ComputedUserset) + }) + + case *core.SetOperation_Child_UsersetRewrite: + g.Go(func() error { + return cl.lookupViaRewrite(subCtx, req, stream, child.UsersetRewrite) + }) + + case *core.SetOperation_Child_TupleToUserset: + g.Go(func() error { + return lookupViaTupleToUserset(subCtx, cl, req, stream, child.TupleToUserset) + }) + + case *core.SetOperation_Child_FunctionedTupleToUserset: + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + g.Go(func() error { + return lookupViaTupleToUserset(subCtx, cl, req, stream, child.FunctionedTupleToUserset) + }) + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + g.Go(func() error { + return lookupViaIntersectionTupleToUserset(subCtx, cl, req, stream, child.FunctionedTupleToUserset) + }) + + default: + return spiceerrors.MustBugf("unknown function in lookup subjects: %v", child.FunctionedTupleToUserset.Function) + } + + case *core.SetOperation_Child_XNil: + // Purposely do nothing. + continue + + default: + return spiceerrors.MustBugf("unknown set operation child `%T` in lookup subjects", child) + } + } + + // Wait for all dispatched operations to complete. + if err := g.Wait(); err != nil { + return err + } + + return reducer.CompletedChildOperations(ctx) +} + +func (cl *ConcurrentLookupSubjects) dispatchTo( + ctx context.Context, + parentRequest ValidatedLookupSubjectsRequest, + toDispatchByType *datasets.SubjectByTypeSet, + relationshipsBySubjectONR *mapz.MultiMap[tuple.ObjectAndRelation, tuple.Relationship], + parentStream dispatch.LookupSubjectsStream, +) error { + if toDispatchByType.IsEmpty() { + return nil + } + + cancelCtx, checkCancel := context.WithCancel(ctx) + defer checkCancel() + + g, subCtx := errgroup.WithContext(cancelCtx) + g.SetLimit(int(cl.concurrencyLimit)) + + toDispatchByType.ForEachType(func(resourceType *core.RelationReference, foundSubjects datasets.SubjectSet) { + slice := foundSubjects.AsSlice() + resourceIds := make([]string, 0, len(slice)) + for _, foundSubject := range slice { + resourceIds = append(resourceIds, foundSubject.SubjectId) + } + + stream := &dispatch.WrappedDispatchStream[*v1.DispatchLookupSubjectsResponse]{ + Stream: parentStream, + Ctx: subCtx, + Processor: func(result *v1.DispatchLookupSubjectsResponse) (*v1.DispatchLookupSubjectsResponse, bool, error) { + // For any found subjects, map them through their associated starting resources, to apply any caveats that were + // only those resources' relationships. + // + // For example, given relationships which formed the dispatch: + // - document:firstdoc#viewer@group:group1#member + // - document:firstdoc#viewer@group:group2#member[somecaveat] + // + // And results: + // - group1 => {user:tom, user:sarah} + // - group2 => {user:tom, user:fred} + // + // This will produce: + // - firstdoc => {user:tom, user:sarah, user:fred[somecaveat]} + // + mappedFoundSubjects := make(map[string]*v1.FoundSubjects) + for childResourceID, foundSubjects := range result.FoundSubjectsByResourceId { + subjectKey := tuple.ONR(resourceType.Namespace, childResourceID, resourceType.Relation) + relationships, _ := relationshipsBySubjectONR.Get(subjectKey) + if len(relationships) == 0 { + return nil, false, fmt.Errorf("missing relationships for subject key %v; please report this error", subjectKey) + } + + for _, relationship := range relationships { + existing := mappedFoundSubjects[relationship.Resource.ObjectID] + + // If the relationship has no caveat, simply map the resource ID. + if relationship.OptionalCaveat == nil { + combined, err := combineFoundSubjects(existing, foundSubjects) + if err != nil { + return nil, false, fmt.Errorf("could not combine caveat-less subjects: %w", err) + } + mappedFoundSubjects[relationship.Resource.ObjectID] = combined + continue + } + + // Otherwise, apply the caveat to all found subjects for that resource and map to the resource ID. + foundSubjectSet := datasets.NewSubjectSet() + err := foundSubjectSet.UnionWith(foundSubjects.FoundSubjects) + if err != nil { + return nil, false, fmt.Errorf("could not combine subject sets: %w", err) + } + + combined, err := combineFoundSubjects( + existing, + foundSubjectSet.WithParentCaveatExpression(wrapCaveat(relationship.OptionalCaveat)).AsFoundSubjects(), + ) + if err != nil { + return nil, false, fmt.Errorf("could not combine caveated subjects: %w", err) + } + + mappedFoundSubjects[relationship.Resource.ObjectID] = combined + } + } + + return &v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: mappedFoundSubjects, + Metadata: addCallToResponseMetadata(result.Metadata), + }, true, nil + }, + } + + // Dispatch the found subjects as the resources of the next step. + slicez.ForEachChunk(resourceIds, cl.dispatchChunkSize, func(resourceIdChunk []string) { + g.Go(func() error { + return cl.d.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{ + ResourceRelation: resourceType, + ResourceIds: resourceIdChunk, + SubjectRelation: parentRequest.SubjectRelation, + Metadata: &v1.ResolverMeta{ + AtRevision: parentRequest.Revision.String(), + DepthRemaining: parentRequest.Metadata.DepthRemaining - 1, + }, + }, stream) + }) + }) + }) + + return g.Wait() +} + +func combineFoundSubjects(existing *v1.FoundSubjects, toAdd *v1.FoundSubjects) (*v1.FoundSubjects, error) { + if existing == nil { + return toAdd, nil + } + + if toAdd == nil { + return nil, fmt.Errorf("toAdd FoundSubject cannot be nil") + } + + return &v1.FoundSubjects{ + FoundSubjects: append(existing.FoundSubjects, toAdd.FoundSubjects...), + }, nil +} + +type lookupSubjectsReducer interface { + ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream + CompletedChildOperations(ctx context.Context) error +} + +// Union +type lookupSubjectsUnion struct { + parentStream dispatch.LookupSubjectsStream + collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse] +} + +func newLookupSubjectsUnion(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsUnion { + return &lookupSubjectsUnion{ + parentStream: parentStream, + collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{}, + } +} + +func (lsu *lookupSubjectsUnion) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream { + collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx) + lsu.collectors[setOperationIndex] = collector + return collector +} + +func (lsu *lookupSubjectsUnion) CompletedChildOperations(ctx context.Context) error { + foundSubjects := datasets.NewSubjectSetByResourceID() + metadata := emptyMetadata + + for index := 0; index < len(lsu.collectors); index++ { + collector, ok := lsu.collectors[index] + if !ok { + return fmt.Errorf("missing collector for index %d", index) + } + + for _, result := range collector.Results() { + metadata = combineResponseMetadata(ctx, metadata, result.Metadata) + if err := foundSubjects.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsUnion: %w", err) + } + } + } + + if foundSubjects.IsEmpty() { + return nil + } + + return lsu.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: foundSubjects.AsMap(), + Metadata: metadata, + }) +} + +// Intersection +type lookupSubjectsIntersection struct { + parentStream dispatch.LookupSubjectsStream + collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse] +} + +func newLookupSubjectsIntersection(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsIntersection { + return &lookupSubjectsIntersection{ + parentStream: parentStream, + collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{}, + } +} + +func (lsi *lookupSubjectsIntersection) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream { + collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx) + lsi.collectors[setOperationIndex] = collector + return collector +} + +func (lsi *lookupSubjectsIntersection) CompletedChildOperations(ctx context.Context) error { + var foundSubjects datasets.SubjectSetByResourceID + metadata := emptyMetadata + + for index := 0; index < len(lsi.collectors); index++ { + collector, ok := lsi.collectors[index] + if !ok { + return fmt.Errorf("missing collector for index %d", index) + } + + results := datasets.NewSubjectSetByResourceID() + for _, result := range collector.Results() { + metadata = combineResponseMetadata(ctx, metadata, result.Metadata) + if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsIntersection: %w", err) + } + } + + if index == 0 { + foundSubjects = results + } else { + err := foundSubjects.IntersectionDifference(results) + if err != nil { + return err + } + + if foundSubjects.IsEmpty() { + return nil + } + } + } + + return lsi.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: foundSubjects.AsMap(), + Metadata: metadata, + }) +} + +// Exclusion +type lookupSubjectsExclusion struct { + parentStream dispatch.LookupSubjectsStream + collectors map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse] +} + +func newLookupSubjectsExclusion(parentStream dispatch.LookupSubjectsStream) *lookupSubjectsExclusion { + return &lookupSubjectsExclusion{ + parentStream: parentStream, + collectors: map[int]*dispatch.CollectingDispatchStream[*v1.DispatchLookupSubjectsResponse]{}, + } +} + +func (lse *lookupSubjectsExclusion) ForIndex(ctx context.Context, setOperationIndex int) dispatch.LookupSubjectsStream { + collector := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](ctx) + lse.collectors[setOperationIndex] = collector + return collector +} + +func (lse *lookupSubjectsExclusion) CompletedChildOperations(ctx context.Context) error { + var foundSubjects datasets.SubjectSetByResourceID + metadata := emptyMetadata + + for index := 0; index < len(lse.collectors); index++ { + collector := lse.collectors[index] + results := datasets.NewSubjectSetByResourceID() + for _, result := range collector.Results() { + metadata = combineResponseMetadata(ctx, metadata, result.Metadata) + if err := results.UnionWith(result.FoundSubjectsByResourceId); err != nil { + return fmt.Errorf("failed to UnionWith under lookupSubjectsExclusion: %w", err) + } + } + + if index == 0 { + foundSubjects = results + } else { + foundSubjects.SubtractAll(results) + if foundSubjects.IsEmpty() { + return nil + } + } + } + + return lse.parentStream.Publish(&v1.DispatchLookupSubjectsResponse{ + FoundSubjectsByResourceId: foundSubjects.AsMap(), + Metadata: metadata, + }) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go b/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go new file mode 100644 index 0000000..f04ee6f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/lr2streams.go @@ -0,0 +1,334 @@ +package graph + +import ( + "context" + "strconv" + "sync" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph/computed" + "github.com/authzed/spicedb/internal/graph/hints" + "github.com/authzed/spicedb/internal/taskrunner" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// runCheckerAndDispatch runs the dispatch and checker for a lookup resources call, and publishes +// the results to the parent stream. This function is responsible for handling checking the +// results to filter them, and then dispatching those found. +func runCheckerAndDispatch( + ctx context.Context, + parentReq ValidatedLookupResources2Request, + foundResources dispatchableResourcesSubjectMap2, + ci cursorInformation, + parentStream dispatch.LookupResources2Stream, + newSubjectType *core.RelationReference, + filteredSubjectIDs []string, + entrypoint schema.ReachabilityEntrypoint, + lrDispatcher dispatch.LookupResources2, + checkDispatcher dispatch.Check, + caveatTypeSet *caveattypes.TypeSet, + concurrencyLimit uint16, + dispatchChunkSize uint16, +) error { + // Only allow max one dispatcher and one checker to run concurrently. + concurrencyLimit = min(concurrencyLimit, 2) + + currentCheckIndex, err := ci.integerSectionValue() + if err != nil { + return err + } + + rdc := &checkAndDispatchRunner{ + parentRequest: parentReq, + foundResources: foundResources, + ci: ci, + parentStream: parentStream, + newSubjectType: newSubjectType, + filteredSubjectIDs: filteredSubjectIDs, + currentCheckIndex: currentCheckIndex, + entrypoint: entrypoint, + lrDispatcher: lrDispatcher, + checkDispatcher: checkDispatcher, + taskrunner: taskrunner.NewTaskRunner(ctx, concurrencyLimit), + lock: &sync.Mutex{}, + dispatchChunkSize: dispatchChunkSize, + caveatTypeSet: caveatTypeSet, + } + + return rdc.runAndWait() +} + +type checkAndDispatchRunner struct { + parentRequest ValidatedLookupResources2Request + foundResources dispatchableResourcesSubjectMap2 + parentStream dispatch.LookupResources2Stream + newSubjectType *core.RelationReference + entrypoint schema.ReachabilityEntrypoint + lrDispatcher dispatch.LookupResources2 + checkDispatcher dispatch.Check + dispatchChunkSize uint16 + caveatTypeSet *caveattypes.TypeSet + filteredSubjectIDs []string + + currentCheckIndex int + taskrunner *taskrunner.TaskRunner + + lock *sync.Mutex + ci cursorInformation // GUARDED_BY(lock) +} + +func (rdc *checkAndDispatchRunner) runAndWait() error { + // Kick off a check at the current cursor, to filter a portion of the initial results set. + rdc.taskrunner.Schedule(func(ctx context.Context) error { + return rdc.runChecker(ctx, rdc.currentCheckIndex) + }) + + return rdc.taskrunner.Wait() +} + +func (rdc *checkAndDispatchRunner) runChecker(ctx context.Context, startingIndex int) error { + rdc.lock.Lock() + if rdc.ci.limits.hasExhaustedLimit() { + rdc.lock.Unlock() + return nil + } + rdc.lock.Unlock() + + endingIndex := min(startingIndex+int(rdc.dispatchChunkSize), len(rdc.filteredSubjectIDs)) + resourceIDsToCheck := rdc.filteredSubjectIDs[startingIndex:endingIndex] + if len(resourceIDsToCheck) == 0 { + return nil + } + + ctx, span := tracer.Start(ctx, "lr2Check", trace.WithAttributes( + attribute.Int("resource-id-count", len(resourceIDsToCheck)), + )) + defer span.End() + + checkHints := make([]*v1.CheckHint, 0, len(resourceIDsToCheck)) + for _, resourceID := range resourceIDsToCheck { + checkHint, err := hints.HintForEntrypoint( + rdc.entrypoint, + resourceID, + tuple.FromCoreObjectAndRelation(rdc.parentRequest.TerminalSubject), + &v1.ResourceCheckResult{ + Membership: v1.ResourceCheckResult_MEMBER, + }) + if err != nil { + return err + } + checkHints = append(checkHints, checkHint) + } + + // NOTE: we are checking the containing permission here, *not* the target relation, as + // the goal is to shear for the containing permission. + resultsByResourceID, checkMetadata, _, err := computed.ComputeBulkCheck(ctx, rdc.checkDispatcher, rdc.caveatTypeSet, computed.CheckParameters{ + ResourceType: tuple.FromCoreRelationReference(rdc.newSubjectType), + Subject: tuple.FromCoreObjectAndRelation(rdc.parentRequest.TerminalSubject), + CaveatContext: rdc.parentRequest.Context.AsMap(), + AtRevision: rdc.parentRequest.Revision, + MaximumDepth: rdc.parentRequest.Metadata.DepthRemaining - 1, + DebugOption: computed.NoDebugging, + CheckHints: checkHints, + }, resourceIDsToCheck, rdc.dispatchChunkSize) + if err != nil { + return err + } + + adjustedResources := rdc.foundResources.cloneAsMutable() + + // Dispatch any resources that are visible. + resourceIDToDispatch := make([]string, 0, len(resourceIDsToCheck)) + for _, resourceID := range resourceIDsToCheck { + result, ok := resultsByResourceID[resourceID] + if !ok { + continue + } + + switch result.Membership { + case v1.ResourceCheckResult_MEMBER: + fallthrough + + case v1.ResourceCheckResult_CAVEATED_MEMBER: + // Record any additional caveats missing from the check. + adjustedResources.withAdditionalMissingContextForDispatchedResourceID(resourceID, result.MissingExprFields) + resourceIDToDispatch = append(resourceIDToDispatch, resourceID) + + case v1.ResourceCheckResult_NOT_MEMBER: + // Skip. + continue + + default: + return spiceerrors.MustBugf("unexpected result from check: %v", result.Membership) + } + } + + if len(resourceIDToDispatch) > 0 { + // Schedule a dispatch of those resources. + rdc.taskrunner.Schedule(func(ctx context.Context) error { + return rdc.runDispatch(ctx, resourceIDToDispatch, adjustedResources.asReadOnly(), checkMetadata, startingIndex) + }) + } + + // Start the next check chunk (if applicable). + nextIndex := startingIndex + len(resourceIDsToCheck) + if nextIndex < len(rdc.filteredSubjectIDs) { + rdc.taskrunner.Schedule(func(ctx context.Context) error { + return rdc.runChecker(ctx, nextIndex) + }) + } + + return nil +} + +func (rdc *checkAndDispatchRunner) runDispatch( + ctx context.Context, + resourceIDsToDispatch []string, + adjustedResources dispatchableResourcesSubjectMap2, + checkMetadata *v1.ResponseMeta, + startingIndex int, +) error { + rdc.lock.Lock() + if rdc.ci.limits.hasExhaustedLimit() { + rdc.lock.Unlock() + return nil + } + rdc.lock.Unlock() + + ctx, span := tracer.Start(ctx, "lr2Dispatch", trace.WithAttributes( + attribute.Int("resource-id-count", len(resourceIDsToDispatch)), + )) + defer span.End() + + // NOTE: Since we extracted a custom section from the cursor at the beginning of this run, we have to add + // the starting index to the cursor to ensure that the next run starts from the correct place, and we have + // to use the *updated* cursor below on the dispatch. + updatedCi, err := rdc.ci.withOutgoingSection(strconv.Itoa(startingIndex)) + if err != nil { + return err + } + responsePartialCursor := updatedCi.responsePartialCursor() + + // Dispatch to the parent resource type and publish any results found. + isFirstPublishCall := true + + wrappedStream := dispatch.NewHandlingDispatchStream(ctx, func(result *v1.DispatchLookupResources2Response) error { + if err := ctx.Err(); err != nil { + return err + } + + if err := publishResultToParentStream(ctx, result, rdc.ci, responsePartialCursor, adjustedResources, nil, isFirstPublishCall, checkMetadata, rdc.parentStream); err != nil { + return err + } + isFirstPublishCall = false + return nil + }) + + return rdc.lrDispatcher.DispatchLookupResources2(&v1.DispatchLookupResources2Request{ + ResourceRelation: rdc.parentRequest.ResourceRelation, + SubjectRelation: rdc.newSubjectType, + SubjectIds: resourceIDsToDispatch, + TerminalSubject: rdc.parentRequest.TerminalSubject, + Metadata: &v1.ResolverMeta{ + AtRevision: rdc.parentRequest.Revision.String(), + DepthRemaining: rdc.parentRequest.Metadata.DepthRemaining - 1, + }, + OptionalCursor: updatedCi.currentCursor, + OptionalLimit: rdc.ci.limits.currentLimit, + Context: rdc.parentRequest.Context, + }, wrappedStream) +} + +// unfilteredLookupResourcesDispatchStreamForEntrypoint creates a new dispatch stream that wraps +// the parent stream, and publishes the results of the lookup resources call to the parent stream, +// mapped via foundResources. +func unfilteredLookupResourcesDispatchStreamForEntrypoint( + ctx context.Context, + foundResources dispatchableResourcesSubjectMap2, + parentStream dispatch.LookupResources2Stream, + ci cursorInformation, +) dispatch.LookupResources2Stream { + isFirstPublishCall := true + + wrappedStream := dispatch.NewHandlingDispatchStream(ctx, func(result *v1.DispatchLookupResources2Response) error { + select { + case <-ctx.Done(): + return ctx.Err() + + default: + } + + if err := publishResultToParentStream(ctx, result, ci, ci.responsePartialCursor(), foundResources, nil, isFirstPublishCall, emptyMetadata, parentStream); err != nil { + return err + } + isFirstPublishCall = false + return nil + }) + + return wrappedStream +} + +// publishResultToParentStream publishes the result of a lookup resources call to the parent stream, +// mapped via foundResources. +func publishResultToParentStream( + ctx context.Context, + result *v1.DispatchLookupResources2Response, + ci cursorInformation, + responseCursor *v1.Cursor, + foundResources dispatchableResourcesSubjectMap2, + additionalMissingContext []string, + isFirstPublishCall bool, + additionalMetadata *v1.ResponseMeta, + parentStream dispatch.LookupResources2Stream, +) error { + // Map the found resources via the subject+resources used for dispatching, to determine + // if any need to be made conditional due to caveats. + mappedResource, err := foundResources.mapPossibleResource(result.Resource) + if err != nil { + return err + } + + if !ci.limits.prepareForPublishing() { + return nil + } + + // The cursor for the response is that of the parent response + the cursor from the result itself. + afterResponseCursor, err := combineCursors( + responseCursor, + result.AfterResponseCursor, + ) + if err != nil { + return err + } + + metadata := result.Metadata + if isFirstPublishCall { + metadata = addCallToResponseMetadata(metadata) + metadata = combineResponseMetadata(ctx, metadata, additionalMetadata) + } else { + metadata = addAdditionalDepthRequired(metadata) + } + + missingContextParameters := mapz.NewSet(mappedResource.MissingContextParams...) + missingContextParameters.Extend(result.Resource.MissingContextParams) + missingContextParameters.Extend(additionalMissingContext) + + mappedResource.MissingContextParams = missingContextParameters.AsSlice() + + resp := &v1.DispatchLookupResources2Response{ + Resource: mappedResource, + Metadata: metadata, + AfterResponseCursor: afterResponseCursor, + } + + return parentStream.Publish(resp) +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go b/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go new file mode 100644 index 0000000..8ab20f4 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/membershipset.go @@ -0,0 +1,243 @@ +package graph + +import ( + "github.com/authzed/spicedb/internal/caveats" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +var ( + caveatOr = caveats.Or + caveatAnd = caveats.And + caveatSub = caveats.Subtract + wrapCaveat = caveats.CaveatAsExpr +) + +// CheckResultsMap defines a type that is a map from resource ID to ResourceCheckResult. +// This must match that defined in the DispatchCheckResponse for the `results_by_resource_id` +// field. +type CheckResultsMap map[string]*v1.ResourceCheckResult + +// NewMembershipSet constructs a new helper set for tracking the membership found for a dispatched +// check request. +func NewMembershipSet() *MembershipSet { + return &MembershipSet{ + hasDeterminedMember: false, + membersByID: map[string]*core.CaveatExpression{}, + } +} + +func membershipSetFromMap(mp map[string]*core.CaveatExpression) *MembershipSet { + ms := NewMembershipSet() + for resourceID, result := range mp { + ms.addMember(resourceID, result) + } + return ms +} + +// MembershipSet is a helper set that trackes the membership results for a dispatched Check +// request, including tracking of the caveats associated with found resource IDs. +type MembershipSet struct { + membersByID map[string]*core.CaveatExpression + hasDeterminedMember bool +} + +// AddDirectMember adds a resource ID that was *directly* found for the dispatched check, with +// optional caveat found on the relationship. +func (ms *MembershipSet) AddDirectMember(resourceID string, caveat *core.ContextualizedCaveat) { + ms.addMember(resourceID, wrapCaveat(caveat)) +} + +// AddMemberViaRelationship adds a resource ID that was found via another relationship, such +// as the result of an arrow operation. The `parentRelationship` is the relationship that was +// followed before the resource itself was resolved. This method will properly apply the caveat(s) +// from both the parent relationship and the resource's result itself, assuming either have a caveat +// associated. +func (ms *MembershipSet) AddMemberViaRelationship( + resourceID string, + resourceCaveatExpression *core.CaveatExpression, + parentRelationship tuple.Relationship, +) { + ms.AddMemberWithParentCaveat(resourceID, resourceCaveatExpression, parentRelationship.OptionalCaveat) +} + +// AddMemberWithParentCaveat adds the given resource ID as a member with the parent caveat +// combined via intersection with the resource's caveat. The parent caveat may be nil. +func (ms *MembershipSet) AddMemberWithParentCaveat( + resourceID string, + resourceCaveatExpression *core.CaveatExpression, + parentCaveat *core.ContextualizedCaveat, +) { + intersection := caveatAnd(wrapCaveat(parentCaveat), resourceCaveatExpression) + ms.addMember(resourceID, intersection) +} + +// AddMemberWithOptionalCaveats adds the given resource ID as a member with the optional caveats combined +// via intersection. +func (ms *MembershipSet) AddMemberWithOptionalCaveats( + resourceID string, + caveats []*core.CaveatExpression, +) { + if len(caveats) == 0 { + ms.addMember(resourceID, nil) + return + } + + intersection := caveats[0] + for _, caveat := range caveats[1:] { + intersection = caveatAnd(intersection, caveat) + } + + ms.addMember(resourceID, intersection) +} + +func (ms *MembershipSet) addMember(resourceID string, caveatExpr *core.CaveatExpression) { + existing, ok := ms.membersByID[resourceID] + if !ok { + ms.hasDeterminedMember = ms.hasDeterminedMember || caveatExpr == nil + ms.membersByID[resourceID] = caveatExpr + return + } + + // If a determined membership result has already been found (i.e. there is no caveat), + // then nothing more to do. + if existing == nil { + return + } + + // If the new caveat expression is nil, then we are adding a determined result. + if caveatExpr == nil { + ms.hasDeterminedMember = true + ms.membersByID[resourceID] = nil + return + } + + // Otherwise, the caveats get unioned together. + ms.membersByID[resourceID] = caveatOr(existing, caveatExpr) +} + +// UnionWith combines the results found in the given map with the members of this set. +// The changes are made in-place. +func (ms *MembershipSet) UnionWith(resultsMap CheckResultsMap) { + for resourceID, details := range resultsMap { + if details.Membership != v1.ResourceCheckResult_NOT_MEMBER { + ms.addMember(resourceID, details.Expression) + } + } +} + +// IntersectWith intersects the results found in the given map with the members of this set. +// The changes are made in-place. +func (ms *MembershipSet) IntersectWith(resultsMap CheckResultsMap) { + for resourceID := range ms.membersByID { + if details, ok := resultsMap[resourceID]; !ok || details.Membership == v1.ResourceCheckResult_NOT_MEMBER { + delete(ms.membersByID, resourceID) + } + } + + ms.hasDeterminedMember = false + for resourceID, details := range resultsMap { + existing, ok := ms.membersByID[resourceID] + if !ok || details.Membership == v1.ResourceCheckResult_NOT_MEMBER { + continue + } + if existing == nil && details.Expression == nil { + ms.hasDeterminedMember = true + continue + } + + ms.membersByID[resourceID] = caveatAnd(existing, details.Expression) + } +} + +// Subtract subtracts the results found in the given map with the members of this set. +// The changes are made in-place. +func (ms *MembershipSet) Subtract(resultsMap CheckResultsMap) { + ms.hasDeterminedMember = false + for resourceID, expression := range ms.membersByID { + if details, ok := resultsMap[resourceID]; ok && details.Membership != v1.ResourceCheckResult_NOT_MEMBER { + // If the incoming member has no caveat, then this removal is absolute. + if details.Expression == nil { + delete(ms.membersByID, resourceID) + continue + } + + // Otherwise, the caveat expression gets combined with an intersection of the inversion + // of the expression. + ms.membersByID[resourceID] = caveatSub(expression, details.Expression) + } else { + if expression == nil { + ms.hasDeterminedMember = true + } + } + } +} + +// HasConcreteResourceID returns whether the resourceID was found in the set +// and has no caveat attached. +func (ms *MembershipSet) HasConcreteResourceID(resourceID string) bool { + if ms == nil { + return false + } + + found, ok := ms.membersByID[resourceID] + return ok && found == nil +} + +// GetResourceID returns a bool indicating whether the resource is found in the set and the +// associated caveat expression, if any. +func (ms *MembershipSet) GetResourceID(resourceID string) (bool, *core.CaveatExpression) { + if ms == nil { + return false, nil + } + + caveat, ok := ms.membersByID[resourceID] + return ok, caveat +} + +// Size returns the number of elements in the membership set. +func (ms *MembershipSet) Size() int { + if ms == nil { + return 0 + } + + return len(ms.membersByID) +} + +// IsEmpty returns true if the set is empty. +func (ms *MembershipSet) IsEmpty() bool { + if ms == nil { + return true + } + + return len(ms.membersByID) == 0 +} + +// HasDeterminedMember returns whether there exists at least one non-caveated member of the set. +func (ms *MembershipSet) HasDeterminedMember() bool { + if ms == nil { + return false + } + + return ms.hasDeterminedMember +} + +// AsCheckResultsMap converts the membership set back into a CheckResultsMap for placement into +// a DispatchCheckResult. +func (ms *MembershipSet) AsCheckResultsMap() CheckResultsMap { + resultsMap := make(CheckResultsMap, len(ms.membersByID)) + for resourceID, caveat := range ms.membersByID { + membership := v1.ResourceCheckResult_MEMBER + if caveat != nil { + membership = v1.ResourceCheckResult_CAVEATED_MEMBER + } + + resultsMap[resourceID] = &v1.ResourceCheckResult{ + Membership: membership, + Expression: caveat, + } + } + + return resultsMap +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go b/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go new file mode 100644 index 0000000..4e41955 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/resourcesubjectsmap2.go @@ -0,0 +1,248 @@ +package graph + +import ( + "sort" + "sync" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +type syncONRSet struct { + sync.Mutex + items map[string]struct{} // GUARDED_BY(Mutex) +} + +func (s *syncONRSet) Add(onr *core.ObjectAndRelation) bool { + key := tuple.StringONR(tuple.FromCoreObjectAndRelation(onr)) + s.Lock() + _, existed := s.items[key] + if !existed { + s.items[key] = struct{}{} + } + s.Unlock() + return !existed +} + +func NewSyncONRSet() *syncONRSet { + return &syncONRSet{items: make(map[string]struct{})} +} + +// resourcesSubjectMap2 is a multimap which tracks mappings from found resource IDs +// to the subject IDs (may be more than one) for each, as well as whether the mapping +// is conditional due to the use of a caveat on the relationship which formed the mapping. +type resourcesSubjectMap2 struct { + resourceType *core.RelationReference + resourcesAndSubjects *mapz.MultiMap[string, subjectInfo2] +} + +// subjectInfo2 is the information about a subject contained in a resourcesSubjectMap2. +type subjectInfo2 struct { + subjectID string + missingContextParameters []string +} + +func newResourcesSubjectMap2(resourceType *core.RelationReference) resourcesSubjectMap2 { + return resourcesSubjectMap2{ + resourceType: resourceType, + resourcesAndSubjects: mapz.NewMultiMap[string, subjectInfo2](), + } +} + +func newResourcesSubjectMap2WithCapacity(resourceType *core.RelationReference, capacity uint32) resourcesSubjectMap2 { + return resourcesSubjectMap2{ + resourceType: resourceType, + resourcesAndSubjects: mapz.NewMultiMapWithCap[string, subjectInfo2](capacity), + } +} + +func subjectIDsToResourcesMap2(resourceType *core.RelationReference, subjectIDs []string) resourcesSubjectMap2 { + rsm := newResourcesSubjectMap2(resourceType) + for _, subjectID := range subjectIDs { + rsm.addSubjectIDAsFoundResourceID(subjectID) + } + return rsm +} + +// addRelationship adds the relationship to the resource subject map, recording a mapping from +// the resource of the relationship to the subject, as well as whether the relationship was caveated. +func (rsm resourcesSubjectMap2) addRelationship(rel tuple.Relationship, missingContextParameters []string) error { + spiceerrors.DebugAssert(func() bool { + return rel.Resource.ObjectType == rsm.resourceType.Namespace && rel.Resource.Relation == rsm.resourceType.Relation + }, "invalid relationship for addRelationship. expected: %v, found: %v", rsm.resourceType, rel.Resource) + + spiceerrors.DebugAssert(func() bool { + return len(missingContextParameters) == 0 || rel.OptionalCaveat != nil + }, "missing context parameters must be empty if there is no caveat") + + rsm.resourcesAndSubjects.Add(rel.Resource.ObjectID, subjectInfo2{rel.Subject.ObjectID, missingContextParameters}) + return nil +} + +// withAdditionalMissingContextForDispatchedResourceID adds additional missing context parameters +// to the existing missing context parameters for the dispatched resource ID. +func (rsm resourcesSubjectMap2) withAdditionalMissingContextForDispatchedResourceID( + resourceID string, + additionalMissingContext []string, +) { + if len(additionalMissingContext) == 0 { + return + } + + subjectInfo2s, _ := rsm.resourcesAndSubjects.Get(resourceID) + updatedInfos := make([]subjectInfo2, 0, len(subjectInfo2s)) + for _, info := range subjectInfo2s { + info.missingContextParameters = append(info.missingContextParameters, additionalMissingContext...) + updatedInfos = append(updatedInfos, info) + } + rsm.resourcesAndSubjects.Set(resourceID, updatedInfos) +} + +// addSubjectIDAsFoundResourceID adds a subject ID directly as a found subject for itself as the resource, +// with no associated caveat. +func (rsm resourcesSubjectMap2) addSubjectIDAsFoundResourceID(subjectID string) { + rsm.resourcesAndSubjects.Add(subjectID, subjectInfo2{subjectID, nil}) +} + +// asReadOnly returns a read-only dispatchableResourcesSubjectMap2 for dispatching for the +// resources in this map (if any). +func (rsm resourcesSubjectMap2) asReadOnly() dispatchableResourcesSubjectMap2 { + return dispatchableResourcesSubjectMap2{rsm} +} + +func (rsm resourcesSubjectMap2) len() int { + return rsm.resourcesAndSubjects.Len() +} + +// dispatchableResourcesSubjectMap2 is a read-only, frozen version of the resourcesSubjectMap2 that +// can be used for mapping conditionals once calls have been dispatched. This is read-only due to +// its use by concurrent callers. +type dispatchableResourcesSubjectMap2 struct { + resourcesSubjectMap2 +} + +func (rsm dispatchableResourcesSubjectMap2) len() int { + return rsm.resourcesAndSubjects.Len() +} + +func (rsm dispatchableResourcesSubjectMap2) isEmpty() bool { + return rsm.resourcesAndSubjects.IsEmpty() +} + +func (rsm dispatchableResourcesSubjectMap2) resourceIDs() []string { + return rsm.resourcesAndSubjects.Keys() +} + +// filterSubjectIDsToDispatch returns the set of subject IDs that have not yet been +// dispatched, by adding them to the dispatched set. +func (rsm dispatchableResourcesSubjectMap2) filterSubjectIDsToDispatch(dispatched *syncONRSet, dispatchSubjectType *core.RelationReference) []string { + resourceIDs := rsm.resourceIDs() + filtered := make([]string, 0, len(resourceIDs)) + for _, resourceID := range resourceIDs { + if dispatched.Add(&core.ObjectAndRelation{ + Namespace: dispatchSubjectType.Namespace, + ObjectId: resourceID, + Relation: dispatchSubjectType.Relation, + }) { + filtered = append(filtered, resourceID) + } + } + + return filtered +} + +// cloneAsMutable returns a mutable clone of this dispatchableResourcesSubjectMap2. +func (rsm dispatchableResourcesSubjectMap2) cloneAsMutable() resourcesSubjectMap2 { + return resourcesSubjectMap2{ + resourceType: rsm.resourceType, + resourcesAndSubjects: rsm.resourcesAndSubjects.Clone(), + } +} + +func (rsm dispatchableResourcesSubjectMap2) asPossibleResources() []*v1.PossibleResource { + resources := make([]*v1.PossibleResource, 0, rsm.resourcesAndSubjects.Len()) + + // Sort for stability. + sortedResourceIds := rsm.resourcesAndSubjects.Keys() + sort.Strings(sortedResourceIds) + + for _, resourceID := range sortedResourceIds { + subjectInfo2s, _ := rsm.resourcesAndSubjects.Get(resourceID) + subjectIDs := make([]string, 0, len(subjectInfo2s)) + allCaveated := true + nonCaveatedSubjectIDs := make([]string, 0, len(subjectInfo2s)) + missingContextParameters := mapz.NewSet[string]() + + for _, info := range subjectInfo2s { + subjectIDs = append(subjectIDs, info.subjectID) + if len(info.missingContextParameters) == 0 { + allCaveated = false + nonCaveatedSubjectIDs = append(nonCaveatedSubjectIDs, info.subjectID) + } else { + missingContextParameters.Extend(info.missingContextParameters) + } + } + + // Sort for stability. + sort.Strings(subjectIDs) + + // If all the incoming edges are caveated, then the entire status has to be marked as a check + // is required. Otherwise, if there is at least *one* non-caveated incoming edge, then we can + // return the existing status as a short-circuit for those non-caveated found subjects. + if allCaveated { + resources = append(resources, &v1.PossibleResource{ + ResourceId: resourceID, + ForSubjectIds: subjectIDs, + MissingContextParams: missingContextParameters.AsSlice(), + }) + } else { + resources = append(resources, &v1.PossibleResource{ + ResourceId: resourceID, + ForSubjectIds: nonCaveatedSubjectIDs, + }) + } + } + return resources +} + +func (rsm dispatchableResourcesSubjectMap2) mapPossibleResource(foundResource *v1.PossibleResource) (*v1.PossibleResource, error) { + forSubjectIDs := mapz.NewSet[string]() + nonCaveatedSubjectIDs := mapz.NewSet[string]() + missingContextParameters := mapz.NewSet[string]() + + for _, forSubjectID := range foundResource.ForSubjectIds { + // Map from the incoming subject ID to the subject ID(s) that caused the dispatch. + infos, ok := rsm.resourcesAndSubjects.Get(forSubjectID) + if !ok { + return nil, spiceerrors.MustBugf("missing for subject ID") + } + + for _, info := range infos { + forSubjectIDs.Insert(info.subjectID) + if len(info.missingContextParameters) == 0 { + nonCaveatedSubjectIDs.Insert(info.subjectID) + } else { + missingContextParameters.Extend(info.missingContextParameters) + } + } + } + + // If there are some non-caveated IDs, return those and mark as the parent status. + if nonCaveatedSubjectIDs.Len() > 0 { + return &v1.PossibleResource{ + ResourceId: foundResource.ResourceId, + ForSubjectIds: nonCaveatedSubjectIDs.AsSlice(), + }, nil + } + + // Otherwise, everything is caveated, so return the full set of subject IDs and mark + // as a check is required. + return &v1.PossibleResource{ + ResourceId: foundResource.ResourceId, + ForSubjectIds: forSubjectIDs.AsSlice(), + MissingContextParams: missingContextParameters.AsSlice(), + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/graph/traceid.go b/vendor/github.com/authzed/spicedb/internal/graph/traceid.go new file mode 100644 index 0000000..e275bc5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/graph/traceid.go @@ -0,0 +1,13 @@ +package graph + +import ( + "github.com/google/uuid" +) + +// NewTraceID generates a new trace ID. The trace IDs will only be unique with +// a single dispatch request tree and should not be used for any other purpose. +// This function currently uses the UUID library to generate a new trace ID, +// which means it should not be invoked from performance-critical code paths. +func NewTraceID() string { + return uuid.NewString() +} |
