diff options
Diffstat (limited to 'vendor/github.com/authzed/zed/internal/cmd/validate.go')
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/validate.go | 414 |
1 files changed, 414 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/cmd/validate.go b/vendor/github.com/authzed/zed/internal/cmd/validate.go new file mode 100644 index 0000000..ffca1c6 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/validate.go @@ -0,0 +1,414 @@ +package cmd + +import ( + "errors" + "fmt" + "net/url" + "os" + "strings" + + "github.com/ccoveille/go-safecast" + "github.com/charmbracelet/lipgloss" + "github.com/jzelinskie/cobrautil/v2" + "github.com/muesli/termenv" + "github.com/spf13/cobra" + + composable "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" + "github.com/authzed/spicedb/pkg/development" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/validationfile" + + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/internal/decode" + "github.com/authzed/zed/internal/printers" +) + +var ( + // NOTE: these need to be set *after* the renderer has been set, otherwise + // the forceColor setting can't work, hence the thunking. + success = func() string { + return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("10")).Render("Success!") + } + complete = func() string { return lipgloss.NewStyle().Bold(true).Render("complete") } + errorPrefix = func() string { return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("9")).Render("error: ") } + warningPrefix = func() string { + return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("3")).Render("warning: ") + } + errorMessageStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true).Width(80) } + linePrefixStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("12")) } + highlightedSourceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) } + highlightedLineStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) } + codeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("8")) } + highlightedCodeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("15")) } + traceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true) } +) + +func registerValidateCmd(cmd *cobra.Command) { + validateCmd.Flags().Bool("force-color", false, "force color code output even in non-tty environments") + validateCmd.Flags().String("schema-type", "", "force validation according to specific schema syntax (\"\", \"composable\", \"standard\")") + cmd.AddCommand(validateCmd) +} + +var validateCmd = &cobra.Command{ + Use: "validate <validation_file_or_schema_file>", + Short: "Validates the given validation file (.yaml, .zaml) or schema file (.zed)", + Example: ` + From a local file (with prefix): + zed validate file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + From a local file (no prefix): + zed validate authzed-x7izWU8_2Gw3.yaml + + From a gist: + zed validate https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6 + + From a playground link: + zed validate https://play.authzed.com/s/iksdFvCtvnkR/schema + + From pastebin: + zed validate https://pastebin.com/8qU45rVK + + From a devtools instance: + zed validate https://localhost:8443/download`, + Args: commands.ValidationWrapper(cobra.MinimumNArgs(1)), + ValidArgsFunction: commands.FileExtensionCompletions("zed", "yaml", "zaml"), + PreRunE: validatePreRunE, + RunE: func(cmd *cobra.Command, filenames []string) error { + result, shouldExit, err := validateCmdFunc(cmd, filenames) + if err != nil { + return err + } + console.Print(result) + if shouldExit { + os.Exit(1) + } + return nil + }, + + // A schema that causes the parser/compiler to error will halt execution + // of this command with an error. In that case, we want to just display the error, + // rather than showing usage for this command. + SilenceUsage: true, +} + +var validSchemaTypes = []string{"", "standard", "composable"} + +func validatePreRunE(cmd *cobra.Command, _ []string) error { + // Override lipgloss's autodetection of whether it's in a terminal environment + // and display things in color anyway. This can be nice in CI environments that + // support it. + setForceColor := cobrautil.MustGetBool(cmd, "force-color") + if setForceColor { + lipgloss.SetColorProfile(termenv.ANSI256) + } + + schemaType := cobrautil.MustGetString(cmd, "schema-type") + schemaTypeValid := false + for _, validType := range validSchemaTypes { + if schemaType == validType { + schemaTypeValid = true + } + } + if !schemaTypeValid { + return fmt.Errorf("schema-type must be one of \"\", \"standard\", \"composable\". received: %s", schemaType) + } + + return nil +} + +// validateCmdFunc returns the string to print to the user, whether to return a non-zero status code, and any errors. +func validateCmdFunc(cmd *cobra.Command, filenames []string) (string, bool, error) { + // Initialize variables for multiple files + var ( + totalFiles = len(filenames) + successfullyValidatedFiles = 0 + shouldExit = false + toPrint = &strings.Builder{} + schemaType = cobrautil.MustGetString(cmd, "schema-type") + ) + + for _, filename := range filenames { + // If we're running over multiple files, print the filename for context/debugging purposes + if totalFiles > 1 { + toPrint.WriteString(filename + "\n") + } + + u, err := url.Parse(filename) + if err != nil { + return "", false, err + } + + decoder, err := decode.DecoderForURL(u) + if err != nil { + return "", false, err + } + + var parsed validationfile.ValidationFile + validateContents, isOnlySchema, err := decoder(&parsed) + standardErrors, composableErrs, otherErrs := classifyErrors(err) + + switch schemaType { + case "standard": + if standardErrors != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(standardErrors, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, standardErrors + } + case "composable": + if composableErrs != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(composableErrs, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, composableErrs + } + default: + // By default, validate will attempt to validate a schema first according to composable schema rules, + // then standard schema rules, + // and if both fail it will show the errors from composable schema. + if composableErrs != nil && standardErrors != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(composableErrs, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, composableErrs + } + } + + if otherErrs != nil { + return "", false, otherErrs + } + + tuples := make([]*core.RelationTuple, 0) + totalAssertions := 0 + totalRelationsValidated := 0 + + for _, rel := range parsed.Relationships.Relationships { + tuples = append(tuples, rel.ToCoreTuple()) + } + + // Create the development context for each run + ctx := cmd.Context() + devCtx, devErrs, err := development.NewDevContext(ctx, &devinterface.RequestContext{ + Schema: parsed.Schema.Schema, + Relationships: tuples, + }) + if err != nil { + return "", false, err + } + if devErrs != nil { + schemaOffset := parsed.Schema.SourcePosition.LineNumber + if isOnlySchema { + schemaOffset = 0 + } + + // Output errors + outputDeveloperErrorsWithLineOffset(toPrint, validateContents, devErrs.InputErrors, schemaOffset) + return toPrint.String(), true, nil + } + // Run assertions + adevErrs, aerr := development.RunAllAssertions(devCtx, &parsed.Assertions) + if aerr != nil { + return "", false, aerr + } + if adevErrs != nil { + outputDeveloperErrors(toPrint, validateContents, adevErrs) + return toPrint.String(), true, nil + } + successfullyValidatedFiles++ + + // Run expected relations for file + _, erDevErrs, rerr := development.RunValidation(devCtx, &parsed.ExpectedRelations) + if rerr != nil { + return "", false, rerr + } + if erDevErrs != nil { + outputDeveloperErrors(toPrint, validateContents, erDevErrs) + return toPrint.String(), true, nil + } + // Print out any warnings for file + warnings, err := development.GetWarnings(ctx, devCtx) + if err != nil { + return "", false, err + } + + if len(warnings) > 0 { + for _, warning := range warnings { + fmt.Fprintf(toPrint, "%s%s\n", warningPrefix(), warning.Message) + outputForLine(toPrint, validateContents, uint64(warning.Line), warning.SourceCode, uint64(warning.Column)) // warning.LineNumber is 1-indexed + toPrint.WriteString("\n") + } + + toPrint.WriteString(complete()) + } else { + toPrint.WriteString(success()) + } + totalAssertions += len(parsed.Assertions.AssertTrue) + len(parsed.Assertions.AssertFalse) + totalRelationsValidated += len(parsed.ExpectedRelations.ValidationMap) + + fmt.Fprintf(toPrint, " - %d relationships loaded, %d assertions run, %d expected relations validated\n", + len(tuples), + totalAssertions, + totalRelationsValidated) + } + + if totalFiles > 1 { + fmt.Fprintf(toPrint, "total files: %d, successfully validated files: %d\n", totalFiles, successfullyValidatedFiles) + } + + return toPrint.String(), shouldExit, nil +} + +func outputErrorWithSource(sb *strings.Builder, validateContents []byte, errWithSource spiceerrors.WithSourceError) { + fmt.Fprintf(sb, "%s%s\n", errorPrefix(), errorMessageStyle().Render(errWithSource.Error())) + outputForLine(sb, validateContents, errWithSource.LineNumber, errWithSource.SourceCodeString, 0) // errWithSource.LineNumber is 1-indexed +} + +func outputForLine(sb *strings.Builder, validateContents []byte, oneIndexedLineNumber uint64, sourceCodeString string, oneIndexedColumnPosition uint64) { + lines := strings.Split(string(validateContents), "\n") + // These should be fine to be zero if the cast fails. + intLineNumber, _ := safecast.ToInt(oneIndexedLineNumber) + intColumnPosition, _ := safecast.ToInt(oneIndexedColumnPosition) + errorLineNumber := intLineNumber - 1 + for i := errorLineNumber - 3; i < errorLineNumber+3; i++ { + if i == errorLineNumber { + renderLine(sb, lines, i, sourceCodeString, errorLineNumber, intColumnPosition-1) + } else { + renderLine(sb, lines, i, "", errorLineNumber, -1) + } + } +} + +func outputDeveloperErrors(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError) { + outputDeveloperErrorsWithLineOffset(sb, validateContents, devErrors, 0) +} + +func outputDeveloperErrorsWithLineOffset(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError, lineOffset int) { + lines := strings.Split(string(validateContents), "\n") + + for _, devErr := range devErrors { + outputDeveloperError(sb, devErr, lines, lineOffset) + } +} + +func outputDeveloperError(sb *strings.Builder, devError *devinterface.DeveloperError, lines []string, lineOffset int) { + fmt.Fprintf(sb, "%s %s\n", errorPrefix(), errorMessageStyle().Render(devError.Message)) + errorLineNumber := int(devError.Line) - 1 + lineOffset // devError.Line is 1-indexed + for i := errorLineNumber - 3; i < errorLineNumber+3; i++ { + if i == errorLineNumber { + renderLine(sb, lines, i, devError.Context, errorLineNumber, -1) + } else { + renderLine(sb, lines, i, "", errorLineNumber, -1) + } + } + + if devError.CheckResolvedDebugInformation != nil && devError.CheckResolvedDebugInformation.Check != nil { + fmt.Fprintf(sb, "\n %s\n", traceStyle().Render("Explanation:")) + tp := printers.NewTreePrinter() + printers.DisplayCheckTrace(devError.CheckResolvedDebugInformation.Check, tp, true) + sb.WriteString(tp.Indented()) + } + + sb.WriteString("\n\n") +} + +func renderLine(sb *strings.Builder, lines []string, index int, highlight string, highlightLineIndex int, highlightStartingColumnIndex int) { + if index < 0 || index >= len(lines) { + return + } + + lineNumberLength := len(fmt.Sprintf("%d", len(lines))) + lineContents := strings.ReplaceAll(lines[index], "\t", " ") + lineDelimiter := "|" + + highlightLength := max(0, len(highlight)-1) + highlightColumnIndex := -1 + + // If the highlight string was provided, then we need to find the index of the highlight + // string in the line contents to determine where to place the caret. + if len(highlight) > 0 { + offset := 0 + for { + foundRelativeIndex := strings.Index(lineContents[offset:], highlight) + foundIndex := foundRelativeIndex + offset + if foundIndex >= highlightStartingColumnIndex { + highlightColumnIndex = foundIndex + break + } + + offset = foundIndex + 1 + if foundRelativeIndex < 0 || foundIndex > len(lineContents) { + break + } + } + } else if highlightStartingColumnIndex >= 0 { + // Otherwise, just show a caret at the specified starting column, if any. + highlightColumnIndex = highlightStartingColumnIndex + } + + lineNumberStr := fmt.Sprintf("%d", index+1) + noNumberSpaces := strings.Repeat(" ", lineNumberLength) + + lineNumberStyle := linePrefixStyle + lineContentsStyle := codeStyle + if index == highlightLineIndex { + lineNumberStyle = highlightedLineStyle + lineContentsStyle = highlightedCodeStyle + lineDelimiter = ">" + } + + lineNumberSpacer := strings.Repeat(" ", lineNumberLength-len(lineNumberStr)) + + if highlightColumnIndex < 0 { + fmt.Fprintf(sb, " %s%s %s %s\n", lineNumberSpacer, lineNumberStyle().Render(lineNumberStr), lineDelimiter, lineContentsStyle().Render(lineContents)) + } else { + fmt.Fprintf(sb, " %s%s %s %s%s%s\n", + lineNumberSpacer, + lineNumberStyle().Render(lineNumberStr), + lineDelimiter, + lineContentsStyle().Render(lineContents[0:highlightColumnIndex]), + highlightedSourceStyle().Render(highlight), + lineContentsStyle().Render(lineContents[highlightColumnIndex+len(highlight):])) + + fmt.Fprintf(sb, " %s %s %s%s%s\n", + noNumberSpaces, + lineDelimiter, + strings.Repeat(" ", highlightColumnIndex), + highlightedSourceStyle().Render("^"), + highlightedSourceStyle().Render(strings.Repeat("~", highlightLength))) + } +} + +// classifyErrors returns errors from the composable DSL, the standard DSL, and any other parsing errors. +func classifyErrors(err error) (error, error, error) { + if err == nil { + return nil, nil, nil + } + var standardErr compiler.BaseCompilerError + var composableErr composable.BaseCompilerError + var retStandard, retComposable, allOthers error + + ok := errors.As(err, &standardErr) + if ok { + retStandard = standardErr + } + ok = errors.As(err, &composableErr) + if ok { + retComposable = composableErr + } + + if retStandard == nil && retComposable == nil { + allOthers = err + } + + return retStandard, retComposable, allOthers +} |
