summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/pkg/schemadsl
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/schemadsl')
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go194
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go142
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go53
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go180
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go32
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go714
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/dslshape.go209
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/zz_generated.nodetype_string.go43
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator.go430
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator_impl.go83
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/input/inputsource.go224
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/input/sourcepositionmapper.go95
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flaggablelexer.go59
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flags.go26
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex.go231
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex_def.go351
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/tokentype_string.go64
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/nodestack.go35
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser.go658
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser_impl.go367
22 files changed, 4194 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go
new file mode 100644
index 0000000..d1e96ec
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go
@@ -0,0 +1,194 @@
+package compiler
+
+import (
+ "errors"
+ "fmt"
+
+ "google.golang.org/protobuf/proto"
+ "k8s.io/utils/strings/slices"
+
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+ "github.com/authzed/spicedb/pkg/schemadsl/parser"
+)
+
+// InputSchema defines the input for a Compile.
+type InputSchema struct {
+ // Source is the source of the schema being compiled.
+ Source input.Source
+
+ // Schema is the contents being compiled.
+ SchemaString string
+}
+
+// SchemaDefinition represents an object or caveat definition in a schema.
+type SchemaDefinition interface {
+ proto.Message
+
+ GetName() string
+}
+
+// CompiledSchema is the result of compiling a schema when there are no errors.
+type CompiledSchema struct {
+ // ObjectDefinitions holds the object definitions in the schema.
+ ObjectDefinitions []*core.NamespaceDefinition
+
+ // CaveatDefinitions holds the caveat definitions in the schema.
+ CaveatDefinitions []*core.CaveatDefinition
+
+ // OrderedDefinitions holds the object and caveat definitions in the schema, in the
+ // order in which they were found.
+ OrderedDefinitions []SchemaDefinition
+
+ rootNode *dslNode
+ mapper input.PositionMapper
+}
+
+// SourcePositionToRunePosition converts a source position to a rune position.
+func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, position input.Position) (int, error) {
+ return cs.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source)
+}
+
+type config struct {
+ skipValidation bool
+ objectTypePrefix *string
+ allowedFlags []string
+ caveatTypeSet *caveattypes.TypeSet
+}
+
+func SkipValidation() Option { return func(cfg *config) { cfg.skipValidation = true } }
+
+func ObjectTypePrefix(prefix string) ObjectPrefixOption {
+ return func(cfg *config) { cfg.objectTypePrefix = &prefix }
+}
+
+func RequirePrefixedObjectType() ObjectPrefixOption {
+ return func(cfg *config) { cfg.objectTypePrefix = nil }
+}
+
+func AllowUnprefixedObjectType() ObjectPrefixOption {
+ return func(cfg *config) { cfg.objectTypePrefix = new(string) }
+}
+
+func CaveatTypeSet(cts *caveattypes.TypeSet) Option {
+ return func(cfg *config) { cfg.caveatTypeSet = cts }
+}
+
+const expirationFlag = "expiration"
+
+func DisallowExpirationFlag() Option {
+ return func(cfg *config) {
+ cfg.allowedFlags = slices.Filter([]string{}, cfg.allowedFlags, func(s string) bool {
+ return s != expirationFlag
+ })
+ }
+}
+
+type Option func(*config)
+
+type ObjectPrefixOption func(*config)
+
+// Compile compilers the input schema into a set of namespace definition protos.
+func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
+ cfg := &config{
+ allowedFlags: make([]string, 0, 1),
+ }
+
+ // Enable `expiration` flag by default.
+ cfg.allowedFlags = append(cfg.allowedFlags, expirationFlag)
+
+ prefix(cfg) // required option
+
+ for _, fn := range opts {
+ fn(cfg)
+ }
+
+ mapper := newPositionMapper(schema)
+ root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode)
+ errs := root.FindAll(dslshape.NodeTypeError)
+ if len(errs) > 0 {
+ err := errorNodeToError(errs[0], mapper)
+ return nil, err
+ }
+
+ cts := caveattypes.TypeSetOrDefault(cfg.caveatTypeSet)
+ compiled, err := translate(&translationContext{
+ objectTypePrefix: cfg.objectTypePrefix,
+ mapper: mapper,
+ schemaString: schema.SchemaString,
+ skipValidate: cfg.skipValidation,
+ allowedFlags: cfg.allowedFlags,
+ caveatTypeSet: cts,
+ }, root)
+ if err != nil {
+ var withNodeError withNodeError
+ if errors.As(err, &withNodeError) {
+ err = toContextError(withNodeError.error.Error(), withNodeError.errorSourceCode, withNodeError.node, mapper)
+ }
+
+ return nil, err
+ }
+
+ return compiled, nil
+}
+
+func errorNodeToError(node *dslNode, mapper input.PositionMapper) error {
+ if node.GetType() != dslshape.NodeTypeError {
+ return fmt.Errorf("given none error node")
+ }
+
+ errMessage, err := node.GetString(dslshape.NodePredicateErrorMessage)
+ if err != nil {
+ return fmt.Errorf("could not get error message for error node: %w", err)
+ }
+
+ errorSourceCode := ""
+ if node.Has(dslshape.NodePredicateErrorSource) {
+ es, err := node.GetString(dslshape.NodePredicateErrorSource)
+ if err != nil {
+ return fmt.Errorf("could not get error source for error node: %w", err)
+ }
+
+ errorSourceCode = es
+ }
+
+ return toContextError(errMessage, errorSourceCode, node, mapper)
+}
+
+func toContextError(errMessage string, errorSourceCode string, node *dslNode, mapper input.PositionMapper) error {
+ sourceRange, err := node.Range(mapper)
+ if err != nil {
+ return fmt.Errorf("could not get range for error node: %w", err)
+ }
+
+ formattedRange, err := formatRange(sourceRange)
+ if err != nil {
+ return err
+ }
+
+ source, err := node.GetString(dslshape.NodePredicateSource)
+ if err != nil {
+ return fmt.Errorf("missing source for node: %w", err)
+ }
+
+ return WithContextError{
+ BaseCompilerError: BaseCompilerError{
+ error: fmt.Errorf("parse error in %s: %s", formattedRange, errMessage),
+ BaseMessage: errMessage,
+ },
+ SourceRange: sourceRange,
+ Source: input.Source(source),
+ ErrorSourceCode: errorSourceCode,
+ }
+}
+
+func formatRange(rnge input.SourceRange) (string, error) {
+ startLine, startCol, err := rnge.Start().LineAndColumn()
+ if err != nil {
+ return "", err
+ }
+
+ return fmt.Sprintf("`%s`, line %v, column %v", rnge.Source(), startLine+1, startCol+1), nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go
new file mode 100644
index 0000000..7c8e7c7
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go
@@ -0,0 +1,142 @@
+package compiler
+
+import (
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+// DSLNode is a node in the DSL AST.
+type DSLNode interface {
+ GetType() dslshape.NodeType
+ GetString(predicateName string) (string, error)
+ GetInt(predicateName string) (int, error)
+ Lookup(predicateName string) (DSLNode, error)
+}
+
+// NodeChain is a chain of nodes in the DSL AST.
+type NodeChain struct {
+ nodes []DSLNode
+ runePosition int
+}
+
+// Head returns the head node of the chain.
+func (nc *NodeChain) Head() DSLNode {
+ return nc.nodes[0]
+}
+
+// HasHeadType returns true if the head node of the chain is of the given type.
+func (nc *NodeChain) HasHeadType(nodeType dslshape.NodeType) bool {
+ return nc.nodes[0].GetType() == nodeType
+}
+
+// ForRunePosition returns the rune position of the chain.
+func (nc *NodeChain) ForRunePosition() int {
+ return nc.runePosition
+}
+
+// FindNodeOfType returns the first node of the given type in the chain, if any.
+func (nc *NodeChain) FindNodeOfType(nodeType dslshape.NodeType) DSLNode {
+ for _, node := range nc.nodes {
+ if node.GetType() == nodeType {
+ return node
+ }
+ }
+
+ return nil
+}
+
+func (nc *NodeChain) String() string {
+ var out string
+ for _, node := range nc.nodes {
+ out += node.GetType().String() + " "
+ }
+ return out
+}
+
+// PositionToAstNodeChain returns the AST node, and its parents (if any), found at the given position in the source, if any.
+func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, position input.Position) (*NodeChain, error) {
+ rootSource, err := schema.rootNode.GetString(dslshape.NodePredicateSource)
+ if err != nil {
+ return nil, err
+ }
+
+ if rootSource != string(source) {
+ return nil, nil
+ }
+
+ // Map the position to a file rune.
+ runePosition, err := schema.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source)
+ if err != nil {
+ return nil, err
+ }
+
+ // Find the node at the rune position.
+ found, err := runePositionToAstNodeChain(schema.rootNode, runePosition)
+ if err != nil {
+ return nil, err
+ }
+
+ if found == nil {
+ return nil, nil
+ }
+
+ return &NodeChain{nodes: found, runePosition: runePosition}, nil
+}
+
+func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, error) {
+ if !node.Has(dslshape.NodePredicateStartRune) {
+ return nil, nil
+ }
+
+ startRune, err := node.GetInt(dslshape.NodePredicateStartRune)
+ if err != nil {
+ return nil, err
+ }
+
+ endRune, err := node.GetInt(dslshape.NodePredicateEndRune)
+ if err != nil {
+ return nil, err
+ }
+
+ if runePosition < startRune || runePosition > endRune {
+ return nil, nil
+ }
+
+ for _, child := range node.AllSubNodes() {
+ childChain, err := runePositionToAstNodeChain(child, runePosition)
+ if err != nil {
+ return nil, err
+ }
+
+ if childChain != nil {
+ return append(childChain, wrapper{node}), nil
+ }
+ }
+
+ return []DSLNode{wrapper{node}}, nil
+}
+
+type wrapper struct {
+ node *dslNode
+}
+
+func (w wrapper) GetType() dslshape.NodeType {
+ return w.node.GetType()
+}
+
+func (w wrapper) GetString(predicateName string) (string, error) {
+ return w.node.GetString(predicateName)
+}
+
+func (w wrapper) GetInt(predicateName string) (int, error) {
+ return w.node.GetInt(predicateName)
+}
+
+func (w wrapper) Lookup(predicateName string) (DSLNode, error) {
+ found, err := w.node.Lookup(predicateName)
+ if err != nil {
+ return nil, err
+ }
+
+ return wrapper{found}, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go
new file mode 100644
index 0000000..fdc1735
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go
@@ -0,0 +1,2 @@
+// Package compiler knows how to build the Go representation of a SpiceDB schema text.
+package compiler
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go
new file mode 100644
index 0000000..2c33ba8
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go
@@ -0,0 +1,53 @@
+package compiler
+
+import (
+ "strconv"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+// BaseCompilerError defines an error with contains the base message of the issue
+// that occurred.
+type BaseCompilerError struct {
+ error
+ BaseMessage string
+}
+
+type withNodeError struct {
+ error
+ node *dslNode
+ errorSourceCode string
+}
+
+// WithContextError defines an error which contains contextual information.
+type WithContextError struct {
+ BaseCompilerError
+ SourceRange input.SourceRange
+ Source input.Source
+ ErrorSourceCode string
+}
+
+func (ewc WithContextError) Unwrap() error {
+ return ewc.BaseCompilerError
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (ewc WithContextError) DetailsMetadata() map[string]string {
+ startLine, startCol, err := ewc.SourceRange.Start().LineAndColumn()
+ if err != nil {
+ return map[string]string{}
+ }
+
+ endLine, endCol, err := ewc.SourceRange.End().LineAndColumn()
+ if err != nil {
+ return map[string]string{}
+ }
+
+ return map[string]string{
+ "start_line_number": strconv.Itoa(startLine),
+ "start_column_position": strconv.Itoa(startCol),
+ "end_line_number": strconv.Itoa(endLine),
+ "end_column_position": strconv.Itoa(endCol),
+ "source_code": ewc.ErrorSourceCode,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go
new file mode 100644
index 0000000..b7e2a70
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go
@@ -0,0 +1,180 @@
+package compiler
+
+import (
+ "container/list"
+ "fmt"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+ "github.com/authzed/spicedb/pkg/schemadsl/parser"
+)
+
+type dslNode struct {
+ nodeType dslshape.NodeType
+ properties map[string]interface{}
+ children map[string]*list.List
+}
+
+func createAstNode(_ input.Source, kind dslshape.NodeType) parser.AstNode {
+ return &dslNode{
+ nodeType: kind,
+ properties: make(map[string]interface{}),
+ children: make(map[string]*list.List),
+ }
+}
+
+func (tn *dslNode) GetType() dslshape.NodeType {
+ return tn.nodeType
+}
+
+func (tn *dslNode) Connect(predicate string, other parser.AstNode) {
+ if tn.children[predicate] == nil {
+ tn.children[predicate] = list.New()
+ }
+
+ tn.children[predicate].PushBack(other)
+}
+
+func (tn *dslNode) MustDecorate(property string, value string) parser.AstNode {
+ if _, ok := tn.properties[property]; ok {
+ panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
+ }
+
+ tn.properties[property] = value
+ return tn
+}
+
+func (tn *dslNode) MustDecorateWithInt(property string, value int) parser.AstNode {
+ if _, ok := tn.properties[property]; ok {
+ panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
+ }
+
+ tn.properties[property] = value
+ return tn
+}
+
+func (tn *dslNode) Range(mapper input.PositionMapper) (input.SourceRange, error) {
+ sourceStr, err := tn.GetString(dslshape.NodePredicateSource)
+ if err != nil {
+ return nil, err
+ }
+
+ source := input.Source(sourceStr)
+
+ startRune, err := tn.GetInt(dslshape.NodePredicateStartRune)
+ if err != nil {
+ return nil, err
+ }
+
+ endRune, err := tn.GetInt(dslshape.NodePredicateEndRune)
+ if err != nil {
+ return nil, err
+ }
+
+ return source.RangeForRunePositions(startRune, endRune, mapper), nil
+}
+
+func (tn *dslNode) Has(predicateName string) bool {
+ _, ok := tn.properties[predicateName]
+ return ok
+}
+
+func (tn *dslNode) GetInt(predicateName string) (int, error) {
+ predicate, ok := tn.properties[predicateName]
+ if !ok {
+ return 0, fmt.Errorf("unknown predicate %s", predicateName)
+ }
+
+ value, ok := predicate.(int)
+ if !ok {
+ return 0, fmt.Errorf("predicate %s is not an int", predicateName)
+ }
+
+ return value, nil
+}
+
+func (tn *dslNode) GetString(predicateName string) (string, error) {
+ predicate, ok := tn.properties[predicateName]
+ if !ok {
+ return "", fmt.Errorf("unknown predicate %s", predicateName)
+ }
+
+ value, ok := predicate.(string)
+ if !ok {
+ return "", fmt.Errorf("predicate %s is not a string", predicateName)
+ }
+
+ return value, nil
+}
+
+func (tn *dslNode) AllSubNodes() []*dslNode {
+ nodes := []*dslNode{}
+ for _, childList := range tn.children {
+ for e := childList.Front(); e != nil; e = e.Next() {
+ nodes = append(nodes, e.Value.(*dslNode))
+ }
+ }
+ return nodes
+}
+
+func (tn *dslNode) GetChildren() []*dslNode {
+ return tn.List(dslshape.NodePredicateChild)
+}
+
+func (tn *dslNode) FindAll(nodeType dslshape.NodeType) []*dslNode {
+ found := []*dslNode{}
+ if tn.nodeType == dslshape.NodeTypeError {
+ found = append(found, tn)
+ }
+
+ for _, childList := range tn.children {
+ for e := childList.Front(); e != nil; e = e.Next() {
+ childFound := e.Value.(*dslNode).FindAll(nodeType)
+ found = append(found, childFound...)
+ }
+ }
+ return found
+}
+
+func (tn *dslNode) List(predicateName string) []*dslNode {
+ children := []*dslNode{}
+ childList, ok := tn.children[predicateName]
+ if !ok {
+ return children
+ }
+
+ for e := childList.Front(); e != nil; e = e.Next() {
+ children = append(children, e.Value.(*dslNode))
+ }
+
+ return children
+}
+
+func (tn *dslNode) Lookup(predicateName string) (*dslNode, error) {
+ childList, ok := tn.children[predicateName]
+ if !ok {
+ return nil, fmt.Errorf("unknown predicate %s", predicateName)
+ }
+
+ for e := childList.Front(); e != nil; e = e.Next() {
+ return e.Value.(*dslNode), nil
+ }
+
+ return nil, fmt.Errorf("nothing in predicate %s", predicateName)
+}
+
+func (tn *dslNode) Errorf(message string, args ...interface{}) error {
+ return withNodeError{
+ error: fmt.Errorf(message, args...),
+ errorSourceCode: "",
+ node: tn,
+ }
+}
+
+func (tn *dslNode) WithSourceErrorf(sourceCode string, message string, args ...interface{}) error {
+ return withNodeError{
+ error: fmt.Errorf(message, args...),
+ errorSourceCode: sourceCode,
+ node: tn,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go
new file mode 100644
index 0000000..aa33c43
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go
@@ -0,0 +1,32 @@
+package compiler
+
+import (
+ "strings"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+type positionMapper struct {
+ schema InputSchema
+ mapper input.SourcePositionMapper
+}
+
+func newPositionMapper(schema InputSchema) input.PositionMapper {
+ return &positionMapper{
+ schema: schema,
+ mapper: input.CreateSourcePositionMapper([]byte(schema.SchemaString)),
+ }
+}
+
+func (pm *positionMapper) RunePositionToLineAndCol(runePosition int, _ input.Source) (int, int, error) {
+ return pm.mapper.RunePositionToLineAndCol(runePosition)
+}
+
+func (pm *positionMapper) LineAndColToRunePosition(lineNumber int, colPosition int, _ input.Source) (int, error) {
+ return pm.mapper.LineAndColToRunePosition(lineNumber, colPosition)
+}
+
+func (pm *positionMapper) TextForLine(lineNumber int, _ input.Source) (string, error) {
+ lines := strings.Split(pm.schema.SchemaString, "\n")
+ return lines[lineNumber], nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go
new file mode 100644
index 0000000..77877b0
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go
@@ -0,0 +1,714 @@
+package compiler
+
+import (
+ "bufio"
+ "fmt"
+ "slices"
+ "strings"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/jzelinskie/stringz"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/namespace"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+type translationContext struct {
+ objectTypePrefix *string
+ mapper input.PositionMapper
+ schemaString string
+ skipValidate bool
+ allowedFlags []string
+ enabledFlags []string
+ caveatTypeSet *caveattypes.TypeSet
+}
+
+func (tctx *translationContext) prefixedPath(definitionName string) (string, error) {
+ var prefix, name string
+ if err := stringz.SplitInto(definitionName, "/", &prefix, &name); err != nil {
+ if tctx.objectTypePrefix == nil {
+ return "", fmt.Errorf("found reference `%s` without prefix", definitionName)
+ }
+ prefix = *tctx.objectTypePrefix
+ name = definitionName
+ }
+
+ if prefix == "" {
+ return name, nil
+ }
+
+ return stringz.Join("/", prefix, name), nil
+}
+
+const Ellipsis = "..."
+
+func translate(tctx *translationContext, root *dslNode) (*CompiledSchema, error) {
+ orderedDefinitions := make([]SchemaDefinition, 0, len(root.GetChildren()))
+ var objectDefinitions []*core.NamespaceDefinition
+ var caveatDefinitions []*core.CaveatDefinition
+
+ names := mapz.NewSet[string]()
+
+ for _, definitionNode := range root.GetChildren() {
+ var definition SchemaDefinition
+
+ switch definitionNode.GetType() {
+ case dslshape.NodeTypeUseFlag:
+ err := translateUseFlag(tctx, definitionNode)
+ if err != nil {
+ return nil, err
+ }
+ continue
+
+ case dslshape.NodeTypeCaveatDefinition:
+ def, err := translateCaveatDefinition(tctx, definitionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ definition = def
+ caveatDefinitions = append(caveatDefinitions, def)
+
+ case dslshape.NodeTypeDefinition:
+ def, err := translateObjectDefinition(tctx, definitionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ definition = def
+ objectDefinitions = append(objectDefinitions, def)
+ }
+
+ if !names.Add(definition.GetName()) {
+ return nil, definitionNode.WithSourceErrorf(definition.GetName(), "found name reused between multiple definitions and/or caveats: %s", definition.GetName())
+ }
+
+ orderedDefinitions = append(orderedDefinitions, definition)
+ }
+
+ return &CompiledSchema{
+ CaveatDefinitions: caveatDefinitions,
+ ObjectDefinitions: objectDefinitions,
+ OrderedDefinitions: orderedDefinitions,
+ rootNode: root,
+ mapper: tctx.mapper,
+ }, nil
+}
+
+func translateCaveatDefinition(tctx *translationContext, defNode *dslNode) (*core.CaveatDefinition, error) {
+ definitionName, err := defNode.GetString(dslshape.NodeCaveatDefinitionPredicateName)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err)
+ }
+
+ // parameters
+ paramNodes := defNode.List(dslshape.NodeCaveatDefinitionPredicateParameters)
+ if len(paramNodes) == 0 {
+ return nil, defNode.WithSourceErrorf(definitionName, "caveat `%s` must have at least one parameter defined", definitionName)
+ }
+
+ env := caveats.NewEnvironment()
+ parameters := make(map[string]caveattypes.VariableType, len(paramNodes))
+ for _, paramNode := range paramNodes {
+ paramName, err := paramNode.GetString(dslshape.NodeCaveatParameterPredicateName)
+ if err != nil {
+ return nil, paramNode.WithSourceErrorf(paramName, "invalid parameter name: %w", err)
+ }
+
+ if _, ok := parameters[paramName]; ok {
+ return nil, paramNode.WithSourceErrorf(paramName, "duplicate parameter `%s` defined on caveat `%s`", paramName, definitionName)
+ }
+
+ typeRefNode, err := paramNode.Lookup(dslshape.NodeCaveatParameterPredicateType)
+ if err != nil {
+ return nil, paramNode.WithSourceErrorf(paramName, "invalid type for parameter: %w", err)
+ }
+
+ translatedType, err := translateCaveatTypeReference(tctx, typeRefNode)
+ if err != nil {
+ return nil, paramNode.WithSourceErrorf(paramName, "invalid type for caveat parameter `%s` on caveat `%s`: %w", paramName, definitionName, err)
+ }
+
+ parameters[paramName] = *translatedType
+ err = env.AddVariable(paramName, *translatedType)
+ if err != nil {
+ return nil, paramNode.WithSourceErrorf(paramName, "invalid type for caveat parameter `%s` on caveat `%s`: %w", paramName, definitionName, err)
+ }
+ }
+
+ caveatPath, err := tctx.prefixedPath(definitionName)
+ if err != nil {
+ return nil, defNode.Errorf("%w", err)
+ }
+
+ // caveat expression.
+ expressionStringNode, err := defNode.Lookup(dslshape.NodeCaveatDefinitionPredicateExpession)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(definitionName, "invalid expression: %w", err)
+ }
+
+ expressionString, err := expressionStringNode.GetString(dslshape.NodeCaveatExpressionPredicateExpression)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err)
+ }
+
+ rnge, err := expressionStringNode.Range(tctx.mapper)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err)
+ }
+
+ source, err := caveats.NewSource(expressionString, caveatPath)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err)
+ }
+
+ compiled, err := caveats.CompileCaveatWithSource(env, caveatPath, source, rnge.Start())
+ if err != nil {
+ return nil, expressionStringNode.WithSourceErrorf(expressionString, "invalid expression for caveat `%s`: %w", definitionName, err)
+ }
+
+ def, err := namespace.CompiledCaveatDefinition(env, caveatPath, compiled)
+ if err != nil {
+ return nil, err
+ }
+
+ def.Metadata = addComments(def.Metadata, defNode)
+ def.SourcePosition = getSourcePosition(defNode, tctx.mapper)
+ return def, nil
+}
+
+func translateCaveatTypeReference(tctx *translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) {
+ typeName, err := typeRefNode.GetString(dslshape.NodeCaveatTypeReferencePredicateType)
+ if err != nil {
+ return nil, typeRefNode.WithSourceErrorf(typeName, "invalid type name: %w", err)
+ }
+
+ childTypeNodes := typeRefNode.List(dslshape.NodeCaveatTypeReferencePredicateChildTypes)
+ childTypes := make([]caveattypes.VariableType, 0, len(childTypeNodes))
+ for _, childTypeNode := range childTypeNodes {
+ translated, err := translateCaveatTypeReference(tctx, childTypeNode)
+ if err != nil {
+ return nil, err
+ }
+ childTypes = append(childTypes, *translated)
+ }
+
+ constructedType, err := tctx.caveatTypeSet.BuildType(typeName, childTypes)
+ if err != nil {
+ return nil, typeRefNode.WithSourceErrorf(typeName, "%w", err)
+ }
+
+ return constructedType, nil
+}
+
+func translateObjectDefinition(tctx *translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) {
+ definitionName, err := defNode.GetString(dslshape.NodeDefinitionPredicateName)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err)
+ }
+
+ relationsAndPermissions := []*core.Relation{}
+ for _, relationOrPermissionNode := range defNode.GetChildren() {
+ if relationOrPermissionNode.GetType() == dslshape.NodeTypeComment {
+ continue
+ }
+
+ relationOrPermission, err := translateRelationOrPermission(tctx, relationOrPermissionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ relationsAndPermissions = append(relationsAndPermissions, relationOrPermission)
+ }
+
+ nspath, err := tctx.prefixedPath(definitionName)
+ if err != nil {
+ return nil, defNode.Errorf("%w", err)
+ }
+
+ if len(relationsAndPermissions) == 0 {
+ ns := namespace.Namespace(nspath)
+ ns.Metadata = addComments(ns.Metadata, defNode)
+ ns.SourcePosition = getSourcePosition(defNode, tctx.mapper)
+
+ if !tctx.skipValidate {
+ if err = ns.Validate(); err != nil {
+ return nil, defNode.Errorf("error in object definition %s: %w", nspath, err)
+ }
+ }
+
+ return ns, nil
+ }
+
+ ns := namespace.Namespace(nspath, relationsAndPermissions...)
+ ns.Metadata = addComments(ns.Metadata, defNode)
+ ns.SourcePosition = getSourcePosition(defNode, tctx.mapper)
+
+ if !tctx.skipValidate {
+ if err := ns.Validate(); err != nil {
+ return nil, defNode.Errorf("error in object definition %s: %w", nspath, err)
+ }
+ }
+
+ return ns, nil
+}
+
+func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.SourcePosition {
+ if !dslNode.Has(dslshape.NodePredicateStartRune) {
+ return nil
+ }
+
+ sourceRange, err := dslNode.Range(mapper)
+ if err != nil {
+ return nil
+ }
+
+ line, col, err := sourceRange.Start().LineAndColumn()
+ if err != nil {
+ return nil
+ }
+
+ // We're okay with these being zero if the cast fails.
+ uintLine, _ := safecast.ToUint64(line)
+ uintCol, _ := safecast.ToUint64(col)
+
+ return &core.SourcePosition{
+ ZeroIndexedLineNumber: uintLine,
+ ZeroIndexedColumnPosition: uintCol,
+ }
+}
+
+func addComments(mdmsg *core.Metadata, dslNode *dslNode) *core.Metadata {
+ for _, child := range dslNode.GetChildren() {
+ if child.GetType() == dslshape.NodeTypeComment {
+ value, err := child.GetString(dslshape.NodeCommentPredicateValue)
+ if err == nil {
+ mdmsg, _ = namespace.AddComment(mdmsg, normalizeComment(value))
+ }
+ }
+ }
+ return mdmsg
+}
+
+func normalizeComment(value string) string {
+ var lines []string
+ scanner := bufio.NewScanner(strings.NewReader(value))
+ for scanner.Scan() {
+ trimmed := strings.TrimSpace(scanner.Text())
+ lines = append(lines, trimmed)
+ }
+ return strings.Join(lines, "\n")
+}
+
+func translateRelationOrPermission(tctx *translationContext, relOrPermNode *dslNode) (*core.Relation, error) {
+ switch relOrPermNode.GetType() {
+ case dslshape.NodeTypeRelation:
+ rel, err := translateRelation(tctx, relOrPermNode)
+ if err != nil {
+ return nil, err
+ }
+ rel.Metadata = addComments(rel.Metadata, relOrPermNode)
+ rel.SourcePosition = getSourcePosition(relOrPermNode, tctx.mapper)
+ return rel, err
+
+ case dslshape.NodeTypePermission:
+ rel, err := translatePermission(tctx, relOrPermNode)
+ if err != nil {
+ return nil, err
+ }
+ rel.Metadata = addComments(rel.Metadata, relOrPermNode)
+ rel.SourcePosition = getSourcePosition(relOrPermNode, tctx.mapper)
+ return rel, err
+
+ default:
+ return nil, relOrPermNode.Errorf("unknown definition top-level node type %s", relOrPermNode.GetType())
+ }
+}
+
+func translateRelation(tctx *translationContext, relationNode *dslNode) (*core.Relation, error) {
+ relationName, err := relationNode.GetString(dslshape.NodePredicateName)
+ if err != nil {
+ return nil, relationNode.Errorf("invalid relation name: %w", err)
+ }
+
+ allowedDirectTypes := []*core.AllowedRelation{}
+ for _, typeRef := range relationNode.List(dslshape.NodeRelationPredicateAllowedTypes) {
+ allowedRelations, err := translateAllowedRelations(tctx, typeRef)
+ if err != nil {
+ return nil, err
+ }
+
+ allowedDirectTypes = append(allowedDirectTypes, allowedRelations...)
+ }
+
+ relation, err := namespace.Relation(relationName, nil, allowedDirectTypes...)
+ if err != nil {
+ return nil, err
+ }
+
+ if !tctx.skipValidate {
+ if err := relation.Validate(); err != nil {
+ return nil, relationNode.Errorf("error in relation %s: %w", relationName, err)
+ }
+ }
+
+ return relation, nil
+}
+
+func translatePermission(tctx *translationContext, permissionNode *dslNode) (*core.Relation, error) {
+ permissionName, err := permissionNode.GetString(dslshape.NodePredicateName)
+ if err != nil {
+ return nil, permissionNode.Errorf("invalid permission name: %w", err)
+ }
+
+ expressionNode, err := permissionNode.Lookup(dslshape.NodePermissionPredicateComputeExpression)
+ if err != nil {
+ return nil, permissionNode.Errorf("invalid permission expression: %w", err)
+ }
+
+ rewrite, err := translateExpression(tctx, expressionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ permission, err := namespace.Relation(permissionName, rewrite)
+ if err != nil {
+ return nil, err
+ }
+
+ if !tctx.skipValidate {
+ if err := permission.Validate(); err != nil {
+ return nil, permissionNode.Errorf("error in permission %s: %w", permissionName, err)
+ }
+ }
+
+ return permission, nil
+}
+
+func translateBinary(tctx *translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) {
+ leftChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ rightChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateRightExpr)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ leftOperation, err := translateExpressionOperation(tctx, leftChild)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ rightOperation, err := translateExpressionOperation(tctx, rightChild)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return leftOperation, rightOperation, nil
+}
+
+func translateExpression(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) {
+ translated, err := translateExpressionDirect(tctx, expressionNode)
+ if err != nil {
+ return translated, err
+ }
+
+ translated.SourcePosition = getSourcePosition(expressionNode, tctx.mapper)
+ return translated, nil
+}
+
+func collapseOps(op *core.SetOperation_Child, handler func(rewrite *core.UsersetRewrite) *core.SetOperation) []*core.SetOperation_Child {
+ if op.GetUsersetRewrite() == nil {
+ return []*core.SetOperation_Child{op}
+ }
+
+ usersetRewrite := op.GetUsersetRewrite()
+ operation := handler(usersetRewrite)
+ if operation == nil {
+ return []*core.SetOperation_Child{op}
+ }
+
+ collapsed := make([]*core.SetOperation_Child, 0, len(operation.Child))
+ for _, child := range operation.Child {
+ collapsed = append(collapsed, collapseOps(child, handler)...)
+ }
+ return collapsed
+}
+
+func translateExpressionDirect(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) {
+ // For union and intersection, we collapse a tree of binary operations into a flat list containing child
+ // operations of the *same* type.
+ translate := func(
+ builder func(firstChild *core.SetOperation_Child, rest ...*core.SetOperation_Child) *core.UsersetRewrite,
+ lookup func(rewrite *core.UsersetRewrite) *core.SetOperation,
+ ) (*core.UsersetRewrite, error) {
+ leftOperation, rightOperation, err := translateBinary(tctx, expressionNode)
+ if err != nil {
+ return nil, err
+ }
+ leftOps := collapseOps(leftOperation, lookup)
+ rightOps := collapseOps(rightOperation, lookup)
+ ops := append(leftOps, rightOps...)
+ return builder(ops[0], ops[1:]...), nil
+ }
+
+ switch expressionNode.GetType() {
+ case dslshape.NodeTypeUnionExpression:
+ return translate(namespace.Union, func(rewrite *core.UsersetRewrite) *core.SetOperation {
+ return rewrite.GetUnion()
+ })
+
+ case dslshape.NodeTypeIntersectExpression:
+ return translate(namespace.Intersection, func(rewrite *core.UsersetRewrite) *core.SetOperation {
+ return rewrite.GetIntersection()
+ })
+
+ case dslshape.NodeTypeExclusionExpression:
+ // Order matters for exclusions, so do not perform the optimization.
+ leftOperation, rightOperation, err := translateBinary(tctx, expressionNode)
+ if err != nil {
+ return nil, err
+ }
+ return namespace.Exclusion(leftOperation, rightOperation), nil
+
+ default:
+ op, err := translateExpressionOperation(tctx, expressionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ return namespace.Union(op), nil
+ }
+}
+
+func translateExpressionOperation(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) {
+ translated, err := translateExpressionOperationDirect(tctx, expressionOpNode)
+ if err != nil {
+ return translated, err
+ }
+
+ translated.SourcePosition = getSourcePosition(expressionOpNode, tctx.mapper)
+ return translated, nil
+}
+
+func translateExpressionOperationDirect(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) {
+ switch expressionOpNode.GetType() {
+ case dslshape.NodeTypeIdentifier:
+ referencedRelationName, err := expressionOpNode.GetString(dslshape.NodeIdentiferPredicateValue)
+ if err != nil {
+ return nil, err
+ }
+
+ return namespace.ComputedUserset(referencedRelationName), nil
+
+ case dslshape.NodeTypeNilExpression:
+ return namespace.Nil(), nil
+
+ case dslshape.NodeTypeArrowExpression:
+ leftChild, err := expressionOpNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr)
+ if err != nil {
+ return nil, err
+ }
+
+ rightChild, err := expressionOpNode.Lookup(dslshape.NodeExpressionPredicateRightExpr)
+ if err != nil {
+ return nil, err
+ }
+
+ if leftChild.GetType() != dslshape.NodeTypeIdentifier {
+ return nil, leftChild.Errorf("Nested arrows not yet supported")
+ }
+
+ tuplesetRelation, err := leftChild.GetString(dslshape.NodeIdentiferPredicateValue)
+ if err != nil {
+ return nil, err
+ }
+
+ usersetRelation, err := rightChild.GetString(dslshape.NodeIdentiferPredicateValue)
+ if err != nil {
+ return nil, err
+ }
+
+ if expressionOpNode.Has(dslshape.NodeArrowExpressionFunctionName) {
+ functionName, err := expressionOpNode.GetString(dslshape.NodeArrowExpressionFunctionName)
+ if err != nil {
+ return nil, err
+ }
+
+ return namespace.MustFunctionedTupleToUserset(tuplesetRelation, functionName, usersetRelation), nil
+ }
+
+ return namespace.TupleToUserset(tuplesetRelation, usersetRelation), nil
+
+ case dslshape.NodeTypeUnionExpression:
+ fallthrough
+
+ case dslshape.NodeTypeIntersectExpression:
+ fallthrough
+
+ case dslshape.NodeTypeExclusionExpression:
+ rewrite, err := translateExpression(tctx, expressionOpNode)
+ if err != nil {
+ return nil, err
+ }
+ return namespace.Rewrite(rewrite), nil
+
+ default:
+ return nil, expressionOpNode.Errorf("unknown expression node type %s", expressionOpNode.GetType())
+ }
+}
+
+func translateAllowedRelations(tctx *translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) {
+ switch typeRefNode.GetType() {
+ case dslshape.NodeTypeTypeReference:
+ references := []*core.AllowedRelation{}
+ for _, subRefNode := range typeRefNode.List(dslshape.NodeTypeReferencePredicateType) {
+ subReferences, err := translateAllowedRelations(tctx, subRefNode)
+ if err != nil {
+ return []*core.AllowedRelation{}, err
+ }
+
+ references = append(references, subReferences...)
+ }
+ return references, nil
+
+ case dslshape.NodeTypeSpecificTypeReference:
+ ref, err := translateSpecificTypeReference(tctx, typeRefNode)
+ if err != nil {
+ return []*core.AllowedRelation{}, err
+ }
+ return []*core.AllowedRelation{ref}, nil
+
+ default:
+ return nil, typeRefNode.Errorf("unknown type ref node type %s", typeRefNode.GetType())
+ }
+}
+
+func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) {
+ typePath, err := typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateType)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid type name: %w", err)
+ }
+
+ nspath, err := tctx.prefixedPath(typePath)
+ if err != nil {
+ return nil, typeRefNode.Errorf("%w", err)
+ }
+
+ if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateWildcard) {
+ ref := &core.AllowedRelation{
+ Namespace: nspath,
+ RelationOrWildcard: &core.AllowedRelation_PublicWildcard_{
+ PublicWildcard: &core.AllowedRelation_PublicWildcard{},
+ },
+ }
+
+ err = addWithCaveats(tctx, typeRefNode, ref)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid caveat: %w", err)
+ }
+
+ if !tctx.skipValidate {
+ if err := ref.Validate(); err != nil {
+ return nil, typeRefNode.Errorf("invalid type relation: %w", err)
+ }
+ }
+
+ ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
+ return ref, nil
+ }
+
+ relationName := Ellipsis
+ if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateRelation) {
+ relationName, err = typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateRelation)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid type relation: %w", err)
+ }
+ }
+
+ ref := &core.AllowedRelation{
+ Namespace: nspath,
+ RelationOrWildcard: &core.AllowedRelation_Relation{
+ Relation: relationName,
+ },
+ }
+
+ // Add the caveat(s), if any.
+ err = addWithCaveats(tctx, typeRefNode, ref)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid caveat: %w", err)
+ }
+
+ // Add the expiration trait, if any.
+ if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil {
+ traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid trait: %w", err)
+ }
+
+ if traitName != "expiration" {
+ return nil, typeRefNode.Errorf("invalid trait: %s", traitName)
+ }
+
+ if !slices.Contains(tctx.allowedFlags, "expiration") {
+ return nil, typeRefNode.Errorf("expiration trait is not allowed")
+ }
+
+ ref.RequiredExpiration = &core.ExpirationTrait{}
+ }
+
+ if !tctx.skipValidate {
+ if err := ref.Validate(); err != nil {
+ return nil, typeRefNode.Errorf("invalid type relation: %w", err)
+ }
+ }
+
+ ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
+ return ref, nil
+}
+
+func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error {
+ caveats := typeRefNode.List(dslshape.NodeSpecificReferencePredicateCaveat)
+ if len(caveats) == 0 {
+ return nil
+ }
+
+ if len(caveats) != 1 {
+ return fmt.Errorf("only one caveat is currently allowed per type reference")
+ }
+
+ name, err := caveats[0].GetString(dslshape.NodeCaveatPredicateCaveat)
+ if err != nil {
+ return err
+ }
+
+ nspath, err := tctx.prefixedPath(name)
+ if err != nil {
+ return err
+ }
+
+ ref.RequiredCaveat = &core.AllowedCaveat{
+ CaveatName: nspath,
+ }
+ return nil
+}
+
+// Translate use node and add flag to list of enabled flags
+func translateUseFlag(tctx *translationContext, useFlagNode *dslNode) error {
+ flagName, err := useFlagNode.GetString(dslshape.NodeUseFlagPredicateName)
+ if err != nil {
+ return err
+ }
+ if slices.Contains(tctx.enabledFlags, flagName) {
+ return useFlagNode.Errorf("found duplicate use flag: %s", flagName)
+ }
+ tctx.enabledFlags = append(tctx.enabledFlags, flagName)
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/doc.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/doc.go
new file mode 100644
index 0000000..457bd0b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/doc.go
@@ -0,0 +1,2 @@
+// Package dslshape defines the types representing the structure of schema DSL.
+package dslshape
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/dslshape.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/dslshape.go
new file mode 100644
index 0000000..c3a599f
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/dslshape.go
@@ -0,0 +1,209 @@
+//go:generate go run golang.org/x/tools/cmd/stringer -type=NodeType -output zz_generated.nodetype_string.go
+
+package dslshape
+
+// NodeType identifies the type of AST node.
+type NodeType int
+
+const (
+ // Top-level
+ NodeTypeError NodeType = iota // error occurred; value is text of error
+ NodeTypeFile // The file root node
+ NodeTypeComment // A single or multiline comment
+ NodeTypeUseFlag // A use flag
+
+ NodeTypeDefinition // A definition.
+ NodeTypeCaveatDefinition // A caveat definition.
+
+ NodeTypeCaveatParameter // A caveat parameter.
+ NodeTypeCaveatExpression // A caveat expression.
+
+ NodeTypeRelation // A relation
+ NodeTypePermission // A permission
+
+ NodeTypeTypeReference // A type reference
+ NodeTypeSpecificTypeReference // A reference to a specific type.
+ NodeTypeCaveatReference // A caveat reference under a type.
+ NodeTypeTraitReference // A trait reference under a typr.
+
+ NodeTypeUnionExpression
+ NodeTypeIntersectExpression
+ NodeTypeExclusionExpression
+
+ NodeTypeArrowExpression // A TTU in arrow form.
+
+ NodeTypeIdentifier // An identifier under an expression.
+ NodeTypeNilExpression // A nil keyword
+
+ NodeTypeCaveatTypeReference // A type reference for a caveat parameter.
+)
+
+const (
+ //
+ // All nodes
+ //
+ // The source of this node.
+ NodePredicateSource = "input-source"
+
+ // The rune position in the input string at which this node begins.
+ NodePredicateStartRune = "start-rune"
+
+ // The rune position in the input string at which this node ends.
+ NodePredicateEndRune = "end-rune"
+
+ // A direct child of this node. Implementations should handle the ordering
+ // automatically for this predicate.
+ NodePredicateChild = "child-node"
+
+ //
+ // NodeTypeError
+ //
+
+ // The message for the parsing error.
+ NodePredicateErrorMessage = "error-message"
+
+ // The (optional) source to highlight for the parsing error.
+ NodePredicateErrorSource = "error-source"
+
+ //
+ // NodeTypeComment
+ //
+
+ // The value of the comment, including its delimeter(s)
+ NodeCommentPredicateValue = "comment-value"
+
+ //
+ // NodeTypeUseFlag
+ //
+
+ // The name of the use flag.
+ NodeUseFlagPredicateName = "use-flag-name"
+
+ //
+ // NodeTypeDefinition
+ //
+
+ // The name of the definition
+ NodeDefinitionPredicateName = "definition-name"
+
+ //
+ // NodeTypeCaveatDefinition
+ //
+
+ // The name of the definition
+ NodeCaveatDefinitionPredicateName = "caveat-definition-name"
+
+ // The parameters for the definition.
+ NodeCaveatDefinitionPredicateParameters = "parameters"
+
+ // The link to the expression for the definition.
+ NodeCaveatDefinitionPredicateExpession = "caveat-definition-expression"
+
+ //
+ // NodeTypeCaveatExpression
+ //
+
+ // The raw CEL expression, in string form.
+ NodeCaveatExpressionPredicateExpression = "caveat-expression-expressionstr"
+
+ //
+ // NodeTypeCaveatParameter
+ //
+
+ // The name of the parameter
+ NodeCaveatParameterPredicateName = "caveat-parameter-name"
+
+ // The defined type of the caveat parameter.
+ NodeCaveatParameterPredicateType = "caveat-parameter-type"
+
+ //
+ // NodeTypeCaveatTypeReference
+ //
+
+ // The type for the caveat type reference.
+ NodeCaveatTypeReferencePredicateType = "type-name"
+
+ // The child type(s) for the type reference.
+ NodeCaveatTypeReferencePredicateChildTypes = "child-types"
+
+ //
+ // NodeTypeRelation + NodeTypePermission
+ //
+
+ // The name of the relation/permission
+ NodePredicateName = "relation-name"
+
+ //
+ // NodeTypeRelation
+ //
+
+ // The allowed types for the relation.
+ NodeRelationPredicateAllowedTypes = "allowed-types"
+
+ //
+ // NodeTypeTypeReference
+ //
+
+ // A type under a type reference.
+ NodeTypeReferencePredicateType = "type-ref-type"
+
+ //
+ // NodeTypeSpecificTypeReference
+ //
+
+ // A type under a type reference.
+ NodeSpecificReferencePredicateType = "type-name"
+
+ // A relation under a type reference.
+ NodeSpecificReferencePredicateRelation = "relation-name"
+
+ // A wildcard under a type reference.
+ NodeSpecificReferencePredicateWildcard = "type-wildcard"
+
+ // A caveat under a type reference.
+ NodeSpecificReferencePredicateCaveat = "caveat"
+
+ // A trait under a type reference.
+ NodeSpecificReferencePredicateTrait = "trait"
+
+ //
+ // NodeTypeCaveatReference
+ //
+
+ // The caveat name under the caveat.
+ NodeCaveatPredicateCaveat = "caveat-name"
+
+ //
+ // NodeTypeTraitReference
+ //
+
+ // The trait name under the trait.
+ NodeTraitPredicateTrait = "trait-name"
+
+ //
+ // NodeTypePermission
+ //
+
+ // The expression to compute the permission.
+ NodePermissionPredicateComputeExpression = "compute-expression"
+
+ //
+ // NodeTypeArrowExpression
+ //
+
+ // The name of the function in the arrow expression.
+ NodeArrowExpressionFunctionName = "function-name"
+
+ //
+ // NodeTypeIdentifer
+ //
+
+ // The value of the identifier.
+ NodeIdentiferPredicateValue = "identifier-value"
+
+ //
+ // NodeTypeUnionExpression + NodeTypeIntersectExpression + NodeTypeExclusionExpression + NodeTypeArrowExpression
+ //
+ NodeExpressionPredicateLeftExpr = "left-expr"
+ NodeExpressionPredicateRightExpr = "right-expr"
+)
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/zz_generated.nodetype_string.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/zz_generated.nodetype_string.go
new file mode 100644
index 0000000..4ef1e06
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/zz_generated.nodetype_string.go
@@ -0,0 +1,43 @@
+// Code generated by "stringer -type=NodeType -output zz_generated.nodetype_string.go"; DO NOT EDIT.
+
+package dslshape
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[NodeTypeError-0]
+ _ = x[NodeTypeFile-1]
+ _ = x[NodeTypeComment-2]
+ _ = x[NodeTypeUseFlag-3]
+ _ = x[NodeTypeDefinition-4]
+ _ = x[NodeTypeCaveatDefinition-5]
+ _ = x[NodeTypeCaveatParameter-6]
+ _ = x[NodeTypeCaveatExpression-7]
+ _ = x[NodeTypeRelation-8]
+ _ = x[NodeTypePermission-9]
+ _ = x[NodeTypeTypeReference-10]
+ _ = x[NodeTypeSpecificTypeReference-11]
+ _ = x[NodeTypeCaveatReference-12]
+ _ = x[NodeTypeTraitReference-13]
+ _ = x[NodeTypeUnionExpression-14]
+ _ = x[NodeTypeIntersectExpression-15]
+ _ = x[NodeTypeExclusionExpression-16]
+ _ = x[NodeTypeArrowExpression-17]
+ _ = x[NodeTypeIdentifier-18]
+ _ = x[NodeTypeNilExpression-19]
+ _ = x[NodeTypeCaveatTypeReference-20]
+}
+
+const _NodeType_name = "NodeTypeErrorNodeTypeFileNodeTypeCommentNodeTypeUseFlagNodeTypeDefinitionNodeTypeCaveatDefinitionNodeTypeCaveatParameterNodeTypeCaveatExpressionNodeTypeRelationNodeTypePermissionNodeTypeTypeReferenceNodeTypeSpecificTypeReferenceNodeTypeCaveatReferenceNodeTypeTraitReferenceNodeTypeUnionExpressionNodeTypeIntersectExpressionNodeTypeExclusionExpressionNodeTypeArrowExpressionNodeTypeIdentifierNodeTypeNilExpressionNodeTypeCaveatTypeReference"
+
+var _NodeType_index = [...]uint16{0, 13, 25, 40, 55, 73, 97, 120, 144, 160, 178, 199, 228, 251, 273, 296, 323, 350, 373, 391, 412, 439}
+
+func (i NodeType) String() string {
+ if i < 0 || i >= NodeType(len(_NodeType_index)-1) {
+ return "NodeType(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _NodeType_name[_NodeType_index[i]:_NodeType_index[i+1]]
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator.go
new file mode 100644
index 0000000..3c8fdd9
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator.go
@@ -0,0 +1,430 @@
+package generator
+
+import (
+ "bufio"
+ "fmt"
+ "sort"
+ "strings"
+
+ "golang.org/x/exp/maps"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/graph"
+ "github.com/authzed/spicedb/pkg/namespace"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+)
+
+// Ellipsis is the relation name for terminal subjects.
+const Ellipsis = "..."
+
+// MaxSingleLineCommentLength sets the maximum length for a comment to made single line.
+const MaxSingleLineCommentLength = 70 // 80 - the comment parts and some padding
+
+func GenerateSchema(definitions []compiler.SchemaDefinition) (string, bool, error) {
+ return GenerateSchemaWithCaveatTypeSet(definitions, caveattypes.Default.TypeSet)
+}
+
+// GenerateSchemaWithCaveatTypeSet generates a DSL view of the given schema.
+func GenerateSchemaWithCaveatTypeSet(definitions []compiler.SchemaDefinition, caveatTypeSet *caveattypes.TypeSet) (string, bool, error) {
+ generated := make([]string, 0, len(definitions))
+ flags := mapz.NewSet[string]()
+
+ result := true
+ for _, definition := range definitions {
+ switch def := definition.(type) {
+ case *core.CaveatDefinition:
+ generatedCaveat, ok, err := GenerateCaveatSource(def, caveatTypeSet)
+ if err != nil {
+ return "", false, err
+ }
+
+ result = result && ok
+ generated = append(generated, generatedCaveat)
+
+ case *core.NamespaceDefinition:
+ generatedSchema, defFlags, ok, err := generateDefinitionSource(def, caveatTypeSet)
+ if err != nil {
+ return "", false, err
+ }
+
+ result = result && ok
+ generated = append(generated, generatedSchema)
+ flags.Extend(defFlags)
+
+ default:
+ return "", false, spiceerrors.MustBugf("unknown type of definition %T in GenerateSchema", def)
+ }
+ }
+
+ if !flags.IsEmpty() {
+ flagsSlice := flags.AsSlice()
+ sort.Strings(flagsSlice)
+
+ for _, flag := range flagsSlice {
+ generated = append([]string{"use " + flag}, generated...)
+ }
+ }
+
+ return strings.Join(generated, "\n\n"), result, nil
+}
+
+// GenerateCaveatSource generates a DSL view of the given caveat definition.
+func GenerateCaveatSource(caveat *core.CaveatDefinition, caveatTypeSet *caveattypes.TypeSet) (string, bool, error) {
+ generator := &sourceGenerator{
+ indentationLevel: 0,
+ hasNewline: true,
+ hasBlankline: true,
+ hasNewScope: true,
+ caveatTypeSet: caveatTypeSet,
+ }
+
+ err := generator.emitCaveat(caveat)
+ if err != nil {
+ return "", false, err
+ }
+
+ return generator.buf.String(), !generator.hasIssue, nil
+}
+
+// GenerateSource generates a DSL view of the given namespace definition.
+func GenerateSource(namespace *core.NamespaceDefinition, caveatTypeSet *caveattypes.TypeSet) (string, bool, error) {
+ source, _, ok, err := generateDefinitionSource(namespace, caveatTypeSet)
+ return source, ok, err
+}
+
+func generateDefinitionSource(namespace *core.NamespaceDefinition, caveatTypeSet *caveattypes.TypeSet) (string, []string, bool, error) {
+ generator := &sourceGenerator{
+ indentationLevel: 0,
+ hasNewline: true,
+ hasBlankline: true,
+ hasNewScope: true,
+ flags: mapz.NewSet[string](),
+ caveatTypeSet: caveatTypeSet,
+ }
+
+ err := generator.emitNamespace(namespace)
+ if err != nil {
+ return "", nil, false, err
+ }
+
+ return generator.buf.String(), generator.flags.AsSlice(), !generator.hasIssue, nil
+}
+
+// GenerateRelationSource generates a DSL view of the given relation definition.
+func GenerateRelationSource(relation *core.Relation, caveatTypeSet *caveattypes.TypeSet) (string, error) {
+ generator := &sourceGenerator{
+ indentationLevel: 0,
+ hasNewline: true,
+ hasBlankline: true,
+ hasNewScope: true,
+ caveatTypeSet: caveatTypeSet,
+ }
+
+ err := generator.emitRelation(relation)
+ if err != nil {
+ return "", err
+ }
+
+ return generator.buf.String(), nil
+}
+
+func (sg *sourceGenerator) emitCaveat(caveat *core.CaveatDefinition) error {
+ sg.emitComments(caveat.Metadata)
+ sg.append("caveat ")
+ sg.append(caveat.Name)
+ sg.append("(")
+
+ parameterNames := maps.Keys(caveat.ParameterTypes)
+ sort.Strings(parameterNames)
+
+ for index, paramName := range parameterNames {
+ if index > 0 {
+ sg.append(", ")
+ }
+
+ decoded, err := caveattypes.DecodeParameterType(sg.caveatTypeSet, caveat.ParameterTypes[paramName])
+ if err != nil {
+ return fmt.Errorf("invalid parameter type on caveat: %w", err)
+ }
+
+ sg.append(paramName)
+ sg.append(" ")
+ sg.append(decoded.String())
+ }
+
+ sg.append(")")
+
+ sg.append(" {")
+ sg.appendLine()
+ sg.indent()
+ sg.markNewScope()
+
+ parameterTypes, err := caveattypes.DecodeParameterTypes(sg.caveatTypeSet, caveat.ParameterTypes)
+ if err != nil {
+ return fmt.Errorf("invalid caveat parameters: %w", err)
+ }
+
+ deserializedExpression, err := caveats.DeserializeCaveatWithTypeSet(sg.caveatTypeSet, caveat.SerializedExpression, parameterTypes)
+ if err != nil {
+ return fmt.Errorf("invalid caveat expression bytes: %w", err)
+ }
+
+ exprString, err := deserializedExpression.ExprString()
+ if err != nil {
+ return fmt.Errorf("invalid caveat expression: %w", err)
+ }
+
+ sg.append(strings.TrimSpace(exprString))
+ sg.appendLine()
+
+ sg.dedent()
+ sg.append("}")
+ return nil
+}
+
+func (sg *sourceGenerator) emitNamespace(namespace *core.NamespaceDefinition) error {
+ sg.emitComments(namespace.Metadata)
+ sg.append("definition ")
+ sg.append(namespace.Name)
+
+ if len(namespace.Relation) == 0 {
+ sg.append(" {}")
+ return nil
+ }
+
+ sg.append(" {")
+ sg.appendLine()
+ sg.indent()
+ sg.markNewScope()
+
+ for _, relation := range namespace.Relation {
+ err := sg.emitRelation(relation)
+ if err != nil {
+ return err
+ }
+ }
+
+ sg.dedent()
+ sg.append("}")
+ return nil
+}
+
+func (sg *sourceGenerator) emitRelation(relation *core.Relation) error {
+ hasThis, err := graph.HasThis(relation.UsersetRewrite)
+ if err != nil {
+ return err
+ }
+
+ isPermission := relation.UsersetRewrite != nil && !hasThis
+
+ sg.emitComments(relation.Metadata)
+ if isPermission {
+ sg.append("permission ")
+ } else {
+ sg.append("relation ")
+ }
+
+ sg.append(relation.Name)
+
+ if !isPermission {
+ sg.append(": ")
+ if relation.TypeInformation == nil || relation.TypeInformation.AllowedDirectRelations == nil || len(relation.TypeInformation.AllowedDirectRelations) == 0 {
+ sg.appendIssue("missing allowed types")
+ } else {
+ for index, allowedRelation := range relation.TypeInformation.AllowedDirectRelations {
+ if index > 0 {
+ sg.append(" | ")
+ }
+
+ sg.emitAllowedRelation(allowedRelation)
+ }
+ }
+ }
+
+ if relation.UsersetRewrite != nil {
+ sg.append(" = ")
+ sg.mustEmitRewrite(relation.UsersetRewrite)
+ }
+
+ sg.appendLine()
+ return nil
+}
+
+func (sg *sourceGenerator) emitAllowedRelation(allowedRelation *core.AllowedRelation) {
+ sg.append(allowedRelation.Namespace)
+ if allowedRelation.GetRelation() != "" && allowedRelation.GetRelation() != Ellipsis {
+ sg.append("#")
+ sg.append(allowedRelation.GetRelation())
+ }
+ if allowedRelation.GetPublicWildcard() != nil {
+ sg.append(":*")
+ }
+
+ hasExpirationTrait := allowedRelation.GetRequiredExpiration() != nil
+ hasCaveat := allowedRelation.GetRequiredCaveat() != nil
+
+ if hasExpirationTrait || hasCaveat {
+ sg.append(" with ")
+ if hasCaveat {
+ sg.append(allowedRelation.RequiredCaveat.CaveatName)
+ }
+
+ if hasExpirationTrait {
+ sg.flags.Add("expiration")
+
+ if hasCaveat {
+ sg.append(" and ")
+ }
+
+ sg.append("expiration")
+ }
+ }
+}
+
+func (sg *sourceGenerator) mustEmitRewrite(rewrite *core.UsersetRewrite) {
+ switch rw := rewrite.RewriteOperation.(type) {
+ case *core.UsersetRewrite_Union:
+ sg.emitRewriteOps(rw.Union, "+")
+ case *core.UsersetRewrite_Intersection:
+ sg.emitRewriteOps(rw.Intersection, "&")
+ case *core.UsersetRewrite_Exclusion:
+ sg.emitRewriteOps(rw.Exclusion, "-")
+ default:
+ panic(spiceerrors.MustBugf("unknown rewrite operation %T", rw))
+ }
+}
+
+func (sg *sourceGenerator) emitRewriteOps(setOp *core.SetOperation, op string) {
+ for index, child := range setOp.Child {
+ if index > 0 {
+ sg.append(" " + op + " ")
+ }
+
+ sg.mustEmitSetOpChild(child)
+ }
+}
+
+func (sg *sourceGenerator) isAllUnion(rewrite *core.UsersetRewrite) bool {
+ switch rw := rewrite.RewriteOperation.(type) {
+ case *core.UsersetRewrite_Union:
+ for _, setOpChild := range rw.Union.Child {
+ switch child := setOpChild.ChildType.(type) {
+ case *core.SetOperation_Child_UsersetRewrite:
+ if !sg.isAllUnion(child.UsersetRewrite) {
+ return false
+ }
+ default:
+ continue
+ }
+ }
+ return true
+ default:
+ return false
+ }
+}
+
+func (sg *sourceGenerator) mustEmitSetOpChild(setOpChild *core.SetOperation_Child) {
+ switch child := setOpChild.ChildType.(type) {
+ case *core.SetOperation_Child_UsersetRewrite:
+ if sg.isAllUnion(child.UsersetRewrite) {
+ sg.mustEmitRewrite(child.UsersetRewrite)
+ break
+ }
+
+ sg.append("(")
+ sg.mustEmitRewrite(child.UsersetRewrite)
+ sg.append(")")
+
+ case *core.SetOperation_Child_XThis:
+ sg.appendIssue("_this unsupported here. Please rewrite into a relation and permission")
+
+ case *core.SetOperation_Child_XNil:
+ sg.append("nil")
+
+ case *core.SetOperation_Child_ComputedUserset:
+ sg.append(child.ComputedUserset.Relation)
+
+ case *core.SetOperation_Child_TupleToUserset:
+ sg.append(child.TupleToUserset.Tupleset.Relation)
+ sg.append("->")
+ sg.append(child.TupleToUserset.ComputedUserset.Relation)
+
+ case *core.SetOperation_Child_FunctionedTupleToUserset:
+ sg.append(child.FunctionedTupleToUserset.Tupleset.Relation)
+ sg.append(".")
+
+ switch child.FunctionedTupleToUserset.Function {
+ case core.FunctionedTupleToUserset_FUNCTION_ALL:
+ sg.append("all")
+
+ case core.FunctionedTupleToUserset_FUNCTION_ANY:
+ sg.append("any")
+
+ default:
+ panic(spiceerrors.MustBugf("unknown function %v", child.FunctionedTupleToUserset.Function))
+ }
+
+ sg.append("(")
+ sg.append(child.FunctionedTupleToUserset.ComputedUserset.Relation)
+ sg.append(")")
+
+ default:
+ panic(spiceerrors.MustBugf("unknown child type %T", child))
+ }
+}
+
+func (sg *sourceGenerator) emitComments(metadata *core.Metadata) {
+ if len(namespace.GetComments(metadata)) > 0 {
+ sg.ensureBlankLineOrNewScope()
+ }
+
+ for _, comment := range namespace.GetComments(metadata) {
+ sg.appendComment(comment)
+ }
+}
+
+func (sg *sourceGenerator) appendComment(comment string) {
+ switch {
+ case strings.HasPrefix(comment, "/*"):
+ stripped := strings.TrimSpace(comment)
+
+ if strings.HasPrefix(stripped, "/**") {
+ stripped = strings.TrimPrefix(stripped, "/**")
+ sg.append("/**")
+ } else {
+ stripped = strings.TrimPrefix(stripped, "/*")
+ sg.append("/*")
+ }
+
+ stripped = strings.TrimSuffix(stripped, "*/")
+ stripped = strings.TrimSpace(stripped)
+
+ requireMultiline := len(stripped) > MaxSingleLineCommentLength || strings.ContainsRune(stripped, '\n')
+
+ if requireMultiline {
+ sg.appendLine()
+ scanner := bufio.NewScanner(strings.NewReader(stripped))
+ for scanner.Scan() {
+ sg.append(" * ")
+ sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(scanner.Text()), "*")))
+ sg.appendLine()
+ }
+ sg.append(" */")
+ sg.appendLine()
+ } else {
+ sg.append(" ")
+ sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(stripped), "*")))
+ sg.append(" */")
+ sg.appendLine()
+ }
+
+ case strings.HasPrefix(comment, "//"):
+ sg.append("// ")
+ sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(comment), "//")))
+ sg.appendLine()
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator_impl.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator_impl.go
new file mode 100644
index 0000000..6251712
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator_impl.go
@@ -0,0 +1,83 @@
+package generator
+
+import (
+ "strings"
+
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+)
+
+type sourceGenerator struct {
+ buf strings.Builder // The buffer for the new source code.
+ indentationLevel int // The current indentation level.
+ hasNewline bool // Whether there is a newline at the end of the buffer.
+ hasBlankline bool // Whether there is a blank line at the end of the buffer.
+ hasIssue bool // Whether there is a translation issue.
+ hasNewScope bool // Whether there is a new scope at the end of the buffer.
+ existingLineLength int // Length of the existing line.
+ flags *mapz.Set[string] // The flags added while generating.
+ caveatTypeSet *caveattypes.TypeSet
+}
+
+// ensureBlankLineOrNewScope ensures that there is a blank line or new scope at the tail of the buffer. If not,
+// a new line is added.
+func (sg *sourceGenerator) ensureBlankLineOrNewScope() {
+ if !sg.hasBlankline && !sg.hasNewScope {
+ sg.appendLine()
+ }
+}
+
+// indent increases the current indentation.
+func (sg *sourceGenerator) indent() {
+ sg.indentationLevel = sg.indentationLevel + 1
+}
+
+// dedent decreases the current indentation.
+func (sg *sourceGenerator) dedent() {
+ sg.indentationLevel = sg.indentationLevel - 1
+}
+
+// appendIssue adds an issue found in generation.
+func (sg *sourceGenerator) appendIssue(description string) {
+ sg.append("/* ")
+ sg.append(description)
+ sg.append(" */")
+ sg.hasIssue = true
+}
+
+// append adds the given value to the buffer, indenting as necessary.
+func (sg *sourceGenerator) append(value string) {
+ for _, currentRune := range value {
+ if currentRune == '\n' {
+ if sg.hasNewline {
+ sg.hasBlankline = true
+ }
+
+ sg.buf.WriteRune('\n')
+ sg.hasNewline = true
+ sg.existingLineLength = 0
+ continue
+ }
+
+ sg.hasBlankline = false
+ sg.hasNewScope = false
+
+ if sg.hasNewline {
+ sg.buf.WriteString(strings.Repeat("\t", sg.indentationLevel))
+ sg.hasNewline = false
+ sg.existingLineLength += sg.indentationLevel
+ }
+
+ sg.existingLineLength++
+ sg.buf.WriteRune(currentRune)
+ }
+}
+
+// appendLine adds a newline.
+func (sg *sourceGenerator) appendLine() {
+ sg.append("\n")
+}
+
+func (sg *sourceGenerator) markNewScope() {
+ sg.hasNewScope = true
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/inputsource.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/inputsource.go
new file mode 100644
index 0000000..3bd0437
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/inputsource.go
@@ -0,0 +1,224 @@
+package input
+
+import (
+ "fmt"
+)
+
+// BytePosition represents the byte position in a piece of code.
+type BytePosition int
+
+// Position represents a position in an arbitrary source file.
+type Position struct {
+ // LineNumber is the 0-indexed line number.
+ LineNumber int
+
+ // ColumnPosition is the 0-indexed column position on the line.
+ ColumnPosition int
+}
+
+// Source represents the path of a source file.
+type Source string
+
+// RangeForRunePosition returns a source range over this source file.
+func (is Source) RangeForRunePosition(runePosition int, mapper PositionMapper) SourceRange {
+ return is.RangeForRunePositions(runePosition, runePosition, mapper)
+}
+
+// PositionForRunePosition returns a source position over this source file.
+func (is Source) PositionForRunePosition(runePosition int, mapper PositionMapper) SourcePosition {
+ return runeIndexedPosition{is, mapper, runePosition}
+}
+
+// PositionFromLineAndColumn returns a source position at the given line and column in this source file.
+func (is Source) PositionFromLineAndColumn(lineNumber int, columnPosition int, mapper PositionMapper) SourcePosition {
+ return lcIndexedPosition{is, mapper, Position{lineNumber, columnPosition}}
+}
+
+// RangeForRunePositions returns a source range over this source file.
+func (is Source) RangeForRunePositions(startRune int, endRune int, mapper PositionMapper) SourceRange {
+ return sourceRange{is, runeIndexedPosition{is, mapper, startRune}, runeIndexedPosition{is, mapper, endRune}}
+}
+
+// RangeForLineAndColPositions returns a source range over this source file.
+func (is Source) RangeForLineAndColPositions(start Position, end Position, mapper PositionMapper) SourceRange {
+ return sourceRange{is, lcIndexedPosition{is, mapper, start}, lcIndexedPosition{is, mapper, end}}
+}
+
+// PositionMapper defines an interface for converting rune position <-> line+col position
+// under source files.
+type PositionMapper interface {
+ // RunePositionToLineAndCol converts the given 0-indexed rune position under the given source file
+ // into a 0-indexed line number and column position.
+ RunePositionToLineAndCol(runePosition int, path Source) (int, int, error)
+
+ // LineAndColToRunePosition converts the given 0-indexed line number and column position under the
+ // given source file into a 0-indexed rune position.
+ LineAndColToRunePosition(lineNumber int, colPosition int, path Source) (int, error)
+
+ // TextForLine returns the text for the specified line number.
+ TextForLine(lineNumber int, path Source) (string, error)
+}
+
+// SourceRange represents a range inside a source file.
+type SourceRange interface {
+ // Source is the input source for this range.
+ Source() Source
+
+ // Start is the starting position of the source range.
+ Start() SourcePosition
+
+ // End is the ending position (inclusive) of the source range. If the same as the Start,
+ // this range represents a single position.
+ End() SourcePosition
+
+ // ContainsPosition returns true if the given range contains the given position.
+ ContainsPosition(position SourcePosition) (bool, error)
+
+ // AtStartPosition returns a SourceRange located only at the starting position of this range.
+ AtStartPosition() SourceRange
+
+ // String returns a (somewhat) human-readable form of the range.
+ String() string
+}
+
+// SourcePosition represents a single position in a source file.
+type SourcePosition interface {
+ // Source is the input source for this position.
+ Source() Source
+
+ // RunePosition returns the 0-indexed rune position in the source file.
+ RunePosition() (int, error)
+
+ // LineAndColumn returns the 0-indexed line number and column position in the source file.
+ LineAndColumn() (int, int, error)
+
+ // LineText returns the text of the line for this position.
+ LineText() (string, error)
+
+ // String returns a (somewhat) human-readable form of the position.
+ String() string
+}
+
+// sourceRange implements the SourceRange interface.
+type sourceRange struct {
+ source Source
+ start SourcePosition
+ end SourcePosition
+}
+
+func (sr sourceRange) Source() Source {
+ return sr.source
+}
+
+func (sr sourceRange) Start() SourcePosition {
+ return sr.start
+}
+
+func (sr sourceRange) End() SourcePosition {
+ return sr.end
+}
+
+func (sr sourceRange) AtStartPosition() SourceRange {
+ return sourceRange{sr.source, sr.start, sr.end}
+}
+
+func (sr sourceRange) ContainsPosition(position SourcePosition) (bool, error) {
+ if position.Source() != sr.source {
+ return false, nil
+ }
+
+ startRune, err := sr.start.RunePosition()
+ if err != nil {
+ return false, err
+ }
+
+ endRune, err := sr.end.RunePosition()
+ if err != nil {
+ return false, err
+ }
+
+ positionRune, err := position.RunePosition()
+ if err != nil {
+ return false, err
+ }
+
+ return positionRune >= startRune && positionRune <= endRune, nil
+}
+
+func (sr sourceRange) String() string {
+ return fmt.Sprintf("%v -> %v", sr.start, sr.end)
+}
+
+// runeIndexedPosition implements the SourcePosition interface over a rune position.
+type runeIndexedPosition struct {
+ source Source
+ mapper PositionMapper
+ runePosition int
+}
+
+func (ris runeIndexedPosition) Source() Source {
+ return ris.source
+}
+
+func (ris runeIndexedPosition) RunePosition() (int, error) {
+ return ris.runePosition, nil
+}
+
+func (ris runeIndexedPosition) LineAndColumn() (int, int, error) {
+ if ris.runePosition == 0 {
+ return 0, 0, nil
+ }
+ if ris.mapper == nil {
+ return -1, -1, fmt.Errorf("nil mapper")
+ }
+ return ris.mapper.RunePositionToLineAndCol(ris.runePosition, ris.source)
+}
+
+func (ris runeIndexedPosition) String() string {
+ return fmt.Sprintf("%s@%v (rune)", ris.source, ris.runePosition)
+}
+
+func (ris runeIndexedPosition) LineText() (string, error) {
+ lineNumber, _, err := ris.LineAndColumn()
+ if err != nil {
+ return "", err
+ }
+
+ return ris.mapper.TextForLine(lineNumber, ris.source)
+}
+
+// lcIndexedPosition implements the SourcePosition interface over a line and colu,n position.
+type lcIndexedPosition struct {
+ source Source
+ mapper PositionMapper
+ lcPosition Position
+}
+
+func (lcip lcIndexedPosition) Source() Source {
+ return lcip.source
+}
+
+func (lcip lcIndexedPosition) String() string {
+ return fmt.Sprintf("%s@%v:%v (line/col)", lcip.source, lcip.lcPosition.LineNumber, lcip.lcPosition.ColumnPosition)
+}
+
+func (lcip lcIndexedPosition) RunePosition() (int, error) {
+ if lcip.lcPosition.LineNumber == 0 && lcip.lcPosition.ColumnPosition == 0 {
+ return 0, nil
+ }
+ if lcip.mapper == nil {
+ return -1, fmt.Errorf("nil mapper")
+ }
+ return lcip.mapper.LineAndColToRunePosition(lcip.lcPosition.LineNumber, lcip.lcPosition.ColumnPosition, lcip.source)
+}
+
+func (lcip lcIndexedPosition) LineAndColumn() (int, int, error) {
+ return lcip.lcPosition.LineNumber, lcip.lcPosition.ColumnPosition, nil
+}
+
+func (lcip lcIndexedPosition) LineText() (string, error) {
+ if lcip.mapper == nil {
+ return "", fmt.Errorf("nil mapper")
+ }
+ return lcip.mapper.TextForLine(lcip.lcPosition.LineNumber, lcip.source)
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/sourcepositionmapper.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/sourcepositionmapper.go
new file mode 100644
index 0000000..1bca03c
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/sourcepositionmapper.go
@@ -0,0 +1,95 @@
+package input
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/emirpasic/gods/trees/redblacktree"
+)
+
+// SourcePositionMapper defines a helper struct for cached, faster lookup of rune position <->
+// (line, column) for a specific source file.
+type SourcePositionMapper struct {
+ // rangeTree holds a tree that maps from rune position to a line and start position.
+ rangeTree *redblacktree.Tree
+
+ // lineMap holds a map from line number to rune positions for that line.
+ lineMap map[int]inclusiveRange
+}
+
+// EmptySourcePositionMapper returns an empty source position mapper.
+func EmptySourcePositionMapper() SourcePositionMapper {
+ rangeTree := redblacktree.NewWith(inclusiveComparator)
+ return SourcePositionMapper{rangeTree, map[int]inclusiveRange{}}
+}
+
+// CreateSourcePositionMapper returns a source position mapper for the contents of a source file.
+func CreateSourcePositionMapper(contents []byte) SourcePositionMapper {
+ lines := strings.Split(string(contents), "\n")
+ rangeTree := redblacktree.NewWith(inclusiveComparator)
+ lineMap := map[int]inclusiveRange{}
+
+ currentStart := int(0)
+ for index, line := range lines {
+ lineEnd := currentStart + len(line)
+ rangeTree.Put(inclusiveRange{currentStart, lineEnd}, lineAndStart{index, currentStart})
+ lineMap[index] = inclusiveRange{currentStart, lineEnd}
+ currentStart = lineEnd + 1
+ }
+
+ return SourcePositionMapper{rangeTree, lineMap}
+}
+
+type inclusiveRange struct {
+ start int
+ end int
+}
+
+type lineAndStart struct {
+ lineNumber int
+ startPosition int
+}
+
+func inclusiveComparator(a, b interface{}) int {
+ i1 := a.(inclusiveRange)
+ i2 := b.(inclusiveRange)
+
+ if i1.start >= i2.start && i1.end <= i2.end {
+ return 0
+ }
+
+ diff := int64(i1.start) - int64(i2.start)
+
+ if diff < 0 {
+ return -1
+ }
+ if diff > 0 {
+ return 1
+ }
+ return 0
+}
+
+// RunePositionToLineAndCol returns the line number and column position of the rune position in source.
+func (spm SourcePositionMapper) RunePositionToLineAndCol(runePosition int) (int, int, error) {
+ ls, found := spm.rangeTree.Get(inclusiveRange{runePosition, runePosition})
+ if !found {
+ return 0, 0, fmt.Errorf("unknown rune position %v in source file", runePosition)
+ }
+
+ las := ls.(lineAndStart)
+ return las.lineNumber, runePosition - las.startPosition, nil
+}
+
+// LineAndColToRunePosition returns the rune position of the line number and column position in source.
+func (spm SourcePositionMapper) LineAndColToRunePosition(lineNumber int, colPosition int) (int, error) {
+ lineRuneInfo, hasLine := spm.lineMap[lineNumber]
+ if !hasLine {
+ return 0, fmt.Errorf("unknown line %v in source file", lineNumber)
+ }
+
+ if colPosition > lineRuneInfo.end-lineRuneInfo.start {
+ return 0, fmt.Errorf("column position %v not found on line %v in source file", colPosition, lineNumber)
+ }
+
+ return lineRuneInfo.start + colPosition, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flaggablelexer.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flaggablelexer.go
new file mode 100644
index 0000000..1ae99af
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flaggablelexer.go
@@ -0,0 +1,59 @@
+package lexer
+
+// FlaggableLexer wraps a lexer, automatically translating tokens based on flags, if any.
+type FlaggableLexer struct {
+ lex *Lexer // a reference to the lexer used for tokenization
+ enabledFlags map[string]transformer // flags that are enabled
+ seenDefinition bool
+ afterUseIdentifier bool
+}
+
+// NewFlaggableLexer returns a new FlaggableLexer for the given lexer.
+func NewFlaggableLexer(lex *Lexer) *FlaggableLexer {
+ return &FlaggableLexer{
+ lex: lex,
+ enabledFlags: map[string]transformer{},
+ }
+}
+
+// Close stops the lexer from running.
+func (l *FlaggableLexer) Close() {
+ l.lex.Close()
+}
+
+// NextToken returns the next token found in the lexer.
+func (l *FlaggableLexer) NextToken() Lexeme {
+ nextToken := l.lex.nextToken()
+
+ // Look for `use somefeature`
+ if nextToken.Kind == TokenTypeIdentifier {
+ // Only allowed until we've seen a definition of some kind.
+ if !l.seenDefinition {
+ if l.afterUseIdentifier {
+ if transformer, ok := Flags[nextToken.Value]; ok {
+ l.enabledFlags[nextToken.Value] = transformer
+ }
+
+ l.afterUseIdentifier = false
+ } else {
+ l.afterUseIdentifier = nextToken.Value == "use"
+ }
+ }
+ }
+
+ if nextToken.Kind == TokenTypeKeyword && nextToken.Value == "definition" {
+ l.seenDefinition = true
+ }
+ if nextToken.Kind == TokenTypeKeyword && nextToken.Value == "caveat" {
+ l.seenDefinition = true
+ }
+
+ for _, handler := range l.enabledFlags {
+ updated, ok := handler(nextToken)
+ if ok {
+ return updated
+ }
+ }
+
+ return nextToken
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flags.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flags.go
new file mode 100644
index 0000000..3bfbbde
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flags.go
@@ -0,0 +1,26 @@
+package lexer
+
+// FlagExpiration indicates that `expiration` is supported as a first-class
+// feature in the schema.
+const FlagExpiration = "expiration"
+
+type transformer func(lexeme Lexeme) (Lexeme, bool)
+
+// Flags is a map of flag names to their corresponding transformers.
+var Flags = map[string]transformer{
+ FlagExpiration: func(lexeme Lexeme) (Lexeme, bool) {
+ // `expiration` becomes a keyword.
+ if lexeme.Kind == TokenTypeIdentifier && lexeme.Value == "expiration" {
+ lexeme.Kind = TokenTypeKeyword
+ return lexeme, true
+ }
+
+ // `and` becomes a keyword.
+ if lexeme.Kind == TokenTypeIdentifier && lexeme.Value == "and" {
+ lexeme.Kind = TokenTypeKeyword
+ return lexeme, true
+ }
+
+ return lexeme, false
+ },
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex.go
new file mode 100644
index 0000000..a09df50
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex.go
@@ -0,0 +1,231 @@
+// Based on design first introduced in: http://blog.golang.org/two-go-talks-lexical-scanning-in-go-and
+// Portions copied and modified from: https://github.com/golang/go/blob/master/src/text/template/parse/lex.go
+
+package lexer
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+ "unicode/utf8"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+const EOFRUNE = -1
+
+// createLexer creates a new scanner for the input string.
+func createLexer(source input.Source, input string) *Lexer {
+ l := &Lexer{
+ source: source,
+ input: input,
+ tokens: make(chan Lexeme),
+ closed: make(chan struct{}),
+ }
+ go l.run()
+ return l
+}
+
+// run runs the state machine for the lexer.
+func (l *Lexer) run() {
+ defer func() {
+ close(l.tokens)
+ }()
+ l.withLock(func() {
+ l.state = lexSource
+ })
+ var state stateFn
+ for {
+ l.withRLock(func() {
+ state = l.state
+ })
+ if state == nil {
+ break
+ }
+ next := state(l)
+ l.withLock(func() {
+ l.state = next
+ })
+ }
+}
+
+// Close stops the lexer from running.
+func (l *Lexer) Close() {
+ close(l.closed)
+ l.withLock(func() {
+ l.state = nil
+ })
+}
+
+// withLock runs f protected by l's lock
+func (l *Lexer) withLock(f func()) {
+ l.Lock()
+ defer l.Unlock()
+ f()
+}
+
+// withRLock runs f protected by l's read lock
+func (l *Lexer) withRLock(f func()) {
+ l.RLock()
+ defer l.RUnlock()
+ f()
+}
+
+// Lexeme represents a token returned from scanning the contents of a file.
+type Lexeme struct {
+ Kind TokenType // The type of this lexeme.
+ Position input.BytePosition // The starting position of this token in the input string.
+ Value string // The textual value of this token.
+ Error string // The error associated with the lexeme, if any.
+}
+
+// stateFn represents the state of the scanner as a function that returns the next state.
+type stateFn func(*Lexer) stateFn
+
+// Lexer holds the state of the scanner.
+type Lexer struct {
+ sync.RWMutex
+ state stateFn // the next lexing function to enter. GUARDED_BY(RWMutex)
+
+ source input.Source // the name of the input; used only for error reports
+ input string // the string being scanned
+ pos input.BytePosition // current position in the input
+ start input.BytePosition // start position of this token
+ width input.BytePosition // width of last rune read from input
+ lastPos input.BytePosition // position of most recent token returned by nextToken
+ tokens chan Lexeme // channel of scanned lexemes
+ currentToken Lexeme // The current token if any
+ lastNonIgnoredToken Lexeme // The last token returned that is non-whitespace and non-comment
+ closed chan struct{} // Holds the closed channel
+}
+
+// nextToken returns the next token from the input.
+func (l *Lexer) nextToken() Lexeme {
+ token := <-l.tokens
+ l.lastPos = token.Position
+ return token
+}
+
+// next returns the next rune in the input.
+func (l *Lexer) next() rune {
+ if int(l.pos) >= len(l.input) {
+ l.width = 0
+ return EOFRUNE
+ }
+ r, w := utf8.DecodeRuneInString(l.input[l.pos:])
+ l.width = input.BytePosition(w)
+ l.pos += l.width
+ return r
+}
+
+// peek returns but does not consume the next rune in the input.
+func (l *Lexer) peek() rune {
+ r := l.next()
+ l.backup()
+ return r
+}
+
+// backup steps back one rune. Can only be called once per call of next.
+func (l *Lexer) backup() {
+ l.pos -= l.width
+}
+
+// value returns the current value of the token in the lexer.
+func (l *Lexer) value() string {
+ return l.input[l.start:l.pos]
+}
+
+// emit passes an token back to the client.
+func (l *Lexer) emit(t TokenType) {
+ currentToken := Lexeme{t, l.start, l.value(), ""}
+
+ if t != TokenTypeWhitespace && t != TokenTypeMultilineComment && t != TokenTypeSinglelineComment {
+ l.lastNonIgnoredToken = currentToken
+ }
+
+ select {
+ case l.tokens <- currentToken:
+ l.currentToken = currentToken
+ l.start = l.pos
+
+ case <-l.closed:
+ return
+ }
+}
+
+// errorf returns an error token and terminates the scan by passing
+// back a nil pointer that will be the next state, terminating l.nexttoken.
+func (l *Lexer) errorf(currentRune rune, format string, args ...interface{}) stateFn {
+ l.tokens <- Lexeme{TokenTypeError, l.start, string(currentRune), fmt.Sprintf(format, args...)}
+ return nil
+}
+
+// peekValue looks forward for the given value string. If found, returns true.
+func (l *Lexer) peekValue(value string) bool {
+ for index, runeValue := range value {
+ r := l.next()
+ if r != runeValue {
+ for j := 0; j <= index; j++ {
+ l.backup()
+ }
+ return false
+ }
+ }
+
+ for i := 0; i < len(value); i++ {
+ l.backup()
+ }
+
+ return true
+}
+
+// accept consumes the next rune if it's from the valid set.
+func (l *Lexer) accept(valid string) bool {
+ if nextRune := l.next(); strings.ContainsRune(valid, nextRune) {
+ return true
+ }
+ l.backup()
+ return false
+}
+
+// acceptString consumes the full given string, if the next tokens in the stream.
+func (l *Lexer) acceptString(value string) bool {
+ for index, runeValue := range value {
+ if l.next() != runeValue {
+ for i := 0; i <= index; i++ {
+ l.backup()
+ }
+
+ return false
+ }
+ }
+
+ return true
+}
+
+// lexSource scans until EOFRUNE
+func lexSource(l *Lexer) stateFn {
+ return lexerEntrypoint(l)
+}
+
+// checkFn returns whether a rune matches for continue looping.
+type checkFn func(r rune) (bool, error)
+
+func buildLexUntil(findType TokenType, checker checkFn) stateFn {
+ return func(l *Lexer) stateFn {
+ for {
+ r := l.next()
+ isValid, err := checker(r)
+ if err != nil {
+ return l.errorf(r, "%v", err)
+ }
+ if !isValid {
+ l.backup()
+ break
+ }
+ }
+
+ l.emit(findType)
+ return lexSource
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex_def.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex_def.go
new file mode 100644
index 0000000..d8cbe8c
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex_def.go
@@ -0,0 +1,351 @@
+//go:generate go run golang.org/x/tools/cmd/stringer -type=TokenType
+
+package lexer
+
+import (
+ "unicode"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+// Lex creates a new scanner for the input string.
+func Lex(source input.Source, input string) *Lexer {
+ return createLexer(source, input)
+}
+
+// TokenType identifies the type of lexer lexemes.
+type TokenType int
+
+const (
+ TokenTypeError TokenType = iota // error occurred; value is text of error
+
+ // Synthetic semicolon
+ TokenTypeSyntheticSemicolon
+
+ TokenTypeEOF
+ TokenTypeWhitespace
+ TokenTypeSinglelineComment
+ TokenTypeMultilineComment
+ TokenTypeNewline
+
+ TokenTypeKeyword // interface
+ TokenTypeIdentifier // helloworld
+ TokenTypeNumber // 123
+
+ TokenTypeLeftBrace // {
+ TokenTypeRightBrace // }
+ TokenTypeLeftParen // (
+ TokenTypeRightParen // )
+
+ TokenTypePipe // |
+ TokenTypePlus // +
+ TokenTypeMinus // -
+ TokenTypeAnd // &
+ TokenTypeDiv // /
+
+ TokenTypeEquals // =
+ TokenTypeColon // :
+ TokenTypeSemicolon // ;
+ TokenTypeRightArrow // ->
+ TokenTypeHash // #
+ TokenTypeEllipsis // ...
+ TokenTypeStar // *
+
+ // Additional tokens for CEL: https://github.com/google/cel-spec/blob/master/doc/langdef.md#syntax
+ TokenTypeQuestionMark // ?
+ TokenTypeConditionalOr // ||
+ TokenTypeConditionalAnd // &&
+ TokenTypeExclamationPoint // !
+ TokenTypeLeftBracket // [
+ TokenTypeRightBracket // ]
+ TokenTypePeriod // .
+ TokenTypeComma // ,
+ TokenTypePercent // %
+ TokenTypeLessThan // <
+ TokenTypeGreaterThan // >
+ TokenTypeLessThanOrEqual // <=
+ TokenTypeGreaterThanOrEqual // >=
+ TokenTypeEqualEqual // ==
+ TokenTypeNotEqual // !=
+ TokenTypeString // "...", '...', """...""", '''...'''
+)
+
+// keywords contains the full set of keywords supported.
+var keywords = map[string]struct{}{
+ "definition": {},
+ "caveat": {},
+ "relation": {},
+ "permission": {},
+ "nil": {},
+ "with": {},
+}
+
+// IsKeyword returns whether the specified input string is a reserved keyword.
+func IsKeyword(candidate string) bool {
+ _, ok := keywords[candidate]
+ return ok
+}
+
+// syntheticPredecessors contains the full set of token types after which, if a newline is found,
+// we emit a synthetic semicolon rather than a normal newline token.
+var syntheticPredecessors = map[TokenType]bool{
+ TokenTypeIdentifier: true,
+ TokenTypeKeyword: true,
+
+ TokenTypeRightBrace: true,
+ TokenTypeRightParen: true,
+
+ TokenTypeStar: true,
+}
+
+// lexerEntrypoint scans until EOFRUNE
+func lexerEntrypoint(l *Lexer) stateFn {
+Loop:
+ for {
+ switch r := l.next(); {
+ case r == EOFRUNE:
+ break Loop
+
+ case r == '{':
+ l.emit(TokenTypeLeftBrace)
+
+ case r == '}':
+ l.emit(TokenTypeRightBrace)
+
+ case r == '(':
+ l.emit(TokenTypeLeftParen)
+
+ case r == ')':
+ l.emit(TokenTypeRightParen)
+
+ case r == '+':
+ l.emit(TokenTypePlus)
+
+ case r == '|':
+ if l.acceptString("|") {
+ l.emit(TokenTypeConditionalOr)
+ } else {
+ l.emit(TokenTypePipe)
+ }
+
+ case r == '&':
+ if l.acceptString("&") {
+ l.emit(TokenTypeConditionalAnd)
+ } else {
+ l.emit(TokenTypeAnd)
+ }
+
+ case r == '?':
+ l.emit(TokenTypeQuestionMark)
+
+ case r == '!':
+ if l.acceptString("=") {
+ l.emit(TokenTypeNotEqual)
+ } else {
+ l.emit(TokenTypeExclamationPoint)
+ }
+
+ case r == '[':
+ l.emit(TokenTypeLeftBracket)
+
+ case r == ']':
+ l.emit(TokenTypeRightBracket)
+
+ case r == '%':
+ l.emit(TokenTypePercent)
+
+ case r == '<':
+ if l.acceptString("=") {
+ l.emit(TokenTypeLessThanOrEqual)
+ } else {
+ l.emit(TokenTypeLessThan)
+ }
+
+ case r == '>':
+ if l.acceptString("=") {
+ l.emit(TokenTypeGreaterThanOrEqual)
+ } else {
+ l.emit(TokenTypeGreaterThan)
+ }
+
+ case r == ',':
+ l.emit(TokenTypeComma)
+
+ case r == '=':
+ if l.acceptString("=") {
+ l.emit(TokenTypeEqualEqual)
+ } else {
+ l.emit(TokenTypeEquals)
+ }
+
+ case r == ':':
+ l.emit(TokenTypeColon)
+
+ case r == ';':
+ l.emit(TokenTypeSemicolon)
+
+ case r == '#':
+ l.emit(TokenTypeHash)
+
+ case r == '*':
+ l.emit(TokenTypeStar)
+
+ case r == '.':
+ if l.acceptString("..") {
+ l.emit(TokenTypeEllipsis)
+ } else {
+ l.emit(TokenTypePeriod)
+ }
+
+ case r == '-':
+ if l.accept(">") {
+ l.emit(TokenTypeRightArrow)
+ } else {
+ l.emit(TokenTypeMinus)
+ }
+
+ case isSpace(r):
+ l.emit(TokenTypeWhitespace)
+
+ case isNewline(r):
+ // If the previous token matches the synthetic semicolon list,
+ // we emit a synthetic semicolon instead of a simple newline.
+ if _, ok := syntheticPredecessors[l.lastNonIgnoredToken.Kind]; ok {
+ l.emit(TokenTypeSyntheticSemicolon)
+ } else {
+ l.emit(TokenTypeNewline)
+ }
+
+ case isAlphaNumeric(r):
+ l.backup()
+ return lexIdentifierOrKeyword
+
+ case r == '\'' || r == '"':
+ l.backup()
+ return lexStringLiteral
+
+ case r == '/':
+ // Check for comments.
+ if l.peekValue("/") {
+ l.backup()
+ return lexSinglelineComment
+ }
+
+ if l.peekValue("*") {
+ l.backup()
+ return lexMultilineComment
+ }
+
+ l.emit(TokenTypeDiv)
+ default:
+ return l.errorf(r, "unrecognized character at this location: %#U", r)
+ }
+ }
+
+ l.emit(TokenTypeEOF)
+ return nil
+}
+
+// lexStringLiteral scan until the close of the string literal or EOFRUNE
+func lexStringLiteral(l *Lexer) stateFn {
+ allowNewlines := false
+ terminator := ""
+
+ if l.acceptString(`"""`) {
+ terminator = `"""`
+ allowNewlines = true
+ } else if l.acceptString(`'''`) {
+ terminator = `"""`
+ allowNewlines = true
+ } else if l.acceptString(`"`) {
+ terminator = `"`
+ } else if l.acceptString(`'`) {
+ terminator = `'`
+ }
+
+ for {
+ if l.peekValue(terminator) {
+ l.acceptString(terminator)
+ l.emit(TokenTypeString)
+ return lexSource
+ }
+
+ // Otherwise, consume until we hit EOFRUNE.
+ r := l.next()
+ if !allowNewlines && isNewline(r) {
+ return l.errorf(r, "Unterminated string")
+ }
+
+ if r == EOFRUNE {
+ return l.errorf(r, "Unterminated string")
+ }
+ }
+}
+
+// lexSinglelineComment scans until newline or EOFRUNE
+func lexSinglelineComment(l *Lexer) stateFn {
+ checker := func(r rune) (bool, error) {
+ result := r == EOFRUNE || isNewline(r)
+ return !result, nil
+ }
+
+ l.acceptString("//")
+ return buildLexUntil(TokenTypeSinglelineComment, checker)
+}
+
+// lexMultilineComment scans until the close of the multiline comment or EOFRUNE
+func lexMultilineComment(l *Lexer) stateFn {
+ l.acceptString("/*")
+ for {
+ // Check for the end of the multiline comment.
+ if l.peekValue("*/") {
+ l.acceptString("*/")
+ l.emit(TokenTypeMultilineComment)
+ return lexSource
+ }
+
+ // Otherwise, consume until we hit EOFRUNE.
+ r := l.next()
+ if r == EOFRUNE {
+ return l.errorf(r, "Unterminated multiline comment")
+ }
+ }
+}
+
+// lexIdentifierOrKeyword searches for a keyword or literal identifier.
+func lexIdentifierOrKeyword(l *Lexer) stateFn {
+ for {
+ if !isAlphaNumeric(l.peek()) {
+ break
+ }
+
+ l.next()
+ }
+
+ _, isKeyword := keywords[l.value()]
+
+ switch {
+ case isKeyword:
+ l.emit(TokenTypeKeyword)
+
+ default:
+ l.emit(TokenTypeIdentifier)
+ }
+
+ return lexSource
+}
+
+// isSpace reports whether r is a space character.
+func isSpace(r rune) bool {
+ return r == ' ' || r == '\t'
+}
+
+// isNewline reports whether r is a newline character.
+func isNewline(r rune) bool {
+ return r == '\r' || r == '\n'
+}
+
+// isAlphaNumeric reports whether r is an alphabetic, digit, or underscore.
+func isAlphaNumeric(r rune) bool {
+ return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/tokentype_string.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/tokentype_string.go
new file mode 100644
index 0000000..79f3585
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/tokentype_string.go
@@ -0,0 +1,64 @@
+// Code generated by "stringer -type=TokenType"; DO NOT EDIT.
+
+package lexer
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[TokenTypeError-0]
+ _ = x[TokenTypeSyntheticSemicolon-1]
+ _ = x[TokenTypeEOF-2]
+ _ = x[TokenTypeWhitespace-3]
+ _ = x[TokenTypeSinglelineComment-4]
+ _ = x[TokenTypeMultilineComment-5]
+ _ = x[TokenTypeNewline-6]
+ _ = x[TokenTypeKeyword-7]
+ _ = x[TokenTypeIdentifier-8]
+ _ = x[TokenTypeNumber-9]
+ _ = x[TokenTypeLeftBrace-10]
+ _ = x[TokenTypeRightBrace-11]
+ _ = x[TokenTypeLeftParen-12]
+ _ = x[TokenTypeRightParen-13]
+ _ = x[TokenTypePipe-14]
+ _ = x[TokenTypePlus-15]
+ _ = x[TokenTypeMinus-16]
+ _ = x[TokenTypeAnd-17]
+ _ = x[TokenTypeDiv-18]
+ _ = x[TokenTypeEquals-19]
+ _ = x[TokenTypeColon-20]
+ _ = x[TokenTypeSemicolon-21]
+ _ = x[TokenTypeRightArrow-22]
+ _ = x[TokenTypeHash-23]
+ _ = x[TokenTypeEllipsis-24]
+ _ = x[TokenTypeStar-25]
+ _ = x[TokenTypeQuestionMark-26]
+ _ = x[TokenTypeConditionalOr-27]
+ _ = x[TokenTypeConditionalAnd-28]
+ _ = x[TokenTypeExclamationPoint-29]
+ _ = x[TokenTypeLeftBracket-30]
+ _ = x[TokenTypeRightBracket-31]
+ _ = x[TokenTypePeriod-32]
+ _ = x[TokenTypeComma-33]
+ _ = x[TokenTypePercent-34]
+ _ = x[TokenTypeLessThan-35]
+ _ = x[TokenTypeGreaterThan-36]
+ _ = x[TokenTypeLessThanOrEqual-37]
+ _ = x[TokenTypeGreaterThanOrEqual-38]
+ _ = x[TokenTypeEqualEqual-39]
+ _ = x[TokenTypeNotEqual-40]
+ _ = x[TokenTypeString-41]
+}
+
+const _TokenType_name = "TokenTypeErrorTokenTypeSyntheticSemicolonTokenTypeEOFTokenTypeWhitespaceTokenTypeSinglelineCommentTokenTypeMultilineCommentTokenTypeNewlineTokenTypeKeywordTokenTypeIdentifierTokenTypeNumberTokenTypeLeftBraceTokenTypeRightBraceTokenTypeLeftParenTokenTypeRightParenTokenTypePipeTokenTypePlusTokenTypeMinusTokenTypeAndTokenTypeDivTokenTypeEqualsTokenTypeColonTokenTypeSemicolonTokenTypeRightArrowTokenTypeHashTokenTypeEllipsisTokenTypeStarTokenTypeQuestionMarkTokenTypeConditionalOrTokenTypeConditionalAndTokenTypeExclamationPointTokenTypeLeftBracketTokenTypeRightBracketTokenTypePeriodTokenTypeCommaTokenTypePercentTokenTypeLessThanTokenTypeGreaterThanTokenTypeLessThanOrEqualTokenTypeGreaterThanOrEqualTokenTypeEqualEqualTokenTypeNotEqualTokenTypeString"
+
+var _TokenType_index = [...]uint16{0, 14, 41, 53, 72, 98, 123, 139, 155, 174, 189, 207, 226, 244, 263, 276, 289, 303, 315, 327, 342, 356, 374, 393, 406, 423, 436, 457, 479, 502, 527, 547, 568, 583, 597, 613, 630, 650, 674, 701, 720, 737, 752}
+
+func (i TokenType) String() string {
+ if i < 0 || i >= TokenType(len(_TokenType_index)-1) {
+ return "TokenType(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _TokenType_name[_TokenType_index[i]:_TokenType_index[i+1]]
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/nodestack.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/nodestack.go
new file mode 100644
index 0000000..0d49e09
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/nodestack.go
@@ -0,0 +1,35 @@
+package parser
+
+type nodeStack struct {
+ top *element
+ size int
+}
+
+type element struct {
+ value AstNode
+ next *element
+}
+
+func (s *nodeStack) topValue() AstNode {
+ if s.size == 0 {
+ return nil
+ }
+
+ return s.top.value
+}
+
+// Push pushes a node onto the stack.
+func (s *nodeStack) push(value AstNode) {
+ s.top = &element{value, s.top}
+ s.size++
+}
+
+// Pop removes the node from the stack and returns it.
+func (s *nodeStack) pop() (value AstNode) {
+ if s.size > 0 {
+ value, s.top = s.top.value, s.top.next
+ s.size--
+ return
+ }
+ return nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser.go
new file mode 100644
index 0000000..100ad9b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser.go
@@ -0,0 +1,658 @@
+// parser package defines the parser for the Authzed Schema DSL.
+package parser
+
+import (
+ "strings"
+
+ "golang.org/x/exp/maps"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+ "github.com/authzed/spicedb/pkg/schemadsl/lexer"
+)
+
+// Parse parses the given Schema DSL source into a parse tree.
+func Parse(builder NodeBuilder, source input.Source, input string) AstNode {
+ lx := lexer.Lex(source, input)
+ parser := buildParser(lx, builder, source, input)
+ defer parser.close()
+ return parser.consumeTopLevel()
+}
+
+// ignoredTokenTypes are those tokens ignored when parsing.
+var ignoredTokenTypes = map[lexer.TokenType]bool{
+ lexer.TokenTypeWhitespace: true,
+ lexer.TokenTypeNewline: true,
+ lexer.TokenTypeSinglelineComment: true,
+ lexer.TokenTypeMultilineComment: true,
+}
+
+// consumeTopLevel attempts to consume the top-level definitions.
+func (p *sourceParser) consumeTopLevel() AstNode {
+ rootNode := p.startNode(dslshape.NodeTypeFile)
+ defer p.mustFinishNode()
+
+ // Start at the first token.
+ p.consumeToken()
+
+ if p.currentToken.Kind == lexer.TokenTypeError {
+ p.emitErrorf("%s", p.currentToken.Value)
+ return rootNode
+ }
+
+ hasSeenDefinition := false
+
+Loop:
+ for {
+ if p.isToken(lexer.TokenTypeEOF) {
+ break Loop
+ }
+
+ // Consume a statement terminator if one was found.
+ p.tryConsumeStatementTerminator()
+
+ if p.isToken(lexer.TokenTypeEOF) {
+ break Loop
+ }
+
+ // The top level of the DSL is a set of definitions and caveats:
+ // definition foobar { ... }
+ // caveat somecaveat (...) { ... }
+
+ switch {
+ case p.isIdentifier("use"):
+ rootNode.Connect(dslshape.NodePredicateChild, p.consumeUseFlag(hasSeenDefinition))
+
+ case p.isKeyword("definition"):
+ hasSeenDefinition = true
+ rootNode.Connect(dslshape.NodePredicateChild, p.consumeDefinition())
+
+ case p.isKeyword("caveat"):
+ hasSeenDefinition = true
+ rootNode.Connect(dslshape.NodePredicateChild, p.consumeCaveat())
+
+ default:
+ p.emitErrorf("Unexpected token at root level: %v", p.currentToken.Kind)
+ break Loop
+ }
+ }
+
+ return rootNode
+}
+
+// consumeCaveat attempts to consume a single caveat definition.
+// ```caveat somecaveat(param1 type, param2 type) { ... }```
+func (p *sourceParser) consumeCaveat() AstNode {
+ defNode := p.startNode(dslshape.NodeTypeCaveatDefinition)
+ defer p.mustFinishNode()
+
+ // caveat ...
+ p.consumeKeyword("caveat")
+ caveatName, ok := p.consumeTypePath()
+ if !ok {
+ return defNode
+ }
+
+ defNode.MustDecorate(dslshape.NodeCaveatDefinitionPredicateName, caveatName)
+
+ // Parameters:
+ // (
+ _, ok = p.consume(lexer.TokenTypeLeftParen)
+ if !ok {
+ return defNode
+ }
+
+ for {
+ paramNode, ok := p.consumeCaveatParameter()
+ if !ok {
+ return defNode
+ }
+
+ defNode.Connect(dslshape.NodeCaveatDefinitionPredicateParameters, paramNode)
+ if _, ok := p.tryConsume(lexer.TokenTypeComma); !ok {
+ break
+ }
+ }
+
+ // )
+ _, ok = p.consume(lexer.TokenTypeRightParen)
+ if !ok {
+ return defNode
+ }
+
+ // {
+ _, ok = p.consume(lexer.TokenTypeLeftBrace)
+ if !ok {
+ return defNode
+ }
+
+ exprNode, ok := p.consumeCaveatExpression()
+ if !ok {
+ return defNode
+ }
+
+ defNode.Connect(dslshape.NodeCaveatDefinitionPredicateExpession, exprNode)
+
+ // }
+ _, ok = p.consume(lexer.TokenTypeRightBrace)
+ if !ok {
+ return defNode
+ }
+
+ return defNode
+}
+
+func (p *sourceParser) consumeCaveatExpression() (AstNode, bool) {
+ exprNode := p.startNode(dslshape.NodeTypeCaveatExpression)
+ defer p.mustFinishNode()
+
+ // Special Logic Note: Since CEL is its own language, we consume here until we have a matching
+ // close brace, and then pass ALL the found tokens to CEL's own parser to attach the expression
+ // here.
+ braceDepth := 1 // Starting at 1 from the open brace above
+ var startToken *commentedLexeme
+ var endToken *commentedLexeme
+consumer:
+ for {
+ currentToken := p.currentToken
+
+ switch currentToken.Kind {
+ case lexer.TokenTypeLeftBrace:
+ braceDepth++
+
+ case lexer.TokenTypeRightBrace:
+ if braceDepth == 1 {
+ break consumer
+ }
+
+ braceDepth--
+
+ case lexer.TokenTypeError:
+ break consumer
+
+ case lexer.TokenTypeEOF:
+ break consumer
+ }
+
+ if startToken == nil {
+ startToken = &currentToken
+ }
+
+ endToken = &currentToken
+ p.consumeToken()
+ }
+
+ if startToken == nil {
+ p.emitErrorf("missing caveat expression")
+ return exprNode, false
+ }
+
+ caveatExpression := p.input[startToken.Position : int(endToken.Position)+len(endToken.Value)]
+ exprNode.MustDecorate(dslshape.NodeCaveatExpressionPredicateExpression, caveatExpression)
+ return exprNode, true
+}
+
+// consumeCaveatParameter attempts to consume a caveat parameter.
+// ```(paramName paramtype)```
+func (p *sourceParser) consumeCaveatParameter() (AstNode, bool) {
+ paramNode := p.startNode(dslshape.NodeTypeCaveatParameter)
+ defer p.mustFinishNode()
+
+ name, ok := p.consumeIdentifier()
+ if !ok {
+ return paramNode, false
+ }
+
+ paramNode.MustDecorate(dslshape.NodeCaveatParameterPredicateName, name)
+ paramNode.Connect(dslshape.NodeCaveatParameterPredicateType, p.consumeCaveatTypeReference())
+ return paramNode, true
+}
+
+// consumeCaveatTypeReference attempts to consume a caveat type reference.
+// ```typeName<childType>```
+func (p *sourceParser) consumeCaveatTypeReference() AstNode {
+ typeRefNode := p.startNode(dslshape.NodeTypeCaveatTypeReference)
+ defer p.mustFinishNode()
+
+ name, ok := p.consumeIdentifier()
+ if !ok {
+ return typeRefNode
+ }
+
+ typeRefNode.MustDecorate(dslshape.NodeCaveatTypeReferencePredicateType, name)
+
+ // Check for child type(s).
+ // <
+ if _, ok := p.tryConsume(lexer.TokenTypeLessThan); !ok {
+ return typeRefNode
+ }
+
+ for {
+ childTypeRef := p.consumeCaveatTypeReference()
+ typeRefNode.Connect(dslshape.NodeCaveatTypeReferencePredicateChildTypes, childTypeRef)
+ if _, ok := p.tryConsume(lexer.TokenTypeComma); !ok {
+ break
+ }
+ }
+
+ // >
+ p.consume(lexer.TokenTypeGreaterThan)
+ return typeRefNode
+}
+
+// consumeUseFlag attempts to consume a use flag.
+// ``` use flagname ```
+func (p *sourceParser) consumeUseFlag(afterDefinition bool) AstNode {
+ useNode := p.startNode(dslshape.NodeTypeUseFlag)
+ defer p.mustFinishNode()
+
+ // consume the `use`
+ p.consumeIdentifier()
+
+ var useFlag string
+ if p.isToken(lexer.TokenTypeIdentifier) {
+ useFlag, _ = p.consumeIdentifier()
+ } else {
+ useName, ok := p.consumeVariableKeyword()
+ if !ok {
+ return useNode
+ }
+ useFlag = useName
+ }
+
+ if _, ok := lexer.Flags[useFlag]; !ok {
+ p.emitErrorf("Unknown use flag: `%s`. Options are: %s", useFlag, strings.Join(maps.Keys(lexer.Flags), ", "))
+ return useNode
+ }
+
+ useNode.MustDecorate(dslshape.NodeUseFlagPredicateName, useFlag)
+
+ // NOTE: we conduct this check in `consumeFlag` rather than at
+ // the callsite to keep the callsite clean.
+ // We also do the check after consumption to ensure that the parser continues
+ // moving past the use expression.
+ if afterDefinition {
+ p.emitErrorf("`use` expressions must be declared before any definition")
+ return useNode
+ }
+
+ return useNode
+}
+
+// consumeDefinition attempts to consume a single schema definition.
+// ```definition somedef { ... }```
+func (p *sourceParser) consumeDefinition() AstNode {
+ defNode := p.startNode(dslshape.NodeTypeDefinition)
+ defer p.mustFinishNode()
+
+ // definition ...
+ p.consumeKeyword("definition")
+ definitionName, ok := p.consumeTypePath()
+ if !ok {
+ return defNode
+ }
+
+ defNode.MustDecorate(dslshape.NodeDefinitionPredicateName, definitionName)
+
+ // {
+ _, ok = p.consume(lexer.TokenTypeLeftBrace)
+ if !ok {
+ return defNode
+ }
+
+ // Relations and permissions.
+ for {
+ // }
+ if _, ok := p.tryConsume(lexer.TokenTypeRightBrace); ok {
+ break
+ }
+
+ // relation ...
+ // permission ...
+ switch {
+ case p.isKeyword("relation"):
+ defNode.Connect(dslshape.NodePredicateChild, p.consumeRelation())
+
+ case p.isKeyword("permission"):
+ defNode.Connect(dslshape.NodePredicateChild, p.consumePermission())
+ }
+
+ ok := p.consumeStatementTerminator()
+ if !ok {
+ break
+ }
+ }
+
+ return defNode
+}
+
+// consumeRelation consumes a relation.
+// ```relation foo: sometype```
+func (p *sourceParser) consumeRelation() AstNode {
+ relNode := p.startNode(dslshape.NodeTypeRelation)
+ defer p.mustFinishNode()
+
+ // relation ...
+ p.consumeKeyword("relation")
+ relationName, ok := p.consumeIdentifier()
+ if !ok {
+ return relNode
+ }
+
+ relNode.MustDecorate(dslshape.NodePredicateName, relationName)
+
+ // :
+ _, ok = p.consume(lexer.TokenTypeColon)
+ if !ok {
+ return relNode
+ }
+
+ // Relation allowed type(s).
+ relNode.Connect(dslshape.NodeRelationPredicateAllowedTypes, p.consumeTypeReference())
+
+ return relNode
+}
+
+// consumeTypeReference consumes a reference to a type or types of relations.
+// ```sometype | anothertype | anothertype:* ```
+func (p *sourceParser) consumeTypeReference() AstNode {
+ refNode := p.startNode(dslshape.NodeTypeTypeReference)
+ defer p.mustFinishNode()
+
+ for {
+ refNode.Connect(dslshape.NodeTypeReferencePredicateType, p.consumeSpecificTypeWithCaveat())
+ if _, ok := p.tryConsume(lexer.TokenTypePipe); !ok {
+ break
+ }
+ }
+
+ return refNode
+}
+
+// tryConsumeWithCaveat tries to consume a caveat `with` expression.
+func (p *sourceParser) tryConsumeWithCaveat() (AstNode, bool) {
+ caveatNode := p.startNode(dslshape.NodeTypeCaveatReference)
+ defer p.mustFinishNode()
+
+ consumed, ok := p.consumeTypePath()
+ if !ok {
+ return caveatNode, true
+ }
+
+ caveatNode.MustDecorate(dslshape.NodeCaveatPredicateCaveat, consumed)
+ return caveatNode, true
+}
+
+// consumeSpecificTypeWithCaveat consumes an identifier as a specific type reference, with optional caveat.
+func (p *sourceParser) consumeSpecificTypeWithCaveat() AstNode {
+ specificNode := p.consumeSpecificTypeWithoutFinish()
+ defer p.mustFinishNode()
+
+ // Check for a caveat and/or supported trait.
+ if !p.isKeyword("with") {
+ return specificNode
+ }
+
+ p.consumeKeyword("with")
+
+ if !p.isKeyword("expiration") {
+ caveatNode, ok := p.tryConsumeWithCaveat()
+ if ok {
+ specificNode.Connect(dslshape.NodeSpecificReferencePredicateCaveat, caveatNode)
+ }
+
+ if !p.tryConsumeKeyword("and") {
+ return specificNode
+ }
+ }
+
+ if p.isKeyword("expiration") {
+ // Check for expiration trait.
+ traitNode := p.consumeExpirationTrait()
+
+ // Decorate with the expiration trait.
+ specificNode.Connect(dslshape.NodeSpecificReferencePredicateTrait, traitNode)
+ }
+
+ return specificNode
+}
+
+// consumeExpirationTrait consumes an expiration trait.
+func (p *sourceParser) consumeExpirationTrait() AstNode {
+ expirationTraitNode := p.startNode(dslshape.NodeTypeTraitReference)
+ p.consumeKeyword("expiration")
+
+ expirationTraitNode.MustDecorate(dslshape.NodeTraitPredicateTrait, "expiration")
+ defer p.mustFinishNode()
+
+ return expirationTraitNode
+}
+
+// consumeSpecificTypeOpen consumes an identifier as a specific type reference.
+func (p *sourceParser) consumeSpecificTypeWithoutFinish() AstNode {
+ specificNode := p.startNode(dslshape.NodeTypeSpecificTypeReference)
+
+ typeName, ok := p.consumeTypePath()
+ if !ok {
+ return specificNode
+ }
+
+ specificNode.MustDecorate(dslshape.NodeSpecificReferencePredicateType, typeName)
+
+ // Check for a wildcard
+ if _, ok := p.tryConsume(lexer.TokenTypeColon); ok {
+ _, ok := p.consume(lexer.TokenTypeStar)
+ if !ok {
+ return specificNode
+ }
+
+ specificNode.MustDecorate(dslshape.NodeSpecificReferencePredicateWildcard, "true")
+ return specificNode
+ }
+
+ // Check for a relation specified.
+ if _, ok := p.tryConsume(lexer.TokenTypeHash); !ok {
+ return specificNode
+ }
+
+ // Consume an identifier or an ellipsis.
+ consumed, ok := p.consume(lexer.TokenTypeIdentifier, lexer.TokenTypeEllipsis)
+ if !ok {
+ return specificNode
+ }
+
+ specificNode.MustDecorate(dslshape.NodeSpecificReferencePredicateRelation, consumed.Value)
+ return specificNode
+}
+
+func (p *sourceParser) consumeTypePath() (string, bool) {
+ var segments []string
+
+ for {
+ segment, ok := p.consumeIdentifier()
+ if !ok {
+ return "", false
+ }
+
+ segments = append(segments, segment)
+
+ _, ok = p.tryConsume(lexer.TokenTypeDiv)
+ if !ok {
+ break
+ }
+ }
+
+ return strings.Join(segments, "/"), true
+}
+
+// consumePermission consumes a permission.
+// ```permission foo = bar + baz```
+func (p *sourceParser) consumePermission() AstNode {
+ permNode := p.startNode(dslshape.NodeTypePermission)
+ defer p.mustFinishNode()
+
+ // permission ...
+ p.consumeKeyword("permission")
+ permissionName, ok := p.consumeIdentifier()
+ if !ok {
+ return permNode
+ }
+
+ permNode.MustDecorate(dslshape.NodePredicateName, permissionName)
+
+ // =
+ _, ok = p.consume(lexer.TokenTypeEquals)
+ if !ok {
+ return permNode
+ }
+
+ permNode.Connect(dslshape.NodePermissionPredicateComputeExpression, p.consumeComputeExpression())
+ return permNode
+}
+
+// ComputeExpressionOperators defines the binary operators in precedence order.
+var ComputeExpressionOperators = []binaryOpDefinition{
+ {lexer.TokenTypeMinus, dslshape.NodeTypeExclusionExpression},
+ {lexer.TokenTypeAnd, dslshape.NodeTypeIntersectExpression},
+ {lexer.TokenTypePlus, dslshape.NodeTypeUnionExpression},
+}
+
+// consumeComputeExpression consumes an expression for computing a permission.
+func (p *sourceParser) consumeComputeExpression() AstNode {
+ // Compute expressions consist of a set of binary operators, so build a tree with proper
+ // precedence.
+ binaryParser := p.buildBinaryOperatorExpressionFnTree(ComputeExpressionOperators)
+ found, ok := binaryParser()
+ if !ok {
+ return p.createErrorNodef("Expected compute expression for permission")
+ }
+ return found
+}
+
+// tryConsumeComputeExpression attempts to consume a nested compute expression.
+func (p *sourceParser) tryConsumeComputeExpression(subTryExprFn tryParserFn, binaryTokenType lexer.TokenType, nodeType dslshape.NodeType) (AstNode, bool) {
+ rightNodeBuilder := func(leftNode AstNode, operatorToken lexer.Lexeme) (AstNode, bool) {
+ rightNode, ok := subTryExprFn()
+ if !ok {
+ return nil, false
+ }
+
+ // Create the expression node representing the binary expression.
+ exprNode := p.createNode(nodeType)
+ exprNode.Connect(dslshape.NodeExpressionPredicateLeftExpr, leftNode)
+ exprNode.Connect(dslshape.NodeExpressionPredicateRightExpr, rightNode)
+ return exprNode, true
+ }
+ return p.performLeftRecursiveParsing(subTryExprFn, rightNodeBuilder, nil, binaryTokenType)
+}
+
+// tryConsumeArrowExpression attempts to consume an arrow expression.
+// ```foo->bar->baz->meh```
+func (p *sourceParser) tryConsumeArrowExpression() (AstNode, bool) {
+ rightNodeBuilder := func(leftNode AstNode, operatorToken lexer.Lexeme) (AstNode, bool) {
+ // Check for an arrow function.
+ if operatorToken.Kind == lexer.TokenTypePeriod {
+ functionName, ok := p.consumeIdentifier()
+ if !ok {
+ return nil, false
+ }
+
+ // TODO(jschorr): Change to keywords in schema v2.
+ if functionName != "any" && functionName != "all" {
+ p.emitErrorf("Expected 'any' or 'all' for arrow function, found: %s", functionName)
+ return nil, false
+ }
+
+ if _, ok := p.consume(lexer.TokenTypeLeftParen); !ok {
+ return nil, false
+ }
+
+ rightNode, ok := p.tryConsumeIdentifierLiteral()
+ if !ok {
+ return nil, false
+ }
+
+ if _, ok := p.consume(lexer.TokenTypeRightParen); !ok {
+ return nil, false
+ }
+
+ exprNode := p.createNode(dslshape.NodeTypeArrowExpression)
+ exprNode.Connect(dslshape.NodeExpressionPredicateLeftExpr, leftNode)
+ exprNode.Connect(dslshape.NodeExpressionPredicateRightExpr, rightNode)
+ exprNode.MustDecorate(dslshape.NodeArrowExpressionFunctionName, functionName)
+ return exprNode, true
+ }
+
+ rightNode, ok := p.tryConsumeIdentifierLiteral()
+ if !ok {
+ return nil, false
+ }
+
+ // Create the expression node representing the binary expression.
+ exprNode := p.createNode(dslshape.NodeTypeArrowExpression)
+ exprNode.Connect(dslshape.NodeExpressionPredicateLeftExpr, leftNode)
+ exprNode.Connect(dslshape.NodeExpressionPredicateRightExpr, rightNode)
+ return exprNode, true
+ }
+ return p.performLeftRecursiveParsing(p.tryConsumeIdentifierLiteral, rightNodeBuilder, nil, lexer.TokenTypeRightArrow, lexer.TokenTypePeriod)
+}
+
+// tryConsumeBaseExpression attempts to consume base compute expressions (identifiers, parenthesis).
+// ```(foo + bar)```
+// ```(foo)```
+// ```foo```
+// ```nil```
+func (p *sourceParser) tryConsumeBaseExpression() (AstNode, bool) {
+ switch {
+ // Nested expression.
+ case p.isToken(lexer.TokenTypeLeftParen):
+ comments := p.currentToken.comments
+
+ p.consume(lexer.TokenTypeLeftParen)
+ exprNode := p.consumeComputeExpression()
+ p.consume(lexer.TokenTypeRightParen)
+
+ // Attach any comments found to the consumed expression.
+ p.decorateComments(exprNode, comments)
+
+ return exprNode, true
+
+ // Nil expression.
+ case p.isKeyword("nil"):
+ return p.tryConsumeNilExpression()
+
+ // Identifier.
+ case p.isToken(lexer.TokenTypeIdentifier):
+ return p.tryConsumeIdentifierLiteral()
+ }
+
+ return nil, false
+}
+
+// tryConsumeIdentifierLiteral attempts to consume an identifier as a literal
+// expression.
+//
+// ```foo```
+func (p *sourceParser) tryConsumeIdentifierLiteral() (AstNode, bool) {
+ if !p.isToken(lexer.TokenTypeIdentifier) {
+ return nil, false
+ }
+
+ identNode := p.startNode(dslshape.NodeTypeIdentifier)
+ defer p.mustFinishNode()
+
+ identifier, _ := p.consumeIdentifier()
+ identNode.MustDecorate(dslshape.NodeIdentiferPredicateValue, identifier)
+ return identNode, true
+}
+
+func (p *sourceParser) tryConsumeNilExpression() (AstNode, bool) {
+ if !p.isKeyword("nil") {
+ return nil, false
+ }
+
+ node := p.startNode(dslshape.NodeTypeNilExpression)
+ p.consumeKeyword("nil")
+ defer p.mustFinishNode()
+ return node, true
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser_impl.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser_impl.go
new file mode 100644
index 0000000..1009b6b
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser_impl.go
@@ -0,0 +1,367 @@
+package parser
+
+import (
+ "fmt"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+ "github.com/authzed/spicedb/pkg/schemadsl/lexer"
+)
+
+// AstNode defines an interface for working with nodes created by this parser.
+type AstNode interface {
+ // Connect connects this AstNode to another AstNode with the given predicate.
+ Connect(predicate string, other AstNode)
+
+ // MustDecorate decorates this AstNode with the given property and string value, returning
+ // the same node.
+ MustDecorate(property string, value string) AstNode
+
+ // MustDecorateWithInt decorates this AstNode with the given property and int value, returning
+ // the same node.
+ MustDecorateWithInt(property string, value int) AstNode
+}
+
+// NodeBuilder is a function for building AST nodes.
+type NodeBuilder func(source input.Source, kind dslshape.NodeType) AstNode
+
+// tryParserFn is a function that attempts to build an AST node.
+type tryParserFn func() (AstNode, bool)
+
+// lookaheadParserFn is a function that performs lookahead.
+type lookaheadParserFn func(currentToken lexer.Lexeme) bool
+
+// rightNodeConstructor is a function which takes in a left expr node and the
+// token consumed for a left-recursive operator, and returns a newly constructed
+// operator expression if a right expression could be found.
+type rightNodeConstructor func(AstNode, lexer.Lexeme) (AstNode, bool)
+
+// commentedLexeme is a lexer.Lexeme with comments attached.
+type commentedLexeme struct {
+ lexer.Lexeme
+ comments []string
+}
+
+// sourceParser holds the state of the parser.
+type sourceParser struct {
+ source input.Source // the name of the input; used only for error reports
+ input string // the input string itself
+ lex *lexer.FlaggableLexer // a reference to the lexer used for tokenization
+ builder NodeBuilder // the builder function for creating AstNode instances
+ nodes *nodeStack // the stack of the current nodes
+ currentToken commentedLexeme // the current token
+ previousToken commentedLexeme // the previous token
+}
+
+// buildParser returns a new sourceParser instance.
+func buildParser(lx *lexer.Lexer, builder NodeBuilder, source input.Source, input string) *sourceParser {
+ l := lexer.NewFlaggableLexer(lx)
+ return &sourceParser{
+ source: source,
+ input: input,
+ lex: l,
+ builder: builder,
+ nodes: &nodeStack{},
+ currentToken: commentedLexeme{lexer.Lexeme{Kind: lexer.TokenTypeEOF}, make([]string, 0)},
+ previousToken: commentedLexeme{lexer.Lexeme{Kind: lexer.TokenTypeEOF}, make([]string, 0)},
+ }
+}
+
+func (p *sourceParser) close() {
+ p.lex.Close()
+}
+
+// createNode creates a new AstNode and returns it.
+func (p *sourceParser) createNode(kind dslshape.NodeType) AstNode {
+ return p.builder(p.source, kind)
+}
+
+// createErrorNodef creates a new error node and returns it.
+func (p *sourceParser) createErrorNodef(format string, args ...interface{}) AstNode {
+ message := fmt.Sprintf(format, args...)
+ node := p.startNode(dslshape.NodeTypeError).MustDecorate(dslshape.NodePredicateErrorMessage, message)
+ p.mustFinishNode()
+ return node
+}
+
+// startNode creates a new node of the given type, decorates it with the current token's
+// position as its start position, and pushes it onto the nodes stack.
+func (p *sourceParser) startNode(kind dslshape.NodeType) AstNode {
+ node := p.createNode(kind)
+ p.decorateStartRuneAndComments(node, p.currentToken)
+ p.nodes.push(node)
+ return node
+}
+
+// decorateStartRuneAndComments decorates the given node with the location of the given token as its
+// starting rune, as well as any comments attached to the token.
+func (p *sourceParser) decorateStartRuneAndComments(node AstNode, token commentedLexeme) {
+ node.MustDecorate(dslshape.NodePredicateSource, string(p.source))
+ node.MustDecorateWithInt(dslshape.NodePredicateStartRune, int(token.Position))
+ p.decorateComments(node, token.comments)
+}
+
+// decorateComments decorates the given node with the specified comments.
+func (p *sourceParser) decorateComments(node AstNode, comments []string) {
+ for _, comment := range comments {
+ commentNode := p.createNode(dslshape.NodeTypeComment)
+ commentNode.MustDecorate(dslshape.NodeCommentPredicateValue, comment)
+ node.Connect(dslshape.NodePredicateChild, commentNode)
+ }
+}
+
+// decorateEndRune decorates the given node with the location of the given token as its
+// ending rune.
+func (p *sourceParser) decorateEndRune(node AstNode, token commentedLexeme) {
+ position := int(token.Position) + len(token.Value) - 1
+ node.MustDecorateWithInt(dslshape.NodePredicateEndRune, position)
+}
+
+// currentNode returns the node at the top of the stack.
+func (p *sourceParser) currentNode() AstNode {
+ return p.nodes.topValue()
+}
+
+// mustFinishNode pops the current node from the top of the stack and decorates it with
+// the current token's end position as its end position.
+func (p *sourceParser) mustFinishNode() {
+ if p.currentNode() == nil {
+ panic(fmt.Sprintf("No current node on stack. Token: %s", p.currentToken.Value))
+ }
+
+ p.decorateEndRune(p.currentNode(), p.previousToken)
+ p.nodes.pop()
+}
+
+// consumeToken advances the lexer forward, returning the next token.
+func (p *sourceParser) consumeToken() commentedLexeme {
+ comments := make([]string, 0)
+
+ for {
+ token := p.lex.NextToken()
+
+ if token.Kind == lexer.TokenTypeSinglelineComment || token.Kind == lexer.TokenTypeMultilineComment {
+ comments = append(comments, token.Value)
+ }
+
+ if _, ok := ignoredTokenTypes[token.Kind]; !ok {
+ p.previousToken = p.currentToken
+ p.currentToken = commentedLexeme{token, comments}
+ return p.currentToken
+ }
+ }
+}
+
+// isToken returns true if the current token matches one of the types given.
+func (p *sourceParser) isToken(types ...lexer.TokenType) bool {
+ for _, kind := range types {
+ if p.currentToken.Kind == kind {
+ return true
+ }
+ }
+
+ return false
+}
+
+// isIdentifier returns true if the current token is an identifier matching that given.
+func (p *sourceParser) isIdentifier(identifier string) bool {
+ return p.isToken(lexer.TokenTypeIdentifier) && p.currentToken.Value == identifier
+}
+
+// isKeyword returns true if the current token is a keyword matching that given.
+func (p *sourceParser) isKeyword(keyword string) bool {
+ return p.isToken(lexer.TokenTypeKeyword) && p.currentToken.Value == keyword
+}
+
+// emitErrorf creates a new error node and attachs it as a child of the current
+// node.
+func (p *sourceParser) emitErrorf(format string, args ...interface{}) {
+ errorNode := p.createErrorNodef(format, args...)
+ if len(p.currentToken.Value) > 0 {
+ errorNode.MustDecorate(dslshape.NodePredicateErrorSource, p.currentToken.Value)
+ }
+ p.currentNode().Connect(dslshape.NodePredicateChild, errorNode)
+}
+
+// consumeVariableKeyword consumes an expected keyword token or adds an error node.
+func (p *sourceParser) consumeVariableKeyword() (string, bool) {
+ if !p.isToken(lexer.TokenTypeKeyword) {
+ p.emitErrorf("Expected keyword, found token %v", p.currentToken.Kind)
+ return "", false
+ }
+
+ token := p.currentToken
+ p.consumeToken()
+ return token.Value, true
+}
+
+// consumeKeyword consumes an expected keyword token or adds an error node.
+func (p *sourceParser) consumeKeyword(keyword string) bool {
+ if !p.tryConsumeKeyword(keyword) {
+ p.emitErrorf("Expected keyword %s, found token %v", keyword, p.currentToken.Kind)
+ return false
+ }
+ return true
+}
+
+// tryConsumeKeyword attempts to consume an expected keyword token.
+func (p *sourceParser) tryConsumeKeyword(keyword string) bool {
+ if !p.isKeyword(keyword) {
+ return false
+ }
+
+ p.consumeToken()
+ return true
+}
+
+// cosumeIdentifier consumes an expected identifier token or adds an error node.
+func (p *sourceParser) consumeIdentifier() (string, bool) {
+ token, ok := p.tryConsume(lexer.TokenTypeIdentifier)
+ if !ok {
+ p.emitErrorf("Expected identifier, found token %v", p.currentToken.Kind)
+ return "", false
+ }
+ return token.Value, true
+}
+
+// consume performs consumption of the next token if it matches any of the given
+// types and returns it. If no matching type is found, adds an error node.
+func (p *sourceParser) consume(types ...lexer.TokenType) (lexer.Lexeme, bool) {
+ token, ok := p.tryConsume(types...)
+ if !ok {
+ p.emitErrorf("Expected one of: %v, found: %v", types, p.currentToken.Kind)
+ }
+ return token, ok
+}
+
+// tryConsume performs consumption of the next token if it matches any of the given
+// types and returns it.
+func (p *sourceParser) tryConsume(types ...lexer.TokenType) (lexer.Lexeme, bool) {
+ token, found := p.tryConsumeWithComments(types...)
+ return token.Lexeme, found
+}
+
+// tryConsume performs consumption of the next token if it matches any of the given
+// types and returns it.
+func (p *sourceParser) tryConsumeWithComments(types ...lexer.TokenType) (commentedLexeme, bool) {
+ if p.isToken(types...) {
+ token := p.currentToken
+ p.consumeToken()
+ return token, true
+ }
+
+ return commentedLexeme{lexer.Lexeme{
+ Kind: lexer.TokenTypeError,
+ }, make([]string, 0)}, false
+}
+
+// performLeftRecursiveParsing performs left-recursive parsing of a set of operators. This method
+// first performs the parsing via the subTryExprFn and then checks for one of the left-recursive
+// operator token types found. If none found, the left expression is returned. Otherwise, the
+// rightNodeBuilder is called to attempt to construct an operator expression. This method also
+// properly handles decoration of the nodes with their proper start and end run locations and
+// comments.
+func (p *sourceParser) performLeftRecursiveParsing(subTryExprFn tryParserFn, rightNodeBuilder rightNodeConstructor, rightTokenTester lookaheadParserFn, operatorTokens ...lexer.TokenType) (AstNode, bool) {
+ leftMostToken := p.currentToken
+
+ // Consume the left side of the expression.
+ leftNode, ok := subTryExprFn()
+ if !ok {
+ return nil, false
+ }
+
+ // Check for an operator token. If none found, then we've found just the left side of the
+ // expression and so we return that node.
+ if !p.isToken(operatorTokens...) {
+ return leftNode, true
+ }
+
+ // Keep consuming pairs of operators and child expressions until such
+ // time as no more can be consumed. We use this loop+custom build rather than recursion
+ // because these operators are *left* recursive, not right.
+ var currentLeftNode AstNode
+ currentLeftNode = leftNode
+
+ for {
+ // Check for an operator.
+ if !p.isToken(operatorTokens...) {
+ break
+ }
+
+ // If a lookahead function is defined, check the lookahead for the matched token.
+ if rightTokenTester != nil && !rightTokenTester(p.currentToken.Lexeme) {
+ break
+ }
+
+ // Consume the operator.
+ operatorToken, ok := p.tryConsumeWithComments(operatorTokens...)
+ if !ok {
+ break
+ }
+
+ // Consume the right hand expression and build an expression node (if applicable).
+ exprNode, ok := rightNodeBuilder(currentLeftNode, operatorToken.Lexeme)
+ if !ok {
+ p.emitErrorf("Expected right hand expression, found: %v", p.currentToken.Kind)
+ return currentLeftNode, true
+ }
+
+ p.decorateStartRuneAndComments(exprNode, leftMostToken)
+ p.decorateEndRune(exprNode, p.previousToken)
+
+ currentLeftNode = exprNode
+ }
+
+ return currentLeftNode, true
+}
+
+// tryConsumeStatementTerminator tries to consume a statement terminator.
+func (p *sourceParser) tryConsumeStatementTerminator() (lexer.Lexeme, bool) {
+ return p.tryConsume(lexer.TokenTypeSyntheticSemicolon, lexer.TokenTypeSemicolon, lexer.TokenTypeEOF)
+}
+
+// consumeStatementTerminator consume a statement terminator.
+func (p *sourceParser) consumeStatementTerminator() bool {
+ _, ok := p.tryConsumeStatementTerminator()
+ if ok {
+ return true
+ }
+
+ p.emitErrorf("Expected end of statement or definition, found: %s", p.currentToken.Kind)
+ return false
+}
+
+// binaryOpDefinition represents information a binary operator token and its associated node type.
+type binaryOpDefinition struct {
+ // The token representing the binary expression's operator.
+ BinaryOperatorToken lexer.TokenType
+
+ // The type of node to create for this expression.
+ BinaryExpressionNodeType dslshape.NodeType
+}
+
+// buildBinaryOperatorExpressionFnTree builds a tree of functions to try to consume a set of binary
+// operator expressions.
+func (p *sourceParser) buildBinaryOperatorExpressionFnTree(ops []binaryOpDefinition) tryParserFn {
+ // Start with a base expression function.
+ var currentParseFn tryParserFn
+ currentParseFn = func() (AstNode, bool) {
+ arrowExpr, ok := p.tryConsumeArrowExpression()
+ if !ok {
+ return p.tryConsumeBaseExpression()
+ }
+
+ return arrowExpr, true
+ }
+
+ for i := range ops {
+ // Note: We have to reverse this to ensure we have proper precedence.
+ currentParseFn = func(operatorInfo binaryOpDefinition, currentFn tryParserFn) tryParserFn {
+ return (func() (AstNode, bool) {
+ return p.tryConsumeComputeExpression(currentFn, operatorInfo.BinaryOperatorToken, operatorInfo.BinaryExpressionNodeType)
+ })
+ }(ops[len(ops)-i-1], currentParseFn)
+ }
+
+ return currentParseFn
+}