summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler')
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go194
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go142
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go2
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go53
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go180
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go32
-rw-r--r--vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go714
7 files changed, 1317 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go
new file mode 100644
index 0000000..d1e96ec
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go
@@ -0,0 +1,194 @@
+package compiler
+
+import (
+ "errors"
+ "fmt"
+
+ "google.golang.org/protobuf/proto"
+ "k8s.io/utils/strings/slices"
+
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+ "github.com/authzed/spicedb/pkg/schemadsl/parser"
+)
+
+// InputSchema defines the input for a Compile.
+type InputSchema struct {
+ // Source is the source of the schema being compiled.
+ Source input.Source
+
+ // Schema is the contents being compiled.
+ SchemaString string
+}
+
+// SchemaDefinition represents an object or caveat definition in a schema.
+type SchemaDefinition interface {
+ proto.Message
+
+ GetName() string
+}
+
+// CompiledSchema is the result of compiling a schema when there are no errors.
+type CompiledSchema struct {
+ // ObjectDefinitions holds the object definitions in the schema.
+ ObjectDefinitions []*core.NamespaceDefinition
+
+ // CaveatDefinitions holds the caveat definitions in the schema.
+ CaveatDefinitions []*core.CaveatDefinition
+
+ // OrderedDefinitions holds the object and caveat definitions in the schema, in the
+ // order in which they were found.
+ OrderedDefinitions []SchemaDefinition
+
+ rootNode *dslNode
+ mapper input.PositionMapper
+}
+
+// SourcePositionToRunePosition converts a source position to a rune position.
+func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, position input.Position) (int, error) {
+ return cs.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source)
+}
+
+type config struct {
+ skipValidation bool
+ objectTypePrefix *string
+ allowedFlags []string
+ caveatTypeSet *caveattypes.TypeSet
+}
+
+func SkipValidation() Option { return func(cfg *config) { cfg.skipValidation = true } }
+
+func ObjectTypePrefix(prefix string) ObjectPrefixOption {
+ return func(cfg *config) { cfg.objectTypePrefix = &prefix }
+}
+
+func RequirePrefixedObjectType() ObjectPrefixOption {
+ return func(cfg *config) { cfg.objectTypePrefix = nil }
+}
+
+func AllowUnprefixedObjectType() ObjectPrefixOption {
+ return func(cfg *config) { cfg.objectTypePrefix = new(string) }
+}
+
+func CaveatTypeSet(cts *caveattypes.TypeSet) Option {
+ return func(cfg *config) { cfg.caveatTypeSet = cts }
+}
+
+const expirationFlag = "expiration"
+
+func DisallowExpirationFlag() Option {
+ return func(cfg *config) {
+ cfg.allowedFlags = slices.Filter([]string{}, cfg.allowedFlags, func(s string) bool {
+ return s != expirationFlag
+ })
+ }
+}
+
+type Option func(*config)
+
+type ObjectPrefixOption func(*config)
+
+// Compile compilers the input schema into a set of namespace definition protos.
+func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) {
+ cfg := &config{
+ allowedFlags: make([]string, 0, 1),
+ }
+
+ // Enable `expiration` flag by default.
+ cfg.allowedFlags = append(cfg.allowedFlags, expirationFlag)
+
+ prefix(cfg) // required option
+
+ for _, fn := range opts {
+ fn(cfg)
+ }
+
+ mapper := newPositionMapper(schema)
+ root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode)
+ errs := root.FindAll(dslshape.NodeTypeError)
+ if len(errs) > 0 {
+ err := errorNodeToError(errs[0], mapper)
+ return nil, err
+ }
+
+ cts := caveattypes.TypeSetOrDefault(cfg.caveatTypeSet)
+ compiled, err := translate(&translationContext{
+ objectTypePrefix: cfg.objectTypePrefix,
+ mapper: mapper,
+ schemaString: schema.SchemaString,
+ skipValidate: cfg.skipValidation,
+ allowedFlags: cfg.allowedFlags,
+ caveatTypeSet: cts,
+ }, root)
+ if err != nil {
+ var withNodeError withNodeError
+ if errors.As(err, &withNodeError) {
+ err = toContextError(withNodeError.error.Error(), withNodeError.errorSourceCode, withNodeError.node, mapper)
+ }
+
+ return nil, err
+ }
+
+ return compiled, nil
+}
+
+func errorNodeToError(node *dslNode, mapper input.PositionMapper) error {
+ if node.GetType() != dslshape.NodeTypeError {
+ return fmt.Errorf("given none error node")
+ }
+
+ errMessage, err := node.GetString(dslshape.NodePredicateErrorMessage)
+ if err != nil {
+ return fmt.Errorf("could not get error message for error node: %w", err)
+ }
+
+ errorSourceCode := ""
+ if node.Has(dslshape.NodePredicateErrorSource) {
+ es, err := node.GetString(dslshape.NodePredicateErrorSource)
+ if err != nil {
+ return fmt.Errorf("could not get error source for error node: %w", err)
+ }
+
+ errorSourceCode = es
+ }
+
+ return toContextError(errMessage, errorSourceCode, node, mapper)
+}
+
+func toContextError(errMessage string, errorSourceCode string, node *dslNode, mapper input.PositionMapper) error {
+ sourceRange, err := node.Range(mapper)
+ if err != nil {
+ return fmt.Errorf("could not get range for error node: %w", err)
+ }
+
+ formattedRange, err := formatRange(sourceRange)
+ if err != nil {
+ return err
+ }
+
+ source, err := node.GetString(dslshape.NodePredicateSource)
+ if err != nil {
+ return fmt.Errorf("missing source for node: %w", err)
+ }
+
+ return WithContextError{
+ BaseCompilerError: BaseCompilerError{
+ error: fmt.Errorf("parse error in %s: %s", formattedRange, errMessage),
+ BaseMessage: errMessage,
+ },
+ SourceRange: sourceRange,
+ Source: input.Source(source),
+ ErrorSourceCode: errorSourceCode,
+ }
+}
+
+func formatRange(rnge input.SourceRange) (string, error) {
+ startLine, startCol, err := rnge.Start().LineAndColumn()
+ if err != nil {
+ return "", err
+ }
+
+ return fmt.Sprintf("`%s`, line %v, column %v", rnge.Source(), startLine+1, startCol+1), nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go
new file mode 100644
index 0000000..7c8e7c7
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go
@@ -0,0 +1,142 @@
+package compiler
+
+import (
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+// DSLNode is a node in the DSL AST.
+type DSLNode interface {
+ GetType() dslshape.NodeType
+ GetString(predicateName string) (string, error)
+ GetInt(predicateName string) (int, error)
+ Lookup(predicateName string) (DSLNode, error)
+}
+
+// NodeChain is a chain of nodes in the DSL AST.
+type NodeChain struct {
+ nodes []DSLNode
+ runePosition int
+}
+
+// Head returns the head node of the chain.
+func (nc *NodeChain) Head() DSLNode {
+ return nc.nodes[0]
+}
+
+// HasHeadType returns true if the head node of the chain is of the given type.
+func (nc *NodeChain) HasHeadType(nodeType dslshape.NodeType) bool {
+ return nc.nodes[0].GetType() == nodeType
+}
+
+// ForRunePosition returns the rune position of the chain.
+func (nc *NodeChain) ForRunePosition() int {
+ return nc.runePosition
+}
+
+// FindNodeOfType returns the first node of the given type in the chain, if any.
+func (nc *NodeChain) FindNodeOfType(nodeType dslshape.NodeType) DSLNode {
+ for _, node := range nc.nodes {
+ if node.GetType() == nodeType {
+ return node
+ }
+ }
+
+ return nil
+}
+
+func (nc *NodeChain) String() string {
+ var out string
+ for _, node := range nc.nodes {
+ out += node.GetType().String() + " "
+ }
+ return out
+}
+
+// PositionToAstNodeChain returns the AST node, and its parents (if any), found at the given position in the source, if any.
+func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, position input.Position) (*NodeChain, error) {
+ rootSource, err := schema.rootNode.GetString(dslshape.NodePredicateSource)
+ if err != nil {
+ return nil, err
+ }
+
+ if rootSource != string(source) {
+ return nil, nil
+ }
+
+ // Map the position to a file rune.
+ runePosition, err := schema.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source)
+ if err != nil {
+ return nil, err
+ }
+
+ // Find the node at the rune position.
+ found, err := runePositionToAstNodeChain(schema.rootNode, runePosition)
+ if err != nil {
+ return nil, err
+ }
+
+ if found == nil {
+ return nil, nil
+ }
+
+ return &NodeChain{nodes: found, runePosition: runePosition}, nil
+}
+
+func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, error) {
+ if !node.Has(dslshape.NodePredicateStartRune) {
+ return nil, nil
+ }
+
+ startRune, err := node.GetInt(dslshape.NodePredicateStartRune)
+ if err != nil {
+ return nil, err
+ }
+
+ endRune, err := node.GetInt(dslshape.NodePredicateEndRune)
+ if err != nil {
+ return nil, err
+ }
+
+ if runePosition < startRune || runePosition > endRune {
+ return nil, nil
+ }
+
+ for _, child := range node.AllSubNodes() {
+ childChain, err := runePositionToAstNodeChain(child, runePosition)
+ if err != nil {
+ return nil, err
+ }
+
+ if childChain != nil {
+ return append(childChain, wrapper{node}), nil
+ }
+ }
+
+ return []DSLNode{wrapper{node}}, nil
+}
+
+type wrapper struct {
+ node *dslNode
+}
+
+func (w wrapper) GetType() dslshape.NodeType {
+ return w.node.GetType()
+}
+
+func (w wrapper) GetString(predicateName string) (string, error) {
+ return w.node.GetString(predicateName)
+}
+
+func (w wrapper) GetInt(predicateName string) (int, error) {
+ return w.node.GetInt(predicateName)
+}
+
+func (w wrapper) Lookup(predicateName string) (DSLNode, error) {
+ found, err := w.node.Lookup(predicateName)
+ if err != nil {
+ return nil, err
+ }
+
+ return wrapper{found}, nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go
new file mode 100644
index 0000000..fdc1735
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go
@@ -0,0 +1,2 @@
+// Package compiler knows how to build the Go representation of a SpiceDB schema text.
+package compiler
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go
new file mode 100644
index 0000000..2c33ba8
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go
@@ -0,0 +1,53 @@
+package compiler
+
+import (
+ "strconv"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+// BaseCompilerError defines an error with contains the base message of the issue
+// that occurred.
+type BaseCompilerError struct {
+ error
+ BaseMessage string
+}
+
+type withNodeError struct {
+ error
+ node *dslNode
+ errorSourceCode string
+}
+
+// WithContextError defines an error which contains contextual information.
+type WithContextError struct {
+ BaseCompilerError
+ SourceRange input.SourceRange
+ Source input.Source
+ ErrorSourceCode string
+}
+
+func (ewc WithContextError) Unwrap() error {
+ return ewc.BaseCompilerError
+}
+
+// DetailsMetadata returns the metadata for details for this error.
+func (ewc WithContextError) DetailsMetadata() map[string]string {
+ startLine, startCol, err := ewc.SourceRange.Start().LineAndColumn()
+ if err != nil {
+ return map[string]string{}
+ }
+
+ endLine, endCol, err := ewc.SourceRange.End().LineAndColumn()
+ if err != nil {
+ return map[string]string{}
+ }
+
+ return map[string]string{
+ "start_line_number": strconv.Itoa(startLine),
+ "start_column_position": strconv.Itoa(startCol),
+ "end_line_number": strconv.Itoa(endLine),
+ "end_column_position": strconv.Itoa(endCol),
+ "source_code": ewc.ErrorSourceCode,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go
new file mode 100644
index 0000000..b7e2a70
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go
@@ -0,0 +1,180 @@
+package compiler
+
+import (
+ "container/list"
+ "fmt"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+ "github.com/authzed/spicedb/pkg/schemadsl/parser"
+)
+
+type dslNode struct {
+ nodeType dslshape.NodeType
+ properties map[string]interface{}
+ children map[string]*list.List
+}
+
+func createAstNode(_ input.Source, kind dslshape.NodeType) parser.AstNode {
+ return &dslNode{
+ nodeType: kind,
+ properties: make(map[string]interface{}),
+ children: make(map[string]*list.List),
+ }
+}
+
+func (tn *dslNode) GetType() dslshape.NodeType {
+ return tn.nodeType
+}
+
+func (tn *dslNode) Connect(predicate string, other parser.AstNode) {
+ if tn.children[predicate] == nil {
+ tn.children[predicate] = list.New()
+ }
+
+ tn.children[predicate].PushBack(other)
+}
+
+func (tn *dslNode) MustDecorate(property string, value string) parser.AstNode {
+ if _, ok := tn.properties[property]; ok {
+ panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
+ }
+
+ tn.properties[property] = value
+ return tn
+}
+
+func (tn *dslNode) MustDecorateWithInt(property string, value int) parser.AstNode {
+ if _, ok := tn.properties[property]; ok {
+ panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties))
+ }
+
+ tn.properties[property] = value
+ return tn
+}
+
+func (tn *dslNode) Range(mapper input.PositionMapper) (input.SourceRange, error) {
+ sourceStr, err := tn.GetString(dslshape.NodePredicateSource)
+ if err != nil {
+ return nil, err
+ }
+
+ source := input.Source(sourceStr)
+
+ startRune, err := tn.GetInt(dslshape.NodePredicateStartRune)
+ if err != nil {
+ return nil, err
+ }
+
+ endRune, err := tn.GetInt(dslshape.NodePredicateEndRune)
+ if err != nil {
+ return nil, err
+ }
+
+ return source.RangeForRunePositions(startRune, endRune, mapper), nil
+}
+
+func (tn *dslNode) Has(predicateName string) bool {
+ _, ok := tn.properties[predicateName]
+ return ok
+}
+
+func (tn *dslNode) GetInt(predicateName string) (int, error) {
+ predicate, ok := tn.properties[predicateName]
+ if !ok {
+ return 0, fmt.Errorf("unknown predicate %s", predicateName)
+ }
+
+ value, ok := predicate.(int)
+ if !ok {
+ return 0, fmt.Errorf("predicate %s is not an int", predicateName)
+ }
+
+ return value, nil
+}
+
+func (tn *dslNode) GetString(predicateName string) (string, error) {
+ predicate, ok := tn.properties[predicateName]
+ if !ok {
+ return "", fmt.Errorf("unknown predicate %s", predicateName)
+ }
+
+ value, ok := predicate.(string)
+ if !ok {
+ return "", fmt.Errorf("predicate %s is not a string", predicateName)
+ }
+
+ return value, nil
+}
+
+func (tn *dslNode) AllSubNodes() []*dslNode {
+ nodes := []*dslNode{}
+ for _, childList := range tn.children {
+ for e := childList.Front(); e != nil; e = e.Next() {
+ nodes = append(nodes, e.Value.(*dslNode))
+ }
+ }
+ return nodes
+}
+
+func (tn *dslNode) GetChildren() []*dslNode {
+ return tn.List(dslshape.NodePredicateChild)
+}
+
+func (tn *dslNode) FindAll(nodeType dslshape.NodeType) []*dslNode {
+ found := []*dslNode{}
+ if tn.nodeType == dslshape.NodeTypeError {
+ found = append(found, tn)
+ }
+
+ for _, childList := range tn.children {
+ for e := childList.Front(); e != nil; e = e.Next() {
+ childFound := e.Value.(*dslNode).FindAll(nodeType)
+ found = append(found, childFound...)
+ }
+ }
+ return found
+}
+
+func (tn *dslNode) List(predicateName string) []*dslNode {
+ children := []*dslNode{}
+ childList, ok := tn.children[predicateName]
+ if !ok {
+ return children
+ }
+
+ for e := childList.Front(); e != nil; e = e.Next() {
+ children = append(children, e.Value.(*dslNode))
+ }
+
+ return children
+}
+
+func (tn *dslNode) Lookup(predicateName string) (*dslNode, error) {
+ childList, ok := tn.children[predicateName]
+ if !ok {
+ return nil, fmt.Errorf("unknown predicate %s", predicateName)
+ }
+
+ for e := childList.Front(); e != nil; e = e.Next() {
+ return e.Value.(*dslNode), nil
+ }
+
+ return nil, fmt.Errorf("nothing in predicate %s", predicateName)
+}
+
+func (tn *dslNode) Errorf(message string, args ...interface{}) error {
+ return withNodeError{
+ error: fmt.Errorf(message, args...),
+ errorSourceCode: "",
+ node: tn,
+ }
+}
+
+func (tn *dslNode) WithSourceErrorf(sourceCode string, message string, args ...interface{}) error {
+ return withNodeError{
+ error: fmt.Errorf(message, args...),
+ errorSourceCode: sourceCode,
+ node: tn,
+ }
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go
new file mode 100644
index 0000000..aa33c43
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go
@@ -0,0 +1,32 @@
+package compiler
+
+import (
+ "strings"
+
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+type positionMapper struct {
+ schema InputSchema
+ mapper input.SourcePositionMapper
+}
+
+func newPositionMapper(schema InputSchema) input.PositionMapper {
+ return &positionMapper{
+ schema: schema,
+ mapper: input.CreateSourcePositionMapper([]byte(schema.SchemaString)),
+ }
+}
+
+func (pm *positionMapper) RunePositionToLineAndCol(runePosition int, _ input.Source) (int, int, error) {
+ return pm.mapper.RunePositionToLineAndCol(runePosition)
+}
+
+func (pm *positionMapper) LineAndColToRunePosition(lineNumber int, colPosition int, _ input.Source) (int, error) {
+ return pm.mapper.LineAndColToRunePosition(lineNumber, colPosition)
+}
+
+func (pm *positionMapper) TextForLine(lineNumber int, _ input.Source) (string, error) {
+ lines := strings.Split(pm.schema.SchemaString, "\n")
+ return lines[lineNumber], nil
+}
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go
new file mode 100644
index 0000000..77877b0
--- /dev/null
+++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go
@@ -0,0 +1,714 @@
+package compiler
+
+import (
+ "bufio"
+ "fmt"
+ "slices"
+ "strings"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/jzelinskie/stringz"
+
+ "github.com/authzed/spicedb/pkg/caveats"
+ caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/genutil/mapz"
+ "github.com/authzed/spicedb/pkg/namespace"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/dslshape"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+)
+
+type translationContext struct {
+ objectTypePrefix *string
+ mapper input.PositionMapper
+ schemaString string
+ skipValidate bool
+ allowedFlags []string
+ enabledFlags []string
+ caveatTypeSet *caveattypes.TypeSet
+}
+
+func (tctx *translationContext) prefixedPath(definitionName string) (string, error) {
+ var prefix, name string
+ if err := stringz.SplitInto(definitionName, "/", &prefix, &name); err != nil {
+ if tctx.objectTypePrefix == nil {
+ return "", fmt.Errorf("found reference `%s` without prefix", definitionName)
+ }
+ prefix = *tctx.objectTypePrefix
+ name = definitionName
+ }
+
+ if prefix == "" {
+ return name, nil
+ }
+
+ return stringz.Join("/", prefix, name), nil
+}
+
+const Ellipsis = "..."
+
+func translate(tctx *translationContext, root *dslNode) (*CompiledSchema, error) {
+ orderedDefinitions := make([]SchemaDefinition, 0, len(root.GetChildren()))
+ var objectDefinitions []*core.NamespaceDefinition
+ var caveatDefinitions []*core.CaveatDefinition
+
+ names := mapz.NewSet[string]()
+
+ for _, definitionNode := range root.GetChildren() {
+ var definition SchemaDefinition
+
+ switch definitionNode.GetType() {
+ case dslshape.NodeTypeUseFlag:
+ err := translateUseFlag(tctx, definitionNode)
+ if err != nil {
+ return nil, err
+ }
+ continue
+
+ case dslshape.NodeTypeCaveatDefinition:
+ def, err := translateCaveatDefinition(tctx, definitionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ definition = def
+ caveatDefinitions = append(caveatDefinitions, def)
+
+ case dslshape.NodeTypeDefinition:
+ def, err := translateObjectDefinition(tctx, definitionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ definition = def
+ objectDefinitions = append(objectDefinitions, def)
+ }
+
+ if !names.Add(definition.GetName()) {
+ return nil, definitionNode.WithSourceErrorf(definition.GetName(), "found name reused between multiple definitions and/or caveats: %s", definition.GetName())
+ }
+
+ orderedDefinitions = append(orderedDefinitions, definition)
+ }
+
+ return &CompiledSchema{
+ CaveatDefinitions: caveatDefinitions,
+ ObjectDefinitions: objectDefinitions,
+ OrderedDefinitions: orderedDefinitions,
+ rootNode: root,
+ mapper: tctx.mapper,
+ }, nil
+}
+
+func translateCaveatDefinition(tctx *translationContext, defNode *dslNode) (*core.CaveatDefinition, error) {
+ definitionName, err := defNode.GetString(dslshape.NodeCaveatDefinitionPredicateName)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err)
+ }
+
+ // parameters
+ paramNodes := defNode.List(dslshape.NodeCaveatDefinitionPredicateParameters)
+ if len(paramNodes) == 0 {
+ return nil, defNode.WithSourceErrorf(definitionName, "caveat `%s` must have at least one parameter defined", definitionName)
+ }
+
+ env := caveats.NewEnvironment()
+ parameters := make(map[string]caveattypes.VariableType, len(paramNodes))
+ for _, paramNode := range paramNodes {
+ paramName, err := paramNode.GetString(dslshape.NodeCaveatParameterPredicateName)
+ if err != nil {
+ return nil, paramNode.WithSourceErrorf(paramName, "invalid parameter name: %w", err)
+ }
+
+ if _, ok := parameters[paramName]; ok {
+ return nil, paramNode.WithSourceErrorf(paramName, "duplicate parameter `%s` defined on caveat `%s`", paramName, definitionName)
+ }
+
+ typeRefNode, err := paramNode.Lookup(dslshape.NodeCaveatParameterPredicateType)
+ if err != nil {
+ return nil, paramNode.WithSourceErrorf(paramName, "invalid type for parameter: %w", err)
+ }
+
+ translatedType, err := translateCaveatTypeReference(tctx, typeRefNode)
+ if err != nil {
+ return nil, paramNode.WithSourceErrorf(paramName, "invalid type for caveat parameter `%s` on caveat `%s`: %w", paramName, definitionName, err)
+ }
+
+ parameters[paramName] = *translatedType
+ err = env.AddVariable(paramName, *translatedType)
+ if err != nil {
+ return nil, paramNode.WithSourceErrorf(paramName, "invalid type for caveat parameter `%s` on caveat `%s`: %w", paramName, definitionName, err)
+ }
+ }
+
+ caveatPath, err := tctx.prefixedPath(definitionName)
+ if err != nil {
+ return nil, defNode.Errorf("%w", err)
+ }
+
+ // caveat expression.
+ expressionStringNode, err := defNode.Lookup(dslshape.NodeCaveatDefinitionPredicateExpession)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(definitionName, "invalid expression: %w", err)
+ }
+
+ expressionString, err := expressionStringNode.GetString(dslshape.NodeCaveatExpressionPredicateExpression)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err)
+ }
+
+ rnge, err := expressionStringNode.Range(tctx.mapper)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err)
+ }
+
+ source, err := caveats.NewSource(expressionString, caveatPath)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err)
+ }
+
+ compiled, err := caveats.CompileCaveatWithSource(env, caveatPath, source, rnge.Start())
+ if err != nil {
+ return nil, expressionStringNode.WithSourceErrorf(expressionString, "invalid expression for caveat `%s`: %w", definitionName, err)
+ }
+
+ def, err := namespace.CompiledCaveatDefinition(env, caveatPath, compiled)
+ if err != nil {
+ return nil, err
+ }
+
+ def.Metadata = addComments(def.Metadata, defNode)
+ def.SourcePosition = getSourcePosition(defNode, tctx.mapper)
+ return def, nil
+}
+
+func translateCaveatTypeReference(tctx *translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) {
+ typeName, err := typeRefNode.GetString(dslshape.NodeCaveatTypeReferencePredicateType)
+ if err != nil {
+ return nil, typeRefNode.WithSourceErrorf(typeName, "invalid type name: %w", err)
+ }
+
+ childTypeNodes := typeRefNode.List(dslshape.NodeCaveatTypeReferencePredicateChildTypes)
+ childTypes := make([]caveattypes.VariableType, 0, len(childTypeNodes))
+ for _, childTypeNode := range childTypeNodes {
+ translated, err := translateCaveatTypeReference(tctx, childTypeNode)
+ if err != nil {
+ return nil, err
+ }
+ childTypes = append(childTypes, *translated)
+ }
+
+ constructedType, err := tctx.caveatTypeSet.BuildType(typeName, childTypes)
+ if err != nil {
+ return nil, typeRefNode.WithSourceErrorf(typeName, "%w", err)
+ }
+
+ return constructedType, nil
+}
+
+func translateObjectDefinition(tctx *translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) {
+ definitionName, err := defNode.GetString(dslshape.NodeDefinitionPredicateName)
+ if err != nil {
+ return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err)
+ }
+
+ relationsAndPermissions := []*core.Relation{}
+ for _, relationOrPermissionNode := range defNode.GetChildren() {
+ if relationOrPermissionNode.GetType() == dslshape.NodeTypeComment {
+ continue
+ }
+
+ relationOrPermission, err := translateRelationOrPermission(tctx, relationOrPermissionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ relationsAndPermissions = append(relationsAndPermissions, relationOrPermission)
+ }
+
+ nspath, err := tctx.prefixedPath(definitionName)
+ if err != nil {
+ return nil, defNode.Errorf("%w", err)
+ }
+
+ if len(relationsAndPermissions) == 0 {
+ ns := namespace.Namespace(nspath)
+ ns.Metadata = addComments(ns.Metadata, defNode)
+ ns.SourcePosition = getSourcePosition(defNode, tctx.mapper)
+
+ if !tctx.skipValidate {
+ if err = ns.Validate(); err != nil {
+ return nil, defNode.Errorf("error in object definition %s: %w", nspath, err)
+ }
+ }
+
+ return ns, nil
+ }
+
+ ns := namespace.Namespace(nspath, relationsAndPermissions...)
+ ns.Metadata = addComments(ns.Metadata, defNode)
+ ns.SourcePosition = getSourcePosition(defNode, tctx.mapper)
+
+ if !tctx.skipValidate {
+ if err := ns.Validate(); err != nil {
+ return nil, defNode.Errorf("error in object definition %s: %w", nspath, err)
+ }
+ }
+
+ return ns, nil
+}
+
+func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.SourcePosition {
+ if !dslNode.Has(dslshape.NodePredicateStartRune) {
+ return nil
+ }
+
+ sourceRange, err := dslNode.Range(mapper)
+ if err != nil {
+ return nil
+ }
+
+ line, col, err := sourceRange.Start().LineAndColumn()
+ if err != nil {
+ return nil
+ }
+
+ // We're okay with these being zero if the cast fails.
+ uintLine, _ := safecast.ToUint64(line)
+ uintCol, _ := safecast.ToUint64(col)
+
+ return &core.SourcePosition{
+ ZeroIndexedLineNumber: uintLine,
+ ZeroIndexedColumnPosition: uintCol,
+ }
+}
+
+func addComments(mdmsg *core.Metadata, dslNode *dslNode) *core.Metadata {
+ for _, child := range dslNode.GetChildren() {
+ if child.GetType() == dslshape.NodeTypeComment {
+ value, err := child.GetString(dslshape.NodeCommentPredicateValue)
+ if err == nil {
+ mdmsg, _ = namespace.AddComment(mdmsg, normalizeComment(value))
+ }
+ }
+ }
+ return mdmsg
+}
+
+func normalizeComment(value string) string {
+ var lines []string
+ scanner := bufio.NewScanner(strings.NewReader(value))
+ for scanner.Scan() {
+ trimmed := strings.TrimSpace(scanner.Text())
+ lines = append(lines, trimmed)
+ }
+ return strings.Join(lines, "\n")
+}
+
+func translateRelationOrPermission(tctx *translationContext, relOrPermNode *dslNode) (*core.Relation, error) {
+ switch relOrPermNode.GetType() {
+ case dslshape.NodeTypeRelation:
+ rel, err := translateRelation(tctx, relOrPermNode)
+ if err != nil {
+ return nil, err
+ }
+ rel.Metadata = addComments(rel.Metadata, relOrPermNode)
+ rel.SourcePosition = getSourcePosition(relOrPermNode, tctx.mapper)
+ return rel, err
+
+ case dslshape.NodeTypePermission:
+ rel, err := translatePermission(tctx, relOrPermNode)
+ if err != nil {
+ return nil, err
+ }
+ rel.Metadata = addComments(rel.Metadata, relOrPermNode)
+ rel.SourcePosition = getSourcePosition(relOrPermNode, tctx.mapper)
+ return rel, err
+
+ default:
+ return nil, relOrPermNode.Errorf("unknown definition top-level node type %s", relOrPermNode.GetType())
+ }
+}
+
+func translateRelation(tctx *translationContext, relationNode *dslNode) (*core.Relation, error) {
+ relationName, err := relationNode.GetString(dslshape.NodePredicateName)
+ if err != nil {
+ return nil, relationNode.Errorf("invalid relation name: %w", err)
+ }
+
+ allowedDirectTypes := []*core.AllowedRelation{}
+ for _, typeRef := range relationNode.List(dslshape.NodeRelationPredicateAllowedTypes) {
+ allowedRelations, err := translateAllowedRelations(tctx, typeRef)
+ if err != nil {
+ return nil, err
+ }
+
+ allowedDirectTypes = append(allowedDirectTypes, allowedRelations...)
+ }
+
+ relation, err := namespace.Relation(relationName, nil, allowedDirectTypes...)
+ if err != nil {
+ return nil, err
+ }
+
+ if !tctx.skipValidate {
+ if err := relation.Validate(); err != nil {
+ return nil, relationNode.Errorf("error in relation %s: %w", relationName, err)
+ }
+ }
+
+ return relation, nil
+}
+
+func translatePermission(tctx *translationContext, permissionNode *dslNode) (*core.Relation, error) {
+ permissionName, err := permissionNode.GetString(dslshape.NodePredicateName)
+ if err != nil {
+ return nil, permissionNode.Errorf("invalid permission name: %w", err)
+ }
+
+ expressionNode, err := permissionNode.Lookup(dslshape.NodePermissionPredicateComputeExpression)
+ if err != nil {
+ return nil, permissionNode.Errorf("invalid permission expression: %w", err)
+ }
+
+ rewrite, err := translateExpression(tctx, expressionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ permission, err := namespace.Relation(permissionName, rewrite)
+ if err != nil {
+ return nil, err
+ }
+
+ if !tctx.skipValidate {
+ if err := permission.Validate(); err != nil {
+ return nil, permissionNode.Errorf("error in permission %s: %w", permissionName, err)
+ }
+ }
+
+ return permission, nil
+}
+
+func translateBinary(tctx *translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) {
+ leftChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ rightChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateRightExpr)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ leftOperation, err := translateExpressionOperation(tctx, leftChild)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ rightOperation, err := translateExpressionOperation(tctx, rightChild)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return leftOperation, rightOperation, nil
+}
+
+func translateExpression(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) {
+ translated, err := translateExpressionDirect(tctx, expressionNode)
+ if err != nil {
+ return translated, err
+ }
+
+ translated.SourcePosition = getSourcePosition(expressionNode, tctx.mapper)
+ return translated, nil
+}
+
+func collapseOps(op *core.SetOperation_Child, handler func(rewrite *core.UsersetRewrite) *core.SetOperation) []*core.SetOperation_Child {
+ if op.GetUsersetRewrite() == nil {
+ return []*core.SetOperation_Child{op}
+ }
+
+ usersetRewrite := op.GetUsersetRewrite()
+ operation := handler(usersetRewrite)
+ if operation == nil {
+ return []*core.SetOperation_Child{op}
+ }
+
+ collapsed := make([]*core.SetOperation_Child, 0, len(operation.Child))
+ for _, child := range operation.Child {
+ collapsed = append(collapsed, collapseOps(child, handler)...)
+ }
+ return collapsed
+}
+
+func translateExpressionDirect(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) {
+ // For union and intersection, we collapse a tree of binary operations into a flat list containing child
+ // operations of the *same* type.
+ translate := func(
+ builder func(firstChild *core.SetOperation_Child, rest ...*core.SetOperation_Child) *core.UsersetRewrite,
+ lookup func(rewrite *core.UsersetRewrite) *core.SetOperation,
+ ) (*core.UsersetRewrite, error) {
+ leftOperation, rightOperation, err := translateBinary(tctx, expressionNode)
+ if err != nil {
+ return nil, err
+ }
+ leftOps := collapseOps(leftOperation, lookup)
+ rightOps := collapseOps(rightOperation, lookup)
+ ops := append(leftOps, rightOps...)
+ return builder(ops[0], ops[1:]...), nil
+ }
+
+ switch expressionNode.GetType() {
+ case dslshape.NodeTypeUnionExpression:
+ return translate(namespace.Union, func(rewrite *core.UsersetRewrite) *core.SetOperation {
+ return rewrite.GetUnion()
+ })
+
+ case dslshape.NodeTypeIntersectExpression:
+ return translate(namespace.Intersection, func(rewrite *core.UsersetRewrite) *core.SetOperation {
+ return rewrite.GetIntersection()
+ })
+
+ case dslshape.NodeTypeExclusionExpression:
+ // Order matters for exclusions, so do not perform the optimization.
+ leftOperation, rightOperation, err := translateBinary(tctx, expressionNode)
+ if err != nil {
+ return nil, err
+ }
+ return namespace.Exclusion(leftOperation, rightOperation), nil
+
+ default:
+ op, err := translateExpressionOperation(tctx, expressionNode)
+ if err != nil {
+ return nil, err
+ }
+
+ return namespace.Union(op), nil
+ }
+}
+
+func translateExpressionOperation(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) {
+ translated, err := translateExpressionOperationDirect(tctx, expressionOpNode)
+ if err != nil {
+ return translated, err
+ }
+
+ translated.SourcePosition = getSourcePosition(expressionOpNode, tctx.mapper)
+ return translated, nil
+}
+
+func translateExpressionOperationDirect(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) {
+ switch expressionOpNode.GetType() {
+ case dslshape.NodeTypeIdentifier:
+ referencedRelationName, err := expressionOpNode.GetString(dslshape.NodeIdentiferPredicateValue)
+ if err != nil {
+ return nil, err
+ }
+
+ return namespace.ComputedUserset(referencedRelationName), nil
+
+ case dslshape.NodeTypeNilExpression:
+ return namespace.Nil(), nil
+
+ case dslshape.NodeTypeArrowExpression:
+ leftChild, err := expressionOpNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr)
+ if err != nil {
+ return nil, err
+ }
+
+ rightChild, err := expressionOpNode.Lookup(dslshape.NodeExpressionPredicateRightExpr)
+ if err != nil {
+ return nil, err
+ }
+
+ if leftChild.GetType() != dslshape.NodeTypeIdentifier {
+ return nil, leftChild.Errorf("Nested arrows not yet supported")
+ }
+
+ tuplesetRelation, err := leftChild.GetString(dslshape.NodeIdentiferPredicateValue)
+ if err != nil {
+ return nil, err
+ }
+
+ usersetRelation, err := rightChild.GetString(dslshape.NodeIdentiferPredicateValue)
+ if err != nil {
+ return nil, err
+ }
+
+ if expressionOpNode.Has(dslshape.NodeArrowExpressionFunctionName) {
+ functionName, err := expressionOpNode.GetString(dslshape.NodeArrowExpressionFunctionName)
+ if err != nil {
+ return nil, err
+ }
+
+ return namespace.MustFunctionedTupleToUserset(tuplesetRelation, functionName, usersetRelation), nil
+ }
+
+ return namespace.TupleToUserset(tuplesetRelation, usersetRelation), nil
+
+ case dslshape.NodeTypeUnionExpression:
+ fallthrough
+
+ case dslshape.NodeTypeIntersectExpression:
+ fallthrough
+
+ case dslshape.NodeTypeExclusionExpression:
+ rewrite, err := translateExpression(tctx, expressionOpNode)
+ if err != nil {
+ return nil, err
+ }
+ return namespace.Rewrite(rewrite), nil
+
+ default:
+ return nil, expressionOpNode.Errorf("unknown expression node type %s", expressionOpNode.GetType())
+ }
+}
+
+func translateAllowedRelations(tctx *translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) {
+ switch typeRefNode.GetType() {
+ case dslshape.NodeTypeTypeReference:
+ references := []*core.AllowedRelation{}
+ for _, subRefNode := range typeRefNode.List(dslshape.NodeTypeReferencePredicateType) {
+ subReferences, err := translateAllowedRelations(tctx, subRefNode)
+ if err != nil {
+ return []*core.AllowedRelation{}, err
+ }
+
+ references = append(references, subReferences...)
+ }
+ return references, nil
+
+ case dslshape.NodeTypeSpecificTypeReference:
+ ref, err := translateSpecificTypeReference(tctx, typeRefNode)
+ if err != nil {
+ return []*core.AllowedRelation{}, err
+ }
+ return []*core.AllowedRelation{ref}, nil
+
+ default:
+ return nil, typeRefNode.Errorf("unknown type ref node type %s", typeRefNode.GetType())
+ }
+}
+
+func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) {
+ typePath, err := typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateType)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid type name: %w", err)
+ }
+
+ nspath, err := tctx.prefixedPath(typePath)
+ if err != nil {
+ return nil, typeRefNode.Errorf("%w", err)
+ }
+
+ if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateWildcard) {
+ ref := &core.AllowedRelation{
+ Namespace: nspath,
+ RelationOrWildcard: &core.AllowedRelation_PublicWildcard_{
+ PublicWildcard: &core.AllowedRelation_PublicWildcard{},
+ },
+ }
+
+ err = addWithCaveats(tctx, typeRefNode, ref)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid caveat: %w", err)
+ }
+
+ if !tctx.skipValidate {
+ if err := ref.Validate(); err != nil {
+ return nil, typeRefNode.Errorf("invalid type relation: %w", err)
+ }
+ }
+
+ ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
+ return ref, nil
+ }
+
+ relationName := Ellipsis
+ if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateRelation) {
+ relationName, err = typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateRelation)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid type relation: %w", err)
+ }
+ }
+
+ ref := &core.AllowedRelation{
+ Namespace: nspath,
+ RelationOrWildcard: &core.AllowedRelation_Relation{
+ Relation: relationName,
+ },
+ }
+
+ // Add the caveat(s), if any.
+ err = addWithCaveats(tctx, typeRefNode, ref)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid caveat: %w", err)
+ }
+
+ // Add the expiration trait, if any.
+ if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil {
+ traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait)
+ if err != nil {
+ return nil, typeRefNode.Errorf("invalid trait: %w", err)
+ }
+
+ if traitName != "expiration" {
+ return nil, typeRefNode.Errorf("invalid trait: %s", traitName)
+ }
+
+ if !slices.Contains(tctx.allowedFlags, "expiration") {
+ return nil, typeRefNode.Errorf("expiration trait is not allowed")
+ }
+
+ ref.RequiredExpiration = &core.ExpirationTrait{}
+ }
+
+ if !tctx.skipValidate {
+ if err := ref.Validate(); err != nil {
+ return nil, typeRefNode.Errorf("invalid type relation: %w", err)
+ }
+ }
+
+ ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper)
+ return ref, nil
+}
+
+func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error {
+ caveats := typeRefNode.List(dslshape.NodeSpecificReferencePredicateCaveat)
+ if len(caveats) == 0 {
+ return nil
+ }
+
+ if len(caveats) != 1 {
+ return fmt.Errorf("only one caveat is currently allowed per type reference")
+ }
+
+ name, err := caveats[0].GetString(dslshape.NodeCaveatPredicateCaveat)
+ if err != nil {
+ return err
+ }
+
+ nspath, err := tctx.prefixedPath(name)
+ if err != nil {
+ return err
+ }
+
+ ref.RequiredCaveat = &core.AllowedCaveat{
+ CaveatName: nspath,
+ }
+ return nil
+}
+
+// Translate use node and add flag to list of enabled flags
+func translateUseFlag(tctx *translationContext, useFlagNode *dslNode) error {
+ flagName, err := useFlagNode.GetString(dslshape.NodeUseFlagPredicateName)
+ if err != nil {
+ return err
+ }
+ if slices.Contains(tctx.enabledFlags, flagName) {
+ return useFlagNode.Errorf("found duplicate use flag: %s", flagName)
+ }
+ tctx.enabledFlags = append(tctx.enabledFlags, flagName)
+ return nil
+}