diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/spicedb/internal/services | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal/services')
22 files changed, 6855 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go b/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go new file mode 100644 index 0000000..05b3907 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go @@ -0,0 +1,208 @@ +package shared + +import ( + "context" + "errors" + "fmt" + "strconv" + + "github.com/rs/zerolog" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph" + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/sharederrors" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// ErrServiceReadOnly is an extended GRPC error returned when a service is in read-only mode. +var ErrServiceReadOnly = mustMakeStatusReadonly() + +func mustMakeStatusReadonly() error { + status, err := status.New(codes.Unavailable, "service read-only").WithDetails(&errdetails.ErrorInfo{ + Reason: v1.ErrorReason_name[int32(v1.ErrorReason_ERROR_REASON_SERVICE_READ_ONLY)], + Domain: spiceerrors.Domain, + }) + if err != nil { + panic("error constructing shared error type") + } + return status.Err() +} + +// NewSchemaWriteDataValidationError creates a new error representing that a schema write cannot be +// completed due to existing data that would be left unreferenced. +func NewSchemaWriteDataValidationError(message string, args ...any) SchemaWriteDataValidationError { + return SchemaWriteDataValidationError{ + error: fmt.Errorf(message, args...), + } +} + +// SchemaWriteDataValidationError occurs when a schema cannot be applied due to leaving data unreferenced. +type SchemaWriteDataValidationError struct { + error +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err SchemaWriteDataValidationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err SchemaWriteDataValidationError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_SCHEMA_TYPE_ERROR, + map[string]string{}, + ), + ) +} + +// MaxDepthExceededError is an error returned when the maximum depth for dispatching has been exceeded. +type MaxDepthExceededError struct { + *spiceerrors.WithAdditionalDetailsError + + // AllowedMaximumDepth is the configured allowed maximum depth. + AllowedMaximumDepth uint32 +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err MaxDepthExceededError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.ResourceExhausted, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_MAXIMUM_DEPTH_EXCEEDED, + err.AddToDetails(map[string]string{ + "maximum_depth_allowed": strconv.Itoa(int(err.AllowedMaximumDepth)), + }), + ), + ) +} + +// NewMaxDepthExceededError creates a new MaxDepthExceededError. +func NewMaxDepthExceededError(allowedMaximumDepth uint32, isCheckRequest bool) error { + if isCheckRequest { + return MaxDepthExceededError{ + spiceerrors.NewWithAdditionalDetailsError(fmt.Errorf("the check request has exceeded the allowable maximum depth of %d: this usually indicates a recursive or too deep data dependency. Try running zed with --explain to see the dependency. See: https://spicedb.dev/d/debug-max-depth-check", allowedMaximumDepth)), + allowedMaximumDepth, + } + } + + return MaxDepthExceededError{ + spiceerrors.NewWithAdditionalDetailsError(fmt.Errorf("the request has exceeded the allowable maximum depth of %d: this usually indicates a recursive or too deep data dependency. See: https://spicedb.dev/d/debug-max-depth", allowedMaximumDepth)), + allowedMaximumDepth, + } +} + +func AsValidationError(err error) *SchemaWriteDataValidationError { + var validationErr SchemaWriteDataValidationError + if errors.As(err, &validationErr) { + return &validationErr + } + return nil +} + +type ConfigForErrors struct { + MaximumAPIDepth uint32 + DebugTrace *v1.DebugInformation +} + +func RewriteErrorWithoutConfig(ctx context.Context, err error) error { + return rewriteError(ctx, err, nil) +} + +func RewriteError(ctx context.Context, err error, config *ConfigForErrors) error { + rerr := rewriteError(ctx, err, config) + if config != nil && config.DebugTrace != nil { + spiceerrors.WithAdditionalDetails(rerr, spiceerrors.DebugTraceErrorDetailsKey, config.DebugTrace.String()) + } + return rerr +} + +func rewriteError(ctx context.Context, err error, config *ConfigForErrors) error { + // Check if the error can be directly used. + if _, ok := status.FromError(err); ok { + return err + } + + // Otherwise, convert any graph/datastore errors. + var nsNotFoundError sharederrors.UnknownNamespaceError + var relationNotFoundError sharederrors.UnknownRelationError + + var compilerError compiler.BaseCompilerError + var sourceError spiceerrors.WithSourceError + var typeError schema.TypeError + var maxDepthError dispatch.MaxDepthExceededError + + switch { + case errors.As(err, &typeError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_SCHEMA_TYPE_ERROR) + case errors.As(err, &compilerError): + return spiceerrors.WithCodeAndReason(err, codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_SCHEMA_PARSE_ERROR) + case errors.As(err, &sourceError): + return spiceerrors.WithCodeAndReason(err, codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_SCHEMA_PARSE_ERROR) + + case errors.Is(err, cursor.ErrHashMismatch): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_INVALID_CURSOR) + + case errors.As(err, &nsNotFoundError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_DEFINITION) + case errors.As(err, &relationNotFoundError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_RELATION_OR_PERMISSION) + + case errors.As(err, &maxDepthError): + if config == nil { + return spiceerrors.MustBugf("missing config for API error") + } + + _, isCheckRequest := maxDepthError.Request.(*dispatchv1.DispatchCheckRequest) + return NewMaxDepthExceededError(config.MaximumAPIDepth, isCheckRequest) + + case errors.As(err, &datastore.ReadOnlyError{}): + return ErrServiceReadOnly + case errors.As(err, &datastore.InvalidRevisionError{}): + return status.Errorf(codes.OutOfRange, "invalid zedtoken: %s", err) + case errors.As(err, &datastore.CaveatNameNotFoundError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_CAVEAT) + case errors.As(err, &datastore.WatchDisabledError{}): + return status.Errorf(codes.FailedPrecondition, "%s", err) + case errors.As(err, &datastore.CounterAlreadyRegisteredError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_COUNTER_ALREADY_REGISTERED) + case errors.As(err, &datastore.CounterNotRegisteredError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_COUNTER_NOT_REGISTERED) + + case errors.As(err, &graph.RelationMissingTypeInfoError{}): + return status.Errorf(codes.FailedPrecondition, "failed precondition: %s", err) + case errors.As(err, &graph.AlwaysFailError{}): + log.Ctx(ctx).Err(err).Msg("received internal error") + return status.Errorf(codes.Internal, "internal error: %s", err) + case errors.As(err, &graph.UnimplementedError{}): + return status.Errorf(codes.Unimplemented, "%s", err) + case errors.Is(err, context.DeadlineExceeded): + return status.Errorf(codes.DeadlineExceeded, "%s", err) + case errors.Is(err, context.Canceled): + err := context.Cause(ctx) + if err != nil { + if _, ok := status.FromError(err); ok { + return err + } + } + + return status.Errorf(codes.Canceled, "%s", err) + default: + log.Ctx(ctx).Err(err).Msg("received unexpected error") + return err + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go b/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go new file mode 100644 index 0000000..455de0a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go @@ -0,0 +1,52 @@ +package shared + +import ( + "google.golang.org/grpc" + + "github.com/authzed/spicedb/internal/middleware/servicespecific" +) + +// WithUnaryServiceSpecificInterceptor is a helper to add a unary interceptor or interceptor +// chain to a service. +type WithUnaryServiceSpecificInterceptor struct { + Unary grpc.UnaryServerInterceptor +} + +// UnaryInterceptor implements servicespecific.ExtraUnaryInterceptor +func (wussi WithUnaryServiceSpecificInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor { + return wussi.Unary +} + +// WithStreamServiceSpecificInterceptor is a helper to add a stream interceptor or interceptor +// chain to a service. +type WithStreamServiceSpecificInterceptor struct { + Stream grpc.StreamServerInterceptor +} + +// StreamInterceptor implements servicespecific.ExtraStreamInterceptor +func (wsssi WithStreamServiceSpecificInterceptor) StreamInterceptor() grpc.StreamServerInterceptor { + return wsssi.Stream +} + +// WithServiceSpecificInterceptors is a helper to add both a unary and stream interceptor +// or interceptor chain to a service. +type WithServiceSpecificInterceptors struct { + Unary grpc.UnaryServerInterceptor + Stream grpc.StreamServerInterceptor +} + +// UnaryInterceptor implements servicespecific.ExtraUnaryInterceptor +func (wssi WithServiceSpecificInterceptors) UnaryInterceptor() grpc.UnaryServerInterceptor { + return wssi.Unary +} + +// StreamInterceptor implements servicespecific.ExtraStreamInterceptor +func (wssi WithServiceSpecificInterceptors) StreamInterceptor() grpc.StreamServerInterceptor { + return wssi.Stream +} + +var ( + _ servicespecific.ExtraUnaryInterceptor = WithUnaryServiceSpecificInterceptor{} + _ servicespecific.ExtraUnaryInterceptor = WithServiceSpecificInterceptors{} + _ servicespecific.ExtraStreamInterceptor = WithServiceSpecificInterceptors{} +) diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go b/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go new file mode 100644 index 0000000..83accde --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go @@ -0,0 +1,474 @@ +package shared + +import ( + "context" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/namespace" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats" + nsdiff "github.com/authzed/spicedb/pkg/diff/namespace" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ValidatedSchemaChanges is a set of validated schema changes that can be applied to the datastore. +type ValidatedSchemaChanges struct { + compiled *compiler.CompiledSchema + validatedTypeSystems map[string]*schema.ValidatedDefinition + newCaveatDefNames *mapz.Set[string] + newObjectDefNames *mapz.Set[string] + additiveOnly bool +} + +// ValidateSchemaChanges validates the schema found in the compiled schema and returns a +// ValidatedSchemaChanges, if fully validated. +func ValidateSchemaChanges(ctx context.Context, compiled *compiler.CompiledSchema, caveatTypeSet *caveattypes.TypeSet, additiveOnly bool) (*ValidatedSchemaChanges, error) { + // 1) Validate the caveats defined. + newCaveatDefNames := mapz.NewSet[string]() + for _, caveatDef := range compiled.CaveatDefinitions { + if err := namespace.ValidateCaveatDefinition(caveatTypeSet, caveatDef); err != nil { + return nil, err + } + + newCaveatDefNames.Insert(caveatDef.Name) + } + + // 2) Validate the namespaces defined. + newObjectDefNames := mapz.NewSet[string]() + validatedTypeSystems := make(map[string]*schema.ValidatedDefinition, len(compiled.ObjectDefinitions)) + res := schema.ResolverForPredefinedDefinitions(schema.PredefinedElements{ + Definitions: compiled.ObjectDefinitions, + Caveats: compiled.CaveatDefinitions, + }) + ts := schema.NewTypeSystem(res) + + for _, nsdef := range compiled.ObjectDefinitions { + vts, err := ts.GetValidatedDefinition(ctx, nsdef.GetName()) + if err != nil { + return nil, err + } + + validatedTypeSystems[nsdef.Name] = vts + newObjectDefNames.Insert(nsdef.Name) + } + + return &ValidatedSchemaChanges{ + compiled: compiled, + validatedTypeSystems: validatedTypeSystems, + newCaveatDefNames: newCaveatDefNames, + newObjectDefNames: newObjectDefNames, + additiveOnly: additiveOnly, + }, nil +} + +// AppliedSchemaChanges holds information about the applied schema changes. +type AppliedSchemaChanges struct { + // TotalOperationCount holds the total number of "dispatch" operations performed by the schema + // being applied. + TotalOperationCount int + + // NewObjectDefNames contains the names of the newly added object definitions. + NewObjectDefNames []string + + // RemovedObjectDefNames contains the names of the removed object definitions. + RemovedObjectDefNames []string + + // NewCaveatDefNames contains the names of the newly added caveat definitions. + NewCaveatDefNames []string + + // RemovedCaveatDefNames contains the names of the removed caveat definitions. + RemovedCaveatDefNames []string +} + +// ApplySchemaChanges applies schema changes found in the validated changes struct, via the specified +// ReadWriteTransaction. +func ApplySchemaChanges(ctx context.Context, rwt datastore.ReadWriteTransaction, caveatTypeSet *caveattypes.TypeSet, validated *ValidatedSchemaChanges) (*AppliedSchemaChanges, error) { + existingCaveats, err := rwt.ListAllCaveats(ctx) + if err != nil { + return nil, err + } + + existingObjectDefs, err := rwt.ListAllNamespaces(ctx) + if err != nil { + return nil, err + } + + return ApplySchemaChangesOverExisting(ctx, rwt, caveatTypeSet, validated, datastore.DefinitionsOf(existingCaveats), datastore.DefinitionsOf(existingObjectDefs)) +} + +// ApplySchemaChangesOverExisting applies schema changes found in the validated changes struct, against +// existing caveat and object definitions given. +func ApplySchemaChangesOverExisting( + ctx context.Context, + rwt datastore.ReadWriteTransaction, + caveatTypeSet *caveattypes.TypeSet, + validated *ValidatedSchemaChanges, + existingCaveats []*core.CaveatDefinition, + existingObjectDefs []*core.NamespaceDefinition, +) (*AppliedSchemaChanges, error) { + // Build a map of existing caveats to determine those being removed, if any. + existingCaveatDefMap := make(map[string]*core.CaveatDefinition, len(existingCaveats)) + existingCaveatDefNames := mapz.NewSet[string]() + + for _, existingCaveat := range existingCaveats { + existingCaveatDefMap[existingCaveat.Name] = existingCaveat + existingCaveatDefNames.Insert(existingCaveat.Name) + } + + // For each caveat definition, perform a diff and ensure the changes will not result in type errors. + caveatDefsWithChanges := make([]*core.CaveatDefinition, 0, len(validated.compiled.CaveatDefinitions)) + for _, caveatDef := range validated.compiled.CaveatDefinitions { + diff, err := sanityCheckCaveatChanges(ctx, rwt, caveatTypeSet, caveatDef, existingCaveatDefMap) + if err != nil { + return nil, err + } + + if len(diff.Deltas()) > 0 { + caveatDefsWithChanges = append(caveatDefsWithChanges, caveatDef) + } + } + + removedCaveatDefNames := existingCaveatDefNames.Subtract(validated.newCaveatDefNames) + + // Build a map of existing definitions to determine those being removed, if any. + existingObjectDefMap := make(map[string]*core.NamespaceDefinition, len(existingObjectDefs)) + existingObjectDefNames := mapz.NewSet[string]() + for _, existingDef := range existingObjectDefs { + existingObjectDefMap[existingDef.Name] = existingDef + existingObjectDefNames.Insert(existingDef.Name) + } + + // For each definition, perform a diff and ensure the changes will not result in any + // breaking changes. + objectDefsWithChanges := make([]*core.NamespaceDefinition, 0, len(validated.compiled.ObjectDefinitions)) + for _, nsdef := range validated.compiled.ObjectDefinitions { + diff, err := sanityCheckNamespaceChanges(ctx, rwt, nsdef, existingObjectDefMap) + if err != nil { + return nil, err + } + + if len(diff.Deltas()) > 0 { + objectDefsWithChanges = append(objectDefsWithChanges, nsdef) + + vts, ok := validated.validatedTypeSystems[nsdef.Name] + if !ok { + return nil, spiceerrors.MustBugf("validated type system not found for namespace `%s`", nsdef.Name) + } + + if err := namespace.AnnotateNamespace(vts); err != nil { + return nil, err + } + } + } + + log.Ctx(ctx). + Trace(). + Int("objectDefinitions", len(validated.compiled.ObjectDefinitions)). + Int("caveatDefinitions", len(validated.compiled.CaveatDefinitions)). + Int("objectDefsWithChanges", len(objectDefsWithChanges)). + Int("caveatDefsWithChanges", len(caveatDefsWithChanges)). + Msg("validated namespace definitions") + + // Ensure that deleting namespaces will not result in any relationships left without associated + // schema. + removedObjectDefNames := existingObjectDefNames.Subtract(validated.newObjectDefNames) + if !validated.additiveOnly { + if err := removedObjectDefNames.ForEach(func(nsdefName string) error { + return ensureNoRelationshipsExist(ctx, rwt, nsdefName) + }); err != nil { + return nil, err + } + } + + // Write the new/changes caveats. + if len(caveatDefsWithChanges) > 0 { + if err := rwt.WriteCaveats(ctx, caveatDefsWithChanges); err != nil { + return nil, err + } + } + + // Write the new/changed namespaces. + if len(objectDefsWithChanges) > 0 { + if err := rwt.WriteNamespaces(ctx, objectDefsWithChanges...); err != nil { + return nil, err + } + } + + if !validated.additiveOnly { + // Delete the removed namespaces. + if removedObjectDefNames.Len() > 0 { + if err := rwt.DeleteNamespaces(ctx, removedObjectDefNames.AsSlice()...); err != nil { + return nil, err + } + } + + // Delete the removed caveats. + if !removedCaveatDefNames.IsEmpty() { + if err := rwt.DeleteCaveats(ctx, removedCaveatDefNames.AsSlice()); err != nil { + return nil, err + } + } + } + + log.Ctx(ctx).Trace(). + Interface("objectDefinitions", validated.compiled.ObjectDefinitions). + Interface("caveatDefinitions", validated.compiled.CaveatDefinitions). + Object("addedOrChangedObjectDefinitions", validated.newObjectDefNames). + Object("removedObjectDefinitions", removedObjectDefNames). + Object("addedOrChangedCaveatDefinitions", validated.newCaveatDefNames). + Object("removedCaveatDefinitions", removedCaveatDefNames). + Msg("completed schema update") + + return &AppliedSchemaChanges{ + TotalOperationCount: len(validated.compiled.ObjectDefinitions) + len(validated.compiled.CaveatDefinitions) + removedObjectDefNames.Len() + removedCaveatDefNames.Len(), + NewObjectDefNames: validated.newObjectDefNames.Subtract(existingObjectDefNames).AsSlice(), + RemovedObjectDefNames: removedObjectDefNames.AsSlice(), + NewCaveatDefNames: validated.newCaveatDefNames.Subtract(existingCaveatDefNames).AsSlice(), + RemovedCaveatDefNames: removedCaveatDefNames.AsSlice(), + }, nil +} + +// sanityCheckCaveatChanges ensures that a caveat definition being written does not break +// the types of the parameters that may already exist on relationships. +func sanityCheckCaveatChanges( + _ context.Context, + _ datastore.ReadWriteTransaction, + caveatTypeSet *caveattypes.TypeSet, + caveatDef *core.CaveatDefinition, + existingDefs map[string]*core.CaveatDefinition, +) (*caveatdiff.Diff, error) { + // Ensure that the updated namespace does not break the existing tuple data. + existing := existingDefs[caveatDef.Name] + diff, err := caveatdiff.DiffCaveats(existing, caveatDef, caveatTypeSet) + if err != nil { + return nil, err + } + + for _, delta := range diff.Deltas() { + switch delta.Type { + case caveatdiff.RemovedParameter: + return diff, NewSchemaWriteDataValidationError("cannot remove parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name) + + case caveatdiff.ParameterTypeChanged: + return diff, NewSchemaWriteDataValidationError("cannot change the type of parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name) + } + } + + return diff, nil +} + +// ensureNoRelationshipsExist ensures that no relationships exist within the namespace with the given name. +func ensureNoRelationshipsExist(ctx context.Context, rwt datastore.ReadWriteTransaction, namespaceName string) error { + qy, qyErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{OptionalResourceType: namespaceName}, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceOfType), + ) + if err := errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete object definition `%s`, as a relationship exists under it", + namespaceName, + ); err != nil { + return err + } + + qy, qyErr = rwt.ReverseQueryRelationships( + ctx, + datastore.SubjectsFilter{ + SubjectType: namespaceName, + }, + options.WithLimitForReverse(options.LimitOne), + options.WithQueryShapeForReverse(queryshape.FindSubjectOfType), + ) + err := errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete object definition `%s`, as a relationship references it", + namespaceName, + ) + if err != nil { + return err + } + + return nil +} + +// sanityCheckNamespaceChanges ensures that a namespace definition being written does not result +// in breaking changes, such as relationships without associated defined schema object definitions +// and relations. +func sanityCheckNamespaceChanges( + ctx context.Context, + rwt datastore.ReadWriteTransaction, + nsdef *core.NamespaceDefinition, + existingDefs map[string]*core.NamespaceDefinition, +) (*nsdiff.Diff, error) { + // Ensure that the updated namespace does not break the existing tuple data. + existing := existingDefs[nsdef.Name] + diff, err := nsdiff.DiffNamespaces(existing, nsdef) + if err != nil { + return nil, err + } + + for _, delta := range diff.Deltas() { + switch delta.Type { + case nsdiff.RemovedRelation: + // NOTE: We add the subject filters here to ensure the reverse relationship index is used + // by the datastores. As there is no index that has {namespace, relation} directly, but there + // *is* an index that has {subject_namespace, subject_relation, namespace, relation}, we can + // force the datastore to use the reverse index by adding the subject filters. + var previousRelation *core.Relation + for _, relation := range existing.Relation { + if relation.Name == delta.RelationName { + previousRelation = relation + break + } + } + + if previousRelation == nil { + return nil, spiceerrors.MustBugf("relation `%s` not found in existing namespace definition", delta.RelationName) + } + + subjectSelectors := make([]datastore.SubjectsSelector, 0, len(previousRelation.TypeInformation.AllowedDirectRelations)) + for _, allowedType := range previousRelation.TypeInformation.AllowedDirectRelations { + if allowedType.GetRelation() == datastore.Ellipsis { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: allowedType.Namespace, + RelationFilter: datastore.SubjectRelationFilter{ + IncludeEllipsisRelation: true, + }, + }) + } else { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: allowedType.Namespace, + RelationFilter: datastore.SubjectRelationFilter{ + NonEllipsisRelation: allowedType.GetRelation(), + }, + }) + } + } + + qy, qyErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{ + OptionalResourceType: nsdef.Name, + OptionalResourceRelation: delta.RelationName, + OptionalSubjectsSelectors: subjectSelectors, + }, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceOfTypeAndRelation), + ) + + err = errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete relation `%s` in object definition `%s`, as a relationship exists under it", delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + + // Also check for right sides of tuples. + qy, qyErr = rwt.ReverseQueryRelationships( + ctx, + datastore.SubjectsFilter{ + SubjectType: nsdef.Name, + RelationFilter: datastore.SubjectRelationFilter{ + NonEllipsisRelation: delta.RelationName, + }, + }, + options.WithLimitForReverse(options.LimitOne), + options.WithQueryShapeForReverse(queryshape.FindSubjectOfTypeAndRelation), + ) + err = errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete relation `%s` in object definition `%s`, as a relationship references it", delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + + case nsdiff.RelationAllowedTypeRemoved: + var optionalSubjectIds []string + var relationFilter datastore.SubjectRelationFilter + var optionalCaveatNameFilter datastore.CaveatNameFilter + + if delta.AllowedType.GetPublicWildcard() != nil { + optionalSubjectIds = []string{tuple.PublicWildcard} + } else { + relationFilter = datastore.SubjectRelationFilter{ + NonEllipsisRelation: delta.AllowedType.GetRelation(), + } + } + + if delta.AllowedType.GetRequiredCaveat() != nil && delta.AllowedType.GetRequiredCaveat().CaveatName != "" { + optionalCaveatNameFilter = datastore.WithCaveatName(delta.AllowedType.GetRequiredCaveat().CaveatName) + } else { + optionalCaveatNameFilter = datastore.WithNoCaveat() + } + + expirationOption := datastore.ExpirationFilterOptionNoExpiration + if delta.AllowedType.RequiredExpiration != nil { + expirationOption = datastore.ExpirationFilterOptionHasExpiration + } + + qyr, qyrErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{ + OptionalResourceType: nsdef.Name, + OptionalResourceRelation: delta.RelationName, + OptionalSubjectsSelectors: []datastore.SubjectsSelector{ + { + OptionalSubjectType: delta.AllowedType.Namespace, + OptionalSubjectIds: optionalSubjectIds, + RelationFilter: relationFilter, + }, + }, + OptionalCaveatNameFilter: optionalCaveatNameFilter, + OptionalExpirationOption: expirationOption, + }, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceRelationForSubjectRelation), + ) + err = errorIfTupleIteratorReturnsTuples( + ctx, + qyr, + qyrErr, + "cannot remove allowed type `%s` from relation `%s` in object definition `%s`, as a relationship exists with it", + schema.SourceForAllowedRelation(delta.AllowedType), delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + } + } + return diff, nil +} + +// errorIfTupleIteratorReturnsTuples takes a tuple iterator and any error that was generated +// when the original iterator was created, and returns an error if iterator contains any tuples. +func errorIfTupleIteratorReturnsTuples(_ context.Context, qy datastore.RelationshipIterator, qyErr error, message string, args ...interface{}) error { + if qyErr != nil { + return qyErr + } + + for _, err := range qy { + if err != nil { + return err + } + return NewSchemaWriteDataValidationError(message, args...) + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go b/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go new file mode 100644 index 0000000..819452e --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/bulkcheck.go @@ -0,0 +1,332 @@ +package v1 + +import ( + "context" + "slices" + "sync" + "time" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/jzelinskie/stringz" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph" + "github.com/authzed/spicedb/internal/graph/computed" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/internal/taskrunner" + "github.com/authzed/spicedb/internal/telemetry" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/genutil/slicez" + "github.com/authzed/spicedb/pkg/middleware/consistency" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// bulkChecker contains the logic to allow ExperimentalService/BulkCheckPermission and +// PermissionsService/CheckBulkPermissions to share the same implementation. +type bulkChecker struct { + maxAPIDepth uint32 + maxCaveatContextSize int + maxConcurrency uint16 + caveatTypeSet *caveattypes.TypeSet + + dispatch dispatch.Dispatcher + dispatchChunkSize uint16 +} + +const maxBulkCheckCount = 10000 + +func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) { + telemetry.RecordLogicalChecks(uint64(len(req.Items))) + + atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, err + } + + if len(req.Items) > maxBulkCheckCount { + return nil, NewExceedsMaximumChecksErr(uint64(len(req.Items)), maxBulkCheckCount) + } + + // Compute a hash for each requested item and record its index(es) for the items, to be used for sorting of results. + itemCount, err := genutil.EnsureUInt32(len(req.Items)) + if err != nil { + return nil, err + } + + itemIndexByHash := mapz.NewMultiMapWithCap[string, int](itemCount) + for index, item := range req.Items { + itemHash, err := computeCheckBulkPermissionsItemHash(item) + if err != nil { + return nil, err + } + + itemIndexByHash.Add(itemHash, index) + } + + // Identify checks with same permission+subject over different resources and group them. This is doable because + // the dispatching system already internally supports this kind of batching for performance. + groupedItems, err := groupItems(ctx, groupingParameters{ + atRevision: atRevision, + maxCaveatContextSize: bc.maxCaveatContextSize, + maximumAPIDepth: bc.maxAPIDepth, + withTracing: req.WithTracing, + }, req.Items) + if err != nil { + return nil, err + } + + bulkResponseMutex := sync.Mutex{} + + spiceerrors.DebugAssert(func() bool { + return bc.maxConcurrency > 0 + }, "max concurrency must be greater than 0 in bulk check") + + tr := taskrunner.NewPreloadedTaskRunner(ctx, bc.maxConcurrency, len(groupedItems)) + + respMetadata := &dispatchv1.ResponseMeta{ + DispatchCount: 1, + CachedDispatchCount: 0, + DepthRequired: 1, + DebugInfo: nil, + } + usagemetrics.SetInContext(ctx, respMetadata) + + orderedPairs := make([]*v1.CheckBulkPermissionsPair, len(req.Items)) + + addPair := func(pair *v1.CheckBulkPermissionsPair) error { + pairItemHash, err := computeCheckBulkPermissionsItemHash(pair.Request) + if err != nil { + return err + } + + found, ok := itemIndexByHash.Get(pairItemHash) + if !ok { + return spiceerrors.MustBugf("missing expected item hash") + } + + for _, index := range found { + orderedPairs[index] = pair + } + + return nil + } + + appendResultsForError := func(params *computed.CheckParameters, resourceIDs []string, err error) error { + rewritten := shared.RewriteError(ctx, err, &shared.ConfigForErrors{ + MaximumAPIDepth: bc.maxAPIDepth, + }) + statusResp, ok := status.FromError(rewritten) + if !ok { + // If error is not a gRPC Status, fail the entire bulk check request. + return err + } + + bulkResponseMutex.Lock() + defer bulkResponseMutex.Unlock() + + for _, resourceID := range resourceIDs { + reqItem, err := requestItemFromResourceAndParameters(params, resourceID) + if err != nil { + return err + } + + if err := addPair(&v1.CheckBulkPermissionsPair{ + Request: reqItem, + Response: &v1.CheckBulkPermissionsPair_Error{ + Error: statusResp.Proto(), + }, + }); err != nil { + return err + } + } + + return nil + } + + appendResultsForCheck := func( + params *computed.CheckParameters, + resourceIDs []string, + metadata *dispatchv1.ResponseMeta, + debugInfos []*dispatchv1.DebugInformation, + results map[string]*dispatchv1.ResourceCheckResult, + ) error { + bulkResponseMutex.Lock() + defer bulkResponseMutex.Unlock() + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + schemaText := "" + if len(debugInfos) > 0 { + schema, err := getFullSchema(ctx, ds) + if err != nil { + return err + } + schemaText = schema + } + + for _, resourceID := range resourceIDs { + var debugTrace *v1.DebugInformation + if len(debugInfos) > 0 { + // Find the debug info that matches the resource ID. + var debugInfo *dispatchv1.DebugInformation + for _, di := range debugInfos { + if slices.Contains(di.Check.Request.ResourceIds, resourceID) { + debugInfo = di + break + } + } + + if debugInfo != nil { + // Synthesize a new debug information with a trace "wrapping" the (potentially batched) + // trace. + localResults := make(map[string]*dispatchv1.ResourceCheckResult, 1) + if result, ok := results[resourceID]; ok { + localResults[resourceID] = result + } + wrappedDebugInfo := &dispatchv1.DebugInformation{ + Check: &dispatchv1.CheckDebugTrace{ + Request: &dispatchv1.DispatchCheckRequest{ + ResourceRelation: debugInfo.Check.Request.ResourceRelation, + ResourceIds: []string{resourceID}, + Subject: debugInfo.Check.Request.Subject, + ResultsSetting: debugInfo.Check.Request.ResultsSetting, + Debug: debugInfo.Check.Request.Debug, + }, + ResourceRelationType: debugInfo.Check.ResourceRelationType, + IsCachedResult: false, + SubProblems: []*dispatchv1.CheckDebugTrace{ + debugInfo.Check, + }, + Results: localResults, + Duration: durationpb.New(time.Duration(0)), + TraceId: graph.NewTraceID(), + SourceId: debugInfo.Check.SourceId, + }, + } + + // Convert to debug information. + dt, err := convertCheckDispatchDebugInformationWithSchema(ctx, params.CaveatContext, wrappedDebugInfo, ds, bc.caveatTypeSet, schemaText) + if err != nil { + return err + } + debugTrace = dt + } + } + + reqItem, err := requestItemFromResourceAndParameters(params, resourceID) + if err != nil { + return err + } + + if err := addPair(&v1.CheckBulkPermissionsPair{ + Request: reqItem, + Response: pairItemFromCheckResult(results[resourceID], debugTrace), + }); err != nil { + return err + } + } + + respMetadata.DispatchCount += metadata.DispatchCount + respMetadata.CachedDispatchCount += metadata.CachedDispatchCount + return nil + } + + for _, group := range groupedItems { + group := group + + slicez.ForEachChunk(group.resourceIDs, bc.dispatchChunkSize, func(resourceIDs []string) { + tr.Add(func(ctx context.Context) error { + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + // Ensure the check namespaces and relations are valid. + err := namespace.CheckNamespaceAndRelations(ctx, + []namespace.TypeAndRelationToCheck{ + { + NamespaceName: group.params.ResourceType.ObjectType, + RelationName: group.params.ResourceType.Relation, + AllowEllipsis: false, + }, + { + NamespaceName: group.params.Subject.ObjectType, + RelationName: stringz.DefaultEmpty(group.params.Subject.Relation, graph.Ellipsis), + AllowEllipsis: true, + }, + }, ds) + if err != nil { + return appendResultsForError(group.params, resourceIDs, err) + } + + // Call bulk check to compute the check result(s) for the resource ID(s). + rcr, metadata, debugInfos, err := computed.ComputeBulkCheck(ctx, bc.dispatch, bc.caveatTypeSet, *group.params, resourceIDs, bc.dispatchChunkSize) + if err != nil { + return appendResultsForError(group.params, resourceIDs, err) + } + + return appendResultsForCheck(group.params, resourceIDs, metadata, debugInfos, rcr) + }) + }) + } + + // Run the checks in parallel. + if err := tr.StartAndWait(); err != nil { + return nil, err + } + + return &v1.CheckBulkPermissionsResponse{CheckedAt: checkedAt, Pairs: orderedPairs}, nil +} + +func toCheckBulkPermissionsRequest(req *v1.BulkCheckPermissionRequest) *v1.CheckBulkPermissionsRequest { + items := make([]*v1.CheckBulkPermissionsRequestItem, len(req.Items)) + for i, item := range req.Items { + items[i] = &v1.CheckBulkPermissionsRequestItem{ + Resource: item.Resource, + Permission: item.Permission, + Subject: item.Subject, + Context: item.Context, + } + } + + return &v1.CheckBulkPermissionsRequest{Items: items} +} + +func toBulkCheckPermissionResponse(resp *v1.CheckBulkPermissionsResponse) *v1.BulkCheckPermissionResponse { + pairs := make([]*v1.BulkCheckPermissionPair, len(resp.Pairs)) + for i, pair := range resp.Pairs { + pairs[i] = &v1.BulkCheckPermissionPair{} + pairs[i].Request = &v1.BulkCheckPermissionRequestItem{ + Resource: pair.Request.Resource, + Permission: pair.Request.Permission, + Subject: pair.Request.Subject, + Context: pair.Request.Context, + } + + switch t := pair.Response.(type) { + case *v1.CheckBulkPermissionsPair_Item: + pairs[i].Response = &v1.BulkCheckPermissionPair_Item{ + Item: &v1.BulkCheckPermissionResponseItem{ + Permissionship: t.Item.Permissionship, + PartialCaveatInfo: t.Item.PartialCaveatInfo, + }, + } + case *v1.CheckBulkPermissionsPair_Error: + pairs[i].Response = &v1.BulkCheckPermissionPair_Error{ + Error: t.Error, + } + default: + panic("unknown CheckBulkPermissionResponse pair response type") + } + } + + return &v1.BulkCheckPermissionResponse{ + CheckedAt: resp.CheckedAt, + Pairs: pairs, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go b/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go new file mode 100644 index 0000000..712f9ec --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/debug.go @@ -0,0 +1,238 @@ +package v1 + +import ( + "cmp" + "context" + "slices" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + cexpr "github.com/authzed/spicedb/internal/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ConvertCheckDispatchDebugInformation converts dispatch debug information found in the response metadata +// into DebugInformation returnable to the API. +func ConvertCheckDispatchDebugInformation( + ctx context.Context, + caveatTypeSet *caveattypes.TypeSet, + caveatContext map[string]any, + debugInfo *dispatch.DebugInformation, + reader datastore.Reader, +) (*v1.DebugInformation, error) { + if debugInfo == nil { + return nil, nil + } + + schema, err := getFullSchema(ctx, reader) + if err != nil { + return nil, err + } + + return convertCheckDispatchDebugInformationWithSchema(ctx, caveatContext, debugInfo, reader, caveatTypeSet, schema) +} + +// getFullSchema returns the full schema from the reader. +func getFullSchema(ctx context.Context, reader datastore.Reader) (string, error) { + caveats, err := reader.ListAllCaveats(ctx) + if err != nil { + return "", err + } + + namespaces, err := reader.ListAllNamespaces(ctx) + if err != nil { + return "", err + } + + defs := make([]compiler.SchemaDefinition, 0, len(namespaces)+len(caveats)) + for _, caveat := range caveats { + defs = append(defs, caveat.Definition) + } + for _, ns := range namespaces { + defs = append(defs, ns.Definition) + } + + schema, _, err := generator.GenerateSchema(defs) + if err != nil { + return "", err + } + + return schema, nil +} + +func convertCheckDispatchDebugInformationWithSchema( + ctx context.Context, + caveatContext map[string]any, + debugInfo *dispatch.DebugInformation, + reader datastore.Reader, + caveatTypeSet *caveattypes.TypeSet, + schema string, +) (*v1.DebugInformation, error) { + converted, err := convertCheckTrace(ctx, caveatContext, debugInfo.Check, reader, caveatTypeSet) + if err != nil { + return nil, err + } + + return &v1.DebugInformation{ + Check: converted, + SchemaUsed: strings.TrimSpace(schema), + }, nil +} + +func convertCheckTrace(ctx context.Context, caveatContext map[string]any, ct *dispatch.CheckDebugTrace, reader datastore.Reader, caveatTypeSet *caveattypes.TypeSet) (*v1.CheckDebugTrace, error) { + permissionType := v1.CheckDebugTrace_PERMISSION_TYPE_UNSPECIFIED + if ct.ResourceRelationType == dispatch.CheckDebugTrace_PERMISSION { + permissionType = v1.CheckDebugTrace_PERMISSION_TYPE_PERMISSION + } else if ct.ResourceRelationType == dispatch.CheckDebugTrace_RELATION { + permissionType = v1.CheckDebugTrace_PERMISSION_TYPE_RELATION + } + + subRelation := ct.Request.Subject.Relation + if subRelation == tuple.Ellipsis { + subRelation = "" + } + + permissionship := v1.CheckDebugTrace_PERMISSIONSHIP_NO_PERMISSION + var partialResults []*dispatch.ResourceCheckResult + for _, checkResult := range ct.Results { + if checkResult.Membership == dispatch.ResourceCheckResult_MEMBER { + permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_HAS_PERMISSION + break + } + + if checkResult.Membership == dispatch.ResourceCheckResult_CAVEATED_MEMBER && permissionship != v1.CheckDebugTrace_PERMISSIONSHIP_HAS_PERMISSION { + permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION + partialResults = append(partialResults, checkResult) + } + } + + var caveatEvalInfo *v1.CaveatEvalInfo + + // NOTE: Bulk check gives the *fully resolved* results, rather than the result pre-caveat + // evaluation. In that case, we skip re-evaluating here. + // TODO(jschorr): Add support for evaluating *each* result distinctly. + if permissionship == v1.CheckDebugTrace_PERMISSIONSHIP_CONDITIONAL_PERMISSION && len(partialResults) == 1 && + len(partialResults[0].MissingExprFields) == 0 { + partialCheckResult := partialResults[0] + spiceerrors.DebugAssertNotNil(partialCheckResult.Expression, "got nil caveat expression") + + computedResult, err := cexpr.RunSingleCaveatExpression(ctx, caveatTypeSet, partialCheckResult.Expression, caveatContext, reader, cexpr.RunCaveatExpressionWithDebugInformation) + if err != nil { + return nil, err + } + + var partialCaveatInfo *v1.PartialCaveatInfo + caveatResult := v1.CaveatEvalInfo_RESULT_FALSE + if computedResult.Value() { + caveatResult = v1.CaveatEvalInfo_RESULT_TRUE + } else if computedResult.IsPartial() { + caveatResult = v1.CaveatEvalInfo_RESULT_MISSING_SOME_CONTEXT + missingNames, _ := computedResult.MissingVarNames() + partialCaveatInfo = &v1.PartialCaveatInfo{ + MissingRequiredContext: missingNames, + } + } + + exprString, contextStruct, err := cexpr.BuildDebugInformation(computedResult) + if err != nil { + return nil, err + } + + caveatName := "" + if partialCheckResult.Expression.GetCaveat() != nil { + caveatName = partialCheckResult.Expression.GetCaveat().CaveatName + } + + caveatEvalInfo = &v1.CaveatEvalInfo{ + Expression: exprString, + Result: caveatResult, + Context: contextStruct, + PartialCaveatInfo: partialCaveatInfo, + CaveatName: caveatName, + } + } + + // If there is more than a single result, mark the overall permissionship + // as unspecified if *all* results needed to be true and at least one is not. + if len(ct.Request.ResourceIds) > 1 && ct.Request.ResultsSetting == dispatch.DispatchCheckRequest_REQUIRE_ALL_RESULTS { + for _, resourceID := range ct.Request.ResourceIds { + if result, ok := ct.Results[resourceID]; !ok || result.Membership != dispatch.ResourceCheckResult_MEMBER { + permissionship = v1.CheckDebugTrace_PERMISSIONSHIP_UNSPECIFIED + break + } + } + } + + if len(ct.SubProblems) > 0 { + subProblems := make([]*v1.CheckDebugTrace, 0, len(ct.SubProblems)) + for _, subProblem := range ct.SubProblems { + converted, err := convertCheckTrace(ctx, caveatContext, subProblem, reader, caveatTypeSet) + if err != nil { + return nil, err + } + + subProblems = append(subProblems, converted) + } + + slices.SortFunc(subProblems, func(a, b *v1.CheckDebugTrace) int { + return cmp.Compare(tuple.V1StringObjectRef(a.Resource), tuple.V1StringObjectRef(a.Resource)) + }) + + return &v1.CheckDebugTrace{ + TraceOperationId: ct.TraceId, + Resource: &v1.ObjectReference{ + ObjectType: ct.Request.ResourceRelation.Namespace, + ObjectId: strings.Join(ct.Request.ResourceIds, ","), + }, + Permission: ct.Request.ResourceRelation.Relation, + PermissionType: permissionType, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: ct.Request.Subject.Namespace, + ObjectId: ct.Request.Subject.ObjectId, + }, + OptionalRelation: subRelation, + }, + CaveatEvaluationInfo: caveatEvalInfo, + Result: permissionship, + Resolution: &v1.CheckDebugTrace_SubProblems_{ + SubProblems: &v1.CheckDebugTrace_SubProblems{ + Traces: subProblems, + }, + }, + Duration: ct.Duration, + Source: ct.SourceId, + }, nil + } + + return &v1.CheckDebugTrace{ + TraceOperationId: ct.TraceId, + Resource: &v1.ObjectReference{ + ObjectType: ct.Request.ResourceRelation.Namespace, + ObjectId: strings.Join(ct.Request.ResourceIds, ","), + }, + Permission: ct.Request.ResourceRelation.Relation, + PermissionType: permissionType, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: ct.Request.Subject.Namespace, + ObjectId: ct.Request.Subject.ObjectId, + }, + OptionalRelation: subRelation, + }, + CaveatEvaluationInfo: caveatEvalInfo, + Result: permissionship, + Resolution: &v1.CheckDebugTrace_WasCachedResult{ + WasCachedResult: ct.IsCachedResult, + }, + Duration: ct.Duration, + Source: ct.SourceId, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go b/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go new file mode 100644 index 0000000..6de6749 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/errors.go @@ -0,0 +1,511 @@ +package v1 + +import ( + "fmt" + "strconv" + + "github.com/rs/zerolog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ExceedsMaximumLimitError occurs when a limit that is too large is given to a call. +type ExceedsMaximumLimitError struct { + error + providedLimit uint64 + maxLimitAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ExceedsMaximumLimitError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("providedLimit", err.providedLimit).Uint64("maxLimitAllowed", err.maxLimitAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ExceedsMaximumLimitError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_EXCEEDS_MAXIMUM_ALLOWABLE_LIMIT, + map[string]string{ + "limit_provided": strconv.FormatUint(err.providedLimit, 10), + "maximum_limit_allowed": strconv.FormatUint(err.maxLimitAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumLimitErr creates a new error representing that the limit specified was too large. +func NewExceedsMaximumLimitErr(providedLimit uint64, maxLimitAllowed uint64) ExceedsMaximumLimitError { + return ExceedsMaximumLimitError{ + error: fmt.Errorf("provided limit %d is greater than maximum allowed of %d", providedLimit, maxLimitAllowed), + providedLimit: providedLimit, + maxLimitAllowed: maxLimitAllowed, + } +} + +// ExceedsMaximumChecksError occurs when too many checks are given to a call. +type ExceedsMaximumChecksError struct { + error + checkCount uint64 + maxCountAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ExceedsMaximumChecksError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("checkCount", err.checkCount).Uint64("maxCountAllowed", err.maxCountAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ExceedsMaximumChecksError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{ + "check_count": strconv.FormatUint(err.checkCount, 10), + "maximum_checks_allowed": strconv.FormatUint(err.maxCountAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumChecksErr creates a new error representing that too many updates were given to a BulkCheckPermissions call. +func NewExceedsMaximumChecksErr(checkCount uint64, maxCountAllowed uint64) ExceedsMaximumChecksError { + return ExceedsMaximumChecksError{ + error: fmt.Errorf("check count of %d is greater than maximum allowed of %d", checkCount, maxCountAllowed), + checkCount: checkCount, + maxCountAllowed: maxCountAllowed, + } +} + +// ExceedsMaximumUpdatesError occurs when too many updates are given to a call. +type ExceedsMaximumUpdatesError struct { + error + updateCount uint64 + maxCountAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ExceedsMaximumUpdatesError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("updateCount", err.updateCount).Uint64("maxCountAllowed", err.maxCountAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ExceedsMaximumUpdatesError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_TOO_MANY_UPDATES_IN_REQUEST, + map[string]string{ + "update_count": strconv.FormatUint(err.updateCount, 10), + "maximum_updates_allowed": strconv.FormatUint(err.maxCountAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumUpdatesErr creates a new error representing that too many updates were given to a WriteRelationships call. +func NewExceedsMaximumUpdatesErr(updateCount uint64, maxCountAllowed uint64) ExceedsMaximumUpdatesError { + return ExceedsMaximumUpdatesError{ + error: fmt.Errorf("update count of %d is greater than maximum allowed of %d", updateCount, maxCountAllowed), + updateCount: updateCount, + maxCountAllowed: maxCountAllowed, + } +} + +// ExceedsMaximumPreconditionsError occurs when too many preconditions are given to a call. +type ExceedsMaximumPreconditionsError struct { + error + preconditionCount uint64 + maxCountAllowed uint64 +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err ExceedsMaximumPreconditionsError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Uint64("preconditionCount", err.preconditionCount).Uint64("maxCountAllowed", err.maxCountAllowed) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ExceedsMaximumPreconditionsError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_TOO_MANY_PRECONDITIONS_IN_REQUEST, + map[string]string{ + "precondition_count": strconv.FormatUint(err.preconditionCount, 10), + "maximum_updates_allowed": strconv.FormatUint(err.maxCountAllowed, 10), + }, + ), + ) +} + +// NewExceedsMaximumPreconditionsErr creates a new error representing that too many preconditions were given to a call. +func NewExceedsMaximumPreconditionsErr(preconditionCount uint64, maxCountAllowed uint64) ExceedsMaximumPreconditionsError { + return ExceedsMaximumPreconditionsError{ + error: fmt.Errorf( + "precondition count of %d is greater than maximum allowed of %d", + preconditionCount, + maxCountAllowed), + preconditionCount: preconditionCount, + maxCountAllowed: maxCountAllowed, + } +} + +// PreconditionFailedError occurs when the precondition to a write tuple call does not match. +type PreconditionFailedError struct { + error + precondition *v1.Precondition +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err PreconditionFailedError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Interface("precondition", err.precondition) +} + +// NewPreconditionFailedErr constructs a new precondition failed error. +func NewPreconditionFailedErr(precondition *v1.Precondition) error { + return PreconditionFailedError{ + error: fmt.Errorf("unable to satisfy write precondition `%s`", precondition), + precondition: precondition, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err PreconditionFailedError) GRPCStatus() *status.Status { + metadata := map[string]string{ + "precondition_operation": v1.Precondition_Operation_name[int32(err.precondition.Operation)], + } + + if err.precondition.Filter.ResourceType != "" { + metadata["precondition_resource_type"] = err.precondition.Filter.ResourceType + } + + if err.precondition.Filter.OptionalResourceId != "" { + metadata["precondition_resource_id"] = err.precondition.Filter.OptionalResourceId + } + + if err.precondition.Filter.OptionalResourceIdPrefix != "" { + metadata["precondition_resource_id_prefix"] = err.precondition.Filter.OptionalResourceIdPrefix + } + + if err.precondition.Filter.OptionalRelation != "" { + metadata["precondition_relation"] = err.precondition.Filter.OptionalRelation + } + + if err.precondition.Filter.OptionalSubjectFilter != nil { + metadata["precondition_subject_type"] = err.precondition.Filter.OptionalSubjectFilter.SubjectType + + if err.precondition.Filter.OptionalSubjectFilter.OptionalSubjectId != "" { + metadata["precondition_subject_id"] = err.precondition.Filter.OptionalSubjectFilter.OptionalSubjectId + } + + if err.precondition.Filter.OptionalSubjectFilter.OptionalRelation != nil { + metadata["precondition_subject_relation"] = err.precondition.Filter.OptionalSubjectFilter.OptionalRelation.Relation + } + } + + return spiceerrors.WithCodeAndDetails( + err, + codes.FailedPrecondition, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_WRITE_OR_DELETE_PRECONDITION_FAILURE, + metadata, + ), + ) +} + +// DuplicateRelationErrorshipError indicates that an update was attempted on the same relationship. +type DuplicateRelationErrorshipError struct { + error + update *v1.RelationshipUpdate +} + +// NewDuplicateRelationshipErr constructs a new invalid subject error. +func NewDuplicateRelationshipErr(update *v1.RelationshipUpdate) DuplicateRelationErrorshipError { + return DuplicateRelationErrorshipError{ + error: fmt.Errorf( + "found more than one update with relationship `%s` in this request; a relationship can only be specified in an update once per overall WriteRelationships request", + tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship), + ), + update: update, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err DuplicateRelationErrorshipError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UPDATES_ON_SAME_RELATIONSHIP, + map[string]string{ + "definition_name": err.update.Relationship.Resource.ObjectType, + "relationship": tuple.MustV1StringRelationship(err.update.Relationship), + }, + ), + ) +} + +// ErrMaxRelationshipContextError indicates an attempt to write a relationship that exceeded the maximum +// configured context size. +type ErrMaxRelationshipContextError struct { + error + update *v1.RelationshipUpdate + maxAllowedSize int +} + +// NewMaxRelationshipContextError constructs a new max relationship context error. +func NewMaxRelationshipContextError(update *v1.RelationshipUpdate, maxAllowedSize int) ErrMaxRelationshipContextError { + return ErrMaxRelationshipContextError{ + error: fmt.Errorf( + "provided relationship `%s` exceeded maximum allowed caveat size of %d", + tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship), + maxAllowedSize, + ), + update: update, + maxAllowedSize: maxAllowedSize, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err ErrMaxRelationshipContextError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_MAX_RELATIONSHIP_CONTEXT_SIZE, + map[string]string{ + "relationship": tuple.V1StringRelationshipWithoutCaveatOrExpiration(err.update.Relationship), + "max_allowed_size": strconv.Itoa(err.maxAllowedSize), + "context_size": strconv.Itoa(proto.Size(err.update.Relationship)), + }, + ), + ) +} + +// CouldNotTransactionallyDeleteError indicates that a deletion could not occur transactionally. +type CouldNotTransactionallyDeleteError struct { + error + limit uint32 + filter *v1.RelationshipFilter +} + +// NewCouldNotTransactionallyDeleteErr constructs a new could not transactionally deleter error. +func NewCouldNotTransactionallyDeleteErr(filter *v1.RelationshipFilter, limit uint32) CouldNotTransactionallyDeleteError { + return CouldNotTransactionallyDeleteError{ + error: fmt.Errorf( + "found more than %d relationships to be deleted and partial deletion was not requested", + limit, + ), + limit: limit, + filter: filter, + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err CouldNotTransactionallyDeleteError) GRPCStatus() *status.Status { + metadata := map[string]string{ + "limit": strconv.Itoa(int(err.limit)), + "filter_resource_type": err.filter.ResourceType, + } + + if err.filter.OptionalResourceId != "" { + metadata["filter_resource_id"] = err.filter.OptionalResourceId + } + + if err.filter.OptionalRelation != "" { + metadata["filter_relation"] = err.filter.OptionalRelation + } + + if err.filter.OptionalSubjectFilter != nil { + metadata["filter_subject_type"] = err.filter.OptionalSubjectFilter.SubjectType + + if err.filter.OptionalSubjectFilter.OptionalSubjectId != "" { + metadata["filter_subject_id"] = err.filter.OptionalSubjectFilter.OptionalSubjectId + } + + if err.filter.OptionalSubjectFilter.OptionalRelation != nil { + metadata["filter_subject_relation"] = err.filter.OptionalSubjectFilter.OptionalRelation.Relation + } + } + + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_TOO_MANY_RELATIONSHIPS_FOR_TRANSACTIONAL_DELETE, + metadata, + ), + ) +} + +// InvalidCursorError indicates that an invalid cursor was found. +type InvalidCursorError struct { + error + reason string +} + +// NewInvalidCursorErr constructs a new invalid cursor error. +func NewInvalidCursorErr(reason string) InvalidCursorError { + return InvalidCursorError{ + error: fmt.Errorf( + "the cursor provided is not valid: %s", + reason, + ), + } +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err InvalidCursorError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.FailedPrecondition, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_INVALID_CURSOR, + map[string]string{ + "reason": err.reason, + }, + ), + ) +} + +// InvalidFilterError indicates the specified relationship filter was invalid. +type InvalidFilterError struct { + error + + filter string +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err InvalidFilterError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_INVALID_FILTER, + map[string]string{ + "filter": err.filter, + }, + ), + ) +} + +// NewInvalidFilterErr constructs a new invalid filter error. +func NewInvalidFilterErr(reason string, filter string) InvalidFilterError { + return InvalidFilterError{ + error: fmt.Errorf( + "the relationship filter provided is not valid: %s", reason, + ), + filter: filter, + } +} + +// NewEmptyPreconditionErr constructs a new empty precondition error. +func NewEmptyPreconditionErr() EmptyPreconditionError { + return EmptyPreconditionError{ + error: fmt.Errorf( + "one of the specified preconditions is empty", + ), + } +} + +// EmptyPreconditionError indicates an empty precondition was found. +type EmptyPreconditionError struct { + error +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err EmptyPreconditionError) GRPCStatus() *status.Status { + // TODO(jschorr): Put a proper error reason in here. + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNSPECIFIED, + map[string]string{}, + ), + ) +} + +// NewNotAPermissionError constructs a new not a permission error. +func NewNotAPermissionError(relationName string) NotAPermissionError { + return NotAPermissionError{ + error: fmt.Errorf( + "the relation `%s` is not a permission", relationName, + ), + relationName: relationName, + } +} + +// NotAPermissionError indicates that the relation is not a permission. +type NotAPermissionError struct { + error + relationName string +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err NotAPermissionError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_UNKNOWN_RELATION_OR_PERMISSION, + map[string]string{ + "relationName": err.relationName, + }, + ), + ) +} + +func defaultIfZero[T comparable](value T, defaultValue T) T { + var zero T + if value == zero { + return defaultValue + } + return value +} + +// TransactionMetadataTooLargeError indicates that the metadata for a transaction is too large. +type TransactionMetadataTooLargeError struct { + error + metadataSize int + maxSize int +} + +// NewTransactionMetadataTooLargeErr constructs a new transaction metadata too large error. +func NewTransactionMetadataTooLargeErr(metadataSize int, maxSize int) TransactionMetadataTooLargeError { + return TransactionMetadataTooLargeError{ + error: fmt.Errorf("metadata size of %d is greater than maximum allowed of %d", metadataSize, maxSize), + metadataSize: metadataSize, + maxSize: maxSize, + } +} + +func (err TransactionMetadataTooLargeError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Int("metadataSize", err.metadataSize).Int("maxSize", err.maxSize) +} + +func (err TransactionMetadataTooLargeError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_TRANSACTION_METADATA_TOO_LARGE, + map[string]string{ + "metadata_byte_size": strconv.Itoa(err.metadataSize), + "maximum_allowed_metadata_byte_size": strconv.Itoa(err.maxSize), + }, + ), + ) +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go b/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go new file mode 100644 index 0000000..0e4b4a7 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/experimental.go @@ -0,0 +1,824 @@ +package v1 + +import ( + "context" + "errors" + "fmt" + "io" + "slices" + "sort" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/ccoveille/go-safecast" + "github.com/jzelinskie/stringz" + + "github.com/authzed/spicedb/internal/dispatch" + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/middleware" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/handwrittenvalidation" + "github.com/authzed/spicedb/internal/middleware/streamtimeout" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/relationships" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/internal/services/v1/options" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "github.com/samber/lo" +) + +const ( + defaultExportBatchSizeFallback = 1_000 + maxExportBatchSizeFallback = 10_000 + streamReadTimeoutFallbackSeconds = 600 +) + +// NewExperimentalServer creates a ExperimentalServiceServer instance. +func NewExperimentalServer(dispatch dispatch.Dispatcher, permServerConfig PermissionsServerConfig, opts ...options.ExperimentalServerOptionsOption) v1.ExperimentalServiceServer { + config := options.NewExperimentalServerOptionsWithOptionsAndDefaults(opts...) + if config.DefaultExportBatchSize == 0 { + log. + Warn(). + Uint32("specified", config.DefaultExportBatchSize). + Uint32("fallback", defaultExportBatchSizeFallback). + Msg("experimental server config specified invalid DefaultExportBatchSize, setting to fallback") + config.DefaultExportBatchSize = defaultExportBatchSizeFallback + } + if config.MaxExportBatchSize == 0 { + fallback := permServerConfig.MaxBulkExportRelationshipsLimit + if fallback == 0 { + fallback = maxExportBatchSizeFallback + } + + log. + Warn(). + Uint32("specified", config.MaxExportBatchSize). + Uint32("fallback", fallback). + Msg("experimental server config specified invalid MaxExportBatchSize, setting to fallback") + config.MaxExportBatchSize = fallback + } + if config.StreamReadTimeout == 0 { + log. + Warn(). + Stringer("specified", config.StreamReadTimeout). + Stringer("fallback", streamReadTimeoutFallbackSeconds*time.Second). + Msg("experimental server config specified invalid StreamReadTimeout, setting to fallback") + config.StreamReadTimeout = streamReadTimeoutFallbackSeconds * time.Second + } + + chunkSize := permServerConfig.DispatchChunkSize + if chunkSize == 0 { + log. + Warn(). + Msg("experimental server config specified invalid DispatchChunkSize, defaulting to 100") + chunkSize = 100 + } + + return &experimentalServer{ + WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{ + Unary: middleware.ChainUnaryServer( + grpcvalidate.UnaryServerInterceptor(), + handwrittenvalidation.UnaryServerInterceptor, + usagemetrics.UnaryServerInterceptor(), + ), + Stream: middleware.ChainStreamServer( + grpcvalidate.StreamServerInterceptor(), + handwrittenvalidation.StreamServerInterceptor, + usagemetrics.StreamServerInterceptor(), + streamtimeout.MustStreamServerInterceptor(config.StreamReadTimeout), + ), + }, + maxBatchSize: uint64(config.MaxExportBatchSize), + caveatTypeSet: caveattypes.TypeSetOrDefault(permServerConfig.CaveatTypeSet), + bulkChecker: &bulkChecker{ + maxAPIDepth: permServerConfig.MaximumAPIDepth, + maxCaveatContextSize: permServerConfig.MaxCaveatContextSize, + maxConcurrency: config.BulkCheckMaxConcurrency, + dispatch: dispatch, + dispatchChunkSize: chunkSize, + caveatTypeSet: caveattypes.TypeSetOrDefault(permServerConfig.CaveatTypeSet), + }, + } +} + +type experimentalServer struct { + v1.UnimplementedExperimentalServiceServer + shared.WithServiceSpecificInterceptors + + maxBatchSize uint64 + + bulkChecker *bulkChecker + caveatTypeSet *caveattypes.TypeSet +} + +type bulkLoadAdapter struct { + stream v1.ExperimentalService_BulkImportRelationshipsServer + referencedNamespaceMap map[string]*schema.Definition + referencedCaveatMap map[string]*core.CaveatDefinition + current tuple.Relationship + caveat core.ContextualizedCaveat + caveatTypeSet *caveattypes.TypeSet + + awaitingNamespaces []string + awaitingCaveats []string + + currentBatch []*v1.Relationship + numSent int + err error +} + +func (a *bulkLoadAdapter) Next(_ context.Context) (*tuple.Relationship, error) { + for a.err == nil && a.numSent == len(a.currentBatch) { + // Load a new batch + batch, err := a.stream.Recv() + if err != nil { + a.err = err + if errors.Is(a.err, io.EOF) { + return nil, nil + } + return nil, a.err + } + + a.currentBatch = batch.Relationships + a.numSent = 0 + + a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats( + a.currentBatch, + a.referencedNamespaceMap, + a.referencedCaveatMap, + ) + } + + if len(a.awaitingNamespaces) > 0 || len(a.awaitingCaveats) > 0 { + // Shut down the stream to give our caller a chance to fill in this information + return nil, nil + } + + a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType + a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId + a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation + a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType + a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId + a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis) + + if a.currentBatch[a.numSent].OptionalCaveat != nil { + a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName + a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context + a.current.OptionalCaveat = &a.caveat + } else { + a.current.OptionalCaveat = nil + } + + if a.currentBatch[a.numSent].OptionalExpiresAt != nil { + t := a.currentBatch[a.numSent].OptionalExpiresAt.AsTime() + a.current.OptionalExpiration = &t + } else { + a.current.OptionalExpiration = nil + } + + a.current.OptionalIntegrity = nil + + if err := relationships.ValidateOneRelationship( + a.referencedNamespaceMap, + a.referencedCaveatMap, + a.caveatTypeSet, + a.current, + relationships.ValidateRelationshipForCreateOrTouch, + ); err != nil { + return nil, err + } + + a.numSent++ + return &a.current, nil +} + +func extractBatchNewReferencedNamespacesAndCaveats( + batch []*v1.Relationship, + existingNamespaces map[string]*schema.Definition, + existingCaveats map[string]*core.CaveatDefinition, +) ([]string, []string) { + newNamespaces := make(map[string]struct{}, 2) + newCaveats := make(map[string]struct{}, 0) + for _, rel := range batch { + if _, ok := existingNamespaces[rel.Resource.ObjectType]; !ok { + newNamespaces[rel.Resource.ObjectType] = struct{}{} + } + if _, ok := existingNamespaces[rel.Subject.Object.ObjectType]; !ok { + newNamespaces[rel.Subject.Object.ObjectType] = struct{}{} + } + if rel.OptionalCaveat != nil { + if _, ok := existingCaveats[rel.OptionalCaveat.CaveatName]; !ok { + newCaveats[rel.OptionalCaveat.CaveatName] = struct{}{} + } + } + } + + return lo.Keys(newNamespaces), lo.Keys(newCaveats) +} + +// TODO: this is now duplicate code with ImportBulkRelationships +func (es *experimentalServer) BulkImportRelationships(stream v1.ExperimentalService_BulkImportRelationshipsServer) error { + ds := datastoremw.MustFromContext(stream.Context()) + + var numWritten uint64 + if _, err := ds.ReadWriteTx(stream.Context(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + loadedNamespaces := make(map[string]*schema.Definition, 2) + loadedCaveats := make(map[string]*core.CaveatDefinition, 0) + + adapter := &bulkLoadAdapter{ + stream: stream, + referencedNamespaceMap: loadedNamespaces, + referencedCaveatMap: loadedCaveats, + current: tuple.Relationship{}, + caveat: core.ContextualizedCaveat{}, + caveatTypeSet: es.caveatTypeSet, + } + resolver := schema.ResolverForDatastoreReader(rwt) + ts := schema.NewTypeSystem(resolver) + + var streamWritten uint64 + var err error + for ; adapter.err == nil && err == nil; streamWritten, err = rwt.BulkLoad(stream.Context(), adapter) { + numWritten += streamWritten + + // The stream has terminated because we're awaiting namespace and/or caveat information + if len(adapter.awaitingNamespaces) > 0 { + nsDefs, err := rwt.LookupNamespacesWithNames(stream.Context(), adapter.awaitingNamespaces) + if err != nil { + return err + } + + for _, nsDef := range nsDefs { + newDef, err := schema.NewDefinition(ts, nsDef.Definition) + if err != nil { + return err + } + + loadedNamespaces[nsDef.Definition.Name] = newDef + } + adapter.awaitingNamespaces = nil + } + + if len(adapter.awaitingCaveats) > 0 { + caveats, err := rwt.LookupCaveatsWithNames(stream.Context(), adapter.awaitingCaveats) + if err != nil { + return err + } + + for _, caveat := range caveats { + loadedCaveats[caveat.Definition.Name] = caveat.Definition + } + adapter.awaitingCaveats = nil + } + } + numWritten += streamWritten + + return err + }, dsoptions.WithDisableRetries(true)); err != nil { + return shared.RewriteErrorWithoutConfig(stream.Context(), err) + } + + usagemetrics.SetInContext(stream.Context(), &dispatchv1.ResponseMeta{ + // One request for the whole load + DispatchCount: 1, + }) + + return stream.SendAndClose(&v1.BulkImportRelationshipsResponse{ + NumLoaded: numWritten, + }) +} + +// TODO: this is now duplicate code with ExportBulkRelationships +func (es *experimentalServer) BulkExportRelationships( + req *v1.BulkExportRelationshipsRequest, + resp grpc.ServerStreamingServer[v1.BulkExportRelationshipsResponse], +) error { + ctx := resp.Context() + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + return BulkExport(ctx, datastoremw.MustFromContext(ctx), es.maxBatchSize, req, atRevision, resp.Send) +} + +// BulkExport implements the BulkExportRelationships API functionality. Given a datastore.Datastore, it will +// export stream via the sender all relationships matched by the incoming request. +// If no cursor is provided, it will fallback to the provided revision. +func BulkExport(ctx context.Context, ds datastore.ReadOnlyDatastore, batchSize uint64, req *v1.BulkExportRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.BulkExportRelationshipsResponse) error) error { + if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { + return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) + } + + atRevision := fallbackRevision + var curNamespace string + var cur dsoptions.Cursor + if req.OptionalCursor != nil { + var err error + atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + } + + reader := ds.SnapshotReader(atRevision) + + namespaces, err := reader.ListAllNamespaces(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Make sure the namespaces are always in a stable order + slices.SortFunc(namespaces, func( + lhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + rhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + ) int { + return strings.Compare(lhs.Definition.Name, rhs.Definition.Name) + }) + + // Skip the namespaces that are already fully returned + for cur != nil && len(namespaces) > 0 && namespaces[0].Definition.Name < curNamespace { + namespaces = namespaces[1:] + } + + limit := batchSize + if req.OptionalLimit > 0 { + limit = uint64(req.OptionalLimit) + } + + // Pre-allocate all of the relationships that we might need in order to + // make export easier and faster for the garbage collector. + relsArray := make([]v1.Relationship, limit) + objArray := make([]v1.ObjectReference, limit) + subArray := make([]v1.SubjectReference, limit) + subObjArray := make([]v1.ObjectReference, limit) + caveatArray := make([]v1.ContextualizedCaveat, limit) + for i := range relsArray { + relsArray[i].Resource = &objArray[i] + relsArray[i].Subject = &subArray[i] + relsArray[i].Subject.Object = &subObjArray[i] + } + + emptyRels := make([]*v1.Relationship, limit) + for _, ns := range namespaces { + rels := emptyRels + + // Reset the cursor between namespaces. + if ns.Definition.Name != curNamespace { + cur = nil + } + + // Skip this namespace if a resource type filter was specified. + if req.OptionalRelationshipFilter != nil && req.OptionalRelationshipFilter.ResourceType != "" { + if ns.Definition.Name != req.OptionalRelationshipFilter.ResourceType { + continue + } + } + + // Setup the filter to use for the relationships. + relationshipFilter := datastore.RelationshipsFilter{OptionalResourceType: ns.Definition.Name} + if req.OptionalRelationshipFilter != nil { + rf, err := datastore.RelationshipsFilterFromPublicFilter(req.OptionalRelationshipFilter) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Overload the namespace name with the one from the request, because each iteration is for a different namespace. + rf.OptionalResourceType = ns.Definition.Name + relationshipFilter = rf + } + + // We want to keep iterating as long as we're sending full batches. + // To bootstrap this loop, we enter the first time with a full rels + // slice of dummy rels that were never sent. + for uint64(len(rels)) == limit { + // Lop off any rels we've already sent + rels = rels[:0] + + relFn := func(rel tuple.Relationship) { + offset := len(rels) + rels = append(rels, &relsArray[offset]) // nozero + + v1Rel := &relsArray[offset] + v1Rel.Resource.ObjectType = rel.RelationshipReference.Resource.ObjectType + v1Rel.Resource.ObjectId = rel.RelationshipReference.Resource.ObjectID + v1Rel.Relation = rel.RelationshipReference.Resource.Relation + v1Rel.Subject.Object.ObjectType = rel.RelationshipReference.Subject.ObjectType + v1Rel.Subject.Object.ObjectId = rel.RelationshipReference.Subject.ObjectID + v1Rel.Subject.OptionalRelation = denormalizeSubjectRelation(rel.RelationshipReference.Subject.Relation) + + if rel.OptionalCaveat != nil { + caveatArray[offset].CaveatName = rel.OptionalCaveat.CaveatName + caveatArray[offset].Context = rel.OptionalCaveat.Context + v1Rel.OptionalCaveat = &caveatArray[offset] + } else { + v1Rel.OptionalCaveat = nil + } + + if rel.OptionalExpiration != nil { + v1Rel.OptionalExpiresAt = timestamppb.New(*rel.OptionalExpiration) + } else { + v1Rel.OptionalExpiresAt = nil + } + } + + cur, err = queryForEach( + ctx, + reader, + relationshipFilter, + relFn, + dsoptions.WithLimit(&limit), + dsoptions.WithAfter(cur), + dsoptions.WithSort(dsoptions.ByResource), + dsoptions.WithQueryShape(queryshape.Varying), + ) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if len(rels) == 0 { + continue + } + + encoded, err := cursor.Encode(&implv1.DecodedCursor{ + VersionOneof: &implv1.DecodedCursor_V1{ + V1: &implv1.V1Cursor{ + Revision: atRevision.String(), + Sections: []string{ + ns.Definition.Name, + tuple.MustString(*dsoptions.ToRelationship(cur)), + }, + }, + }, + }) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if err := sender(&v1.BulkExportRelationshipsResponse{ + AfterResultCursor: encoded, + Relationships: rels, + }); err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + } + } + return nil +} + +func (es *experimentalServer) BulkCheckPermission(ctx context.Context, req *v1.BulkCheckPermissionRequest) (*v1.BulkCheckPermissionResponse, error) { + convertedReq := toCheckBulkPermissionsRequest(req) + res, err := es.bulkChecker.checkBulkPermissions(ctx, convertedReq) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return toBulkCheckPermissionResponse(res), nil +} + +func (es *experimentalServer) ExperimentalReflectSchema(ctx context.Context, req *v1.ExperimentalReflectSchemaRequest) (*v1.ExperimentalReflectSchemaResponse, error) { + // Get the current schema. + schema, atRevision, err := loadCurrentSchema(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + filters, err := newexpSchemaFilters(req.OptionalFilters) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + definitions := make([]*v1.ExpDefinition, 0, len(schema.ObjectDefinitions)) + if filters.HasNamespaces() { + for _, ns := range schema.ObjectDefinitions { + def, err := expNamespaceAPIRepr(ns, filters) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + if def != nil { + definitions = append(definitions, def) + } + } + } + + caveats := make([]*v1.ExpCaveat, 0, len(schema.CaveatDefinitions)) + if filters.HasCaveats() { + for _, cd := range schema.CaveatDefinitions { + caveat, err := expCaveatAPIRepr(cd, filters, es.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + if caveat != nil { + caveats = append(caveats, caveat) + } + } + } + + return &v1.ExperimentalReflectSchemaResponse{ + Definitions: definitions, + Caveats: caveats, + ReadAt: zedtoken.MustNewFromRevision(atRevision), + }, nil +} + +func (es *experimentalServer) ExperimentalDiffSchema(ctx context.Context, req *v1.ExperimentalDiffSchemaRequest) (*v1.ExperimentalDiffSchemaResponse, error) { + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, err + } + + diff, existingSchema, comparisonSchema, err := schemaDiff(ctx, req.ComparisonSchema, es.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + resp, err := expConvertDiff(diff, existingSchema, comparisonSchema, atRevision, es.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return resp, nil +} + +func (es *experimentalServer) ExperimentalComputablePermissions(ctx context.Context, req *v1.ExperimentalComputablePermissionsRequest) (*v1.ExperimentalComputablePermissionsResponse, error) { + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) + vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relationName := req.RelationName + if relationName == "" { + relationName = tuple.Ellipsis + } else { + if _, ok := vdef.GetRelation(relationName); !ok { + return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, relationName)) + } + } + + allNamespaces, err := ds.ListAllNamespaces(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + allDefinitions := make([]*core.NamespaceDefinition, 0, len(allNamespaces)) + for _, ns := range allNamespaces { + allDefinitions = append(allDefinitions, ns.Definition) + } + + rg := vdef.Reachability() + rr, err := rg.RelationsEncounteredForSubject(ctx, allDefinitions, &core.RelationReference{ + Namespace: req.DefinitionName, + Relation: relationName, + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations := make([]*v1.ExpRelationReference, 0, len(rr)) + for _, r := range rr { + if r.Namespace == req.DefinitionName && r.Relation == req.RelationName { + continue + } + + if req.OptionalDefinitionNameFilter != "" && !strings.HasPrefix(r.Namespace, req.OptionalDefinitionNameFilter) { + continue + } + + def, err := ts.GetValidatedDefinition(ctx, r.Namespace) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations = append(relations, &v1.ExpRelationReference{ + DefinitionName: r.Namespace, + RelationName: r.Relation, + IsPermission: def.IsPermission(r.Relation), + }) + } + + sort.Slice(relations, func(i, j int) bool { + if relations[i].DefinitionName == relations[j].DefinitionName { + return relations[i].RelationName < relations[j].RelationName + } + return relations[i].DefinitionName < relations[j].DefinitionName + }) + + return &v1.ExperimentalComputablePermissionsResponse{ + Permissions: relations, + ReadAt: revisionReadAt, + }, nil +} + +func (es *experimentalServer) ExperimentalDependentRelations(ctx context.Context, req *v1.ExperimentalDependentRelationsRequest) (*v1.ExperimentalDependentRelationsResponse, error) { + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) + vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + _, ok := vdef.GetRelation(req.PermissionName) + if !ok { + return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, req.PermissionName)) + } + + if !vdef.IsPermission(req.PermissionName) { + return nil, shared.RewriteErrorWithoutConfig(ctx, NewNotAPermissionError(req.PermissionName)) + } + + rg := vdef.Reachability() + rr, err := rg.RelationsEncounteredForResource(ctx, &core.RelationReference{ + Namespace: req.DefinitionName, + Relation: req.PermissionName, + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations := make([]*v1.ExpRelationReference, 0, len(rr)) + for _, r := range rr { + if r.Namespace == req.DefinitionName && r.Relation == req.PermissionName { + continue + } + + ts, err := ts.GetDefinition(ctx, r.Namespace) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations = append(relations, &v1.ExpRelationReference{ + DefinitionName: r.Namespace, + RelationName: r.Relation, + IsPermission: ts.IsPermission(r.Relation), + }) + } + + sort.Slice(relations, func(i, j int) bool { + if relations[i].DefinitionName == relations[j].DefinitionName { + return relations[i].RelationName < relations[j].RelationName + } + + return relations[i].DefinitionName < relations[j].DefinitionName + }) + + return &v1.ExperimentalDependentRelationsResponse{ + Relations: relations, + ReadAt: revisionReadAt, + }, nil +} + +func (es *experimentalServer) ExperimentalRegisterRelationshipCounter(ctx context.Context, req *v1.ExperimentalRegisterRelationshipCounterRequest) (*v1.ExperimentalRegisterRelationshipCounterResponse, error) { + ds := datastoremw.MustFromContext(ctx) + + if req.Name == "" { + return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED)) + } + + _, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, rwt); err != nil { + return err + } + + coreFilter := datastore.CoreFilterFromRelationshipFilter(req.RelationshipFilter) + return rwt.RegisterCounter(ctx, req.Name, coreFilter) + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return &v1.ExperimentalRegisterRelationshipCounterResponse{}, nil +} + +func (es *experimentalServer) ExperimentalUnregisterRelationshipCounter(ctx context.Context, req *v1.ExperimentalUnregisterRelationshipCounterRequest) (*v1.ExperimentalUnregisterRelationshipCounterResponse, error) { + ds := datastoremw.MustFromContext(ctx) + + if req.Name == "" { + return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED)) + } + + _, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.UnregisterCounter(ctx, req.Name) + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return &v1.ExperimentalUnregisterRelationshipCounterResponse{}, nil +} + +func (es *experimentalServer) ExperimentalCountRelationships(ctx context.Context, req *v1.ExperimentalCountRelationshipsRequest) (*v1.ExperimentalCountRelationshipsResponse, error) { + if req.Name == "" { + return nil, shared.RewriteErrorWithoutConfig(ctx, spiceerrors.WithCodeAndReason(errors.New("name must be provided"), codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_UNSPECIFIED)) + } + + ds := datastoremw.MustFromContext(ctx) + headRev, err := ds.HeadRevision(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + snapshotReader := ds.SnapshotReader(headRev) + count, err := snapshotReader.CountRelationships(ctx, req.Name) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + uintCount, err := safecast.ToUint64(count) + if err != nil { + return nil, spiceerrors.MustBugf("count should not be negative") + } + + return &v1.ExperimentalCountRelationshipsResponse{ + CounterResult: &v1.ExperimentalCountRelationshipsResponse_ReadCounterValue{ + ReadCounterValue: &v1.ReadCounterValue{ + RelationshipCount: uintCount, + ReadAt: zedtoken.MustNewFromRevision(headRev), + }, + }, + }, nil +} + +func queryForEach( + ctx context.Context, + reader datastore.Reader, + filter datastore.RelationshipsFilter, + fn func(rel tuple.Relationship), + opts ...dsoptions.QueryOptionsOption, +) (dsoptions.Cursor, error) { + iter, err := reader.QueryRelationships(ctx, filter, opts...) + if err != nil { + return nil, err + } + + var cursor dsoptions.Cursor + for rel, err := range iter { + if err != nil { + return nil, err + } + + fn(rel) + cursor = dsoptions.ToCursor(rel) + } + return cursor, nil +} + +func decodeCursor(ds datastore.ReadOnlyDatastore, encoded *v1.Cursor) (datastore.Revision, string, dsoptions.Cursor, error) { + decoded, err := cursor.Decode(encoded) + if err != nil { + return datastore.NoRevision, "", nil, err + } + + if decoded.GetV1() == nil { + return datastore.NoRevision, "", nil, errors.New("malformed cursor: no V1 in OneOf") + } + + if len(decoded.GetV1().Sections) != 2 { + return datastore.NoRevision, "", nil, errors.New("malformed cursor: wrong number of components") + } + + atRevision, err := ds.RevisionFromString(decoded.GetV1().Revision) + if err != nil { + return datastore.NoRevision, "", nil, err + } + + cur, err := tuple.Parse(decoded.GetV1().GetSections()[1]) + if err != nil { + return datastore.NoRevision, "", nil, fmt.Errorf("malformed cursor: invalid encoded relation tuple: %w", err) + } + + // Returns the current namespace and the cursor. + return atRevision, decoded.GetV1().GetSections()[0], dsoptions.ToCursor(cur), nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go b/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go new file mode 100644 index 0000000..8ef6c25 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/expreflection.go @@ -0,0 +1,720 @@ +package v1 + +import ( + "sort" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/diff" + caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats" + nsdiff "github.com/authzed/spicedb/pkg/diff/namespace" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +type expSchemaFilters struct { + filters []*v1.ExpSchemaFilter +} + +func newexpSchemaFilters(filters []*v1.ExpSchemaFilter) (*expSchemaFilters, error) { + for _, filter := range filters { + if filter.OptionalDefinitionNameFilter != "" { + if filter.OptionalCaveatNameFilter != "" { + return nil, NewInvalidFilterErr("cannot filter by both definition and caveat name", filter.String()) + } + } + + if filter.OptionalRelationNameFilter != "" { + if filter.OptionalDefinitionNameFilter == "" { + return nil, NewInvalidFilterErr("relation name match requires definition name match", filter.String()) + } + + if filter.OptionalPermissionNameFilter != "" { + return nil, NewInvalidFilterErr("cannot filter by both relation and permission name", filter.String()) + } + } + + if filter.OptionalPermissionNameFilter != "" { + if filter.OptionalDefinitionNameFilter == "" { + return nil, NewInvalidFilterErr("permission name match requires definition name match", filter.String()) + } + } + } + + return &expSchemaFilters{filters: filters}, nil +} + +func (sf *expSchemaFilters) HasNamespaces() bool { + if len(sf.filters) == 0 { + return true + } + + for _, filter := range sf.filters { + if filter.OptionalDefinitionNameFilter != "" { + return true + } + } + + return false +} + +func (sf *expSchemaFilters) HasCaveats() bool { + if len(sf.filters) == 0 { + return true + } + + for _, filter := range sf.filters { + if filter.OptionalCaveatNameFilter != "" { + return true + } + } + + return false +} + +func (sf *expSchemaFilters) HasNamespace(namespaceName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasDefinitionFilter := false + for _, filter := range sf.filters { + if filter.OptionalDefinitionNameFilter == "" { + continue + } + + hasDefinitionFilter = true + isMatch := strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasDefinitionFilter +} + +func (sf *expSchemaFilters) HasCaveat(caveatName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasCaveatFilter := false + for _, filter := range sf.filters { + if filter.OptionalCaveatNameFilter == "" { + continue + } + + hasCaveatFilter = true + isMatch := strings.HasPrefix(caveatName, filter.OptionalCaveatNameFilter) + if isMatch { + return true + } + } + + return !hasCaveatFilter +} + +func (sf *expSchemaFilters) HasRelation(namespaceName, relationName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasRelationFilter := false + for _, filter := range sf.filters { + if filter.OptionalRelationNameFilter == "" { + continue + } + + hasRelationFilter = true + isMatch := strings.HasPrefix(relationName, filter.OptionalRelationNameFilter) + if !isMatch { + continue + } + + isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasRelationFilter +} + +func (sf *expSchemaFilters) HasPermission(namespaceName, permissionName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasPermissionFilter := false + for _, filter := range sf.filters { + if filter.OptionalPermissionNameFilter == "" { + continue + } + + hasPermissionFilter = true + isMatch := strings.HasPrefix(permissionName, filter.OptionalPermissionNameFilter) + if !isMatch { + continue + } + + isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasPermissionFilter +} + +// expConvertDiff converts a schema diff into an API response. +func expConvertDiff( + diff *diff.SchemaDiff, + existingSchema *diff.DiffableSchema, + comparisonSchema *diff.DiffableSchema, + atRevision datastore.Revision, + caveatTypeSet *caveattypes.TypeSet, +) (*v1.ExperimentalDiffSchemaResponse, error) { + size := len(diff.AddedNamespaces) + len(diff.RemovedNamespaces) + len(diff.AddedCaveats) + len(diff.RemovedCaveats) + len(diff.ChangedNamespaces) + len(diff.ChangedCaveats) + diffs := make([]*v1.ExpSchemaDiff, 0, size) + + // Add/remove namespaces. + for _, ns := range diff.AddedNamespaces { + nsDef, err := expNamespaceAPIReprForName(ns, comparisonSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_DefinitionAdded{ + DefinitionAdded: nsDef, + }, + }) + } + + for _, ns := range diff.RemovedNamespaces { + nsDef, err := expNamespaceAPIReprForName(ns, existingSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_DefinitionRemoved{ + DefinitionRemoved: nsDef, + }, + }) + } + + // Add/remove caveats. + for _, caveat := range diff.AddedCaveats { + caveatDef, err := expCaveatAPIReprForName(caveat, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatAdded{ + CaveatAdded: caveatDef, + }, + }) + } + + for _, caveat := range diff.RemovedCaveats { + caveatDef, err := expCaveatAPIReprForName(caveat, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatRemoved{ + CaveatRemoved: caveatDef, + }, + }) + } + + // Changed namespaces. + for nsName, nsDiff := range diff.ChangedNamespaces { + for _, delta := range nsDiff.Deltas() { + switch delta.Type { + case nsdiff.AddedPermission: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := expPermissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_PermissionAdded{ + PermissionAdded: perm, + }, + }) + + case nsdiff.AddedRelation: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationAdded{ + RelationAdded: rel, + }, + }) + + case nsdiff.ChangedPermissionComment: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := expPermissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_PermissionDocCommentChanged{ + PermissionDocCommentChanged: perm, + }, + }) + + case nsdiff.ChangedPermissionImpl: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := expPermissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_PermissionExprChanged{ + PermissionExprChanged: perm, + }, + }) + + case nsdiff.ChangedRelationComment: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationDocCommentChanged{ + RelationDocCommentChanged: rel, + }, + }) + + case nsdiff.LegacyChangedRelationImpl: + return nil, spiceerrors.MustBugf("legacy relation implementation changes are not supported") + + case nsdiff.NamespaceCommentsChanged: + def, err := expNamespaceAPIReprForName(nsName, comparisonSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_DefinitionDocCommentChanged{ + DefinitionDocCommentChanged: def, + }, + }) + + case nsdiff.RelationAllowedTypeRemoved: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationSubjectTypeRemoved{ + RelationSubjectTypeRemoved: &v1.ExpRelationSubjectTypeChange{ + Relation: rel, + ChangedSubjectType: expTypeAPIRepr(delta.AllowedType), + }, + }, + }) + + case nsdiff.RelationAllowedTypeAdded: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationSubjectTypeAdded{ + RelationSubjectTypeAdded: &v1.ExpRelationSubjectTypeChange{ + Relation: rel, + ChangedSubjectType: expTypeAPIRepr(delta.AllowedType), + }, + }, + }) + + case nsdiff.RemovedPermission: + permission, ok := existingSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := expPermissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_PermissionRemoved{ + PermissionRemoved: perm, + }, + }) + + case nsdiff.RemovedRelation: + relation, ok := existingSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := expRelationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_RelationRemoved{ + RelationRemoved: rel, + }, + }) + + case nsdiff.NamespaceAdded: + return nil, spiceerrors.MustBugf("should be handled above") + + case nsdiff.NamespaceRemoved: + return nil, spiceerrors.MustBugf("should be handled above") + + default: + return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type) + } + } + } + + // Changed caveats. + for caveatName, caveatDiff := range diff.ChangedCaveats { + for _, delta := range caveatDiff.Deltas() { + switch delta.Type { + case caveatdiff.CaveatCommentsChanged: + caveat, err := expCaveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatDocCommentChanged{ + CaveatDocCommentChanged: caveat, + }, + }) + + case caveatdiff.AddedParameter: + paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatParameterAdded{ + CaveatParameterAdded: paramDef, + }, + }) + + case caveatdiff.RemovedParameter: + paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatParameterRemoved{ + CaveatParameterRemoved: paramDef, + }, + }) + + case caveatdiff.ParameterTypeChanged: + previousParamDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + paramDef, err := expCaveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatParameterTypeChanged{ + CaveatParameterTypeChanged: &v1.ExpCaveatParameterTypeChange{ + Parameter: paramDef, + PreviousType: previousParamDef.Type, + }, + }, + }) + + case caveatdiff.CaveatExpressionChanged: + caveat, err := expCaveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ExpSchemaDiff{ + Diff: &v1.ExpSchemaDiff_CaveatExprChanged{ + CaveatExprChanged: caveat, + }, + }) + + case caveatdiff.CaveatAdded: + return nil, spiceerrors.MustBugf("should be handled above") + + case caveatdiff.CaveatRemoved: + return nil, spiceerrors.MustBugf("should be handled above") + + default: + return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type) + } + } + } + + return &v1.ExperimentalDiffSchemaResponse{ + Diffs: diffs, + ReadAt: zedtoken.MustNewFromRevision(atRevision), + }, nil +} + +// expNamespaceAPIReprForName builds an API representation of a namespace. +func expNamespaceAPIReprForName(namespaceName string, schema *diff.DiffableSchema) (*v1.ExpDefinition, error) { + nsDef, ok := schema.GetNamespace(namespaceName) + if !ok { + return nil, spiceerrors.MustBugf("namespace %q not found in schema", namespaceName) + } + + return expNamespaceAPIRepr(nsDef, nil) +} + +func expNamespaceAPIRepr(nsDef *core.NamespaceDefinition, expSchemaFilters *expSchemaFilters) (*v1.ExpDefinition, error) { + if expSchemaFilters != nil && !expSchemaFilters.HasNamespace(nsDef.Name) { + return nil, nil + } + + relations := make([]*v1.ExpRelation, 0, len(nsDef.Relation)) + permissions := make([]*v1.ExpPermission, 0, len(nsDef.Relation)) + + for _, rel := range nsDef.Relation { + if namespace.GetRelationKind(rel) == iv1.RelationMetadata_PERMISSION { + permission, err := expPermissionAPIRepr(rel, nsDef.Name, expSchemaFilters) + if err != nil { + return nil, err + } + + if permission != nil { + permissions = append(permissions, permission) + } + continue + } + + relation, err := expRelationAPIRepr(rel, nsDef.Name, expSchemaFilters) + if err != nil { + return nil, err + } + + if relation != nil { + relations = append(relations, relation) + } + } + + comments := namespace.GetComments(nsDef.Metadata) + return &v1.ExpDefinition{ + Name: nsDef.Name, + Comment: strings.Join(comments, "\n"), + Relations: relations, + Permissions: permissions, + }, nil +} + +// expPermissionAPIRepr builds an API representation of a permission. +func expPermissionAPIRepr(relation *core.Relation, parentDefName string, expSchemaFilters *expSchemaFilters) (*v1.ExpPermission, error) { + if expSchemaFilters != nil && !expSchemaFilters.HasPermission(parentDefName, relation.Name) { + return nil, nil + } + + comments := namespace.GetComments(relation.Metadata) + return &v1.ExpPermission{ + Name: relation.Name, + Comment: strings.Join(comments, "\n"), + ParentDefinitionName: parentDefName, + }, nil +} + +// expRelationAPIRepr builds an API representation of a relation. +func expRelationAPIRepr(relation *core.Relation, parentDefName string, expSchemaFilters *expSchemaFilters) (*v1.ExpRelation, error) { + if expSchemaFilters != nil && !expSchemaFilters.HasRelation(parentDefName, relation.Name) { + return nil, nil + } + + comments := namespace.GetComments(relation.Metadata) + + var subjectTypes []*v1.ExpTypeReference + if relation.TypeInformation != nil { + subjectTypes = make([]*v1.ExpTypeReference, 0, len(relation.TypeInformation.AllowedDirectRelations)) + for _, subjectType := range relation.TypeInformation.AllowedDirectRelations { + typeref := expTypeAPIRepr(subjectType) + subjectTypes = append(subjectTypes, typeref) + } + } + + return &v1.ExpRelation{ + Name: relation.Name, + Comment: strings.Join(comments, "\n"), + ParentDefinitionName: parentDefName, + SubjectTypes: subjectTypes, + }, nil +} + +// expTypeAPIRepr builds an API representation of a type. +func expTypeAPIRepr(subjectType *core.AllowedRelation) *v1.ExpTypeReference { + typeref := &v1.ExpTypeReference{ + SubjectDefinitionName: subjectType.Namespace, + Typeref: &v1.ExpTypeReference_IsTerminalSubject{}, + } + + if subjectType.GetRelation() != tuple.Ellipsis && subjectType.GetRelation() != "" { + typeref.Typeref = &v1.ExpTypeReference_OptionalRelationName{ + OptionalRelationName: subjectType.GetRelation(), + } + } else if subjectType.GetPublicWildcard() != nil { + typeref.Typeref = &v1.ExpTypeReference_IsPublicWildcard{ + IsPublicWildcard: true, + } + } + + if subjectType.GetRequiredCaveat() != nil { + typeref.OptionalCaveatName = subjectType.GetRequiredCaveat().CaveatName + } + + return typeref +} + +// expCaveatAPIReprForName builds an API representation of a caveat. +func expCaveatAPIReprForName(caveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveat, error) { + caveatDef, ok := schema.GetCaveat(caveatName) + if !ok { + return nil, spiceerrors.MustBugf("caveat %q not found in schema", caveatName) + } + + return expCaveatAPIRepr(caveatDef, nil, caveatTypeSet) +} + +// expCaveatAPIRepr builds an API representation of a caveat. +func expCaveatAPIRepr(caveatDef *core.CaveatDefinition, expSchemaFilters *expSchemaFilters, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveat, error) { + if expSchemaFilters != nil && !expSchemaFilters.HasCaveat(caveatDef.Name) { + return nil, nil + } + + parameters := make([]*v1.ExpCaveatParameter, 0, len(caveatDef.ParameterTypes)) + paramNames := maps.Keys(caveatDef.ParameterTypes) + sort.Strings(paramNames) + + for _, paramName := range paramNames { + paramType, ok := caveatDef.ParameterTypes[paramName] + if !ok { + return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, caveatDef.Name) + } + + decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType) + if err != nil { + return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err) + } + + parameters = append(parameters, &v1.ExpCaveatParameter{ + Name: paramName, + Type: decoded.String(), + ParentCaveatName: caveatDef.Name, + }) + } + + parameterTypes, err := caveattypes.DecodeParameterTypes(caveatTypeSet, caveatDef.ParameterTypes) + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat parameters: %v", err) + } + + deserializedExpression, err := caveats.DeserializeCaveatWithTypeSet(caveatTypeSet, caveatDef.SerializedExpression, parameterTypes) + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat expression bytes: %v", err) + } + + exprString, err := deserializedExpression.ExprString() + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat expression: %v", err) + } + + comments := namespace.GetComments(caveatDef.Metadata) + return &v1.ExpCaveat{ + Name: caveatDef.Name, + Comment: strings.Join(comments, "\n"), + Parameters: parameters, + Expression: exprString, + }, nil +} + +// expCaveatAPIParamRepr builds an API representation of a caveat parameter. +func expCaveatAPIParamRepr(paramName, parentCaveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ExpCaveatParameter, error) { + caveatDef, ok := schema.GetCaveat(parentCaveatName) + if !ok { + return nil, spiceerrors.MustBugf("caveat %q not found in schema", parentCaveatName) + } + + paramType, ok := caveatDef.ParameterTypes[paramName] + if !ok { + return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, parentCaveatName) + } + + decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType) + if err != nil { + return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err) + } + + return &v1.ExpCaveatParameter{ + Name: paramName, + Type: decoded.String(), + ParentCaveatName: parentCaveatName, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go b/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go new file mode 100644 index 0000000..99b681d --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/grouping.go @@ -0,0 +1,72 @@ +package v1 + +import ( + "context" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/graph/computed" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/tuple" +) + +type groupedCheckParameters struct { + params *computed.CheckParameters + resourceIDs []string +} + +type groupingParameters struct { + atRevision datastore.Revision + maximumAPIDepth uint32 + maxCaveatContextSize int + withTracing bool +} + +// groupItems takes a slice of CheckBulkPermissionsRequestItem and groups them based +// on using the same permission, subject type, subject id, and caveat. +func groupItems(ctx context.Context, params groupingParameters, items []*v1.CheckBulkPermissionsRequestItem) (map[string]*groupedCheckParameters, error) { + res := make(map[string]*groupedCheckParameters) + + for _, item := range items { + hash, err := computeCheckBulkPermissionsItemHashWithoutResourceID(item) + if err != nil { + return nil, err + } + + if _, ok := res[hash]; !ok { + caveatContext, err := GetCaveatContext(ctx, item.Context, params.maxCaveatContextSize) + if err != nil { + return nil, err + } + + res[hash] = &groupedCheckParameters{ + params: checkParametersFromCheckBulkPermissionsRequestItem(item, params, caveatContext), + resourceIDs: []string{item.Resource.ObjectId}, + } + } else { + res[hash].resourceIDs = append(res[hash].resourceIDs, item.Resource.ObjectId) + } + } + + return res, nil +} + +func checkParametersFromCheckBulkPermissionsRequestItem( + bc *v1.CheckBulkPermissionsRequestItem, + params groupingParameters, + caveatContext map[string]any, +) *computed.CheckParameters { + debugOption := computed.NoDebugging + if params.withTracing { + debugOption = computed.BasicDebuggingEnabled + } + + return &computed.CheckParameters{ + ResourceType: tuple.RR(bc.Resource.ObjectType, bc.Permission), + Subject: tuple.ONR(bc.Subject.Object.ObjectType, bc.Subject.Object.ObjectId, normalizeSubjectRelation(bc.Subject)), + CaveatContext: caveatContext, + AtRevision: params.atRevision, + MaximumDepth: params.maximumAPIDepth, + DebugOption: debugOption, + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go new file mode 100644 index 0000000..1754669 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash.go @@ -0,0 +1,110 @@ +package v1 + +import ( + "strconv" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/spicedb/pkg/caveats" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +func computeCheckBulkPermissionsItemHashWithoutResourceID(req *v1.CheckBulkPermissionsRequestItem) (string, error) { + return computeCallHash("v1.checkbulkpermissionsrequestitem", nil, map[string]any{ + "resource-type": req.Resource.ObjectType, + "permission": req.Permission, + "subject-type": req.Subject.Object.ObjectType, + "subject-id": req.Subject.Object.ObjectId, + "subject-relation": req.Subject.OptionalRelation, + "context": req.Context, + }) +} + +func computeCheckBulkPermissionsItemHash(req *v1.CheckBulkPermissionsRequestItem) (string, error) { + return computeCallHash("v1.checkbulkpermissionsrequestitem", nil, map[string]any{ + "resource-type": req.Resource.ObjectType, + "resource-id": req.Resource.ObjectId, + "permission": req.Permission, + "subject-type": req.Subject.Object.ObjectType, + "subject-id": req.Subject.Object.ObjectId, + "subject-relation": req.Subject.OptionalRelation, + "context": req.Context, + }) +} + +func computeReadRelationshipsRequestHash(req *v1.ReadRelationshipsRequest) (string, error) { + osf := req.RelationshipFilter.OptionalSubjectFilter + if osf == nil { + osf = &v1.SubjectFilter{} + } + + srf := "(none)" + if osf.OptionalRelation != nil { + srf = osf.OptionalRelation.Relation + } + + return computeCallHash("v1.readrelationships", req.Consistency, map[string]any{ + "filter-resource-type": req.RelationshipFilter.ResourceType, + "filter-relation": req.RelationshipFilter.OptionalRelation, + "filter-resource-id": req.RelationshipFilter.OptionalResourceId, + "subject-type": osf.SubjectType, + "subject-relation": srf, + "subject-resource-id": osf.OptionalSubjectId, + "limit": req.OptionalLimit, + }) +} + +func computeLRRequestHash(req *v1.LookupResourcesRequest) (string, error) { + return computeCallHash("v1.lookupresources", req.Consistency, map[string]any{ + "resource-type": req.ResourceObjectType, + "permission": req.Permission, + "subject": tuple.V1StringSubjectRef(req.Subject), + "limit": req.OptionalLimit, + "context": req.Context, + }) +} + +func computeCallHash(apiName string, consistency *v1.Consistency, arguments map[string]any) (string, error) { + stringArguments := make(map[string]string, len(arguments)+1) + + if consistency == nil { + consistency = &v1.Consistency{ + Requirement: &v1.Consistency_MinimizeLatency{ + MinimizeLatency: true, + }, + } + } + + consistencyBytes, err := consistency.MarshalVT() + if err != nil { + return "", err + } + + stringArguments["consistency"] = string(consistencyBytes) + + for argName, argValue := range arguments { + if argName == "consistency" { + return "", spiceerrors.MustBugf("cannot specify consistency in the arguments") + } + + switch v := argValue.(type) { + case string: + stringArguments[argName] = v + + case int: + stringArguments[argName] = strconv.Itoa(v) + + case uint32: + stringArguments[argName] = strconv.Itoa(int(v)) + + case *structpb.Struct: + stringArguments[argName] = caveats.StableContextStringForHashing(v) + + default: + return "", spiceerrors.MustBugf("unknown argument type in compute call hash") + } + } + return computeAPICallHash(apiName, stringArguments) +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go new file mode 100644 index 0000000..fad4a40 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_nonwasm.go @@ -0,0 +1,52 @@ +//go:build !wasm +// +build !wasm + +package v1 + +import ( + "fmt" + "sort" + + "github.com/cespare/xxhash/v2" + "golang.org/x/exp/maps" +) + +func computeAPICallHash(apiName string, arguments map[string]string) (string, error) { + hasher := xxhash.New() + _, err := hasher.WriteString(apiName) + if err != nil { + return "", err + } + + _, err = hasher.WriteString(":") + if err != nil { + return "", err + } + + keys := maps.Keys(arguments) + sort.Strings(keys) + + for _, key := range keys { + _, err = hasher.WriteString(key) + if err != nil { + return "", err + } + + _, err = hasher.WriteString(":") + if err != nil { + return "", err + } + + _, err = hasher.WriteString(arguments[key]) + if err != nil { + return "", err + } + + _, err = hasher.WriteString(";") + if err != nil { + return "", err + } + } + + return fmt.Sprintf("%x", hasher.Sum(nil)), nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go new file mode 100644 index 0000000..4c75aa0 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/hash_wasm.go @@ -0,0 +1,50 @@ +package v1 + +import ( + "crypto/sha256" + "fmt" + "sort" + + "golang.org/x/exp/maps" +) + +func computeAPICallHash(apiName string, arguments map[string]string) (string, error) { + h := sha256.New() + + _, err := h.Write([]byte(apiName)) + if err != nil { + return "", err + } + + _, err = h.Write([]byte(":")) + if err != nil { + return "", err + } + + keys := maps.Keys(arguments) + sort.Strings(keys) + + for _, key := range keys { + _, err = h.Write([]byte(key)) + if err != nil { + return "", err + } + + _, err = h.Write([]byte(":")) + if err != nil { + return "", err + } + + _, err = h.Write([]byte(arguments[key])) + if err != nil { + return "", err + } + + _, err = h.Write([]byte(";")) + if err != nil { + return "", err + } + } + + return fmt.Sprintf("%x", h.Sum(nil)), nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go b/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go new file mode 100644 index 0000000..d309c3b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/options/options.go @@ -0,0 +1,12 @@ +package options + +import "time" + +//go:generate go run github.com/ecordell/optgen -output zz_generated.query_options.go . ExperimentalServerOptions + +type ExperimentalServerOptions struct { + StreamReadTimeout time.Duration `debugmap:"visible" default:"600s"` + DefaultExportBatchSize uint32 `debugmap:"visible" default:"1_000"` + MaxExportBatchSize uint32 `debugmap:"visible" default:"100_000"` + BulkCheckMaxConcurrency uint16 `debugmap:"visible" default:"50"` +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go b/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go new file mode 100644 index 0000000..5b75b5f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/options/zz_generated.query_options.go @@ -0,0 +1,93 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package options + +import ( + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" + "time" +) + +type ExperimentalServerOptionsOption func(e *ExperimentalServerOptions) + +// NewExperimentalServerOptionsWithOptions creates a new ExperimentalServerOptions with the passed in options set +func NewExperimentalServerOptionsWithOptions(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions { + e := &ExperimentalServerOptions{} + for _, o := range opts { + o(e) + } + return e +} + +// NewExperimentalServerOptionsWithOptionsAndDefaults creates a new ExperimentalServerOptions with the passed in options set starting from the defaults +func NewExperimentalServerOptionsWithOptionsAndDefaults(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions { + e := &ExperimentalServerOptions{} + defaults.MustSet(e) + for _, o := range opts { + o(e) + } + return e +} + +// ToOption returns a new ExperimentalServerOptionsOption that sets the values from the passed in ExperimentalServerOptions +func (e *ExperimentalServerOptions) ToOption() ExperimentalServerOptionsOption { + return func(to *ExperimentalServerOptions) { + to.StreamReadTimeout = e.StreamReadTimeout + to.DefaultExportBatchSize = e.DefaultExportBatchSize + to.MaxExportBatchSize = e.MaxExportBatchSize + to.BulkCheckMaxConcurrency = e.BulkCheckMaxConcurrency + } +} + +// DebugMap returns a map form of ExperimentalServerOptions for debugging +func (e ExperimentalServerOptions) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["StreamReadTimeout"] = helpers.DebugValue(e.StreamReadTimeout, false) + debugMap["DefaultExportBatchSize"] = helpers.DebugValue(e.DefaultExportBatchSize, false) + debugMap["MaxExportBatchSize"] = helpers.DebugValue(e.MaxExportBatchSize, false) + debugMap["BulkCheckMaxConcurrency"] = helpers.DebugValue(e.BulkCheckMaxConcurrency, false) + return debugMap +} + +// ExperimentalServerOptionsWithOptions configures an existing ExperimentalServerOptions with the passed in options set +func ExperimentalServerOptionsWithOptions(e *ExperimentalServerOptions, opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions { + for _, o := range opts { + o(e) + } + return e +} + +// WithOptions configures the receiver ExperimentalServerOptions with the passed in options set +func (e *ExperimentalServerOptions) WithOptions(opts ...ExperimentalServerOptionsOption) *ExperimentalServerOptions { + for _, o := range opts { + o(e) + } + return e +} + +// WithStreamReadTimeout returns an option that can set StreamReadTimeout on a ExperimentalServerOptions +func WithStreamReadTimeout(streamReadTimeout time.Duration) ExperimentalServerOptionsOption { + return func(e *ExperimentalServerOptions) { + e.StreamReadTimeout = streamReadTimeout + } +} + +// WithDefaultExportBatchSize returns an option that can set DefaultExportBatchSize on a ExperimentalServerOptions +func WithDefaultExportBatchSize(defaultExportBatchSize uint32) ExperimentalServerOptionsOption { + return func(e *ExperimentalServerOptions) { + e.DefaultExportBatchSize = defaultExportBatchSize + } +} + +// WithMaxExportBatchSize returns an option that can set MaxExportBatchSize on a ExperimentalServerOptions +func WithMaxExportBatchSize(maxExportBatchSize uint32) ExperimentalServerOptionsOption { + return func(e *ExperimentalServerOptions) { + e.MaxExportBatchSize = maxExportBatchSize + } +} + +// WithBulkCheckMaxConcurrency returns an option that can set BulkCheckMaxConcurrency on a ExperimentalServerOptions +func WithBulkCheckMaxConcurrency(bulkCheckMaxConcurrency uint16) ExperimentalServerOptionsOption { + return func(e *ExperimentalServerOptions) { + e.BulkCheckMaxConcurrency = bulkCheckMaxConcurrency + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go b/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go new file mode 100644 index 0000000..da6dd18 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/permissions.go @@ -0,0 +1,1094 @@ +package v1 + +import ( + "context" + "errors" + "fmt" + "io" + "slices" + "strings" + + "github.com/authzed/authzed-go/pkg/requestmeta" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/jzelinskie/stringz" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + + cexpr "github.com/authzed/spicedb/internal/caveats" + dispatchpkg "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph" + "github.com/authzed/spicedb/internal/graph/computed" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/relationships" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/internal/telemetry" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + dsoptions "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +func (ps *permissionServer) rewriteError(ctx context.Context, err error) error { + return shared.RewriteError(ctx, err, &shared.ConfigForErrors{ + MaximumAPIDepth: ps.config.MaximumAPIDepth, + }) +} + +func (ps *permissionServer) rewriteErrorWithOptionalDebugTrace(ctx context.Context, err error, debugTrace *v1.DebugInformation) error { + return shared.RewriteError(ctx, err, &shared.ConfigForErrors{ + MaximumAPIDepth: ps.config.MaximumAPIDepth, + DebugTrace: debugTrace, + }) +} + +func (ps *permissionServer) CheckPermission(ctx context.Context, req *v1.CheckPermissionRequest) (*v1.CheckPermissionResponse, error) { + telemetry.RecordLogicalChecks(1) + + atRevision, checkedAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + if err := namespace.CheckNamespaceAndRelations(ctx, + []namespace.TypeAndRelationToCheck{ + { + NamespaceName: req.Resource.ObjectType, + RelationName: req.Permission, + AllowEllipsis: false, + }, + { + NamespaceName: req.Subject.Object.ObjectType, + RelationName: normalizeSubjectRelation(req.Subject), + AllowEllipsis: true, + }, + }, ds); err != nil { + return nil, ps.rewriteError(ctx, err) + } + + debugOption := computed.NoDebugging + + if md, ok := metadata.FromIncomingContext(ctx); ok { + _, isDebuggingEnabled := md[string(requestmeta.RequestDebugInformation)] + if isDebuggingEnabled { + debugOption = computed.BasicDebuggingEnabled + } + } + + if req.WithTracing { + debugOption = computed.BasicDebuggingEnabled + } + + cr, metadata, err := computed.ComputeCheck(ctx, ps.dispatch, + ps.config.CaveatTypeSet, + computed.CheckParameters{ + ResourceType: tuple.RR(req.Resource.ObjectType, req.Permission), + Subject: tuple.ONR(req.Subject.Object.ObjectType, req.Subject.Object.ObjectId, normalizeSubjectRelation(req.Subject)), + CaveatContext: caveatContext, + AtRevision: atRevision, + MaximumDepth: ps.config.MaximumAPIDepth, + DebugOption: debugOption, + }, + req.Resource.ObjectId, + ps.config.DispatchChunkSize, + ) + usagemetrics.SetInContext(ctx, metadata) + + var debugTrace *v1.DebugInformation + if debugOption != computed.NoDebugging && metadata.DebugInfo != nil { + // Convert the dispatch debug information into API debug information. + converted, cerr := ConvertCheckDispatchDebugInformation(ctx, ps.config.CaveatTypeSet, caveatContext, metadata.DebugInfo, ds) + if cerr != nil { + return nil, ps.rewriteError(ctx, cerr) + } + debugTrace = converted + } + + if err != nil { + // If the error already contains debug information, rewrite it. This can happen if + // a dispatch error occurs and debug was requested. + if dispatchDebugInfo, ok := spiceerrors.GetDetails[*dispatch.DebugInformation](err); ok { + // Convert the dispatch debug information into API debug information. + converted, cerr := ConvertCheckDispatchDebugInformation(ctx, ps.config.CaveatTypeSet, caveatContext, dispatchDebugInfo, ds) + if cerr != nil { + return nil, ps.rewriteError(ctx, cerr) + } + + if converted != nil { + return nil, spiceerrors.AppendDetailsMetadata(err, spiceerrors.DebugTraceErrorDetailsKey, converted.String()) + } + } + + return nil, ps.rewriteErrorWithOptionalDebugTrace(ctx, err, debugTrace) + } + + permissionship, partialCaveat := checkResultToAPITypes(cr) + + return &v1.CheckPermissionResponse{ + CheckedAt: checkedAt, + Permissionship: permissionship, + PartialCaveatInfo: partialCaveat, + DebugTrace: debugTrace, + }, nil +} + +func checkResultToAPITypes(cr *dispatch.ResourceCheckResult) (v1.CheckPermissionResponse_Permissionship, *v1.PartialCaveatInfo) { + var partialCaveat *v1.PartialCaveatInfo + permissionship := v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION + if cr.Membership == dispatch.ResourceCheckResult_MEMBER { + permissionship = v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION + } else if cr.Membership == dispatch.ResourceCheckResult_CAVEATED_MEMBER { + permissionship = v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION + partialCaveat = &v1.PartialCaveatInfo{ + MissingRequiredContext: cr.MissingExprFields, + } + } + return permissionship, partialCaveat +} + +func (ps *permissionServer) CheckBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) { + res, err := ps.bulkChecker.checkBulkPermissions(ctx, req) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + return res, nil +} + +func pairItemFromCheckResult(checkResult *dispatch.ResourceCheckResult, debugTrace *v1.DebugInformation) *v1.CheckBulkPermissionsPair_Item { + permissionship, partialCaveat := checkResultToAPITypes(checkResult) + return &v1.CheckBulkPermissionsPair_Item{ + Item: &v1.CheckBulkPermissionsResponseItem{ + Permissionship: permissionship, + PartialCaveatInfo: partialCaveat, + DebugTrace: debugTrace, + }, + } +} + +func requestItemFromResourceAndParameters(params *computed.CheckParameters, resourceID string) (*v1.CheckBulkPermissionsRequestItem, error) { + item := &v1.CheckBulkPermissionsRequestItem{ + Resource: &v1.ObjectReference{ + ObjectType: params.ResourceType.ObjectType, + ObjectId: resourceID, + }, + Permission: params.ResourceType.Relation, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: params.Subject.ObjectType, + ObjectId: params.Subject.ObjectID, + }, + OptionalRelation: denormalizeSubjectRelation(params.Subject.Relation), + }, + } + if len(params.CaveatContext) > 0 { + var err error + item.Context, err = structpb.NewStruct(params.CaveatContext) + if err != nil { + return nil, fmt.Errorf("caveat context wasn't properly validated: %w", err) + } + } + return item, nil +} + +func (ps *permissionServer) ExpandPermissionTree(ctx context.Context, req *v1.ExpandPermissionTreeRequest) (*v1.ExpandPermissionTreeResponse, error) { + telemetry.RecordLogicalChecks(1) + + atRevision, expandedAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + err = namespace.CheckNamespaceAndRelation(ctx, req.Resource.ObjectType, req.Permission, false, ds) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth)) + if err != nil { + return nil, err + } + + resp, err := ps.dispatch.DispatchExpand(ctx, &dispatch.DispatchExpandRequest{ + Metadata: &dispatch.ResolverMeta{ + AtRevision: atRevision.String(), + DepthRemaining: ps.config.MaximumAPIDepth, + TraversalBloom: bf, + }, + ResourceAndRelation: &core.ObjectAndRelation{ + Namespace: req.Resource.ObjectType, + ObjectId: req.Resource.ObjectId, + Relation: req.Permission, + }, + ExpansionMode: dispatch.DispatchExpandRequest_SHALLOW, + }) + usagemetrics.SetInContext(ctx, resp.Metadata) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + // TODO(jschorr): Change to either using shared interfaces for nodes, or switch the internal + // dispatched expand to return V1 node types. + return &v1.ExpandPermissionTreeResponse{ + TreeRoot: TranslateExpansionTree(resp.TreeNode), + ExpandedAt: expandedAt, + }, nil +} + +// TranslateRelationshipTree translates a V1 PermissionRelationshipTree into a RelationTupleTreeNode. +func TranslateRelationshipTree(tree *v1.PermissionRelationshipTree) *core.RelationTupleTreeNode { + var expanded *core.ObjectAndRelation + if tree.ExpandedObject != nil { + expanded = &core.ObjectAndRelation{ + Namespace: tree.ExpandedObject.ObjectType, + ObjectId: tree.ExpandedObject.ObjectId, + Relation: tree.ExpandedRelation, + } + } + + switch t := tree.TreeType.(type) { + case *v1.PermissionRelationshipTree_Intermediate: + var operation core.SetOperationUserset_Operation + switch t.Intermediate.Operation { + case v1.AlgebraicSubjectSet_OPERATION_EXCLUSION: + operation = core.SetOperationUserset_EXCLUSION + case v1.AlgebraicSubjectSet_OPERATION_INTERSECTION: + operation = core.SetOperationUserset_INTERSECTION + case v1.AlgebraicSubjectSet_OPERATION_UNION: + operation = core.SetOperationUserset_UNION + default: + panic("unknown set operation") + } + + children := []*core.RelationTupleTreeNode{} + for _, child := range t.Intermediate.Children { + children = append(children, TranslateRelationshipTree(child)) + } + + return &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_IntermediateNode{ + IntermediateNode: &core.SetOperationUserset{ + Operation: operation, + ChildNodes: children, + }, + }, + Expanded: expanded, + } + + case *v1.PermissionRelationshipTree_Leaf: + var subjects []*core.DirectSubject + for _, subj := range t.Leaf.Subjects { + subjects = append(subjects, &core.DirectSubject{ + Subject: &core.ObjectAndRelation{ + Namespace: subj.Object.ObjectType, + ObjectId: subj.Object.ObjectId, + Relation: stringz.DefaultEmpty(subj.OptionalRelation, graph.Ellipsis), + }, + }) + } + + return &core.RelationTupleTreeNode{ + NodeType: &core.RelationTupleTreeNode_LeafNode{ + LeafNode: &core.DirectSubjects{Subjects: subjects}, + }, + Expanded: expanded, + } + + default: + panic("unknown type of expansion tree node") + } +} + +func TranslateExpansionTree(node *core.RelationTupleTreeNode) *v1.PermissionRelationshipTree { + switch t := node.NodeType.(type) { + case *core.RelationTupleTreeNode_IntermediateNode: + var operation v1.AlgebraicSubjectSet_Operation + switch t.IntermediateNode.Operation { + case core.SetOperationUserset_EXCLUSION: + operation = v1.AlgebraicSubjectSet_OPERATION_EXCLUSION + case core.SetOperationUserset_INTERSECTION: + operation = v1.AlgebraicSubjectSet_OPERATION_INTERSECTION + case core.SetOperationUserset_UNION: + operation = v1.AlgebraicSubjectSet_OPERATION_UNION + default: + panic("unknown set operation") + } + + var children []*v1.PermissionRelationshipTree + for _, child := range node.GetIntermediateNode().ChildNodes { + children = append(children, TranslateExpansionTree(child)) + } + + var objRef *v1.ObjectReference + var objRel string + if node.Expanded != nil { + objRef = &v1.ObjectReference{ + ObjectType: node.Expanded.Namespace, + ObjectId: node.Expanded.ObjectId, + } + objRel = node.Expanded.Relation + } + + return &v1.PermissionRelationshipTree{ + TreeType: &v1.PermissionRelationshipTree_Intermediate{ + Intermediate: &v1.AlgebraicSubjectSet{ + Operation: operation, + Children: children, + }, + }, + ExpandedObject: objRef, + ExpandedRelation: objRel, + } + + case *core.RelationTupleTreeNode_LeafNode: + var subjects []*v1.SubjectReference + for _, found := range t.LeafNode.Subjects { + subjects = append(subjects, &v1.SubjectReference{ + Object: &v1.ObjectReference{ + ObjectType: found.Subject.Namespace, + ObjectId: found.Subject.ObjectId, + }, + OptionalRelation: denormalizeSubjectRelation(found.Subject.Relation), + }) + } + + if node.Expanded == nil { + return &v1.PermissionRelationshipTree{ + TreeType: &v1.PermissionRelationshipTree_Leaf{ + Leaf: &v1.DirectSubjectSet{ + Subjects: subjects, + }, + }, + } + } + + return &v1.PermissionRelationshipTree{ + TreeType: &v1.PermissionRelationshipTree_Leaf{ + Leaf: &v1.DirectSubjectSet{ + Subjects: subjects, + }, + }, + ExpandedObject: &v1.ObjectReference{ + ObjectType: node.Expanded.Namespace, + ObjectId: node.Expanded.ObjectId, + }, + ExpandedRelation: node.Expanded.Relation, + } + + default: + panic("unknown type of expansion tree node") + } +} + +const lrv2CursorFlag = "lrv2" + +func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp v1.PermissionsService_LookupResourcesServer) error { + // NOTE: LRv2 is the only valid option, and we'll expect that all cursors include that flag. + // This is to preserve backward-compatibility in the meantime. + if req.OptionalCursor != nil { + _, _, err := cursor.GetCursorFlag(req.OptionalCursor, lrv2CursorFlag) + if err != nil { + return ps.rewriteError(resp.Context(), err) + } + } + + if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxLookupResourcesLimit { + return ps.rewriteError(resp.Context(), NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxLookupResourcesLimit))) + } + + ctx := resp.Context() + + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + if err := namespace.CheckNamespaceAndRelations(ctx, + []namespace.TypeAndRelationToCheck{ + { + NamespaceName: req.ResourceObjectType, + RelationName: req.Permission, + AllowEllipsis: false, + }, + { + NamespaceName: req.Subject.Object.ObjectType, + RelationName: normalizeSubjectRelation(req.Subject), + AllowEllipsis: true, + }, + }, ds); err != nil { + return ps.rewriteError(ctx, err) + } + + respMetadata := &dispatch.ResponseMeta{ + DispatchCount: 1, + CachedDispatchCount: 0, + DepthRequired: 1, + DebugInfo: nil, + } + usagemetrics.SetInContext(ctx, respMetadata) + + var currentCursor *dispatch.Cursor + + lrRequestHash, err := computeLRRequestHash(req) + if err != nil { + return ps.rewriteError(ctx, err) + } + + if req.OptionalCursor != nil { + decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash) + if err != nil { + return ps.rewriteError(ctx, err) + } + currentCursor = decodedCursor + } + + alreadyPublishedPermissionedResourceIds := map[string]struct{}{} + var totalCountPublished uint64 + defer func() { + telemetry.RecordLogicalChecks(totalCountPublished) + }() + + stream := dispatchpkg.NewHandlingDispatchStream(ctx, func(result *dispatch.DispatchLookupResources2Response) error { + found := result.Resource + + dispatchpkg.AddResponseMetadata(respMetadata, result.Metadata) + currentCursor = result.AfterResponseCursor + + var partial *v1.PartialCaveatInfo + permissionship := v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION + if len(found.MissingContextParams) > 0 { + permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION + partial = &v1.PartialCaveatInfo{ + MissingRequiredContext: found.MissingContextParams, + } + } else if req.OptionalLimit == 0 { + if _, ok := alreadyPublishedPermissionedResourceIds[found.ResourceId]; ok { + // Skip publishing the duplicate. + return nil + } + + // TODO(jschorr): Investigate something like a Trie here for better memory efficiency. + alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{} + } + + encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, map[string]string{ + lrv2CursorFlag: "1", + }) + if err != nil { + return ps.rewriteError(ctx, err) + } + + err = resp.Send(&v1.LookupResourcesResponse{ + LookedUpAt: revisionReadAt, + ResourceObjectId: found.ResourceId, + Permissionship: permissionship, + PartialCaveatInfo: partial, + AfterResultCursor: encodedCursor, + }) + if err != nil { + return err + } + + totalCountPublished++ + return nil + }) + + bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth)) + if err != nil { + return err + } + + err = ps.dispatch.DispatchLookupResources2( + &dispatch.DispatchLookupResources2Request{ + Metadata: &dispatch.ResolverMeta{ + AtRevision: atRevision.String(), + DepthRemaining: ps.config.MaximumAPIDepth, + TraversalBloom: bf, + }, + ResourceRelation: &core.RelationReference{ + Namespace: req.ResourceObjectType, + Relation: req.Permission, + }, + SubjectRelation: &core.RelationReference{ + Namespace: req.Subject.Object.ObjectType, + Relation: normalizeSubjectRelation(req.Subject), + }, + SubjectIds: []string{req.Subject.Object.ObjectId}, + TerminalSubject: &core.ObjectAndRelation{ + Namespace: req.Subject.Object.ObjectType, + ObjectId: req.Subject.Object.ObjectId, + Relation: normalizeSubjectRelation(req.Subject), + }, + Context: req.Context, + OptionalCursor: currentCursor, + OptionalLimit: req.OptionalLimit, + }, + stream) + if err != nil { + return ps.rewriteError(ctx, err) + } + + return nil +} + +func (ps *permissionServer) LookupSubjects(req *v1.LookupSubjectsRequest, resp v1.PermissionsService_LookupSubjectsServer) error { + ctx := resp.Context() + + if req.OptionalConcreteLimit != 0 { + return ps.rewriteError(ctx, status.Errorf(codes.Unimplemented, "concrete limit is not yet supported")) + } + + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + caveatContext, err := GetCaveatContext(ctx, req.Context, ps.config.MaxCaveatContextSize) + if err != nil { + return ps.rewriteError(ctx, err) + } + + if err := namespace.CheckNamespaceAndRelations(ctx, + []namespace.TypeAndRelationToCheck{ + { + NamespaceName: req.Resource.ObjectType, + RelationName: req.Permission, + AllowEllipsis: false, + }, + { + NamespaceName: req.SubjectObjectType, + RelationName: stringz.DefaultEmpty(req.OptionalSubjectRelation, tuple.Ellipsis), + AllowEllipsis: true, + }, + }, ds); err != nil { + return ps.rewriteError(ctx, err) + } + + respMetadata := &dispatch.ResponseMeta{ + DispatchCount: 0, + CachedDispatchCount: 0, + DepthRequired: 0, + DebugInfo: nil, + } + usagemetrics.SetInContext(ctx, respMetadata) + + var totalCountPublished uint64 + defer func() { + telemetry.RecordLogicalChecks(totalCountPublished) + }() + + stream := dispatchpkg.NewHandlingDispatchStream(ctx, func(result *dispatch.DispatchLookupSubjectsResponse) error { + foundSubjects, ok := result.FoundSubjectsByResourceId[req.Resource.ObjectId] + if !ok { + return fmt.Errorf("missing resource ID in returned LS") + } + + for _, foundSubject := range foundSubjects.FoundSubjects { + excludedSubjectIDs := make([]string, 0, len(foundSubject.ExcludedSubjects)) + for _, excludedSubject := range foundSubject.ExcludedSubjects { + excludedSubjectIDs = append(excludedSubjectIDs, excludedSubject.SubjectId) + } + + excludedSubjects := make([]*v1.ResolvedSubject, 0, len(foundSubject.ExcludedSubjects)) + for _, excludedSubject := range foundSubject.ExcludedSubjects { + resolvedExcludedSubject, err := foundSubjectToResolvedSubject(ctx, excludedSubject, caveatContext, ds, ps.config.CaveatTypeSet) + if err != nil { + return err + } + + if resolvedExcludedSubject == nil { + continue + } + + excludedSubjects = append(excludedSubjects, resolvedExcludedSubject) + } + + subject, err := foundSubjectToResolvedSubject(ctx, foundSubject, caveatContext, ds, ps.config.CaveatTypeSet) + if err != nil { + return err + } + if subject == nil { + continue + } + + err = resp.Send(&v1.LookupSubjectsResponse{ + Subject: subject, + ExcludedSubjects: excludedSubjects, + LookedUpAt: revisionReadAt, + SubjectObjectId: foundSubject.SubjectId, // Deprecated + ExcludedSubjectIds: excludedSubjectIDs, // Deprecated + Permissionship: subject.Permissionship, // Deprecated + PartialCaveatInfo: subject.PartialCaveatInfo, // Deprecated + }) + if err != nil { + return err + } + } + + totalCountPublished++ + dispatchpkg.AddResponseMetadata(respMetadata, result.Metadata) + return nil + }) + + bf, err := dispatch.NewTraversalBloomFilter(uint(ps.config.MaximumAPIDepth)) + if err != nil { + return err + } + + err = ps.dispatch.DispatchLookupSubjects( + &dispatch.DispatchLookupSubjectsRequest{ + Metadata: &dispatch.ResolverMeta{ + AtRevision: atRevision.String(), + DepthRemaining: ps.config.MaximumAPIDepth, + TraversalBloom: bf, + }, + ResourceRelation: &core.RelationReference{ + Namespace: req.Resource.ObjectType, + Relation: req.Permission, + }, + ResourceIds: []string{req.Resource.ObjectId}, + SubjectRelation: &core.RelationReference{ + Namespace: req.SubjectObjectType, + Relation: stringz.DefaultEmpty(req.OptionalSubjectRelation, tuple.Ellipsis), + }, + }, + stream) + if err != nil { + return ps.rewriteError(ctx, err) + } + + return nil +} + +func foundSubjectToResolvedSubject(ctx context.Context, foundSubject *dispatch.FoundSubject, caveatContext map[string]any, ds datastore.CaveatReader, caveatTypeSet *caveattypes.TypeSet) (*v1.ResolvedSubject, error) { + var partialCaveat *v1.PartialCaveatInfo + permissionship := v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION + if foundSubject.GetCaveatExpression() != nil { + permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_CONDITIONAL_PERMISSION + + cr, err := cexpr.RunSingleCaveatExpression(ctx, caveatTypeSet, foundSubject.GetCaveatExpression(), caveatContext, ds, cexpr.RunCaveatExpressionNoDebugging) + if err != nil { + return nil, err + } + + if cr.Value() { + permissionship = v1.LookupPermissionship_LOOKUP_PERMISSIONSHIP_HAS_PERMISSION + } else if cr.IsPartial() { + missingFields, _ := cr.MissingVarNames() + partialCaveat = &v1.PartialCaveatInfo{ + MissingRequiredContext: missingFields, + } + } else { + // Skip this found subject. + return nil, nil + } + } + + return &v1.ResolvedSubject{ + SubjectObjectId: foundSubject.SubjectId, + Permissionship: permissionship, + PartialCaveatInfo: partialCaveat, + }, nil +} + +func normalizeSubjectRelation(sub *v1.SubjectReference) string { + if sub.OptionalRelation == "" { + return graph.Ellipsis + } + return sub.OptionalRelation +} + +func denormalizeSubjectRelation(relation string) string { + if relation == graph.Ellipsis { + return "" + } + return relation +} + +func GetCaveatContext(ctx context.Context, caveatCtx *structpb.Struct, maxCaveatContextSize int) (map[string]any, error) { + var caveatContext map[string]any + if caveatCtx != nil { + if size := proto.Size(caveatCtx); maxCaveatContextSize > 0 && size > maxCaveatContextSize { + return nil, shared.RewriteError( + ctx, + status.Errorf( + codes.InvalidArgument, + "request caveat context should have less than %d bytes but had %d", + maxCaveatContextSize, + size, + ), + nil, + ) + } + caveatContext = caveatCtx.AsMap() + } + return caveatContext, nil +} + +type loadBulkAdapter struct { + stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse] + referencedNamespaceMap map[string]*schema.Definition + referencedCaveatMap map[string]*core.CaveatDefinition + current tuple.Relationship + caveat core.ContextualizedCaveat + caveatTypeSet *caveattypes.TypeSet + + awaitingNamespaces []string + awaitingCaveats []string + + currentBatch []*v1.Relationship + numSent int + err error +} + +func (a *loadBulkAdapter) Next(_ context.Context) (*tuple.Relationship, error) { + for a.err == nil && a.numSent == len(a.currentBatch) { + // Load a new batch + batch, err := a.stream.Recv() + if err != nil { + a.err = err + if errors.Is(a.err, io.EOF) { + return nil, nil + } + return nil, a.err + } + + a.currentBatch = batch.Relationships + a.numSent = 0 + + a.awaitingNamespaces, a.awaitingCaveats = extractBatchNewReferencedNamespacesAndCaveats( + a.currentBatch, + a.referencedNamespaceMap, + a.referencedCaveatMap, + ) + } + + if len(a.awaitingNamespaces) > 0 || len(a.awaitingCaveats) > 0 { + // Shut down the stream to give our caller a chance to fill in this information + return nil, nil + } + + a.current.RelationshipReference.Resource.ObjectType = a.currentBatch[a.numSent].Resource.ObjectType + a.current.RelationshipReference.Resource.ObjectID = a.currentBatch[a.numSent].Resource.ObjectId + a.current.RelationshipReference.Resource.Relation = a.currentBatch[a.numSent].Relation + a.current.Subject.ObjectType = a.currentBatch[a.numSent].Subject.Object.ObjectType + a.current.Subject.ObjectID = a.currentBatch[a.numSent].Subject.Object.ObjectId + a.current.Subject.Relation = stringz.DefaultEmpty(a.currentBatch[a.numSent].Subject.OptionalRelation, tuple.Ellipsis) + + if a.currentBatch[a.numSent].OptionalCaveat != nil { + a.caveat.CaveatName = a.currentBatch[a.numSent].OptionalCaveat.CaveatName + a.caveat.Context = a.currentBatch[a.numSent].OptionalCaveat.Context + a.current.OptionalCaveat = &a.caveat + } else { + a.current.OptionalCaveat = nil + } + + if a.currentBatch[a.numSent].OptionalExpiresAt != nil { + t := a.currentBatch[a.numSent].OptionalExpiresAt.AsTime() + a.current.OptionalExpiration = &t + } else { + a.current.OptionalExpiration = nil + } + + a.current.OptionalIntegrity = nil + + if err := relationships.ValidateOneRelationship( + a.referencedNamespaceMap, + a.referencedCaveatMap, + a.caveatTypeSet, + a.current, + relationships.ValidateRelationshipForCreateOrTouch, + ); err != nil { + return nil, err + } + + a.numSent++ + return &a.current, nil +} + +func (ps *permissionServer) ImportBulkRelationships(stream grpc.ClientStreamingServer[v1.ImportBulkRelationshipsRequest, v1.ImportBulkRelationshipsResponse]) error { + ds := datastoremw.MustFromContext(stream.Context()) + + var numWritten uint64 + if _, err := ds.ReadWriteTx(stream.Context(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + loadedNamespaces := make(map[string]*schema.Definition, 2) + loadedCaveats := make(map[string]*core.CaveatDefinition, 0) + + adapter := &loadBulkAdapter{ + stream: stream, + referencedNamespaceMap: loadedNamespaces, + referencedCaveatMap: loadedCaveats, + caveat: core.ContextualizedCaveat{}, + caveatTypeSet: ps.config.CaveatTypeSet, + } + resolver := schema.ResolverForDatastoreReader(rwt) + ts := schema.NewTypeSystem(resolver) + + var streamWritten uint64 + var err error + for ; adapter.err == nil && err == nil; streamWritten, err = rwt.BulkLoad(stream.Context(), adapter) { + numWritten += streamWritten + + // The stream has terminated because we're awaiting namespace and/or caveat information + if len(adapter.awaitingNamespaces) > 0 { + nsDefs, err := rwt.LookupNamespacesWithNames(stream.Context(), adapter.awaitingNamespaces) + if err != nil { + return err + } + + for _, nsDef := range nsDefs { + newDef, err := schema.NewDefinition(ts, nsDef.Definition) + if err != nil { + return err + } + + loadedNamespaces[nsDef.Definition.Name] = newDef + } + adapter.awaitingNamespaces = nil + } + + if len(adapter.awaitingCaveats) > 0 { + caveats, err := rwt.LookupCaveatsWithNames(stream.Context(), adapter.awaitingCaveats) + if err != nil { + return err + } + + for _, caveat := range caveats { + loadedCaveats[caveat.Definition.Name] = caveat.Definition + } + adapter.awaitingCaveats = nil + } + } + numWritten += streamWritten + + return err + }, dsoptions.WithDisableRetries(true)); err != nil { + return shared.RewriteErrorWithoutConfig(stream.Context(), err) + } + + usagemetrics.SetInContext(stream.Context(), &dispatch.ResponseMeta{ + // One request for the whole load + DispatchCount: 1, + }) + + return stream.SendAndClose(&v1.ImportBulkRelationshipsResponse{ + NumLoaded: numWritten, + }) +} + +func (ps *permissionServer) ExportBulkRelationships( + req *v1.ExportBulkRelationshipsRequest, + resp grpc.ServerStreamingServer[v1.ExportBulkRelationshipsResponse], +) error { + ctx := resp.Context() + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + return ExportBulk(ctx, datastoremw.MustFromContext(ctx), uint64(ps.config.MaxBulkExportRelationshipsLimit), req, atRevision, resp.Send) +} + +// ExportBulk implements the ExportBulkRelationships API functionality. Given a datastore.Datastore, it will +// export stream via the sender all relationships matched by the incoming request. +// If no cursor is provided, it will fallback to the provided revision. +func ExportBulk(ctx context.Context, ds datastore.Datastore, batchSize uint64, req *v1.ExportBulkRelationshipsRequest, fallbackRevision datastore.Revision, sender func(response *v1.ExportBulkRelationshipsResponse) error) error { + if req.OptionalLimit > 0 && uint64(req.OptionalLimit) > batchSize { + return shared.RewriteErrorWithoutConfig(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), batchSize)) + } + + atRevision := fallbackRevision + var curNamespace string + var cur dsoptions.Cursor + if req.OptionalCursor != nil { + var err error + atRevision, curNamespace, cur, err = decodeCursor(ds, req.OptionalCursor) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + } + + reader := ds.SnapshotReader(atRevision) + + namespaces, err := reader.ListAllNamespaces(ctx) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Make sure the namespaces are always in a stable order + slices.SortFunc(namespaces, func( + lhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + rhs datastore.RevisionedDefinition[*core.NamespaceDefinition], + ) int { + return strings.Compare(lhs.Definition.Name, rhs.Definition.Name) + }) + + // Skip the namespaces that are already fully returned + for cur != nil && len(namespaces) > 0 && namespaces[0].Definition.Name < curNamespace { + namespaces = namespaces[1:] + } + + limit := batchSize + if req.OptionalLimit > 0 { + limit = uint64(req.OptionalLimit) + } + + // Pre-allocate all of the relationships that we might need in order to + // make export easier and faster for the garbage collector. + relsArray := make([]v1.Relationship, limit) + objArray := make([]v1.ObjectReference, limit) + subArray := make([]v1.SubjectReference, limit) + subObjArray := make([]v1.ObjectReference, limit) + caveatArray := make([]v1.ContextualizedCaveat, limit) + for i := range relsArray { + relsArray[i].Resource = &objArray[i] + relsArray[i].Subject = &subArray[i] + relsArray[i].Subject.Object = &subObjArray[i] + } + + emptyRels := make([]*v1.Relationship, limit) + // The number of batches/dispatches for the purpose of usage metrics + var batches uint32 + for _, ns := range namespaces { + rels := emptyRels + + // Reset the cursor between namespaces. + if ns.Definition.Name != curNamespace { + cur = nil + } + + // Skip this namespace if a resource type filter was specified. + if req.OptionalRelationshipFilter != nil && req.OptionalRelationshipFilter.ResourceType != "" { + if ns.Definition.Name != req.OptionalRelationshipFilter.ResourceType { + continue + } + } + + // Setup the filter to use for the relationships. + relationshipFilter := datastore.RelationshipsFilter{OptionalResourceType: ns.Definition.Name} + if req.OptionalRelationshipFilter != nil { + rf, err := datastore.RelationshipsFilterFromPublicFilter(req.OptionalRelationshipFilter) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + // Overload the namespace name with the one from the request, because each iteration is for a different namespace. + rf.OptionalResourceType = ns.Definition.Name + relationshipFilter = rf + } + + // We want to keep iterating as long as we're sending full batches. + // To bootstrap this loop, we enter the first time with a full rels + // slice of dummy rels that were never sent. + for uint64(len(rels)) == limit { + // Lop off any rels we've already sent + rels = rels[:0] + + relFn := func(rel tuple.Relationship) { + offset := len(rels) + rels = append(rels, &relsArray[offset]) // nozero + + v1Rel := &relsArray[offset] + v1Rel.Resource.ObjectType = rel.RelationshipReference.Resource.ObjectType + v1Rel.Resource.ObjectId = rel.RelationshipReference.Resource.ObjectID + v1Rel.Relation = rel.RelationshipReference.Resource.Relation + v1Rel.Subject.Object.ObjectType = rel.RelationshipReference.Subject.ObjectType + v1Rel.Subject.Object.ObjectId = rel.RelationshipReference.Subject.ObjectID + v1Rel.Subject.OptionalRelation = denormalizeSubjectRelation(rel.RelationshipReference.Subject.Relation) + + if rel.OptionalCaveat != nil { + caveatArray[offset].CaveatName = rel.OptionalCaveat.CaveatName + caveatArray[offset].Context = rel.OptionalCaveat.Context + v1Rel.OptionalCaveat = &caveatArray[offset] + } else { + caveatArray[offset].CaveatName = "" + caveatArray[offset].Context = nil + v1Rel.OptionalCaveat = nil + } + + if rel.OptionalExpiration != nil { + v1Rel.OptionalExpiresAt = timestamppb.New(*rel.OptionalExpiration) + } else { + v1Rel.OptionalExpiresAt = nil + } + } + + cur, err = queryForEach( + ctx, + reader, + relationshipFilter, + relFn, + dsoptions.WithLimit(&limit), + dsoptions.WithAfter(cur), + dsoptions.WithSort(dsoptions.ByResource), + dsoptions.WithQueryShape(queryshape.Varying), + ) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if len(rels) == 0 { + continue + } + + encoded, err := cursor.Encode(&implv1.DecodedCursor{ + VersionOneof: &implv1.DecodedCursor_V1{ + V1: &implv1.V1Cursor{ + Revision: atRevision.String(), + Sections: []string{ + ns.Definition.Name, + tuple.MustString(*dsoptions.ToRelationship(cur)), + }, + }, + }, + }) + if err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + + if err := sender(&v1.ExportBulkRelationshipsResponse{ + AfterResultCursor: encoded, + Relationships: rels, + }); err != nil { + return shared.RewriteErrorWithoutConfig(ctx, err) + } + // Increment batches for usagemetrics + batches++ + } + } + + // Record usage metrics + respMetadata := &dispatch.ResponseMeta{ + DispatchCount: batches, + } + usagemetrics.SetInContext(ctx, respMetadata) + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go b/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go new file mode 100644 index 0000000..c34d5d5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/preconditions.go @@ -0,0 +1,54 @@ +package v1 + +import ( + "context" + "fmt" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" +) + +var limitOne uint64 = 1 + +// checkPreconditions checks whether the preconditions are met in the context of a datastore +// read-write transaction, and returns an error if they are not met. +func checkPreconditions( + ctx context.Context, + rwt datastore.ReadWriteTransaction, + preconditions []*v1.Precondition, +) error { + for _, precond := range preconditions { + dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(precond.Filter) + if err != nil { + return fmt.Errorf("error converting filter: %w", err) + } + + iter, err := rwt.QueryRelationships(ctx, dsFilter, options.WithLimit(&limitOne), options.WithQueryShape(queryshape.Varying)) + if err != nil { + return fmt.Errorf("error reading relationships: %w", err) + } + + _, ok, err := datastore.FirstRelationshipIn(iter) + if err != nil { + return fmt.Errorf("error reading relationships from iterator: %w", err) + } + + switch precond.Operation { + case v1.Precondition_OPERATION_MUST_NOT_MATCH: + if ok { + return NewPreconditionFailedErr(precond) + } + case v1.Precondition_OPERATION_MUST_MATCH: + if !ok { + return NewPreconditionFailedErr(precond) + } + default: + return fmt.Errorf("unspecified precondition operation: %s", precond.Operation) + } + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go new file mode 100644 index 0000000..723a8d3 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionapi.go @@ -0,0 +1,720 @@ +package v1 + +import ( + "sort" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/diff" + caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats" + nsdiff "github.com/authzed/spicedb/pkg/diff/namespace" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +type schemaFilters struct { + filters []*v1.ReflectionSchemaFilter +} + +func newSchemaFilters(filters []*v1.ReflectionSchemaFilter) (*schemaFilters, error) { + for _, filter := range filters { + if filter.OptionalDefinitionNameFilter != "" { + if filter.OptionalCaveatNameFilter != "" { + return nil, NewInvalidFilterErr("cannot filter by both definition and caveat name", filter.String()) + } + } + + if filter.OptionalRelationNameFilter != "" { + if filter.OptionalDefinitionNameFilter == "" { + return nil, NewInvalidFilterErr("relation name match requires definition name match", filter.String()) + } + + if filter.OptionalPermissionNameFilter != "" { + return nil, NewInvalidFilterErr("cannot filter by both relation and permission name", filter.String()) + } + } + + if filter.OptionalPermissionNameFilter != "" { + if filter.OptionalDefinitionNameFilter == "" { + return nil, NewInvalidFilterErr("permission name match requires definition name match", filter.String()) + } + } + } + + return &schemaFilters{filters: filters}, nil +} + +func (sf *schemaFilters) HasNamespaces() bool { + if len(sf.filters) == 0 { + return true + } + + for _, filter := range sf.filters { + if filter.OptionalDefinitionNameFilter != "" { + return true + } + } + + return false +} + +func (sf *schemaFilters) HasCaveats() bool { + if len(sf.filters) == 0 { + return true + } + + for _, filter := range sf.filters { + if filter.OptionalCaveatNameFilter != "" { + return true + } + } + + return false +} + +func (sf *schemaFilters) HasNamespace(namespaceName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasDefinitionFilter := false + for _, filter := range sf.filters { + if filter.OptionalDefinitionNameFilter == "" { + continue + } + + hasDefinitionFilter = true + isMatch := strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasDefinitionFilter +} + +func (sf *schemaFilters) HasCaveat(caveatName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasCaveatFilter := false + for _, filter := range sf.filters { + if filter.OptionalCaveatNameFilter == "" { + continue + } + + hasCaveatFilter = true + isMatch := strings.HasPrefix(caveatName, filter.OptionalCaveatNameFilter) + if isMatch { + return true + } + } + + return !hasCaveatFilter +} + +func (sf *schemaFilters) HasRelation(namespaceName, relationName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasRelationFilter := false + for _, filter := range sf.filters { + if filter.OptionalRelationNameFilter == "" { + continue + } + + hasRelationFilter = true + isMatch := strings.HasPrefix(relationName, filter.OptionalRelationNameFilter) + if !isMatch { + continue + } + + isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasRelationFilter +} + +func (sf *schemaFilters) HasPermission(namespaceName, permissionName string) bool { + if len(sf.filters) == 0 { + return true + } + + hasPermissionFilter := false + for _, filter := range sf.filters { + if filter.OptionalPermissionNameFilter == "" { + continue + } + + hasPermissionFilter = true + isMatch := strings.HasPrefix(permissionName, filter.OptionalPermissionNameFilter) + if !isMatch { + continue + } + + isMatch = strings.HasPrefix(namespaceName, filter.OptionalDefinitionNameFilter) + if isMatch { + return true + } + } + + return !hasPermissionFilter +} + +// convertDiff converts a schema diff into an API response. +func convertDiff( + diff *diff.SchemaDiff, + existingSchema *diff.DiffableSchema, + comparisonSchema *diff.DiffableSchema, + atRevision datastore.Revision, + caveatTypeSet *caveattypes.TypeSet, +) (*v1.DiffSchemaResponse, error) { + size := len(diff.AddedNamespaces) + len(diff.RemovedNamespaces) + len(diff.AddedCaveats) + len(diff.RemovedCaveats) + len(diff.ChangedNamespaces) + len(diff.ChangedCaveats) + diffs := make([]*v1.ReflectionSchemaDiff, 0, size) + + // Add/remove namespaces. + for _, ns := range diff.AddedNamespaces { + nsDef, err := namespaceAPIReprForName(ns, comparisonSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_DefinitionAdded{ + DefinitionAdded: nsDef, + }, + }) + } + + for _, ns := range diff.RemovedNamespaces { + nsDef, err := namespaceAPIReprForName(ns, existingSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_DefinitionRemoved{ + DefinitionRemoved: nsDef, + }, + }) + } + + // Add/remove caveats. + for _, caveat := range diff.AddedCaveats { + caveatDef, err := caveatAPIReprForName(caveat, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatAdded{ + CaveatAdded: caveatDef, + }, + }) + } + + for _, caveat := range diff.RemovedCaveats { + caveatDef, err := caveatAPIReprForName(caveat, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatRemoved{ + CaveatRemoved: caveatDef, + }, + }) + } + + // Changed namespaces. + for nsName, nsDiff := range diff.ChangedNamespaces { + for _, delta := range nsDiff.Deltas() { + switch delta.Type { + case nsdiff.AddedPermission: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := permissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_PermissionAdded{ + PermissionAdded: perm, + }, + }) + + case nsdiff.AddedRelation: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationAdded{ + RelationAdded: rel, + }, + }) + + case nsdiff.ChangedPermissionComment: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := permissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_PermissionDocCommentChanged{ + PermissionDocCommentChanged: perm, + }, + }) + + case nsdiff.ChangedPermissionImpl: + permission, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("permission %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := permissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_PermissionExprChanged{ + PermissionExprChanged: perm, + }, + }) + + case nsdiff.ChangedRelationComment: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationDocCommentChanged{ + RelationDocCommentChanged: rel, + }, + }) + + case nsdiff.LegacyChangedRelationImpl: + return nil, spiceerrors.MustBugf("legacy relation implementation changes are not supported") + + case nsdiff.NamespaceCommentsChanged: + def, err := namespaceAPIReprForName(nsName, comparisonSchema) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_DefinitionDocCommentChanged{ + DefinitionDocCommentChanged: def, + }, + }) + + case nsdiff.RelationAllowedTypeRemoved: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationSubjectTypeRemoved{ + RelationSubjectTypeRemoved: &v1.ReflectionRelationSubjectTypeChange{ + Relation: rel, + ChangedSubjectType: typeAPIRepr(delta.AllowedType), + }, + }, + }) + + case nsdiff.RelationAllowedTypeAdded: + relation, ok := comparisonSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationSubjectTypeAdded{ + RelationSubjectTypeAdded: &v1.ReflectionRelationSubjectTypeChange{ + Relation: rel, + ChangedSubjectType: typeAPIRepr(delta.AllowedType), + }, + }, + }) + + case nsdiff.RemovedPermission: + permission, ok := existingSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + perm, err := permissionAPIRepr(permission, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_PermissionRemoved{ + PermissionRemoved: perm, + }, + }) + + case nsdiff.RemovedRelation: + relation, ok := existingSchema.GetRelation(nsName, delta.RelationName) + if !ok { + return nil, spiceerrors.MustBugf("relation %q not found in namespace %q", delta.RelationName, nsName) + } + + rel, err := relationAPIRepr(relation, nsName, nil) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_RelationRemoved{ + RelationRemoved: rel, + }, + }) + + case nsdiff.NamespaceAdded: + return nil, spiceerrors.MustBugf("should be handled above") + + case nsdiff.NamespaceRemoved: + return nil, spiceerrors.MustBugf("should be handled above") + + default: + return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type) + } + } + } + + // Changed caveats. + for caveatName, caveatDiff := range diff.ChangedCaveats { + for _, delta := range caveatDiff.Deltas() { + switch delta.Type { + case caveatdiff.CaveatCommentsChanged: + caveat, err := caveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatDocCommentChanged{ + CaveatDocCommentChanged: caveat, + }, + }) + + case caveatdiff.AddedParameter: + paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatParameterAdded{ + CaveatParameterAdded: paramDef, + }, + }) + + case caveatdiff.RemovedParameter: + paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatParameterRemoved{ + CaveatParameterRemoved: paramDef, + }, + }) + + case caveatdiff.ParameterTypeChanged: + previousParamDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, existingSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + paramDef, err := caveatAPIParamRepr(delta.ParameterName, caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatParameterTypeChanged{ + CaveatParameterTypeChanged: &v1.ReflectionCaveatParameterTypeChange{ + Parameter: paramDef, + PreviousType: previousParamDef.Type, + }, + }, + }) + + case caveatdiff.CaveatExpressionChanged: + caveat, err := caveatAPIReprForName(caveatName, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, err + } + + diffs = append(diffs, &v1.ReflectionSchemaDiff{ + Diff: &v1.ReflectionSchemaDiff_CaveatExprChanged{ + CaveatExprChanged: caveat, + }, + }) + + case caveatdiff.CaveatAdded: + return nil, spiceerrors.MustBugf("should be handled above") + + case caveatdiff.CaveatRemoved: + return nil, spiceerrors.MustBugf("should be handled above") + + default: + return nil, spiceerrors.MustBugf("unexpected delta type %v", delta.Type) + } + } + } + + return &v1.DiffSchemaResponse{ + Diffs: diffs, + ReadAt: zedtoken.MustNewFromRevision(atRevision), + }, nil +} + +// namespaceAPIReprForName builds an API representation of a namespace. +func namespaceAPIReprForName(namespaceName string, schema *diff.DiffableSchema) (*v1.ReflectionDefinition, error) { + nsDef, ok := schema.GetNamespace(namespaceName) + if !ok { + return nil, spiceerrors.MustBugf("namespace %q not found in schema", namespaceName) + } + + return namespaceAPIRepr(nsDef, nil) +} + +func namespaceAPIRepr(nsDef *core.NamespaceDefinition, schemaFilters *schemaFilters) (*v1.ReflectionDefinition, error) { + if schemaFilters != nil && !schemaFilters.HasNamespace(nsDef.Name) { + return nil, nil + } + + relations := make([]*v1.ReflectionRelation, 0, len(nsDef.Relation)) + permissions := make([]*v1.ReflectionPermission, 0, len(nsDef.Relation)) + + for _, rel := range nsDef.Relation { + if namespace.GetRelationKind(rel) == iv1.RelationMetadata_PERMISSION { + permission, err := permissionAPIRepr(rel, nsDef.Name, schemaFilters) + if err != nil { + return nil, err + } + + if permission != nil { + permissions = append(permissions, permission) + } + continue + } + + relation, err := relationAPIRepr(rel, nsDef.Name, schemaFilters) + if err != nil { + return nil, err + } + + if relation != nil { + relations = append(relations, relation) + } + } + + comments := namespace.GetComments(nsDef.Metadata) + return &v1.ReflectionDefinition{ + Name: nsDef.Name, + Comment: strings.Join(comments, "\n"), + Relations: relations, + Permissions: permissions, + }, nil +} + +// permissionAPIRepr builds an API representation of a permission. +func permissionAPIRepr(relation *core.Relation, parentDefName string, schemaFilters *schemaFilters) (*v1.ReflectionPermission, error) { + if schemaFilters != nil && !schemaFilters.HasPermission(parentDefName, relation.Name) { + return nil, nil + } + + comments := namespace.GetComments(relation.Metadata) + return &v1.ReflectionPermission{ + Name: relation.Name, + Comment: strings.Join(comments, "\n"), + ParentDefinitionName: parentDefName, + }, nil +} + +// relationAPIRepresentation builds an API representation of a relation. +func relationAPIRepr(relation *core.Relation, parentDefName string, schemaFilters *schemaFilters) (*v1.ReflectionRelation, error) { + if schemaFilters != nil && !schemaFilters.HasRelation(parentDefName, relation.Name) { + return nil, nil + } + + comments := namespace.GetComments(relation.Metadata) + + var subjectTypes []*v1.ReflectionTypeReference + if relation.TypeInformation != nil { + subjectTypes = make([]*v1.ReflectionTypeReference, 0, len(relation.TypeInformation.AllowedDirectRelations)) + for _, subjectType := range relation.TypeInformation.AllowedDirectRelations { + typeref := typeAPIRepr(subjectType) + subjectTypes = append(subjectTypes, typeref) + } + } + + return &v1.ReflectionRelation{ + Name: relation.Name, + Comment: strings.Join(comments, "\n"), + ParentDefinitionName: parentDefName, + SubjectTypes: subjectTypes, + }, nil +} + +// typeAPIRepr builds an API representation of a type. +func typeAPIRepr(subjectType *core.AllowedRelation) *v1.ReflectionTypeReference { + typeref := &v1.ReflectionTypeReference{ + SubjectDefinitionName: subjectType.Namespace, + Typeref: &v1.ReflectionTypeReference_IsTerminalSubject{}, + } + + if subjectType.GetRelation() != tuple.Ellipsis && subjectType.GetRelation() != "" { + typeref.Typeref = &v1.ReflectionTypeReference_OptionalRelationName{ + OptionalRelationName: subjectType.GetRelation(), + } + } else if subjectType.GetPublicWildcard() != nil { + typeref.Typeref = &v1.ReflectionTypeReference_IsPublicWildcard{ + IsPublicWildcard: true, + } + } + + if subjectType.GetRequiredCaveat() != nil { + typeref.OptionalCaveatName = subjectType.GetRequiredCaveat().CaveatName + } + + return typeref +} + +// caveatAPIReprForName builds an API representation of a caveat. +func caveatAPIReprForName(caveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveat, error) { + caveatDef, ok := schema.GetCaveat(caveatName) + if !ok { + return nil, spiceerrors.MustBugf("caveat %q not found in schema", caveatName) + } + + return caveatAPIRepr(caveatDef, nil, caveatTypeSet) +} + +// caveatAPIRepr builds an API representation of a caveat. +func caveatAPIRepr(caveatDef *core.CaveatDefinition, schemaFilters *schemaFilters, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveat, error) { + if schemaFilters != nil && !schemaFilters.HasCaveat(caveatDef.Name) { + return nil, nil + } + + parameters := make([]*v1.ReflectionCaveatParameter, 0, len(caveatDef.ParameterTypes)) + paramNames := maps.Keys(caveatDef.ParameterTypes) + sort.Strings(paramNames) + + for _, paramName := range paramNames { + paramType, ok := caveatDef.ParameterTypes[paramName] + if !ok { + return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, caveatDef.Name) + } + + decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType) + if err != nil { + return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err) + } + + parameters = append(parameters, &v1.ReflectionCaveatParameter{ + Name: paramName, + Type: decoded.String(), + ParentCaveatName: caveatDef.Name, + }) + } + + parameterTypes, err := caveattypes.DecodeParameterTypes(caveatTypeSet, caveatDef.ParameterTypes) + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat parameters: %v", err) + } + + deserializedReflectionression, err := caveats.DeserializeCaveatWithTypeSet(caveatTypeSet, caveatDef.SerializedExpression, parameterTypes) + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat expression bytes: %v", err) + } + + exprString, err := deserializedReflectionression.ExprString() + if err != nil { + return nil, spiceerrors.MustBugf("invalid caveat expression: %v", err) + } + + comments := namespace.GetComments(caveatDef.Metadata) + return &v1.ReflectionCaveat{ + Name: caveatDef.Name, + Comment: strings.Join(comments, "\n"), + Parameters: parameters, + Expression: exprString, + }, nil +} + +// caveatAPIParamRepresentation builds an API representation of a caveat parameter. +func caveatAPIParamRepr(paramName, parentCaveatName string, schema *diff.DiffableSchema, caveatTypeSet *caveattypes.TypeSet) (*v1.ReflectionCaveatParameter, error) { + caveatDef, ok := schema.GetCaveat(parentCaveatName) + if !ok { + return nil, spiceerrors.MustBugf("caveat %q not found in schema", parentCaveatName) + } + + paramType, ok := caveatDef.ParameterTypes[paramName] + if !ok { + return nil, spiceerrors.MustBugf("parameter %q not found in caveat %q", paramName, parentCaveatName) + } + + decoded, err := caveattypes.DecodeParameterType(caveatTypeSet, paramType) + if err != nil { + return nil, spiceerrors.MustBugf("invalid parameter type on caveat: %v", err) + } + + return &v1.ReflectionCaveatParameter{ + Name: paramName, + Type: decoded.String(), + ParentCaveatName: parentCaveatName, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go new file mode 100644 index 0000000..a572216 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/reflectionutil.go @@ -0,0 +1,76 @@ +package v1 + +import ( + "context" + + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/diff" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +func loadCurrentSchema(ctx context.Context) (*diff.DiffableSchema, datastore.Revision, error) { + ds := datastoremw.MustFromContext(ctx) + + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, nil, err + } + + reader := ds.SnapshotReader(atRevision) + + namespacesAndRevs, err := reader.ListAllNamespaces(ctx) + if err != nil { + return nil, atRevision, err + } + + caveatsAndRevs, err := reader.ListAllCaveats(ctx) + if err != nil { + return nil, atRevision, err + } + + namespaces := make([]*core.NamespaceDefinition, 0, len(namespacesAndRevs)) + for _, namespaceAndRev := range namespacesAndRevs { + namespaces = append(namespaces, namespaceAndRev.Definition) + } + + caveats := make([]*core.CaveatDefinition, 0, len(caveatsAndRevs)) + for _, caveatAndRev := range caveatsAndRevs { + caveats = append(caveats, caveatAndRev.Definition) + } + + return &diff.DiffableSchema{ + ObjectDefinitions: namespaces, + CaveatDefinitions: caveats, + }, atRevision, nil +} + +func schemaDiff(ctx context.Context, comparisonSchemaString string, caveatTypeSet *caveattypes.TypeSet) (*diff.SchemaDiff, *diff.DiffableSchema, *diff.DiffableSchema, error) { + existingSchema, _, err := loadCurrentSchema(ctx) + if err != nil { + return nil, nil, nil, err + } + + // Compile the comparison schema. + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: comparisonSchemaString, + }, compiler.AllowUnprefixedObjectType(), compiler.CaveatTypeSet(caveatTypeSet)) + if err != nil { + return nil, nil, nil, err + } + + comparisonSchema := diff.NewDiffableSchemaFromCompiledSchema(compiled) + + diff, err := diff.DiffSchemas(*existingSchema, comparisonSchema, caveatTypeSet) + if err != nil { + return nil, nil, nil, err + } + + // Return the diff. + return diff, existingSchema, &comparisonSchema, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go b/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go new file mode 100644 index 0000000..f0b2138 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/relationships.go @@ -0,0 +1,576 @@ +package v1 + +import ( + "context" + "fmt" + "time" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "github.com/jzelinskie/stringz" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/middleware" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/handwrittenvalidation" + "github.com/authzed/spicedb/internal/middleware/streamtimeout" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/namespace" + "github.com/authzed/spicedb/internal/relationships" + "github.com/authzed/spicedb/internal/services/shared" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/pagination" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + "github.com/authzed/spicedb/pkg/genutil" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/middleware/consistency" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +var writeUpdateCounter = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "spicedb", + Subsystem: "v1", + Name: "write_relationships_updates", + Help: "The update counts for the WriteRelationships calls", + Buckets: []float64{0, 1, 2, 5, 10, 15, 25, 50, 100, 250, 500, 1000}, +}, []string{"kind"}) + +const MaximumTransactionMetadataSize = 65000 // bytes. Limited by the BLOB size used in MySQL driver + +// PermissionsServerConfig is configuration for the permissions server. +type PermissionsServerConfig struct { + // MaxUpdatesPerWrite holds the maximum number of updates allowed per + // WriteRelationships call. + MaxUpdatesPerWrite uint16 + + // MaxPreconditionsCount holds the maximum number of preconditions allowed + // on a WriteRelationships or DeleteRelationships call. + MaxPreconditionsCount uint16 + + // MaximumAPIDepth is the default/starting depth remaining for API calls made + // to the permissions server. + MaximumAPIDepth uint32 + + // DispatchChunkSize is the maximum number of elements to dispach in a dispatch call + DispatchChunkSize uint16 + + // StreamingAPITimeout is the timeout for streaming APIs when no response has been + // recently received. + StreamingAPITimeout time.Duration + + // MaxCaveatContextSize defines the maximum length of the request caveat context in bytes + MaxCaveatContextSize int + + // MaxRelationshipContextSize defines the maximum length of a relationship's context in bytes + MaxRelationshipContextSize int + + // MaxDatastoreReadPageSize defines the maximum number of relationships loaded from the + // datastore in one query. + MaxDatastoreReadPageSize uint64 + + // MaxCheckBulkConcurrency defines the maximum number of concurrent checks that can be + // made in a single CheckBulkPermissions call. + MaxCheckBulkConcurrency uint16 + + // MaxReadRelationshipsLimit defines the maximum number of relationships that can be read + // in a single ReadRelationships call. + MaxReadRelationshipsLimit uint32 + + // MaxDeleteRelationshipsLimit defines the maximum number of relationships that can be deleted + // in a single DeleteRelationships call. + MaxDeleteRelationshipsLimit uint32 + + // MaxLookupResourcesLimit defines the maximum number of resources that can be looked up in a + // single LookupResources call. + MaxLookupResourcesLimit uint32 + + // MaxBulkExportRelationshipsLimit defines the maximum number of relationships that can be + // exported in a single BulkExportRelationships call. + MaxBulkExportRelationshipsLimit uint32 + + // ExpiringRelationshipsEnabled defines whether or not expiring relationships are enabled. + ExpiringRelationshipsEnabled bool + + // CaveatTypeSet is the set of caveat types to use for caveats. If not specified, + // the default type set is used. + CaveatTypeSet *caveattypes.TypeSet +} + +// NewPermissionsServer creates a PermissionsServiceServer instance. +func NewPermissionsServer( + dispatch dispatch.Dispatcher, + config PermissionsServerConfig, +) v1.PermissionsServiceServer { + configWithDefaults := PermissionsServerConfig{ + MaxPreconditionsCount: defaultIfZero(config.MaxPreconditionsCount, 1000), + MaxUpdatesPerWrite: defaultIfZero(config.MaxUpdatesPerWrite, 1000), + MaximumAPIDepth: defaultIfZero(config.MaximumAPIDepth, 50), + StreamingAPITimeout: defaultIfZero(config.StreamingAPITimeout, 30*time.Second), + MaxCaveatContextSize: defaultIfZero(config.MaxCaveatContextSize, 4096), + MaxRelationshipContextSize: defaultIfZero(config.MaxRelationshipContextSize, 25_000), + MaxDatastoreReadPageSize: defaultIfZero(config.MaxDatastoreReadPageSize, 1_000), + MaxReadRelationshipsLimit: defaultIfZero(config.MaxReadRelationshipsLimit, 1_000), + MaxDeleteRelationshipsLimit: defaultIfZero(config.MaxDeleteRelationshipsLimit, 1_000), + MaxLookupResourcesLimit: defaultIfZero(config.MaxLookupResourcesLimit, 1_000), + MaxBulkExportRelationshipsLimit: defaultIfZero(config.MaxBulkExportRelationshipsLimit, 100_000), + DispatchChunkSize: defaultIfZero(config.DispatchChunkSize, 100), + MaxCheckBulkConcurrency: defaultIfZero(config.MaxCheckBulkConcurrency, 50), + CaveatTypeSet: caveattypes.TypeSetOrDefault(config.CaveatTypeSet), + ExpiringRelationshipsEnabled: true, + } + + return &permissionServer{ + dispatch: dispatch, + config: configWithDefaults, + WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{ + Unary: middleware.ChainUnaryServer( + grpcvalidate.UnaryServerInterceptor(), + handwrittenvalidation.UnaryServerInterceptor, + usagemetrics.UnaryServerInterceptor(), + ), + Stream: middleware.ChainStreamServer( + grpcvalidate.StreamServerInterceptor(), + handwrittenvalidation.StreamServerInterceptor, + usagemetrics.StreamServerInterceptor(), + streamtimeout.MustStreamServerInterceptor(configWithDefaults.StreamingAPITimeout), + ), + }, + bulkChecker: &bulkChecker{ + maxAPIDepth: configWithDefaults.MaximumAPIDepth, + maxCaveatContextSize: configWithDefaults.MaxCaveatContextSize, + maxConcurrency: configWithDefaults.MaxCheckBulkConcurrency, + dispatch: dispatch, + dispatchChunkSize: configWithDefaults.DispatchChunkSize, + caveatTypeSet: configWithDefaults.CaveatTypeSet, + }, + } +} + +type permissionServer struct { + v1.UnimplementedPermissionsServiceServer + shared.WithServiceSpecificInterceptors + + dispatch dispatch.Dispatcher + config PermissionsServerConfig + + bulkChecker *bulkChecker +} + +func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest, resp v1.PermissionsService_ReadRelationshipsServer) error { + if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxReadRelationshipsLimit { + return ps.rewriteError(resp.Context(), NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxReadRelationshipsLimit))) + } + + ctx := resp.Context() + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + + if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, ds); err != nil { + return ps.rewriteError(ctx, err) + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + DispatchCount: 1, + }) + + limit := uint64(0) + var startCursor options.Cursor + + rrRequestHash, err := computeReadRelationshipsRequestHash(req) + if err != nil { + return ps.rewriteError(ctx, err) + } + + if req.OptionalCursor != nil { + decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, rrRequestHash) + if err != nil { + return ps.rewriteError(ctx, err) + } + + if len(decodedCursor.Sections) != 1 { + return ps.rewriteError(ctx, NewInvalidCursorErr("did not find expected resume relationship")) + } + + parsed, err := tuple.Parse(decodedCursor.Sections[0]) + if err != nil { + return ps.rewriteError(ctx, NewInvalidCursorErr("could not parse resume relationship")) + } + + startCursor = options.ToCursor(parsed) + } + + pageSize := ps.config.MaxDatastoreReadPageSize + if req.OptionalLimit > 0 { + limit = uint64(req.OptionalLimit) + if limit < pageSize { + pageSize = limit + } + } + + dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(req.RelationshipFilter) + if err != nil { + return ps.rewriteError(ctx, fmt.Errorf("error filtering: %w", err)) + } + + it, err := pagination.NewPaginatedIterator( + ctx, + ds, + dsFilter, + pageSize, + options.ByResource, + startCursor, + queryshape.Varying, + ) + if err != nil { + return ps.rewriteError(ctx, err) + } + + response := &v1.ReadRelationshipsResponse{ + ReadAt: revisionReadAt, + Relationship: &v1.Relationship{ + Resource: &v1.ObjectReference{}, + Subject: &v1.SubjectReference{ + Object: &v1.ObjectReference{}, + }, + }, + } + + dispatchCursor := &dispatchv1.Cursor{ + DispatchVersion: 1, + Sections: []string{""}, + } + + var returnedCount uint64 + for rel, err := range it { + if err != nil { + return ps.rewriteError(ctx, fmt.Errorf("error when reading tuples: %w", err)) + } + + if limit > 0 && returnedCount >= limit { + break + } + + dispatchCursor.Sections[0] = tuple.StringWithoutCaveatOrExpiration(rel) + encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, nil) + if err != nil { + return ps.rewriteError(ctx, err) + } + + tuple.CopyToV1Relationship(rel, response.Relationship) + response.AfterResultCursor = encodedCursor + + err = resp.Send(response) + if err != nil { + return ps.rewriteError(ctx, fmt.Errorf("error when streaming tuple: %w", err)) + } + returnedCount++ + } + return nil +} + +func (ps *permissionServer) WriteRelationships(ctx context.Context, req *v1.WriteRelationshipsRequest) (*v1.WriteRelationshipsResponse, error) { + if err := ps.validateTransactionMetadata(req.OptionalTransactionMetadata); err != nil { + return nil, ps.rewriteError(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx) + + span := trace.SpanFromContext(ctx) + span.AddEvent("validating mutations") + // Ensure that the updates and preconditions are not over the configured limits. + if len(req.Updates) > int(ps.config.MaxUpdatesPerWrite) { + return nil, ps.rewriteError( + ctx, + NewExceedsMaximumUpdatesErr(uint64(len(req.Updates)), uint64(ps.config.MaxUpdatesPerWrite)), + ) + } + + if len(req.OptionalPreconditions) > int(ps.config.MaxPreconditionsCount) { + return nil, ps.rewriteError( + ctx, + NewExceedsMaximumPreconditionsErr(uint64(len(req.OptionalPreconditions)), uint64(ps.config.MaxPreconditionsCount)), + ) + } + + // Check for duplicate updates and create the set of caveat names to load. + updateRelationshipSet := mapz.NewSet[string]() + for _, update := range req.Updates { + // TODO(jschorr): Change to struct-based keys. + tupleStr := tuple.V1StringRelationshipWithoutCaveatOrExpiration(update.Relationship) + if !updateRelationshipSet.Add(tupleStr) { + return nil, ps.rewriteError( + ctx, + NewDuplicateRelationshipErr(update), + ) + } + if proto.Size(update.Relationship.OptionalCaveat) > ps.config.MaxRelationshipContextSize { + return nil, ps.rewriteError( + ctx, + NewMaxRelationshipContextError(update, ps.config.MaxRelationshipContextSize), + ) + } + + if !ps.config.ExpiringRelationshipsEnabled && update.Relationship.OptionalExpiresAt != nil { + return nil, ps.rewriteError( + ctx, + fmt.Errorf("support for expiring relationships is not enabled"), + ) + } + } + + // Execute the write operation(s). + span.AddEvent("read write transaction") + relUpdates, err := tuple.UpdatesFromV1RelationshipUpdates(req.Updates) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + span.AddEvent("preconditions") + + // Validate the preconditions. + for _, precond := range req.OptionalPreconditions { + if err := validatePrecondition(ctx, precond, rwt); err != nil { + return err + } + } + + // Validate the updates. + span.AddEvent("validate updates") + err := relationships.ValidateRelationshipUpdates(ctx, rwt, ps.config.CaveatTypeSet, relUpdates) + if err != nil { + return ps.rewriteError(ctx, err) + } + + dispatchCount, err := genutil.EnsureUInt32(len(req.OptionalPreconditions) + 1) + if err != nil { + return ps.rewriteError(ctx, err) + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + // One request per precondition and one request for the actual writes. + DispatchCount: dispatchCount, + }) + + span.AddEvent("preconditions") + if err := checkPreconditions(ctx, rwt, req.OptionalPreconditions); err != nil { + return err + } + + span.AddEvent("write relationships") + return rwt.WriteRelationships(ctx, relUpdates) + }, options.WithMetadata(req.OptionalTransactionMetadata)) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + // Log a metric of the counts of the different kinds of update operations. + updateCountByOperation := make(map[v1.RelationshipUpdate_Operation]int, 0) + for _, update := range req.Updates { + updateCountByOperation[update.Operation]++ + } + + for kind, count := range updateCountByOperation { + writeUpdateCounter.WithLabelValues(v1.RelationshipUpdate_Operation_name[int32(kind)]).Observe(float64(count)) + } + + return &v1.WriteRelationshipsResponse{ + WrittenAt: zedtoken.MustNewFromRevision(revision), + }, nil +} + +func (ps *permissionServer) validateTransactionMetadata(metadata *structpb.Struct) error { + if metadata == nil { + return nil + } + + b, err := metadata.MarshalJSON() + if err != nil { + return err + } + + if len(b) > MaximumTransactionMetadataSize { + return NewTransactionMetadataTooLargeErr(len(b), MaximumTransactionMetadataSize) + } + + return nil +} + +func (ps *permissionServer) DeleteRelationships(ctx context.Context, req *v1.DeleteRelationshipsRequest) (*v1.DeleteRelationshipsResponse, error) { + if err := ps.validateTransactionMetadata(req.OptionalTransactionMetadata); err != nil { + return nil, ps.rewriteError(ctx, err) + } + + if len(req.OptionalPreconditions) > int(ps.config.MaxPreconditionsCount) { + return nil, ps.rewriteError( + ctx, + NewExceedsMaximumPreconditionsErr(uint64(len(req.OptionalPreconditions)), uint64(ps.config.MaxPreconditionsCount)), + ) + } + + if req.OptionalLimit > 0 && req.OptionalLimit > ps.config.MaxDeleteRelationshipsLimit { + return nil, ps.rewriteError(ctx, NewExceedsMaximumLimitErr(uint64(req.OptionalLimit), uint64(ps.config.MaxDeleteRelationshipsLimit))) + } + + ds := datastoremw.MustFromContext(ctx) + deletionProgress := v1.DeleteRelationshipsResponse_DELETION_PROGRESS_COMPLETE + + var deletedRelationshipCount uint64 + revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + if err := validateRelationshipsFilter(ctx, req.RelationshipFilter, rwt); err != nil { + return err + } + + dispatchCount, err := genutil.EnsureUInt32(len(req.OptionalPreconditions) + 1) + if err != nil { + return ps.rewriteError(ctx, err) + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + // One request per precondition and one request for the actual delete. + DispatchCount: dispatchCount, + }) + + for _, precond := range req.OptionalPreconditions { + if err := validatePrecondition(ctx, precond, rwt); err != nil { + return err + } + } + + if err := checkPreconditions(ctx, rwt, req.OptionalPreconditions); err != nil { + return err + } + + // If a limit was specified but partial deletion is not allowed, we need to check if the + // number of relationships to be deleted exceeds the limit. + if req.OptionalLimit > 0 && !req.OptionalAllowPartialDeletions { + limit := uint64(req.OptionalLimit) + limitPlusOne := limit + 1 + filter, err := datastore.RelationshipsFilterFromPublicFilter(req.RelationshipFilter) + if err != nil { + return ps.rewriteError(ctx, err) + } + + it, err := rwt.QueryRelationships(ctx, filter, options.WithLimit(&limitPlusOne), options.WithQueryShape(queryshape.Varying)) + if err != nil { + return ps.rewriteError(ctx, err) + } + + counter := uint64(0) + for _, err := range it { + if err != nil { + return ps.rewriteError(ctx, err) + } + + if counter == limit { + return ps.rewriteError(ctx, NewCouldNotTransactionallyDeleteErr(req.RelationshipFilter, req.OptionalLimit)) + } + + counter++ + } + } + + // Delete with the specified limit. + if req.OptionalLimit > 0 { + deleteLimit := uint64(req.OptionalLimit) + drc, reachedLimit, err := rwt.DeleteRelationships(ctx, req.RelationshipFilter, options.WithDeleteLimit(&deleteLimit)) + if err != nil { + return err + } + + if reachedLimit { + deletionProgress = v1.DeleteRelationshipsResponse_DELETION_PROGRESS_PARTIAL + } + + deletedRelationshipCount = drc + return nil + } + + // Otherwise, kick off an unlimited deletion. + deletedRelationshipCount, _, err = rwt.DeleteRelationships(ctx, req.RelationshipFilter) + return err + }, options.WithMetadata(req.OptionalTransactionMetadata)) + if err != nil { + return nil, ps.rewriteError(ctx, err) + } + + return &v1.DeleteRelationshipsResponse{ + DeletedAt: zedtoken.MustNewFromRevision(revision), + DeletionProgress: deletionProgress, + RelationshipsDeletedCount: deletedRelationshipCount, + }, nil +} + +var emptyPrecondition = &v1.Precondition{} + +func validatePrecondition(ctx context.Context, precond *v1.Precondition, reader datastore.Reader) error { + if precond.EqualVT(emptyPrecondition) || precond.Filter == nil { + return NewEmptyPreconditionErr() + } + + return validateRelationshipsFilter(ctx, precond.Filter, reader) +} + +func checkFilterComponent(ctx context.Context, objectType, optionalRelation string, ds datastore.Reader) error { + if objectType == "" { + return nil + } + + relationToTest := stringz.DefaultEmpty(optionalRelation, datastore.Ellipsis) + allowEllipsis := optionalRelation == "" + return namespace.CheckNamespaceAndRelation(ctx, objectType, relationToTest, allowEllipsis, ds) +} + +func validateRelationshipsFilter(ctx context.Context, filter *v1.RelationshipFilter, ds datastore.Reader) error { + // ResourceType is optional, so only check the relation if it is specified. + if filter.ResourceType != "" { + if err := checkFilterComponent(ctx, filter.ResourceType, filter.OptionalRelation, ds); err != nil { + return err + } + } + + // SubjectFilter is optional, so only check if it is specified. + if subjectFilter := filter.OptionalSubjectFilter; subjectFilter != nil { + subjectRelation := "" + if subjectFilter.OptionalRelation != nil { + subjectRelation = subjectFilter.OptionalRelation.Relation + } + if err := checkFilterComponent(ctx, subjectFilter.SubjectType, subjectRelation, ds); err != nil { + return err + } + } + + // Ensure the resource ID and the resource ID prefix are not set at the same time. + if filter.OptionalResourceId != "" && filter.OptionalResourceIdPrefix != "" { + return NewInvalidFilterErr("resource_id and resource_id_prefix cannot be set at the same time", filter.String()) + } + + // Ensure that at least one field is set. + return checkIfFilterIsEmpty(filter) +} + +func checkIfFilterIsEmpty(filter *v1.RelationshipFilter) error { + if filter.ResourceType == "" && + filter.OptionalResourceId == "" && + filter.OptionalResourceIdPrefix == "" && + filter.OptionalRelation == "" && + filter.OptionalSubjectFilter == nil { + return NewInvalidFilterErr("at least one field must be set", filter.String()) + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go b/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go new file mode 100644 index 0000000..14faf3d --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/schema.go @@ -0,0 +1,375 @@ +package v1 + +import ( + "context" + "sort" + "strings" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/middleware" + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/services/shared" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil" + "github.com/authzed/spicedb/pkg/middleware/consistency" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +// NewSchemaServer creates a SchemaServiceServer instance. +func NewSchemaServer(caveatTypeSet *caveattypes.TypeSet, additiveOnly bool, expiringRelsEnabled bool) v1.SchemaServiceServer { + cts := caveattypes.TypeSetOrDefault(caveatTypeSet) + return &schemaServer{ + WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{ + Unary: middleware.ChainUnaryServer( + grpcvalidate.UnaryServerInterceptor(), + usagemetrics.UnaryServerInterceptor(), + ), + Stream: middleware.ChainStreamServer( + grpcvalidate.StreamServerInterceptor(), + usagemetrics.StreamServerInterceptor(), + ), + }, + additiveOnly: additiveOnly, + expiringRelsEnabled: expiringRelsEnabled, + caveatTypeSet: cts, + } +} + +type schemaServer struct { + v1.UnimplementedSchemaServiceServer + shared.WithServiceSpecificInterceptors + + caveatTypeSet *caveattypes.TypeSet + additiveOnly bool + expiringRelsEnabled bool +} + +func (ss *schemaServer) rewriteError(ctx context.Context, err error) error { + return shared.RewriteError(ctx, err, nil) +} + +func (ss *schemaServer) ReadSchema(ctx context.Context, _ *v1.ReadSchemaRequest) (*v1.ReadSchemaResponse, error) { + // Schema is always read from the head revision. + ds := datastoremw.MustFromContext(ctx) + headRevision, err := ds.HeadRevision(ctx) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + reader := ds.SnapshotReader(headRevision) + + nsDefs, err := reader.ListAllNamespaces(ctx) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + caveatDefs, err := reader.ListAllCaveats(ctx) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + if len(nsDefs) == 0 { + return nil, status.Errorf(codes.NotFound, "No schema has been defined; please call WriteSchema to start") + } + + schemaDefinitions := make([]compiler.SchemaDefinition, 0, len(nsDefs)+len(caveatDefs)) + for _, caveatDef := range caveatDefs { + schemaDefinitions = append(schemaDefinitions, caveatDef.Definition) + } + + for _, nsDef := range nsDefs { + schemaDefinitions = append(schemaDefinitions, nsDef.Definition) + } + + schemaText, _, err := generator.GenerateSchema(schemaDefinitions) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + dispatchCount, err := genutil.EnsureUInt32(len(nsDefs) + len(caveatDefs)) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + DispatchCount: dispatchCount, + }) + + return &v1.ReadSchemaResponse{ + SchemaText: schemaText, + ReadAt: zedtoken.MustNewFromRevision(headRevision), + }, nil +} + +func (ss *schemaServer) WriteSchema(ctx context.Context, in *v1.WriteSchemaRequest) (*v1.WriteSchemaResponse, error) { + log.Ctx(ctx).Trace().Str("schema", in.GetSchema()).Msg("requested Schema to be written") + + ds := datastoremw.MustFromContext(ctx) + + // Compile the schema into the namespace definitions. + opts := make([]compiler.Option, 0, 3) + if !ss.expiringRelsEnabled { + opts = append(opts, compiler.DisallowExpirationFlag()) + } + + opts = append(opts, compiler.CaveatTypeSet(ss.caveatTypeSet)) + + compiled, err := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), + SchemaString: in.GetSchema(), + }, compiler.AllowUnprefixedObjectType(), opts...) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + log.Ctx(ctx).Trace().Int("objectDefinitions", len(compiled.ObjectDefinitions)).Int("caveatDefinitions", len(compiled.CaveatDefinitions)).Msg("compiled namespace definitions") + + // Do as much validation as we can before talking to the datastore. + validated, err := shared.ValidateSchemaChanges(ctx, compiled, ss.caveatTypeSet, ss.additiveOnly) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + // Update the schema. + revision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + applied, err := shared.ApplySchemaChanges(ctx, rwt, ss.caveatTypeSet, validated) + if err != nil { + return err + } + + dispatchCount, err := genutil.EnsureUInt32(applied.TotalOperationCount) + if err != nil { + return err + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + DispatchCount: dispatchCount, + }) + return nil + }) + if err != nil { + return nil, ss.rewriteError(ctx, err) + } + + return &v1.WriteSchemaResponse{ + WrittenAt: zedtoken.MustNewFromRevision(revision), + }, nil +} + +func (ss *schemaServer) ReflectSchema(ctx context.Context, req *v1.ReflectSchemaRequest) (*v1.ReflectSchemaResponse, error) { + // Get the current schema. + schema, atRevision, err := loadCurrentSchema(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + filters, err := newSchemaFilters(req.OptionalFilters) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + definitions := make([]*v1.ReflectionDefinition, 0, len(schema.ObjectDefinitions)) + if filters.HasNamespaces() { + for _, ns := range schema.ObjectDefinitions { + def, err := namespaceAPIRepr(ns, filters) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + if def != nil { + definitions = append(definitions, def) + } + } + } + + caveats := make([]*v1.ReflectionCaveat, 0, len(schema.CaveatDefinitions)) + if filters.HasCaveats() { + for _, cd := range schema.CaveatDefinitions { + caveat, err := caveatAPIRepr(cd, filters, ss.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + if caveat != nil { + caveats = append(caveats, caveat) + } + } + } + + return &v1.ReflectSchemaResponse{ + Definitions: definitions, + Caveats: caveats, + ReadAt: zedtoken.MustNewFromRevision(atRevision), + }, nil +} + +func (ss *schemaServer) DiffSchema(ctx context.Context, req *v1.DiffSchemaRequest) (*v1.DiffSchemaResponse, error) { + atRevision, _, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, err + } + + diff, existingSchema, comparisonSchema, err := schemaDiff(ctx, req.ComparisonSchema, ss.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + resp, err := convertDiff(diff, existingSchema, comparisonSchema, atRevision, ss.caveatTypeSet) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + return resp, nil +} + +func (ss *schemaServer) ComputablePermissions(ctx context.Context, req *v1.ComputablePermissionsRequest) (*v1.ComputablePermissionsResponse, error) { + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) + vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relationName := req.RelationName + if relationName == "" { + relationName = tuple.Ellipsis + } else { + if _, ok := vdef.GetRelation(relationName); !ok { + return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, relationName)) + } + } + + allNamespaces, err := ds.ListAllNamespaces(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + allDefinitions := make([]*core.NamespaceDefinition, 0, len(allNamespaces)) + for _, ns := range allNamespaces { + allDefinitions = append(allDefinitions, ns.Definition) + } + + rg := vdef.Reachability() + rr, err := rg.RelationsEncounteredForSubject(ctx, allDefinitions, &core.RelationReference{ + Namespace: req.DefinitionName, + Relation: relationName, + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations := make([]*v1.ReflectionRelationReference, 0, len(rr)) + for _, r := range rr { + if r.Namespace == req.DefinitionName && r.Relation == req.RelationName { + continue + } + + if req.OptionalDefinitionNameFilter != "" && !strings.HasPrefix(r.Namespace, req.OptionalDefinitionNameFilter) { + continue + } + + ts, err := ts.GetDefinition(ctx, r.Namespace) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations = append(relations, &v1.ReflectionRelationReference{ + DefinitionName: r.Namespace, + RelationName: r.Relation, + IsPermission: ts.IsPermission(r.Relation), + }) + } + + sort.Slice(relations, func(i, j int) bool { + if relations[i].DefinitionName == relations[j].DefinitionName { + return relations[i].RelationName < relations[j].RelationName + } + return relations[i].DefinitionName < relations[j].DefinitionName + }) + + return &v1.ComputablePermissionsResponse{ + Permissions: relations, + ReadAt: revisionReadAt, + }, nil +} + +func (ss *schemaServer) DependentRelations(ctx context.Context, req *v1.DependentRelationsRequest) (*v1.DependentRelationsResponse, error) { + atRevision, revisionReadAt, err := consistency.RevisionFromContext(ctx) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision) + ts := schema.NewTypeSystem(schema.ResolverForDatastoreReader(ds)) + vdef, err := ts.GetValidatedDefinition(ctx, req.DefinitionName) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + _, ok := vdef.GetRelation(req.PermissionName) + if !ok { + return nil, shared.RewriteErrorWithoutConfig(ctx, schema.NewRelationNotFoundErr(req.DefinitionName, req.PermissionName)) + } + + if !vdef.IsPermission(req.PermissionName) { + return nil, shared.RewriteErrorWithoutConfig(ctx, NewNotAPermissionError(req.PermissionName)) + } + + rg := vdef.Reachability() + rr, err := rg.RelationsEncounteredForResource(ctx, &core.RelationReference{ + Namespace: req.DefinitionName, + Relation: req.PermissionName, + }) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations := make([]*v1.ReflectionRelationReference, 0, len(rr)) + for _, r := range rr { + if r.Namespace == req.DefinitionName && r.Relation == req.PermissionName { + continue + } + + ts, err := ts.GetDefinition(ctx, r.Namespace) + if err != nil { + return nil, shared.RewriteErrorWithoutConfig(ctx, err) + } + + relations = append(relations, &v1.ReflectionRelationReference{ + DefinitionName: r.Namespace, + RelationName: r.Relation, + IsPermission: ts.IsPermission(r.Relation), + }) + } + + sort.Slice(relations, func(i, j int) bool { + if relations[i].DefinitionName == relations[j].DefinitionName { + return relations[i].RelationName < relations[j].RelationName + } + + return relations[i].DefinitionName < relations[j].DefinitionName + }) + + return &v1.DependentRelationsResponse{ + Relations: relations, + ReadAt: revisionReadAt, + }, nil +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go b/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go new file mode 100644 index 0000000..ef13a26 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/watch.go @@ -0,0 +1,190 @@ +package v1 + +import ( + "context" + "errors" + "slices" + "time" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" + "github.com/authzed/spicedb/internal/middleware/usagemetrics" + "github.com/authzed/spicedb/internal/services/shared" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/genutil/mapz" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/zedtoken" +) + +type watchServer struct { + v1.UnimplementedWatchServiceServer + shared.WithStreamServiceSpecificInterceptor + + heartbeatDuration time.Duration +} + +// NewWatchServer creates an instance of the watch server. +func NewWatchServer(heartbeatDuration time.Duration) v1.WatchServiceServer { + s := &watchServer{ + WithStreamServiceSpecificInterceptor: shared.WithStreamServiceSpecificInterceptor{ + Stream: grpcvalidate.StreamServerInterceptor(), + }, + heartbeatDuration: heartbeatDuration, + } + return s +} + +func (ws *watchServer) Watch(req *v1.WatchRequest, stream v1.WatchService_WatchServer) error { + if len(req.GetOptionalUpdateKinds()) == 0 || + slices.Contains(req.GetOptionalUpdateKinds(), v1.WatchKind_WATCH_KIND_UNSPECIFIED) || + slices.Contains(req.GetOptionalUpdateKinds(), v1.WatchKind_WATCH_KIND_INCLUDE_RELATIONSHIP_UPDATES) { + if len(req.GetOptionalObjectTypes()) > 0 && len(req.OptionalRelationshipFilters) > 0 { + return status.Errorf(codes.InvalidArgument, "cannot specify both object types and relationship filters") + } + } + + objectTypes := mapz.NewSet[string](req.GetOptionalObjectTypes()...) + + ctx := stream.Context() + ds := datastoremw.MustFromContext(ctx) + + var afterRevision datastore.Revision + if req.OptionalStartCursor != nil && req.OptionalStartCursor.Token != "" { + decodedRevision, err := zedtoken.DecodeRevision(req.OptionalStartCursor, ds) + if err != nil { + return status.Errorf(codes.InvalidArgument, "failed to decode start revision: %s", err) + } + + afterRevision = decodedRevision + } else { + var err error + afterRevision, err = ds.OptimizedRevision(ctx) + if err != nil { + return status.Errorf(codes.Unavailable, "failed to start watch: %s", err) + } + } + + reader := ds.SnapshotReader(afterRevision) + + filters, err := buildRelationshipFilters(req, stream, reader, ws, ctx) + if err != nil { + return err + } + + usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{ + DispatchCount: 1, + }) + + updates, errchan := ds.Watch(ctx, afterRevision, datastore.WatchOptions{ + Content: convertWatchKindToContent(req.OptionalUpdateKinds), + CheckpointInterval: ws.heartbeatDuration, + }) + for { + select { + case update, ok := <-updates: + if ok { + filteredRelationshipUpdates := filterRelationshipUpdates(objectTypes, filters, update.RelationshipChanges) + if len(filteredRelationshipUpdates) > 0 { + converted, err := tuple.UpdatesToV1RelationshipUpdates(filteredRelationshipUpdates) + if err != nil { + return status.Errorf(codes.Internal, "failed to convert updates: %s", err) + } + + if err := stream.Send(&v1.WatchResponse{ + Updates: converted, + ChangesThrough: zedtoken.MustNewFromRevision(update.Revision), + OptionalTransactionMetadata: update.Metadata, + }); err != nil { + return status.Errorf(codes.Canceled, "watch canceled by user: %s", err) + } + } + if len(update.ChangedDefinitions) > 0 || len(update.DeletedCaveats) > 0 || len(update.DeletedNamespaces) > 0 { + if err := stream.Send(&v1.WatchResponse{ + SchemaUpdated: true, + ChangesThrough: zedtoken.MustNewFromRevision(update.Revision), + OptionalTransactionMetadata: update.Metadata, + }); err != nil { + return status.Errorf(codes.Canceled, "watch canceled by user: %s", err) + } + } + if update.IsCheckpoint { + if err := stream.Send(&v1.WatchResponse{ + IsCheckpoint: update.IsCheckpoint, + ChangesThrough: zedtoken.MustNewFromRevision(update.Revision), + OptionalTransactionMetadata: update.Metadata, + }); err != nil { + return status.Errorf(codes.Canceled, "watch canceled by user: %s", err) + } + } + } + case err := <-errchan: + switch { + case errors.As(err, &datastore.WatchCanceledError{}): + return status.Errorf(codes.Canceled, "watch canceled by user: %s", err) + case errors.As(err, &datastore.WatchDisconnectedError{}): + return status.Errorf(codes.ResourceExhausted, "watch disconnected: %s", err) + default: + return status.Errorf(codes.Internal, "watch error: %s", err) + } + } + } +} + +func buildRelationshipFilters(req *v1.WatchRequest, stream v1.WatchService_WatchServer, reader datastore.Reader, ws *watchServer, ctx context.Context) ([]datastore.RelationshipsFilter, error) { + filters := make([]datastore.RelationshipsFilter, 0, len(req.OptionalRelationshipFilters)) + for _, filter := range req.OptionalRelationshipFilters { + if err := validateRelationshipsFilter(stream.Context(), filter, reader); err != nil { + return nil, ws.rewriteError(ctx, err) + } + + dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "failed to parse relationship filter: %s", err) + } + + filters = append(filters, dsFilter) + } + return filters, nil +} + +func (ws *watchServer) rewriteError(ctx context.Context, err error) error { + return shared.RewriteError(ctx, err, &shared.ConfigForErrors{}) +} + +func filterRelationshipUpdates(objectTypes *mapz.Set[string], filters []datastore.RelationshipsFilter, updates []tuple.RelationshipUpdate) []tuple.RelationshipUpdate { + if objectTypes.IsEmpty() && len(filters) == 0 { + return updates + } + + filtered := make([]tuple.RelationshipUpdate, 0, len(updates)) + for _, update := range updates { + objectType := update.Relationship.Resource.ObjectType + if !objectTypes.IsEmpty() && !objectTypes.Has(objectType) { + continue + } + + if len(filters) > 0 { + // If there are filters, we need to check if the update matches any of them. + matched := false + for _, filter := range filters { + if filter.Test(update.Relationship) { + matched = true + break + } + } + + if !matched { + continue + } + } + + filtered = append(filtered, update) + } + + return filtered +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go b/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go new file mode 100644 index 0000000..08910a1 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/v1/watchutil.go @@ -0,0 +1,22 @@ +package v1 + +import ( + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/pkg/datastore" +) + +func convertWatchKindToContent(kinds []v1.WatchKind) datastore.WatchContent { + res := datastore.WatchRelationships + for _, kind := range kinds { + switch kind { + case v1.WatchKind_WATCH_KIND_INCLUDE_RELATIONSHIP_UPDATES: + res |= datastore.WatchRelationships + case v1.WatchKind_WATCH_KIND_INCLUDE_SCHEMA_UPDATES: + res |= datastore.WatchSchema + case v1.WatchKind_WATCH_KIND_INCLUDE_CHECKPOINTS: + res |= datastore.WatchCheckpoints + } + } + return res +} |
