diff options
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler')
7 files changed, 1317 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go new file mode 100644 index 0000000..d1e96ec --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go @@ -0,0 +1,194 @@ +package compiler + +import ( + "errors" + "fmt" + + "google.golang.org/protobuf/proto" + "k8s.io/utils/strings/slices" + + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/schemadsl/parser" +) + +// InputSchema defines the input for a Compile. +type InputSchema struct { + // Source is the source of the schema being compiled. + Source input.Source + + // Schema is the contents being compiled. + SchemaString string +} + +// SchemaDefinition represents an object or caveat definition in a schema. +type SchemaDefinition interface { + proto.Message + + GetName() string +} + +// CompiledSchema is the result of compiling a schema when there are no errors. +type CompiledSchema struct { + // ObjectDefinitions holds the object definitions in the schema. + ObjectDefinitions []*core.NamespaceDefinition + + // CaveatDefinitions holds the caveat definitions in the schema. + CaveatDefinitions []*core.CaveatDefinition + + // OrderedDefinitions holds the object and caveat definitions in the schema, in the + // order in which they were found. + OrderedDefinitions []SchemaDefinition + + rootNode *dslNode + mapper input.PositionMapper +} + +// SourcePositionToRunePosition converts a source position to a rune position. +func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, position input.Position) (int, error) { + return cs.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source) +} + +type config struct { + skipValidation bool + objectTypePrefix *string + allowedFlags []string + caveatTypeSet *caveattypes.TypeSet +} + +func SkipValidation() Option { return func(cfg *config) { cfg.skipValidation = true } } + +func ObjectTypePrefix(prefix string) ObjectPrefixOption { + return func(cfg *config) { cfg.objectTypePrefix = &prefix } +} + +func RequirePrefixedObjectType() ObjectPrefixOption { + return func(cfg *config) { cfg.objectTypePrefix = nil } +} + +func AllowUnprefixedObjectType() ObjectPrefixOption { + return func(cfg *config) { cfg.objectTypePrefix = new(string) } +} + +func CaveatTypeSet(cts *caveattypes.TypeSet) Option { + return func(cfg *config) { cfg.caveatTypeSet = cts } +} + +const expirationFlag = "expiration" + +func DisallowExpirationFlag() Option { + return func(cfg *config) { + cfg.allowedFlags = slices.Filter([]string{}, cfg.allowedFlags, func(s string) bool { + return s != expirationFlag + }) + } +} + +type Option func(*config) + +type ObjectPrefixOption func(*config) + +// Compile compilers the input schema into a set of namespace definition protos. +func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) { + cfg := &config{ + allowedFlags: make([]string, 0, 1), + } + + // Enable `expiration` flag by default. + cfg.allowedFlags = append(cfg.allowedFlags, expirationFlag) + + prefix(cfg) // required option + + for _, fn := range opts { + fn(cfg) + } + + mapper := newPositionMapper(schema) + root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode) + errs := root.FindAll(dslshape.NodeTypeError) + if len(errs) > 0 { + err := errorNodeToError(errs[0], mapper) + return nil, err + } + + cts := caveattypes.TypeSetOrDefault(cfg.caveatTypeSet) + compiled, err := translate(&translationContext{ + objectTypePrefix: cfg.objectTypePrefix, + mapper: mapper, + schemaString: schema.SchemaString, + skipValidate: cfg.skipValidation, + allowedFlags: cfg.allowedFlags, + caveatTypeSet: cts, + }, root) + if err != nil { + var withNodeError withNodeError + if errors.As(err, &withNodeError) { + err = toContextError(withNodeError.error.Error(), withNodeError.errorSourceCode, withNodeError.node, mapper) + } + + return nil, err + } + + return compiled, nil +} + +func errorNodeToError(node *dslNode, mapper input.PositionMapper) error { + if node.GetType() != dslshape.NodeTypeError { + return fmt.Errorf("given none error node") + } + + errMessage, err := node.GetString(dslshape.NodePredicateErrorMessage) + if err != nil { + return fmt.Errorf("could not get error message for error node: %w", err) + } + + errorSourceCode := "" + if node.Has(dslshape.NodePredicateErrorSource) { + es, err := node.GetString(dslshape.NodePredicateErrorSource) + if err != nil { + return fmt.Errorf("could not get error source for error node: %w", err) + } + + errorSourceCode = es + } + + return toContextError(errMessage, errorSourceCode, node, mapper) +} + +func toContextError(errMessage string, errorSourceCode string, node *dslNode, mapper input.PositionMapper) error { + sourceRange, err := node.Range(mapper) + if err != nil { + return fmt.Errorf("could not get range for error node: %w", err) + } + + formattedRange, err := formatRange(sourceRange) + if err != nil { + return err + } + + source, err := node.GetString(dslshape.NodePredicateSource) + if err != nil { + return fmt.Errorf("missing source for node: %w", err) + } + + return WithContextError{ + BaseCompilerError: BaseCompilerError{ + error: fmt.Errorf("parse error in %s: %s", formattedRange, errMessage), + BaseMessage: errMessage, + }, + SourceRange: sourceRange, + Source: input.Source(source), + ErrorSourceCode: errorSourceCode, + } +} + +func formatRange(rnge input.SourceRange) (string, error) { + startLine, startCol, err := rnge.Start().LineAndColumn() + if err != nil { + return "", err + } + + return fmt.Sprintf("`%s`, line %v, column %v", rnge.Source(), startLine+1, startCol+1), nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go new file mode 100644 index 0000000..7c8e7c7 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go @@ -0,0 +1,142 @@ +package compiler + +import ( + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +// DSLNode is a node in the DSL AST. +type DSLNode interface { + GetType() dslshape.NodeType + GetString(predicateName string) (string, error) + GetInt(predicateName string) (int, error) + Lookup(predicateName string) (DSLNode, error) +} + +// NodeChain is a chain of nodes in the DSL AST. +type NodeChain struct { + nodes []DSLNode + runePosition int +} + +// Head returns the head node of the chain. +func (nc *NodeChain) Head() DSLNode { + return nc.nodes[0] +} + +// HasHeadType returns true if the head node of the chain is of the given type. +func (nc *NodeChain) HasHeadType(nodeType dslshape.NodeType) bool { + return nc.nodes[0].GetType() == nodeType +} + +// ForRunePosition returns the rune position of the chain. +func (nc *NodeChain) ForRunePosition() int { + return nc.runePosition +} + +// FindNodeOfType returns the first node of the given type in the chain, if any. +func (nc *NodeChain) FindNodeOfType(nodeType dslshape.NodeType) DSLNode { + for _, node := range nc.nodes { + if node.GetType() == nodeType { + return node + } + } + + return nil +} + +func (nc *NodeChain) String() string { + var out string + for _, node := range nc.nodes { + out += node.GetType().String() + " " + } + return out +} + +// PositionToAstNodeChain returns the AST node, and its parents (if any), found at the given position in the source, if any. +func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, position input.Position) (*NodeChain, error) { + rootSource, err := schema.rootNode.GetString(dslshape.NodePredicateSource) + if err != nil { + return nil, err + } + + if rootSource != string(source) { + return nil, nil + } + + // Map the position to a file rune. + runePosition, err := schema.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source) + if err != nil { + return nil, err + } + + // Find the node at the rune position. + found, err := runePositionToAstNodeChain(schema.rootNode, runePosition) + if err != nil { + return nil, err + } + + if found == nil { + return nil, nil + } + + return &NodeChain{nodes: found, runePosition: runePosition}, nil +} + +func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, error) { + if !node.Has(dslshape.NodePredicateStartRune) { + return nil, nil + } + + startRune, err := node.GetInt(dslshape.NodePredicateStartRune) + if err != nil { + return nil, err + } + + endRune, err := node.GetInt(dslshape.NodePredicateEndRune) + if err != nil { + return nil, err + } + + if runePosition < startRune || runePosition > endRune { + return nil, nil + } + + for _, child := range node.AllSubNodes() { + childChain, err := runePositionToAstNodeChain(child, runePosition) + if err != nil { + return nil, err + } + + if childChain != nil { + return append(childChain, wrapper{node}), nil + } + } + + return []DSLNode{wrapper{node}}, nil +} + +type wrapper struct { + node *dslNode +} + +func (w wrapper) GetType() dslshape.NodeType { + return w.node.GetType() +} + +func (w wrapper) GetString(predicateName string) (string, error) { + return w.node.GetString(predicateName) +} + +func (w wrapper) GetInt(predicateName string) (int, error) { + return w.node.GetInt(predicateName) +} + +func (w wrapper) Lookup(predicateName string) (DSLNode, error) { + found, err := w.node.Lookup(predicateName) + if err != nil { + return nil, err + } + + return wrapper{found}, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go new file mode 100644 index 0000000..fdc1735 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go @@ -0,0 +1,2 @@ +// Package compiler knows how to build the Go representation of a SpiceDB schema text. +package compiler diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go new file mode 100644 index 0000000..2c33ba8 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go @@ -0,0 +1,53 @@ +package compiler + +import ( + "strconv" + + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +// BaseCompilerError defines an error with contains the base message of the issue +// that occurred. +type BaseCompilerError struct { + error + BaseMessage string +} + +type withNodeError struct { + error + node *dslNode + errorSourceCode string +} + +// WithContextError defines an error which contains contextual information. +type WithContextError struct { + BaseCompilerError + SourceRange input.SourceRange + Source input.Source + ErrorSourceCode string +} + +func (ewc WithContextError) Unwrap() error { + return ewc.BaseCompilerError +} + +// DetailsMetadata returns the metadata for details for this error. +func (ewc WithContextError) DetailsMetadata() map[string]string { + startLine, startCol, err := ewc.SourceRange.Start().LineAndColumn() + if err != nil { + return map[string]string{} + } + + endLine, endCol, err := ewc.SourceRange.End().LineAndColumn() + if err != nil { + return map[string]string{} + } + + return map[string]string{ + "start_line_number": strconv.Itoa(startLine), + "start_column_position": strconv.Itoa(startCol), + "end_line_number": strconv.Itoa(endLine), + "end_column_position": strconv.Itoa(endCol), + "source_code": ewc.ErrorSourceCode, + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go new file mode 100644 index 0000000..b7e2a70 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go @@ -0,0 +1,180 @@ +package compiler + +import ( + "container/list" + "fmt" + + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/schemadsl/parser" +) + +type dslNode struct { + nodeType dslshape.NodeType + properties map[string]interface{} + children map[string]*list.List +} + +func createAstNode(_ input.Source, kind dslshape.NodeType) parser.AstNode { + return &dslNode{ + nodeType: kind, + properties: make(map[string]interface{}), + children: make(map[string]*list.List), + } +} + +func (tn *dslNode) GetType() dslshape.NodeType { + return tn.nodeType +} + +func (tn *dslNode) Connect(predicate string, other parser.AstNode) { + if tn.children[predicate] == nil { + tn.children[predicate] = list.New() + } + + tn.children[predicate].PushBack(other) +} + +func (tn *dslNode) MustDecorate(property string, value string) parser.AstNode { + if _, ok := tn.properties[property]; ok { + panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties)) + } + + tn.properties[property] = value + return tn +} + +func (tn *dslNode) MustDecorateWithInt(property string, value int) parser.AstNode { + if _, ok := tn.properties[property]; ok { + panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties)) + } + + tn.properties[property] = value + return tn +} + +func (tn *dslNode) Range(mapper input.PositionMapper) (input.SourceRange, error) { + sourceStr, err := tn.GetString(dslshape.NodePredicateSource) + if err != nil { + return nil, err + } + + source := input.Source(sourceStr) + + startRune, err := tn.GetInt(dslshape.NodePredicateStartRune) + if err != nil { + return nil, err + } + + endRune, err := tn.GetInt(dslshape.NodePredicateEndRune) + if err != nil { + return nil, err + } + + return source.RangeForRunePositions(startRune, endRune, mapper), nil +} + +func (tn *dslNode) Has(predicateName string) bool { + _, ok := tn.properties[predicateName] + return ok +} + +func (tn *dslNode) GetInt(predicateName string) (int, error) { + predicate, ok := tn.properties[predicateName] + if !ok { + return 0, fmt.Errorf("unknown predicate %s", predicateName) + } + + value, ok := predicate.(int) + if !ok { + return 0, fmt.Errorf("predicate %s is not an int", predicateName) + } + + return value, nil +} + +func (tn *dslNode) GetString(predicateName string) (string, error) { + predicate, ok := tn.properties[predicateName] + if !ok { + return "", fmt.Errorf("unknown predicate %s", predicateName) + } + + value, ok := predicate.(string) + if !ok { + return "", fmt.Errorf("predicate %s is not a string", predicateName) + } + + return value, nil +} + +func (tn *dslNode) AllSubNodes() []*dslNode { + nodes := []*dslNode{} + for _, childList := range tn.children { + for e := childList.Front(); e != nil; e = e.Next() { + nodes = append(nodes, e.Value.(*dslNode)) + } + } + return nodes +} + +func (tn *dslNode) GetChildren() []*dslNode { + return tn.List(dslshape.NodePredicateChild) +} + +func (tn *dslNode) FindAll(nodeType dslshape.NodeType) []*dslNode { + found := []*dslNode{} + if tn.nodeType == dslshape.NodeTypeError { + found = append(found, tn) + } + + for _, childList := range tn.children { + for e := childList.Front(); e != nil; e = e.Next() { + childFound := e.Value.(*dslNode).FindAll(nodeType) + found = append(found, childFound...) + } + } + return found +} + +func (tn *dslNode) List(predicateName string) []*dslNode { + children := []*dslNode{} + childList, ok := tn.children[predicateName] + if !ok { + return children + } + + for e := childList.Front(); e != nil; e = e.Next() { + children = append(children, e.Value.(*dslNode)) + } + + return children +} + +func (tn *dslNode) Lookup(predicateName string) (*dslNode, error) { + childList, ok := tn.children[predicateName] + if !ok { + return nil, fmt.Errorf("unknown predicate %s", predicateName) + } + + for e := childList.Front(); e != nil; e = e.Next() { + return e.Value.(*dslNode), nil + } + + return nil, fmt.Errorf("nothing in predicate %s", predicateName) +} + +func (tn *dslNode) Errorf(message string, args ...interface{}) error { + return withNodeError{ + error: fmt.Errorf(message, args...), + errorSourceCode: "", + node: tn, + } +} + +func (tn *dslNode) WithSourceErrorf(sourceCode string, message string, args ...interface{}) error { + return withNodeError{ + error: fmt.Errorf(message, args...), + errorSourceCode: sourceCode, + node: tn, + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go new file mode 100644 index 0000000..aa33c43 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go @@ -0,0 +1,32 @@ +package compiler + +import ( + "strings" + + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +type positionMapper struct { + schema InputSchema + mapper input.SourcePositionMapper +} + +func newPositionMapper(schema InputSchema) input.PositionMapper { + return &positionMapper{ + schema: schema, + mapper: input.CreateSourcePositionMapper([]byte(schema.SchemaString)), + } +} + +func (pm *positionMapper) RunePositionToLineAndCol(runePosition int, _ input.Source) (int, int, error) { + return pm.mapper.RunePositionToLineAndCol(runePosition) +} + +func (pm *positionMapper) LineAndColToRunePosition(lineNumber int, colPosition int, _ input.Source) (int, error) { + return pm.mapper.LineAndColToRunePosition(lineNumber, colPosition) +} + +func (pm *positionMapper) TextForLine(lineNumber int, _ input.Source) (string, error) { + lines := strings.Split(pm.schema.SchemaString, "\n") + return lines[lineNumber], nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go new file mode 100644 index 0000000..77877b0 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go @@ -0,0 +1,714 @@ +package compiler + +import ( + "bufio" + "fmt" + "slices" + "strings" + + "github.com/ccoveille/go-safecast" + "github.com/jzelinskie/stringz" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +type translationContext struct { + objectTypePrefix *string + mapper input.PositionMapper + schemaString string + skipValidate bool + allowedFlags []string + enabledFlags []string + caveatTypeSet *caveattypes.TypeSet +} + +func (tctx *translationContext) prefixedPath(definitionName string) (string, error) { + var prefix, name string + if err := stringz.SplitInto(definitionName, "/", &prefix, &name); err != nil { + if tctx.objectTypePrefix == nil { + return "", fmt.Errorf("found reference `%s` without prefix", definitionName) + } + prefix = *tctx.objectTypePrefix + name = definitionName + } + + if prefix == "" { + return name, nil + } + + return stringz.Join("/", prefix, name), nil +} + +const Ellipsis = "..." + +func translate(tctx *translationContext, root *dslNode) (*CompiledSchema, error) { + orderedDefinitions := make([]SchemaDefinition, 0, len(root.GetChildren())) + var objectDefinitions []*core.NamespaceDefinition + var caveatDefinitions []*core.CaveatDefinition + + names := mapz.NewSet[string]() + + for _, definitionNode := range root.GetChildren() { + var definition SchemaDefinition + + switch definitionNode.GetType() { + case dslshape.NodeTypeUseFlag: + err := translateUseFlag(tctx, definitionNode) + if err != nil { + return nil, err + } + continue + + case dslshape.NodeTypeCaveatDefinition: + def, err := translateCaveatDefinition(tctx, definitionNode) + if err != nil { + return nil, err + } + + definition = def + caveatDefinitions = append(caveatDefinitions, def) + + case dslshape.NodeTypeDefinition: + def, err := translateObjectDefinition(tctx, definitionNode) + if err != nil { + return nil, err + } + + definition = def + objectDefinitions = append(objectDefinitions, def) + } + + if !names.Add(definition.GetName()) { + return nil, definitionNode.WithSourceErrorf(definition.GetName(), "found name reused between multiple definitions and/or caveats: %s", definition.GetName()) + } + + orderedDefinitions = append(orderedDefinitions, definition) + } + + return &CompiledSchema{ + CaveatDefinitions: caveatDefinitions, + ObjectDefinitions: objectDefinitions, + OrderedDefinitions: orderedDefinitions, + rootNode: root, + mapper: tctx.mapper, + }, nil +} + +func translateCaveatDefinition(tctx *translationContext, defNode *dslNode) (*core.CaveatDefinition, error) { + definitionName, err := defNode.GetString(dslshape.NodeCaveatDefinitionPredicateName) + if err != nil { + return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) + } + + // parameters + paramNodes := defNode.List(dslshape.NodeCaveatDefinitionPredicateParameters) + if len(paramNodes) == 0 { + return nil, defNode.WithSourceErrorf(definitionName, "caveat `%s` must have at least one parameter defined", definitionName) + } + + env := caveats.NewEnvironment() + parameters := make(map[string]caveattypes.VariableType, len(paramNodes)) + for _, paramNode := range paramNodes { + paramName, err := paramNode.GetString(dslshape.NodeCaveatParameterPredicateName) + if err != nil { + return nil, paramNode.WithSourceErrorf(paramName, "invalid parameter name: %w", err) + } + + if _, ok := parameters[paramName]; ok { + return nil, paramNode.WithSourceErrorf(paramName, "duplicate parameter `%s` defined on caveat `%s`", paramName, definitionName) + } + + typeRefNode, err := paramNode.Lookup(dslshape.NodeCaveatParameterPredicateType) + if err != nil { + return nil, paramNode.WithSourceErrorf(paramName, "invalid type for parameter: %w", err) + } + + translatedType, err := translateCaveatTypeReference(tctx, typeRefNode) + if err != nil { + return nil, paramNode.WithSourceErrorf(paramName, "invalid type for caveat parameter `%s` on caveat `%s`: %w", paramName, definitionName, err) + } + + parameters[paramName] = *translatedType + err = env.AddVariable(paramName, *translatedType) + if err != nil { + return nil, paramNode.WithSourceErrorf(paramName, "invalid type for caveat parameter `%s` on caveat `%s`: %w", paramName, definitionName, err) + } + } + + caveatPath, err := tctx.prefixedPath(definitionName) + if err != nil { + return nil, defNode.Errorf("%w", err) + } + + // caveat expression. + expressionStringNode, err := defNode.Lookup(dslshape.NodeCaveatDefinitionPredicateExpession) + if err != nil { + return nil, defNode.WithSourceErrorf(definitionName, "invalid expression: %w", err) + } + + expressionString, err := expressionStringNode.GetString(dslshape.NodeCaveatExpressionPredicateExpression) + if err != nil { + return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err) + } + + rnge, err := expressionStringNode.Range(tctx.mapper) + if err != nil { + return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err) + } + + source, err := caveats.NewSource(expressionString, caveatPath) + if err != nil { + return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err) + } + + compiled, err := caveats.CompileCaveatWithSource(env, caveatPath, source, rnge.Start()) + if err != nil { + return nil, expressionStringNode.WithSourceErrorf(expressionString, "invalid expression for caveat `%s`: %w", definitionName, err) + } + + def, err := namespace.CompiledCaveatDefinition(env, caveatPath, compiled) + if err != nil { + return nil, err + } + + def.Metadata = addComments(def.Metadata, defNode) + def.SourcePosition = getSourcePosition(defNode, tctx.mapper) + return def, nil +} + +func translateCaveatTypeReference(tctx *translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) { + typeName, err := typeRefNode.GetString(dslshape.NodeCaveatTypeReferencePredicateType) + if err != nil { + return nil, typeRefNode.WithSourceErrorf(typeName, "invalid type name: %w", err) + } + + childTypeNodes := typeRefNode.List(dslshape.NodeCaveatTypeReferencePredicateChildTypes) + childTypes := make([]caveattypes.VariableType, 0, len(childTypeNodes)) + for _, childTypeNode := range childTypeNodes { + translated, err := translateCaveatTypeReference(tctx, childTypeNode) + if err != nil { + return nil, err + } + childTypes = append(childTypes, *translated) + } + + constructedType, err := tctx.caveatTypeSet.BuildType(typeName, childTypes) + if err != nil { + return nil, typeRefNode.WithSourceErrorf(typeName, "%w", err) + } + + return constructedType, nil +} + +func translateObjectDefinition(tctx *translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) { + definitionName, err := defNode.GetString(dslshape.NodeDefinitionPredicateName) + if err != nil { + return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) + } + + relationsAndPermissions := []*core.Relation{} + for _, relationOrPermissionNode := range defNode.GetChildren() { + if relationOrPermissionNode.GetType() == dslshape.NodeTypeComment { + continue + } + + relationOrPermission, err := translateRelationOrPermission(tctx, relationOrPermissionNode) + if err != nil { + return nil, err + } + + relationsAndPermissions = append(relationsAndPermissions, relationOrPermission) + } + + nspath, err := tctx.prefixedPath(definitionName) + if err != nil { + return nil, defNode.Errorf("%w", err) + } + + if len(relationsAndPermissions) == 0 { + ns := namespace.Namespace(nspath) + ns.Metadata = addComments(ns.Metadata, defNode) + ns.SourcePosition = getSourcePosition(defNode, tctx.mapper) + + if !tctx.skipValidate { + if err = ns.Validate(); err != nil { + return nil, defNode.Errorf("error in object definition %s: %w", nspath, err) + } + } + + return ns, nil + } + + ns := namespace.Namespace(nspath, relationsAndPermissions...) + ns.Metadata = addComments(ns.Metadata, defNode) + ns.SourcePosition = getSourcePosition(defNode, tctx.mapper) + + if !tctx.skipValidate { + if err := ns.Validate(); err != nil { + return nil, defNode.Errorf("error in object definition %s: %w", nspath, err) + } + } + + return ns, nil +} + +func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.SourcePosition { + if !dslNode.Has(dslshape.NodePredicateStartRune) { + return nil + } + + sourceRange, err := dslNode.Range(mapper) + if err != nil { + return nil + } + + line, col, err := sourceRange.Start().LineAndColumn() + if err != nil { + return nil + } + + // We're okay with these being zero if the cast fails. + uintLine, _ := safecast.ToUint64(line) + uintCol, _ := safecast.ToUint64(col) + + return &core.SourcePosition{ + ZeroIndexedLineNumber: uintLine, + ZeroIndexedColumnPosition: uintCol, + } +} + +func addComments(mdmsg *core.Metadata, dslNode *dslNode) *core.Metadata { + for _, child := range dslNode.GetChildren() { + if child.GetType() == dslshape.NodeTypeComment { + value, err := child.GetString(dslshape.NodeCommentPredicateValue) + if err == nil { + mdmsg, _ = namespace.AddComment(mdmsg, normalizeComment(value)) + } + } + } + return mdmsg +} + +func normalizeComment(value string) string { + var lines []string + scanner := bufio.NewScanner(strings.NewReader(value)) + for scanner.Scan() { + trimmed := strings.TrimSpace(scanner.Text()) + lines = append(lines, trimmed) + } + return strings.Join(lines, "\n") +} + +func translateRelationOrPermission(tctx *translationContext, relOrPermNode *dslNode) (*core.Relation, error) { + switch relOrPermNode.GetType() { + case dslshape.NodeTypeRelation: + rel, err := translateRelation(tctx, relOrPermNode) + if err != nil { + return nil, err + } + rel.Metadata = addComments(rel.Metadata, relOrPermNode) + rel.SourcePosition = getSourcePosition(relOrPermNode, tctx.mapper) + return rel, err + + case dslshape.NodeTypePermission: + rel, err := translatePermission(tctx, relOrPermNode) + if err != nil { + return nil, err + } + rel.Metadata = addComments(rel.Metadata, relOrPermNode) + rel.SourcePosition = getSourcePosition(relOrPermNode, tctx.mapper) + return rel, err + + default: + return nil, relOrPermNode.Errorf("unknown definition top-level node type %s", relOrPermNode.GetType()) + } +} + +func translateRelation(tctx *translationContext, relationNode *dslNode) (*core.Relation, error) { + relationName, err := relationNode.GetString(dslshape.NodePredicateName) + if err != nil { + return nil, relationNode.Errorf("invalid relation name: %w", err) + } + + allowedDirectTypes := []*core.AllowedRelation{} + for _, typeRef := range relationNode.List(dslshape.NodeRelationPredicateAllowedTypes) { + allowedRelations, err := translateAllowedRelations(tctx, typeRef) + if err != nil { + return nil, err + } + + allowedDirectTypes = append(allowedDirectTypes, allowedRelations...) + } + + relation, err := namespace.Relation(relationName, nil, allowedDirectTypes...) + if err != nil { + return nil, err + } + + if !tctx.skipValidate { + if err := relation.Validate(); err != nil { + return nil, relationNode.Errorf("error in relation %s: %w", relationName, err) + } + } + + return relation, nil +} + +func translatePermission(tctx *translationContext, permissionNode *dslNode) (*core.Relation, error) { + permissionName, err := permissionNode.GetString(dslshape.NodePredicateName) + if err != nil { + return nil, permissionNode.Errorf("invalid permission name: %w", err) + } + + expressionNode, err := permissionNode.Lookup(dslshape.NodePermissionPredicateComputeExpression) + if err != nil { + return nil, permissionNode.Errorf("invalid permission expression: %w", err) + } + + rewrite, err := translateExpression(tctx, expressionNode) + if err != nil { + return nil, err + } + + permission, err := namespace.Relation(permissionName, rewrite) + if err != nil { + return nil, err + } + + if !tctx.skipValidate { + if err := permission.Validate(); err != nil { + return nil, permissionNode.Errorf("error in permission %s: %w", permissionName, err) + } + } + + return permission, nil +} + +func translateBinary(tctx *translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) { + leftChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr) + if err != nil { + return nil, nil, err + } + + rightChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateRightExpr) + if err != nil { + return nil, nil, err + } + + leftOperation, err := translateExpressionOperation(tctx, leftChild) + if err != nil { + return nil, nil, err + } + + rightOperation, err := translateExpressionOperation(tctx, rightChild) + if err != nil { + return nil, nil, err + } + + return leftOperation, rightOperation, nil +} + +func translateExpression(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { + translated, err := translateExpressionDirect(tctx, expressionNode) + if err != nil { + return translated, err + } + + translated.SourcePosition = getSourcePosition(expressionNode, tctx.mapper) + return translated, nil +} + +func collapseOps(op *core.SetOperation_Child, handler func(rewrite *core.UsersetRewrite) *core.SetOperation) []*core.SetOperation_Child { + if op.GetUsersetRewrite() == nil { + return []*core.SetOperation_Child{op} + } + + usersetRewrite := op.GetUsersetRewrite() + operation := handler(usersetRewrite) + if operation == nil { + return []*core.SetOperation_Child{op} + } + + collapsed := make([]*core.SetOperation_Child, 0, len(operation.Child)) + for _, child := range operation.Child { + collapsed = append(collapsed, collapseOps(child, handler)...) + } + return collapsed +} + +func translateExpressionDirect(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { + // For union and intersection, we collapse a tree of binary operations into a flat list containing child + // operations of the *same* type. + translate := func( + builder func(firstChild *core.SetOperation_Child, rest ...*core.SetOperation_Child) *core.UsersetRewrite, + lookup func(rewrite *core.UsersetRewrite) *core.SetOperation, + ) (*core.UsersetRewrite, error) { + leftOperation, rightOperation, err := translateBinary(tctx, expressionNode) + if err != nil { + return nil, err + } + leftOps := collapseOps(leftOperation, lookup) + rightOps := collapseOps(rightOperation, lookup) + ops := append(leftOps, rightOps...) + return builder(ops[0], ops[1:]...), nil + } + + switch expressionNode.GetType() { + case dslshape.NodeTypeUnionExpression: + return translate(namespace.Union, func(rewrite *core.UsersetRewrite) *core.SetOperation { + return rewrite.GetUnion() + }) + + case dslshape.NodeTypeIntersectExpression: + return translate(namespace.Intersection, func(rewrite *core.UsersetRewrite) *core.SetOperation { + return rewrite.GetIntersection() + }) + + case dslshape.NodeTypeExclusionExpression: + // Order matters for exclusions, so do not perform the optimization. + leftOperation, rightOperation, err := translateBinary(tctx, expressionNode) + if err != nil { + return nil, err + } + return namespace.Exclusion(leftOperation, rightOperation), nil + + default: + op, err := translateExpressionOperation(tctx, expressionNode) + if err != nil { + return nil, err + } + + return namespace.Union(op), nil + } +} + +func translateExpressionOperation(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { + translated, err := translateExpressionOperationDirect(tctx, expressionOpNode) + if err != nil { + return translated, err + } + + translated.SourcePosition = getSourcePosition(expressionOpNode, tctx.mapper) + return translated, nil +} + +func translateExpressionOperationDirect(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { + switch expressionOpNode.GetType() { + case dslshape.NodeTypeIdentifier: + referencedRelationName, err := expressionOpNode.GetString(dslshape.NodeIdentiferPredicateValue) + if err != nil { + return nil, err + } + + return namespace.ComputedUserset(referencedRelationName), nil + + case dslshape.NodeTypeNilExpression: + return namespace.Nil(), nil + + case dslshape.NodeTypeArrowExpression: + leftChild, err := expressionOpNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr) + if err != nil { + return nil, err + } + + rightChild, err := expressionOpNode.Lookup(dslshape.NodeExpressionPredicateRightExpr) + if err != nil { + return nil, err + } + + if leftChild.GetType() != dslshape.NodeTypeIdentifier { + return nil, leftChild.Errorf("Nested arrows not yet supported") + } + + tuplesetRelation, err := leftChild.GetString(dslshape.NodeIdentiferPredicateValue) + if err != nil { + return nil, err + } + + usersetRelation, err := rightChild.GetString(dslshape.NodeIdentiferPredicateValue) + if err != nil { + return nil, err + } + + if expressionOpNode.Has(dslshape.NodeArrowExpressionFunctionName) { + functionName, err := expressionOpNode.GetString(dslshape.NodeArrowExpressionFunctionName) + if err != nil { + return nil, err + } + + return namespace.MustFunctionedTupleToUserset(tuplesetRelation, functionName, usersetRelation), nil + } + + return namespace.TupleToUserset(tuplesetRelation, usersetRelation), nil + + case dslshape.NodeTypeUnionExpression: + fallthrough + + case dslshape.NodeTypeIntersectExpression: + fallthrough + + case dslshape.NodeTypeExclusionExpression: + rewrite, err := translateExpression(tctx, expressionOpNode) + if err != nil { + return nil, err + } + return namespace.Rewrite(rewrite), nil + + default: + return nil, expressionOpNode.Errorf("unknown expression node type %s", expressionOpNode.GetType()) + } +} + +func translateAllowedRelations(tctx *translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) { + switch typeRefNode.GetType() { + case dslshape.NodeTypeTypeReference: + references := []*core.AllowedRelation{} + for _, subRefNode := range typeRefNode.List(dslshape.NodeTypeReferencePredicateType) { + subReferences, err := translateAllowedRelations(tctx, subRefNode) + if err != nil { + return []*core.AllowedRelation{}, err + } + + references = append(references, subReferences...) + } + return references, nil + + case dslshape.NodeTypeSpecificTypeReference: + ref, err := translateSpecificTypeReference(tctx, typeRefNode) + if err != nil { + return []*core.AllowedRelation{}, err + } + return []*core.AllowedRelation{ref}, nil + + default: + return nil, typeRefNode.Errorf("unknown type ref node type %s", typeRefNode.GetType()) + } +} + +func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) { + typePath, err := typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateType) + if err != nil { + return nil, typeRefNode.Errorf("invalid type name: %w", err) + } + + nspath, err := tctx.prefixedPath(typePath) + if err != nil { + return nil, typeRefNode.Errorf("%w", err) + } + + if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateWildcard) { + ref := &core.AllowedRelation{ + Namespace: nspath, + RelationOrWildcard: &core.AllowedRelation_PublicWildcard_{ + PublicWildcard: &core.AllowedRelation_PublicWildcard{}, + }, + } + + err = addWithCaveats(tctx, typeRefNode, ref) + if err != nil { + return nil, typeRefNode.Errorf("invalid caveat: %w", err) + } + + if !tctx.skipValidate { + if err := ref.Validate(); err != nil { + return nil, typeRefNode.Errorf("invalid type relation: %w", err) + } + } + + ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper) + return ref, nil + } + + relationName := Ellipsis + if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateRelation) { + relationName, err = typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateRelation) + if err != nil { + return nil, typeRefNode.Errorf("invalid type relation: %w", err) + } + } + + ref := &core.AllowedRelation{ + Namespace: nspath, + RelationOrWildcard: &core.AllowedRelation_Relation{ + Relation: relationName, + }, + } + + // Add the caveat(s), if any. + err = addWithCaveats(tctx, typeRefNode, ref) + if err != nil { + return nil, typeRefNode.Errorf("invalid caveat: %w", err) + } + + // Add the expiration trait, if any. + if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil { + traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait) + if err != nil { + return nil, typeRefNode.Errorf("invalid trait: %w", err) + } + + if traitName != "expiration" { + return nil, typeRefNode.Errorf("invalid trait: %s", traitName) + } + + if !slices.Contains(tctx.allowedFlags, "expiration") { + return nil, typeRefNode.Errorf("expiration trait is not allowed") + } + + ref.RequiredExpiration = &core.ExpirationTrait{} + } + + if !tctx.skipValidate { + if err := ref.Validate(); err != nil { + return nil, typeRefNode.Errorf("invalid type relation: %w", err) + } + } + + ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper) + return ref, nil +} + +func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error { + caveats := typeRefNode.List(dslshape.NodeSpecificReferencePredicateCaveat) + if len(caveats) == 0 { + return nil + } + + if len(caveats) != 1 { + return fmt.Errorf("only one caveat is currently allowed per type reference") + } + + name, err := caveats[0].GetString(dslshape.NodeCaveatPredicateCaveat) + if err != nil { + return err + } + + nspath, err := tctx.prefixedPath(name) + if err != nil { + return err + } + + ref.RequiredCaveat = &core.AllowedCaveat{ + CaveatName: nspath, + } + return nil +} + +// Translate use node and add flag to list of enabled flags +func translateUseFlag(tctx *translationContext, useFlagNode *dslNode) error { + flagName, err := useFlagNode.GetString(dslshape.NodeUseFlagPredicateName) + if err != nil { + return err + } + if slices.Contains(tctx.enabledFlags, flagName) { + return useFlagNode.Errorf("found duplicate use flag: %s", flagName) + } + tctx.enabledFlags = append(tctx.enabledFlags, flagName) + return nil +} |
