diff options
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/development/resolver.go')
| -rw-r--r-- | vendor/github.com/authzed/spicedb/pkg/development/resolver.go | 426 |
1 files changed, 426 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/development/resolver.go b/vendor/github.com/authzed/spicedb/pkg/development/resolver.go new file mode 100644 index 0000000..ddb8a67 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/development/resolver.go @@ -0,0 +1,426 @@ +package development + +import ( + "context" + "fmt" + "strings" + + "github.com/ccoveille/go-safecast" + + log "github.com/authzed/spicedb/internal/logging" + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +// ReferenceType is the type of reference. +type ReferenceType int + +const ( + ReferenceTypeUnknown ReferenceType = iota + ReferenceTypeDefinition + ReferenceTypeCaveat + ReferenceTypeRelation + ReferenceTypePermission + ReferenceTypeCaveatParameter +) + +// SchemaReference represents a reference to a schema node. +type SchemaReference struct { + // Source is the source of the reference. + Source input.Source + + // Position is the position of the reference in the source. + Position input.Position + + // Text is the text of the reference. + Text string + + // ReferenceType is the type of reference. + ReferenceType ReferenceType + + // ReferenceMarkdown is the markdown representation of the reference. + ReferenceMarkdown string + + // TargetSource is the source of the target node, if any. + TargetSource *input.Source + + // TargetPosition is the position of the target node, if any. + TargetPosition *input.Position + + // TargetSourceCode is the source code representation of the target, if any. + TargetSourceCode string + + // TargetNamePositionOffset is the offset from the target position from where the + // *name* of the target is found. + TargetNamePositionOffset int +} + +// Resolver resolves references to schema nodes from source positions. +type Resolver struct { + schema *compiler.CompiledSchema + typeSystem *schema.TypeSystem +} + +// NewResolver creates a new resolver for the given schema. +func NewResolver(compiledSchema *compiler.CompiledSchema) (*Resolver, error) { + resolver := schema.ResolverForCompiledSchema(*compiledSchema) + ts := schema.NewTypeSystem(resolver) + return &Resolver{schema: compiledSchema, typeSystem: ts}, nil +} + +// ReferenceAtPosition returns the reference to the schema node at the given position in the source, if any. +func (r *Resolver) ReferenceAtPosition(source input.Source, position input.Position) (*SchemaReference, error) { + nodeChain, err := compiler.PositionToAstNodeChain(r.schema, source, position) + if err != nil { + return nil, err + } + + if nodeChain == nil { + return nil, nil + } + + relationReference := func(relation *core.Relation, def *schema.Definition) (*SchemaReference, error) { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToInt(relation.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToInt(relation.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + relationPosition := input.Position{ + LineNumber: lineNumber, + ColumnPosition: columnPosition, + } + + targetSourceCode, err := generator.GenerateRelationSource(relation, caveattypes.Default.TypeSet) + if err != nil { + return nil, err + } + + if def.IsPermission(relation.Name) { + return &SchemaReference{ + Source: source, + Position: position, + Text: relation.Name, + + ReferenceType: ReferenceTypePermission, + ReferenceMarkdown: fmt.Sprintf("permission %s", relation.Name), + + TargetSource: &source, + TargetPosition: &relationPosition, + TargetSourceCode: targetSourceCode, + TargetNamePositionOffset: len("permission "), + }, nil + } + + return &SchemaReference{ + Source: source, + Position: position, + Text: relation.Name, + + ReferenceType: ReferenceTypeRelation, + ReferenceMarkdown: fmt.Sprintf("relation %s", relation.Name), + + TargetSource: &source, + TargetPosition: &relationPosition, + TargetSourceCode: targetSourceCode, + TargetNamePositionOffset: len("relation "), + }, nil + } + + // Type reference. + if ts, relation, ok := r.typeReferenceChain(nodeChain); ok { + if relation != nil { + return relationReference(relation, ts) + } + + def := ts.Namespace() + + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToInt(def.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToInt(def.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + + defPosition := input.Position{ + LineNumber: lineNumber, + ColumnPosition: columnPosition, + } + + docComment := "" + comments := namespace.GetComments(def.Metadata) + if len(comments) > 0 { + docComment = strings.Join(comments, "\n") + "\n" + } + + targetSourceCode := fmt.Sprintf("%sdefinition %s {\n\t// ...\n}", docComment, def.Name) + if len(def.Relation) == 0 { + targetSourceCode = fmt.Sprintf("%sdefinition %s {}", docComment, def.Name) + } + + return &SchemaReference{ + Source: source, + Position: position, + Text: def.Name, + + ReferenceType: ReferenceTypeDefinition, + ReferenceMarkdown: fmt.Sprintf("definition %s", def.Name), + + TargetSource: &source, + TargetPosition: &defPosition, + TargetSourceCode: targetSourceCode, + TargetNamePositionOffset: len("definition "), + }, nil + } + + // Caveat Type reference. + if caveatDef, ok := r.caveatTypeReferenceChain(nodeChain); ok { + // NOTE: zeroes are fine here to mean "unknown" + lineNumber, err := safecast.ToInt(caveatDef.SourcePosition.ZeroIndexedLineNumber) + if err != nil { + log.Err(err).Msg("could not cast lineNumber to uint32") + } + columnPosition, err := safecast.ToInt(caveatDef.SourcePosition.ZeroIndexedColumnPosition) + if err != nil { + log.Err(err).Msg("could not cast columnPosition to uint32") + } + + defPosition := input.Position{ + LineNumber: lineNumber, + ColumnPosition: columnPosition, + } + + var caveatSourceCode strings.Builder + caveatSourceCode.WriteString(fmt.Sprintf("caveat %s(", caveatDef.Name)) + index := 0 + for paramName, paramType := range caveatDef.ParameterTypes { + if index > 0 { + caveatSourceCode.WriteString(", ") + } + + caveatSourceCode.WriteString(fmt.Sprintf("%s %s", paramName, caveats.ParameterTypeString(paramType))) + index++ + } + caveatSourceCode.WriteString(") {\n\t// ...\n}") + + return &SchemaReference{ + Source: source, + Position: position, + Text: caveatDef.Name, + + ReferenceType: ReferenceTypeCaveat, + ReferenceMarkdown: fmt.Sprintf("caveat %s", caveatDef.Name), + + TargetSource: &source, + TargetPosition: &defPosition, + TargetSourceCode: caveatSourceCode.String(), + TargetNamePositionOffset: len("caveat "), + }, nil + } + + // Relation reference. + if relation, ts, ok := r.relationReferenceChain(nodeChain); ok { + return relationReference(relation, ts) + } + + // Caveat parameter used in expression. + if caveatParamName, caveatDef, ok := r.caveatParamChain(nodeChain, source, position); ok { + targetSourceCode := fmt.Sprintf("%s %s", caveatParamName, caveats.ParameterTypeString(caveatDef.ParameterTypes[caveatParamName])) + + return &SchemaReference{ + Source: source, + Position: position, + Text: caveatParamName, + + ReferenceType: ReferenceTypeCaveatParameter, + ReferenceMarkdown: targetSourceCode, + + TargetSource: &source, + TargetSourceCode: targetSourceCode, + }, nil + } + + return nil, nil +} + +func (r *Resolver) lookupCaveat(caveatName string) (*core.CaveatDefinition, bool) { + for _, caveatDef := range r.schema.CaveatDefinitions { + if caveatDef.Name == caveatName { + return caveatDef, true + } + } + + return nil, false +} + +func (r *Resolver) lookupRelation(defName, relationName string) (*core.Relation, *schema.Definition, bool) { + ts, err := r.typeSystem.GetDefinition(context.Background(), defName) + if err != nil { + return nil, nil, false + } + + rel, ok := ts.GetRelation(relationName) + if !ok { + return nil, nil, false + } + + return rel, ts, true +} + +func (r *Resolver) caveatParamChain(nodeChain *compiler.NodeChain, source input.Source, position input.Position) (string, *core.CaveatDefinition, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypeCaveatExpression) { + return "", nil, false + } + + caveatDefNode := nodeChain.FindNodeOfType(dslshape.NodeTypeCaveatDefinition) + if caveatDefNode == nil { + return "", nil, false + } + + caveatName, err := caveatDefNode.GetString(dslshape.NodeCaveatDefinitionPredicateName) + if err != nil { + return "", nil, false + } + + caveatDef, ok := r.lookupCaveat(caveatName) + if !ok { + return "", nil, false + } + + runePosition, err := r.schema.SourcePositionToRunePosition(source, position) + if err != nil { + return "", nil, false + } + + exprRunePosition, err := nodeChain.Head().GetInt(dslshape.NodePredicateStartRune) + if err != nil { + return "", nil, false + } + + if exprRunePosition > runePosition { + return "", nil, false + } + + relationRunePosition := runePosition - exprRunePosition + + caveatExpr, err := nodeChain.Head().GetString(dslshape.NodeCaveatExpressionPredicateExpression) + if err != nil { + return "", nil, false + } + + // Split the expression into tokens and find the associated token. + tokens := strings.FieldsFunc(caveatExpr, splitCELToken) + currentIndex := 0 + for _, token := range tokens { + if currentIndex <= relationRunePosition && currentIndex+len(token) >= relationRunePosition { + if _, ok := caveatDef.ParameterTypes[token]; ok { + return token, caveatDef, true + } + } + } + + return "", caveatDef, true +} + +func splitCELToken(r rune) bool { + return r == ' ' || r == '(' || r == ')' || r == '.' || r == ',' || r == '[' || r == ']' || r == '{' || r == '}' || r == ':' || r == '=' +} + +func (r *Resolver) caveatTypeReferenceChain(nodeChain *compiler.NodeChain) (*core.CaveatDefinition, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypeCaveatReference) { + return nil, false + } + + caveatName, err := nodeChain.Head().GetString(dslshape.NodeCaveatPredicateCaveat) + if err != nil { + return nil, false + } + + return r.lookupCaveat(caveatName) +} + +func (r *Resolver) typeReferenceChain(nodeChain *compiler.NodeChain) (*schema.Definition, *core.Relation, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypeSpecificTypeReference) { + return nil, nil, false + } + + defName, err := nodeChain.Head().GetString(dslshape.NodeSpecificReferencePredicateType) + if err != nil { + return nil, nil, false + } + + def, err := r.typeSystem.GetDefinition(context.Background(), defName) + if err != nil { + return nil, nil, false + } + + relationName, err := nodeChain.Head().GetString(dslshape.NodeSpecificReferencePredicateRelation) + if err != nil { + return def, nil, true + } + + startingRune, err := nodeChain.Head().GetInt(dslshape.NodePredicateStartRune) + if err != nil { + return def, nil, true + } + + // If hover over the definition name, return the definition. + if nodeChain.ForRunePosition() < startingRune+len(defName) { + return def, nil, true + } + + relation, ok := def.GetRelation(relationName) + if !ok { + return nil, nil, false + } + + return def, relation, true +} + +func (r *Resolver) relationReferenceChain(nodeChain *compiler.NodeChain) (*core.Relation, *schema.Definition, bool) { + if !nodeChain.HasHeadType(dslshape.NodeTypeIdentifier) { + return nil, nil, false + } + + if arrowExpr := nodeChain.FindNodeOfType(dslshape.NodeTypeArrowExpression); arrowExpr != nil { + // Ensure this on the left side of the arrow. + rightExpr, err := arrowExpr.Lookup(dslshape.NodeExpressionPredicateRightExpr) + if err != nil { + return nil, nil, false + } + + if rightExpr == nodeChain.Head() { + return nil, nil, false + } + } + + relationName, err := nodeChain.Head().GetString(dslshape.NodeIdentiferPredicateValue) + if err != nil { + return nil, nil, false + } + + parentDefNode := nodeChain.FindNodeOfType(dslshape.NodeTypeDefinition) + if parentDefNode == nil { + return nil, nil, false + } + + defName, err := parentDefNode.GetString(dslshape.NodeDefinitionPredicateName) + if err != nil { + return nil, nil, false + } + + return r.lookupRelation(defName, relationName) +} |
