diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/spicedb/pkg/schema | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/schema')
11 files changed, 2505 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/arrows.go b/vendor/github.com/authzed/spicedb/pkg/schema/arrows.go new file mode 100644 index 0000000..021f282 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/arrows.go @@ -0,0 +1,157 @@ +package schema + +import ( + "context" + "errors" + "fmt" + "strconv" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// ArrowInformation holds information about an arrow (TupleToUserset) in the schema. +type ArrowInformation struct { + Arrow *core.TupleToUserset + Path string + ParentRelationName string +} + +// ArrowSet represents a set of all the arrows (TupleToUserset's) found in the schema. +type ArrowSet struct { + res FullSchemaResolver + ts *TypeSystem + arrowsByFullTuplesetRelation *mapz.MultiMap[string, ArrowInformation] + arrowsByComputedUsersetNamespaceAndRelation *mapz.MultiMap[string, ArrowInformation] + reachableComputedUsersetRelationsByTuplesetRelation *mapz.MultiMap[string, string] +} + +// buildArrowSet builds a new set of all arrows found in the given schema. +func buildArrowSet(ctx context.Context, res FullSchemaResolver) (*ArrowSet, error) { + arrowSet := &ArrowSet{ + res: res, + ts: NewTypeSystem(res), + arrowsByFullTuplesetRelation: mapz.NewMultiMap[string, ArrowInformation](), + arrowsByComputedUsersetNamespaceAndRelation: mapz.NewMultiMap[string, ArrowInformation](), + reachableComputedUsersetRelationsByTuplesetRelation: mapz.NewMultiMap[string, string](), + } + if err := arrowSet.compute(ctx); err != nil { + return nil, err + } + return arrowSet, nil +} + +// AllReachableRelations returns all relations reachable through arrows, including tupleset relations +// and computed userset relations. +func (as *ArrowSet) AllReachableRelations() *mapz.Set[string] { + c := mapz.NewSet(as.reachableComputedUsersetRelationsByTuplesetRelation.Values()...) + c.Extend(as.arrowsByFullTuplesetRelation.Keys()) + return c +} + +// HasPossibleArrowWithComputedUserset returns true if there is a *possible* arrow with the given relation name +// as the arrow's computed userset/for a subject type that has the given namespace. +func (as *ArrowSet) HasPossibleArrowWithComputedUserset(namespaceName string, relationName string) bool { + return as.arrowsByComputedUsersetNamespaceAndRelation.Has(namespaceName + "#" + relationName) +} + +// LookupTuplesetArrows finds all arrows with the given namespace and relation name as the arrows' tupleset. +func (as *ArrowSet) LookupTuplesetArrows(namespaceName string, relationName string) []ArrowInformation { + key := namespaceName + "#" + relationName + found, _ := as.arrowsByFullTuplesetRelation.Get(key) + return found +} + +func (as *ArrowSet) compute(ctx context.Context) error { + for _, name := range as.res.AllDefinitionNames() { + def, err := as.ts.GetValidatedDefinition(ctx, name) + if err != nil { + return err + } + for _, relation := range def.nsDef.Relation { + if err := as.collectArrowInformationForRelation(ctx, def, relation.Name); err != nil { + return err + } + } + } + return nil +} + +func (as *ArrowSet) add(ttu *core.TupleToUserset, path string, namespaceName string, relationName string) { + tsKey := namespaceName + "#" + ttu.Tupleset.Relation + as.arrowsByFullTuplesetRelation.Add(tsKey, ArrowInformation{Path: path, Arrow: ttu, ParentRelationName: relationName}) +} + +func (as *ArrowSet) collectArrowInformationForRelation(ctx context.Context, def *ValidatedDefinition, relationName string) error { + if !def.IsPermission(relationName) { + return nil + } + + relation, ok := def.GetRelation(relationName) + if !ok { + return asTypeError(NewRelationNotFoundErr(def.Namespace().GetName(), relationName)) + } + return as.collectArrowInformationForRewrite(ctx, relation.UsersetRewrite, def, relation, relationName) +} + +func (as *ArrowSet) collectArrowInformationForRewrite(ctx context.Context, rewrite *core.UsersetRewrite, def *ValidatedDefinition, relation *core.Relation, path string) error { + switch rw := rewrite.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + return as.collectArrowInformationForSetOperation(ctx, rw.Union, def, relation, path) + case *core.UsersetRewrite_Intersection: + return as.collectArrowInformationForSetOperation(ctx, rw.Intersection, def, relation, path) + case *core.UsersetRewrite_Exclusion: + return as.collectArrowInformationForSetOperation(ctx, rw.Exclusion, def, relation, path) + default: + return errors.New("userset rewrite operation not implemented in addArrowRelationsForRewrite") + } +} + +func (as *ArrowSet) collectArrowInformationForSetOperation(ctx context.Context, so *core.SetOperation, def *ValidatedDefinition, relation *core.Relation, path string) error { + for index, childOneof := range so.Child { + updatedPath := path + "." + strconv.Itoa(index) + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_ComputedUserset: + // Nothing to do + + case *core.SetOperation_Child_UsersetRewrite: + err := as.collectArrowInformationForRewrite(ctx, child.UsersetRewrite, def, relation, updatedPath) + if err != nil { + return err + } + + case *core.SetOperation_Child_TupleToUserset: + as.add(child.TupleToUserset, updatedPath, def.Namespace().Name, relation.Name) + + allowedSubjectTypes, err := def.AllowedSubjectRelations(child.TupleToUserset.Tupleset.Relation) + if err != nil { + return err + } + + for _, ast := range allowedSubjectTypes { + def, err := as.ts.GetValidatedDefinition(ctx, ast.Namespace) + if err != nil { + return err + } + + // NOTE: this is explicitly added to the arrowsByComputedUsersetNamespaceAndRelation without + // checking if the relation/permission exists, because its needed for schema diff tracking. + as.arrowsByComputedUsersetNamespaceAndRelation.Add(ast.Namespace+"#"+child.TupleToUserset.ComputedUserset.Relation, ArrowInformation{Path: path, Arrow: child.TupleToUserset, ParentRelationName: relation.Name}) + if def.HasRelation(child.TupleToUserset.ComputedUserset.Relation) { + as.reachableComputedUsersetRelationsByTuplesetRelation.Add(ast.Namespace+"#"+child.TupleToUserset.Tupleset.Relation, ast.Namespace+"#"+child.TupleToUserset.ComputedUserset.Relation) + } + } + + case *core.SetOperation_Child_XThis: + // Nothing to do + + case *core.SetOperation_Child_XNil: + // Nothing to do + + default: + return fmt.Errorf("unknown set operation child `%T` in addArrowRelationsInSetOperation", child) + } + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/compiled_schema_resolver.go b/vendor/github.com/authzed/spicedb/pkg/schema/compiled_schema_resolver.go new file mode 100644 index 0000000..1bf40cc --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/compiled_schema_resolver.go @@ -0,0 +1,60 @@ +package schema + +import ( + "context" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" +) + +// FullSchemaResolver is a superset of a resolver that knows how to retrieve all definitions +// from its source by name (by having a complete list of names). +type FullSchemaResolver interface { + Resolver + AllDefinitionNames() []string +} + +// CompiledSchemaResolver is a resolver for a fully compiled schema. It implements FullSchemaResolver, +// as it has the full context of the schema. +type CompiledSchemaResolver struct { + schema compiler.CompiledSchema +} + +// ResolverForCompiledSchema builds a resolver from a compiled schema. +func ResolverForCompiledSchema(schema compiler.CompiledSchema) *CompiledSchemaResolver { + return &CompiledSchemaResolver{ + schema: schema, + } +} + +var _ FullSchemaResolver = &CompiledSchemaResolver{} + +// LookupDefinition lookups up a namespace, also returning whether it was pre-validated. +func (c CompiledSchemaResolver) LookupDefinition(ctx context.Context, name string) (*core.NamespaceDefinition, bool, error) { + for _, o := range c.schema.ObjectDefinitions { + if o.GetName() == name { + return o, false, nil + } + } + return nil, false, asTypeError(NewDefinitionNotFoundErr(name)) +} + +// LookupCaveat lookups up a caveat. +func (c CompiledSchemaResolver) LookupCaveat(ctx context.Context, name string) (*Caveat, error) { + for _, v := range c.schema.CaveatDefinitions { + if v.GetName() == name { + return v, nil + } + } + return nil, asTypeError(NewCaveatNotFoundErr(name)) +} + +// AllDefinitionNames returns a list of all the names of defined namespaces for this resolved schema. +// Every definition is a valid parameter for LookupDefinition +func (c CompiledSchemaResolver) AllDefinitionNames() []string { + out := make([]string, len(c.schema.ObjectDefinitions)) + for i, o := range c.schema.ObjectDefinitions { + out[i] = o.GetName() + } + return out +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/definition.go b/vendor/github.com/authzed/spicedb/pkg/schema/definition.go new file mode 100644 index 0000000..83ad73d --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/definition.go @@ -0,0 +1,386 @@ +package schema + +import ( + "fmt" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + nspkg "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// AllowedDirectRelation indicates whether a relation is allowed on the right side of another relation. +type AllowedDirectRelation int + +const ( + // UnknownIfRelationAllowed indicates that no type information is defined for + // this relation. + UnknownIfRelationAllowed AllowedDirectRelation = iota + + // DirectRelationValid indicates that the specified subject relation is valid as + // part of a *direct* tuple on the relation. + DirectRelationValid + + // DirectRelationNotValid indicates that the specified subject relation is not + // valid as part of a *direct* tuple on the relation. + DirectRelationNotValid +) + +// AllowedPublicSubject indicates whether a public subject of a particular kind is allowed on the right side of another relation. +type AllowedPublicSubject int + +const ( + // UnknownIfPublicAllowed indicates that no type information is defined for + // this relation. + UnknownIfPublicAllowed AllowedPublicSubject = iota + + // PublicSubjectAllowed indicates that the specified subject wildcard is valid as + // part of a *direct* tuple on the relation. + PublicSubjectAllowed + + // PublicSubjectNotAllowed indicates that the specified subject wildcard is not + // valid as part of a *direct* tuple on the relation. + PublicSubjectNotAllowed +) + +// AllowedRelationOption indicates whether an allowed relation of a particular kind is allowed on the right side of another relation. +type AllowedRelationOption int + +const ( + // UnknownIfAllowed indicates that no type information is defined for + // this relation. + UnknownIfAllowed AllowedRelationOption = iota + + // AllowedRelationValid indicates that the specified subject relation is valid. + AllowedRelationValid + + // AllowedRelationNotValid indicates that the specified subject relation is not valid. + AllowedRelationNotValid +) + +// AllowedDefinitionOption indicates whether an allowed definition of a particular kind is allowed on the right side of another relation. +type AllowedDefinitionOption int + +const ( + // UnknownIfAllowedDefinition indicates that no type information is defined for + // this relation. + UnknownIfAllowedDefinition AllowedDefinitionOption = iota + + // AllowedDefinitionValid indicates that the specified subject definition is valid. + AllowedDefinitionValid + + // AllowedDefinitionNotValid indicates that the specified subject definition is not valid. + AllowedDefinitionNotValid +) + +// NewDefinition returns a new type definition for the given definition proto. +func NewDefinition(ts *TypeSystem, nsDef *core.NamespaceDefinition) (*Definition, error) { + relationMap := make(map[string]*core.Relation, len(nsDef.GetRelation())) + for _, relation := range nsDef.GetRelation() { + _, existing := relationMap[relation.Name] + if existing { + return nil, NewTypeWithSourceError( + NewDuplicateRelationError(nsDef.Name, relation.Name), + relation, + relation.Name, + ) + } + + relationMap[relation.Name] = relation + } + + return &Definition{ + ts: ts, + nsDef: nsDef, + relationMap: relationMap, + }, nil +} + +// Definition represents typing information found in a definition. +// It also provides better ergonomic accessors to the defintion's type information. +type Definition struct { + ts *TypeSystem + nsDef *core.NamespaceDefinition + relationMap map[string]*core.Relation +} + +// Namespace is the NamespaceDefinition for which the type system was constructed. +func (def *Definition) Namespace() *core.NamespaceDefinition { + return def.nsDef +} + +// TypeSystem returns the typesystem for this definition +func (def *Definition) TypeSystem() *TypeSystem { + return def.ts +} + +// HasTypeInformation returns true if the relation with the given name exists and has type +// information defined. +func (def *Definition) HasTypeInformation(relationName string) bool { + rel, ok := def.relationMap[relationName] + return ok && rel.GetTypeInformation() != nil +} + +// HasRelation returns true if the definition has the given relation defined. +func (def *Definition) HasRelation(relationName string) bool { + _, ok := def.relationMap[relationName] + return ok +} + +// GetRelation returns the relation that's defined with the give name in the type system or returns false. +func (def *Definition) GetRelation(relationName string) (*core.Relation, bool) { + rel, ok := def.relationMap[relationName] + return rel, ok +} + +// IsPermission returns true if the definition has the given relation defined and it is +// a permission. +func (def *Definition) IsPermission(relationName string) bool { + found, ok := def.relationMap[relationName] + if !ok { + return false + } + + return nspkg.GetRelationKind(found) == iv1.RelationMetadata_PERMISSION +} + +// GetAllowedDirectNamespaceSubjectRelations returns the subject relations for the target definition, if it is defined as appearing +// somewhere on the right side of a relation (except wildcards). Returns nil if there is no type information or it is not allowed. +func (def *Definition) GetAllowedDirectNamespaceSubjectRelations(sourceRelationName string, targetNamespaceName string) (*mapz.Set[string], error) { + found, ok := def.relationMap[sourceRelationName] + if !ok { + return nil, asTypeError(NewRelationNotFoundErr(def.nsDef.Name, sourceRelationName)) + } + + typeInfo := found.GetTypeInformation() + if typeInfo == nil { + return nil, nil + } + + allowedRelations := typeInfo.GetAllowedDirectRelations() + allowedSubjectRelations := mapz.NewSet[string]() + for _, allowedRelation := range allowedRelations { + if allowedRelation.GetNamespace() == targetNamespaceName && allowedRelation.GetPublicWildcard() == nil { + allowedSubjectRelations.Add(allowedRelation.GetRelation()) + } + } + + return allowedSubjectRelations, nil +} + +// IsAllowedDirectNamespace returns whether the target definition is defined as appearing somewhere on the +// right side of a relation (except public). +func (def *Definition) IsAllowedDirectNamespace(sourceRelationName string, targetNamespaceName string) (AllowedDefinitionOption, error) { + found, ok := def.relationMap[sourceRelationName] + if !ok { + return UnknownIfAllowedDefinition, asTypeError(NewRelationNotFoundErr(def.nsDef.Name, sourceRelationName)) + } + + typeInfo := found.GetTypeInformation() + if typeInfo == nil { + return UnknownIfAllowedDefinition, nil + } + + allowedRelations := typeInfo.GetAllowedDirectRelations() + for _, allowedRelation := range allowedRelations { + if allowedRelation.GetNamespace() == targetNamespaceName && allowedRelation.GetPublicWildcard() == nil { + return AllowedDefinitionValid, nil + } + } + + return AllowedDefinitionNotValid, nil +} + +// IsAllowedPublicNamespace returns whether the target definition is defined as public on the source relation. +func (def *Definition) IsAllowedPublicNamespace(sourceRelationName string, targetNamespaceName string) (AllowedPublicSubject, error) { + found, ok := def.relationMap[sourceRelationName] + if !ok { + return UnknownIfPublicAllowed, asTypeError(NewRelationNotFoundErr(def.nsDef.Name, sourceRelationName)) + } + + typeInfo := found.GetTypeInformation() + if typeInfo == nil { + return UnknownIfPublicAllowed, nil + } + + allowedRelations := typeInfo.GetAllowedDirectRelations() + for _, allowedRelation := range allowedRelations { + if allowedRelation.GetNamespace() == targetNamespaceName && allowedRelation.GetPublicWildcard() != nil { + return PublicSubjectAllowed, nil + } + } + + return PublicSubjectNotAllowed, nil +} + +// IsAllowedDirectRelation returns whether the subject relation is allowed to appear on the right +// hand side of a tuple placed in the source relation with the given name. +func (def *Definition) IsAllowedDirectRelation(sourceRelationName string, targetNamespaceName string, targetRelationName string) (AllowedDirectRelation, error) { + found, ok := def.relationMap[sourceRelationName] + if !ok { + return UnknownIfRelationAllowed, asTypeError(NewRelationNotFoundErr(def.nsDef.Name, sourceRelationName)) + } + + typeInfo := found.GetTypeInformation() + if typeInfo == nil { + return UnknownIfRelationAllowed, nil + } + + allowedRelations := typeInfo.GetAllowedDirectRelations() + for _, allowedRelation := range allowedRelations { + if allowedRelation.GetNamespace() == targetNamespaceName && allowedRelation.GetRelation() == targetRelationName { + return DirectRelationValid, nil + } + } + + return DirectRelationNotValid, nil +} + +// HasAllowedRelation returns whether the source relation has the given allowed relation. +func (def *Definition) HasAllowedRelation(sourceRelationName string, toCheck *core.AllowedRelation) (AllowedRelationOption, error) { + found, ok := def.relationMap[sourceRelationName] + if !ok { + return UnknownIfAllowed, asTypeError(NewRelationNotFoundErr(def.nsDef.Name, sourceRelationName)) + } + + typeInfo := found.GetTypeInformation() + if typeInfo == nil { + return UnknownIfAllowed, nil + } + + allowedRelations := typeInfo.GetAllowedDirectRelations() + for _, allowedRelation := range allowedRelations { + if SourceForAllowedRelation(allowedRelation) == SourceForAllowedRelation(toCheck) { + return AllowedRelationValid, nil + } + } + + return AllowedRelationNotValid, nil +} + +// AllowedDirectRelationsAndWildcards returns the allowed subject relations for a source relation. +// Note that this function will return wildcards. +func (def *Definition) AllowedDirectRelationsAndWildcards(sourceRelationName string) ([]*core.AllowedRelation, error) { + found, ok := def.relationMap[sourceRelationName] + if !ok { + return []*core.AllowedRelation{}, asTypeError(NewRelationNotFoundErr(def.nsDef.Name, sourceRelationName)) + } + + typeInfo := found.GetTypeInformation() + if typeInfo == nil { + return []*core.AllowedRelation{}, nil + } + + return typeInfo.GetAllowedDirectRelations(), nil +} + +// AllowedSubjectRelations returns the allowed subject relations for a source relation. Note that this function will *not* +// return wildcards, and returns without the marked caveats and expiration. +func (def *Definition) AllowedSubjectRelations(sourceRelationName string) ([]*core.RelationReference, error) { + allowedDirect, err := def.AllowedDirectRelationsAndWildcards(sourceRelationName) + if err != nil { + return []*core.RelationReference{}, asTypeError(err) + } + + filtered := make([]*core.RelationReference, 0, len(allowedDirect)) + for _, allowed := range allowedDirect { + if allowed.GetPublicWildcard() != nil { + continue + } + + if allowed.GetRelation() == "" { + return nil, spiceerrors.MustBugf("got an empty relation for a non-wildcard type definition under namespace") + } + + filtered = append(filtered, &core.RelationReference{ + Namespace: allowed.GetNamespace(), + Relation: allowed.GetRelation(), + }) + } + return filtered, nil +} + +// WildcardTypeReference represents a relation that references a wildcard type. +type WildcardTypeReference struct { + // ReferencingRelation is the relation referencing the wildcard type. + ReferencingRelation *core.RelationReference + + // WildcardType is the wildcard type referenced. + WildcardType *core.AllowedRelation +} + +// SourceForAllowedRelation returns the source code representation of an allowed relation. +func SourceForAllowedRelation(allowedRelation *core.AllowedRelation) string { + caveatAndTraitsStr := "" + + hasCaveat := allowedRelation.GetRequiredCaveat() != nil + hasExpirationTrait := allowedRelation.GetRequiredExpiration() != nil + hasTraits := hasCaveat || hasExpirationTrait + + if hasTraits { + caveatAndTraitsStr = " with " + if hasCaveat { + caveatAndTraitsStr += allowedRelation.RequiredCaveat.CaveatName + } + + if hasCaveat && hasExpirationTrait { + caveatAndTraitsStr += " and " + } + + if hasExpirationTrait { + caveatAndTraitsStr += "expiration" + } + } + + if allowedRelation.GetPublicWildcard() != nil { + return tuple.JoinObjectRef(allowedRelation.Namespace, "*") + caveatAndTraitsStr + } + + if rel := allowedRelation.GetRelation(); rel != tuple.Ellipsis { + return tuple.JoinRelRef(allowedRelation.Namespace, rel) + caveatAndTraitsStr + } + + return allowedRelation.Namespace + caveatAndTraitsStr +} + +// RelationDoesNotAllowCaveatsOrTraitsForSubject returns true if and only if it can be conclusively determined that +// the given subject type does not accept any caveats or traits on the given relation. If the relation does not have type information, +// returns an error. +func (def *Definition) RelationDoesNotAllowCaveatsOrTraitsForSubject(relationName string, subjectTypeName string) (bool, error) { + relation, ok := def.relationMap[relationName] + if !ok { + return false, NewRelationNotFoundErr(def.nsDef.Name, relationName) + } + + typeInfo := relation.GetTypeInformation() + if typeInfo == nil { + return false, NewTypeWithSourceError( + fmt.Errorf("relation `%s` does not have type information", relationName), + relation, relationName, + ) + } + + foundSubjectType := false + for _, allowedRelation := range typeInfo.GetAllowedDirectRelations() { + if allowedRelation.GetNamespace() == subjectTypeName { + foundSubjectType = true + if allowedRelation.GetRequiredCaveat() != nil && allowedRelation.GetRequiredCaveat().CaveatName != "" { + return false, nil + } + if allowedRelation.GetRequiredExpiration() != nil { + return false, nil + } + } + } + + if !foundSubjectType { + return false, NewTypeWithSourceError( + fmt.Errorf("relation `%s` does not allow subject type `%s`", relationName, subjectTypeName), + relation, relationName, + ) + } + + return true, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/doc.go b/vendor/github.com/authzed/spicedb/pkg/schema/doc.go new file mode 100644 index 0000000..b2cdb65 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/doc.go @@ -0,0 +1,2 @@ +// Package schema contains code that manipulates a schema and knows how to traverse it. +package schema diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/errors.go b/vendor/github.com/authzed/spicedb/pkg/schema/errors.go new file mode 100644 index 0000000..e81f315 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/errors.go @@ -0,0 +1,406 @@ +package schema + +import ( + "errors" + "fmt" + "strings" + + "github.com/rs/zerolog" + + "github.com/authzed/spicedb/internal/sharederrors" + nspkg "github.com/authzed/spicedb/pkg/namespace" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// DefinitionNotFoundError occurs when a definition was not found. +type DefinitionNotFoundError struct { + error + definitionName string +} + +// NotFoundNamespaceName is the name of the definition not found. +func (err DefinitionNotFoundError) NotFoundNamespaceName() string { + return err.definitionName +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err DefinitionNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err DefinitionNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + } +} + +// RelationNotFoundError occurs when a relation was not found under a definition. +type RelationNotFoundError struct { + error + definitionName string + relationName string +} + +// NamespaceName returns the name of the definition in which the relation was not found. +func (err RelationNotFoundError) NamespaceName() string { + return err.definitionName +} + +// NotFoundRelationName returns the name of the relation not found. +func (err RelationNotFoundError) NotFoundRelationName() string { + return err.relationName +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err RelationNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err RelationNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + "relation_or_permission_name": err.relationName, + } +} + +// CaveatNotFoundError occurs when a caveat was not found. +type CaveatNotFoundError struct { + error + caveatName string +} + +// CaveatName returns the name of the caveat not found. +func (err CaveatNotFoundError) CaveatName() string { + return err.caveatName +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err CaveatNotFoundError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("caveat", err.caveatName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err CaveatNotFoundError) DetailsMetadata() map[string]string { + return map[string]string{ + "caveat_name": err.caveatName, + } +} + +// DuplicateRelationError occurs when a duplicate relation was found inside a definition. +type DuplicateRelationError struct { + error + definitionName string + relationName string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err DuplicateRelationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err DuplicateRelationError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + "relation_or_permission_name": err.relationName, + } +} + +// PermissionUsedOnLeftOfArrowError occurs when a permission is used on the left side of an arrow +// expression. +type PermissionUsedOnLeftOfArrowError struct { + error + definitionName string + parentPermissionName string + foundPermissionName string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err PermissionUsedOnLeftOfArrowError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName).Str("permission", err.parentPermissionName).Str("usedPermission", err.foundPermissionName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err PermissionUsedOnLeftOfArrowError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + "permission_name": err.parentPermissionName, + "used_permission_name": err.foundPermissionName, + } +} + +// WildcardUsedInArrowError occurs when an arrow operates over a relation that contains a wildcard. +type WildcardUsedInArrowError struct { + error + definitionName string + parentPermissionName string + accessedRelationName string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err WildcardUsedInArrowError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName).Str("parentPermissionName", err.parentPermissionName).Str("accessedRelationName", err.accessedRelationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err WildcardUsedInArrowError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + "permission_name": err.parentPermissionName, + "accessed_relation_name": err.accessedRelationName, + } +} + +// MissingAllowedRelationsError occurs when a relation is defined without any type information. +type MissingAllowedRelationsError struct { + error + definitionName string + relationName string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err MissingAllowedRelationsError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err MissingAllowedRelationsError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + "relation_name": err.relationName, + } +} + +// TransitiveWildcardError occurs when a wildcard relation in turn references another wildcard +// relation. +type TransitiveWildcardError struct { + error + definitionName string + relationName string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err TransitiveWildcardError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName).Str("relation", err.relationName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err TransitiveWildcardError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + "relation_name": err.relationName, + } +} + +// PermissionsCycleError occurs when a cycle exists within permissions. +type PermissionsCycleError struct { + error + definitionName string + permissionNames []string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err PermissionsCycleError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName).Str("permissions", strings.Join(err.permissionNames, ", ")) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err PermissionsCycleError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + "permission_names": strings.Join(err.permissionNames, ","), + } +} + +// DuplicateAllowedRelationError indicates that an allowed relation was redefined on a relation. +type DuplicateAllowedRelationError struct { + error + definitionName string + relationName string + allowedRelationSource string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err DuplicateAllowedRelationError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("definition", err.definitionName).Str("relation", err.relationName).Str("allowed-relation", err.allowedRelationSource) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err DuplicateAllowedRelationError) DetailsMetadata() map[string]string { + return map[string]string{ + "definition_name": err.definitionName, + "relation_name": err.relationName, + "allowed_relation": err.allowedRelationSource, + } +} + +// UnusedCaveatParameterError indicates that a caveat parameter is unused in the caveat expression. +type UnusedCaveatParameterError struct { + error + caveatName string + paramName string +} + +// MarshalZerologObject implements zerolog object marshalling. +func (err UnusedCaveatParameterError) MarshalZerologObject(e *zerolog.Event) { + e.Err(err.error).Str("caveat", err.caveatName).Str("param", err.paramName) +} + +// DetailsMetadata returns the metadata for details for this error. +func (err UnusedCaveatParameterError) DetailsMetadata() map[string]string { + return map[string]string{ + "caveat_name": err.caveatName, + "parameter_name": err.paramName, + } +} + +// NewDefinitionNotFoundErr constructs a new definition not found error. +func NewDefinitionNotFoundErr(nsName string) error { + return DefinitionNotFoundError{ + error: fmt.Errorf("object definition `%s` not found", nsName), + definitionName: nsName, + } +} + +// NewRelationNotFoundErr constructs a new relation not found error. +func NewRelationNotFoundErr(nsName string, relationName string) error { + return RelationNotFoundError{ + error: fmt.Errorf("relation/permission `%s` not found under definition `%s`", relationName, nsName), + definitionName: nsName, + relationName: relationName, + } +} + +// NewCaveatNotFoundErr constructs a new caveat not found error. +func NewCaveatNotFoundErr(caveatName string) error { + return CaveatNotFoundError{ + error: fmt.Errorf("caveat `%s` not found", caveatName), + caveatName: caveatName, + } +} + +// NewDuplicateRelationError constructs an error indicating that a relation was defined more than once in a definition. +func NewDuplicateRelationError(nsName string, relationName string) error { + return DuplicateRelationError{ + error: fmt.Errorf("found duplicate relation/permission name `%s` under definition `%s`", relationName, nsName), + definitionName: nsName, + relationName: relationName, + } +} + +// NewDuplicateAllowedRelationErr constructs an error indicating that an allowed relation was defined more than once for a relation. +func NewDuplicateAllowedRelationErr(nsName string, relationName string, allowedRelationSource string) error { + return DuplicateAllowedRelationError{ + error: fmt.Errorf("found duplicate allowed subject type `%s` on relation `%s` under definition `%s`", allowedRelationSource, relationName, nsName), + definitionName: nsName, + relationName: relationName, + allowedRelationSource: allowedRelationSource, + } +} + +// NewPermissionUsedOnLeftOfArrowErr constructs an error indicating that a permission was used on the left side of an arrow. +func NewPermissionUsedOnLeftOfArrowErr(nsName string, parentPermissionName string, foundPermissionName string) error { + return PermissionUsedOnLeftOfArrowError{ + error: fmt.Errorf("under permission `%s` under definition `%s`: permissions cannot be used on the left hand side of an arrow (found `%s`)", parentPermissionName, nsName, foundPermissionName), + definitionName: nsName, + parentPermissionName: parentPermissionName, + foundPermissionName: foundPermissionName, + } +} + +// NewWildcardUsedInArrowErr constructs an error indicating that an arrow operated over a relation with a wildcard type. +func NewWildcardUsedInArrowErr(nsName string, parentPermissionName string, foundRelationName string, wildcardTypeName string, wildcardRelationName string) error { + return WildcardUsedInArrowError{ + error: fmt.Errorf("for arrow under permission `%s`: relation `%s#%s` includes wildcard type `%s` via relation `%s`: wildcard relations cannot be used on the left side of arrows", parentPermissionName, nsName, foundRelationName, wildcardTypeName, wildcardRelationName), + definitionName: nsName, + parentPermissionName: parentPermissionName, + accessedRelationName: foundRelationName, + } +} + +// NewMissingAllowedRelationsErr constructs an error indicating that type information is missing for a relation. +func NewMissingAllowedRelationsErr(nsName string, relationName string) error { + return MissingAllowedRelationsError{ + error: fmt.Errorf("at least one allowed relation/permission is required to be defined for relation `%s`", relationName), + definitionName: nsName, + relationName: relationName, + } +} + +// NewTransitiveWildcardErr constructs an error indicating that a transitive wildcard exists. +func NewTransitiveWildcardErr(nsName string, relationName string, foundRelationNamespace string, foundRelationName string, wildcardTypeName string, wildcardRelationReference string) error { + return TransitiveWildcardError{ + error: fmt.Errorf("for relation `%s`: relation/permission `%s#%s` includes wildcard type `%s` via relation `%s`: wildcard relations cannot be transitively included", relationName, foundRelationNamespace, foundRelationName, wildcardTypeName, wildcardRelationReference), + definitionName: nsName, + relationName: relationName, + } +} + +// NewPermissionsCycleErr constructs an error indicating that a cycle exists amongst permissions. +func NewPermissionsCycleErr(nsName string, permissionNames []string) error { + return PermissionsCycleError{ + error: fmt.Errorf("under definition `%s`, there exists a cycle in permissions: %s", nsName, strings.Join(permissionNames, ", ")), + definitionName: nsName, + permissionNames: permissionNames, + } +} + +// NewUnusedCaveatParameterErr constructs indicating that a parameter was unused in a caveat expression. +func NewUnusedCaveatParameterErr(caveatName string, paramName string) error { + return UnusedCaveatParameterError{ + error: fmt.Errorf("parameter `%s` for caveat `%s` is unused", paramName, caveatName), + caveatName: caveatName, + paramName: paramName, + } +} + +// asTypeError wraps another error in a type error. +func asTypeError(wrapped error) error { + if wrapped == nil { + return nil + } + + var te TypeError + if errors.As(wrapped, &te) { + return wrapped + } + + return TypeError{wrapped} +} + +// TypeError wraps another error as a type error. +type TypeError struct { + error +} + +func (err TypeError) Unwrap() error { + return err.error +} + +var ( + _ sharederrors.UnknownNamespaceError = DefinitionNotFoundError{} + _ sharederrors.UnknownRelationError = RelationNotFoundError{} +) + +// NewTypeWithSourceError creates a new type error at the specific position and with source code, wrapping the underlying +// error. +func NewTypeWithSourceError(wrapped error, withSource nspkg.WithSourcePosition, sourceCodeString string) error { + sourcePosition := withSource.GetSourcePosition() + if sourcePosition != nil { + return asTypeError(spiceerrors.NewWithSourceError( + wrapped, + sourceCodeString, + sourcePosition.ZeroIndexedLineNumber+1, // +1 to make 1-indexed + sourcePosition.ZeroIndexedColumnPosition+1, // +1 to make 1-indexed + )) + } + + return asTypeError(spiceerrors.NewWithSourceError( + wrapped, + sourceCodeString, + 0, + 0, + )) +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/full_reachability.go b/vendor/github.com/authzed/spicedb/pkg/schema/full_reachability.go new file mode 100644 index 0000000..a6143b5 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/full_reachability.go @@ -0,0 +1,226 @@ +package schema + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +// Graph is a struct holding reachability information. +type Graph struct { + arrowSet *ArrowSet + ts *TypeSystem + referenceInfoMap map[nsAndRel][]RelationReferenceInfo +} + +// BuildGraph builds the graph of all reachable information in the schema. +func BuildGraph(ctx context.Context, r *CompiledSchemaResolver) (*Graph, error) { + arrowSet, err := buildArrowSet(ctx, r) + if err != nil { + return nil, err + } + + ts := NewTypeSystem(r) + referenceInfoMap, err := preComputeRelationReferenceInfo(ctx, arrowSet, r, ts) + if err != nil { + return nil, err + } + + return &Graph{ + ts: ts, + arrowSet: arrowSet, + referenceInfoMap: referenceInfoMap, + }, nil +} + +// Arrows returns the set of arrows found in the reachability graph. +func (g *Graph) Arrows() *ArrowSet { + return g.arrowSet +} + +// RelationsReferencing returns all relations/permissions in the schema that reference the specified +// relation in some manner. +func (g *Graph) RelationsReferencing(namespaceName string, relationName string) []RelationReferenceInfo { + return g.referenceInfoMap[nsAndRel{ + Namespace: namespaceName, + Relation: relationName, + }] +} + +// ReferenceType is an enum describing what kind of relation reference we hold in a RelationReferenceInfo. +type ReferenceType int + +const ( + RelationInExpression ReferenceType = iota + RelationIsSubjectType + RelationIsTuplesetForArrow + RelationIsComputedUsersetForArrow +) + +// RelationReferenceInfo holds the relation and metadata for a relation found in the full reachability graph. +type RelationReferenceInfo struct { + Relation *core.RelationReference + Type ReferenceType + Arrow *ArrowInformation +} + +func relationsReferencing(ctx context.Context, arrowSet *ArrowSet, res FullSchemaResolver, ts *TypeSystem, namespaceName string, relationName string) ([]RelationReferenceInfo, error) { + foundReferences := make([]RelationReferenceInfo, 0) + + for _, name := range res.AllDefinitionNames() { + def, err := ts.GetValidatedDefinition(ctx, name) + if err != nil { + return nil, err + } + for _, relation := range def.Namespace().Relation { + // Check for the use of the relation directly as part of any permissions in the same namespace. + if def.IsPermission(relation.Name) && name == namespaceName { + hasReference, err := expressionReferencesRelation(ctx, relation.GetUsersetRewrite(), relationName) + if err != nil { + return nil, err + } + + if hasReference { + foundReferences = append(foundReferences, RelationReferenceInfo{ + Relation: &core.RelationReference{ + Namespace: name, + Relation: relation.Name, + }, + Type: RelationInExpression, + }) + } + continue + } + + // Check for the use of the relation as a subject type on any relation in the entire schema. + isAllowed, err := def.IsAllowedDirectRelation(relation.Name, namespaceName, relationName) + if err != nil { + return nil, err + } + + if isAllowed == DirectRelationValid { + foundReferences = append(foundReferences, RelationReferenceInfo{ + Relation: &core.RelationReference{ + Namespace: name, + Relation: relation.Name, + }, + Type: RelationIsSubjectType, + }) + } + } + } + + // Add any arrow references. + key := namespaceName + "#" + relationName + foundArrows, _ := arrowSet.arrowsByFullTuplesetRelation.Get(key) + for _, arrow := range foundArrows { + arrow := arrow + foundReferences = append(foundReferences, RelationReferenceInfo{ + Relation: &core.RelationReference{ + Namespace: namespaceName, + Relation: arrow.ParentRelationName, + }, + Type: RelationIsTuplesetForArrow, + Arrow: &arrow, + }) + } + + for _, tuplesetRelationKey := range arrowSet.reachableComputedUsersetRelationsByTuplesetRelation.Keys() { + values, ok := arrowSet.reachableComputedUsersetRelationsByTuplesetRelation.Get(tuplesetRelationKey) + if !ok { + continue + } + + if slices.Contains(values, key) { + pieces := strings.Split(tuplesetRelationKey, "#") + foundReferences = append(foundReferences, RelationReferenceInfo{ + Relation: &core.RelationReference{ + Namespace: pieces[0], + Relation: pieces[1], + }, + Type: RelationIsComputedUsersetForArrow, + }) + } + } + + return foundReferences, nil +} + +type nsAndRel struct { + Namespace string + Relation string +} + +func preComputeRelationReferenceInfo(ctx context.Context, arrowSet *ArrowSet, res FullSchemaResolver, ts *TypeSystem) (map[nsAndRel][]RelationReferenceInfo, error) { + nsAndRelToInfo := make(map[nsAndRel][]RelationReferenceInfo) + + for _, namespaceName := range res.AllDefinitionNames() { + outerTS, err := ts.GetValidatedDefinition(ctx, namespaceName) + if err != nil { + return nil, err + } + for _, outerRelation := range outerTS.Namespace().Relation { + referenceInfos, err := relationsReferencing(ctx, arrowSet, res, ts, namespaceName, outerRelation.Name) + if err != nil { + return nil, err + } + + nsAndRel := nsAndRel{ + Namespace: namespaceName, + Relation: outerRelation.Name, + } + nsAndRelToInfo[nsAndRel] = referenceInfos + } + } + + return nsAndRelToInfo, nil +} + +func expressionReferencesRelation(ctx context.Context, rewrite *core.UsersetRewrite, relationName string) (bool, error) { + // TODO(jschorr): Precompute this and maybe create a visitor pattern to stop repeating this everywhere + switch rw := rewrite.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + return setOperationReferencesRelation(ctx, rw.Union, relationName) + case *core.UsersetRewrite_Intersection: + return setOperationReferencesRelation(ctx, rw.Intersection, relationName) + case *core.UsersetRewrite_Exclusion: + return setOperationReferencesRelation(ctx, rw.Exclusion, relationName) + default: + return false, errors.New("userset rewrite operation not implemented in expressionReferencesRelation") + } +} + +func setOperationReferencesRelation(ctx context.Context, so *core.SetOperation, relationName string) (bool, error) { + for _, childOneof := range so.Child { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_ComputedUserset: + if child.ComputedUserset.Relation == relationName { + return true, nil + } + + case *core.SetOperation_Child_UsersetRewrite: + result, err := expressionReferencesRelation(ctx, child.UsersetRewrite, relationName) + if result || err != nil { + return result, err + } + + case *core.SetOperation_Child_TupleToUserset: + // Nothing to do, handled above via arrow set + + case *core.SetOperation_Child_XThis: + // Nothing to do + + case *core.SetOperation_Child_XNil: + // Nothing to do + + default: + return false, fmt.Errorf("unknown set operation child `%T` in setOperationReferencesRelation", child) + } + } + + return false, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/reachabilitygraph.go b/vendor/github.com/authzed/spicedb/pkg/schema/reachabilitygraph.go new file mode 100644 index 0000000..dd252e2 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/reachabilitygraph.go @@ -0,0 +1,452 @@ +package schema + +import ( + "context" + "fmt" + "sort" + "strconv" + "sync" + + "github.com/cespare/xxhash/v2" + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" +) + +// DefinitionReachability is a helper struct that provides an easy way to determine all entrypoints +// for a subject of a particular type into a schema, for the purpose of walking from the subject +// to a specific resource relation. +type DefinitionReachability struct { + def *Definition + cachedGraphs sync.Map + hasOptimizedEntrypointCache sync.Map +} + +// Reachability returns a reachability graph for the given namespace. +func (def *Definition) Reachability() *DefinitionReachability { + return &DefinitionReachability{def, sync.Map{}, sync.Map{}} +} + +// RelationsEncounteredForResource returns all relations that are encountered when walking outward from a resource+relation. +func (rg *DefinitionReachability) RelationsEncounteredForResource( + ctx context.Context, + resourceType *core.RelationReference, +) ([]*core.RelationReference, error) { + _, relationNames, err := rg.computeEntrypoints(ctx, resourceType, nil /* include all entrypoints */, reachabilityFull, entrypointLookupFindAll) + if err != nil { + return nil, err + } + + relationRefs := make([]*core.RelationReference, 0, len(relationNames)) + for _, relationName := range relationNames { + namespace, relation := tuple.MustSplitRelRef(relationName) + relationRefs = append(relationRefs, &core.RelationReference{ + Namespace: namespace, + Relation: relation, + }) + } + return relationRefs, nil +} + +// RelationsEncounteredForSubject returns all relations that are encountered when walking outward from a subject+relation. +func (rg *DefinitionReachability) RelationsEncounteredForSubject( + ctx context.Context, + allDefinitions []*core.NamespaceDefinition, + startingSubjectType *core.RelationReference, +) ([]*core.RelationReference, error) { + if startingSubjectType.Namespace != rg.def.nsDef.Name { + return nil, spiceerrors.MustBugf("gave mismatching namespace name for subject type to reachability graph") + } + + allRelationNames := mapz.NewSet[string]() + + subjectTypesToCheck := []*core.RelationReference{startingSubjectType} + + // TODO(jschorr): optimize this to not require walking over all types recursively. + added := mapz.NewSet[string]() + for { + if len(subjectTypesToCheck) == 0 { + break + } + + collected := &[]ReachabilityEntrypoint{} + for _, nsDef := range allDefinitions { + nts, err := rg.def.ts.GetDefinition(ctx, nsDef.Name) + if err != nil { + return nil, err + } + + nrg := nts.Reachability() + + for _, relation := range nsDef.Relation { + for _, subjectType := range subjectTypesToCheck { + if subjectType.Namespace == nsDef.Name && subjectType.Relation == relation.Name { + continue + } + + encounteredRelations := map[string]struct{}{} + err := nrg.collectEntrypoints(ctx, &core.RelationReference{ + Namespace: nsDef.Name, + Relation: relation.Name, + }, subjectType, collected, encounteredRelations, reachabilityFull, entrypointLookupFindAll) + if err != nil { + return nil, err + } + } + } + } + + subjectTypesToCheck = make([]*core.RelationReference, 0, len(*collected)) + + for _, entrypoint := range *collected { + st := tuple.JoinRelRef(entrypoint.re.TargetRelation.Namespace, entrypoint.re.TargetRelation.Relation) + if !added.Add(st) { + continue + } + + allRelationNames.Add(st) + subjectTypesToCheck = append(subjectTypesToCheck, entrypoint.re.TargetRelation) + } + } + + relationRefs := make([]*core.RelationReference, 0, allRelationNames.Len()) + for _, relationName := range allRelationNames.AsSlice() { + namespace, relation := tuple.MustSplitRelRef(relationName) + relationRefs = append(relationRefs, &core.RelationReference{ + Namespace: namespace, + Relation: relation, + }) + } + return relationRefs, nil +} + +// AllEntrypointsForSubjectToResource returns the entrypoints into the reachability graph, starting +// at the given subject type and walking to the given resource type. +func (rg *DefinitionReachability) AllEntrypointsForSubjectToResource( + ctx context.Context, + subjectType *core.RelationReference, + resourceType *core.RelationReference, +) ([]ReachabilityEntrypoint, error) { + entrypoints, _, err := rg.computeEntrypoints(ctx, resourceType, subjectType, reachabilityFull, entrypointLookupFindAll) + return entrypoints, err +} + +// FirstEntrypointsForSubjectToResource returns the *optimized* set of entrypoints into the +// reachability graph, starting at the given subject type and walking to the given resource type. +// +// It does this by limiting the number of entrypoints (and checking the alternatives) and so simply returns the first entrypoint in an +// intersection or exclusion branch. +func (rg *DefinitionReachability) FirstEntrypointsForSubjectToResource( + ctx context.Context, + subjectType *core.RelationReference, + resourceType *core.RelationReference, +) ([]ReachabilityEntrypoint, error) { + entrypoints, _, err := rg.computeEntrypoints(ctx, resourceType, subjectType, reachabilityFirst, entrypointLookupFindAll) + return entrypoints, err +} + +// HasOptimizedEntrypointsForSubjectToResource returns whether there exists any *optimized* +// entrypoints into the reachability graph, starting at the given subject type and walking +// to the given resource type. +// +// The optimized set will skip branches on intersections and exclusions in an attempt to minimize +// the number of entrypoints. +func (rg *DefinitionReachability) HasOptimizedEntrypointsForSubjectToResource( + ctx context.Context, + subjectType *core.RelationReference, + resourceType *core.RelationReference, +) (bool, error) { + // TODO(jschorr): Change this to be indexed by a struct + cacheKey := tuple.StringCoreRR(subjectType) + "=>" + tuple.StringCoreRR(resourceType) + if result, ok := rg.hasOptimizedEntrypointCache.Load(cacheKey); ok { + return result.(bool), nil + } + + // TODO(jzelinskie): measure to see if it's worth singleflighting this + found, _, err := rg.computeEntrypoints(ctx, resourceType, subjectType, reachabilityFirst, entrypointLookupFindOne) + if err != nil { + return false, err + } + + result := len(found) > 0 + rg.hasOptimizedEntrypointCache.Store(cacheKey, result) + return result, nil +} + +type entrypointLookupOption int + +const ( + entrypointLookupFindAll entrypointLookupOption = iota + entrypointLookupFindOne +) + +func (rg *DefinitionReachability) computeEntrypoints( + ctx context.Context, + resourceType *core.RelationReference, + optionalSubjectType *core.RelationReference, + reachabilityOption reachabilityOption, + entrypointLookupOption entrypointLookupOption, +) ([]ReachabilityEntrypoint, []string, error) { + if resourceType.Namespace != rg.def.nsDef.Name { + return nil, nil, fmt.Errorf("gave mismatching namespace name for resource type to reachability graph") + } + + collected := &[]ReachabilityEntrypoint{} + encounteredRelations := map[string]struct{}{} + err := rg.collectEntrypoints(ctx, resourceType, optionalSubjectType, collected, encounteredRelations, reachabilityOption, entrypointLookupOption) + if err != nil { + return nil, maps.Keys(encounteredRelations), err + } + + collectedEntrypoints := *collected + + // Deduplicate any entrypoints found. An example that can cause a duplicate is a relation which references + // the same subject type multiple times due to caveats: + // + // relation somerel: user | user with somecaveat + // + // This will produce two entrypoints (one per user reference), but as entrypoints themselves are not caveated, + // one is spurious. + entrypointMap := make(map[uint64]ReachabilityEntrypoint, len(collectedEntrypoints)) + uniqueEntrypoints := make([]ReachabilityEntrypoint, 0, len(collectedEntrypoints)) + for _, entrypoint := range collectedEntrypoints { + hash, err := entrypoint.Hash() + if err != nil { + return nil, maps.Keys(encounteredRelations), err + } + + if _, ok := entrypointMap[hash]; !ok { + entrypointMap[hash] = entrypoint + uniqueEntrypoints = append(uniqueEntrypoints, entrypoint) + } + } + + return uniqueEntrypoints, maps.Keys(encounteredRelations), nil +} + +func (rg *DefinitionReachability) getOrBuildGraph(ctx context.Context, resourceType *core.RelationReference, reachabilityOption reachabilityOption) (*core.ReachabilityGraph, error) { + // Check the cache. + // TODO(jschorr): Change to be indexed by a struct. + cacheKey := tuple.StringCoreRR(resourceType) + "-" + strconv.Itoa(int(reachabilityOption)) + if cached, ok := rg.cachedGraphs.Load(cacheKey); ok { + return cached.(*core.ReachabilityGraph), nil + } + + // Load the type system for the target resource relation. + tdef, err := rg.def.ts.GetDefinition(ctx, resourceType.Namespace) + if err != nil { + return nil, err + } + + rrg, err := computeReachability(ctx, tdef, resourceType.Relation, reachabilityOption) + if err != nil { + return nil, err + } + + rg.cachedGraphs.Store(cacheKey, rrg) + return rrg, err +} + +func (rg *DefinitionReachability) collectEntrypoints( + ctx context.Context, + resourceType *core.RelationReference, + optionalSubjectType *core.RelationReference, + collected *[]ReachabilityEntrypoint, + encounteredRelations map[string]struct{}, + reachabilityOption reachabilityOption, + entrypointLookupOption entrypointLookupOption, +) error { + // Ensure that we only process each relation once. + key := tuple.JoinRelRef(resourceType.Namespace, resourceType.Relation) + if _, ok := encounteredRelations[key]; ok { + return nil + } + + encounteredRelations[key] = struct{}{} + + rrg, err := rg.getOrBuildGraph(ctx, resourceType, reachabilityOption) + if err != nil { + return err + } + + if optionalSubjectType != nil { + // Add subject type entrypoints. + subjectTypeEntrypoints, ok := rrg.EntrypointsBySubjectType[optionalSubjectType.Namespace] + if ok { + addEntrypoints(subjectTypeEntrypoints, resourceType, collected, encounteredRelations) + } + + if entrypointLookupOption == entrypointLookupFindOne && len(*collected) > 0 { + return nil + } + + // Add subject relation entrypoints. + subjectRelationEntrypoints, ok := rrg.EntrypointsBySubjectRelation[tuple.JoinRelRef(optionalSubjectType.Namespace, optionalSubjectType.Relation)] + if ok { + addEntrypoints(subjectRelationEntrypoints, resourceType, collected, encounteredRelations) + } + + if entrypointLookupOption == entrypointLookupFindOne && len(*collected) > 0 { + return nil + } + } else { + // Add all entrypoints. + for _, entrypoints := range rrg.EntrypointsBySubjectType { + addEntrypoints(entrypoints, resourceType, collected, encounteredRelations) + } + + for _, entrypoints := range rrg.EntrypointsBySubjectRelation { + addEntrypoints(entrypoints, resourceType, collected, encounteredRelations) + } + } + + // Sort the keys to ensure a stable graph is produced. + keys := maps.Keys(rrg.EntrypointsBySubjectRelation) + sort.Strings(keys) + + // Recursively collect over any reachability graphs for subjects with non-ellipsis relations. + for _, entrypointSetKey := range keys { + entrypointSet := rrg.EntrypointsBySubjectRelation[entrypointSetKey] + if entrypointSet.SubjectRelation != nil && entrypointSet.SubjectRelation.Relation != tuple.Ellipsis { + err := rg.collectEntrypoints(ctx, entrypointSet.SubjectRelation, optionalSubjectType, collected, encounteredRelations, reachabilityOption, entrypointLookupOption) + if err != nil { + return err + } + + if entrypointLookupOption == entrypointLookupFindOne && len(*collected) > 0 { + return nil + } + } + } + + return nil +} + +func addEntrypoints(entrypoints *core.ReachabilityEntrypoints, parentRelation *core.RelationReference, collected *[]ReachabilityEntrypoint, encounteredRelations map[string]struct{}) { + for _, entrypoint := range entrypoints.Entrypoints { + if entrypoint.TuplesetRelation != "" { + key := tuple.JoinRelRef(entrypoint.TargetRelation.Namespace, entrypoint.TuplesetRelation) + encounteredRelations[key] = struct{}{} + } + + *collected = append(*collected, ReachabilityEntrypoint{entrypoint, parentRelation}) + } +} + +// ReachabilityEntrypoint is an entrypoint into the reachability graph for a subject of particular +// type. +type ReachabilityEntrypoint struct { + re *core.ReachabilityEntrypoint + parentRelation *core.RelationReference +} + +// Hash returns a hash representing the data in the entrypoint, for comparison to other entrypoints. +// This is ONLY stable within a single version of SpiceDB and should NEVER be stored for later +// comparison outside of the process. +func (re ReachabilityEntrypoint) Hash() (uint64, error) { + size := re.re.SizeVT() + if re.parentRelation != nil { + size += re.parentRelation.SizeVT() + } + + hashData := make([]byte, 0, size) + + data, err := re.re.MarshalVT() + if err != nil { + return 0, err + } + + hashData = append(hashData, data...) + + if re.parentRelation != nil { + data, err := re.parentRelation.MarshalVT() + if err != nil { + return 0, err + } + + hashData = append(hashData, data...) + } + + return xxhash.Sum64(hashData), nil +} + +// EntrypointKind is the kind of the entrypoint. +func (re ReachabilityEntrypoint) EntrypointKind() core.ReachabilityEntrypoint_ReachabilityEntrypointKind { + return re.re.Kind +} + +// ComputedUsersetRelation returns the tupleset relation of the computed userset, if any. +func (re ReachabilityEntrypoint) ComputedUsersetRelation() (string, error) { + if re.EntrypointKind() == core.ReachabilityEntrypoint_RELATION_ENTRYPOINT { + return "", fmt.Errorf("cannot call ComputedUsersetRelation for kind %v", re.EntrypointKind()) + } + return re.re.ComputedUsersetRelation, nil +} + +// TuplesetRelation returns the tupleset relation of the TTU, if a TUPLESET_TO_USERSET_ENTRYPOINT. +func (re ReachabilityEntrypoint) TuplesetRelation() (string, error) { + if re.EntrypointKind() != core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT { + return "", fmt.Errorf("cannot call TupleToUserset for kind %v", re.EntrypointKind()) + } + + return re.re.TuplesetRelation, nil +} + +// DirectRelation is the relation that this entrypoint represents, if a RELATION_ENTRYPOINT. +func (re ReachabilityEntrypoint) DirectRelation() (*core.RelationReference, error) { + if re.EntrypointKind() != core.ReachabilityEntrypoint_RELATION_ENTRYPOINT { + return nil, fmt.Errorf("cannot call DirectRelation for kind %v", re.EntrypointKind()) + } + + return re.re.TargetRelation, nil +} + +// TargetNamespace returns the namespace for the entrypoint's target relation. +func (re ReachabilityEntrypoint) TargetNamespace() string { + return re.re.TargetRelation.Namespace +} + +// ContainingRelationOrPermission is the relation or permission containing this entrypoint. +func (re ReachabilityEntrypoint) ContainingRelationOrPermission() *core.RelationReference { + return re.parentRelation +} + +// IsDirectResult returns whether the entrypoint, when evaluated, becomes a direct result of +// the parent relation/permission. A direct result only exists if the entrypoint is not contained +// under an intersection or exclusion, which makes the entrypoint's object merely conditionally +// reachable. +func (re ReachabilityEntrypoint) IsDirectResult() bool { + return re.re.ResultStatus == core.ReachabilityEntrypoint_DIRECT_OPERATION_RESULT +} + +func (re ReachabilityEntrypoint) String() string { + return re.MustDebugString() +} + +func (re ReachabilityEntrypoint) MustDebugString() string { + ds, err := re.DebugString() + if err != nil { + panic(err) + } + + return ds +} + +func (re ReachabilityEntrypoint) DebugString() (string, error) { + switch re.EntrypointKind() { + case core.ReachabilityEntrypoint_RELATION_ENTRYPOINT: + return "relation-entrypoint: " + re.re.TargetRelation.Namespace + "#" + re.re.TargetRelation.Relation, nil + + case core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT: + return "ttu-entrypoint: " + re.re.TuplesetRelation + " -> " + re.re.TargetRelation.Namespace + "#" + re.re.TargetRelation.Relation, nil + + case core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT: + return "computed-userset-entrypoint: " + re.re.TargetRelation.Namespace + "#" + re.re.TargetRelation.Relation, nil + + default: + return "", fmt.Errorf("unknown entrypoint kind %v", re.EntrypointKind()) + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/reachabilitygraphbuilder.go b/vendor/github.com/authzed/spicedb/pkg/schema/reachabilitygraphbuilder.go new file mode 100644 index 0000000..e8f8608 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/reachabilitygraphbuilder.go @@ -0,0 +1,275 @@ +package schema + +import ( + "context" + "fmt" + + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/tuple" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +type reachabilityOption int + +const ( + reachabilityFull reachabilityOption = iota + reachabilityFirst +) + +func computeReachability(ctx context.Context, def *Definition, relationName string, option reachabilityOption) (*core.ReachabilityGraph, error) { + targetRelation, ok := def.relationMap[relationName] + if !ok { + return nil, fmt.Errorf("relation `%s` not found under type `%s` missing when computing reachability", relationName, def.nsDef.Name) + } + + if !def.HasTypeInformation(relationName) && targetRelation.GetUsersetRewrite() == nil { + return nil, fmt.Errorf("relation `%s` missing type information when computing reachability for namespace `%s`", relationName, def.nsDef.Name) + } + + graph := &core.ReachabilityGraph{ + EntrypointsBySubjectType: map[string]*core.ReachabilityEntrypoints{}, + EntrypointsBySubjectRelation: map[string]*core.ReachabilityEntrypoints{}, + } + + usersetRewrite := targetRelation.GetUsersetRewrite() + if usersetRewrite != nil { + return graph, computeRewriteReachability(ctx, graph, usersetRewrite, core.ReachabilityEntrypoint_DIRECT_OPERATION_RESULT, targetRelation, def, option) + } + + // If there is no userRewrite, then we have a relation and its entrypoints will all be + // relation entrypoints. + return graph, addSubjectLinks(graph, core.ReachabilityEntrypoint_DIRECT_OPERATION_RESULT, targetRelation, def) +} + +func computeRewriteReachability(ctx context.Context, graph *core.ReachabilityGraph, rewrite *core.UsersetRewrite, operationResultState core.ReachabilityEntrypoint_EntrypointResultStatus, targetRelation *core.Relation, def *Definition, option reachabilityOption) error { + switch rw := rewrite.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + return computeRewriteOpReachability(ctx, rw.Union.Child, operationResultState, graph, targetRelation, def, option) + + case *core.UsersetRewrite_Intersection: + // If optimized mode is set, only return the first child of the intersection. + if option == reachabilityFirst { + return computeRewriteOpReachability(ctx, rw.Intersection.Child[0:1], core.ReachabilityEntrypoint_REACHABLE_CONDITIONAL_RESULT, graph, targetRelation, def, option) + } + + return computeRewriteOpReachability(ctx, rw.Intersection.Child, core.ReachabilityEntrypoint_REACHABLE_CONDITIONAL_RESULT, graph, targetRelation, def, option) + + case *core.UsersetRewrite_Exclusion: + // If optimized mode is set, only return the first child of the exclusion. + if option == reachabilityFirst { + return computeRewriteOpReachability(ctx, rw.Exclusion.Child[0:1], core.ReachabilityEntrypoint_REACHABLE_CONDITIONAL_RESULT, graph, targetRelation, def, option) + } + + return computeRewriteOpReachability(ctx, rw.Exclusion.Child, core.ReachabilityEntrypoint_REACHABLE_CONDITIONAL_RESULT, graph, targetRelation, def, option) + + default: + return fmt.Errorf("unknown kind of userset rewrite in reachability computation: %T", rw) + } +} + +func computeRewriteOpReachability(ctx context.Context, children []*core.SetOperation_Child, operationResultState core.ReachabilityEntrypoint_EntrypointResultStatus, graph *core.ReachabilityGraph, targetRelation *core.Relation, def *Definition, option reachabilityOption) error { + rr := &core.RelationReference{ + Namespace: def.nsDef.Name, + Relation: targetRelation.Name, + } + + for _, childOneof := range children { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_XThis: + return fmt.Errorf("use of _this is unsupported; please rewrite your schema") + + case *core.SetOperation_Child_ComputedUserset: + // A computed userset adds an entrypoint indicating that the relation is rewritten. + err := addSubjectEntrypoint(graph, def.nsDef.Name, child.ComputedUserset.Relation, &core.ReachabilityEntrypoint{ + Kind: core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT, + TargetRelation: rr, + ComputedUsersetRelation: child.ComputedUserset.Relation, + ResultStatus: operationResultState, + }) + if err != nil { + return err + } + + case *core.SetOperation_Child_UsersetRewrite: + err := computeRewriteReachability(ctx, graph, child.UsersetRewrite, operationResultState, targetRelation, def, option) + if err != nil { + return err + } + + case *core.SetOperation_Child_TupleToUserset: + tuplesetRelation := child.TupleToUserset.Tupleset.Relation + computedUsersetRelation := child.TupleToUserset.ComputedUserset.Relation + if err := computeTTUReachability(ctx, graph, tuplesetRelation, computedUsersetRelation, operationResultState, rr, def); err != nil { + return err + } + + case *core.SetOperation_Child_FunctionedTupleToUserset: + tuplesetRelation := child.FunctionedTupleToUserset.Tupleset.Relation + computedUsersetRelation := child.FunctionedTupleToUserset.ComputedUserset.Relation + + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ANY: + // Nothing to change. + + case core.FunctionedTupleToUserset_FUNCTION_ALL: + // Mark as a conditional result. + operationResultState = core.ReachabilityEntrypoint_REACHABLE_CONDITIONAL_RESULT + + default: + return spiceerrors.MustBugf("unknown function type `%T` in reachability graph building", child.FunctionedTupleToUserset.Function) + } + + if err := computeTTUReachability(ctx, graph, tuplesetRelation, computedUsersetRelation, operationResultState, rr, def); err != nil { + return err + } + + case *core.SetOperation_Child_XNil: + // nil has no entrypoints. + return nil + + default: + return spiceerrors.MustBugf("unknown set operation child `%T` in reachability graph building", child) + } + } + + return nil +} + +func computeTTUReachability( + ctx context.Context, + graph *core.ReachabilityGraph, + tuplesetRelation string, + computedUsersetRelation string, + operationResultState core.ReachabilityEntrypoint_EntrypointResultStatus, + rr *core.RelationReference, + def *Definition, +) error { + directRelationTypes, err := def.AllowedDirectRelationsAndWildcards(tuplesetRelation) + if err != nil { + return err + } + + for _, allowedRelationType := range directRelationTypes { + // For each namespace allowed to be found on the right hand side of the + // tupleset relation, include the *computed userset* relation as an entrypoint. + // + // For example, given a schema: + // + // ``` + // definition user {} + // + // definition parent1 { + // relation somerel: user + // } + // + // definition parent2 { + // relation somerel: user + // } + // + // definition child { + // relation parent: parent1 | parent2 + // permission someperm = parent->somerel + // } + // ``` + // + // We will add an entrypoint for the arrow itself, keyed to the relation type + // included from the computed userset. + // + // Using the above example, this will add entrypoints for `parent1#somerel` + // and `parent2#somerel`, which are the subjects reached after resolving the + // right side of the arrow. + + // Check if the relation does exist on the allowed type, and only add the entrypoint if present. + relDef, err := def.ts.GetDefinition(ctx, allowedRelationType.Namespace) + if err != nil { + return err + } + + if relDef.HasRelation(computedUsersetRelation) { + err := addSubjectEntrypoint(graph, allowedRelationType.Namespace, computedUsersetRelation, &core.ReachabilityEntrypoint{ + Kind: core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT, + TargetRelation: rr, + ResultStatus: operationResultState, + ComputedUsersetRelation: computedUsersetRelation, + TuplesetRelation: tuplesetRelation, + }) + if err != nil { + return err + } + } + } + + return nil +} + +func addSubjectEntrypoint(graph *core.ReachabilityGraph, namespaceName string, relationName string, entrypoint *core.ReachabilityEntrypoint) error { + key := tuple.JoinRelRef(namespaceName, relationName) + if relationName == "" { + return spiceerrors.MustBugf("found empty relation name for subject entrypoint") + } + + if graph.EntrypointsBySubjectRelation[key] == nil { + graph.EntrypointsBySubjectRelation[key] = &core.ReachabilityEntrypoints{ + Entrypoints: []*core.ReachabilityEntrypoint{}, + SubjectRelation: &core.RelationReference{ + Namespace: namespaceName, + Relation: relationName, + }, + } + } + + graph.EntrypointsBySubjectRelation[key].Entrypoints = append( + graph.EntrypointsBySubjectRelation[key].Entrypoints, + entrypoint, + ) + + return nil +} + +func addSubjectLinks(graph *core.ReachabilityGraph, operationResultState core.ReachabilityEntrypoint_EntrypointResultStatus, relation *core.Relation, def *Definition) error { + typeInfo := relation.GetTypeInformation() + if typeInfo == nil { + return fmt.Errorf("missing type information for relation %s#%s", def.nsDef.Name, relation.Name) + } + + rr := &core.RelationReference{ + Namespace: def.nsDef.Name, + Relation: relation.Name, + } + + allowedDirectRelations := typeInfo.GetAllowedDirectRelations() + for _, directRelation := range allowedDirectRelations { + // If the allowed relation is a wildcard, add it as a subject *type* entrypoint, rather than + // a subject relation. + if directRelation.GetPublicWildcard() != nil { + if graph.EntrypointsBySubjectType[directRelation.Namespace] == nil { + graph.EntrypointsBySubjectType[directRelation.Namespace] = &core.ReachabilityEntrypoints{ + Entrypoints: []*core.ReachabilityEntrypoint{}, + SubjectType: directRelation.Namespace, + } + } + + graph.EntrypointsBySubjectType[directRelation.Namespace].Entrypoints = append( + graph.EntrypointsBySubjectType[directRelation.Namespace].Entrypoints, + &core.ReachabilityEntrypoint{ + Kind: core.ReachabilityEntrypoint_RELATION_ENTRYPOINT, + TargetRelation: rr, + ResultStatus: operationResultState, + }, + ) + continue + } + + err := addSubjectEntrypoint(graph, directRelation.Namespace, directRelation.GetRelation(), &core.ReachabilityEntrypoint{ + Kind: core.ReachabilityEntrypoint_RELATION_ENTRYPOINT, + TargetRelation: rr, + ResultStatus: operationResultState, + }) + if err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/resolver.go b/vendor/github.com/authzed/spicedb/pkg/schema/resolver.go new file mode 100644 index 0000000..4610886 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/resolver.go @@ -0,0 +1,114 @@ +package schema + +import ( + "context" + "fmt" + "slices" + + "github.com/authzed/spicedb/pkg/datastore" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" +) + +// Resolver is an interface defined for resolving referenced namespaces and caveats when constructing +// and validating a type system. +type Resolver interface { + // LookupDefinition lookups up a namespace definition, also returning whether it was pre-validated. + LookupDefinition(ctx context.Context, name string) (*core.NamespaceDefinition, bool, error) + + // LookupCaveat lookups up a caveat. + LookupCaveat(ctx context.Context, name string) (*Caveat, error) +} + +// ResolverForDatastoreReader returns a Resolver for a datastore reader. +func ResolverForDatastoreReader(ds datastore.Reader) *DatastoreResolver { + return &DatastoreResolver{ + ds: ds, + } +} + +// PredefinedElements are predefined namespaces and/or caveats to give to a resolver. +type PredefinedElements struct { + Definitions []*core.NamespaceDefinition + Caveats []*Caveat +} + +func (pe PredefinedElements) combineWith(other PredefinedElements) PredefinedElements { + return PredefinedElements{ + Definitions: append(slices.Clone(pe.Definitions), other.Definitions...), + Caveats: append(slices.Clone(pe.Caveats), other.Caveats...), + } +} + +// ResolverForPredefinedDefinitions returns a resolver for predefined namespaces and caveats. +func ResolverForPredefinedDefinitions(predefined PredefinedElements) Resolver { + return &DatastoreResolver{ + predefined: predefined, + } +} + +// ResolverForSchema returns a resolver for a schema. +func ResolverForSchema(schema compiler.CompiledSchema) Resolver { + return ResolverForPredefinedDefinitions( + PredefinedElements{ + Definitions: schema.ObjectDefinitions, + Caveats: schema.CaveatDefinitions, + }, + ) +} + +// DatastoreResolver is a resolver implementation for a datastore, to look up schema stored in the underlying storage. +type DatastoreResolver struct { + ds datastore.Reader + predefined PredefinedElements +} + +// LookupDefinition lookups up a namespace definition, also returning whether it was pre-validated. +func (r *DatastoreResolver) LookupDefinition(ctx context.Context, name string) (*core.NamespaceDefinition, bool, error) { + if len(r.predefined.Definitions) > 0 { + for _, def := range r.predefined.Definitions { + if def.Name == name { + return def, false, nil + } + } + } + + if r.ds == nil { + return nil, false, asTypeError(NewDefinitionNotFoundErr(name)) + } + + ns, _, err := r.ds.ReadNamespaceByName(ctx, name) + return ns, true, err +} + +// WithPredefinedElements adds elements (definitions and caveats) that will be used as a local overlay +// for the datastore, often for validation. +func (r *DatastoreResolver) WithPredefinedElements(predefined PredefinedElements) Resolver { + return &DatastoreResolver{ + ds: r.ds, + predefined: predefined.combineWith(r.predefined), + } +} + +// LookupCaveat lookups up a caveat. +func (r *DatastoreResolver) LookupCaveat(ctx context.Context, name string) (*Caveat, error) { + if len(r.predefined.Caveats) > 0 { + for _, caveat := range r.predefined.Caveats { + if caveat.Name == name { + return caveat, nil + } + } + } + + if r.ds == nil { + return nil, asTypeError(NewCaveatNotFoundErr(name)) + } + + cr, ok := r.ds.(datastore.CaveatReader) + if !ok { + return nil, fmt.Errorf("caveats are not supported on this datastore type") + } + + caveatDef, _, err := cr.ReadCaveatByName(ctx, name) + return caveatDef, err +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/typesystem.go b/vendor/github.com/authzed/spicedb/pkg/schema/typesystem.go new file mode 100644 index 0000000..cc989e9 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/typesystem.go @@ -0,0 +1,66 @@ +package schema + +import ( + "context" + "sync" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" +) + +type ( + // Caveat is an alias for a core.CaveatDefinition proto + Caveat = core.CaveatDefinition + // Relation is an alias for a core.Relation proto + Relation = core.Relation +) + +// TypeSystem is a cache and view into an entire combined schema of type definitions and caveats. +// It also provides accessors to build reachability graphs for the underlying types. +type TypeSystem struct { + sync.Mutex + validatedDefinitions map[string]*ValidatedDefinition // GUARDED_BY(Mutex) + resolver Resolver + wildcardCheckCache map[string]*WildcardTypeReference +} + +// NewTypeSystem builds a TypeSystem object from a resolver, which can look up the definitions. +func NewTypeSystem(resolver Resolver) *TypeSystem { + return &TypeSystem{ + validatedDefinitions: make(map[string]*ValidatedDefinition), + resolver: resolver, + wildcardCheckCache: nil, + } +} + +// GetDefinition looks up and returns a definition struct. +func (ts *TypeSystem) GetDefinition(ctx context.Context, definition string) (*Definition, error) { + v, _, err := ts.getDefinition(ctx, definition) + return v, err +} + +// getDefinition is an internal helper for GetDefinition and GetValidatedDefinition +func (ts *TypeSystem) getDefinition(ctx context.Context, definition string) (*Definition, bool, error) { + ts.Lock() + v, ok := ts.validatedDefinitions[definition] + ts.Unlock() + if ok { + return v.Definition, true, nil + } + + ns, prevalidated, err := ts.resolver.LookupDefinition(ctx, definition) + if err != nil { + return nil, false, err + } + d, err := NewDefinition(ts, ns) + if err != nil { + return nil, false, err + } + if prevalidated { + ts.Lock() + if _, ok := ts.validatedDefinitions[definition]; !ok { + ts.validatedDefinitions[definition] = &ValidatedDefinition{Definition: d} + } + ts.Unlock() + } + return d, prevalidated, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schema/typesystem_validation.go b/vendor/github.com/authzed/spicedb/pkg/schema/typesystem_validation.go new file mode 100644 index 0000000..fa9315e --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schema/typesystem_validation.go @@ -0,0 +1,361 @@ +package schema + +import ( + "context" + "fmt" + + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/graph" + nspkg "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + iv1 "github.com/authzed/spicedb/pkg/proto/impl/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +// GetValidatedDefinition runs validation on the type system for the definition to ensure it is consistent. +func (ts *TypeSystem) GetValidatedDefinition(ctx context.Context, definition string) (*ValidatedDefinition, error) { + def, validated, err := ts.getDefinition(ctx, definition) + if err != nil { + return nil, err + } + if validated { + return &ValidatedDefinition{Definition: def}, nil + } + vdef, err := def.Validate(ctx) + if err != nil { + return nil, err + } + ts.Lock() + defer ts.Unlock() + if _, ok := ts.validatedDefinitions[definition]; !ok { + ts.validatedDefinitions[definition] = vdef + } + return vdef, nil +} + +func (def *Definition) Validate(ctx context.Context) (*ValidatedDefinition, error) { + for _, relation := range def.relationMap { + relation := relation + + // Validate the usersets's. + usersetRewrite := relation.GetUsersetRewrite() + rerr, err := graph.WalkRewrite(usersetRewrite, func(childOneof *core.SetOperation_Child) (any, error) { + switch child := childOneof.ChildType.(type) { + case *core.SetOperation_Child_ComputedUserset: + relationName := child.ComputedUserset.GetRelation() + _, ok := def.relationMap[relationName] + if !ok { + return NewTypeWithSourceError( + NewRelationNotFoundErr(def.nsDef.Name, relationName), + childOneof, + relationName, + ), nil + } + + case *core.SetOperation_Child_TupleToUserset: + ttu := child.TupleToUserset + if ttu == nil { + return nil, nil + } + + tupleset := ttu.GetTupleset() + if tupleset == nil { + return nil, nil + } + + relationName := tupleset.GetRelation() + found, ok := def.relationMap[relationName] + if !ok { + return NewTypeWithSourceError( + NewRelationNotFoundErr(def.nsDef.Name, relationName), + childOneof, + relationName, + ), nil + } + + if nspkg.GetRelationKind(found) == iv1.RelationMetadata_PERMISSION { + return NewTypeWithSourceError( + NewPermissionUsedOnLeftOfArrowErr(def.nsDef.Name, relation.Name, relationName), + childOneof, relationName), nil + } + + // Ensure the tupleset relation doesn't itself import wildcard. + referencedWildcard, err := def.TypeSystem().referencesWildcardType(ctx, def, relationName) + if err != nil { + return err, nil + } + + if referencedWildcard != nil { + return NewTypeWithSourceError( + NewWildcardUsedInArrowErr( + def.nsDef.Name, + relation.Name, + relationName, + referencedWildcard.WildcardType.GetNamespace(), + tuple.StringCoreRR(referencedWildcard.ReferencingRelation), + ), + childOneof, relationName, + ), nil + } + + case *core.SetOperation_Child_FunctionedTupleToUserset: + ttu := child.FunctionedTupleToUserset + if ttu == nil { + return nil, nil + } + + tupleset := ttu.GetTupleset() + if tupleset == nil { + return nil, nil + } + + relationName := tupleset.GetRelation() + found, ok := def.relationMap[relationName] + if !ok { + return NewTypeWithSourceError( + NewRelationNotFoundErr(def.nsDef.Name, relationName), + childOneof, + relationName, + ), nil + } + + if nspkg.GetRelationKind(found) == iv1.RelationMetadata_PERMISSION { + return NewTypeWithSourceError( + NewPermissionUsedOnLeftOfArrowErr(def.nsDef.Name, relation.Name, relationName), + childOneof, relationName), nil + } + + // Ensure the tupleset relation doesn't itself import wildcard. + referencedWildcard, err := def.TypeSystem().referencesWildcardType(ctx, def, relationName) + if err != nil { + return err, nil + } + + if referencedWildcard != nil { + return NewTypeWithSourceError( + NewWildcardUsedInArrowErr( + def.nsDef.Name, + relation.Name, + relationName, + referencedWildcard.WildcardType.GetNamespace(), + tuple.StringCoreRR(referencedWildcard.ReferencingRelation), + ), + childOneof, relationName, + ), nil + } + } + return nil, nil + }) + if rerr != nil { + return nil, asTypeError(rerr.(error)) + } + if err != nil { + return nil, err + } + + // Validate type information. + typeInfo := relation.TypeInformation + if typeInfo == nil { + continue + } + + allowedDirectRelations := typeInfo.GetAllowedDirectRelations() + + // Check for a _this or the lack of a userset_rewrite. If either is found, + // then the allowed list must have at least one type. + hasThis, err := graph.HasThis(usersetRewrite) + if err != nil { + return nil, err + } + + if usersetRewrite == nil || hasThis { + if len(allowedDirectRelations) == 0 { + return nil, NewTypeWithSourceError( + NewMissingAllowedRelationsErr(def.nsDef.Name, relation.Name), + relation, relation.Name, + ) + } + } else { + if len(allowedDirectRelations) != 0 { + // NOTE: This is a legacy error and should never really occur with schema. + return nil, NewTypeWithSourceError( + fmt.Errorf("direct relations are not allowed under relation `%s`", relation.Name), + relation, relation.Name) + } + } + + // Allowed relations verification: + // 1) that all allowed relations are not this very relation + // 2) that they exist within the referenced namespace + // 3) that they are not duplicated in any way + // 4) that if they have a caveat reference, the caveat is valid + encountered := mapz.NewSet[string]() + + for _, allowedRelation := range allowedDirectRelations { + source := SourceForAllowedRelation(allowedRelation) + if !encountered.Add(source) { + return nil, NewTypeWithSourceError( + NewDuplicateAllowedRelationErr(def.nsDef.Name, relation.Name, source), + allowedRelation, + source, + ) + } + + // Check the namespace. + if allowedRelation.GetNamespace() == def.nsDef.Name { + if allowedRelation.GetPublicWildcard() == nil && allowedRelation.GetRelation() != tuple.Ellipsis { + _, ok := def.relationMap[allowedRelation.GetRelation()] + if !ok { + return nil, NewTypeWithSourceError( + NewRelationNotFoundErr(allowedRelation.GetNamespace(), allowedRelation.GetRelation()), + allowedRelation, + allowedRelation.GetRelation(), + ) + } + } + } else { + subjectTS, err := def.TypeSystem().GetDefinition(ctx, allowedRelation.GetNamespace()) + if err != nil { + return nil, NewTypeWithSourceError( + fmt.Errorf("could not lookup definition `%s` for relation `%s`: %w", allowedRelation.GetNamespace(), relation.Name, err), + allowedRelation, + allowedRelation.GetNamespace(), + ) + } + + // Check for relations. + if allowedRelation.GetPublicWildcard() == nil && allowedRelation.GetRelation() != tuple.Ellipsis { + // Ensure the relation exists. + ok := subjectTS.HasRelation(allowedRelation.GetRelation()) + if !ok { + return nil, NewTypeWithSourceError( + NewRelationNotFoundErr(allowedRelation.GetNamespace(), allowedRelation.GetRelation()), + allowedRelation, + allowedRelation.GetRelation(), + ) + } + + // Ensure the relation doesn't itself import wildcard. + referencedWildcard, err := def.TypeSystem().referencesWildcardType(ctx, subjectTS, allowedRelation.GetRelation()) + if err != nil { + return nil, err + } + + if referencedWildcard != nil { + return nil, NewTypeWithSourceError( + NewTransitiveWildcardErr( + def.nsDef.Name, + relation.GetName(), + allowedRelation.Namespace, + allowedRelation.GetRelation(), + referencedWildcard.WildcardType.GetNamespace(), + tuple.StringCoreRR(referencedWildcard.ReferencingRelation), + ), + allowedRelation, + tuple.JoinRelRef(allowedRelation.GetNamespace(), allowedRelation.GetRelation()), + ) + } + } + } + + // Check the caveat, if any. + if allowedRelation.GetRequiredCaveat() != nil { + _, err := def.TypeSystem().resolver.LookupCaveat(ctx, allowedRelation.GetRequiredCaveat().CaveatName) + if err != nil { + return nil, NewTypeWithSourceError( + fmt.Errorf("could not lookup caveat `%s` for relation `%s`: %w", allowedRelation.GetRequiredCaveat().CaveatName, relation.Name, err), + allowedRelation, + source, + ) + } + } + } + } + + return &ValidatedDefinition{def}, nil +} + +// referencesWildcardType returns true if the relation references a wildcard type, either directly or via +// another relation. +func (ts *TypeSystem) referencesWildcardType(ctx context.Context, def *Definition, relationName string) (*WildcardTypeReference, error) { + return ts.referencesWildcardTypeWithEncountered(ctx, def, relationName, map[string]bool{}) +} + +func (ts *TypeSystem) referencesWildcardTypeWithEncountered(ctx context.Context, def *Definition, relationName string, encountered map[string]bool) (*WildcardTypeReference, error) { + if ts.wildcardCheckCache == nil { + ts.wildcardCheckCache = make(map[string]*WildcardTypeReference, 1) + } + + cached, isCached := ts.wildcardCheckCache[relationName] + if isCached { + return cached, nil + } + + computed, err := ts.computeReferencesWildcardType(ctx, def, relationName, encountered) + if err != nil { + return nil, err + } + + ts.wildcardCheckCache[relationName] = computed + return computed, nil +} + +func (ts *TypeSystem) computeReferencesWildcardType(ctx context.Context, def *Definition, relationName string, encountered map[string]bool) (*WildcardTypeReference, error) { + relString := tuple.JoinRelRef(def.nsDef.Name, relationName) + if _, ok := encountered[relString]; ok { + return nil, nil + } + encountered[relString] = true + + allowedRels, err := def.AllowedDirectRelationsAndWildcards(relationName) + if err != nil { + return nil, asTypeError(err) + } + + for _, allowedRelation := range allowedRels { + if allowedRelation.GetPublicWildcard() != nil { + return &WildcardTypeReference{ + ReferencingRelation: &core.RelationReference{ + Namespace: def.nsDef.Name, + Relation: relationName, + }, + WildcardType: allowedRelation, + }, nil + } + + if allowedRelation.GetRelation() != tuple.Ellipsis { + if allowedRelation.GetNamespace() == def.nsDef.Name { + found, err := ts.referencesWildcardTypeWithEncountered(ctx, def, allowedRelation.GetRelation(), encountered) + if err != nil { + return nil, asTypeError(err) + } + + if found != nil { + return found, nil + } + continue + } + + subjectTS, err := ts.GetDefinition(ctx, allowedRelation.GetNamespace()) + if err != nil { + return nil, asTypeError(err) + } + + found, err := ts.referencesWildcardTypeWithEncountered(ctx, subjectTS, allowedRelation.GetRelation(), encountered) + if err != nil { + return nil, asTypeError(err) + } + + if found != nil { + return found, nil + } + } + } + + return nil, nil +} + +// ValidatedDefinition is a typesafe reference to a definition that has been validated. +type ValidatedDefinition struct { + *Definition +} |
