summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/zed/internal/cmd/validate.go
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/zed/internal/cmd/validate.go
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/zed/internal/cmd/validate.go')
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/validate.go414
1 files changed, 414 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/cmd/validate.go b/vendor/github.com/authzed/zed/internal/cmd/validate.go
new file mode 100644
index 0000000..ffca1c6
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/validate.go
@@ -0,0 +1,414 @@
+package cmd
+
+import (
+ "errors"
+ "fmt"
+ "net/url"
+ "os"
+ "strings"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/charmbracelet/lipgloss"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/muesli/termenv"
+ "github.com/spf13/cobra"
+
+ composable "github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/development"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/validationfile"
+
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/internal/decode"
+ "github.com/authzed/zed/internal/printers"
+)
+
+var (
+ // NOTE: these need to be set *after* the renderer has been set, otherwise
+ // the forceColor setting can't work, hence the thunking.
+ success = func() string {
+ return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("10")).Render("Success!")
+ }
+ complete = func() string { return lipgloss.NewStyle().Bold(true).Render("complete") }
+ errorPrefix = func() string { return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("9")).Render("error: ") }
+ warningPrefix = func() string {
+ return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("3")).Render("warning: ")
+ }
+ errorMessageStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true).Width(80) }
+ linePrefixStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("12")) }
+ highlightedSourceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) }
+ highlightedLineStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) }
+ codeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("8")) }
+ highlightedCodeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("15")) }
+ traceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true) }
+)
+
+func registerValidateCmd(cmd *cobra.Command) {
+ validateCmd.Flags().Bool("force-color", false, "force color code output even in non-tty environments")
+ validateCmd.Flags().String("schema-type", "", "force validation according to specific schema syntax (\"\", \"composable\", \"standard\")")
+ cmd.AddCommand(validateCmd)
+}
+
+var validateCmd = &cobra.Command{
+ Use: "validate <validation_file_or_schema_file>",
+ Short: "Validates the given validation file (.yaml, .zaml) or schema file (.zed)",
+ Example: `
+ From a local file (with prefix):
+ zed validate file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ From a local file (no prefix):
+ zed validate authzed-x7izWU8_2Gw3.yaml
+
+ From a gist:
+ zed validate https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6
+
+ From a playground link:
+ zed validate https://play.authzed.com/s/iksdFvCtvnkR/schema
+
+ From pastebin:
+ zed validate https://pastebin.com/8qU45rVK
+
+ From a devtools instance:
+ zed validate https://localhost:8443/download`,
+ Args: commands.ValidationWrapper(cobra.MinimumNArgs(1)),
+ ValidArgsFunction: commands.FileExtensionCompletions("zed", "yaml", "zaml"),
+ PreRunE: validatePreRunE,
+ RunE: func(cmd *cobra.Command, filenames []string) error {
+ result, shouldExit, err := validateCmdFunc(cmd, filenames)
+ if err != nil {
+ return err
+ }
+ console.Print(result)
+ if shouldExit {
+ os.Exit(1)
+ }
+ return nil
+ },
+
+ // A schema that causes the parser/compiler to error will halt execution
+ // of this command with an error. In that case, we want to just display the error,
+ // rather than showing usage for this command.
+ SilenceUsage: true,
+}
+
+var validSchemaTypes = []string{"", "standard", "composable"}
+
+func validatePreRunE(cmd *cobra.Command, _ []string) error {
+ // Override lipgloss's autodetection of whether it's in a terminal environment
+ // and display things in color anyway. This can be nice in CI environments that
+ // support it.
+ setForceColor := cobrautil.MustGetBool(cmd, "force-color")
+ if setForceColor {
+ lipgloss.SetColorProfile(termenv.ANSI256)
+ }
+
+ schemaType := cobrautil.MustGetString(cmd, "schema-type")
+ schemaTypeValid := false
+ for _, validType := range validSchemaTypes {
+ if schemaType == validType {
+ schemaTypeValid = true
+ }
+ }
+ if !schemaTypeValid {
+ return fmt.Errorf("schema-type must be one of \"\", \"standard\", \"composable\". received: %s", schemaType)
+ }
+
+ return nil
+}
+
+// validateCmdFunc returns the string to print to the user, whether to return a non-zero status code, and any errors.
+func validateCmdFunc(cmd *cobra.Command, filenames []string) (string, bool, error) {
+ // Initialize variables for multiple files
+ var (
+ totalFiles = len(filenames)
+ successfullyValidatedFiles = 0
+ shouldExit = false
+ toPrint = &strings.Builder{}
+ schemaType = cobrautil.MustGetString(cmd, "schema-type")
+ )
+
+ for _, filename := range filenames {
+ // If we're running over multiple files, print the filename for context/debugging purposes
+ if totalFiles > 1 {
+ toPrint.WriteString(filename + "\n")
+ }
+
+ u, err := url.Parse(filename)
+ if err != nil {
+ return "", false, err
+ }
+
+ decoder, err := decode.DecoderForURL(u)
+ if err != nil {
+ return "", false, err
+ }
+
+ var parsed validationfile.ValidationFile
+ validateContents, isOnlySchema, err := decoder(&parsed)
+ standardErrors, composableErrs, otherErrs := classifyErrors(err)
+
+ switch schemaType {
+ case "standard":
+ if standardErrors != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(standardErrors, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, standardErrors
+ }
+ case "composable":
+ if composableErrs != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(composableErrs, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, composableErrs
+ }
+ default:
+ // By default, validate will attempt to validate a schema first according to composable schema rules,
+ // then standard schema rules,
+ // and if both fail it will show the errors from composable schema.
+ if composableErrs != nil && standardErrors != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(composableErrs, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, composableErrs
+ }
+ }
+
+ if otherErrs != nil {
+ return "", false, otherErrs
+ }
+
+ tuples := make([]*core.RelationTuple, 0)
+ totalAssertions := 0
+ totalRelationsValidated := 0
+
+ for _, rel := range parsed.Relationships.Relationships {
+ tuples = append(tuples, rel.ToCoreTuple())
+ }
+
+ // Create the development context for each run
+ ctx := cmd.Context()
+ devCtx, devErrs, err := development.NewDevContext(ctx, &devinterface.RequestContext{
+ Schema: parsed.Schema.Schema,
+ Relationships: tuples,
+ })
+ if err != nil {
+ return "", false, err
+ }
+ if devErrs != nil {
+ schemaOffset := parsed.Schema.SourcePosition.LineNumber
+ if isOnlySchema {
+ schemaOffset = 0
+ }
+
+ // Output errors
+ outputDeveloperErrorsWithLineOffset(toPrint, validateContents, devErrs.InputErrors, schemaOffset)
+ return toPrint.String(), true, nil
+ }
+ // Run assertions
+ adevErrs, aerr := development.RunAllAssertions(devCtx, &parsed.Assertions)
+ if aerr != nil {
+ return "", false, aerr
+ }
+ if adevErrs != nil {
+ outputDeveloperErrors(toPrint, validateContents, adevErrs)
+ return toPrint.String(), true, nil
+ }
+ successfullyValidatedFiles++
+
+ // Run expected relations for file
+ _, erDevErrs, rerr := development.RunValidation(devCtx, &parsed.ExpectedRelations)
+ if rerr != nil {
+ return "", false, rerr
+ }
+ if erDevErrs != nil {
+ outputDeveloperErrors(toPrint, validateContents, erDevErrs)
+ return toPrint.String(), true, nil
+ }
+ // Print out any warnings for file
+ warnings, err := development.GetWarnings(ctx, devCtx)
+ if err != nil {
+ return "", false, err
+ }
+
+ if len(warnings) > 0 {
+ for _, warning := range warnings {
+ fmt.Fprintf(toPrint, "%s%s\n", warningPrefix(), warning.Message)
+ outputForLine(toPrint, validateContents, uint64(warning.Line), warning.SourceCode, uint64(warning.Column)) // warning.LineNumber is 1-indexed
+ toPrint.WriteString("\n")
+ }
+
+ toPrint.WriteString(complete())
+ } else {
+ toPrint.WriteString(success())
+ }
+ totalAssertions += len(parsed.Assertions.AssertTrue) + len(parsed.Assertions.AssertFalse)
+ totalRelationsValidated += len(parsed.ExpectedRelations.ValidationMap)
+
+ fmt.Fprintf(toPrint, " - %d relationships loaded, %d assertions run, %d expected relations validated\n",
+ len(tuples),
+ totalAssertions,
+ totalRelationsValidated)
+ }
+
+ if totalFiles > 1 {
+ fmt.Fprintf(toPrint, "total files: %d, successfully validated files: %d\n", totalFiles, successfullyValidatedFiles)
+ }
+
+ return toPrint.String(), shouldExit, nil
+}
+
+func outputErrorWithSource(sb *strings.Builder, validateContents []byte, errWithSource spiceerrors.WithSourceError) {
+ fmt.Fprintf(sb, "%s%s\n", errorPrefix(), errorMessageStyle().Render(errWithSource.Error()))
+ outputForLine(sb, validateContents, errWithSource.LineNumber, errWithSource.SourceCodeString, 0) // errWithSource.LineNumber is 1-indexed
+}
+
+func outputForLine(sb *strings.Builder, validateContents []byte, oneIndexedLineNumber uint64, sourceCodeString string, oneIndexedColumnPosition uint64) {
+ lines := strings.Split(string(validateContents), "\n")
+ // These should be fine to be zero if the cast fails.
+ intLineNumber, _ := safecast.ToInt(oneIndexedLineNumber)
+ intColumnPosition, _ := safecast.ToInt(oneIndexedColumnPosition)
+ errorLineNumber := intLineNumber - 1
+ for i := errorLineNumber - 3; i < errorLineNumber+3; i++ {
+ if i == errorLineNumber {
+ renderLine(sb, lines, i, sourceCodeString, errorLineNumber, intColumnPosition-1)
+ } else {
+ renderLine(sb, lines, i, "", errorLineNumber, -1)
+ }
+ }
+}
+
+func outputDeveloperErrors(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError) {
+ outputDeveloperErrorsWithLineOffset(sb, validateContents, devErrors, 0)
+}
+
+func outputDeveloperErrorsWithLineOffset(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError, lineOffset int) {
+ lines := strings.Split(string(validateContents), "\n")
+
+ for _, devErr := range devErrors {
+ outputDeveloperError(sb, devErr, lines, lineOffset)
+ }
+}
+
+func outputDeveloperError(sb *strings.Builder, devError *devinterface.DeveloperError, lines []string, lineOffset int) {
+ fmt.Fprintf(sb, "%s %s\n", errorPrefix(), errorMessageStyle().Render(devError.Message))
+ errorLineNumber := int(devError.Line) - 1 + lineOffset // devError.Line is 1-indexed
+ for i := errorLineNumber - 3; i < errorLineNumber+3; i++ {
+ if i == errorLineNumber {
+ renderLine(sb, lines, i, devError.Context, errorLineNumber, -1)
+ } else {
+ renderLine(sb, lines, i, "", errorLineNumber, -1)
+ }
+ }
+
+ if devError.CheckResolvedDebugInformation != nil && devError.CheckResolvedDebugInformation.Check != nil {
+ fmt.Fprintf(sb, "\n %s\n", traceStyle().Render("Explanation:"))
+ tp := printers.NewTreePrinter()
+ printers.DisplayCheckTrace(devError.CheckResolvedDebugInformation.Check, tp, true)
+ sb.WriteString(tp.Indented())
+ }
+
+ sb.WriteString("\n\n")
+}
+
+func renderLine(sb *strings.Builder, lines []string, index int, highlight string, highlightLineIndex int, highlightStartingColumnIndex int) {
+ if index < 0 || index >= len(lines) {
+ return
+ }
+
+ lineNumberLength := len(fmt.Sprintf("%d", len(lines)))
+ lineContents := strings.ReplaceAll(lines[index], "\t", " ")
+ lineDelimiter := "|"
+
+ highlightLength := max(0, len(highlight)-1)
+ highlightColumnIndex := -1
+
+ // If the highlight string was provided, then we need to find the index of the highlight
+ // string in the line contents to determine where to place the caret.
+ if len(highlight) > 0 {
+ offset := 0
+ for {
+ foundRelativeIndex := strings.Index(lineContents[offset:], highlight)
+ foundIndex := foundRelativeIndex + offset
+ if foundIndex >= highlightStartingColumnIndex {
+ highlightColumnIndex = foundIndex
+ break
+ }
+
+ offset = foundIndex + 1
+ if foundRelativeIndex < 0 || foundIndex > len(lineContents) {
+ break
+ }
+ }
+ } else if highlightStartingColumnIndex >= 0 {
+ // Otherwise, just show a caret at the specified starting column, if any.
+ highlightColumnIndex = highlightStartingColumnIndex
+ }
+
+ lineNumberStr := fmt.Sprintf("%d", index+1)
+ noNumberSpaces := strings.Repeat(" ", lineNumberLength)
+
+ lineNumberStyle := linePrefixStyle
+ lineContentsStyle := codeStyle
+ if index == highlightLineIndex {
+ lineNumberStyle = highlightedLineStyle
+ lineContentsStyle = highlightedCodeStyle
+ lineDelimiter = ">"
+ }
+
+ lineNumberSpacer := strings.Repeat(" ", lineNumberLength-len(lineNumberStr))
+
+ if highlightColumnIndex < 0 {
+ fmt.Fprintf(sb, " %s%s %s %s\n", lineNumberSpacer, lineNumberStyle().Render(lineNumberStr), lineDelimiter, lineContentsStyle().Render(lineContents))
+ } else {
+ fmt.Fprintf(sb, " %s%s %s %s%s%s\n",
+ lineNumberSpacer,
+ lineNumberStyle().Render(lineNumberStr),
+ lineDelimiter,
+ lineContentsStyle().Render(lineContents[0:highlightColumnIndex]),
+ highlightedSourceStyle().Render(highlight),
+ lineContentsStyle().Render(lineContents[highlightColumnIndex+len(highlight):]))
+
+ fmt.Fprintf(sb, " %s %s %s%s%s\n",
+ noNumberSpaces,
+ lineDelimiter,
+ strings.Repeat(" ", highlightColumnIndex),
+ highlightedSourceStyle().Render("^"),
+ highlightedSourceStyle().Render(strings.Repeat("~", highlightLength)))
+ }
+}
+
+// classifyErrors returns errors from the composable DSL, the standard DSL, and any other parsing errors.
+func classifyErrors(err error) (error, error, error) {
+ if err == nil {
+ return nil, nil, nil
+ }
+ var standardErr compiler.BaseCompilerError
+ var composableErr composable.BaseCompilerError
+ var retStandard, retComposable, allOthers error
+
+ ok := errors.As(err, &standardErr)
+ if ok {
+ retStandard = standardErr
+ }
+ ok = errors.As(err, &composableErr)
+ if ok {
+ retComposable = composableErr
+ }
+
+ if retStandard == nil && retComposable == nil {
+ allOthers = err
+ }
+
+ return retStandard, retComposable, allOthers
+}