diff options
Diffstat (limited to 'vendor/github.com/authzed/spicedb/pkg/schemadsl')
22 files changed, 4194 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go new file mode 100644 index 0000000..d1e96ec --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/compiler.go @@ -0,0 +1,194 @@ +package compiler + +import ( + "errors" + "fmt" + + "google.golang.org/protobuf/proto" + "k8s.io/utils/strings/slices" + + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/schemadsl/parser" +) + +// InputSchema defines the input for a Compile. +type InputSchema struct { + // Source is the source of the schema being compiled. + Source input.Source + + // Schema is the contents being compiled. + SchemaString string +} + +// SchemaDefinition represents an object or caveat definition in a schema. +type SchemaDefinition interface { + proto.Message + + GetName() string +} + +// CompiledSchema is the result of compiling a schema when there are no errors. +type CompiledSchema struct { + // ObjectDefinitions holds the object definitions in the schema. + ObjectDefinitions []*core.NamespaceDefinition + + // CaveatDefinitions holds the caveat definitions in the schema. + CaveatDefinitions []*core.CaveatDefinition + + // OrderedDefinitions holds the object and caveat definitions in the schema, in the + // order in which they were found. + OrderedDefinitions []SchemaDefinition + + rootNode *dslNode + mapper input.PositionMapper +} + +// SourcePositionToRunePosition converts a source position to a rune position. +func (cs CompiledSchema) SourcePositionToRunePosition(source input.Source, position input.Position) (int, error) { + return cs.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source) +} + +type config struct { + skipValidation bool + objectTypePrefix *string + allowedFlags []string + caveatTypeSet *caveattypes.TypeSet +} + +func SkipValidation() Option { return func(cfg *config) { cfg.skipValidation = true } } + +func ObjectTypePrefix(prefix string) ObjectPrefixOption { + return func(cfg *config) { cfg.objectTypePrefix = &prefix } +} + +func RequirePrefixedObjectType() ObjectPrefixOption { + return func(cfg *config) { cfg.objectTypePrefix = nil } +} + +func AllowUnprefixedObjectType() ObjectPrefixOption { + return func(cfg *config) { cfg.objectTypePrefix = new(string) } +} + +func CaveatTypeSet(cts *caveattypes.TypeSet) Option { + return func(cfg *config) { cfg.caveatTypeSet = cts } +} + +const expirationFlag = "expiration" + +func DisallowExpirationFlag() Option { + return func(cfg *config) { + cfg.allowedFlags = slices.Filter([]string{}, cfg.allowedFlags, func(s string) bool { + return s != expirationFlag + }) + } +} + +type Option func(*config) + +type ObjectPrefixOption func(*config) + +// Compile compilers the input schema into a set of namespace definition protos. +func Compile(schema InputSchema, prefix ObjectPrefixOption, opts ...Option) (*CompiledSchema, error) { + cfg := &config{ + allowedFlags: make([]string, 0, 1), + } + + // Enable `expiration` flag by default. + cfg.allowedFlags = append(cfg.allowedFlags, expirationFlag) + + prefix(cfg) // required option + + for _, fn := range opts { + fn(cfg) + } + + mapper := newPositionMapper(schema) + root := parser.Parse(createAstNode, schema.Source, schema.SchemaString).(*dslNode) + errs := root.FindAll(dslshape.NodeTypeError) + if len(errs) > 0 { + err := errorNodeToError(errs[0], mapper) + return nil, err + } + + cts := caveattypes.TypeSetOrDefault(cfg.caveatTypeSet) + compiled, err := translate(&translationContext{ + objectTypePrefix: cfg.objectTypePrefix, + mapper: mapper, + schemaString: schema.SchemaString, + skipValidate: cfg.skipValidation, + allowedFlags: cfg.allowedFlags, + caveatTypeSet: cts, + }, root) + if err != nil { + var withNodeError withNodeError + if errors.As(err, &withNodeError) { + err = toContextError(withNodeError.error.Error(), withNodeError.errorSourceCode, withNodeError.node, mapper) + } + + return nil, err + } + + return compiled, nil +} + +func errorNodeToError(node *dslNode, mapper input.PositionMapper) error { + if node.GetType() != dslshape.NodeTypeError { + return fmt.Errorf("given none error node") + } + + errMessage, err := node.GetString(dslshape.NodePredicateErrorMessage) + if err != nil { + return fmt.Errorf("could not get error message for error node: %w", err) + } + + errorSourceCode := "" + if node.Has(dslshape.NodePredicateErrorSource) { + es, err := node.GetString(dslshape.NodePredicateErrorSource) + if err != nil { + return fmt.Errorf("could not get error source for error node: %w", err) + } + + errorSourceCode = es + } + + return toContextError(errMessage, errorSourceCode, node, mapper) +} + +func toContextError(errMessage string, errorSourceCode string, node *dslNode, mapper input.PositionMapper) error { + sourceRange, err := node.Range(mapper) + if err != nil { + return fmt.Errorf("could not get range for error node: %w", err) + } + + formattedRange, err := formatRange(sourceRange) + if err != nil { + return err + } + + source, err := node.GetString(dslshape.NodePredicateSource) + if err != nil { + return fmt.Errorf("missing source for node: %w", err) + } + + return WithContextError{ + BaseCompilerError: BaseCompilerError{ + error: fmt.Errorf("parse error in %s: %s", formattedRange, errMessage), + BaseMessage: errMessage, + }, + SourceRange: sourceRange, + Source: input.Source(source), + ErrorSourceCode: errorSourceCode, + } +} + +func formatRange(rnge input.SourceRange) (string, error) { + startLine, startCol, err := rnge.Start().LineAndColumn() + if err != nil { + return "", err + } + + return fmt.Sprintf("`%s`, line %v, column %v", rnge.Source(), startLine+1, startCol+1), nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go new file mode 100644 index 0000000..7c8e7c7 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/development.go @@ -0,0 +1,142 @@ +package compiler + +import ( + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +// DSLNode is a node in the DSL AST. +type DSLNode interface { + GetType() dslshape.NodeType + GetString(predicateName string) (string, error) + GetInt(predicateName string) (int, error) + Lookup(predicateName string) (DSLNode, error) +} + +// NodeChain is a chain of nodes in the DSL AST. +type NodeChain struct { + nodes []DSLNode + runePosition int +} + +// Head returns the head node of the chain. +func (nc *NodeChain) Head() DSLNode { + return nc.nodes[0] +} + +// HasHeadType returns true if the head node of the chain is of the given type. +func (nc *NodeChain) HasHeadType(nodeType dslshape.NodeType) bool { + return nc.nodes[0].GetType() == nodeType +} + +// ForRunePosition returns the rune position of the chain. +func (nc *NodeChain) ForRunePosition() int { + return nc.runePosition +} + +// FindNodeOfType returns the first node of the given type in the chain, if any. +func (nc *NodeChain) FindNodeOfType(nodeType dslshape.NodeType) DSLNode { + for _, node := range nc.nodes { + if node.GetType() == nodeType { + return node + } + } + + return nil +} + +func (nc *NodeChain) String() string { + var out string + for _, node := range nc.nodes { + out += node.GetType().String() + " " + } + return out +} + +// PositionToAstNodeChain returns the AST node, and its parents (if any), found at the given position in the source, if any. +func PositionToAstNodeChain(schema *CompiledSchema, source input.Source, position input.Position) (*NodeChain, error) { + rootSource, err := schema.rootNode.GetString(dslshape.NodePredicateSource) + if err != nil { + return nil, err + } + + if rootSource != string(source) { + return nil, nil + } + + // Map the position to a file rune. + runePosition, err := schema.mapper.LineAndColToRunePosition(position.LineNumber, position.ColumnPosition, source) + if err != nil { + return nil, err + } + + // Find the node at the rune position. + found, err := runePositionToAstNodeChain(schema.rootNode, runePosition) + if err != nil { + return nil, err + } + + if found == nil { + return nil, nil + } + + return &NodeChain{nodes: found, runePosition: runePosition}, nil +} + +func runePositionToAstNodeChain(node *dslNode, runePosition int) ([]DSLNode, error) { + if !node.Has(dslshape.NodePredicateStartRune) { + return nil, nil + } + + startRune, err := node.GetInt(dslshape.NodePredicateStartRune) + if err != nil { + return nil, err + } + + endRune, err := node.GetInt(dslshape.NodePredicateEndRune) + if err != nil { + return nil, err + } + + if runePosition < startRune || runePosition > endRune { + return nil, nil + } + + for _, child := range node.AllSubNodes() { + childChain, err := runePositionToAstNodeChain(child, runePosition) + if err != nil { + return nil, err + } + + if childChain != nil { + return append(childChain, wrapper{node}), nil + } + } + + return []DSLNode{wrapper{node}}, nil +} + +type wrapper struct { + node *dslNode +} + +func (w wrapper) GetType() dslshape.NodeType { + return w.node.GetType() +} + +func (w wrapper) GetString(predicateName string) (string, error) { + return w.node.GetString(predicateName) +} + +func (w wrapper) GetInt(predicateName string) (int, error) { + return w.node.GetInt(predicateName) +} + +func (w wrapper) Lookup(predicateName string) (DSLNode, error) { + found, err := w.node.Lookup(predicateName) + if err != nil { + return nil, err + } + + return wrapper{found}, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go new file mode 100644 index 0000000..fdc1735 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/doc.go @@ -0,0 +1,2 @@ +// Package compiler knows how to build the Go representation of a SpiceDB schema text. +package compiler diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go new file mode 100644 index 0000000..2c33ba8 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/errors.go @@ -0,0 +1,53 @@ +package compiler + +import ( + "strconv" + + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +// BaseCompilerError defines an error with contains the base message of the issue +// that occurred. +type BaseCompilerError struct { + error + BaseMessage string +} + +type withNodeError struct { + error + node *dslNode + errorSourceCode string +} + +// WithContextError defines an error which contains contextual information. +type WithContextError struct { + BaseCompilerError + SourceRange input.SourceRange + Source input.Source + ErrorSourceCode string +} + +func (ewc WithContextError) Unwrap() error { + return ewc.BaseCompilerError +} + +// DetailsMetadata returns the metadata for details for this error. +func (ewc WithContextError) DetailsMetadata() map[string]string { + startLine, startCol, err := ewc.SourceRange.Start().LineAndColumn() + if err != nil { + return map[string]string{} + } + + endLine, endCol, err := ewc.SourceRange.End().LineAndColumn() + if err != nil { + return map[string]string{} + } + + return map[string]string{ + "start_line_number": strconv.Itoa(startLine), + "start_column_position": strconv.Itoa(startCol), + "end_line_number": strconv.Itoa(endLine), + "end_column_position": strconv.Itoa(endCol), + "source_code": ewc.ErrorSourceCode, + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go new file mode 100644 index 0000000..b7e2a70 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/node.go @@ -0,0 +1,180 @@ +package compiler + +import ( + "container/list" + "fmt" + + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/schemadsl/parser" +) + +type dslNode struct { + nodeType dslshape.NodeType + properties map[string]interface{} + children map[string]*list.List +} + +func createAstNode(_ input.Source, kind dslshape.NodeType) parser.AstNode { + return &dslNode{ + nodeType: kind, + properties: make(map[string]interface{}), + children: make(map[string]*list.List), + } +} + +func (tn *dslNode) GetType() dslshape.NodeType { + return tn.nodeType +} + +func (tn *dslNode) Connect(predicate string, other parser.AstNode) { + if tn.children[predicate] == nil { + tn.children[predicate] = list.New() + } + + tn.children[predicate].PushBack(other) +} + +func (tn *dslNode) MustDecorate(property string, value string) parser.AstNode { + if _, ok := tn.properties[property]; ok { + panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties)) + } + + tn.properties[property] = value + return tn +} + +func (tn *dslNode) MustDecorateWithInt(property string, value int) parser.AstNode { + if _, ok := tn.properties[property]; ok { + panic(fmt.Sprintf("Existing key for property %s\n\tNode: %v", property, tn.properties)) + } + + tn.properties[property] = value + return tn +} + +func (tn *dslNode) Range(mapper input.PositionMapper) (input.SourceRange, error) { + sourceStr, err := tn.GetString(dslshape.NodePredicateSource) + if err != nil { + return nil, err + } + + source := input.Source(sourceStr) + + startRune, err := tn.GetInt(dslshape.NodePredicateStartRune) + if err != nil { + return nil, err + } + + endRune, err := tn.GetInt(dslshape.NodePredicateEndRune) + if err != nil { + return nil, err + } + + return source.RangeForRunePositions(startRune, endRune, mapper), nil +} + +func (tn *dslNode) Has(predicateName string) bool { + _, ok := tn.properties[predicateName] + return ok +} + +func (tn *dslNode) GetInt(predicateName string) (int, error) { + predicate, ok := tn.properties[predicateName] + if !ok { + return 0, fmt.Errorf("unknown predicate %s", predicateName) + } + + value, ok := predicate.(int) + if !ok { + return 0, fmt.Errorf("predicate %s is not an int", predicateName) + } + + return value, nil +} + +func (tn *dslNode) GetString(predicateName string) (string, error) { + predicate, ok := tn.properties[predicateName] + if !ok { + return "", fmt.Errorf("unknown predicate %s", predicateName) + } + + value, ok := predicate.(string) + if !ok { + return "", fmt.Errorf("predicate %s is not a string", predicateName) + } + + return value, nil +} + +func (tn *dslNode) AllSubNodes() []*dslNode { + nodes := []*dslNode{} + for _, childList := range tn.children { + for e := childList.Front(); e != nil; e = e.Next() { + nodes = append(nodes, e.Value.(*dslNode)) + } + } + return nodes +} + +func (tn *dslNode) GetChildren() []*dslNode { + return tn.List(dslshape.NodePredicateChild) +} + +func (tn *dslNode) FindAll(nodeType dslshape.NodeType) []*dslNode { + found := []*dslNode{} + if tn.nodeType == dslshape.NodeTypeError { + found = append(found, tn) + } + + for _, childList := range tn.children { + for e := childList.Front(); e != nil; e = e.Next() { + childFound := e.Value.(*dslNode).FindAll(nodeType) + found = append(found, childFound...) + } + } + return found +} + +func (tn *dslNode) List(predicateName string) []*dslNode { + children := []*dslNode{} + childList, ok := tn.children[predicateName] + if !ok { + return children + } + + for e := childList.Front(); e != nil; e = e.Next() { + children = append(children, e.Value.(*dslNode)) + } + + return children +} + +func (tn *dslNode) Lookup(predicateName string) (*dslNode, error) { + childList, ok := tn.children[predicateName] + if !ok { + return nil, fmt.Errorf("unknown predicate %s", predicateName) + } + + for e := childList.Front(); e != nil; e = e.Next() { + return e.Value.(*dslNode), nil + } + + return nil, fmt.Errorf("nothing in predicate %s", predicateName) +} + +func (tn *dslNode) Errorf(message string, args ...interface{}) error { + return withNodeError{ + error: fmt.Errorf(message, args...), + errorSourceCode: "", + node: tn, + } +} + +func (tn *dslNode) WithSourceErrorf(sourceCode string, message string, args ...interface{}) error { + return withNodeError{ + error: fmt.Errorf(message, args...), + errorSourceCode: sourceCode, + node: tn, + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go new file mode 100644 index 0000000..aa33c43 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/positionmapper.go @@ -0,0 +1,32 @@ +package compiler + +import ( + "strings" + + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +type positionMapper struct { + schema InputSchema + mapper input.SourcePositionMapper +} + +func newPositionMapper(schema InputSchema) input.PositionMapper { + return &positionMapper{ + schema: schema, + mapper: input.CreateSourcePositionMapper([]byte(schema.SchemaString)), + } +} + +func (pm *positionMapper) RunePositionToLineAndCol(runePosition int, _ input.Source) (int, int, error) { + return pm.mapper.RunePositionToLineAndCol(runePosition) +} + +func (pm *positionMapper) LineAndColToRunePosition(lineNumber int, colPosition int, _ input.Source) (int, error) { + return pm.mapper.LineAndColToRunePosition(lineNumber, colPosition) +} + +func (pm *positionMapper) TextForLine(lineNumber int, _ input.Source) (string, error) { + lines := strings.Split(pm.schema.SchemaString, "\n") + return lines[lineNumber], nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go new file mode 100644 index 0000000..77877b0 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/compiler/translator.go @@ -0,0 +1,714 @@ +package compiler + +import ( + "bufio" + "fmt" + "slices" + "strings" + + "github.com/ccoveille/go-safecast" + "github.com/jzelinskie/stringz" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +type translationContext struct { + objectTypePrefix *string + mapper input.PositionMapper + schemaString string + skipValidate bool + allowedFlags []string + enabledFlags []string + caveatTypeSet *caveattypes.TypeSet +} + +func (tctx *translationContext) prefixedPath(definitionName string) (string, error) { + var prefix, name string + if err := stringz.SplitInto(definitionName, "/", &prefix, &name); err != nil { + if tctx.objectTypePrefix == nil { + return "", fmt.Errorf("found reference `%s` without prefix", definitionName) + } + prefix = *tctx.objectTypePrefix + name = definitionName + } + + if prefix == "" { + return name, nil + } + + return stringz.Join("/", prefix, name), nil +} + +const Ellipsis = "..." + +func translate(tctx *translationContext, root *dslNode) (*CompiledSchema, error) { + orderedDefinitions := make([]SchemaDefinition, 0, len(root.GetChildren())) + var objectDefinitions []*core.NamespaceDefinition + var caveatDefinitions []*core.CaveatDefinition + + names := mapz.NewSet[string]() + + for _, definitionNode := range root.GetChildren() { + var definition SchemaDefinition + + switch definitionNode.GetType() { + case dslshape.NodeTypeUseFlag: + err := translateUseFlag(tctx, definitionNode) + if err != nil { + return nil, err + } + continue + + case dslshape.NodeTypeCaveatDefinition: + def, err := translateCaveatDefinition(tctx, definitionNode) + if err != nil { + return nil, err + } + + definition = def + caveatDefinitions = append(caveatDefinitions, def) + + case dslshape.NodeTypeDefinition: + def, err := translateObjectDefinition(tctx, definitionNode) + if err != nil { + return nil, err + } + + definition = def + objectDefinitions = append(objectDefinitions, def) + } + + if !names.Add(definition.GetName()) { + return nil, definitionNode.WithSourceErrorf(definition.GetName(), "found name reused between multiple definitions and/or caveats: %s", definition.GetName()) + } + + orderedDefinitions = append(orderedDefinitions, definition) + } + + return &CompiledSchema{ + CaveatDefinitions: caveatDefinitions, + ObjectDefinitions: objectDefinitions, + OrderedDefinitions: orderedDefinitions, + rootNode: root, + mapper: tctx.mapper, + }, nil +} + +func translateCaveatDefinition(tctx *translationContext, defNode *dslNode) (*core.CaveatDefinition, error) { + definitionName, err := defNode.GetString(dslshape.NodeCaveatDefinitionPredicateName) + if err != nil { + return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) + } + + // parameters + paramNodes := defNode.List(dslshape.NodeCaveatDefinitionPredicateParameters) + if len(paramNodes) == 0 { + return nil, defNode.WithSourceErrorf(definitionName, "caveat `%s` must have at least one parameter defined", definitionName) + } + + env := caveats.NewEnvironment() + parameters := make(map[string]caveattypes.VariableType, len(paramNodes)) + for _, paramNode := range paramNodes { + paramName, err := paramNode.GetString(dslshape.NodeCaveatParameterPredicateName) + if err != nil { + return nil, paramNode.WithSourceErrorf(paramName, "invalid parameter name: %w", err) + } + + if _, ok := parameters[paramName]; ok { + return nil, paramNode.WithSourceErrorf(paramName, "duplicate parameter `%s` defined on caveat `%s`", paramName, definitionName) + } + + typeRefNode, err := paramNode.Lookup(dslshape.NodeCaveatParameterPredicateType) + if err != nil { + return nil, paramNode.WithSourceErrorf(paramName, "invalid type for parameter: %w", err) + } + + translatedType, err := translateCaveatTypeReference(tctx, typeRefNode) + if err != nil { + return nil, paramNode.WithSourceErrorf(paramName, "invalid type for caveat parameter `%s` on caveat `%s`: %w", paramName, definitionName, err) + } + + parameters[paramName] = *translatedType + err = env.AddVariable(paramName, *translatedType) + if err != nil { + return nil, paramNode.WithSourceErrorf(paramName, "invalid type for caveat parameter `%s` on caveat `%s`: %w", paramName, definitionName, err) + } + } + + caveatPath, err := tctx.prefixedPath(definitionName) + if err != nil { + return nil, defNode.Errorf("%w", err) + } + + // caveat expression. + expressionStringNode, err := defNode.Lookup(dslshape.NodeCaveatDefinitionPredicateExpession) + if err != nil { + return nil, defNode.WithSourceErrorf(definitionName, "invalid expression: %w", err) + } + + expressionString, err := expressionStringNode.GetString(dslshape.NodeCaveatExpressionPredicateExpression) + if err != nil { + return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err) + } + + rnge, err := expressionStringNode.Range(tctx.mapper) + if err != nil { + return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err) + } + + source, err := caveats.NewSource(expressionString, caveatPath) + if err != nil { + return nil, defNode.WithSourceErrorf(expressionString, "invalid expression: %w", err) + } + + compiled, err := caveats.CompileCaveatWithSource(env, caveatPath, source, rnge.Start()) + if err != nil { + return nil, expressionStringNode.WithSourceErrorf(expressionString, "invalid expression for caveat `%s`: %w", definitionName, err) + } + + def, err := namespace.CompiledCaveatDefinition(env, caveatPath, compiled) + if err != nil { + return nil, err + } + + def.Metadata = addComments(def.Metadata, defNode) + def.SourcePosition = getSourcePosition(defNode, tctx.mapper) + return def, nil +} + +func translateCaveatTypeReference(tctx *translationContext, typeRefNode *dslNode) (*caveattypes.VariableType, error) { + typeName, err := typeRefNode.GetString(dslshape.NodeCaveatTypeReferencePredicateType) + if err != nil { + return nil, typeRefNode.WithSourceErrorf(typeName, "invalid type name: %w", err) + } + + childTypeNodes := typeRefNode.List(dslshape.NodeCaveatTypeReferencePredicateChildTypes) + childTypes := make([]caveattypes.VariableType, 0, len(childTypeNodes)) + for _, childTypeNode := range childTypeNodes { + translated, err := translateCaveatTypeReference(tctx, childTypeNode) + if err != nil { + return nil, err + } + childTypes = append(childTypes, *translated) + } + + constructedType, err := tctx.caveatTypeSet.BuildType(typeName, childTypes) + if err != nil { + return nil, typeRefNode.WithSourceErrorf(typeName, "%w", err) + } + + return constructedType, nil +} + +func translateObjectDefinition(tctx *translationContext, defNode *dslNode) (*core.NamespaceDefinition, error) { + definitionName, err := defNode.GetString(dslshape.NodeDefinitionPredicateName) + if err != nil { + return nil, defNode.WithSourceErrorf(definitionName, "invalid definition name: %w", err) + } + + relationsAndPermissions := []*core.Relation{} + for _, relationOrPermissionNode := range defNode.GetChildren() { + if relationOrPermissionNode.GetType() == dslshape.NodeTypeComment { + continue + } + + relationOrPermission, err := translateRelationOrPermission(tctx, relationOrPermissionNode) + if err != nil { + return nil, err + } + + relationsAndPermissions = append(relationsAndPermissions, relationOrPermission) + } + + nspath, err := tctx.prefixedPath(definitionName) + if err != nil { + return nil, defNode.Errorf("%w", err) + } + + if len(relationsAndPermissions) == 0 { + ns := namespace.Namespace(nspath) + ns.Metadata = addComments(ns.Metadata, defNode) + ns.SourcePosition = getSourcePosition(defNode, tctx.mapper) + + if !tctx.skipValidate { + if err = ns.Validate(); err != nil { + return nil, defNode.Errorf("error in object definition %s: %w", nspath, err) + } + } + + return ns, nil + } + + ns := namespace.Namespace(nspath, relationsAndPermissions...) + ns.Metadata = addComments(ns.Metadata, defNode) + ns.SourcePosition = getSourcePosition(defNode, tctx.mapper) + + if !tctx.skipValidate { + if err := ns.Validate(); err != nil { + return nil, defNode.Errorf("error in object definition %s: %w", nspath, err) + } + } + + return ns, nil +} + +func getSourcePosition(dslNode *dslNode, mapper input.PositionMapper) *core.SourcePosition { + if !dslNode.Has(dslshape.NodePredicateStartRune) { + return nil + } + + sourceRange, err := dslNode.Range(mapper) + if err != nil { + return nil + } + + line, col, err := sourceRange.Start().LineAndColumn() + if err != nil { + return nil + } + + // We're okay with these being zero if the cast fails. + uintLine, _ := safecast.ToUint64(line) + uintCol, _ := safecast.ToUint64(col) + + return &core.SourcePosition{ + ZeroIndexedLineNumber: uintLine, + ZeroIndexedColumnPosition: uintCol, + } +} + +func addComments(mdmsg *core.Metadata, dslNode *dslNode) *core.Metadata { + for _, child := range dslNode.GetChildren() { + if child.GetType() == dslshape.NodeTypeComment { + value, err := child.GetString(dslshape.NodeCommentPredicateValue) + if err == nil { + mdmsg, _ = namespace.AddComment(mdmsg, normalizeComment(value)) + } + } + } + return mdmsg +} + +func normalizeComment(value string) string { + var lines []string + scanner := bufio.NewScanner(strings.NewReader(value)) + for scanner.Scan() { + trimmed := strings.TrimSpace(scanner.Text()) + lines = append(lines, trimmed) + } + return strings.Join(lines, "\n") +} + +func translateRelationOrPermission(tctx *translationContext, relOrPermNode *dslNode) (*core.Relation, error) { + switch relOrPermNode.GetType() { + case dslshape.NodeTypeRelation: + rel, err := translateRelation(tctx, relOrPermNode) + if err != nil { + return nil, err + } + rel.Metadata = addComments(rel.Metadata, relOrPermNode) + rel.SourcePosition = getSourcePosition(relOrPermNode, tctx.mapper) + return rel, err + + case dslshape.NodeTypePermission: + rel, err := translatePermission(tctx, relOrPermNode) + if err != nil { + return nil, err + } + rel.Metadata = addComments(rel.Metadata, relOrPermNode) + rel.SourcePosition = getSourcePosition(relOrPermNode, tctx.mapper) + return rel, err + + default: + return nil, relOrPermNode.Errorf("unknown definition top-level node type %s", relOrPermNode.GetType()) + } +} + +func translateRelation(tctx *translationContext, relationNode *dslNode) (*core.Relation, error) { + relationName, err := relationNode.GetString(dslshape.NodePredicateName) + if err != nil { + return nil, relationNode.Errorf("invalid relation name: %w", err) + } + + allowedDirectTypes := []*core.AllowedRelation{} + for _, typeRef := range relationNode.List(dslshape.NodeRelationPredicateAllowedTypes) { + allowedRelations, err := translateAllowedRelations(tctx, typeRef) + if err != nil { + return nil, err + } + + allowedDirectTypes = append(allowedDirectTypes, allowedRelations...) + } + + relation, err := namespace.Relation(relationName, nil, allowedDirectTypes...) + if err != nil { + return nil, err + } + + if !tctx.skipValidate { + if err := relation.Validate(); err != nil { + return nil, relationNode.Errorf("error in relation %s: %w", relationName, err) + } + } + + return relation, nil +} + +func translatePermission(tctx *translationContext, permissionNode *dslNode) (*core.Relation, error) { + permissionName, err := permissionNode.GetString(dslshape.NodePredicateName) + if err != nil { + return nil, permissionNode.Errorf("invalid permission name: %w", err) + } + + expressionNode, err := permissionNode.Lookup(dslshape.NodePermissionPredicateComputeExpression) + if err != nil { + return nil, permissionNode.Errorf("invalid permission expression: %w", err) + } + + rewrite, err := translateExpression(tctx, expressionNode) + if err != nil { + return nil, err + } + + permission, err := namespace.Relation(permissionName, rewrite) + if err != nil { + return nil, err + } + + if !tctx.skipValidate { + if err := permission.Validate(); err != nil { + return nil, permissionNode.Errorf("error in permission %s: %w", permissionName, err) + } + } + + return permission, nil +} + +func translateBinary(tctx *translationContext, expressionNode *dslNode) (*core.SetOperation_Child, *core.SetOperation_Child, error) { + leftChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr) + if err != nil { + return nil, nil, err + } + + rightChild, err := expressionNode.Lookup(dslshape.NodeExpressionPredicateRightExpr) + if err != nil { + return nil, nil, err + } + + leftOperation, err := translateExpressionOperation(tctx, leftChild) + if err != nil { + return nil, nil, err + } + + rightOperation, err := translateExpressionOperation(tctx, rightChild) + if err != nil { + return nil, nil, err + } + + return leftOperation, rightOperation, nil +} + +func translateExpression(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { + translated, err := translateExpressionDirect(tctx, expressionNode) + if err != nil { + return translated, err + } + + translated.SourcePosition = getSourcePosition(expressionNode, tctx.mapper) + return translated, nil +} + +func collapseOps(op *core.SetOperation_Child, handler func(rewrite *core.UsersetRewrite) *core.SetOperation) []*core.SetOperation_Child { + if op.GetUsersetRewrite() == nil { + return []*core.SetOperation_Child{op} + } + + usersetRewrite := op.GetUsersetRewrite() + operation := handler(usersetRewrite) + if operation == nil { + return []*core.SetOperation_Child{op} + } + + collapsed := make([]*core.SetOperation_Child, 0, len(operation.Child)) + for _, child := range operation.Child { + collapsed = append(collapsed, collapseOps(child, handler)...) + } + return collapsed +} + +func translateExpressionDirect(tctx *translationContext, expressionNode *dslNode) (*core.UsersetRewrite, error) { + // For union and intersection, we collapse a tree of binary operations into a flat list containing child + // operations of the *same* type. + translate := func( + builder func(firstChild *core.SetOperation_Child, rest ...*core.SetOperation_Child) *core.UsersetRewrite, + lookup func(rewrite *core.UsersetRewrite) *core.SetOperation, + ) (*core.UsersetRewrite, error) { + leftOperation, rightOperation, err := translateBinary(tctx, expressionNode) + if err != nil { + return nil, err + } + leftOps := collapseOps(leftOperation, lookup) + rightOps := collapseOps(rightOperation, lookup) + ops := append(leftOps, rightOps...) + return builder(ops[0], ops[1:]...), nil + } + + switch expressionNode.GetType() { + case dslshape.NodeTypeUnionExpression: + return translate(namespace.Union, func(rewrite *core.UsersetRewrite) *core.SetOperation { + return rewrite.GetUnion() + }) + + case dslshape.NodeTypeIntersectExpression: + return translate(namespace.Intersection, func(rewrite *core.UsersetRewrite) *core.SetOperation { + return rewrite.GetIntersection() + }) + + case dslshape.NodeTypeExclusionExpression: + // Order matters for exclusions, so do not perform the optimization. + leftOperation, rightOperation, err := translateBinary(tctx, expressionNode) + if err != nil { + return nil, err + } + return namespace.Exclusion(leftOperation, rightOperation), nil + + default: + op, err := translateExpressionOperation(tctx, expressionNode) + if err != nil { + return nil, err + } + + return namespace.Union(op), nil + } +} + +func translateExpressionOperation(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { + translated, err := translateExpressionOperationDirect(tctx, expressionOpNode) + if err != nil { + return translated, err + } + + translated.SourcePosition = getSourcePosition(expressionOpNode, tctx.mapper) + return translated, nil +} + +func translateExpressionOperationDirect(tctx *translationContext, expressionOpNode *dslNode) (*core.SetOperation_Child, error) { + switch expressionOpNode.GetType() { + case dslshape.NodeTypeIdentifier: + referencedRelationName, err := expressionOpNode.GetString(dslshape.NodeIdentiferPredicateValue) + if err != nil { + return nil, err + } + + return namespace.ComputedUserset(referencedRelationName), nil + + case dslshape.NodeTypeNilExpression: + return namespace.Nil(), nil + + case dslshape.NodeTypeArrowExpression: + leftChild, err := expressionOpNode.Lookup(dslshape.NodeExpressionPredicateLeftExpr) + if err != nil { + return nil, err + } + + rightChild, err := expressionOpNode.Lookup(dslshape.NodeExpressionPredicateRightExpr) + if err != nil { + return nil, err + } + + if leftChild.GetType() != dslshape.NodeTypeIdentifier { + return nil, leftChild.Errorf("Nested arrows not yet supported") + } + + tuplesetRelation, err := leftChild.GetString(dslshape.NodeIdentiferPredicateValue) + if err != nil { + return nil, err + } + + usersetRelation, err := rightChild.GetString(dslshape.NodeIdentiferPredicateValue) + if err != nil { + return nil, err + } + + if expressionOpNode.Has(dslshape.NodeArrowExpressionFunctionName) { + functionName, err := expressionOpNode.GetString(dslshape.NodeArrowExpressionFunctionName) + if err != nil { + return nil, err + } + + return namespace.MustFunctionedTupleToUserset(tuplesetRelation, functionName, usersetRelation), nil + } + + return namespace.TupleToUserset(tuplesetRelation, usersetRelation), nil + + case dslshape.NodeTypeUnionExpression: + fallthrough + + case dslshape.NodeTypeIntersectExpression: + fallthrough + + case dslshape.NodeTypeExclusionExpression: + rewrite, err := translateExpression(tctx, expressionOpNode) + if err != nil { + return nil, err + } + return namespace.Rewrite(rewrite), nil + + default: + return nil, expressionOpNode.Errorf("unknown expression node type %s", expressionOpNode.GetType()) + } +} + +func translateAllowedRelations(tctx *translationContext, typeRefNode *dslNode) ([]*core.AllowedRelation, error) { + switch typeRefNode.GetType() { + case dslshape.NodeTypeTypeReference: + references := []*core.AllowedRelation{} + for _, subRefNode := range typeRefNode.List(dslshape.NodeTypeReferencePredicateType) { + subReferences, err := translateAllowedRelations(tctx, subRefNode) + if err != nil { + return []*core.AllowedRelation{}, err + } + + references = append(references, subReferences...) + } + return references, nil + + case dslshape.NodeTypeSpecificTypeReference: + ref, err := translateSpecificTypeReference(tctx, typeRefNode) + if err != nil { + return []*core.AllowedRelation{}, err + } + return []*core.AllowedRelation{ref}, nil + + default: + return nil, typeRefNode.Errorf("unknown type ref node type %s", typeRefNode.GetType()) + } +} + +func translateSpecificTypeReference(tctx *translationContext, typeRefNode *dslNode) (*core.AllowedRelation, error) { + typePath, err := typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateType) + if err != nil { + return nil, typeRefNode.Errorf("invalid type name: %w", err) + } + + nspath, err := tctx.prefixedPath(typePath) + if err != nil { + return nil, typeRefNode.Errorf("%w", err) + } + + if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateWildcard) { + ref := &core.AllowedRelation{ + Namespace: nspath, + RelationOrWildcard: &core.AllowedRelation_PublicWildcard_{ + PublicWildcard: &core.AllowedRelation_PublicWildcard{}, + }, + } + + err = addWithCaveats(tctx, typeRefNode, ref) + if err != nil { + return nil, typeRefNode.Errorf("invalid caveat: %w", err) + } + + if !tctx.skipValidate { + if err := ref.Validate(); err != nil { + return nil, typeRefNode.Errorf("invalid type relation: %w", err) + } + } + + ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper) + return ref, nil + } + + relationName := Ellipsis + if typeRefNode.Has(dslshape.NodeSpecificReferencePredicateRelation) { + relationName, err = typeRefNode.GetString(dslshape.NodeSpecificReferencePredicateRelation) + if err != nil { + return nil, typeRefNode.Errorf("invalid type relation: %w", err) + } + } + + ref := &core.AllowedRelation{ + Namespace: nspath, + RelationOrWildcard: &core.AllowedRelation_Relation{ + Relation: relationName, + }, + } + + // Add the caveat(s), if any. + err = addWithCaveats(tctx, typeRefNode, ref) + if err != nil { + return nil, typeRefNode.Errorf("invalid caveat: %w", err) + } + + // Add the expiration trait, if any. + if traitNode, err := typeRefNode.Lookup(dslshape.NodeSpecificReferencePredicateTrait); err == nil { + traitName, err := traitNode.GetString(dslshape.NodeTraitPredicateTrait) + if err != nil { + return nil, typeRefNode.Errorf("invalid trait: %w", err) + } + + if traitName != "expiration" { + return nil, typeRefNode.Errorf("invalid trait: %s", traitName) + } + + if !slices.Contains(tctx.allowedFlags, "expiration") { + return nil, typeRefNode.Errorf("expiration trait is not allowed") + } + + ref.RequiredExpiration = &core.ExpirationTrait{} + } + + if !tctx.skipValidate { + if err := ref.Validate(); err != nil { + return nil, typeRefNode.Errorf("invalid type relation: %w", err) + } + } + + ref.SourcePosition = getSourcePosition(typeRefNode, tctx.mapper) + return ref, nil +} + +func addWithCaveats(tctx *translationContext, typeRefNode *dslNode, ref *core.AllowedRelation) error { + caveats := typeRefNode.List(dslshape.NodeSpecificReferencePredicateCaveat) + if len(caveats) == 0 { + return nil + } + + if len(caveats) != 1 { + return fmt.Errorf("only one caveat is currently allowed per type reference") + } + + name, err := caveats[0].GetString(dslshape.NodeCaveatPredicateCaveat) + if err != nil { + return err + } + + nspath, err := tctx.prefixedPath(name) + if err != nil { + return err + } + + ref.RequiredCaveat = &core.AllowedCaveat{ + CaveatName: nspath, + } + return nil +} + +// Translate use node and add flag to list of enabled flags +func translateUseFlag(tctx *translationContext, useFlagNode *dslNode) error { + flagName, err := useFlagNode.GetString(dslshape.NodeUseFlagPredicateName) + if err != nil { + return err + } + if slices.Contains(tctx.enabledFlags, flagName) { + return useFlagNode.Errorf("found duplicate use flag: %s", flagName) + } + tctx.enabledFlags = append(tctx.enabledFlags, flagName) + return nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/doc.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/doc.go new file mode 100644 index 0000000..457bd0b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/doc.go @@ -0,0 +1,2 @@ +// Package dslshape defines the types representing the structure of schema DSL. +package dslshape diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/dslshape.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/dslshape.go new file mode 100644 index 0000000..c3a599f --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/dslshape.go @@ -0,0 +1,209 @@ +//go:generate go run golang.org/x/tools/cmd/stringer -type=NodeType -output zz_generated.nodetype_string.go + +package dslshape + +// NodeType identifies the type of AST node. +type NodeType int + +const ( + // Top-level + NodeTypeError NodeType = iota // error occurred; value is text of error + NodeTypeFile // The file root node + NodeTypeComment // A single or multiline comment + NodeTypeUseFlag // A use flag + + NodeTypeDefinition // A definition. + NodeTypeCaveatDefinition // A caveat definition. + + NodeTypeCaveatParameter // A caveat parameter. + NodeTypeCaveatExpression // A caveat expression. + + NodeTypeRelation // A relation + NodeTypePermission // A permission + + NodeTypeTypeReference // A type reference + NodeTypeSpecificTypeReference // A reference to a specific type. + NodeTypeCaveatReference // A caveat reference under a type. + NodeTypeTraitReference // A trait reference under a typr. + + NodeTypeUnionExpression + NodeTypeIntersectExpression + NodeTypeExclusionExpression + + NodeTypeArrowExpression // A TTU in arrow form. + + NodeTypeIdentifier // An identifier under an expression. + NodeTypeNilExpression // A nil keyword + + NodeTypeCaveatTypeReference // A type reference for a caveat parameter. +) + +const ( + // + // All nodes + // + // The source of this node. + NodePredicateSource = "input-source" + + // The rune position in the input string at which this node begins. + NodePredicateStartRune = "start-rune" + + // The rune position in the input string at which this node ends. + NodePredicateEndRune = "end-rune" + + // A direct child of this node. Implementations should handle the ordering + // automatically for this predicate. + NodePredicateChild = "child-node" + + // + // NodeTypeError + // + + // The message for the parsing error. + NodePredicateErrorMessage = "error-message" + + // The (optional) source to highlight for the parsing error. + NodePredicateErrorSource = "error-source" + + // + // NodeTypeComment + // + + // The value of the comment, including its delimeter(s) + NodeCommentPredicateValue = "comment-value" + + // + // NodeTypeUseFlag + // + + // The name of the use flag. + NodeUseFlagPredicateName = "use-flag-name" + + // + // NodeTypeDefinition + // + + // The name of the definition + NodeDefinitionPredicateName = "definition-name" + + // + // NodeTypeCaveatDefinition + // + + // The name of the definition + NodeCaveatDefinitionPredicateName = "caveat-definition-name" + + // The parameters for the definition. + NodeCaveatDefinitionPredicateParameters = "parameters" + + // The link to the expression for the definition. + NodeCaveatDefinitionPredicateExpession = "caveat-definition-expression" + + // + // NodeTypeCaveatExpression + // + + // The raw CEL expression, in string form. + NodeCaveatExpressionPredicateExpression = "caveat-expression-expressionstr" + + // + // NodeTypeCaveatParameter + // + + // The name of the parameter + NodeCaveatParameterPredicateName = "caveat-parameter-name" + + // The defined type of the caveat parameter. + NodeCaveatParameterPredicateType = "caveat-parameter-type" + + // + // NodeTypeCaveatTypeReference + // + + // The type for the caveat type reference. + NodeCaveatTypeReferencePredicateType = "type-name" + + // The child type(s) for the type reference. + NodeCaveatTypeReferencePredicateChildTypes = "child-types" + + // + // NodeTypeRelation + NodeTypePermission + // + + // The name of the relation/permission + NodePredicateName = "relation-name" + + // + // NodeTypeRelation + // + + // The allowed types for the relation. + NodeRelationPredicateAllowedTypes = "allowed-types" + + // + // NodeTypeTypeReference + // + + // A type under a type reference. + NodeTypeReferencePredicateType = "type-ref-type" + + // + // NodeTypeSpecificTypeReference + // + + // A type under a type reference. + NodeSpecificReferencePredicateType = "type-name" + + // A relation under a type reference. + NodeSpecificReferencePredicateRelation = "relation-name" + + // A wildcard under a type reference. + NodeSpecificReferencePredicateWildcard = "type-wildcard" + + // A caveat under a type reference. + NodeSpecificReferencePredicateCaveat = "caveat" + + // A trait under a type reference. + NodeSpecificReferencePredicateTrait = "trait" + + // + // NodeTypeCaveatReference + // + + // The caveat name under the caveat. + NodeCaveatPredicateCaveat = "caveat-name" + + // + // NodeTypeTraitReference + // + + // The trait name under the trait. + NodeTraitPredicateTrait = "trait-name" + + // + // NodeTypePermission + // + + // The expression to compute the permission. + NodePermissionPredicateComputeExpression = "compute-expression" + + // + // NodeTypeArrowExpression + // + + // The name of the function in the arrow expression. + NodeArrowExpressionFunctionName = "function-name" + + // + // NodeTypeIdentifer + // + + // The value of the identifier. + NodeIdentiferPredicateValue = "identifier-value" + + // + // NodeTypeUnionExpression + NodeTypeIntersectExpression + NodeTypeExclusionExpression + NodeTypeArrowExpression + // + NodeExpressionPredicateLeftExpr = "left-expr" + NodeExpressionPredicateRightExpr = "right-expr" +) diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/zz_generated.nodetype_string.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/zz_generated.nodetype_string.go new file mode 100644 index 0000000..4ef1e06 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/dslshape/zz_generated.nodetype_string.go @@ -0,0 +1,43 @@ +// Code generated by "stringer -type=NodeType -output zz_generated.nodetype_string.go"; DO NOT EDIT. + +package dslshape + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[NodeTypeError-0] + _ = x[NodeTypeFile-1] + _ = x[NodeTypeComment-2] + _ = x[NodeTypeUseFlag-3] + _ = x[NodeTypeDefinition-4] + _ = x[NodeTypeCaveatDefinition-5] + _ = x[NodeTypeCaveatParameter-6] + _ = x[NodeTypeCaveatExpression-7] + _ = x[NodeTypeRelation-8] + _ = x[NodeTypePermission-9] + _ = x[NodeTypeTypeReference-10] + _ = x[NodeTypeSpecificTypeReference-11] + _ = x[NodeTypeCaveatReference-12] + _ = x[NodeTypeTraitReference-13] + _ = x[NodeTypeUnionExpression-14] + _ = x[NodeTypeIntersectExpression-15] + _ = x[NodeTypeExclusionExpression-16] + _ = x[NodeTypeArrowExpression-17] + _ = x[NodeTypeIdentifier-18] + _ = x[NodeTypeNilExpression-19] + _ = x[NodeTypeCaveatTypeReference-20] +} + +const _NodeType_name = "NodeTypeErrorNodeTypeFileNodeTypeCommentNodeTypeUseFlagNodeTypeDefinitionNodeTypeCaveatDefinitionNodeTypeCaveatParameterNodeTypeCaveatExpressionNodeTypeRelationNodeTypePermissionNodeTypeTypeReferenceNodeTypeSpecificTypeReferenceNodeTypeCaveatReferenceNodeTypeTraitReferenceNodeTypeUnionExpressionNodeTypeIntersectExpressionNodeTypeExclusionExpressionNodeTypeArrowExpressionNodeTypeIdentifierNodeTypeNilExpressionNodeTypeCaveatTypeReference" + +var _NodeType_index = [...]uint16{0, 13, 25, 40, 55, 73, 97, 120, 144, 160, 178, 199, 228, 251, 273, 296, 323, 350, 373, 391, 412, 439} + +func (i NodeType) String() string { + if i < 0 || i >= NodeType(len(_NodeType_index)-1) { + return "NodeType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _NodeType_name[_NodeType_index[i]:_NodeType_index[i+1]] +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator.go new file mode 100644 index 0000000..3c8fdd9 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator.go @@ -0,0 +1,430 @@ +package generator + +import ( + "bufio" + "fmt" + "sort" + "strings" + + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/pkg/caveats" + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil/mapz" + "github.com/authzed/spicedb/pkg/graph" + "github.com/authzed/spicedb/pkg/namespace" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// Ellipsis is the relation name for terminal subjects. +const Ellipsis = "..." + +// MaxSingleLineCommentLength sets the maximum length for a comment to made single line. +const MaxSingleLineCommentLength = 70 // 80 - the comment parts and some padding + +func GenerateSchema(definitions []compiler.SchemaDefinition) (string, bool, error) { + return GenerateSchemaWithCaveatTypeSet(definitions, caveattypes.Default.TypeSet) +} + +// GenerateSchemaWithCaveatTypeSet generates a DSL view of the given schema. +func GenerateSchemaWithCaveatTypeSet(definitions []compiler.SchemaDefinition, caveatTypeSet *caveattypes.TypeSet) (string, bool, error) { + generated := make([]string, 0, len(definitions)) + flags := mapz.NewSet[string]() + + result := true + for _, definition := range definitions { + switch def := definition.(type) { + case *core.CaveatDefinition: + generatedCaveat, ok, err := GenerateCaveatSource(def, caveatTypeSet) + if err != nil { + return "", false, err + } + + result = result && ok + generated = append(generated, generatedCaveat) + + case *core.NamespaceDefinition: + generatedSchema, defFlags, ok, err := generateDefinitionSource(def, caveatTypeSet) + if err != nil { + return "", false, err + } + + result = result && ok + generated = append(generated, generatedSchema) + flags.Extend(defFlags) + + default: + return "", false, spiceerrors.MustBugf("unknown type of definition %T in GenerateSchema", def) + } + } + + if !flags.IsEmpty() { + flagsSlice := flags.AsSlice() + sort.Strings(flagsSlice) + + for _, flag := range flagsSlice { + generated = append([]string{"use " + flag}, generated...) + } + } + + return strings.Join(generated, "\n\n"), result, nil +} + +// GenerateCaveatSource generates a DSL view of the given caveat definition. +func GenerateCaveatSource(caveat *core.CaveatDefinition, caveatTypeSet *caveattypes.TypeSet) (string, bool, error) { + generator := &sourceGenerator{ + indentationLevel: 0, + hasNewline: true, + hasBlankline: true, + hasNewScope: true, + caveatTypeSet: caveatTypeSet, + } + + err := generator.emitCaveat(caveat) + if err != nil { + return "", false, err + } + + return generator.buf.String(), !generator.hasIssue, nil +} + +// GenerateSource generates a DSL view of the given namespace definition. +func GenerateSource(namespace *core.NamespaceDefinition, caveatTypeSet *caveattypes.TypeSet) (string, bool, error) { + source, _, ok, err := generateDefinitionSource(namespace, caveatTypeSet) + return source, ok, err +} + +func generateDefinitionSource(namespace *core.NamespaceDefinition, caveatTypeSet *caveattypes.TypeSet) (string, []string, bool, error) { + generator := &sourceGenerator{ + indentationLevel: 0, + hasNewline: true, + hasBlankline: true, + hasNewScope: true, + flags: mapz.NewSet[string](), + caveatTypeSet: caveatTypeSet, + } + + err := generator.emitNamespace(namespace) + if err != nil { + return "", nil, false, err + } + + return generator.buf.String(), generator.flags.AsSlice(), !generator.hasIssue, nil +} + +// GenerateRelationSource generates a DSL view of the given relation definition. +func GenerateRelationSource(relation *core.Relation, caveatTypeSet *caveattypes.TypeSet) (string, error) { + generator := &sourceGenerator{ + indentationLevel: 0, + hasNewline: true, + hasBlankline: true, + hasNewScope: true, + caveatTypeSet: caveatTypeSet, + } + + err := generator.emitRelation(relation) + if err != nil { + return "", err + } + + return generator.buf.String(), nil +} + +func (sg *sourceGenerator) emitCaveat(caveat *core.CaveatDefinition) error { + sg.emitComments(caveat.Metadata) + sg.append("caveat ") + sg.append(caveat.Name) + sg.append("(") + + parameterNames := maps.Keys(caveat.ParameterTypes) + sort.Strings(parameterNames) + + for index, paramName := range parameterNames { + if index > 0 { + sg.append(", ") + } + + decoded, err := caveattypes.DecodeParameterType(sg.caveatTypeSet, caveat.ParameterTypes[paramName]) + if err != nil { + return fmt.Errorf("invalid parameter type on caveat: %w", err) + } + + sg.append(paramName) + sg.append(" ") + sg.append(decoded.String()) + } + + sg.append(")") + + sg.append(" {") + sg.appendLine() + sg.indent() + sg.markNewScope() + + parameterTypes, err := caveattypes.DecodeParameterTypes(sg.caveatTypeSet, caveat.ParameterTypes) + if err != nil { + return fmt.Errorf("invalid caveat parameters: %w", err) + } + + deserializedExpression, err := caveats.DeserializeCaveatWithTypeSet(sg.caveatTypeSet, caveat.SerializedExpression, parameterTypes) + if err != nil { + return fmt.Errorf("invalid caveat expression bytes: %w", err) + } + + exprString, err := deserializedExpression.ExprString() + if err != nil { + return fmt.Errorf("invalid caveat expression: %w", err) + } + + sg.append(strings.TrimSpace(exprString)) + sg.appendLine() + + sg.dedent() + sg.append("}") + return nil +} + +func (sg *sourceGenerator) emitNamespace(namespace *core.NamespaceDefinition) error { + sg.emitComments(namespace.Metadata) + sg.append("definition ") + sg.append(namespace.Name) + + if len(namespace.Relation) == 0 { + sg.append(" {}") + return nil + } + + sg.append(" {") + sg.appendLine() + sg.indent() + sg.markNewScope() + + for _, relation := range namespace.Relation { + err := sg.emitRelation(relation) + if err != nil { + return err + } + } + + sg.dedent() + sg.append("}") + return nil +} + +func (sg *sourceGenerator) emitRelation(relation *core.Relation) error { + hasThis, err := graph.HasThis(relation.UsersetRewrite) + if err != nil { + return err + } + + isPermission := relation.UsersetRewrite != nil && !hasThis + + sg.emitComments(relation.Metadata) + if isPermission { + sg.append("permission ") + } else { + sg.append("relation ") + } + + sg.append(relation.Name) + + if !isPermission { + sg.append(": ") + if relation.TypeInformation == nil || relation.TypeInformation.AllowedDirectRelations == nil || len(relation.TypeInformation.AllowedDirectRelations) == 0 { + sg.appendIssue("missing allowed types") + } else { + for index, allowedRelation := range relation.TypeInformation.AllowedDirectRelations { + if index > 0 { + sg.append(" | ") + } + + sg.emitAllowedRelation(allowedRelation) + } + } + } + + if relation.UsersetRewrite != nil { + sg.append(" = ") + sg.mustEmitRewrite(relation.UsersetRewrite) + } + + sg.appendLine() + return nil +} + +func (sg *sourceGenerator) emitAllowedRelation(allowedRelation *core.AllowedRelation) { + sg.append(allowedRelation.Namespace) + if allowedRelation.GetRelation() != "" && allowedRelation.GetRelation() != Ellipsis { + sg.append("#") + sg.append(allowedRelation.GetRelation()) + } + if allowedRelation.GetPublicWildcard() != nil { + sg.append(":*") + } + + hasExpirationTrait := allowedRelation.GetRequiredExpiration() != nil + hasCaveat := allowedRelation.GetRequiredCaveat() != nil + + if hasExpirationTrait || hasCaveat { + sg.append(" with ") + if hasCaveat { + sg.append(allowedRelation.RequiredCaveat.CaveatName) + } + + if hasExpirationTrait { + sg.flags.Add("expiration") + + if hasCaveat { + sg.append(" and ") + } + + sg.append("expiration") + } + } +} + +func (sg *sourceGenerator) mustEmitRewrite(rewrite *core.UsersetRewrite) { + switch rw := rewrite.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + sg.emitRewriteOps(rw.Union, "+") + case *core.UsersetRewrite_Intersection: + sg.emitRewriteOps(rw.Intersection, "&") + case *core.UsersetRewrite_Exclusion: + sg.emitRewriteOps(rw.Exclusion, "-") + default: + panic(spiceerrors.MustBugf("unknown rewrite operation %T", rw)) + } +} + +func (sg *sourceGenerator) emitRewriteOps(setOp *core.SetOperation, op string) { + for index, child := range setOp.Child { + if index > 0 { + sg.append(" " + op + " ") + } + + sg.mustEmitSetOpChild(child) + } +} + +func (sg *sourceGenerator) isAllUnion(rewrite *core.UsersetRewrite) bool { + switch rw := rewrite.RewriteOperation.(type) { + case *core.UsersetRewrite_Union: + for _, setOpChild := range rw.Union.Child { + switch child := setOpChild.ChildType.(type) { + case *core.SetOperation_Child_UsersetRewrite: + if !sg.isAllUnion(child.UsersetRewrite) { + return false + } + default: + continue + } + } + return true + default: + return false + } +} + +func (sg *sourceGenerator) mustEmitSetOpChild(setOpChild *core.SetOperation_Child) { + switch child := setOpChild.ChildType.(type) { + case *core.SetOperation_Child_UsersetRewrite: + if sg.isAllUnion(child.UsersetRewrite) { + sg.mustEmitRewrite(child.UsersetRewrite) + break + } + + sg.append("(") + sg.mustEmitRewrite(child.UsersetRewrite) + sg.append(")") + + case *core.SetOperation_Child_XThis: + sg.appendIssue("_this unsupported here. Please rewrite into a relation and permission") + + case *core.SetOperation_Child_XNil: + sg.append("nil") + + case *core.SetOperation_Child_ComputedUserset: + sg.append(child.ComputedUserset.Relation) + + case *core.SetOperation_Child_TupleToUserset: + sg.append(child.TupleToUserset.Tupleset.Relation) + sg.append("->") + sg.append(child.TupleToUserset.ComputedUserset.Relation) + + case *core.SetOperation_Child_FunctionedTupleToUserset: + sg.append(child.FunctionedTupleToUserset.Tupleset.Relation) + sg.append(".") + + switch child.FunctionedTupleToUserset.Function { + case core.FunctionedTupleToUserset_FUNCTION_ALL: + sg.append("all") + + case core.FunctionedTupleToUserset_FUNCTION_ANY: + sg.append("any") + + default: + panic(spiceerrors.MustBugf("unknown function %v", child.FunctionedTupleToUserset.Function)) + } + + sg.append("(") + sg.append(child.FunctionedTupleToUserset.ComputedUserset.Relation) + sg.append(")") + + default: + panic(spiceerrors.MustBugf("unknown child type %T", child)) + } +} + +func (sg *sourceGenerator) emitComments(metadata *core.Metadata) { + if len(namespace.GetComments(metadata)) > 0 { + sg.ensureBlankLineOrNewScope() + } + + for _, comment := range namespace.GetComments(metadata) { + sg.appendComment(comment) + } +} + +func (sg *sourceGenerator) appendComment(comment string) { + switch { + case strings.HasPrefix(comment, "/*"): + stripped := strings.TrimSpace(comment) + + if strings.HasPrefix(stripped, "/**") { + stripped = strings.TrimPrefix(stripped, "/**") + sg.append("/**") + } else { + stripped = strings.TrimPrefix(stripped, "/*") + sg.append("/*") + } + + stripped = strings.TrimSuffix(stripped, "*/") + stripped = strings.TrimSpace(stripped) + + requireMultiline := len(stripped) > MaxSingleLineCommentLength || strings.ContainsRune(stripped, '\n') + + if requireMultiline { + sg.appendLine() + scanner := bufio.NewScanner(strings.NewReader(stripped)) + for scanner.Scan() { + sg.append(" * ") + sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(scanner.Text()), "*"))) + sg.appendLine() + } + sg.append(" */") + sg.appendLine() + } else { + sg.append(" ") + sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(stripped), "*"))) + sg.append(" */") + sg.appendLine() + } + + case strings.HasPrefix(comment, "//"): + sg.append("// ") + sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(comment), "//"))) + sg.appendLine() + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator_impl.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator_impl.go new file mode 100644 index 0000000..6251712 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/generator/generator_impl.go @@ -0,0 +1,83 @@ +package generator + +import ( + "strings" + + caveattypes "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/genutil/mapz" +) + +type sourceGenerator struct { + buf strings.Builder // The buffer for the new source code. + indentationLevel int // The current indentation level. + hasNewline bool // Whether there is a newline at the end of the buffer. + hasBlankline bool // Whether there is a blank line at the end of the buffer. + hasIssue bool // Whether there is a translation issue. + hasNewScope bool // Whether there is a new scope at the end of the buffer. + existingLineLength int // Length of the existing line. + flags *mapz.Set[string] // The flags added while generating. + caveatTypeSet *caveattypes.TypeSet +} + +// ensureBlankLineOrNewScope ensures that there is a blank line or new scope at the tail of the buffer. If not, +// a new line is added. +func (sg *sourceGenerator) ensureBlankLineOrNewScope() { + if !sg.hasBlankline && !sg.hasNewScope { + sg.appendLine() + } +} + +// indent increases the current indentation. +func (sg *sourceGenerator) indent() { + sg.indentationLevel = sg.indentationLevel + 1 +} + +// dedent decreases the current indentation. +func (sg *sourceGenerator) dedent() { + sg.indentationLevel = sg.indentationLevel - 1 +} + +// appendIssue adds an issue found in generation. +func (sg *sourceGenerator) appendIssue(description string) { + sg.append("/* ") + sg.append(description) + sg.append(" */") + sg.hasIssue = true +} + +// append adds the given value to the buffer, indenting as necessary. +func (sg *sourceGenerator) append(value string) { + for _, currentRune := range value { + if currentRune == '\n' { + if sg.hasNewline { + sg.hasBlankline = true + } + + sg.buf.WriteRune('\n') + sg.hasNewline = true + sg.existingLineLength = 0 + continue + } + + sg.hasBlankline = false + sg.hasNewScope = false + + if sg.hasNewline { + sg.buf.WriteString(strings.Repeat("\t", sg.indentationLevel)) + sg.hasNewline = false + sg.existingLineLength += sg.indentationLevel + } + + sg.existingLineLength++ + sg.buf.WriteRune(currentRune) + } +} + +// appendLine adds a newline. +func (sg *sourceGenerator) appendLine() { + sg.append("\n") +} + +func (sg *sourceGenerator) markNewScope() { + sg.hasNewScope = true +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/inputsource.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/inputsource.go new file mode 100644 index 0000000..3bd0437 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/inputsource.go @@ -0,0 +1,224 @@ +package input + +import ( + "fmt" +) + +// BytePosition represents the byte position in a piece of code. +type BytePosition int + +// Position represents a position in an arbitrary source file. +type Position struct { + // LineNumber is the 0-indexed line number. + LineNumber int + + // ColumnPosition is the 0-indexed column position on the line. + ColumnPosition int +} + +// Source represents the path of a source file. +type Source string + +// RangeForRunePosition returns a source range over this source file. +func (is Source) RangeForRunePosition(runePosition int, mapper PositionMapper) SourceRange { + return is.RangeForRunePositions(runePosition, runePosition, mapper) +} + +// PositionForRunePosition returns a source position over this source file. +func (is Source) PositionForRunePosition(runePosition int, mapper PositionMapper) SourcePosition { + return runeIndexedPosition{is, mapper, runePosition} +} + +// PositionFromLineAndColumn returns a source position at the given line and column in this source file. +func (is Source) PositionFromLineAndColumn(lineNumber int, columnPosition int, mapper PositionMapper) SourcePosition { + return lcIndexedPosition{is, mapper, Position{lineNumber, columnPosition}} +} + +// RangeForRunePositions returns a source range over this source file. +func (is Source) RangeForRunePositions(startRune int, endRune int, mapper PositionMapper) SourceRange { + return sourceRange{is, runeIndexedPosition{is, mapper, startRune}, runeIndexedPosition{is, mapper, endRune}} +} + +// RangeForLineAndColPositions returns a source range over this source file. +func (is Source) RangeForLineAndColPositions(start Position, end Position, mapper PositionMapper) SourceRange { + return sourceRange{is, lcIndexedPosition{is, mapper, start}, lcIndexedPosition{is, mapper, end}} +} + +// PositionMapper defines an interface for converting rune position <-> line+col position +// under source files. +type PositionMapper interface { + // RunePositionToLineAndCol converts the given 0-indexed rune position under the given source file + // into a 0-indexed line number and column position. + RunePositionToLineAndCol(runePosition int, path Source) (int, int, error) + + // LineAndColToRunePosition converts the given 0-indexed line number and column position under the + // given source file into a 0-indexed rune position. + LineAndColToRunePosition(lineNumber int, colPosition int, path Source) (int, error) + + // TextForLine returns the text for the specified line number. + TextForLine(lineNumber int, path Source) (string, error) +} + +// SourceRange represents a range inside a source file. +type SourceRange interface { + // Source is the input source for this range. + Source() Source + + // Start is the starting position of the source range. + Start() SourcePosition + + // End is the ending position (inclusive) of the source range. If the same as the Start, + // this range represents a single position. + End() SourcePosition + + // ContainsPosition returns true if the given range contains the given position. + ContainsPosition(position SourcePosition) (bool, error) + + // AtStartPosition returns a SourceRange located only at the starting position of this range. + AtStartPosition() SourceRange + + // String returns a (somewhat) human-readable form of the range. + String() string +} + +// SourcePosition represents a single position in a source file. +type SourcePosition interface { + // Source is the input source for this position. + Source() Source + + // RunePosition returns the 0-indexed rune position in the source file. + RunePosition() (int, error) + + // LineAndColumn returns the 0-indexed line number and column position in the source file. + LineAndColumn() (int, int, error) + + // LineText returns the text of the line for this position. + LineText() (string, error) + + // String returns a (somewhat) human-readable form of the position. + String() string +} + +// sourceRange implements the SourceRange interface. +type sourceRange struct { + source Source + start SourcePosition + end SourcePosition +} + +func (sr sourceRange) Source() Source { + return sr.source +} + +func (sr sourceRange) Start() SourcePosition { + return sr.start +} + +func (sr sourceRange) End() SourcePosition { + return sr.end +} + +func (sr sourceRange) AtStartPosition() SourceRange { + return sourceRange{sr.source, sr.start, sr.end} +} + +func (sr sourceRange) ContainsPosition(position SourcePosition) (bool, error) { + if position.Source() != sr.source { + return false, nil + } + + startRune, err := sr.start.RunePosition() + if err != nil { + return false, err + } + + endRune, err := sr.end.RunePosition() + if err != nil { + return false, err + } + + positionRune, err := position.RunePosition() + if err != nil { + return false, err + } + + return positionRune >= startRune && positionRune <= endRune, nil +} + +func (sr sourceRange) String() string { + return fmt.Sprintf("%v -> %v", sr.start, sr.end) +} + +// runeIndexedPosition implements the SourcePosition interface over a rune position. +type runeIndexedPosition struct { + source Source + mapper PositionMapper + runePosition int +} + +func (ris runeIndexedPosition) Source() Source { + return ris.source +} + +func (ris runeIndexedPosition) RunePosition() (int, error) { + return ris.runePosition, nil +} + +func (ris runeIndexedPosition) LineAndColumn() (int, int, error) { + if ris.runePosition == 0 { + return 0, 0, nil + } + if ris.mapper == nil { + return -1, -1, fmt.Errorf("nil mapper") + } + return ris.mapper.RunePositionToLineAndCol(ris.runePosition, ris.source) +} + +func (ris runeIndexedPosition) String() string { + return fmt.Sprintf("%s@%v (rune)", ris.source, ris.runePosition) +} + +func (ris runeIndexedPosition) LineText() (string, error) { + lineNumber, _, err := ris.LineAndColumn() + if err != nil { + return "", err + } + + return ris.mapper.TextForLine(lineNumber, ris.source) +} + +// lcIndexedPosition implements the SourcePosition interface over a line and colu,n position. +type lcIndexedPosition struct { + source Source + mapper PositionMapper + lcPosition Position +} + +func (lcip lcIndexedPosition) Source() Source { + return lcip.source +} + +func (lcip lcIndexedPosition) String() string { + return fmt.Sprintf("%s@%v:%v (line/col)", lcip.source, lcip.lcPosition.LineNumber, lcip.lcPosition.ColumnPosition) +} + +func (lcip lcIndexedPosition) RunePosition() (int, error) { + if lcip.lcPosition.LineNumber == 0 && lcip.lcPosition.ColumnPosition == 0 { + return 0, nil + } + if lcip.mapper == nil { + return -1, fmt.Errorf("nil mapper") + } + return lcip.mapper.LineAndColToRunePosition(lcip.lcPosition.LineNumber, lcip.lcPosition.ColumnPosition, lcip.source) +} + +func (lcip lcIndexedPosition) LineAndColumn() (int, int, error) { + return lcip.lcPosition.LineNumber, lcip.lcPosition.ColumnPosition, nil +} + +func (lcip lcIndexedPosition) LineText() (string, error) { + if lcip.mapper == nil { + return "", fmt.Errorf("nil mapper") + } + return lcip.mapper.TextForLine(lcip.lcPosition.LineNumber, lcip.source) +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/sourcepositionmapper.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/sourcepositionmapper.go new file mode 100644 index 0000000..1bca03c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/input/sourcepositionmapper.go @@ -0,0 +1,95 @@ +package input + +import ( + "fmt" + "strings" + + "github.com/emirpasic/gods/trees/redblacktree" +) + +// SourcePositionMapper defines a helper struct for cached, faster lookup of rune position <-> +// (line, column) for a specific source file. +type SourcePositionMapper struct { + // rangeTree holds a tree that maps from rune position to a line and start position. + rangeTree *redblacktree.Tree + + // lineMap holds a map from line number to rune positions for that line. + lineMap map[int]inclusiveRange +} + +// EmptySourcePositionMapper returns an empty source position mapper. +func EmptySourcePositionMapper() SourcePositionMapper { + rangeTree := redblacktree.NewWith(inclusiveComparator) + return SourcePositionMapper{rangeTree, map[int]inclusiveRange{}} +} + +// CreateSourcePositionMapper returns a source position mapper for the contents of a source file. +func CreateSourcePositionMapper(contents []byte) SourcePositionMapper { + lines := strings.Split(string(contents), "\n") + rangeTree := redblacktree.NewWith(inclusiveComparator) + lineMap := map[int]inclusiveRange{} + + currentStart := int(0) + for index, line := range lines { + lineEnd := currentStart + len(line) + rangeTree.Put(inclusiveRange{currentStart, lineEnd}, lineAndStart{index, currentStart}) + lineMap[index] = inclusiveRange{currentStart, lineEnd} + currentStart = lineEnd + 1 + } + + return SourcePositionMapper{rangeTree, lineMap} +} + +type inclusiveRange struct { + start int + end int +} + +type lineAndStart struct { + lineNumber int + startPosition int +} + +func inclusiveComparator(a, b interface{}) int { + i1 := a.(inclusiveRange) + i2 := b.(inclusiveRange) + + if i1.start >= i2.start && i1.end <= i2.end { + return 0 + } + + diff := int64(i1.start) - int64(i2.start) + + if diff < 0 { + return -1 + } + if diff > 0 { + return 1 + } + return 0 +} + +// RunePositionToLineAndCol returns the line number and column position of the rune position in source. +func (spm SourcePositionMapper) RunePositionToLineAndCol(runePosition int) (int, int, error) { + ls, found := spm.rangeTree.Get(inclusiveRange{runePosition, runePosition}) + if !found { + return 0, 0, fmt.Errorf("unknown rune position %v in source file", runePosition) + } + + las := ls.(lineAndStart) + return las.lineNumber, runePosition - las.startPosition, nil +} + +// LineAndColToRunePosition returns the rune position of the line number and column position in source. +func (spm SourcePositionMapper) LineAndColToRunePosition(lineNumber int, colPosition int) (int, error) { + lineRuneInfo, hasLine := spm.lineMap[lineNumber] + if !hasLine { + return 0, fmt.Errorf("unknown line %v in source file", lineNumber) + } + + if colPosition > lineRuneInfo.end-lineRuneInfo.start { + return 0, fmt.Errorf("column position %v not found on line %v in source file", colPosition, lineNumber) + } + + return lineRuneInfo.start + colPosition, nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flaggablelexer.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flaggablelexer.go new file mode 100644 index 0000000..1ae99af --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flaggablelexer.go @@ -0,0 +1,59 @@ +package lexer + +// FlaggableLexer wraps a lexer, automatically translating tokens based on flags, if any. +type FlaggableLexer struct { + lex *Lexer // a reference to the lexer used for tokenization + enabledFlags map[string]transformer // flags that are enabled + seenDefinition bool + afterUseIdentifier bool +} + +// NewFlaggableLexer returns a new FlaggableLexer for the given lexer. +func NewFlaggableLexer(lex *Lexer) *FlaggableLexer { + return &FlaggableLexer{ + lex: lex, + enabledFlags: map[string]transformer{}, + } +} + +// Close stops the lexer from running. +func (l *FlaggableLexer) Close() { + l.lex.Close() +} + +// NextToken returns the next token found in the lexer. +func (l *FlaggableLexer) NextToken() Lexeme { + nextToken := l.lex.nextToken() + + // Look for `use somefeature` + if nextToken.Kind == TokenTypeIdentifier { + // Only allowed until we've seen a definition of some kind. + if !l.seenDefinition { + if l.afterUseIdentifier { + if transformer, ok := Flags[nextToken.Value]; ok { + l.enabledFlags[nextToken.Value] = transformer + } + + l.afterUseIdentifier = false + } else { + l.afterUseIdentifier = nextToken.Value == "use" + } + } + } + + if nextToken.Kind == TokenTypeKeyword && nextToken.Value == "definition" { + l.seenDefinition = true + } + if nextToken.Kind == TokenTypeKeyword && nextToken.Value == "caveat" { + l.seenDefinition = true + } + + for _, handler := range l.enabledFlags { + updated, ok := handler(nextToken) + if ok { + return updated + } + } + + return nextToken +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flags.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flags.go new file mode 100644 index 0000000..3bfbbde --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/flags.go @@ -0,0 +1,26 @@ +package lexer + +// FlagExpiration indicates that `expiration` is supported as a first-class +// feature in the schema. +const FlagExpiration = "expiration" + +type transformer func(lexeme Lexeme) (Lexeme, bool) + +// Flags is a map of flag names to their corresponding transformers. +var Flags = map[string]transformer{ + FlagExpiration: func(lexeme Lexeme) (Lexeme, bool) { + // `expiration` becomes a keyword. + if lexeme.Kind == TokenTypeIdentifier && lexeme.Value == "expiration" { + lexeme.Kind = TokenTypeKeyword + return lexeme, true + } + + // `and` becomes a keyword. + if lexeme.Kind == TokenTypeIdentifier && lexeme.Value == "and" { + lexeme.Kind = TokenTypeKeyword + return lexeme, true + } + + return lexeme, false + }, +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex.go new file mode 100644 index 0000000..a09df50 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex.go @@ -0,0 +1,231 @@ +// Based on design first introduced in: http://blog.golang.org/two-go-talks-lexical-scanning-in-go-and +// Portions copied and modified from: https://github.com/golang/go/blob/master/src/text/template/parse/lex.go + +package lexer + +import ( + "fmt" + "strings" + "sync" + "unicode/utf8" + + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +const EOFRUNE = -1 + +// createLexer creates a new scanner for the input string. +func createLexer(source input.Source, input string) *Lexer { + l := &Lexer{ + source: source, + input: input, + tokens: make(chan Lexeme), + closed: make(chan struct{}), + } + go l.run() + return l +} + +// run runs the state machine for the lexer. +func (l *Lexer) run() { + defer func() { + close(l.tokens) + }() + l.withLock(func() { + l.state = lexSource + }) + var state stateFn + for { + l.withRLock(func() { + state = l.state + }) + if state == nil { + break + } + next := state(l) + l.withLock(func() { + l.state = next + }) + } +} + +// Close stops the lexer from running. +func (l *Lexer) Close() { + close(l.closed) + l.withLock(func() { + l.state = nil + }) +} + +// withLock runs f protected by l's lock +func (l *Lexer) withLock(f func()) { + l.Lock() + defer l.Unlock() + f() +} + +// withRLock runs f protected by l's read lock +func (l *Lexer) withRLock(f func()) { + l.RLock() + defer l.RUnlock() + f() +} + +// Lexeme represents a token returned from scanning the contents of a file. +type Lexeme struct { + Kind TokenType // The type of this lexeme. + Position input.BytePosition // The starting position of this token in the input string. + Value string // The textual value of this token. + Error string // The error associated with the lexeme, if any. +} + +// stateFn represents the state of the scanner as a function that returns the next state. +type stateFn func(*Lexer) stateFn + +// Lexer holds the state of the scanner. +type Lexer struct { + sync.RWMutex + state stateFn // the next lexing function to enter. GUARDED_BY(RWMutex) + + source input.Source // the name of the input; used only for error reports + input string // the string being scanned + pos input.BytePosition // current position in the input + start input.BytePosition // start position of this token + width input.BytePosition // width of last rune read from input + lastPos input.BytePosition // position of most recent token returned by nextToken + tokens chan Lexeme // channel of scanned lexemes + currentToken Lexeme // The current token if any + lastNonIgnoredToken Lexeme // The last token returned that is non-whitespace and non-comment + closed chan struct{} // Holds the closed channel +} + +// nextToken returns the next token from the input. +func (l *Lexer) nextToken() Lexeme { + token := <-l.tokens + l.lastPos = token.Position + return token +} + +// next returns the next rune in the input. +func (l *Lexer) next() rune { + if int(l.pos) >= len(l.input) { + l.width = 0 + return EOFRUNE + } + r, w := utf8.DecodeRuneInString(l.input[l.pos:]) + l.width = input.BytePosition(w) + l.pos += l.width + return r +} + +// peek returns but does not consume the next rune in the input. +func (l *Lexer) peek() rune { + r := l.next() + l.backup() + return r +} + +// backup steps back one rune. Can only be called once per call of next. +func (l *Lexer) backup() { + l.pos -= l.width +} + +// value returns the current value of the token in the lexer. +func (l *Lexer) value() string { + return l.input[l.start:l.pos] +} + +// emit passes an token back to the client. +func (l *Lexer) emit(t TokenType) { + currentToken := Lexeme{t, l.start, l.value(), ""} + + if t != TokenTypeWhitespace && t != TokenTypeMultilineComment && t != TokenTypeSinglelineComment { + l.lastNonIgnoredToken = currentToken + } + + select { + case l.tokens <- currentToken: + l.currentToken = currentToken + l.start = l.pos + + case <-l.closed: + return + } +} + +// errorf returns an error token and terminates the scan by passing +// back a nil pointer that will be the next state, terminating l.nexttoken. +func (l *Lexer) errorf(currentRune rune, format string, args ...interface{}) stateFn { + l.tokens <- Lexeme{TokenTypeError, l.start, string(currentRune), fmt.Sprintf(format, args...)} + return nil +} + +// peekValue looks forward for the given value string. If found, returns true. +func (l *Lexer) peekValue(value string) bool { + for index, runeValue := range value { + r := l.next() + if r != runeValue { + for j := 0; j <= index; j++ { + l.backup() + } + return false + } + } + + for i := 0; i < len(value); i++ { + l.backup() + } + + return true +} + +// accept consumes the next rune if it's from the valid set. +func (l *Lexer) accept(valid string) bool { + if nextRune := l.next(); strings.ContainsRune(valid, nextRune) { + return true + } + l.backup() + return false +} + +// acceptString consumes the full given string, if the next tokens in the stream. +func (l *Lexer) acceptString(value string) bool { + for index, runeValue := range value { + if l.next() != runeValue { + for i := 0; i <= index; i++ { + l.backup() + } + + return false + } + } + + return true +} + +// lexSource scans until EOFRUNE +func lexSource(l *Lexer) stateFn { + return lexerEntrypoint(l) +} + +// checkFn returns whether a rune matches for continue looping. +type checkFn func(r rune) (bool, error) + +func buildLexUntil(findType TokenType, checker checkFn) stateFn { + return func(l *Lexer) stateFn { + for { + r := l.next() + isValid, err := checker(r) + if err != nil { + return l.errorf(r, "%v", err) + } + if !isValid { + l.backup() + break + } + } + + l.emit(findType) + return lexSource + } +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex_def.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex_def.go new file mode 100644 index 0000000..d8cbe8c --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/lex_def.go @@ -0,0 +1,351 @@ +//go:generate go run golang.org/x/tools/cmd/stringer -type=TokenType + +package lexer + +import ( + "unicode" + + "github.com/authzed/spicedb/pkg/schemadsl/input" +) + +// Lex creates a new scanner for the input string. +func Lex(source input.Source, input string) *Lexer { + return createLexer(source, input) +} + +// TokenType identifies the type of lexer lexemes. +type TokenType int + +const ( + TokenTypeError TokenType = iota // error occurred; value is text of error + + // Synthetic semicolon + TokenTypeSyntheticSemicolon + + TokenTypeEOF + TokenTypeWhitespace + TokenTypeSinglelineComment + TokenTypeMultilineComment + TokenTypeNewline + + TokenTypeKeyword // interface + TokenTypeIdentifier // helloworld + TokenTypeNumber // 123 + + TokenTypeLeftBrace // { + TokenTypeRightBrace // } + TokenTypeLeftParen // ( + TokenTypeRightParen // ) + + TokenTypePipe // | + TokenTypePlus // + + TokenTypeMinus // - + TokenTypeAnd // & + TokenTypeDiv // / + + TokenTypeEquals // = + TokenTypeColon // : + TokenTypeSemicolon // ; + TokenTypeRightArrow // -> + TokenTypeHash // # + TokenTypeEllipsis // ... + TokenTypeStar // * + + // Additional tokens for CEL: https://github.com/google/cel-spec/blob/master/doc/langdef.md#syntax + TokenTypeQuestionMark // ? + TokenTypeConditionalOr // || + TokenTypeConditionalAnd // && + TokenTypeExclamationPoint // ! + TokenTypeLeftBracket // [ + TokenTypeRightBracket // ] + TokenTypePeriod // . + TokenTypeComma // , + TokenTypePercent // % + TokenTypeLessThan // < + TokenTypeGreaterThan // > + TokenTypeLessThanOrEqual // <= + TokenTypeGreaterThanOrEqual // >= + TokenTypeEqualEqual // == + TokenTypeNotEqual // != + TokenTypeString // "...", '...', """...""", '''...''' +) + +// keywords contains the full set of keywords supported. +var keywords = map[string]struct{}{ + "definition": {}, + "caveat": {}, + "relation": {}, + "permission": {}, + "nil": {}, + "with": {}, +} + +// IsKeyword returns whether the specified input string is a reserved keyword. +func IsKeyword(candidate string) bool { + _, ok := keywords[candidate] + return ok +} + +// syntheticPredecessors contains the full set of token types after which, if a newline is found, +// we emit a synthetic semicolon rather than a normal newline token. +var syntheticPredecessors = map[TokenType]bool{ + TokenTypeIdentifier: true, + TokenTypeKeyword: true, + + TokenTypeRightBrace: true, + TokenTypeRightParen: true, + + TokenTypeStar: true, +} + +// lexerEntrypoint scans until EOFRUNE +func lexerEntrypoint(l *Lexer) stateFn { +Loop: + for { + switch r := l.next(); { + case r == EOFRUNE: + break Loop + + case r == '{': + l.emit(TokenTypeLeftBrace) + + case r == '}': + l.emit(TokenTypeRightBrace) + + case r == '(': + l.emit(TokenTypeLeftParen) + + case r == ')': + l.emit(TokenTypeRightParen) + + case r == '+': + l.emit(TokenTypePlus) + + case r == '|': + if l.acceptString("|") { + l.emit(TokenTypeConditionalOr) + } else { + l.emit(TokenTypePipe) + } + + case r == '&': + if l.acceptString("&") { + l.emit(TokenTypeConditionalAnd) + } else { + l.emit(TokenTypeAnd) + } + + case r == '?': + l.emit(TokenTypeQuestionMark) + + case r == '!': + if l.acceptString("=") { + l.emit(TokenTypeNotEqual) + } else { + l.emit(TokenTypeExclamationPoint) + } + + case r == '[': + l.emit(TokenTypeLeftBracket) + + case r == ']': + l.emit(TokenTypeRightBracket) + + case r == '%': + l.emit(TokenTypePercent) + + case r == '<': + if l.acceptString("=") { + l.emit(TokenTypeLessThanOrEqual) + } else { + l.emit(TokenTypeLessThan) + } + + case r == '>': + if l.acceptString("=") { + l.emit(TokenTypeGreaterThanOrEqual) + } else { + l.emit(TokenTypeGreaterThan) + } + + case r == ',': + l.emit(TokenTypeComma) + + case r == '=': + if l.acceptString("=") { + l.emit(TokenTypeEqualEqual) + } else { + l.emit(TokenTypeEquals) + } + + case r == ':': + l.emit(TokenTypeColon) + + case r == ';': + l.emit(TokenTypeSemicolon) + + case r == '#': + l.emit(TokenTypeHash) + + case r == '*': + l.emit(TokenTypeStar) + + case r == '.': + if l.acceptString("..") { + l.emit(TokenTypeEllipsis) + } else { + l.emit(TokenTypePeriod) + } + + case r == '-': + if l.accept(">") { + l.emit(TokenTypeRightArrow) + } else { + l.emit(TokenTypeMinus) + } + + case isSpace(r): + l.emit(TokenTypeWhitespace) + + case isNewline(r): + // If the previous token matches the synthetic semicolon list, + // we emit a synthetic semicolon instead of a simple newline. + if _, ok := syntheticPredecessors[l.lastNonIgnoredToken.Kind]; ok { + l.emit(TokenTypeSyntheticSemicolon) + } else { + l.emit(TokenTypeNewline) + } + + case isAlphaNumeric(r): + l.backup() + return lexIdentifierOrKeyword + + case r == '\'' || r == '"': + l.backup() + return lexStringLiteral + + case r == '/': + // Check for comments. + if l.peekValue("/") { + l.backup() + return lexSinglelineComment + } + + if l.peekValue("*") { + l.backup() + return lexMultilineComment + } + + l.emit(TokenTypeDiv) + default: + return l.errorf(r, "unrecognized character at this location: %#U", r) + } + } + + l.emit(TokenTypeEOF) + return nil +} + +// lexStringLiteral scan until the close of the string literal or EOFRUNE +func lexStringLiteral(l *Lexer) stateFn { + allowNewlines := false + terminator := "" + + if l.acceptString(`"""`) { + terminator = `"""` + allowNewlines = true + } else if l.acceptString(`'''`) { + terminator = `"""` + allowNewlines = true + } else if l.acceptString(`"`) { + terminator = `"` + } else if l.acceptString(`'`) { + terminator = `'` + } + + for { + if l.peekValue(terminator) { + l.acceptString(terminator) + l.emit(TokenTypeString) + return lexSource + } + + // Otherwise, consume until we hit EOFRUNE. + r := l.next() + if !allowNewlines && isNewline(r) { + return l.errorf(r, "Unterminated string") + } + + if r == EOFRUNE { + return l.errorf(r, "Unterminated string") + } + } +} + +// lexSinglelineComment scans until newline or EOFRUNE +func lexSinglelineComment(l *Lexer) stateFn { + checker := func(r rune) (bool, error) { + result := r == EOFRUNE || isNewline(r) + return !result, nil + } + + l.acceptString("//") + return buildLexUntil(TokenTypeSinglelineComment, checker) +} + +// lexMultilineComment scans until the close of the multiline comment or EOFRUNE +func lexMultilineComment(l *Lexer) stateFn { + l.acceptString("/*") + for { + // Check for the end of the multiline comment. + if l.peekValue("*/") { + l.acceptString("*/") + l.emit(TokenTypeMultilineComment) + return lexSource + } + + // Otherwise, consume until we hit EOFRUNE. + r := l.next() + if r == EOFRUNE { + return l.errorf(r, "Unterminated multiline comment") + } + } +} + +// lexIdentifierOrKeyword searches for a keyword or literal identifier. +func lexIdentifierOrKeyword(l *Lexer) stateFn { + for { + if !isAlphaNumeric(l.peek()) { + break + } + + l.next() + } + + _, isKeyword := keywords[l.value()] + + switch { + case isKeyword: + l.emit(TokenTypeKeyword) + + default: + l.emit(TokenTypeIdentifier) + } + + return lexSource +} + +// isSpace reports whether r is a space character. +func isSpace(r rune) bool { + return r == ' ' || r == '\t' +} + +// isNewline reports whether r is a newline character. +func isNewline(r rune) bool { + return r == '\r' || r == '\n' +} + +// isAlphaNumeric reports whether r is an alphabetic, digit, or underscore. +func isAlphaNumeric(r rune) bool { + return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r) +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/tokentype_string.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/tokentype_string.go new file mode 100644 index 0000000..79f3585 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/lexer/tokentype_string.go @@ -0,0 +1,64 @@ +// Code generated by "stringer -type=TokenType"; DO NOT EDIT. + +package lexer + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[TokenTypeError-0] + _ = x[TokenTypeSyntheticSemicolon-1] + _ = x[TokenTypeEOF-2] + _ = x[TokenTypeWhitespace-3] + _ = x[TokenTypeSinglelineComment-4] + _ = x[TokenTypeMultilineComment-5] + _ = x[TokenTypeNewline-6] + _ = x[TokenTypeKeyword-7] + _ = x[TokenTypeIdentifier-8] + _ = x[TokenTypeNumber-9] + _ = x[TokenTypeLeftBrace-10] + _ = x[TokenTypeRightBrace-11] + _ = x[TokenTypeLeftParen-12] + _ = x[TokenTypeRightParen-13] + _ = x[TokenTypePipe-14] + _ = x[TokenTypePlus-15] + _ = x[TokenTypeMinus-16] + _ = x[TokenTypeAnd-17] + _ = x[TokenTypeDiv-18] + _ = x[TokenTypeEquals-19] + _ = x[TokenTypeColon-20] + _ = x[TokenTypeSemicolon-21] + _ = x[TokenTypeRightArrow-22] + _ = x[TokenTypeHash-23] + _ = x[TokenTypeEllipsis-24] + _ = x[TokenTypeStar-25] + _ = x[TokenTypeQuestionMark-26] + _ = x[TokenTypeConditionalOr-27] + _ = x[TokenTypeConditionalAnd-28] + _ = x[TokenTypeExclamationPoint-29] + _ = x[TokenTypeLeftBracket-30] + _ = x[TokenTypeRightBracket-31] + _ = x[TokenTypePeriod-32] + _ = x[TokenTypeComma-33] + _ = x[TokenTypePercent-34] + _ = x[TokenTypeLessThan-35] + _ = x[TokenTypeGreaterThan-36] + _ = x[TokenTypeLessThanOrEqual-37] + _ = x[TokenTypeGreaterThanOrEqual-38] + _ = x[TokenTypeEqualEqual-39] + _ = x[TokenTypeNotEqual-40] + _ = x[TokenTypeString-41] +} + +const _TokenType_name = "TokenTypeErrorTokenTypeSyntheticSemicolonTokenTypeEOFTokenTypeWhitespaceTokenTypeSinglelineCommentTokenTypeMultilineCommentTokenTypeNewlineTokenTypeKeywordTokenTypeIdentifierTokenTypeNumberTokenTypeLeftBraceTokenTypeRightBraceTokenTypeLeftParenTokenTypeRightParenTokenTypePipeTokenTypePlusTokenTypeMinusTokenTypeAndTokenTypeDivTokenTypeEqualsTokenTypeColonTokenTypeSemicolonTokenTypeRightArrowTokenTypeHashTokenTypeEllipsisTokenTypeStarTokenTypeQuestionMarkTokenTypeConditionalOrTokenTypeConditionalAndTokenTypeExclamationPointTokenTypeLeftBracketTokenTypeRightBracketTokenTypePeriodTokenTypeCommaTokenTypePercentTokenTypeLessThanTokenTypeGreaterThanTokenTypeLessThanOrEqualTokenTypeGreaterThanOrEqualTokenTypeEqualEqualTokenTypeNotEqualTokenTypeString" + +var _TokenType_index = [...]uint16{0, 14, 41, 53, 72, 98, 123, 139, 155, 174, 189, 207, 226, 244, 263, 276, 289, 303, 315, 327, 342, 356, 374, 393, 406, 423, 436, 457, 479, 502, 527, 547, 568, 583, 597, 613, 630, 650, 674, 701, 720, 737, 752} + +func (i TokenType) String() string { + if i < 0 || i >= TokenType(len(_TokenType_index)-1) { + return "TokenType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _TokenType_name[_TokenType_index[i]:_TokenType_index[i+1]] +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/nodestack.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/nodestack.go new file mode 100644 index 0000000..0d49e09 --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/nodestack.go @@ -0,0 +1,35 @@ +package parser + +type nodeStack struct { + top *element + size int +} + +type element struct { + value AstNode + next *element +} + +func (s *nodeStack) topValue() AstNode { + if s.size == 0 { + return nil + } + + return s.top.value +} + +// Push pushes a node onto the stack. +func (s *nodeStack) push(value AstNode) { + s.top = &element{value, s.top} + s.size++ +} + +// Pop removes the node from the stack and returns it. +func (s *nodeStack) pop() (value AstNode) { + if s.size > 0 { + value, s.top = s.top.value, s.top.next + s.size-- + return + } + return nil +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser.go new file mode 100644 index 0000000..100ad9b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser.go @@ -0,0 +1,658 @@ +// parser package defines the parser for the Authzed Schema DSL. +package parser + +import ( + "strings" + + "golang.org/x/exp/maps" + + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/schemadsl/lexer" +) + +// Parse parses the given Schema DSL source into a parse tree. +func Parse(builder NodeBuilder, source input.Source, input string) AstNode { + lx := lexer.Lex(source, input) + parser := buildParser(lx, builder, source, input) + defer parser.close() + return parser.consumeTopLevel() +} + +// ignoredTokenTypes are those tokens ignored when parsing. +var ignoredTokenTypes = map[lexer.TokenType]bool{ + lexer.TokenTypeWhitespace: true, + lexer.TokenTypeNewline: true, + lexer.TokenTypeSinglelineComment: true, + lexer.TokenTypeMultilineComment: true, +} + +// consumeTopLevel attempts to consume the top-level definitions. +func (p *sourceParser) consumeTopLevel() AstNode { + rootNode := p.startNode(dslshape.NodeTypeFile) + defer p.mustFinishNode() + + // Start at the first token. + p.consumeToken() + + if p.currentToken.Kind == lexer.TokenTypeError { + p.emitErrorf("%s", p.currentToken.Value) + return rootNode + } + + hasSeenDefinition := false + +Loop: + for { + if p.isToken(lexer.TokenTypeEOF) { + break Loop + } + + // Consume a statement terminator if one was found. + p.tryConsumeStatementTerminator() + + if p.isToken(lexer.TokenTypeEOF) { + break Loop + } + + // The top level of the DSL is a set of definitions and caveats: + // definition foobar { ... } + // caveat somecaveat (...) { ... } + + switch { + case p.isIdentifier("use"): + rootNode.Connect(dslshape.NodePredicateChild, p.consumeUseFlag(hasSeenDefinition)) + + case p.isKeyword("definition"): + hasSeenDefinition = true + rootNode.Connect(dslshape.NodePredicateChild, p.consumeDefinition()) + + case p.isKeyword("caveat"): + hasSeenDefinition = true + rootNode.Connect(dslshape.NodePredicateChild, p.consumeCaveat()) + + default: + p.emitErrorf("Unexpected token at root level: %v", p.currentToken.Kind) + break Loop + } + } + + return rootNode +} + +// consumeCaveat attempts to consume a single caveat definition. +// ```caveat somecaveat(param1 type, param2 type) { ... }``` +func (p *sourceParser) consumeCaveat() AstNode { + defNode := p.startNode(dslshape.NodeTypeCaveatDefinition) + defer p.mustFinishNode() + + // caveat ... + p.consumeKeyword("caveat") + caveatName, ok := p.consumeTypePath() + if !ok { + return defNode + } + + defNode.MustDecorate(dslshape.NodeCaveatDefinitionPredicateName, caveatName) + + // Parameters: + // ( + _, ok = p.consume(lexer.TokenTypeLeftParen) + if !ok { + return defNode + } + + for { + paramNode, ok := p.consumeCaveatParameter() + if !ok { + return defNode + } + + defNode.Connect(dslshape.NodeCaveatDefinitionPredicateParameters, paramNode) + if _, ok := p.tryConsume(lexer.TokenTypeComma); !ok { + break + } + } + + // ) + _, ok = p.consume(lexer.TokenTypeRightParen) + if !ok { + return defNode + } + + // { + _, ok = p.consume(lexer.TokenTypeLeftBrace) + if !ok { + return defNode + } + + exprNode, ok := p.consumeCaveatExpression() + if !ok { + return defNode + } + + defNode.Connect(dslshape.NodeCaveatDefinitionPredicateExpession, exprNode) + + // } + _, ok = p.consume(lexer.TokenTypeRightBrace) + if !ok { + return defNode + } + + return defNode +} + +func (p *sourceParser) consumeCaveatExpression() (AstNode, bool) { + exprNode := p.startNode(dslshape.NodeTypeCaveatExpression) + defer p.mustFinishNode() + + // Special Logic Note: Since CEL is its own language, we consume here until we have a matching + // close brace, and then pass ALL the found tokens to CEL's own parser to attach the expression + // here. + braceDepth := 1 // Starting at 1 from the open brace above + var startToken *commentedLexeme + var endToken *commentedLexeme +consumer: + for { + currentToken := p.currentToken + + switch currentToken.Kind { + case lexer.TokenTypeLeftBrace: + braceDepth++ + + case lexer.TokenTypeRightBrace: + if braceDepth == 1 { + break consumer + } + + braceDepth-- + + case lexer.TokenTypeError: + break consumer + + case lexer.TokenTypeEOF: + break consumer + } + + if startToken == nil { + startToken = ¤tToken + } + + endToken = ¤tToken + p.consumeToken() + } + + if startToken == nil { + p.emitErrorf("missing caveat expression") + return exprNode, false + } + + caveatExpression := p.input[startToken.Position : int(endToken.Position)+len(endToken.Value)] + exprNode.MustDecorate(dslshape.NodeCaveatExpressionPredicateExpression, caveatExpression) + return exprNode, true +} + +// consumeCaveatParameter attempts to consume a caveat parameter. +// ```(paramName paramtype)``` +func (p *sourceParser) consumeCaveatParameter() (AstNode, bool) { + paramNode := p.startNode(dslshape.NodeTypeCaveatParameter) + defer p.mustFinishNode() + + name, ok := p.consumeIdentifier() + if !ok { + return paramNode, false + } + + paramNode.MustDecorate(dslshape.NodeCaveatParameterPredicateName, name) + paramNode.Connect(dslshape.NodeCaveatParameterPredicateType, p.consumeCaveatTypeReference()) + return paramNode, true +} + +// consumeCaveatTypeReference attempts to consume a caveat type reference. +// ```typeName<childType>``` +func (p *sourceParser) consumeCaveatTypeReference() AstNode { + typeRefNode := p.startNode(dslshape.NodeTypeCaveatTypeReference) + defer p.mustFinishNode() + + name, ok := p.consumeIdentifier() + if !ok { + return typeRefNode + } + + typeRefNode.MustDecorate(dslshape.NodeCaveatTypeReferencePredicateType, name) + + // Check for child type(s). + // < + if _, ok := p.tryConsume(lexer.TokenTypeLessThan); !ok { + return typeRefNode + } + + for { + childTypeRef := p.consumeCaveatTypeReference() + typeRefNode.Connect(dslshape.NodeCaveatTypeReferencePredicateChildTypes, childTypeRef) + if _, ok := p.tryConsume(lexer.TokenTypeComma); !ok { + break + } + } + + // > + p.consume(lexer.TokenTypeGreaterThan) + return typeRefNode +} + +// consumeUseFlag attempts to consume a use flag. +// ``` use flagname ``` +func (p *sourceParser) consumeUseFlag(afterDefinition bool) AstNode { + useNode := p.startNode(dslshape.NodeTypeUseFlag) + defer p.mustFinishNode() + + // consume the `use` + p.consumeIdentifier() + + var useFlag string + if p.isToken(lexer.TokenTypeIdentifier) { + useFlag, _ = p.consumeIdentifier() + } else { + useName, ok := p.consumeVariableKeyword() + if !ok { + return useNode + } + useFlag = useName + } + + if _, ok := lexer.Flags[useFlag]; !ok { + p.emitErrorf("Unknown use flag: `%s`. Options are: %s", useFlag, strings.Join(maps.Keys(lexer.Flags), ", ")) + return useNode + } + + useNode.MustDecorate(dslshape.NodeUseFlagPredicateName, useFlag) + + // NOTE: we conduct this check in `consumeFlag` rather than at + // the callsite to keep the callsite clean. + // We also do the check after consumption to ensure that the parser continues + // moving past the use expression. + if afterDefinition { + p.emitErrorf("`use` expressions must be declared before any definition") + return useNode + } + + return useNode +} + +// consumeDefinition attempts to consume a single schema definition. +// ```definition somedef { ... }``` +func (p *sourceParser) consumeDefinition() AstNode { + defNode := p.startNode(dslshape.NodeTypeDefinition) + defer p.mustFinishNode() + + // definition ... + p.consumeKeyword("definition") + definitionName, ok := p.consumeTypePath() + if !ok { + return defNode + } + + defNode.MustDecorate(dslshape.NodeDefinitionPredicateName, definitionName) + + // { + _, ok = p.consume(lexer.TokenTypeLeftBrace) + if !ok { + return defNode + } + + // Relations and permissions. + for { + // } + if _, ok := p.tryConsume(lexer.TokenTypeRightBrace); ok { + break + } + + // relation ... + // permission ... + switch { + case p.isKeyword("relation"): + defNode.Connect(dslshape.NodePredicateChild, p.consumeRelation()) + + case p.isKeyword("permission"): + defNode.Connect(dslshape.NodePredicateChild, p.consumePermission()) + } + + ok := p.consumeStatementTerminator() + if !ok { + break + } + } + + return defNode +} + +// consumeRelation consumes a relation. +// ```relation foo: sometype``` +func (p *sourceParser) consumeRelation() AstNode { + relNode := p.startNode(dslshape.NodeTypeRelation) + defer p.mustFinishNode() + + // relation ... + p.consumeKeyword("relation") + relationName, ok := p.consumeIdentifier() + if !ok { + return relNode + } + + relNode.MustDecorate(dslshape.NodePredicateName, relationName) + + // : + _, ok = p.consume(lexer.TokenTypeColon) + if !ok { + return relNode + } + + // Relation allowed type(s). + relNode.Connect(dslshape.NodeRelationPredicateAllowedTypes, p.consumeTypeReference()) + + return relNode +} + +// consumeTypeReference consumes a reference to a type or types of relations. +// ```sometype | anothertype | anothertype:* ``` +func (p *sourceParser) consumeTypeReference() AstNode { + refNode := p.startNode(dslshape.NodeTypeTypeReference) + defer p.mustFinishNode() + + for { + refNode.Connect(dslshape.NodeTypeReferencePredicateType, p.consumeSpecificTypeWithCaveat()) + if _, ok := p.tryConsume(lexer.TokenTypePipe); !ok { + break + } + } + + return refNode +} + +// tryConsumeWithCaveat tries to consume a caveat `with` expression. +func (p *sourceParser) tryConsumeWithCaveat() (AstNode, bool) { + caveatNode := p.startNode(dslshape.NodeTypeCaveatReference) + defer p.mustFinishNode() + + consumed, ok := p.consumeTypePath() + if !ok { + return caveatNode, true + } + + caveatNode.MustDecorate(dslshape.NodeCaveatPredicateCaveat, consumed) + return caveatNode, true +} + +// consumeSpecificTypeWithCaveat consumes an identifier as a specific type reference, with optional caveat. +func (p *sourceParser) consumeSpecificTypeWithCaveat() AstNode { + specificNode := p.consumeSpecificTypeWithoutFinish() + defer p.mustFinishNode() + + // Check for a caveat and/or supported trait. + if !p.isKeyword("with") { + return specificNode + } + + p.consumeKeyword("with") + + if !p.isKeyword("expiration") { + caveatNode, ok := p.tryConsumeWithCaveat() + if ok { + specificNode.Connect(dslshape.NodeSpecificReferencePredicateCaveat, caveatNode) + } + + if !p.tryConsumeKeyword("and") { + return specificNode + } + } + + if p.isKeyword("expiration") { + // Check for expiration trait. + traitNode := p.consumeExpirationTrait() + + // Decorate with the expiration trait. + specificNode.Connect(dslshape.NodeSpecificReferencePredicateTrait, traitNode) + } + + return specificNode +} + +// consumeExpirationTrait consumes an expiration trait. +func (p *sourceParser) consumeExpirationTrait() AstNode { + expirationTraitNode := p.startNode(dslshape.NodeTypeTraitReference) + p.consumeKeyword("expiration") + + expirationTraitNode.MustDecorate(dslshape.NodeTraitPredicateTrait, "expiration") + defer p.mustFinishNode() + + return expirationTraitNode +} + +// consumeSpecificTypeOpen consumes an identifier as a specific type reference. +func (p *sourceParser) consumeSpecificTypeWithoutFinish() AstNode { + specificNode := p.startNode(dslshape.NodeTypeSpecificTypeReference) + + typeName, ok := p.consumeTypePath() + if !ok { + return specificNode + } + + specificNode.MustDecorate(dslshape.NodeSpecificReferencePredicateType, typeName) + + // Check for a wildcard + if _, ok := p.tryConsume(lexer.TokenTypeColon); ok { + _, ok := p.consume(lexer.TokenTypeStar) + if !ok { + return specificNode + } + + specificNode.MustDecorate(dslshape.NodeSpecificReferencePredicateWildcard, "true") + return specificNode + } + + // Check for a relation specified. + if _, ok := p.tryConsume(lexer.TokenTypeHash); !ok { + return specificNode + } + + // Consume an identifier or an ellipsis. + consumed, ok := p.consume(lexer.TokenTypeIdentifier, lexer.TokenTypeEllipsis) + if !ok { + return specificNode + } + + specificNode.MustDecorate(dslshape.NodeSpecificReferencePredicateRelation, consumed.Value) + return specificNode +} + +func (p *sourceParser) consumeTypePath() (string, bool) { + var segments []string + + for { + segment, ok := p.consumeIdentifier() + if !ok { + return "", false + } + + segments = append(segments, segment) + + _, ok = p.tryConsume(lexer.TokenTypeDiv) + if !ok { + break + } + } + + return strings.Join(segments, "/"), true +} + +// consumePermission consumes a permission. +// ```permission foo = bar + baz``` +func (p *sourceParser) consumePermission() AstNode { + permNode := p.startNode(dslshape.NodeTypePermission) + defer p.mustFinishNode() + + // permission ... + p.consumeKeyword("permission") + permissionName, ok := p.consumeIdentifier() + if !ok { + return permNode + } + + permNode.MustDecorate(dslshape.NodePredicateName, permissionName) + + // = + _, ok = p.consume(lexer.TokenTypeEquals) + if !ok { + return permNode + } + + permNode.Connect(dslshape.NodePermissionPredicateComputeExpression, p.consumeComputeExpression()) + return permNode +} + +// ComputeExpressionOperators defines the binary operators in precedence order. +var ComputeExpressionOperators = []binaryOpDefinition{ + {lexer.TokenTypeMinus, dslshape.NodeTypeExclusionExpression}, + {lexer.TokenTypeAnd, dslshape.NodeTypeIntersectExpression}, + {lexer.TokenTypePlus, dslshape.NodeTypeUnionExpression}, +} + +// consumeComputeExpression consumes an expression for computing a permission. +func (p *sourceParser) consumeComputeExpression() AstNode { + // Compute expressions consist of a set of binary operators, so build a tree with proper + // precedence. + binaryParser := p.buildBinaryOperatorExpressionFnTree(ComputeExpressionOperators) + found, ok := binaryParser() + if !ok { + return p.createErrorNodef("Expected compute expression for permission") + } + return found +} + +// tryConsumeComputeExpression attempts to consume a nested compute expression. +func (p *sourceParser) tryConsumeComputeExpression(subTryExprFn tryParserFn, binaryTokenType lexer.TokenType, nodeType dslshape.NodeType) (AstNode, bool) { + rightNodeBuilder := func(leftNode AstNode, operatorToken lexer.Lexeme) (AstNode, bool) { + rightNode, ok := subTryExprFn() + if !ok { + return nil, false + } + + // Create the expression node representing the binary expression. + exprNode := p.createNode(nodeType) + exprNode.Connect(dslshape.NodeExpressionPredicateLeftExpr, leftNode) + exprNode.Connect(dslshape.NodeExpressionPredicateRightExpr, rightNode) + return exprNode, true + } + return p.performLeftRecursiveParsing(subTryExprFn, rightNodeBuilder, nil, binaryTokenType) +} + +// tryConsumeArrowExpression attempts to consume an arrow expression. +// ```foo->bar->baz->meh``` +func (p *sourceParser) tryConsumeArrowExpression() (AstNode, bool) { + rightNodeBuilder := func(leftNode AstNode, operatorToken lexer.Lexeme) (AstNode, bool) { + // Check for an arrow function. + if operatorToken.Kind == lexer.TokenTypePeriod { + functionName, ok := p.consumeIdentifier() + if !ok { + return nil, false + } + + // TODO(jschorr): Change to keywords in schema v2. + if functionName != "any" && functionName != "all" { + p.emitErrorf("Expected 'any' or 'all' for arrow function, found: %s", functionName) + return nil, false + } + + if _, ok := p.consume(lexer.TokenTypeLeftParen); !ok { + return nil, false + } + + rightNode, ok := p.tryConsumeIdentifierLiteral() + if !ok { + return nil, false + } + + if _, ok := p.consume(lexer.TokenTypeRightParen); !ok { + return nil, false + } + + exprNode := p.createNode(dslshape.NodeTypeArrowExpression) + exprNode.Connect(dslshape.NodeExpressionPredicateLeftExpr, leftNode) + exprNode.Connect(dslshape.NodeExpressionPredicateRightExpr, rightNode) + exprNode.MustDecorate(dslshape.NodeArrowExpressionFunctionName, functionName) + return exprNode, true + } + + rightNode, ok := p.tryConsumeIdentifierLiteral() + if !ok { + return nil, false + } + + // Create the expression node representing the binary expression. + exprNode := p.createNode(dslshape.NodeTypeArrowExpression) + exprNode.Connect(dslshape.NodeExpressionPredicateLeftExpr, leftNode) + exprNode.Connect(dslshape.NodeExpressionPredicateRightExpr, rightNode) + return exprNode, true + } + return p.performLeftRecursiveParsing(p.tryConsumeIdentifierLiteral, rightNodeBuilder, nil, lexer.TokenTypeRightArrow, lexer.TokenTypePeriod) +} + +// tryConsumeBaseExpression attempts to consume base compute expressions (identifiers, parenthesis). +// ```(foo + bar)``` +// ```(foo)``` +// ```foo``` +// ```nil``` +func (p *sourceParser) tryConsumeBaseExpression() (AstNode, bool) { + switch { + // Nested expression. + case p.isToken(lexer.TokenTypeLeftParen): + comments := p.currentToken.comments + + p.consume(lexer.TokenTypeLeftParen) + exprNode := p.consumeComputeExpression() + p.consume(lexer.TokenTypeRightParen) + + // Attach any comments found to the consumed expression. + p.decorateComments(exprNode, comments) + + return exprNode, true + + // Nil expression. + case p.isKeyword("nil"): + return p.tryConsumeNilExpression() + + // Identifier. + case p.isToken(lexer.TokenTypeIdentifier): + return p.tryConsumeIdentifierLiteral() + } + + return nil, false +} + +// tryConsumeIdentifierLiteral attempts to consume an identifier as a literal +// expression. +// +// ```foo``` +func (p *sourceParser) tryConsumeIdentifierLiteral() (AstNode, bool) { + if !p.isToken(lexer.TokenTypeIdentifier) { + return nil, false + } + + identNode := p.startNode(dslshape.NodeTypeIdentifier) + defer p.mustFinishNode() + + identifier, _ := p.consumeIdentifier() + identNode.MustDecorate(dslshape.NodeIdentiferPredicateValue, identifier) + return identNode, true +} + +func (p *sourceParser) tryConsumeNilExpression() (AstNode, bool) { + if !p.isKeyword("nil") { + return nil, false + } + + node := p.startNode(dslshape.NodeTypeNilExpression) + p.consumeKeyword("nil") + defer p.mustFinishNode() + return node, true +} diff --git a/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser_impl.go b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser_impl.go new file mode 100644 index 0000000..1009b6b --- /dev/null +++ b/vendor/github.com/authzed/spicedb/pkg/schemadsl/parser/parser_impl.go @@ -0,0 +1,367 @@ +package parser + +import ( + "fmt" + + "github.com/authzed/spicedb/pkg/schemadsl/dslshape" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/schemadsl/lexer" +) + +// AstNode defines an interface for working with nodes created by this parser. +type AstNode interface { + // Connect connects this AstNode to another AstNode with the given predicate. + Connect(predicate string, other AstNode) + + // MustDecorate decorates this AstNode with the given property and string value, returning + // the same node. + MustDecorate(property string, value string) AstNode + + // MustDecorateWithInt decorates this AstNode with the given property and int value, returning + // the same node. + MustDecorateWithInt(property string, value int) AstNode +} + +// NodeBuilder is a function for building AST nodes. +type NodeBuilder func(source input.Source, kind dslshape.NodeType) AstNode + +// tryParserFn is a function that attempts to build an AST node. +type tryParserFn func() (AstNode, bool) + +// lookaheadParserFn is a function that performs lookahead. +type lookaheadParserFn func(currentToken lexer.Lexeme) bool + +// rightNodeConstructor is a function which takes in a left expr node and the +// token consumed for a left-recursive operator, and returns a newly constructed +// operator expression if a right expression could be found. +type rightNodeConstructor func(AstNode, lexer.Lexeme) (AstNode, bool) + +// commentedLexeme is a lexer.Lexeme with comments attached. +type commentedLexeme struct { + lexer.Lexeme + comments []string +} + +// sourceParser holds the state of the parser. +type sourceParser struct { + source input.Source // the name of the input; used only for error reports + input string // the input string itself + lex *lexer.FlaggableLexer // a reference to the lexer used for tokenization + builder NodeBuilder // the builder function for creating AstNode instances + nodes *nodeStack // the stack of the current nodes + currentToken commentedLexeme // the current token + previousToken commentedLexeme // the previous token +} + +// buildParser returns a new sourceParser instance. +func buildParser(lx *lexer.Lexer, builder NodeBuilder, source input.Source, input string) *sourceParser { + l := lexer.NewFlaggableLexer(lx) + return &sourceParser{ + source: source, + input: input, + lex: l, + builder: builder, + nodes: &nodeStack{}, + currentToken: commentedLexeme{lexer.Lexeme{Kind: lexer.TokenTypeEOF}, make([]string, 0)}, + previousToken: commentedLexeme{lexer.Lexeme{Kind: lexer.TokenTypeEOF}, make([]string, 0)}, + } +} + +func (p *sourceParser) close() { + p.lex.Close() +} + +// createNode creates a new AstNode and returns it. +func (p *sourceParser) createNode(kind dslshape.NodeType) AstNode { + return p.builder(p.source, kind) +} + +// createErrorNodef creates a new error node and returns it. +func (p *sourceParser) createErrorNodef(format string, args ...interface{}) AstNode { + message := fmt.Sprintf(format, args...) + node := p.startNode(dslshape.NodeTypeError).MustDecorate(dslshape.NodePredicateErrorMessage, message) + p.mustFinishNode() + return node +} + +// startNode creates a new node of the given type, decorates it with the current token's +// position as its start position, and pushes it onto the nodes stack. +func (p *sourceParser) startNode(kind dslshape.NodeType) AstNode { + node := p.createNode(kind) + p.decorateStartRuneAndComments(node, p.currentToken) + p.nodes.push(node) + return node +} + +// decorateStartRuneAndComments decorates the given node with the location of the given token as its +// starting rune, as well as any comments attached to the token. +func (p *sourceParser) decorateStartRuneAndComments(node AstNode, token commentedLexeme) { + node.MustDecorate(dslshape.NodePredicateSource, string(p.source)) + node.MustDecorateWithInt(dslshape.NodePredicateStartRune, int(token.Position)) + p.decorateComments(node, token.comments) +} + +// decorateComments decorates the given node with the specified comments. +func (p *sourceParser) decorateComments(node AstNode, comments []string) { + for _, comment := range comments { + commentNode := p.createNode(dslshape.NodeTypeComment) + commentNode.MustDecorate(dslshape.NodeCommentPredicateValue, comment) + node.Connect(dslshape.NodePredicateChild, commentNode) + } +} + +// decorateEndRune decorates the given node with the location of the given token as its +// ending rune. +func (p *sourceParser) decorateEndRune(node AstNode, token commentedLexeme) { + position := int(token.Position) + len(token.Value) - 1 + node.MustDecorateWithInt(dslshape.NodePredicateEndRune, position) +} + +// currentNode returns the node at the top of the stack. +func (p *sourceParser) currentNode() AstNode { + return p.nodes.topValue() +} + +// mustFinishNode pops the current node from the top of the stack and decorates it with +// the current token's end position as its end position. +func (p *sourceParser) mustFinishNode() { + if p.currentNode() == nil { + panic(fmt.Sprintf("No current node on stack. Token: %s", p.currentToken.Value)) + } + + p.decorateEndRune(p.currentNode(), p.previousToken) + p.nodes.pop() +} + +// consumeToken advances the lexer forward, returning the next token. +func (p *sourceParser) consumeToken() commentedLexeme { + comments := make([]string, 0) + + for { + token := p.lex.NextToken() + + if token.Kind == lexer.TokenTypeSinglelineComment || token.Kind == lexer.TokenTypeMultilineComment { + comments = append(comments, token.Value) + } + + if _, ok := ignoredTokenTypes[token.Kind]; !ok { + p.previousToken = p.currentToken + p.currentToken = commentedLexeme{token, comments} + return p.currentToken + } + } +} + +// isToken returns true if the current token matches one of the types given. +func (p *sourceParser) isToken(types ...lexer.TokenType) bool { + for _, kind := range types { + if p.currentToken.Kind == kind { + return true + } + } + + return false +} + +// isIdentifier returns true if the current token is an identifier matching that given. +func (p *sourceParser) isIdentifier(identifier string) bool { + return p.isToken(lexer.TokenTypeIdentifier) && p.currentToken.Value == identifier +} + +// isKeyword returns true if the current token is a keyword matching that given. +func (p *sourceParser) isKeyword(keyword string) bool { + return p.isToken(lexer.TokenTypeKeyword) && p.currentToken.Value == keyword +} + +// emitErrorf creates a new error node and attachs it as a child of the current +// node. +func (p *sourceParser) emitErrorf(format string, args ...interface{}) { + errorNode := p.createErrorNodef(format, args...) + if len(p.currentToken.Value) > 0 { + errorNode.MustDecorate(dslshape.NodePredicateErrorSource, p.currentToken.Value) + } + p.currentNode().Connect(dslshape.NodePredicateChild, errorNode) +} + +// consumeVariableKeyword consumes an expected keyword token or adds an error node. +func (p *sourceParser) consumeVariableKeyword() (string, bool) { + if !p.isToken(lexer.TokenTypeKeyword) { + p.emitErrorf("Expected keyword, found token %v", p.currentToken.Kind) + return "", false + } + + token := p.currentToken + p.consumeToken() + return token.Value, true +} + +// consumeKeyword consumes an expected keyword token or adds an error node. +func (p *sourceParser) consumeKeyword(keyword string) bool { + if !p.tryConsumeKeyword(keyword) { + p.emitErrorf("Expected keyword %s, found token %v", keyword, p.currentToken.Kind) + return false + } + return true +} + +// tryConsumeKeyword attempts to consume an expected keyword token. +func (p *sourceParser) tryConsumeKeyword(keyword string) bool { + if !p.isKeyword(keyword) { + return false + } + + p.consumeToken() + return true +} + +// cosumeIdentifier consumes an expected identifier token or adds an error node. +func (p *sourceParser) consumeIdentifier() (string, bool) { + token, ok := p.tryConsume(lexer.TokenTypeIdentifier) + if !ok { + p.emitErrorf("Expected identifier, found token %v", p.currentToken.Kind) + return "", false + } + return token.Value, true +} + +// consume performs consumption of the next token if it matches any of the given +// types and returns it. If no matching type is found, adds an error node. +func (p *sourceParser) consume(types ...lexer.TokenType) (lexer.Lexeme, bool) { + token, ok := p.tryConsume(types...) + if !ok { + p.emitErrorf("Expected one of: %v, found: %v", types, p.currentToken.Kind) + } + return token, ok +} + +// tryConsume performs consumption of the next token if it matches any of the given +// types and returns it. +func (p *sourceParser) tryConsume(types ...lexer.TokenType) (lexer.Lexeme, bool) { + token, found := p.tryConsumeWithComments(types...) + return token.Lexeme, found +} + +// tryConsume performs consumption of the next token if it matches any of the given +// types and returns it. +func (p *sourceParser) tryConsumeWithComments(types ...lexer.TokenType) (commentedLexeme, bool) { + if p.isToken(types...) { + token := p.currentToken + p.consumeToken() + return token, true + } + + return commentedLexeme{lexer.Lexeme{ + Kind: lexer.TokenTypeError, + }, make([]string, 0)}, false +} + +// performLeftRecursiveParsing performs left-recursive parsing of a set of operators. This method +// first performs the parsing via the subTryExprFn and then checks for one of the left-recursive +// operator token types found. If none found, the left expression is returned. Otherwise, the +// rightNodeBuilder is called to attempt to construct an operator expression. This method also +// properly handles decoration of the nodes with their proper start and end run locations and +// comments. +func (p *sourceParser) performLeftRecursiveParsing(subTryExprFn tryParserFn, rightNodeBuilder rightNodeConstructor, rightTokenTester lookaheadParserFn, operatorTokens ...lexer.TokenType) (AstNode, bool) { + leftMostToken := p.currentToken + + // Consume the left side of the expression. + leftNode, ok := subTryExprFn() + if !ok { + return nil, false + } + + // Check for an operator token. If none found, then we've found just the left side of the + // expression and so we return that node. + if !p.isToken(operatorTokens...) { + return leftNode, true + } + + // Keep consuming pairs of operators and child expressions until such + // time as no more can be consumed. We use this loop+custom build rather than recursion + // because these operators are *left* recursive, not right. + var currentLeftNode AstNode + currentLeftNode = leftNode + + for { + // Check for an operator. + if !p.isToken(operatorTokens...) { + break + } + + // If a lookahead function is defined, check the lookahead for the matched token. + if rightTokenTester != nil && !rightTokenTester(p.currentToken.Lexeme) { + break + } + + // Consume the operator. + operatorToken, ok := p.tryConsumeWithComments(operatorTokens...) + if !ok { + break + } + + // Consume the right hand expression and build an expression node (if applicable). + exprNode, ok := rightNodeBuilder(currentLeftNode, operatorToken.Lexeme) + if !ok { + p.emitErrorf("Expected right hand expression, found: %v", p.currentToken.Kind) + return currentLeftNode, true + } + + p.decorateStartRuneAndComments(exprNode, leftMostToken) + p.decorateEndRune(exprNode, p.previousToken) + + currentLeftNode = exprNode + } + + return currentLeftNode, true +} + +// tryConsumeStatementTerminator tries to consume a statement terminator. +func (p *sourceParser) tryConsumeStatementTerminator() (lexer.Lexeme, bool) { + return p.tryConsume(lexer.TokenTypeSyntheticSemicolon, lexer.TokenTypeSemicolon, lexer.TokenTypeEOF) +} + +// consumeStatementTerminator consume a statement terminator. +func (p *sourceParser) consumeStatementTerminator() bool { + _, ok := p.tryConsumeStatementTerminator() + if ok { + return true + } + + p.emitErrorf("Expected end of statement or definition, found: %s", p.currentToken.Kind) + return false +} + +// binaryOpDefinition represents information a binary operator token and its associated node type. +type binaryOpDefinition struct { + // The token representing the binary expression's operator. + BinaryOperatorToken lexer.TokenType + + // The type of node to create for this expression. + BinaryExpressionNodeType dslshape.NodeType +} + +// buildBinaryOperatorExpressionFnTree builds a tree of functions to try to consume a set of binary +// operator expressions. +func (p *sourceParser) buildBinaryOperatorExpressionFnTree(ops []binaryOpDefinition) tryParserFn { + // Start with a base expression function. + var currentParseFn tryParserFn + currentParseFn = func() (AstNode, bool) { + arrowExpr, ok := p.tryConsumeArrowExpression() + if !ok { + return p.tryConsumeBaseExpression() + } + + return arrowExpr, true + } + + for i := range ops { + // Note: We have to reverse this to ensure we have proper precedence. + currentParseFn = func(operatorInfo binaryOpDefinition, currentFn tryParserFn) tryParserFn { + return (func() (AstNode, bool) { + return p.tryConsumeComputeExpression(currentFn, operatorInfo.BinaryOperatorToken, operatorInfo.BinaryExpressionNodeType) + }) + }(ops[len(ops)-i-1], currentParseFn) + } + + return currentParseFn +} |
