diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/spicedb/internal/services/shared | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/spicedb/internal/services/shared')
3 files changed, 734 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go b/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go new file mode 100644 index 0000000..05b3907 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/errors.go @@ -0,0 +1,208 @@ +package shared + +import ( + "context" + "errors" + "fmt" + "strconv" + + "github.com/rs/zerolog" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/spicedb/internal/dispatch" + "github.com/authzed/spicedb/internal/graph" + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/sharederrors" + "github.com/authzed/spicedb/pkg/cursor" + "github.com/authzed/spicedb/pkg/datastore" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// ErrServiceReadOnly is an extended GRPC error returned when a service is in read-only mode. +var ErrServiceReadOnly = mustMakeStatusReadonly() + +func mustMakeStatusReadonly() error { + status, err := status.New(codes.Unavailable, "service read-only").WithDetails(&errdetails.ErrorInfo{ + Reason: v1.ErrorReason_name[int32(v1.ErrorReason_ERROR_REASON_SERVICE_READ_ONLY)], + Domain: spiceerrors.Domain, + }) + if err != nil { + panic("error constructing shared error type") + } + return status.Err() +} + +// NewSchemaWriteDataValidationError creates a new error representing that a schema write cannot be +// completed due to existing data that would be left unreferenced. +func NewSchemaWriteDataValidationError(message string, args ...any) SchemaWriteDataValidationError { + return SchemaWriteDataValidationError{ + error: fmt.Errorf(message, args...), + } +} + +// SchemaWriteDataValidationError occurs when a schema cannot be applied due to leaving data unreferenced. +type SchemaWriteDataValidationError struct { + error +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err SchemaWriteDataValidationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error) +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err SchemaWriteDataValidationError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.InvalidArgument, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_SCHEMA_TYPE_ERROR, + map[string]string{}, + ), + ) +} + +// MaxDepthExceededError is an error returned when the maximum depth for dispatching has been exceeded. +type MaxDepthExceededError struct { + *spiceerrors.WithAdditionalDetailsError + + // AllowedMaximumDepth is the configured allowed maximum depth. + AllowedMaximumDepth uint32 +} + +// GRPCStatus implements retrieving the gRPC status for the error. +func (err MaxDepthExceededError) GRPCStatus() *status.Status { + return spiceerrors.WithCodeAndDetails( + err, + codes.ResourceExhausted, + spiceerrors.ForReason( + v1.ErrorReason_ERROR_REASON_MAXIMUM_DEPTH_EXCEEDED, + err.AddToDetails(map[string]string{ + "maximum_depth_allowed": strconv.Itoa(int(err.AllowedMaximumDepth)), + }), + ), + ) +} + +// NewMaxDepthExceededError creates a new MaxDepthExceededError. +func NewMaxDepthExceededError(allowedMaximumDepth uint32, isCheckRequest bool) error { + if isCheckRequest { + return MaxDepthExceededError{ + spiceerrors.NewWithAdditionalDetailsError(fmt.Errorf("the check request has exceeded the allowable maximum depth of %d: this usually indicates a recursive or too deep data dependency. Try running zed with --explain to see the dependency. See: https://spicedb.dev/d/debug-max-depth-check", allowedMaximumDepth)), + allowedMaximumDepth, + } + } + + return MaxDepthExceededError{ + spiceerrors.NewWithAdditionalDetailsError(fmt.Errorf("the request has exceeded the allowable maximum depth of %d: this usually indicates a recursive or too deep data dependency. See: https://spicedb.dev/d/debug-max-depth", allowedMaximumDepth)), + allowedMaximumDepth, + } +} + +func AsValidationError(err error) *SchemaWriteDataValidationError { + var validationErr SchemaWriteDataValidationError + if errors.As(err, &validationErr) { + return &validationErr + } + return nil +} + +type ConfigForErrors struct { + MaximumAPIDepth uint32 + DebugTrace *v1.DebugInformation +} + +func RewriteErrorWithoutConfig(ctx context.Context, err error) error { + return rewriteError(ctx, err, nil) +} + +func RewriteError(ctx context.Context, err error, config *ConfigForErrors) error { + rerr := rewriteError(ctx, err, config) + if config != nil && config.DebugTrace != nil { + spiceerrors.WithAdditionalDetails(rerr, spiceerrors.DebugTraceErrorDetailsKey, config.DebugTrace.String()) + } + return rerr +} + +func rewriteError(ctx context.Context, err error, config *ConfigForErrors) error { + // Check if the error can be directly used. + if _, ok := status.FromError(err); ok { + return err + } + + // Otherwise, convert any graph/datastore errors. + var nsNotFoundError sharederrors.UnknownNamespaceError + var relationNotFoundError sharederrors.UnknownRelationError + + var compilerError compiler.BaseCompilerError + var sourceError spiceerrors.WithSourceError + var typeError schema.TypeError + var maxDepthError dispatch.MaxDepthExceededError + + switch { + case errors.As(err, &typeError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_SCHEMA_TYPE_ERROR) + case errors.As(err, &compilerError): + return spiceerrors.WithCodeAndReason(err, codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_SCHEMA_PARSE_ERROR) + case errors.As(err, &sourceError): + return spiceerrors.WithCodeAndReason(err, codes.InvalidArgument, v1.ErrorReason_ERROR_REASON_SCHEMA_PARSE_ERROR) + + case errors.Is(err, cursor.ErrHashMismatch): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_INVALID_CURSOR) + + case errors.As(err, &nsNotFoundError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_DEFINITION) + case errors.As(err, &relationNotFoundError): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_RELATION_OR_PERMISSION) + + case errors.As(err, &maxDepthError): + if config == nil { + return spiceerrors.MustBugf("missing config for API error") + } + + _, isCheckRequest := maxDepthError.Request.(*dispatchv1.DispatchCheckRequest) + return NewMaxDepthExceededError(config.MaximumAPIDepth, isCheckRequest) + + case errors.As(err, &datastore.ReadOnlyError{}): + return ErrServiceReadOnly + case errors.As(err, &datastore.InvalidRevisionError{}): + return status.Errorf(codes.OutOfRange, "invalid zedtoken: %s", err) + case errors.As(err, &datastore.CaveatNameNotFoundError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_UNKNOWN_CAVEAT) + case errors.As(err, &datastore.WatchDisabledError{}): + return status.Errorf(codes.FailedPrecondition, "%s", err) + case errors.As(err, &datastore.CounterAlreadyRegisteredError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_COUNTER_ALREADY_REGISTERED) + case errors.As(err, &datastore.CounterNotRegisteredError{}): + return spiceerrors.WithCodeAndReason(err, codes.FailedPrecondition, v1.ErrorReason_ERROR_REASON_COUNTER_NOT_REGISTERED) + + case errors.As(err, &graph.RelationMissingTypeInfoError{}): + return status.Errorf(codes.FailedPrecondition, "failed precondition: %s", err) + case errors.As(err, &graph.AlwaysFailError{}): + log.Ctx(ctx).Err(err).Msg("received internal error") + return status.Errorf(codes.Internal, "internal error: %s", err) + case errors.As(err, &graph.UnimplementedError{}): + return status.Errorf(codes.Unimplemented, "%s", err) + case errors.Is(err, context.DeadlineExceeded): + return status.Errorf(codes.DeadlineExceeded, "%s", err) + case errors.Is(err, context.Canceled): + err := context.Cause(ctx) + if err != nil { + if _, ok := status.FromError(err); ok { + return err + } + } + + return status.Errorf(codes.Canceled, "%s", err) + default: + log.Ctx(ctx).Err(err).Msg("received unexpected error") + return err + } +} diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go b/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go new file mode 100644 index 0000000..455de0a --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/interceptor.go @@ -0,0 +1,52 @@ +package shared + +import ( + "google.golang.org/grpc" + + "github.com/authzed/spicedb/internal/middleware/servicespecific" +) + +// WithUnaryServiceSpecificInterceptor is a helper to add a unary interceptor or interceptor +// chain to a service. +type WithUnaryServiceSpecificInterceptor struct { + Unary grpc.UnaryServerInterceptor +} + +// UnaryInterceptor implements servicespecific.ExtraUnaryInterceptor +func (wussi WithUnaryServiceSpecificInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor { + return wussi.Unary +} + +// WithStreamServiceSpecificInterceptor is a helper to add a stream interceptor or interceptor +// chain to a service. +type WithStreamServiceSpecificInterceptor struct { + Stream grpc.StreamServerInterceptor +} + +// StreamInterceptor implements servicespecific.ExtraStreamInterceptor +func (wsssi WithStreamServiceSpecificInterceptor) StreamInterceptor() grpc.StreamServerInterceptor { + return wsssi.Stream +} + +// WithServiceSpecificInterceptors is a helper to add both a unary and stream interceptor +// or interceptor chain to a service. +type WithServiceSpecificInterceptors struct { + Unary grpc.UnaryServerInterceptor + Stream grpc.StreamServerInterceptor +} + +// UnaryInterceptor implements servicespecific.ExtraUnaryInterceptor +func (wssi WithServiceSpecificInterceptors) UnaryInterceptor() grpc.UnaryServerInterceptor { + return wssi.Unary +} + +// StreamInterceptor implements servicespecific.ExtraStreamInterceptor +func (wssi WithServiceSpecificInterceptors) StreamInterceptor() grpc.StreamServerInterceptor { + return wssi.Stream +} + +var ( + _ servicespecific.ExtraUnaryInterceptor = WithUnaryServiceSpecificInterceptor{} + _ servicespecific.ExtraUnaryInterceptor = WithServiceSpecificInterceptors{} + _ servicespecific.ExtraStreamInterceptor = WithServiceSpecificInterceptors{} +) diff --git a/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go b/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go new file mode 100644 index 0000000..83accde --- /dev/null +++ b/vendor/github.com/authzed/spicedb/internal/services/shared/schema.go @@ -0,0 +1,474 @@ +package shared + +import ( + "context" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/internal/namespace" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/datastore/queryshape" + caveatdiff "github.com/authzed/spicedb/pkg/diff/caveats" + nsdiff "github.com/authzed/spicedb/pkg/diff/namespace" + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// ValidatedSchemaChanges is a set of validated schema changes that can be applied to the datastore. +type ValidatedSchemaChanges struct { + compiled *compiler.CompiledSchema + validatedTypeSystems map[string]*schema.ValidatedDefinition + newCaveatDefNames *mapz.Set[string] + newObjectDefNames *mapz.Set[string] + additiveOnly bool +} + +// ValidateSchemaChanges validates the schema found in the compiled schema and returns a +// ValidatedSchemaChanges, if fully validated. +func ValidateSchemaChanges(ctx context.Context, compiled *compiler.CompiledSchema, caveatTypeSet *caveattypes.TypeSet, additiveOnly bool) (*ValidatedSchemaChanges, error) { + // 1) Validate the caveats defined. + newCaveatDefNames := mapz.NewSet[string]() + for _, caveatDef := range compiled.CaveatDefinitions { + if err := namespace.ValidateCaveatDefinition(caveatTypeSet, caveatDef); err != nil { + return nil, err + } + + newCaveatDefNames.Insert(caveatDef.Name) + } + + // 2) Validate the namespaces defined. + newObjectDefNames := mapz.NewSet[string]() + validatedTypeSystems := make(map[string]*schema.ValidatedDefinition, len(compiled.ObjectDefinitions)) + res := schema.ResolverForPredefinedDefinitions(schema.PredefinedElements{ + Definitions: compiled.ObjectDefinitions, + Caveats: compiled.CaveatDefinitions, + }) + ts := schema.NewTypeSystem(res) + + for _, nsdef := range compiled.ObjectDefinitions { + vts, err := ts.GetValidatedDefinition(ctx, nsdef.GetName()) + if err != nil { + return nil, err + } + + validatedTypeSystems[nsdef.Name] = vts + newObjectDefNames.Insert(nsdef.Name) + } + + return &ValidatedSchemaChanges{ + compiled: compiled, + validatedTypeSystems: validatedTypeSystems, + newCaveatDefNames: newCaveatDefNames, + newObjectDefNames: newObjectDefNames, + additiveOnly: additiveOnly, + }, nil +} + +// AppliedSchemaChanges holds information about the applied schema changes. +type AppliedSchemaChanges struct { + // TotalOperationCount holds the total number of "dispatch" operations performed by the schema + // being applied. + TotalOperationCount int + + // NewObjectDefNames contains the names of the newly added object definitions. + NewObjectDefNames []string + + // RemovedObjectDefNames contains the names of the removed object definitions. + RemovedObjectDefNames []string + + // NewCaveatDefNames contains the names of the newly added caveat definitions. + NewCaveatDefNames []string + + // RemovedCaveatDefNames contains the names of the removed caveat definitions. + RemovedCaveatDefNames []string +} + +// ApplySchemaChanges applies schema changes found in the validated changes struct, via the specified +// ReadWriteTransaction. +func ApplySchemaChanges(ctx context.Context, rwt datastore.ReadWriteTransaction, caveatTypeSet *caveattypes.TypeSet, validated *ValidatedSchemaChanges) (*AppliedSchemaChanges, error) { + existingCaveats, err := rwt.ListAllCaveats(ctx) + if err != nil { + return nil, err + } + + existingObjectDefs, err := rwt.ListAllNamespaces(ctx) + if err != nil { + return nil, err + } + + return ApplySchemaChangesOverExisting(ctx, rwt, caveatTypeSet, validated, datastore.DefinitionsOf(existingCaveats), datastore.DefinitionsOf(existingObjectDefs)) +} + +// ApplySchemaChangesOverExisting applies schema changes found in the validated changes struct, against +// existing caveat and object definitions given. +func ApplySchemaChangesOverExisting( + ctx context.Context, + rwt datastore.ReadWriteTransaction, + caveatTypeSet *caveattypes.TypeSet, + validated *ValidatedSchemaChanges, + existingCaveats []*core.CaveatDefinition, + existingObjectDefs []*core.NamespaceDefinition, +) (*AppliedSchemaChanges, error) { + // Build a map of existing caveats to determine those being removed, if any. + existingCaveatDefMap := make(map[string]*core.CaveatDefinition, len(existingCaveats)) + existingCaveatDefNames := mapz.NewSet[string]() + + for _, existingCaveat := range existingCaveats { + existingCaveatDefMap[existingCaveat.Name] = existingCaveat + existingCaveatDefNames.Insert(existingCaveat.Name) + } + + // For each caveat definition, perform a diff and ensure the changes will not result in type errors. + caveatDefsWithChanges := make([]*core.CaveatDefinition, 0, len(validated.compiled.CaveatDefinitions)) + for _, caveatDef := range validated.compiled.CaveatDefinitions { + diff, err := sanityCheckCaveatChanges(ctx, rwt, caveatTypeSet, caveatDef, existingCaveatDefMap) + if err != nil { + return nil, err + } + + if len(diff.Deltas()) > 0 { + caveatDefsWithChanges = append(caveatDefsWithChanges, caveatDef) + } + } + + removedCaveatDefNames := existingCaveatDefNames.Subtract(validated.newCaveatDefNames) + + // Build a map of existing definitions to determine those being removed, if any. + existingObjectDefMap := make(map[string]*core.NamespaceDefinition, len(existingObjectDefs)) + existingObjectDefNames := mapz.NewSet[string]() + for _, existingDef := range existingObjectDefs { + existingObjectDefMap[existingDef.Name] = existingDef + existingObjectDefNames.Insert(existingDef.Name) + } + + // For each definition, perform a diff and ensure the changes will not result in any + // breaking changes. + objectDefsWithChanges := make([]*core.NamespaceDefinition, 0, len(validated.compiled.ObjectDefinitions)) + for _, nsdef := range validated.compiled.ObjectDefinitions { + diff, err := sanityCheckNamespaceChanges(ctx, rwt, nsdef, existingObjectDefMap) + if err != nil { + return nil, err + } + + if len(diff.Deltas()) > 0 { + objectDefsWithChanges = append(objectDefsWithChanges, nsdef) + + vts, ok := validated.validatedTypeSystems[nsdef.Name] + if !ok { + return nil, spiceerrors.MustBugf("validated type system not found for namespace `%s`", nsdef.Name) + } + + if err := namespace.AnnotateNamespace(vts); err != nil { + return nil, err + } + } + } + + log.Ctx(ctx). + Trace(). + Int("objectDefinitions", len(validated.compiled.ObjectDefinitions)). + Int("caveatDefinitions", len(validated.compiled.CaveatDefinitions)). + Int("objectDefsWithChanges", len(objectDefsWithChanges)). + Int("caveatDefsWithChanges", len(caveatDefsWithChanges)). + Msg("validated namespace definitions") + + // Ensure that deleting namespaces will not result in any relationships left without associated + // schema. + removedObjectDefNames := existingObjectDefNames.Subtract(validated.newObjectDefNames) + if !validated.additiveOnly { + if err := removedObjectDefNames.ForEach(func(nsdefName string) error { + return ensureNoRelationshipsExist(ctx, rwt, nsdefName) + }); err != nil { + return nil, err + } + } + + // Write the new/changes caveats. + if len(caveatDefsWithChanges) > 0 { + if err := rwt.WriteCaveats(ctx, caveatDefsWithChanges); err != nil { + return nil, err + } + } + + // Write the new/changed namespaces. + if len(objectDefsWithChanges) > 0 { + if err := rwt.WriteNamespaces(ctx, objectDefsWithChanges...); err != nil { + return nil, err + } + } + + if !validated.additiveOnly { + // Delete the removed namespaces. + if removedObjectDefNames.Len() > 0 { + if err := rwt.DeleteNamespaces(ctx, removedObjectDefNames.AsSlice()...); err != nil { + return nil, err + } + } + + // Delete the removed caveats. + if !removedCaveatDefNames.IsEmpty() { + if err := rwt.DeleteCaveats(ctx, removedCaveatDefNames.AsSlice()); err != nil { + return nil, err + } + } + } + + log.Ctx(ctx).Trace(). + Interface("objectDefinitions", validated.compiled.ObjectDefinitions). + Interface("caveatDefinitions", validated.compiled.CaveatDefinitions). + Object("addedOrChangedObjectDefinitions", validated.newObjectDefNames). + Object("removedObjectDefinitions", removedObjectDefNames). + Object("addedOrChangedCaveatDefinitions", validated.newCaveatDefNames). + Object("removedCaveatDefinitions", removedCaveatDefNames). + Msg("completed schema update") + + return &AppliedSchemaChanges{ + TotalOperationCount: len(validated.compiled.ObjectDefinitions) + len(validated.compiled.CaveatDefinitions) + removedObjectDefNames.Len() + removedCaveatDefNames.Len(), + NewObjectDefNames: validated.newObjectDefNames.Subtract(existingObjectDefNames).AsSlice(), + RemovedObjectDefNames: removedObjectDefNames.AsSlice(), + NewCaveatDefNames: validated.newCaveatDefNames.Subtract(existingCaveatDefNames).AsSlice(), + RemovedCaveatDefNames: removedCaveatDefNames.AsSlice(), + }, nil +} + +// sanityCheckCaveatChanges ensures that a caveat definition being written does not break +// the types of the parameters that may already exist on relationships. +func sanityCheckCaveatChanges( + _ context.Context, + _ datastore.ReadWriteTransaction, + caveatTypeSet *caveattypes.TypeSet, + caveatDef *core.CaveatDefinition, + existingDefs map[string]*core.CaveatDefinition, +) (*caveatdiff.Diff, error) { + // Ensure that the updated namespace does not break the existing tuple data. + existing := existingDefs[caveatDef.Name] + diff, err := caveatdiff.DiffCaveats(existing, caveatDef, caveatTypeSet) + if err != nil { + return nil, err + } + + for _, delta := range diff.Deltas() { + switch delta.Type { + case caveatdiff.RemovedParameter: + return diff, NewSchemaWriteDataValidationError("cannot remove parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name) + + case caveatdiff.ParameterTypeChanged: + return diff, NewSchemaWriteDataValidationError("cannot change the type of parameter `%s` on caveat `%s`", delta.ParameterName, caveatDef.Name) + } + } + + return diff, nil +} + +// ensureNoRelationshipsExist ensures that no relationships exist within the namespace with the given name. +func ensureNoRelationshipsExist(ctx context.Context, rwt datastore.ReadWriteTransaction, namespaceName string) error { + qy, qyErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{OptionalResourceType: namespaceName}, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceOfType), + ) + if err := errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete object definition `%s`, as a relationship exists under it", + namespaceName, + ); err != nil { + return err + } + + qy, qyErr = rwt.ReverseQueryRelationships( + ctx, + datastore.SubjectsFilter{ + SubjectType: namespaceName, + }, + options.WithLimitForReverse(options.LimitOne), + options.WithQueryShapeForReverse(queryshape.FindSubjectOfType), + ) + err := errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete object definition `%s`, as a relationship references it", + namespaceName, + ) + if err != nil { + return err + } + + return nil +} + +// sanityCheckNamespaceChanges ensures that a namespace definition being written does not result +// in breaking changes, such as relationships without associated defined schema object definitions +// and relations. +func sanityCheckNamespaceChanges( + ctx context.Context, + rwt datastore.ReadWriteTransaction, + nsdef *core.NamespaceDefinition, + existingDefs map[string]*core.NamespaceDefinition, +) (*nsdiff.Diff, error) { + // Ensure that the updated namespace does not break the existing tuple data. + existing := existingDefs[nsdef.Name] + diff, err := nsdiff.DiffNamespaces(existing, nsdef) + if err != nil { + return nil, err + } + + for _, delta := range diff.Deltas() { + switch delta.Type { + case nsdiff.RemovedRelation: + // NOTE: We add the subject filters here to ensure the reverse relationship index is used + // by the datastores. As there is no index that has {namespace, relation} directly, but there + // *is* an index that has {subject_namespace, subject_relation, namespace, relation}, we can + // force the datastore to use the reverse index by adding the subject filters. + var previousRelation *core.Relation + for _, relation := range existing.Relation { + if relation.Name == delta.RelationName { + previousRelation = relation + break + } + } + + if previousRelation == nil { + return nil, spiceerrors.MustBugf("relation `%s` not found in existing namespace definition", delta.RelationName) + } + + subjectSelectors := make([]datastore.SubjectsSelector, 0, len(previousRelation.TypeInformation.AllowedDirectRelations)) + for _, allowedType := range previousRelation.TypeInformation.AllowedDirectRelations { + if allowedType.GetRelation() == datastore.Ellipsis { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: allowedType.Namespace, + RelationFilter: datastore.SubjectRelationFilter{ + IncludeEllipsisRelation: true, + }, + }) + } else { + subjectSelectors = append(subjectSelectors, datastore.SubjectsSelector{ + OptionalSubjectType: allowedType.Namespace, + RelationFilter: datastore.SubjectRelationFilter{ + NonEllipsisRelation: allowedType.GetRelation(), + }, + }) + } + } + + qy, qyErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{ + OptionalResourceType: nsdef.Name, + OptionalResourceRelation: delta.RelationName, + OptionalSubjectsSelectors: subjectSelectors, + }, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceOfTypeAndRelation), + ) + + err = errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete relation `%s` in object definition `%s`, as a relationship exists under it", delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + + // Also check for right sides of tuples. + qy, qyErr = rwt.ReverseQueryRelationships( + ctx, + datastore.SubjectsFilter{ + SubjectType: nsdef.Name, + RelationFilter: datastore.SubjectRelationFilter{ + NonEllipsisRelation: delta.RelationName, + }, + }, + options.WithLimitForReverse(options.LimitOne), + options.WithQueryShapeForReverse(queryshape.FindSubjectOfTypeAndRelation), + ) + err = errorIfTupleIteratorReturnsTuples( + ctx, + qy, + qyErr, + "cannot delete relation `%s` in object definition `%s`, as a relationship references it", delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + + case nsdiff.RelationAllowedTypeRemoved: + var optionalSubjectIds []string + var relationFilter datastore.SubjectRelationFilter + var optionalCaveatNameFilter datastore.CaveatNameFilter + + if delta.AllowedType.GetPublicWildcard() != nil { + optionalSubjectIds = []string{tuple.PublicWildcard} + } else { + relationFilter = datastore.SubjectRelationFilter{ + NonEllipsisRelation: delta.AllowedType.GetRelation(), + } + } + + if delta.AllowedType.GetRequiredCaveat() != nil && delta.AllowedType.GetRequiredCaveat().CaveatName != "" { + optionalCaveatNameFilter = datastore.WithCaveatName(delta.AllowedType.GetRequiredCaveat().CaveatName) + } else { + optionalCaveatNameFilter = datastore.WithNoCaveat() + } + + expirationOption := datastore.ExpirationFilterOptionNoExpiration + if delta.AllowedType.RequiredExpiration != nil { + expirationOption = datastore.ExpirationFilterOptionHasExpiration + } + + qyr, qyrErr := rwt.QueryRelationships( + ctx, + datastore.RelationshipsFilter{ + OptionalResourceType: nsdef.Name, + OptionalResourceRelation: delta.RelationName, + OptionalSubjectsSelectors: []datastore.SubjectsSelector{ + { + OptionalSubjectType: delta.AllowedType.Namespace, + OptionalSubjectIds: optionalSubjectIds, + RelationFilter: relationFilter, + }, + }, + OptionalCaveatNameFilter: optionalCaveatNameFilter, + OptionalExpirationOption: expirationOption, + }, + options.WithLimit(options.LimitOne), + options.WithQueryShape(queryshape.FindResourceRelationForSubjectRelation), + ) + err = errorIfTupleIteratorReturnsTuples( + ctx, + qyr, + qyrErr, + "cannot remove allowed type `%s` from relation `%s` in object definition `%s`, as a relationship exists with it", + schema.SourceForAllowedRelation(delta.AllowedType), delta.RelationName, nsdef.Name) + if err != nil { + return diff, err + } + } + } + return diff, nil +} + +// errorIfTupleIteratorReturnsTuples takes a tuple iterator and any error that was generated +// when the original iterator was created, and returns an error if iterator contains any tuples. +func errorIfTupleIteratorReturnsTuples(_ context.Context, qy datastore.RelationshipIterator, qyErr error, message string, args ...interface{}) error { + if qyErr != nil { + return qyErr + } + + for _, err := range qy { + if err != nil { + return err + } + return NewSchemaWriteDataValidationError(message, args...) + } + + return nil +} |
