summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/zed/internal/cmd
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
committermo khan <mo@mokhan.ca>2025-07-22 17:35:49 -0600
commit20ef0d92694465ac86b550df139e8366a0a2b4fa (patch)
tree3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/zed/internal/cmd
parent44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff)
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/zed/internal/cmd')
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/backup.go868
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/cmd.go187
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/context.go202
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/import.go181
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/preview.go120
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/restorer.go446
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/schema.go319
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/validate.go414
-rw-r--r--vendor/github.com/authzed/zed/internal/cmd/version.go64
9 files changed, 2801 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/cmd/backup.go b/vendor/github.com/authzed/zed/internal/cmd/backup.go
new file mode 100644
index 0000000..57fa83c
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/backup.go
@@ -0,0 +1,868 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/mattn/go-isatty"
+ "github.com/rodaine/table"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "golang.org/x/exp/constraints"
+ "golang.org/x/exp/maps"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ schemapkg "github.com/authzed/spicedb/pkg/schema"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/generator"
+ "github.com/authzed/spicedb/pkg/tuple"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/pkg/backupformat"
+)
+
+const (
+ returnIfExists = true
+ doNotReturnIfExists = false
+)
+
+// cobraRunEFunc is the signature of a cobra.Command.RunE function.
+type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error)
+
+// withErrorHandling is a wrapper that centralizes error handling, instead of having to scatter it around the command logic.
+func withErrorHandling(f cobraRunEFunc) cobraRunEFunc {
+ return func(cmd *cobra.Command, args []string) (err error) {
+ return addSizeErrInfo(f(cmd, args))
+ }
+}
+
+var (
+ backupCmd = &cobra.Command{
+ Use: "backup <filename>",
+ Short: "Create, restore, and inspect permissions system backups",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ // Create used to be on the root, so add it here for back-compat.
+ RunE: withErrorHandling(backupCreateCmdFunc),
+ }
+
+ backupCreateCmd = &cobra.Command{
+ Use: "create <filename>",
+ Short: "Backup a permission system to a file",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ RunE: withErrorHandling(backupCreateCmdFunc),
+ }
+
+ backupRestoreCmd = &cobra.Command{
+ Use: "restore <filename>",
+ Short: "Restore a permission system from a file",
+ Args: commands.ValidationWrapper(commands.StdinOrExactArgs(1)),
+ RunE: backupRestoreCmdFunc,
+ }
+
+ backupParseSchemaCmd = &cobra.Command{
+ Use: "parse-schema <filename>",
+ Short: "Extract the schema from a backup file",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return backupParseSchemaCmdFunc(cmd, os.Stdout, args)
+ },
+ }
+
+ backupParseRevisionCmd = &cobra.Command{
+ Use: "parse-revision <filename>",
+ Short: "Extract the revision from a backup file",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return backupParseRevisionCmdFunc(cmd, os.Stdout, args)
+ },
+ }
+
+ backupParseRelsCmd = &cobra.Command{
+ Use: "parse-relationships <filename>",
+ Short: "Extract the relationships from a backup file",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return backupParseRelsCmdFunc(cmd, os.Stdout, args)
+ },
+ }
+
+ backupRedactCmd = &cobra.Command{
+ Use: "redact <filename>",
+ Short: "Redact a backup file to remove sensitive information",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return backupRedactCmdFunc(cmd, args)
+ },
+ }
+)
+
+func registerBackupCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(backupCmd)
+ registerBackupCreateFlags(backupCmd)
+
+ backupCmd.AddCommand(backupCreateCmd)
+ registerBackupCreateFlags(backupCreateCmd)
+
+ backupCmd.AddCommand(backupRestoreCmd)
+ registerBackupRestoreFlags(backupRestoreCmd)
+
+ backupCmd.AddCommand(backupRedactCmd)
+ backupRedactCmd.Flags().Bool("redact-definitions", true, "redact definitions")
+ backupRedactCmd.Flags().Bool("redact-relations", true, "redact relations")
+ backupRedactCmd.Flags().Bool("redact-object-ids", true, "redact object IDs")
+ backupRedactCmd.Flags().Bool("print-redacted-object-ids", false, "prints the redacted object IDs")
+
+ // Restore used to be on the root, so add it there too, but hidden.
+ restoreCmd := &cobra.Command{
+ Use: "restore <filename>",
+ Short: "Restore a permission system from a backup file",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ RunE: backupRestoreCmdFunc,
+ Hidden: true,
+ }
+ rootCmd.AddCommand(restoreCmd)
+ registerBackupRestoreFlags(restoreCmd)
+
+ backupCmd.AddCommand(backupParseSchemaCmd)
+ backupParseSchemaCmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
+ backupParseSchemaCmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
+
+ backupCmd.AddCommand(backupParseRevisionCmd)
+ backupCmd.AddCommand(backupParseRelsCmd)
+ backupParseRelsCmd.Flags().String("prefix-filter", "", "Include only relationships with a given prefix")
+}
+
+func registerBackupRestoreFlags(cmd *cobra.Command) {
+ cmd.Flags().Uint("batch-size", 1_000, "restore relationship write batch size")
+ cmd.Flags().Uint("batches-per-transaction", 10, "number of batches per transaction")
+ cmd.Flags().String("conflict-strategy", "fail", "strategy used when a conflicting relationship is found. Possible values: fail, skip, touch")
+ cmd.Flags().Bool("disable-retries", false, "retries when an errors is determined to be retryable (e.g. serialization errors)")
+ cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
+ cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
+ cmd.Flags().Duration("request-timeout", 30*time.Second, "timeout for each request performed during restore")
+}
+
+func registerBackupCreateFlags(cmd *cobra.Command) {
+ cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
+ cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
+ cmd.Flags().Uint32("page-limit", 0, "defines the number of relationships to be read by requested page during backup")
+}
+
+func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, error) {
+ if filename == "-" {
+ log.Trace().Str("filename", "- (stdout)").Send()
+ return os.Stdout, false, nil
+ }
+
+ log.Trace().Str("filename", filename).Send()
+
+ if _, err := os.Stat(filename); err == nil {
+ if !returnIfExists {
+ return nil, false, fmt.Errorf("backup file already exists: %s", filename)
+ }
+
+ f, err := os.OpenFile(filename, os.O_RDWR|os.O_APPEND, 0o644)
+ if err != nil {
+ return nil, false, fmt.Errorf("unable to open existing backup file: %w", err)
+ }
+
+ return f, true, nil
+ }
+
+ f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644)
+ if err != nil {
+ return nil, false, fmt.Errorf("unable to create backup file: %w", err)
+ }
+
+ return f, false, nil
+}
+
+var (
+ missingAllowedTypes = regexp.MustCompile(`(\s*)(relation)(.+)(/\* missing allowed types \*/)(.*)`)
+ shortRelations = regexp.MustCompile(`(\s*)relation [a-z][a-z0-9_]:(.+)`)
+)
+
+func partialPrefixMatch(name, prefix string) bool {
+ return strings.HasPrefix(name, prefix+"/")
+}
+
+func filterSchemaDefs(schema, prefix string) (filteredSchema string, err error) {
+ if schema == "" || prefix == "" {
+ return schema, nil
+ }
+
+ compiledSchema, err := compiler.Compile(
+ compiler.InputSchema{Source: "schema", SchemaString: schema},
+ compiler.AllowUnprefixedObjectType(),
+ compiler.SkipValidation(),
+ )
+ if err != nil {
+ return "", fmt.Errorf("error reading schema: %w", err)
+ }
+
+ var prefixedDefs []compiler.SchemaDefinition
+ for _, def := range compiledSchema.ObjectDefinitions {
+ if partialPrefixMatch(def.Name, prefix) {
+ prefixedDefs = append(prefixedDefs, def)
+ }
+ }
+ for _, def := range compiledSchema.CaveatDefinitions {
+ if partialPrefixMatch(def.Name, prefix) {
+ prefixedDefs = append(prefixedDefs, def)
+ }
+ }
+
+ if len(prefixedDefs) == 0 {
+ return "", errors.New("filtered all definitions from schema")
+ }
+
+ filteredSchema, _, err = generator.GenerateSchema(prefixedDefs)
+ if err != nil {
+ return "", fmt.Errorf("error generating filtered schema: %w", err)
+ }
+
+ // Validate that the type system for the generated schema is comprehensive.
+ compiledFilteredSchema, err := compiler.Compile(
+ compiler.InputSchema{Source: "generated-schema", SchemaString: filteredSchema},
+ compiler.AllowUnprefixedObjectType(),
+ )
+ if err != nil {
+ return "", fmt.Errorf("generated invalid schema: %w", err)
+ }
+
+ for _, rawDef := range compiledFilteredSchema.ObjectDefinitions {
+ ts := schemapkg.NewTypeSystem(schemapkg.ResolverForCompiledSchema(*compiledFilteredSchema))
+ def, err := schemapkg.NewDefinition(ts, rawDef)
+ if err != nil {
+ return "", fmt.Errorf("generated invalid schema: %w", err)
+ }
+ if _, err := def.Validate(context.Background()); err != nil {
+ return "", fmt.Errorf("generated invalid schema: %w", err)
+ }
+ }
+
+ return
+}
+
+func hasRelPrefix(rel *v1.Relationship, prefix string) bool {
+ // Skip any relationships without the prefix on both sides.
+ return strings.HasPrefix(rel.Resource.ObjectType, prefix) &&
+ strings.HasPrefix(rel.Subject.Object.ObjectType, prefix)
+}
+
+func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) {
+ prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter")
+ pageLimit := cobrautil.MustGetUint32(cmd, "page-limit")
+
+ backupFileName, err := computeBackupFileName(cmd, args)
+ if err != nil {
+ return err
+ }
+
+ backupFile, backupExists, err := createBackupFile(backupFileName, returnIfExists)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) {
+ *e = errors.Join(*e, backupFile.Sync())
+ *e = errors.Join(*e, backupFile.Close())
+ }(&err)
+
+ // the goal of this file is to keep the bulk export cursor in case the process is terminated
+ // and we need to resume from where we left off. OCF does not support in-place record updates.
+ progressFile, cursor, err := openProgressFile(backupFileName, backupExists)
+ if err != nil {
+ return err
+ }
+
+ var backupCompleted bool
+ defer func(e *error) {
+ *e = errors.Join(*e, progressFile.Sync())
+ *e = errors.Join(*e, progressFile.Close())
+
+ if backupCompleted {
+ if err := os.Remove(progressFile.Name()); err != nil {
+ log.Warn().
+ Str("progress-file", progressFile.Name()).
+ Msg("failed to remove progress file, consider removing it manually")
+ }
+ }
+ }(&err)
+
+ c, err := client.NewClient(cmd)
+ if err != nil {
+ return fmt.Errorf("unable to initialize client: %w", err)
+ }
+
+ var zedToken *v1.ZedToken
+ var encoder *backupformat.Encoder
+ if backupExists {
+ encoder, err = backupformat.NewEncoderForExisting(backupFile)
+ if err != nil {
+ return fmt.Errorf("error creating backup file encoder: %w", err)
+ }
+ } else {
+ encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile)
+ if err != nil {
+ return err
+ }
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err)
+
+ if zedToken == nil && cursor == nil {
+ return errors.New("malformed existing backup, consider recreating it")
+ }
+
+ req := &v1.BulkExportRelationshipsRequest{
+ OptionalLimit: pageLimit,
+ OptionalCursor: cursor,
+ }
+
+ // if a cursor is present, zedtoken is not needed (it is already in the cursor)
+ if zedToken != nil {
+ req.Consistency = &v1.Consistency{
+ Requirement: &v1.Consistency_AtExactSnapshot{
+ AtExactSnapshot: zedToken,
+ },
+ }
+ }
+
+ ctx := cmd.Context()
+ relationshipStream, err := c.BulkExportRelationships(ctx, req)
+ if err != nil {
+ return fmt.Errorf("error exporting relationships: %w", err)
+ }
+
+ relationshipReadStart := time.Now()
+ tick := time.Tick(5 * time.Second)
+ bar := console.CreateProgressBar("processing backup")
+ var relsFilteredOut, relsProcessed uint64
+ defer func() {
+ _ = bar.Finish()
+
+ evt := log.Info().
+ Uint64("filtered", relsFilteredOut).
+ Uint64("processed", relsProcessed).
+ Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))).
+ Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second))
+ if isCanceled(err) {
+ evt.Msg("backup canceled - resume by restarting the backup command")
+ } else if err != nil {
+ evt.Msg("backup failed")
+ } else {
+ evt.Msg("finished backup")
+ }
+ }()
+
+ for {
+ if err := ctx.Err(); err != nil {
+ if isCanceled(err) {
+ return context.Canceled
+ }
+
+ return fmt.Errorf("aborted backup: %w", err)
+ }
+
+ relsResp, err := relationshipStream.Recv()
+ if err != nil {
+ if isCanceled(err) {
+ return context.Canceled
+ }
+
+ if !errors.Is(err, io.EOF) {
+ return fmt.Errorf("error receiving relationships: %w", err)
+ }
+ break
+ }
+
+ for _, rel := range relsResp.Relationships {
+ if hasRelPrefix(rel, prefixFilter) {
+ if err := encoder.Append(rel); err != nil {
+ return fmt.Errorf("error storing relationship: %w", err)
+ }
+ } else {
+ relsFilteredOut++
+ }
+
+ relsProcessed++
+ if err := bar.Add(1); err != nil {
+ return fmt.Errorf("error incrementing progress bar: %w", err)
+ }
+
+ // progress fallback in case there is no TTY
+ if !isatty.IsTerminal(os.Stderr.Fd()) {
+ select {
+ case <-tick:
+ log.Info().
+ Uint64("filtered", relsFilteredOut).
+ Uint64("processed", relsProcessed).
+ Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))).
+ Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)).
+ Msg("backup progress")
+ default:
+ }
+ }
+ }
+
+ if err := writeProgress(progressFile, relsResp); err != nil {
+ return err
+ }
+ }
+
+ backupCompleted = true
+ return nil
+}
+
+// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup
+// must be taken.
+func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) {
+ prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter")
+
+ schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{})
+ if err != nil {
+ return nil, nil, fmt.Errorf("error reading schema: %w", err)
+ }
+ if schemaResp.ReadAt == nil {
+ return nil, nil, fmt.Errorf("`backup` is not supported on this version of SpiceDB")
+ }
+ schema := schemaResp.SchemaText
+
+ // Remove any invalid relations generated from old, backwards-incompat
+ // Serverless permission systems.
+ if cobrautil.MustGetBool(cmd, "rewrite-legacy") {
+ schema = rewriteLegacy(schema)
+ }
+
+ // Skip any definitions without the provided prefix
+
+ if prefixFilter != "" {
+ schema, err = filterSchemaDefs(schema, prefixFilter)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ zedToken := schemaResp.ReadAt
+
+ encoder, err := backupformat.NewEncoder(backupFile, schema, zedToken)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating backup file encoder: %w", err)
+ }
+
+ return encoder, zedToken, nil
+}
+
+func writeProgress(progressFile *os.File, relsResp *v1.BulkExportRelationshipsResponse) error {
+ err := progressFile.Truncate(0)
+ if err != nil {
+ return fmt.Errorf("unable to truncate backup progress file: %w", err)
+ }
+
+ _, err = progressFile.Seek(0, 0)
+ if err != nil {
+ return fmt.Errorf("unable to seek backup progress file: %w", err)
+ }
+
+ _, err = progressFile.WriteString(relsResp.AfterResultCursor.Token)
+ if err != nil {
+ return fmt.Errorf("unable to write result cursor to backup progress file: %w", err)
+ }
+
+ return nil
+}
+
+// openProgressFile returns the progress marker file and the stored progress cursor if it exists, or creates
+// a new one if it does not exist. If the backup file exists, but the progress marker does not, it will return an error.
+//
+// The progress marker file keeps track of the last successful cursor received from the server, and is used to resume
+// backups in case of failure.
+func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.File, *v1.Cursor, error) {
+ var cursor *v1.Cursor
+ progressFileName := toLockFileName(backupFileName)
+ var progressFile *os.File
+ // if a backup existed
+ var fileMode int
+ readCursor, err := os.ReadFile(progressFileName)
+ if backupAlreadyExisted && (os.IsNotExist(err) || len(readCursor) == 0) {
+ return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName)
+ } else if backupAlreadyExisted && err == nil {
+ cursor = &v1.Cursor{
+ Token: string(readCursor),
+ }
+
+ // if backup existed and there is a progress marker, the latter should not be truncated to make sure the
+ // cursor stays around in case of a failure before we even start ingesting from bulk export
+ fileMode = os.O_WRONLY | os.O_CREATE
+ log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume")
+ } else {
+ // if a backup did not exist, make sure to truncate the progress file
+ fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
+ }
+
+ progressFile, err = os.OpenFile(progressFileName, fileMode, 0o644)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return progressFile, cursor, nil
+}
+
+func toLockFileName(backupFileName string) string {
+ return backupFileName + ".lock"
+}
+
+// computeBackupFileName computes the backup file name based.
+// If no file name is provided, it derives a backup on the current context
+func computeBackupFileName(cmd *cobra.Command, args []string) (string, error) {
+ if len(args) > 0 {
+ return args[0], nil
+ }
+
+ configStore, secretStore := client.DefaultStorage()
+ token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
+ if err != nil {
+ return "", fmt.Errorf("failed to determine current zed context: %w", err)
+ }
+
+ ex, err := os.Executable()
+ if err != nil {
+ return "", err
+ }
+ exPath := filepath.Dir(ex)
+
+ backupFileName := filepath.Join(exPath, token.Name+".zedbackup")
+
+ return backupFileName, nil
+}
+
+func isCanceled(err error) bool {
+ if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled {
+ return true
+ }
+
+ return errors.Is(err, context.Canceled)
+}
+
+func openRestoreFile(filename string) (*os.File, int64, error) {
+ if filename == "" {
+ log.Trace().Str("filename", "(stdin)").Send()
+ return os.Stdin, -1, nil
+ }
+
+ log.Trace().Str("filename", filename).Send()
+
+ stats, err := os.Stat(filename)
+ if err != nil {
+ return nil, 0, fmt.Errorf("unable to stat restore file: %w", err)
+ }
+
+ f, err := os.Open(filename)
+ if err != nil {
+ return nil, 0, fmt.Errorf("unable to open restore file: %w", err)
+ }
+
+ return f, stats.Size(), nil
+}
+
+func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error {
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ if loadedToken := decoder.ZedToken(); loadedToken != nil {
+ log.Debug().Str("revision", loadedToken.Token).Msg("parsed revision")
+ }
+
+ schema := decoder.Schema()
+
+ // Remove any invalid relations generated from old, backwards-incompat
+ // Serverless permission systems.
+ if cobrautil.MustGetBool(cmd, "rewrite-legacy") {
+ schema = rewriteLegacy(schema)
+ }
+
+ // Skip any definitions without the provided prefix
+ prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter")
+ if prefixFilter != "" {
+ schema, err = filterSchemaDefs(schema, prefixFilter)
+ if err != nil {
+ return err
+ }
+ }
+ log.Debug().Str("schema", schema).Bool("filtered", prefixFilter != "").Msg("parsed schema")
+
+ c, err := client.NewClient(cmd)
+ if err != nil {
+ return fmt.Errorf("unable to initialize client: %w", err)
+ }
+
+ batchSize := cobrautil.MustGetUint(cmd, "batch-size")
+ batchesPerTransaction := cobrautil.MustGetUint(cmd, "batches-per-transaction")
+
+ strategy, err := GetEnum[ConflictStrategy](cmd, "conflict-strategy", conflictStrategyMapping)
+ if err != nil {
+ return err
+ }
+ disableRetries := cobrautil.MustGetBool(cmd, "disable-retries")
+ requestTimeout := cobrautil.MustGetDuration(cmd, "request-timeout")
+
+ return newRestorer(schema, decoder, c, prefixFilter, batchSize, batchesPerTransaction, strategy,
+ disableRetries, requestTimeout).restoreFromDecoder(cmd.Context())
+}
+
+// GetEnum is a helper for getting an enum value from a string cobra flag.
+func GetEnum[E constraints.Integer](cmd *cobra.Command, name string, mapping map[string]E) (E, error) {
+ value := cobrautil.MustGetString(cmd, name)
+ value = strings.TrimSpace(strings.ToLower(value))
+ if enum, ok := mapping[value]; ok {
+ return enum, nil
+ }
+
+ var zeroValueE E
+ return zeroValueE, fmt.Errorf("unexpected flag '%s' value '%s': should be one of %v", name, value, maps.Keys(mapping))
+}
+
+func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error {
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ schema := decoder.Schema()
+
+ // Remove any invalid relations generated from old, backwards-incompat
+ // Serverless permission systems.
+ if cobrautil.MustGetBool(cmd, "rewrite-legacy") {
+ schema = rewriteLegacy(schema)
+ }
+
+ // Skip any definitions without the provided prefix
+ if prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter"); prefixFilter != "" {
+ schema, err = filterSchemaDefs(schema, prefixFilter)
+ if err != nil {
+ return err
+ }
+ }
+
+ _, err = fmt.Fprintln(out, schema)
+ return err
+}
+
+func backupParseRevisionCmdFunc(_ *cobra.Command, out io.Writer, args []string) error {
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ loadedToken := decoder.ZedToken()
+ if loadedToken == nil {
+ return fmt.Errorf("failed to parse decoded revision")
+ }
+
+ _, err = fmt.Fprintln(out, loadedToken.Token)
+ return err
+}
+
+func backupRedactCmdFunc(cmd *cobra.Command, args []string) error {
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return fmt.Errorf("error creating restore file decoder: %w", err)
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ filename := args[0] + ".redacted"
+ writer, _, err := createBackupFile(filename, doNotReturnIfExists)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, writer.Close()) }(&err)
+
+ redactor, err := backupformat.NewRedactor(decoder, writer, backupformat.RedactionOptions{
+ RedactDefinitions: cobrautil.MustGetBool(cmd, "redact-definitions"),
+ RedactRelations: cobrautil.MustGetBool(cmd, "redact-relations"),
+ RedactObjectIDs: cobrautil.MustGetBool(cmd, "redact-object-ids"),
+ })
+ if err != nil {
+ return fmt.Errorf("error creating redactor: %w", err)
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, redactor.Close()) }(&err)
+ bar := console.CreateProgressBar("redacting backup")
+ var written int64
+ for {
+ if err := cmd.Context().Err(); err != nil {
+ return fmt.Errorf("aborted redaction: %w", err)
+ }
+
+ err := redactor.Next()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ return fmt.Errorf("error redacting: %w", err)
+ }
+
+ written++
+ if err := bar.Set64(written); err != nil {
+ return fmt.Errorf("error incrementing progress bar: %w", err)
+ }
+ }
+
+ if err := bar.Finish(); err != nil {
+ return fmt.Errorf("error finalizing progress bar: %w", err)
+ }
+
+ fmt.Println("Redaction map:")
+ fmt.Println("--------------")
+ fmt.Println()
+
+ // Draw a table of definitions, caveats and relations mapped.
+ tbl := table.New("Definition Name", "Redacted Name")
+ for k, v := range redactor.RedactionMap().Definitions {
+ tbl.AddRow(k, v)
+ }
+
+ tbl.Print()
+ fmt.Println()
+
+ if len(redactor.RedactionMap().Caveats) > 0 {
+ tbl = table.New("Caveat Name", "Redacted Name")
+ for k, v := range redactor.RedactionMap().Caveats {
+ tbl.AddRow(k, v)
+ }
+ tbl.Print()
+ fmt.Println()
+ }
+
+ tbl = table.New("Relation/Permission Name", "Redacted Name")
+ for k, v := range redactor.RedactionMap().Relations {
+ tbl.AddRow(k, v)
+ }
+ tbl.Print()
+ fmt.Println()
+
+ if len(redactor.RedactionMap().ObjectIDs) > 0 && cobrautil.MustGetBool(cmd, "print-redacted-object-ids") {
+ tbl = table.New("Object ID", "Redacted Object ID")
+ for k, v := range redactor.RedactionMap().ObjectIDs {
+ tbl.AddRow(k, v)
+ }
+ tbl.Print()
+ fmt.Println()
+ }
+
+ return nil
+}
+
+func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error {
+ prefix := cobrautil.MustGetString(cmd, "prefix-filter")
+ decoder, closer, err := decoderFromArgs(args...)
+ if err != nil {
+ return err
+ }
+
+ defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err)
+ defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err)
+
+ for rel, err := decoder.Next(); rel != nil && err == nil; rel, err = decoder.Next() {
+ if !hasRelPrefix(rel, prefix) {
+ continue
+ }
+
+ relString, err := tuple.V1StringRelationship(rel)
+ if err != nil {
+ return err
+ }
+
+ if _, err = fmt.Fprintln(out, replaceRelString(relString)); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func decoderFromArgs(args ...string) (*backupformat.Decoder, io.Closer, error) {
+ filename := "" // Default to stdin.
+ if len(args) > 0 {
+ filename = args[0]
+ }
+
+ f, _, err := openRestoreFile(filename)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ decoder, err := backupformat.NewDecoder(f)
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating restore file decoder: %w", err)
+ }
+
+ return decoder, f, nil
+}
+
+func replaceRelString(rel string) string {
+ rel = strings.Replace(rel, "@", " ", 1)
+ return strings.Replace(rel, "#", " ", 1)
+}
+
+func rewriteLegacy(schema string) string {
+ schema = string(missingAllowedTypes.ReplaceAll([]byte(schema), []byte("\n/* deleted missing allowed type error */")))
+ return string(shortRelations.ReplaceAll([]byte(schema), []byte("\n/* deleted short relation name */")))
+}
+
+var sizeErrorRegEx = regexp.MustCompile(`received message larger than max \((\d+) vs. (\d+)\)`)
+
+func addSizeErrInfo(err error) error {
+ if err == nil {
+ return nil
+ }
+
+ code := status.Code(err)
+ if code != codes.ResourceExhausted {
+ return err
+ }
+
+ if !strings.Contains(err.Error(), "received message larger than max") {
+ return err
+ }
+
+ matches := sizeErrorRegEx.FindStringSubmatch(err.Error())
+ if len(matches) != 3 {
+ return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err)
+ }
+
+ necessaryByteCount, atoiErr := strconv.Atoi(matches[1])
+ if atoiErr != nil {
+ return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err)
+ }
+
+ return fmt.Errorf("%w: set flag --max-message-size=%d to increase the maximum allowable size", err, 2*necessaryByteCount)
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/cmd.go b/vendor/github.com/authzed/zed/internal/cmd/cmd.go
new file mode 100644
index 0000000..a2e5332
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/cmd.go
@@ -0,0 +1,187 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "io"
+ "os"
+ "os/signal"
+ "strings"
+ "syscall"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/cobrautil/v2/cobrazerolog"
+ "github.com/mattn/go-isatty"
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+
+ "github.com/authzed/zed/internal/commands"
+)
+
+var (
+ SyncFlagsCmdFunc = cobrautil.SyncViperPreRunE("ZED")
+ errParsing = errors.New("parsing error")
+)
+
+func init() {
+ // NOTE: this is mostly to set up logging in the case where
+ // the command doesn't exist or the construction of the command
+ // errors out before the PersistentPreRunE setup in the below function.
+ // It helps keep log output visually consistent for a user even in
+ // exceptional cases.
+ var output io.Writer
+
+ if isatty.IsTerminal(os.Stdout.Fd()) {
+ output = zerolog.ConsoleWriter{Out: os.Stderr}
+ } else {
+ output = os.Stderr
+ }
+
+ l := zerolog.New(output).With().Timestamp().Logger()
+
+ log.Logger = l
+}
+
+var flagError = flagErrorFunc
+
+func flagErrorFunc(cmd *cobra.Command, err error) error {
+ cmd.Println(err)
+ cmd.Println(cmd.UsageString())
+ return errParsing
+}
+
+// InitialiseRootCmd This function is utilised to generate docs for zed
+func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command {
+ rootCmd := &cobra.Command{
+ Use: "zed",
+ Short: "SpiceDB CLI, built by AuthZed",
+ Long: "A command-line client for managing SpiceDB clusters.",
+ PersistentPreRunE: cobrautil.CommandStack(
+ zl.RunE(),
+ SyncFlagsCmdFunc,
+ commands.InjectRequestID,
+ ),
+ SilenceErrors: true,
+ SilenceUsage: true,
+ }
+ rootCmd.SetFlagErrorFunc(func(command *cobra.Command, err error) error {
+ return flagError(command, err)
+ })
+
+ zl.RegisterFlags(rootCmd.PersistentFlags())
+
+ rootCmd.PersistentFlags().String("endpoint", "", "spicedb gRPC API endpoint")
+ rootCmd.PersistentFlags().String("permissions-system", "", "permissions system to query")
+ rootCmd.PersistentFlags().String("hostname-override", "", "override the hostname used in the connection to the endpoint")
+ rootCmd.PersistentFlags().String("token", "", "token used to authenticate to SpiceDB")
+ rootCmd.PersistentFlags().String("certificate-path", "", "path to certificate authority used to verify secure connections")
+ rootCmd.PersistentFlags().Bool("insecure", false, "connect over a plaintext connection")
+ rootCmd.PersistentFlags().Bool("skip-version-check", false, "if true, no version check is performed against the server")
+ rootCmd.PersistentFlags().Bool("no-verify-ca", false, "do not attempt to verify the server's certificate chain and host name")
+ rootCmd.PersistentFlags().Bool("debug", false, "enable debug logging")
+ rootCmd.PersistentFlags().String("request-id", "", "optional id to send along with SpiceDB requests for tracing")
+ rootCmd.PersistentFlags().Int("max-message-size", 0, "maximum size *in bytes* (defaults to 4_194_304 bytes ~= 4MB) of a gRPC message that can be sent or received by zed")
+ rootCmd.PersistentFlags().String("proxy", "", "specify a SOCKS5 proxy address")
+ rootCmd.PersistentFlags().Uint("max-retries", 10, "maximum number of sequential retries to attempt when a request fails")
+ _ = rootCmd.PersistentFlags().MarkHidden("debug") // This cannot return its error.
+
+ versionCmd := &cobra.Command{
+ Use: "version",
+ Short: "Display zed and SpiceDB version information",
+ RunE: versionCmdFunc,
+ }
+ cobrautil.RegisterVersionFlags(versionCmd.Flags())
+ versionCmd.Flags().Bool("include-remote-version", true, "whether to display the version of Authzed or SpiceDB for the current context")
+ rootCmd.AddCommand(versionCmd)
+
+ // Register root-level aliases
+ rootCmd.AddCommand(&cobra.Command{
+ Use: "use <context>",
+ Short: "Alias for `zed context use`",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ RunE: contextUseCmdFunc,
+ ValidArgsFunction: ContextGet,
+ })
+
+ // Register CLI-only commands.
+ registerContextCmd(rootCmd)
+ registerImportCmd(rootCmd)
+ registerValidateCmd(rootCmd)
+ registerBackupCmd(rootCmd)
+ registerPreviewCmd(rootCmd)
+
+ // Register shared commands.
+ commands.RegisterPermissionCmd(rootCmd)
+
+ relCmd := commands.RegisterRelationshipCmd(rootCmd)
+
+ commands.RegisterWatchCmd(rootCmd)
+ commands.RegisterWatchRelationshipCmd(relCmd)
+
+ schemaCmd := commands.RegisterSchemaCmd(rootCmd)
+ registerAdditionalSchemaCmds(schemaCmd)
+
+ return rootCmd
+}
+
+func Run() {
+ if err := runWithoutExit(); err != nil {
+ os.Exit(1)
+ }
+}
+
+func runWithoutExit() error {
+ zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel))
+
+ rootCmd := InitialiseRootCmd(zl)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ signalChan := make(chan os.Signal, 2)
+ signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM)
+ defer func() {
+ signal.Stop(signalChan)
+ cancel()
+ }()
+
+ go func() {
+ select {
+ case <-signalChan:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+
+ return handleError(rootCmd, rootCmd.ExecuteContext(ctx))
+}
+
+func handleError(command *cobra.Command, err error) error {
+ if err == nil {
+ return nil
+ }
+ // this snippet of code is taken from Command.ExecuteC in order to determine the command that was ultimately
+ // parsed. This is necessary to be able to print the proper command-specific usage
+ var findErr error
+ var cmdToExecute *cobra.Command
+ args := os.Args[1:]
+ if command.TraverseChildren {
+ cmdToExecute, _, findErr = command.Traverse(args)
+ } else {
+ cmdToExecute, _, findErr = command.Find(args)
+ }
+ if findErr != nil {
+ cmdToExecute = command
+ }
+
+ if errors.Is(err, commands.ValidationError{}) {
+ _ = flagError(cmdToExecute, err)
+ } else if err != nil && strings.Contains(err.Error(), "unknown command") {
+ _ = flagError(cmdToExecute, err)
+ } else if !errors.Is(err, errParsing) {
+ log.Err(err).Msg("terminated with errors")
+ }
+
+ return err
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/context.go b/vendor/github.com/authzed/zed/internal/cmd/context.go
new file mode 100644
index 0000000..d5a0eec
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/context.go
@@ -0,0 +1,202 @@
+package cmd
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/stringz"
+ "github.com/spf13/cobra"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/internal/printers"
+ "github.com/authzed/zed/internal/storage"
+)
+
+func registerContextCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(contextCmd)
+
+ contextCmd.AddCommand(contextListCmd)
+ contextListCmd.Flags().Bool("reveal-tokens", false, "display secrets in results")
+
+ contextCmd.AddCommand(contextSetCmd)
+ contextCmd.AddCommand(contextRemoveCmd)
+ contextCmd.AddCommand(contextUseCmd)
+}
+
+var contextCmd = &cobra.Command{
+ Use: "context <subcommand>",
+ Short: "Manage configurations for connecting to SpiceDB deployments",
+ Aliases: []string{"ctx"},
+}
+
+var contextListCmd = &cobra.Command{
+ Use: "list",
+ Short: "Lists all available contexts",
+ Aliases: []string{"ls"},
+ Args: commands.ValidationWrapper(cobra.ExactArgs(0)),
+ ValidArgsFunction: cobra.NoFileCompletions,
+ RunE: contextListCmdFunc,
+}
+
+var contextSetCmd = &cobra.Command{
+ Use: "set <name> <endpoint> <api-token>",
+ Short: "Creates or overwrite a context",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(3)),
+ ValidArgsFunction: cobra.NoFileCompletions,
+ RunE: contextSetCmdFunc,
+}
+
+var contextRemoveCmd = &cobra.Command{
+ Use: "remove <system>",
+ Short: "Removes a context",
+ Aliases: []string{"rm"},
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ ValidArgsFunction: ContextGet,
+ RunE: contextRemoveCmdFunc,
+}
+
+var contextUseCmd = &cobra.Command{
+ Use: "use <system>",
+ Short: "Sets a context as the current context",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ ValidArgsFunction: ContextGet,
+ RunE: contextUseCmdFunc,
+}
+
+func ContextGet(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) {
+ _, secretStore := client.DefaultStorage()
+ secrets, err := secretStore.Get()
+ if err != nil {
+ return nil, cobra.ShellCompDirectiveError
+ }
+
+ names := make([]string, 0, len(secrets.Tokens))
+ for _, token := range secrets.Tokens {
+ names = append(names, token.Name)
+ }
+
+ return names, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace | cobra.ShellCompDirectiveKeepOrder
+}
+
+func contextListCmdFunc(cmd *cobra.Command, _ []string) error {
+ cfgStore, secretStore := client.DefaultStorage()
+ secrets, err := secretStore.Get()
+ if err != nil {
+ return err
+ }
+
+ cfg, err := cfgStore.Get()
+ if err != nil {
+ return err
+ }
+
+ rows := make([][]string, 0, len(secrets.Tokens))
+ for _, token := range secrets.Tokens {
+ current := ""
+ if token.Name == cfg.CurrentToken {
+ current = " ✓ "
+ }
+ secret := token.APIToken
+ if !cobrautil.MustGetBool(cmd, "reveal-tokens") {
+ secret = token.Redacted()
+ }
+
+ var certStr string
+ if token.IsInsecure() {
+ certStr = "insecure"
+ } else if token.HasNoVerifyCA() {
+ certStr = "no-verify-ca"
+ } else if _, ok := token.Certificate(); ok {
+ certStr = "custom"
+ } else {
+ certStr = "system"
+ }
+
+ rows = append(rows, []string{
+ current,
+ token.Name,
+ token.Endpoint,
+ secret,
+ certStr,
+ })
+ }
+
+ printers.PrintTable(os.Stdout, []string{"current", "name", "endpoint", "token", "tls cert"}, rows)
+
+ return nil
+}
+
+func contextSetCmdFunc(cmd *cobra.Command, args []string) error {
+ var name, endpoint, apiToken string
+ err := stringz.Unpack(args, &name, &endpoint, &apiToken)
+ if err != nil {
+ return err
+ }
+
+ certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path")
+ var certBytes []byte
+ if certPath != "" {
+ certBytes, err = os.ReadFile(certPath)
+ if err != nil {
+ return fmt.Errorf("failed to read ceritficate: %w", err)
+ }
+ }
+
+ insecure := cobrautil.MustGetBool(cmd, "insecure")
+ noVerifyCA := cobrautil.MustGetBool(cmd, "no-verify-ca")
+ cfgStore, secretStore := client.DefaultStorage()
+ err = storage.PutToken(storage.Token{
+ Name: name,
+ Endpoint: stringz.DefaultEmpty(endpoint, "grpc.authzed.com:443"),
+ APIToken: apiToken,
+ Insecure: &insecure,
+ NoVerifyCA: &noVerifyCA,
+ CACert: certBytes,
+ }, secretStore)
+ if err != nil {
+ return err
+ }
+
+ return storage.SetCurrentToken(name, cfgStore, secretStore)
+}
+
+func contextRemoveCmdFunc(_ *cobra.Command, args []string) error {
+ // If the token is what's currently being used, remove it from the config.
+ cfgStore, secretStore := client.DefaultStorage()
+ cfg, err := cfgStore.Get()
+ if err != nil {
+ return err
+ }
+
+ if cfg.CurrentToken == args[0] {
+ cfg.CurrentToken = ""
+ }
+
+ err = cfgStore.Put(cfg)
+ if err != nil {
+ return err
+ }
+
+ return storage.RemoveToken(args[0], secretStore)
+}
+
+func contextUseCmdFunc(_ *cobra.Command, args []string) error {
+ cfgStore, secretStore := client.DefaultStorage()
+ switch len(args) {
+ case 0:
+ cfg, err := cfgStore.Get()
+ if err != nil {
+ return err
+ }
+ console.Println(cfg.CurrentToken)
+ case 1:
+ return storage.SetCurrentToken(args[0], cfgStore, secretStore)
+ default:
+ panic("cobra command did not enforce valid number of args")
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/import.go b/vendor/github.com/authzed/zed/internal/cmd/import.go
new file mode 100644
index 0000000..687d42e
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/import.go
@@ -0,0 +1,181 @@
+package cmd
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "net/url"
+ "strings"
+
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/tuple"
+ "github.com/authzed/spicedb/pkg/validationfile"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/decode"
+ "github.com/authzed/zed/internal/grpcutil"
+)
+
+func registerImportCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(importCmd)
+ importCmd.Flags().Int("batch-size", 1000, "import batch size")
+ importCmd.Flags().Int("workers", 1, "number of concurrent batching workers")
+ importCmd.Flags().Bool("schema", true, "import schema")
+ importCmd.Flags().Bool("relationships", true, "import relationships")
+ importCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before importing")
+}
+
+var importCmd = &cobra.Command{
+ Use: "import <url>",
+ Short: "Imports schema and relationships from a file or url",
+ Example: `
+ From a gist:
+ zed import https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6
+
+ From a playground link:
+ zed import https://play.authzed.com/s/iksdFvCtvnkR/schema
+
+ From pastebin:
+ zed import https://pastebin.com/8qU45rVK
+
+ From a devtools instance:
+ zed import https://localhost:8443/download
+
+ From a local file (with prefix):
+ zed import file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ From a local file (no prefix):
+ zed import authzed-x7izWU8_2Gw3.yaml
+
+ Only schema:
+ zed import --relationships=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ Only relationships:
+ zed import --schema=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ With schema definition prefix:
+ zed import --schema-definition-prefix=mypermsystem file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+`,
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ RunE: importCmdFunc,
+}
+
+func importCmdFunc(cmd *cobra.Command, args []string) error {
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+ u, err := url.Parse(args[0])
+ if err != nil {
+ return err
+ }
+
+ decoder, err := decode.DecoderForURL(u)
+ if err != nil {
+ return err
+ }
+ var p validationfile.ValidationFile
+ if _, _, err := decoder(&p); err != nil {
+ return err
+ }
+
+ prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil)
+ if err != nil {
+ return err
+ }
+
+ if cobrautil.MustGetBool(cmd, "schema") {
+ if err := importSchema(cmd.Context(), client, p.Schema.Schema, prefix); err != nil {
+ return err
+ }
+ }
+
+ if cobrautil.MustGetBool(cmd, "relationships") {
+ batchSize := cobrautil.MustGetInt(cmd, "batch-size")
+ workers := cobrautil.MustGetInt(cmd, "workers")
+ if err := importRelationships(cmd.Context(), client, p.Relationships.RelationshipsString, prefix, batchSize, workers); err != nil {
+ return err
+ }
+ }
+
+ return err
+}
+
+func importSchema(ctx context.Context, client client.Client, schema string, definitionPrefix string) error {
+ log.Info().Msg("importing schema")
+
+ // Recompile the schema with the specified prefix.
+ schemaText, err := rewriteSchema(schema, definitionPrefix)
+ if err != nil {
+ return err
+ }
+
+ // Write the recompiled and regenerated schema.
+ request := &v1.WriteSchemaRequest{Schema: schemaText}
+ log.Trace().Interface("request", request).Str("schema", schemaText).Msg("writing schema")
+
+ if _, err := client.WriteSchema(ctx, request); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func importRelationships(ctx context.Context, client client.Client, relationships string, definitionPrefix string, batchSize int, workers int) error {
+ relationshipUpdates := make([]*v1.RelationshipUpdate, 0)
+ scanner := bufio.NewScanner(strings.NewReader(relationships))
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" {
+ continue
+ }
+ if strings.HasPrefix(line, "//") {
+ continue
+ }
+ rel, err := tuple.ParseV1Rel(line)
+ if err != nil {
+ return fmt.Errorf("failed to parse %s as relationship", line)
+ }
+ log.Trace().Str("line", line).Send()
+
+ // Rewrite the prefix on the references, if any.
+ if len(definitionPrefix) > 0 {
+ rel.Resource.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Resource.ObjectType)
+ rel.Subject.Object.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Subject.Object.ObjectType)
+ }
+
+ relationshipUpdates = append(relationshipUpdates, &v1.RelationshipUpdate{
+ Operation: v1.RelationshipUpdate_OPERATION_TOUCH,
+ Relationship: rel,
+ })
+ }
+ if err := scanner.Err(); err != nil {
+ return err
+ }
+
+ log.Info().
+ Int("batch_size", batchSize).
+ Int("workers", workers).
+ Int("count", len(relationshipUpdates)).
+ Msg("importing relationships")
+
+ err := grpcutil.ConcurrentBatch(ctx, len(relationshipUpdates), batchSize, workers, func(ctx context.Context, no int, start int, end int) error {
+ request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates[start:end]}
+ _, err := client.WriteRelationships(ctx, request)
+ if err != nil {
+ return err
+ }
+
+ log.Info().
+ Int("batch_no", no).
+ Int("count", len(relationshipUpdates[start:end])).
+ Msg("imported relationships")
+ return nil
+ })
+ return err
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/preview.go b/vendor/github.com/authzed/zed/internal/cmd/preview.go
new file mode 100644
index 0000000..279e310
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/preview.go
@@ -0,0 +1,120 @@
+package cmd
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "golang.org/x/term"
+
+ newcompiler "github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
+ newinput "github.com/authzed/spicedb/pkg/composableschemadsl/input"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/generator"
+
+ "github.com/authzed/zed/internal/commands"
+)
+
+func registerPreviewCmd(rootCmd *cobra.Command) {
+ rootCmd.AddCommand(previewCmd)
+
+ previewCmd.AddCommand(schemaCmd)
+
+ schemaCmd.AddCommand(schemaCompileCmd)
+ schemaCompileCmd.Flags().String("out", "", "output filepath; omitting writes to stdout")
+}
+
+var previewCmd = &cobra.Command{
+ Use: "preview <subcommand>",
+ Short: "Experimental commands that have been made available for preview",
+}
+
+var schemaCmd = &cobra.Command{
+ Use: "schema <subcommand>",
+ Short: "Manage schema for a permissions system",
+}
+
+var schemaCompileCmd = &cobra.Command{
+ Use: "compile <file>",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(1)),
+ Short: "Compile a schema that uses extended syntax into one that can be written to SpiceDB",
+ Example: `
+ Write to stdout:
+ zed preview schema compile root.zed
+ Write to an output file:
+ zed preview schema compile root.zed --out compiled.zed
+ `,
+ ValidArgsFunction: commands.FileExtensionCompletions("zed"),
+ RunE: schemaCompileCmdFunc,
+}
+
+// Compiles an input schema written in the new composable schema syntax
+// and produces it as a fully-realized schema
+func schemaCompileCmdFunc(cmd *cobra.Command, args []string) error {
+ stdOutFd, err := safecast.ToInt(uint(os.Stdout.Fd()))
+ if err != nil {
+ return err
+ }
+ outputFilepath := cobrautil.MustGetString(cmd, "out")
+ if outputFilepath == "" && !term.IsTerminal(stdOutFd) {
+ return fmt.Errorf("must provide stdout or output file path")
+ }
+
+ inputFilepath := args[0]
+ inputSourceFolder := filepath.Dir(inputFilepath)
+ var schemaBytes []byte
+ schemaBytes, err = os.ReadFile(inputFilepath)
+ if err != nil {
+ return fmt.Errorf("failed to read schema file: %w", err)
+ }
+ log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file")
+
+ if len(schemaBytes) == 0 {
+ return errors.New("attempted to compile empty schema")
+ }
+
+ compiled, err := newcompiler.Compile(newcompiler.InputSchema{
+ Source: newinput.Source(inputFilepath),
+ SchemaString: string(schemaBytes),
+ }, newcompiler.AllowUnprefixedObjectType(),
+ newcompiler.SourceFolder(inputSourceFolder))
+ if err != nil {
+ return err
+ }
+
+ // Attempt to cast one kind of OrderedDefinition to another
+ oldDefinitions := make([]compiler.SchemaDefinition, 0, len(compiled.OrderedDefinitions))
+ for _, definition := range compiled.OrderedDefinitions {
+ oldDefinition, ok := definition.(compiler.SchemaDefinition)
+ if !ok {
+ return fmt.Errorf("could not convert definition to old schemadefinition: %v", oldDefinition)
+ }
+ oldDefinitions = append(oldDefinitions, oldDefinition)
+ }
+
+ // This is where we functionally assert that the two systems are compatible
+ generated, _, err := generator.GenerateSchema(oldDefinitions)
+ if err != nil {
+ return fmt.Errorf("could not generate resulting schema: %w", err)
+ }
+
+ // Add a newline at the end for hygiene's sake
+ terminated := generated + "\n"
+
+ if outputFilepath == "" {
+ // Print to stdout
+ fmt.Print(terminated)
+ } else {
+ err = os.WriteFile(outputFilepath, []byte(terminated), 0o_600)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/restorer.go b/vendor/github.com/authzed/zed/internal/cmd/restorer.go
new file mode 100644
index 0000000..468064f
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/restorer.go
@@ -0,0 +1,446 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/cenkalti/backoff/v4"
+ "github.com/mattn/go-isatty"
+ "github.com/rs/zerolog/log"
+ "github.com/samber/lo"
+ "github.com/schollz/progressbar/v3"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/pkg/backupformat"
+)
+
+type ConflictStrategy int
+
+const (
+ Fail ConflictStrategy = iota
+ Skip
+ Touch
+
+ defaultBackoff = 50 * time.Millisecond
+ defaultMaxRetries = 10
+)
+
+var conflictStrategyMapping = map[string]ConflictStrategy{
+ "fail": Fail,
+ "skip": Skip,
+ "touch": Touch,
+}
+
+// Fallback for datastore implementations on SpiceDB < 1.29.0 not returning proper gRPC codes
+// Remove once https://github.com/authzed/spicedb/pull/1688 lands
+var (
+ txConflictCodes = []string{
+ "SQLSTATE 23505", // CockroachDB
+ "Error 1062 (23000)", // MySQL
+ }
+ retryableErrorCodes = []string{
+ "retryable error", // CockroachDB, PostgreSQL
+ "try restarting transaction", "Error 1205", // MySQL
+ }
+)
+
+type restorer struct {
+ schema string
+ decoder *backupformat.Decoder
+ client client.Client
+ prefixFilter string
+ batchSize uint
+ batchesPerTransaction uint
+ conflictStrategy ConflictStrategy
+ disableRetryErrors bool
+ bar *progressbar.ProgressBar
+
+ // stats
+ filteredOutRels uint
+ writtenRels uint
+ writtenBatches uint
+ skippedRels uint
+ skippedBatches uint
+ duplicateRels uint
+ duplicateBatches uint
+ totalRetries uint
+ requestTimeout time.Duration
+}
+
+func newRestorer(schema string, decoder *backupformat.Decoder, client client.Client, prefixFilter string, batchSize uint,
+ batchesPerTransaction uint, conflictStrategy ConflictStrategy, disableRetryErrors bool,
+ requestTimeout time.Duration,
+) *restorer {
+ return &restorer{
+ decoder: decoder,
+ schema: schema,
+ client: client,
+ prefixFilter: prefixFilter,
+ requestTimeout: requestTimeout,
+ batchSize: batchSize,
+ batchesPerTransaction: batchesPerTransaction,
+ conflictStrategy: conflictStrategy,
+ disableRetryErrors: disableRetryErrors,
+ bar: console.CreateProgressBar("restoring from backup"),
+ }
+}
+
+func (r *restorer) restoreFromDecoder(ctx context.Context) error {
+ relationshipWriteStart := time.Now()
+ defer func() {
+ if err := r.bar.Finish(); err != nil {
+ log.Warn().Err(err).Msg("error finalizing progress bar")
+ }
+ }()
+
+ r.bar.Describe("restoring schema from backup")
+ if _, err := r.client.WriteSchema(ctx, &v1.WriteSchemaRequest{
+ Schema: r.schema,
+ }); err != nil {
+ return fmt.Errorf("unable to write schema: %w", err)
+ }
+
+ relationshipWriter, err := r.client.BulkImportRelationships(ctx)
+ if err != nil {
+ return fmt.Errorf("error creating writer stream: %w", err)
+ }
+
+ r.bar.Describe("restoring relationships from backup")
+ batch := make([]*v1.Relationship, 0, r.batchSize)
+ batchesToBeCommitted := make([][]*v1.Relationship, 0, r.batchesPerTransaction)
+ for rel, err := r.decoder.Next(); rel != nil && err == nil; rel, err = r.decoder.Next() {
+ if err := ctx.Err(); err != nil {
+ r.bar.Describe("backup restore aborted")
+ return fmt.Errorf("aborted restore: %w", err)
+ }
+
+ if !hasRelPrefix(rel, r.prefixFilter) {
+ r.filteredOutRels++
+ continue
+ }
+
+ batch = append(batch, rel)
+
+ if uint(len(batch))%r.batchSize == 0 {
+ batchesToBeCommitted = append(batchesToBeCommitted, batch)
+ err := relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{
+ Relationships: batch,
+ })
+ if err != nil {
+ // It feels non-idiomatic to check for error and perform an operation, but in gRPC, when an element
+ // sent over the stream fails, we need to call recvAndClose() to get the error.
+ if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
+ return fmt.Errorf("error committing batches: %w", err)
+ }
+
+ // after an error
+ relationshipWriter, err = r.client.BulkImportRelationships(ctx)
+ if err != nil {
+ return fmt.Errorf("error creating new writer stream: %w", err)
+ }
+
+ batchesToBeCommitted = batchesToBeCommitted[:0]
+ batch = batch[:0]
+ continue
+ }
+
+ // The batch just sent is kept in batchesToBeCommitted, which is used for retries.
+ // Therefore, we cannot reuse the batch. Batches may fail on send, or on commit (CloseAndRecv).
+ batch = make([]*v1.Relationship, 0, r.batchSize)
+
+ // if we've sent the maximum number of batches per transaction, proceed to commit
+ if uint(len(batchesToBeCommitted))%r.batchesPerTransaction != 0 {
+ continue
+ }
+
+ if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
+ return fmt.Errorf("error committing batches: %w", err)
+ }
+
+ relationshipWriter, err = r.client.BulkImportRelationships(ctx)
+ if err != nil {
+ return fmt.Errorf("error creating new writer stream: %w", err)
+ }
+
+ batchesToBeCommitted = batchesToBeCommitted[:0]
+ }
+ }
+
+ // Write the last batch
+ if len(batch) > 0 {
+ // Since we are going to close the stream anyway after the last batch, and given the actual error
+ // is only returned on CloseAndRecv(), we have to ignore the error here in order to get the actual
+ // underlying error that caused Send() to fail. It also gives us the opportunity to retry it
+ // in case it failed.
+ batchesToBeCommitted = append(batchesToBeCommitted, batch)
+ _ = relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{Relationships: batch})
+ }
+
+ if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil {
+ return fmt.Errorf("error committing last set of batches: %w", err)
+ }
+
+ r.bar.Describe("completed import")
+ if err := r.bar.Finish(); err != nil {
+ log.Warn().Err(err).Msg("error finalizing progress bar")
+ }
+
+ totalTime := time.Since(relationshipWriteStart)
+ log.Info().
+ Uint("batches", r.writtenBatches).
+ Uint("relationships_loaded", r.writtenRels).
+ Uint("relationships_skipped", r.skippedRels).
+ Uint("duplicate_relationships", r.duplicateRels).
+ Uint("relationships_filtered_out", r.filteredOutRels).
+ Uint("retried_errors", r.totalRetries).
+ Uint64("perSecond", perSec(uint64(r.writtenRels), totalTime)).
+ Stringer("duration", totalTime).
+ Msg("finished restore")
+ return nil
+}
+
+func (r *restorer) commitStream(ctx context.Context, bulkImportClient v1.ExperimentalService_BulkImportRelationshipsClient,
+ batchesToBeCommitted [][]*v1.Relationship,
+) error {
+ var numLoaded, expectedLoaded, retries uint
+ for _, b := range batchesToBeCommitted {
+ expectedLoaded += uint(len(b))
+ }
+
+ resp, err := bulkImportClient.CloseAndRecv() // transaction commit happens here
+
+ // Failure to commit transaction means the stream is closed, so it can't be reused any further
+ // The retry will be done using WriteRelationships instead of BulkImportRelationships
+ // This lets us retry with TOUCH semantics in case of failure due to duplicates
+ retryable := isRetryableError(err)
+ conflict := isAlreadyExistsError(err)
+ canceled, cancelErr := isCanceledError(ctx.Err(), err)
+ unknown := !retryable && !conflict && !canceled && err != nil
+
+ numBatches := uint(len(batchesToBeCommitted))
+
+ switch {
+ case canceled:
+ r.bar.Describe("backup restore aborted")
+ return cancelErr
+ case unknown:
+ r.bar.Describe("failed with unrecoverable error")
+ return fmt.Errorf("error finalizing write of %d batches: %w", len(batchesToBeCommitted), err)
+ case retryable && r.disableRetryErrors:
+ return err
+ case conflict && r.conflictStrategy == Skip:
+ r.skippedRels += expectedLoaded
+ r.skippedBatches += numBatches
+ r.duplicateBatches += numBatches
+ r.duplicateRels += expectedLoaded
+ r.bar.Describe("skipping conflicting batch")
+ case conflict && r.conflictStrategy == Touch:
+ r.bar.Describe("touching conflicting batch")
+ r.duplicateRels += expectedLoaded
+ r.duplicateBatches += numBatches
+ r.totalRetries++
+ numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted)
+ if err != nil {
+ return fmt.Errorf("failed to write retried batch: %w", err)
+ }
+
+ retries++ // account for the initial attempt
+ r.writtenBatches += numBatches
+ r.writtenRels += numLoaded
+ case conflict && r.conflictStrategy == Fail:
+ r.bar.Describe("conflict detected, aborting restore")
+ return fmt.Errorf("duplicate relationships found")
+ case retryable:
+ r.bar.Describe("retrying after error")
+ r.totalRetries++
+ numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted)
+ if err != nil {
+ return fmt.Errorf("failed to write retried batch: %w", err)
+ }
+
+ retries++ // account for the initial attempt
+ r.writtenBatches += numBatches
+ r.writtenRels += numLoaded
+ default:
+ r.bar.Describe("restoring relationships from backup")
+ r.writtenBatches += numBatches
+ }
+
+ // it was a successful transaction commit without duplicates
+ if resp != nil {
+ numLoaded, err := safecast.ToUint(resp.NumLoaded)
+ if err != nil {
+ return spiceerrors.MustBugf("could not cast numLoaded to uint")
+ }
+ r.writtenRels += numLoaded
+ if uint64(expectedLoaded) != resp.NumLoaded {
+ log.Warn().Uint64("loaded", resp.NumLoaded).Uint("expected", expectedLoaded).Msg("unexpected number of relationships loaded")
+ }
+ }
+
+ writtenAndSkipped, err := safecast.ToInt64(r.writtenRels + r.skippedRels)
+ if err != nil {
+ return fmt.Errorf("too many written and skipped rels for an int64")
+ }
+
+ if err := r.bar.Set64(writtenAndSkipped); err != nil {
+ return fmt.Errorf("error incrementing progress bar: %w", err)
+ }
+
+ if !isatty.IsTerminal(os.Stderr.Fd()) {
+ log.Trace().
+ Uint("batches_written", r.writtenBatches).
+ Uint("relationships_written", r.writtenRels).
+ Uint("duplicate_batches", r.duplicateBatches).
+ Uint("duplicate_relationships", r.duplicateRels).
+ Uint("skipped_batches", r.skippedBatches).
+ Uint("skipped_relationships", r.skippedRels).
+ Uint("retries", retries).
+ Msg("restore progress")
+ }
+
+ return nil
+}
+
+// writeBatchesWithRetry writes a set of batches using touch semantics and without transactional guarantees -
+// each batch will be committed independently. If a batch fails, it will be retried up to 10 times with a backoff.
+func (r *restorer) writeBatchesWithRetry(ctx context.Context, batches [][]*v1.Relationship) (uint, uint, error) {
+ backoffInterval := backoff.NewExponentialBackOff()
+ backoffInterval.InitialInterval = defaultBackoff
+ backoffInterval.MaxInterval = 2 * time.Second
+ backoffInterval.MaxElapsedTime = 0
+ backoffInterval.Reset()
+
+ var currentRetries, totalRetries, loadedRels uint
+ for _, batch := range batches {
+ updates := lo.Map[*v1.Relationship, *v1.RelationshipUpdate](batch, func(item *v1.Relationship, _ int) *v1.RelationshipUpdate {
+ return &v1.RelationshipUpdate{
+ Relationship: item,
+ Operation: v1.RelationshipUpdate_OPERATION_TOUCH,
+ }
+ })
+
+ for {
+ cancelCtx, cancel := context.WithTimeout(ctx, r.requestTimeout)
+ _, err := r.client.WriteRelationships(cancelCtx, &v1.WriteRelationshipsRequest{Updates: updates})
+ cancel()
+
+ if isRetryableError(err) && currentRetries < defaultMaxRetries {
+ // throttle the writes so we don't overwhelm the server
+ bo := backoffInterval.NextBackOff()
+ r.bar.Describe(fmt.Sprintf("retrying write with backoff %s after error (attempt %d/%d)", bo,
+ currentRetries+1, defaultMaxRetries))
+ time.Sleep(bo)
+ currentRetries++
+ r.totalRetries++
+ totalRetries++
+ continue
+ }
+ if err != nil {
+ return 0, 0, err
+ }
+
+ currentRetries = 0
+ backoffInterval.Reset()
+ loadedRels += uint(len(batch))
+ break
+ }
+ }
+
+ return loadedRels, totalRetries, nil
+}
+
+func isAlreadyExistsError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ if isGRPCCode(err, codes.AlreadyExists) {
+ return true
+ }
+
+ return isContainsErrorString(err, txConflictCodes...)
+}
+
+func isRetryableError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ if isGRPCCode(err, codes.Unavailable, codes.DeadlineExceeded) {
+ return true
+ }
+
+ if isContainsErrorString(err, retryableErrorCodes...) {
+ return true
+ }
+
+ return errors.Is(err, context.DeadlineExceeded)
+}
+
+func isCanceledError(errs ...error) (bool, error) {
+ for _, err := range errs {
+ if err == nil {
+ continue
+ }
+
+ if errors.Is(err, context.Canceled) {
+ return true, err
+ }
+
+ if isGRPCCode(err, codes.Canceled) {
+ return true, err
+ }
+ }
+
+ return false, nil
+}
+
+func isContainsErrorString(err error, errStrings ...string) bool {
+ if err == nil {
+ return false
+ }
+
+ for _, errString := range errStrings {
+ if strings.Contains(err.Error(), errString) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func isGRPCCode(err error, codes ...codes.Code) bool {
+ if err == nil {
+ return false
+ }
+
+ if s, ok := status.FromError(err); ok {
+ for _, code := range codes {
+ if s.Code() == code {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+func perSec(i uint64, d time.Duration) uint64 {
+ secs := uint64(d.Seconds())
+ if secs == 0 {
+ return i
+ }
+ return i / secs
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/schema.go b/vendor/github.com/authzed/zed/internal/cmd/schema.go
new file mode 100644
index 0000000..bbe52f3
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/schema.go
@@ -0,0 +1,319 @@
+package cmd
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "strings"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/jzelinskie/stringz"
+ "github.com/rs/zerolog/log"
+ "github.com/spf13/cobra"
+ "golang.org/x/term"
+
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+ "github.com/authzed/spicedb/pkg/caveats/types"
+ "github.com/authzed/spicedb/pkg/diff"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/schemadsl/generator"
+ "github.com/authzed/spicedb/pkg/schemadsl/input"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+)
+
+func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) {
+ schemaCmd.AddCommand(schemaCopyCmd)
+ schemaCopyCmd.Flags().Bool("json", false, "output as JSON")
+ schemaCopyCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing")
+
+ schemaCmd.AddCommand(schemaWriteCmd)
+ schemaWriteCmd.Flags().Bool("json", false, "output as JSON")
+ schemaWriteCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing")
+
+ schemaCmd.AddCommand(schemaDiffCmd)
+}
+
+var schemaWriteCmd = &cobra.Command{
+ Use: "write <file?>",
+ Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)),
+ Short: "Write a schema file (.zed or stdin) to the current permissions system",
+ ValidArgsFunction: commands.FileExtensionCompletions("zed"),
+ RunE: schemaWriteCmdFunc,
+}
+
+var schemaCopyCmd = &cobra.Command{
+ Use: "copy <src context> <dest context>",
+ Short: "Copy a schema from one context into another",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(2)),
+ ValidArgsFunction: ContextGet,
+ RunE: schemaCopyCmdFunc,
+}
+
+var schemaDiffCmd = &cobra.Command{
+ Use: "diff <before file> <after file>",
+ Short: "Diff two schema files",
+ Args: commands.ValidationWrapper(cobra.ExactArgs(2)),
+ RunE: schemaDiffCmdFunc,
+}
+
+func schemaDiffCmdFunc(_ *cobra.Command, args []string) error {
+ beforeBytes, err := os.ReadFile(args[0])
+ if err != nil {
+ return fmt.Errorf("failed to read before schema file: %w", err)
+ }
+
+ afterBytes, err := os.ReadFile(args[1])
+ if err != nil {
+ return fmt.Errorf("failed to read after schema file: %w", err)
+ }
+
+ before, err := compiler.Compile(
+ compiler.InputSchema{Source: input.Source(args[0]), SchemaString: string(beforeBytes)},
+ compiler.AllowUnprefixedObjectType(),
+ )
+ if err != nil {
+ return err
+ }
+
+ after, err := compiler.Compile(
+ compiler.InputSchema{Source: input.Source(args[1]), SchemaString: string(afterBytes)},
+ compiler.AllowUnprefixedObjectType(),
+ )
+ if err != nil {
+ return err
+ }
+
+ dbefore := diff.NewDiffableSchemaFromCompiledSchema(before)
+ dafter := diff.NewDiffableSchemaFromCompiledSchema(after)
+
+ schemaDiff, err := diff.DiffSchemas(dbefore, dafter, types.Default.TypeSet)
+ if err != nil {
+ return err
+ }
+
+ for _, ns := range schemaDiff.AddedNamespaces {
+ console.Printf("Added definition: %s\n", ns)
+ }
+
+ for _, ns := range schemaDiff.RemovedNamespaces {
+ console.Printf("Removed definition: %s\n", ns)
+ }
+
+ for nsName, ns := range schemaDiff.ChangedNamespaces {
+ console.Printf("Changed definition: %s\n", nsName)
+ for _, delta := range ns.Deltas() {
+ console.Printf("\t %s: %s\n", delta.Type, delta.RelationName)
+ }
+ }
+
+ for _, caveat := range schemaDiff.AddedCaveats {
+ console.Printf("Added caveat: %s\n", caveat)
+ }
+
+ for _, caveat := range schemaDiff.RemovedCaveats {
+ console.Printf("Removed caveat: %s\n", caveat)
+ }
+
+ return nil
+}
+
+func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error {
+ _, secretStore := client.DefaultStorage()
+ srcClient, err := client.NewClientForContext(cmd, args[0], secretStore)
+ if err != nil {
+ return err
+ }
+
+ destClient, err := client.NewClientForContext(cmd, args[1], secretStore)
+ if err != nil {
+ return err
+ }
+
+ readRequest := &v1.ReadSchemaRequest{}
+ log.Trace().Interface("request", readRequest).Msg("requesting schema read")
+
+ readResp, err := srcClient.ReadSchema(cmd.Context(), readRequest)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to read schema")
+ }
+ log.Trace().Interface("response", readResp).Msg("read schema")
+
+ prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), nil, &readResp.SchemaText)
+ if err != nil {
+ return err
+ }
+
+ schemaText, err := rewriteSchema(readResp.SchemaText, prefix)
+ if err != nil {
+ return err
+ }
+
+ writeRequest := &v1.WriteSchemaRequest{Schema: schemaText}
+ log.Trace().Interface("request", writeRequest).Msg("writing schema")
+
+ resp, err := destClient.WriteSchema(cmd.Context(), writeRequest)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to write schema")
+ }
+ log.Trace().Interface("response", resp).Msg("wrote schema")
+
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := commands.PrettyProto(resp)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to convert schema to JSON")
+ }
+
+ console.Println(string(prettyProto))
+ return nil
+ }
+
+ return nil
+}
+
+func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error {
+ intFd, err := safecast.ToInt(uint(os.Stdout.Fd()))
+ if err != nil {
+ return err
+ }
+ if len(args) == 0 && term.IsTerminal(intFd) {
+ return fmt.Errorf("must provide file path or contents via stdin")
+ }
+
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+ var schemaBytes []byte
+ switch len(args) {
+ case 1:
+ schemaBytes, err = os.ReadFile(args[0])
+ if err != nil {
+ return fmt.Errorf("failed to read schema file: %w", err)
+ }
+ log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file")
+ case 0:
+ schemaBytes, err = io.ReadAll(os.Stdin)
+ if err != nil {
+ return fmt.Errorf("failed to read schema file: %w", err)
+ }
+ log.Trace().Str("schema", string(schemaBytes)).Msg("read schema from stdin")
+ default:
+ panic("schemaWriteCmdFunc called with incorrect number of arguments")
+ }
+
+ if len(schemaBytes) == 0 {
+ return errors.New("attempted to write empty schema")
+ }
+
+ prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil)
+ if err != nil {
+ return err
+ }
+
+ schemaText, err := rewriteSchema(string(schemaBytes), prefix)
+ if err != nil {
+ return err
+ }
+
+ request := &v1.WriteSchemaRequest{Schema: schemaText}
+ log.Trace().Interface("request", request).Msg("writing schema")
+
+ resp, err := client.WriteSchema(cmd.Context(), request)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to write schema")
+ }
+ log.Trace().Interface("response", resp).Msg("wrote schema")
+
+ if cobrautil.MustGetBool(cmd, "json") {
+ prettyProto, err := commands.PrettyProto(resp)
+ if err != nil {
+ log.Fatal().Err(err).Msg("failed to convert schema to JSON")
+ }
+
+ console.Println(string(prettyProto))
+ return nil
+ }
+
+ return nil
+}
+
+// rewriteSchema rewrites the given existing schema to include the specified prefix on all definitions.
+func rewriteSchema(existingSchemaText string, definitionPrefix string) (string, error) {
+ if definitionPrefix == "" {
+ return existingSchemaText, nil
+ }
+
+ compiled, err := compiler.Compile(
+ compiler.InputSchema{Source: input.Source("schema"), SchemaString: existingSchemaText},
+ compiler.ObjectTypePrefix(definitionPrefix),
+ compiler.SkipValidation(),
+ )
+ if err != nil {
+ return "", err
+ }
+
+ generated, _, err := generator.GenerateSchema(compiled.OrderedDefinitions)
+ return generated, err
+}
+
+// determinePrefixForSchema determines the prefix to be applied to a schema that will be written.
+//
+// If specifiedPrefix is non-empty, it is returned immediately.
+// If existingSchema is non-nil, it is parsed for the prefix.
+// Otherwise, the client is used to retrieve the existing schema (if any), and the prefix is retrieved from there.
+func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client client.Client, existingSchema *string) (string, error) {
+ if specifiedPrefix != "" {
+ return specifiedPrefix, nil
+ }
+
+ var schemaText string
+ if existingSchema != nil {
+ schemaText = *existingSchema
+ } else {
+ readSchemaText, err := commands.ReadSchema(ctx, client)
+ if err != nil {
+ return "", nil
+ }
+ schemaText = readSchemaText
+ }
+
+ // If there is no schema found, return the empty string.
+ if schemaText == "" {
+ return "", nil
+ }
+
+ // Otherwise, compile the schema and grab the prefixes of the namespaces defined.
+ found, err := compiler.Compile(
+ compiler.InputSchema{Source: input.Source("schema"), SchemaString: schemaText},
+ compiler.AllowUnprefixedObjectType(),
+ compiler.SkipValidation(),
+ )
+ if err != nil {
+ return "", err
+ }
+
+ foundPrefixes := make([]string, 0, len(found.OrderedDefinitions))
+ for _, def := range found.OrderedDefinitions {
+ if strings.Contains(def.GetName(), "/") {
+ parts := strings.Split(def.GetName(), "/")
+ foundPrefixes = append(foundPrefixes, parts[0])
+ } else {
+ foundPrefixes = append(foundPrefixes, "")
+ }
+ }
+
+ prefixes := stringz.Dedup(foundPrefixes)
+ if len(prefixes) == 1 {
+ prefix := prefixes[0]
+ log.Debug().Str("prefix", prefix).Msg("found schema definition prefix")
+ return prefix, nil
+ }
+
+ return "", nil
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/validate.go b/vendor/github.com/authzed/zed/internal/cmd/validate.go
new file mode 100644
index 0000000..ffca1c6
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/validate.go
@@ -0,0 +1,414 @@
+package cmd
+
+import (
+ "errors"
+ "fmt"
+ "net/url"
+ "os"
+ "strings"
+
+ "github.com/ccoveille/go-safecast"
+ "github.com/charmbracelet/lipgloss"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/muesli/termenv"
+ "github.com/spf13/cobra"
+
+ composable "github.com/authzed/spicedb/pkg/composableschemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/development"
+ core "github.com/authzed/spicedb/pkg/proto/core/v1"
+ devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1"
+ "github.com/authzed/spicedb/pkg/schemadsl/compiler"
+ "github.com/authzed/spicedb/pkg/spiceerrors"
+ "github.com/authzed/spicedb/pkg/validationfile"
+
+ "github.com/authzed/zed/internal/commands"
+ "github.com/authzed/zed/internal/console"
+ "github.com/authzed/zed/internal/decode"
+ "github.com/authzed/zed/internal/printers"
+)
+
+var (
+ // NOTE: these need to be set *after* the renderer has been set, otherwise
+ // the forceColor setting can't work, hence the thunking.
+ success = func() string {
+ return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("10")).Render("Success!")
+ }
+ complete = func() string { return lipgloss.NewStyle().Bold(true).Render("complete") }
+ errorPrefix = func() string { return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("9")).Render("error: ") }
+ warningPrefix = func() string {
+ return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("3")).Render("warning: ")
+ }
+ errorMessageStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true).Width(80) }
+ linePrefixStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("12")) }
+ highlightedSourceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) }
+ highlightedLineStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) }
+ codeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("8")) }
+ highlightedCodeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("15")) }
+ traceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true) }
+)
+
+func registerValidateCmd(cmd *cobra.Command) {
+ validateCmd.Flags().Bool("force-color", false, "force color code output even in non-tty environments")
+ validateCmd.Flags().String("schema-type", "", "force validation according to specific schema syntax (\"\", \"composable\", \"standard\")")
+ cmd.AddCommand(validateCmd)
+}
+
+var validateCmd = &cobra.Command{
+ Use: "validate <validation_file_or_schema_file>",
+ Short: "Validates the given validation file (.yaml, .zaml) or schema file (.zed)",
+ Example: `
+ From a local file (with prefix):
+ zed validate file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
+
+ From a local file (no prefix):
+ zed validate authzed-x7izWU8_2Gw3.yaml
+
+ From a gist:
+ zed validate https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6
+
+ From a playground link:
+ zed validate https://play.authzed.com/s/iksdFvCtvnkR/schema
+
+ From pastebin:
+ zed validate https://pastebin.com/8qU45rVK
+
+ From a devtools instance:
+ zed validate https://localhost:8443/download`,
+ Args: commands.ValidationWrapper(cobra.MinimumNArgs(1)),
+ ValidArgsFunction: commands.FileExtensionCompletions("zed", "yaml", "zaml"),
+ PreRunE: validatePreRunE,
+ RunE: func(cmd *cobra.Command, filenames []string) error {
+ result, shouldExit, err := validateCmdFunc(cmd, filenames)
+ if err != nil {
+ return err
+ }
+ console.Print(result)
+ if shouldExit {
+ os.Exit(1)
+ }
+ return nil
+ },
+
+ // A schema that causes the parser/compiler to error will halt execution
+ // of this command with an error. In that case, we want to just display the error,
+ // rather than showing usage for this command.
+ SilenceUsage: true,
+}
+
+var validSchemaTypes = []string{"", "standard", "composable"}
+
+func validatePreRunE(cmd *cobra.Command, _ []string) error {
+ // Override lipgloss's autodetection of whether it's in a terminal environment
+ // and display things in color anyway. This can be nice in CI environments that
+ // support it.
+ setForceColor := cobrautil.MustGetBool(cmd, "force-color")
+ if setForceColor {
+ lipgloss.SetColorProfile(termenv.ANSI256)
+ }
+
+ schemaType := cobrautil.MustGetString(cmd, "schema-type")
+ schemaTypeValid := false
+ for _, validType := range validSchemaTypes {
+ if schemaType == validType {
+ schemaTypeValid = true
+ }
+ }
+ if !schemaTypeValid {
+ return fmt.Errorf("schema-type must be one of \"\", \"standard\", \"composable\". received: %s", schemaType)
+ }
+
+ return nil
+}
+
+// validateCmdFunc returns the string to print to the user, whether to return a non-zero status code, and any errors.
+func validateCmdFunc(cmd *cobra.Command, filenames []string) (string, bool, error) {
+ // Initialize variables for multiple files
+ var (
+ totalFiles = len(filenames)
+ successfullyValidatedFiles = 0
+ shouldExit = false
+ toPrint = &strings.Builder{}
+ schemaType = cobrautil.MustGetString(cmd, "schema-type")
+ )
+
+ for _, filename := range filenames {
+ // If we're running over multiple files, print the filename for context/debugging purposes
+ if totalFiles > 1 {
+ toPrint.WriteString(filename + "\n")
+ }
+
+ u, err := url.Parse(filename)
+ if err != nil {
+ return "", false, err
+ }
+
+ decoder, err := decode.DecoderForURL(u)
+ if err != nil {
+ return "", false, err
+ }
+
+ var parsed validationfile.ValidationFile
+ validateContents, isOnlySchema, err := decoder(&parsed)
+ standardErrors, composableErrs, otherErrs := classifyErrors(err)
+
+ switch schemaType {
+ case "standard":
+ if standardErrors != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(standardErrors, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, standardErrors
+ }
+ case "composable":
+ if composableErrs != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(composableErrs, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, composableErrs
+ }
+ default:
+ // By default, validate will attempt to validate a schema first according to composable schema rules,
+ // then standard schema rules,
+ // and if both fail it will show the errors from composable schema.
+ if composableErrs != nil && standardErrors != nil {
+ var errWithSource spiceerrors.WithSourceError
+ if errors.As(composableErrs, &errWithSource) {
+ outputErrorWithSource(toPrint, validateContents, errWithSource)
+ shouldExit = true
+ }
+ return "", shouldExit, composableErrs
+ }
+ }
+
+ if otherErrs != nil {
+ return "", false, otherErrs
+ }
+
+ tuples := make([]*core.RelationTuple, 0)
+ totalAssertions := 0
+ totalRelationsValidated := 0
+
+ for _, rel := range parsed.Relationships.Relationships {
+ tuples = append(tuples, rel.ToCoreTuple())
+ }
+
+ // Create the development context for each run
+ ctx := cmd.Context()
+ devCtx, devErrs, err := development.NewDevContext(ctx, &devinterface.RequestContext{
+ Schema: parsed.Schema.Schema,
+ Relationships: tuples,
+ })
+ if err != nil {
+ return "", false, err
+ }
+ if devErrs != nil {
+ schemaOffset := parsed.Schema.SourcePosition.LineNumber
+ if isOnlySchema {
+ schemaOffset = 0
+ }
+
+ // Output errors
+ outputDeveloperErrorsWithLineOffset(toPrint, validateContents, devErrs.InputErrors, schemaOffset)
+ return toPrint.String(), true, nil
+ }
+ // Run assertions
+ adevErrs, aerr := development.RunAllAssertions(devCtx, &parsed.Assertions)
+ if aerr != nil {
+ return "", false, aerr
+ }
+ if adevErrs != nil {
+ outputDeveloperErrors(toPrint, validateContents, adevErrs)
+ return toPrint.String(), true, nil
+ }
+ successfullyValidatedFiles++
+
+ // Run expected relations for file
+ _, erDevErrs, rerr := development.RunValidation(devCtx, &parsed.ExpectedRelations)
+ if rerr != nil {
+ return "", false, rerr
+ }
+ if erDevErrs != nil {
+ outputDeveloperErrors(toPrint, validateContents, erDevErrs)
+ return toPrint.String(), true, nil
+ }
+ // Print out any warnings for file
+ warnings, err := development.GetWarnings(ctx, devCtx)
+ if err != nil {
+ return "", false, err
+ }
+
+ if len(warnings) > 0 {
+ for _, warning := range warnings {
+ fmt.Fprintf(toPrint, "%s%s\n", warningPrefix(), warning.Message)
+ outputForLine(toPrint, validateContents, uint64(warning.Line), warning.SourceCode, uint64(warning.Column)) // warning.LineNumber is 1-indexed
+ toPrint.WriteString("\n")
+ }
+
+ toPrint.WriteString(complete())
+ } else {
+ toPrint.WriteString(success())
+ }
+ totalAssertions += len(parsed.Assertions.AssertTrue) + len(parsed.Assertions.AssertFalse)
+ totalRelationsValidated += len(parsed.ExpectedRelations.ValidationMap)
+
+ fmt.Fprintf(toPrint, " - %d relationships loaded, %d assertions run, %d expected relations validated\n",
+ len(tuples),
+ totalAssertions,
+ totalRelationsValidated)
+ }
+
+ if totalFiles > 1 {
+ fmt.Fprintf(toPrint, "total files: %d, successfully validated files: %d\n", totalFiles, successfullyValidatedFiles)
+ }
+
+ return toPrint.String(), shouldExit, nil
+}
+
+func outputErrorWithSource(sb *strings.Builder, validateContents []byte, errWithSource spiceerrors.WithSourceError) {
+ fmt.Fprintf(sb, "%s%s\n", errorPrefix(), errorMessageStyle().Render(errWithSource.Error()))
+ outputForLine(sb, validateContents, errWithSource.LineNumber, errWithSource.SourceCodeString, 0) // errWithSource.LineNumber is 1-indexed
+}
+
+func outputForLine(sb *strings.Builder, validateContents []byte, oneIndexedLineNumber uint64, sourceCodeString string, oneIndexedColumnPosition uint64) {
+ lines := strings.Split(string(validateContents), "\n")
+ // These should be fine to be zero if the cast fails.
+ intLineNumber, _ := safecast.ToInt(oneIndexedLineNumber)
+ intColumnPosition, _ := safecast.ToInt(oneIndexedColumnPosition)
+ errorLineNumber := intLineNumber - 1
+ for i := errorLineNumber - 3; i < errorLineNumber+3; i++ {
+ if i == errorLineNumber {
+ renderLine(sb, lines, i, sourceCodeString, errorLineNumber, intColumnPosition-1)
+ } else {
+ renderLine(sb, lines, i, "", errorLineNumber, -1)
+ }
+ }
+}
+
+func outputDeveloperErrors(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError) {
+ outputDeveloperErrorsWithLineOffset(sb, validateContents, devErrors, 0)
+}
+
+func outputDeveloperErrorsWithLineOffset(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError, lineOffset int) {
+ lines := strings.Split(string(validateContents), "\n")
+
+ for _, devErr := range devErrors {
+ outputDeveloperError(sb, devErr, lines, lineOffset)
+ }
+}
+
+func outputDeveloperError(sb *strings.Builder, devError *devinterface.DeveloperError, lines []string, lineOffset int) {
+ fmt.Fprintf(sb, "%s %s\n", errorPrefix(), errorMessageStyle().Render(devError.Message))
+ errorLineNumber := int(devError.Line) - 1 + lineOffset // devError.Line is 1-indexed
+ for i := errorLineNumber - 3; i < errorLineNumber+3; i++ {
+ if i == errorLineNumber {
+ renderLine(sb, lines, i, devError.Context, errorLineNumber, -1)
+ } else {
+ renderLine(sb, lines, i, "", errorLineNumber, -1)
+ }
+ }
+
+ if devError.CheckResolvedDebugInformation != nil && devError.CheckResolvedDebugInformation.Check != nil {
+ fmt.Fprintf(sb, "\n %s\n", traceStyle().Render("Explanation:"))
+ tp := printers.NewTreePrinter()
+ printers.DisplayCheckTrace(devError.CheckResolvedDebugInformation.Check, tp, true)
+ sb.WriteString(tp.Indented())
+ }
+
+ sb.WriteString("\n\n")
+}
+
+func renderLine(sb *strings.Builder, lines []string, index int, highlight string, highlightLineIndex int, highlightStartingColumnIndex int) {
+ if index < 0 || index >= len(lines) {
+ return
+ }
+
+ lineNumberLength := len(fmt.Sprintf("%d", len(lines)))
+ lineContents := strings.ReplaceAll(lines[index], "\t", " ")
+ lineDelimiter := "|"
+
+ highlightLength := max(0, len(highlight)-1)
+ highlightColumnIndex := -1
+
+ // If the highlight string was provided, then we need to find the index of the highlight
+ // string in the line contents to determine where to place the caret.
+ if len(highlight) > 0 {
+ offset := 0
+ for {
+ foundRelativeIndex := strings.Index(lineContents[offset:], highlight)
+ foundIndex := foundRelativeIndex + offset
+ if foundIndex >= highlightStartingColumnIndex {
+ highlightColumnIndex = foundIndex
+ break
+ }
+
+ offset = foundIndex + 1
+ if foundRelativeIndex < 0 || foundIndex > len(lineContents) {
+ break
+ }
+ }
+ } else if highlightStartingColumnIndex >= 0 {
+ // Otherwise, just show a caret at the specified starting column, if any.
+ highlightColumnIndex = highlightStartingColumnIndex
+ }
+
+ lineNumberStr := fmt.Sprintf("%d", index+1)
+ noNumberSpaces := strings.Repeat(" ", lineNumberLength)
+
+ lineNumberStyle := linePrefixStyle
+ lineContentsStyle := codeStyle
+ if index == highlightLineIndex {
+ lineNumberStyle = highlightedLineStyle
+ lineContentsStyle = highlightedCodeStyle
+ lineDelimiter = ">"
+ }
+
+ lineNumberSpacer := strings.Repeat(" ", lineNumberLength-len(lineNumberStr))
+
+ if highlightColumnIndex < 0 {
+ fmt.Fprintf(sb, " %s%s %s %s\n", lineNumberSpacer, lineNumberStyle().Render(lineNumberStr), lineDelimiter, lineContentsStyle().Render(lineContents))
+ } else {
+ fmt.Fprintf(sb, " %s%s %s %s%s%s\n",
+ lineNumberSpacer,
+ lineNumberStyle().Render(lineNumberStr),
+ lineDelimiter,
+ lineContentsStyle().Render(lineContents[0:highlightColumnIndex]),
+ highlightedSourceStyle().Render(highlight),
+ lineContentsStyle().Render(lineContents[highlightColumnIndex+len(highlight):]))
+
+ fmt.Fprintf(sb, " %s %s %s%s%s\n",
+ noNumberSpaces,
+ lineDelimiter,
+ strings.Repeat(" ", highlightColumnIndex),
+ highlightedSourceStyle().Render("^"),
+ highlightedSourceStyle().Render(strings.Repeat("~", highlightLength)))
+ }
+}
+
+// classifyErrors returns errors from the composable DSL, the standard DSL, and any other parsing errors.
+func classifyErrors(err error) (error, error, error) {
+ if err == nil {
+ return nil, nil, nil
+ }
+ var standardErr compiler.BaseCompilerError
+ var composableErr composable.BaseCompilerError
+ var retStandard, retComposable, allOthers error
+
+ ok := errors.As(err, &standardErr)
+ if ok {
+ retStandard = standardErr
+ }
+ ok = errors.As(err, &composableErr)
+ if ok {
+ retComposable = composableErr
+ }
+
+ if retStandard == nil && retComposable == nil {
+ allOthers = err
+ }
+
+ return retStandard, retComposable, allOthers
+}
diff --git a/vendor/github.com/authzed/zed/internal/cmd/version.go b/vendor/github.com/authzed/zed/internal/cmd/version.go
new file mode 100644
index 0000000..b87ec66
--- /dev/null
+++ b/vendor/github.com/authzed/zed/internal/cmd/version.go
@@ -0,0 +1,64 @@
+package cmd
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/gookit/color"
+ "github.com/jzelinskie/cobrautil/v2"
+ "github.com/mattn/go-isatty"
+ "github.com/spf13/cobra"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+
+ "github.com/authzed/authzed-go/pkg/responsemeta"
+ v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
+
+ "github.com/authzed/zed/internal/client"
+ "github.com/authzed/zed/internal/console"
+)
+
+func versionCmdFunc(cmd *cobra.Command, _ []string) error {
+ if !isatty.IsTerminal(os.Stdout.Fd()) {
+ color.Disable()
+ }
+
+ includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version")
+ hasContext := false
+ if includeRemoteVersion {
+ configStore, secretStore := client.DefaultStorage()
+ _, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore)
+ hasContext = err == nil
+ }
+
+ if hasContext && includeRemoteVersion {
+ green := color.FgGreen.Render
+ fmt.Print(green("client: "))
+ }
+
+ console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps")))
+
+ if hasContext && includeRemoteVersion {
+ client, err := client.NewClient(cmd)
+ if err != nil {
+ return err
+ }
+
+ // NOTE: we ignore the error here, as it may be due to a schema not existing, or
+ // the client being unable to connect, etc. We just treat all such cases as an unknown
+ // version.
+ var headerMD metadata.MD
+ _, _ = client.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}, grpc.Header(&headerMD))
+ version := headerMD.Get(string(responsemeta.ServerVersion))
+
+ blue := color.FgLightBlue.Render
+ fmt.Print(blue("service: "))
+ if len(version) == 1 {
+ console.Println(version[0])
+ } else {
+ console.Println("(unknown)")
+ }
+ }
+
+ return nil
+}