summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go')
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go142
1 files changed, 142 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go
new file mode 100644
index 0000000..7c8e7c7
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go
@@ -0,0 +1,142 @@
+package compiler
+
+import (
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+// DSLNode is a node in the DSL AST.
+type DSLNode interface {
+ GetType() dslshape.NodeType
+ GetString(predicateName string) (string, error)
+ GetInt(predicateName string) (int, error)
+ Lookup(predicateName string) (DSLNode, error)
+}
+
+// NodeChain is a chain of nodes in the DSL AST.
+type NodeChain struct {
+ nodes []DSLNode
+ runePosition int
+}
+
+// Head returns the head node of the chain.
+func (nc *NodeChain) Head() DSLNode {
+ return nc.nodes[0]
+}
+
+// HasHeadType returns true if the head node of the chain is of the given type.
+func (nc *NodeChain) HasHeadType(nodeType dslshape.NodeType) bool {
+ return nc.nodes[0].GetType() == nodeType
+}
+
+// ForRunePosition returns the rune position of the chain.
+func (nc *NodeChain) ForRunePosition() int {
+ return nc.runePosition
+}
+
+// FindNodeOfType returns the first node of the given type in the chain, if any.
+func (nc *NodeChain) FindNodeOfType(nodeType dslshape.NodeType) DSLNode {
+ for _, node := range nc.nodes {
+ if node.GetType() == nodeType {
+ return node
+ }
+ }
+
+ return nil
+}
+
+func (nc *NodeChain) String() string {
+ var out string
+ for _, node := range nc.nodes {
+ out += node.GetType().String() + " "
+ }
+ return out
+}
+
+// PositionToAstNodeChain returns the AST node, and its parents (if any), found at the given position in the source, if any.
+func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, position input.Position) (*NodeChain, error) {
+ rootSource, err := schema.rootNode.GetString(dslshape.NodePredicateSource)
+ if err != nil {
+ return nil, err
+ }
+
+ if rootSource != string(source) {
+ return nil, nil
+ }
+
+ // Map the position to a file rune.
+ runePosition, err := schema.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source)
+ if err != nil {
+ return nil, err
+ }
+
+ // Find the node at the rune position.
+ found, err := runePositionToAstNodeChain(schema.rootNode, runePosition)
+ if err != nil {
+ return nil, err
+ }
+
+ if found == nil {
+ return nil, nil
+ }
+
+ return &NodeChain{nodes: found, runePosition: runePosition}, nil
+}
+
+func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, error) {
+ if !node.Has(dslshape.NodePredicateStartRune) {
+ return nil, nil
+ }
+
+ startRune, err := node.GetInt(dslshape.NodePredicateStartRune)
+ if err != nil {
+ return nil, err
+ }
+
+ endRune, err := node.GetInt(dslshape.NodePredicateEndRune)
+ if err != nil {
+ return nil, err
+ }
+
+ if runePosition < startRune || runePosition > endRune {
+ return nil, nil
+ }
+
+ for _, child := range node.AllSubNodes() {
+ childChain, err := runePositionToAstNodeChain(child, runePosition)
+ if err != nil {
+ return nil, err
+ }
+
+ if childChain != nil {
+ return append(childChain, wrapper{node}), nil
+ }
+ }
+
+ return []DSLNode{wrapper{node}}, nil
+}
+
+type wrapper struct {
+ node *dslNode
+}
+
+func (w wrapper) GetType() dslshape.NodeType {
+ return w.node.GetType()
+}
+
+func (w wrapper) GetString(predicateName string) (string, error) {
+ return w.node.GetString(predicateName)
+}
+
+func (w wrapper) GetInt(predicateName string) (int, error) {
+ return w.node.GetInt(predicateName)
+}
+
+func (w wrapper) Lookup(predicateName string) (DSLNode, error) {
+ found, err := w.node.Lookup(predicateName)
+ if err != nil {
+ return nil, err
+ }
+
+ return wrapper{found}, nil
+}