diff options
| author | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-07-22 17:35:49 -0600 |
| commit | 20ef0d92694465ac86b550df139e8366a0a2b4fa (patch) | |
| tree | 3f14589e1ce6eb9306a3af31c3a1f9e1af5ed637 /vendor/github.com/authzed/zed/internal/cmd | |
| parent | 44e0d272c040cdc53a98b9f1dc58ae7da67752e6 (diff) | |
feat: connect to spicedb
Diffstat (limited to 'vendor/github.com/authzed/zed/internal/cmd')
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/backup.go | 868 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/cmd.go | 187 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/context.go | 202 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/import.go | 181 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/preview.go | 120 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/restorer.go | 446 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/schema.go | 319 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/validate.go | 414 | ||||
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/version.go | 64 |
9 files changed, 2801 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/cmd/backup.go b/vendor/github.com/authzed/zed/internal/cmd/backup.go new file mode 100644 index 0000000..57fa83c --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/backup.go @@ -0,0 +1,868 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/mattn/go-isatty" + "github.com/rodaine/table" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + schemapkg "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/tuple" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/pkg/backupformat" +) + +const ( + returnIfExists = true + doNotReturnIfExists = false +) + +// cobraRunEFunc is the signature of a cobra.Command.RunE function. +type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error) + +// withErrorHandling is a wrapper that centralizes error handling, instead of having to scatter it around the command logic. +func withErrorHandling(f cobraRunEFunc) cobraRunEFunc { + return func(cmd *cobra.Command, args []string) (err error) { + return addSizeErrInfo(f(cmd, args)) + } +} + +var ( + backupCmd = &cobra.Command{ + Use: "backup <filename>", + Short: "Create, restore, and inspect permissions system backups", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + // Create used to be on the root, so add it here for back-compat. + RunE: withErrorHandling(backupCreateCmdFunc), + } + + backupCreateCmd = &cobra.Command{ + Use: "create <filename>", + Short: "Backup a permission system to a file", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + RunE: withErrorHandling(backupCreateCmdFunc), + } + + backupRestoreCmd = &cobra.Command{ + Use: "restore <filename>", + Short: "Restore a permission system from a file", + Args: commands.ValidationWrapper(commands.StdinOrExactArgs(1)), + RunE: backupRestoreCmdFunc, + } + + backupParseSchemaCmd = &cobra.Command{ + Use: "parse-schema <filename>", + Short: "Extract the schema from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseSchemaCmdFunc(cmd, os.Stdout, args) + }, + } + + backupParseRevisionCmd = &cobra.Command{ + Use: "parse-revision <filename>", + Short: "Extract the revision from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseRevisionCmdFunc(cmd, os.Stdout, args) + }, + } + + backupParseRelsCmd = &cobra.Command{ + Use: "parse-relationships <filename>", + Short: "Extract the relationships from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseRelsCmdFunc(cmd, os.Stdout, args) + }, + } + + backupRedactCmd = &cobra.Command{ + Use: "redact <filename>", + Short: "Redact a backup file to remove sensitive information", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupRedactCmdFunc(cmd, args) + }, + } +) + +func registerBackupCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(backupCmd) + registerBackupCreateFlags(backupCmd) + + backupCmd.AddCommand(backupCreateCmd) + registerBackupCreateFlags(backupCreateCmd) + + backupCmd.AddCommand(backupRestoreCmd) + registerBackupRestoreFlags(backupRestoreCmd) + + backupCmd.AddCommand(backupRedactCmd) + backupRedactCmd.Flags().Bool("redact-definitions", true, "redact definitions") + backupRedactCmd.Flags().Bool("redact-relations", true, "redact relations") + backupRedactCmd.Flags().Bool("redact-object-ids", true, "redact object IDs") + backupRedactCmd.Flags().Bool("print-redacted-object-ids", false, "prints the redacted object IDs") + + // Restore used to be on the root, so add it there too, but hidden. + restoreCmd := &cobra.Command{ + Use: "restore <filename>", + Short: "Restore a permission system from a backup file", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + RunE: backupRestoreCmdFunc, + Hidden: true, + } + rootCmd.AddCommand(restoreCmd) + registerBackupRestoreFlags(restoreCmd) + + backupCmd.AddCommand(backupParseSchemaCmd) + backupParseSchemaCmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + backupParseSchemaCmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + + backupCmd.AddCommand(backupParseRevisionCmd) + backupCmd.AddCommand(backupParseRelsCmd) + backupParseRelsCmd.Flags().String("prefix-filter", "", "Include only relationships with a given prefix") +} + +func registerBackupRestoreFlags(cmd *cobra.Command) { + cmd.Flags().Uint("batch-size", 1_000, "restore relationship write batch size") + cmd.Flags().Uint("batches-per-transaction", 10, "number of batches per transaction") + cmd.Flags().String("conflict-strategy", "fail", "strategy used when a conflicting relationship is found. Possible values: fail, skip, touch") + cmd.Flags().Bool("disable-retries", false, "retries when an errors is determined to be retryable (e.g. serialization errors)") + cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + cmd.Flags().Duration("request-timeout", 30*time.Second, "timeout for each request performed during restore") +} + +func registerBackupCreateFlags(cmd *cobra.Command) { + cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + cmd.Flags().Uint32("page-limit", 0, "defines the number of relationships to be read by requested page during backup") +} + +func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, error) { + if filename == "-" { + log.Trace().Str("filename", "- (stdout)").Send() + return os.Stdout, false, nil + } + + log.Trace().Str("filename", filename).Send() + + if _, err := os.Stat(filename); err == nil { + if !returnIfExists { + return nil, false, fmt.Errorf("backup file already exists: %s", filename) + } + + f, err := os.OpenFile(filename, os.O_RDWR|os.O_APPEND, 0o644) + if err != nil { + return nil, false, fmt.Errorf("unable to open existing backup file: %w", err) + } + + return f, true, nil + } + + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644) + if err != nil { + return nil, false, fmt.Errorf("unable to create backup file: %w", err) + } + + return f, false, nil +} + +var ( + missingAllowedTypes = regexp.MustCompile(`(\s*)(relation)(.+)(/\* missing allowed types \*/)(.*)`) + shortRelations = regexp.MustCompile(`(\s*)relation [a-z][a-z0-9_]:(.+)`) +) + +func partialPrefixMatch(name, prefix string) bool { + return strings.HasPrefix(name, prefix+"/") +} + +func filterSchemaDefs(schema, prefix string) (filteredSchema string, err error) { + if schema == "" || prefix == "" { + return schema, nil + } + + compiledSchema, err := compiler.Compile( + compiler.InputSchema{Source: "schema", SchemaString: schema}, + compiler.AllowUnprefixedObjectType(), + compiler.SkipValidation(), + ) + if err != nil { + return "", fmt.Errorf("error reading schema: %w", err) + } + + var prefixedDefs []compiler.SchemaDefinition + for _, def := range compiledSchema.ObjectDefinitions { + if partialPrefixMatch(def.Name, prefix) { + prefixedDefs = append(prefixedDefs, def) + } + } + for _, def := range compiledSchema.CaveatDefinitions { + if partialPrefixMatch(def.Name, prefix) { + prefixedDefs = append(prefixedDefs, def) + } + } + + if len(prefixedDefs) == 0 { + return "", errors.New("filtered all definitions from schema") + } + + filteredSchema, _, err = generator.GenerateSchema(prefixedDefs) + if err != nil { + return "", fmt.Errorf("error generating filtered schema: %w", err) + } + + // Validate that the type system for the generated schema is comprehensive. + compiledFilteredSchema, err := compiler.Compile( + compiler.InputSchema{Source: "generated-schema", SchemaString: filteredSchema}, + compiler.AllowUnprefixedObjectType(), + ) + if err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + + for _, rawDef := range compiledFilteredSchema.ObjectDefinitions { + ts := schemapkg.NewTypeSystem(schemapkg.ResolverForCompiledSchema(*compiledFilteredSchema)) + def, err := schemapkg.NewDefinition(ts, rawDef) + if err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + if _, err := def.Validate(context.Background()); err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + } + + return +} + +func hasRelPrefix(rel *v1.Relationship, prefix string) bool { + // Skip any relationships without the prefix on both sides. + return strings.HasPrefix(rel.Resource.ObjectType, prefix) && + strings.HasPrefix(rel.Subject.Object.ObjectType, prefix) +} + +func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") + + backupFileName, err := computeBackupFileName(cmd, args) + if err != nil { + return err + } + + backupFile, backupExists, err := createBackupFile(backupFileName, returnIfExists) + if err != nil { + return err + } + + defer func(e *error) { + *e = errors.Join(*e, backupFile.Sync()) + *e = errors.Join(*e, backupFile.Close()) + }(&err) + + // the goal of this file is to keep the bulk export cursor in case the process is terminated + // and we need to resume from where we left off. OCF does not support in-place record updates. + progressFile, cursor, err := openProgressFile(backupFileName, backupExists) + if err != nil { + return err + } + + var backupCompleted bool + defer func(e *error) { + *e = errors.Join(*e, progressFile.Sync()) + *e = errors.Join(*e, progressFile.Close()) + + if backupCompleted { + if err := os.Remove(progressFile.Name()); err != nil { + log.Warn(). + Str("progress-file", progressFile.Name()). + Msg("failed to remove progress file, consider removing it manually") + } + } + }(&err) + + c, err := client.NewClient(cmd) + if err != nil { + return fmt.Errorf("unable to initialize client: %w", err) + } + + var zedToken *v1.ZedToken + var encoder *backupformat.Encoder + if backupExists { + encoder, err = backupformat.NewEncoderForExisting(backupFile) + if err != nil { + return fmt.Errorf("error creating backup file encoder: %w", err) + } + } else { + encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile) + if err != nil { + return err + } + } + + defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err) + + if zedToken == nil && cursor == nil { + return errors.New("malformed existing backup, consider recreating it") + } + + req := &v1.BulkExportRelationshipsRequest{ + OptionalLimit: pageLimit, + OptionalCursor: cursor, + } + + // if a cursor is present, zedtoken is not needed (it is already in the cursor) + if zedToken != nil { + req.Consistency = &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: zedToken, + }, + } + } + + ctx := cmd.Context() + relationshipStream, err := c.BulkExportRelationships(ctx, req) + if err != nil { + return fmt.Errorf("error exporting relationships: %w", err) + } + + relationshipReadStart := time.Now() + tick := time.Tick(5 * time.Second) + bar := console.CreateProgressBar("processing backup") + var relsFilteredOut, relsProcessed uint64 + defer func() { + _ = bar.Finish() + + evt := log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)) + if isCanceled(err) { + evt.Msg("backup canceled - resume by restarting the backup command") + } else if err != nil { + evt.Msg("backup failed") + } else { + evt.Msg("finished backup") + } + }() + + for { + if err := ctx.Err(); err != nil { + if isCanceled(err) { + return context.Canceled + } + + return fmt.Errorf("aborted backup: %w", err) + } + + relsResp, err := relationshipStream.Recv() + if err != nil { + if isCanceled(err) { + return context.Canceled + } + + if !errors.Is(err, io.EOF) { + return fmt.Errorf("error receiving relationships: %w", err) + } + break + } + + for _, rel := range relsResp.Relationships { + if hasRelPrefix(rel, prefixFilter) { + if err := encoder.Append(rel); err != nil { + return fmt.Errorf("error storing relationship: %w", err) + } + } else { + relsFilteredOut++ + } + + relsProcessed++ + if err := bar.Add(1); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + + // progress fallback in case there is no TTY + if !isatty.IsTerminal(os.Stderr.Fd()) { + select { + case <-tick: + log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)). + Msg("backup progress") + default: + } + } + } + + if err := writeProgress(progressFile, relsResp); err != nil { + return err + } + } + + backupCompleted = true + return nil +} + +// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup +// must be taken. +func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) { + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + + schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}) + if err != nil { + return nil, nil, fmt.Errorf("error reading schema: %w", err) + } + if schemaResp.ReadAt == nil { + return nil, nil, fmt.Errorf("`backup` is not supported on this version of SpiceDB") + } + schema := schemaResp.SchemaText + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + + if prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return nil, nil, err + } + } + + zedToken := schemaResp.ReadAt + + encoder, err := backupformat.NewEncoder(backupFile, schema, zedToken) + if err != nil { + return nil, nil, fmt.Errorf("error creating backup file encoder: %w", err) + } + + return encoder, zedToken, nil +} + +func writeProgress(progressFile *os.File, relsResp *v1.BulkExportRelationshipsResponse) error { + err := progressFile.Truncate(0) + if err != nil { + return fmt.Errorf("unable to truncate backup progress file: %w", err) + } + + _, err = progressFile.Seek(0, 0) + if err != nil { + return fmt.Errorf("unable to seek backup progress file: %w", err) + } + + _, err = progressFile.WriteString(relsResp.AfterResultCursor.Token) + if err != nil { + return fmt.Errorf("unable to write result cursor to backup progress file: %w", err) + } + + return nil +} + +// openProgressFile returns the progress marker file and the stored progress cursor if it exists, or creates +// a new one if it does not exist. If the backup file exists, but the progress marker does not, it will return an error. +// +// The progress marker file keeps track of the last successful cursor received from the server, and is used to resume +// backups in case of failure. +func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.File, *v1.Cursor, error) { + var cursor *v1.Cursor + progressFileName := toLockFileName(backupFileName) + var progressFile *os.File + // if a backup existed + var fileMode int + readCursor, err := os.ReadFile(progressFileName) + if backupAlreadyExisted && (os.IsNotExist(err) || len(readCursor) == 0) { + return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName) + } else if backupAlreadyExisted && err == nil { + cursor = &v1.Cursor{ + Token: string(readCursor), + } + + // if backup existed and there is a progress marker, the latter should not be truncated to make sure the + // cursor stays around in case of a failure before we even start ingesting from bulk export + fileMode = os.O_WRONLY | os.O_CREATE + log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume") + } else { + // if a backup did not exist, make sure to truncate the progress file + fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC + } + + progressFile, err = os.OpenFile(progressFileName, fileMode, 0o644) + if err != nil { + return nil, nil, err + } + + return progressFile, cursor, nil +} + +func toLockFileName(backupFileName string) string { + return backupFileName + ".lock" +} + +// computeBackupFileName computes the backup file name based. +// If no file name is provided, it derives a backup on the current context +func computeBackupFileName(cmd *cobra.Command, args []string) (string, error) { + if len(args) > 0 { + return args[0], nil + } + + configStore, secretStore := client.DefaultStorage() + token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + if err != nil { + return "", fmt.Errorf("failed to determine current zed context: %w", err) + } + + ex, err := os.Executable() + if err != nil { + return "", err + } + exPath := filepath.Dir(ex) + + backupFileName := filepath.Join(exPath, token.Name+".zedbackup") + + return backupFileName, nil +} + +func isCanceled(err error) bool { + if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled { + return true + } + + return errors.Is(err, context.Canceled) +} + +func openRestoreFile(filename string) (*os.File, int64, error) { + if filename == "" { + log.Trace().Str("filename", "(stdin)").Send() + return os.Stdin, -1, nil + } + + log.Trace().Str("filename", filename).Send() + + stats, err := os.Stat(filename) + if err != nil { + return nil, 0, fmt.Errorf("unable to stat restore file: %w", err) + } + + f, err := os.Open(filename) + if err != nil { + return nil, 0, fmt.Errorf("unable to open restore file: %w", err) + } + + return f, stats.Size(), nil +} + +func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + if loadedToken := decoder.ZedToken(); loadedToken != nil { + log.Debug().Str("revision", loadedToken.Token).Msg("parsed revision") + } + + schema := decoder.Schema() + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + if prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return err + } + } + log.Debug().Str("schema", schema).Bool("filtered", prefixFilter != "").Msg("parsed schema") + + c, err := client.NewClient(cmd) + if err != nil { + return fmt.Errorf("unable to initialize client: %w", err) + } + + batchSize := cobrautil.MustGetUint(cmd, "batch-size") + batchesPerTransaction := cobrautil.MustGetUint(cmd, "batches-per-transaction") + + strategy, err := GetEnum[ConflictStrategy](cmd, "conflict-strategy", conflictStrategyMapping) + if err != nil { + return err + } + disableRetries := cobrautil.MustGetBool(cmd, "disable-retries") + requestTimeout := cobrautil.MustGetDuration(cmd, "request-timeout") + + return newRestorer(schema, decoder, c, prefixFilter, batchSize, batchesPerTransaction, strategy, + disableRetries, requestTimeout).restoreFromDecoder(cmd.Context()) +} + +// GetEnum is a helper for getting an enum value from a string cobra flag. +func GetEnum[E constraints.Integer](cmd *cobra.Command, name string, mapping map[string]E) (E, error) { + value := cobrautil.MustGetString(cmd, name) + value = strings.TrimSpace(strings.ToLower(value)) + if enum, ok := mapping[value]; ok { + return enum, nil + } + + var zeroValueE E + return zeroValueE, fmt.Errorf("unexpected flag '%s' value '%s': should be one of %v", name, value, maps.Keys(mapping)) +} + +func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + schema := decoder.Schema() + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + if prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter"); prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return err + } + } + + _, err = fmt.Fprintln(out, schema) + return err +} + +func backupParseRevisionCmdFunc(_ *cobra.Command, out io.Writer, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + loadedToken := decoder.ZedToken() + if loadedToken == nil { + return fmt.Errorf("failed to parse decoded revision") + } + + _, err = fmt.Fprintln(out, loadedToken.Token) + return err +} + +func backupRedactCmdFunc(cmd *cobra.Command, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return fmt.Errorf("error creating restore file decoder: %w", err) + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + filename := args[0] + ".redacted" + writer, _, err := createBackupFile(filename, doNotReturnIfExists) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, writer.Close()) }(&err) + + redactor, err := backupformat.NewRedactor(decoder, writer, backupformat.RedactionOptions{ + RedactDefinitions: cobrautil.MustGetBool(cmd, "redact-definitions"), + RedactRelations: cobrautil.MustGetBool(cmd, "redact-relations"), + RedactObjectIDs: cobrautil.MustGetBool(cmd, "redact-object-ids"), + }) + if err != nil { + return fmt.Errorf("error creating redactor: %w", err) + } + + defer func(e *error) { *e = errors.Join(*e, redactor.Close()) }(&err) + bar := console.CreateProgressBar("redacting backup") + var written int64 + for { + if err := cmd.Context().Err(); err != nil { + return fmt.Errorf("aborted redaction: %w", err) + } + + err := redactor.Next() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("error redacting: %w", err) + } + + written++ + if err := bar.Set64(written); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + } + + if err := bar.Finish(); err != nil { + return fmt.Errorf("error finalizing progress bar: %w", err) + } + + fmt.Println("Redaction map:") + fmt.Println("--------------") + fmt.Println() + + // Draw a table of definitions, caveats and relations mapped. + tbl := table.New("Definition Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Definitions { + tbl.AddRow(k, v) + } + + tbl.Print() + fmt.Println() + + if len(redactor.RedactionMap().Caveats) > 0 { + tbl = table.New("Caveat Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Caveats { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + } + + tbl = table.New("Relation/Permission Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Relations { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + + if len(redactor.RedactionMap().ObjectIDs) > 0 && cobrautil.MustGetBool(cmd, "print-redacted-object-ids") { + tbl = table.New("Object ID", "Redacted Object ID") + for k, v := range redactor.RedactionMap().ObjectIDs { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + } + + return nil +} + +func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { + prefix := cobrautil.MustGetString(cmd, "prefix-filter") + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + for rel, err := decoder.Next(); rel != nil && err == nil; rel, err = decoder.Next() { + if !hasRelPrefix(rel, prefix) { + continue + } + + relString, err := tuple.V1StringRelationship(rel) + if err != nil { + return err + } + + if _, err = fmt.Fprintln(out, replaceRelString(relString)); err != nil { + return err + } + } + + return nil +} + +func decoderFromArgs(args ...string) (*backupformat.Decoder, io.Closer, error) { + filename := "" // Default to stdin. + if len(args) > 0 { + filename = args[0] + } + + f, _, err := openRestoreFile(filename) + if err != nil { + return nil, nil, err + } + + decoder, err := backupformat.NewDecoder(f) + if err != nil { + return nil, nil, fmt.Errorf("error creating restore file decoder: %w", err) + } + + return decoder, f, nil +} + +func replaceRelString(rel string) string { + rel = strings.Replace(rel, "@", " ", 1) + return strings.Replace(rel, "#", " ", 1) +} + +func rewriteLegacy(schema string) string { + schema = string(missingAllowedTypes.ReplaceAll([]byte(schema), []byte("\n/* deleted missing allowed type error */"))) + return string(shortRelations.ReplaceAll([]byte(schema), []byte("\n/* deleted short relation name */"))) +} + +var sizeErrorRegEx = regexp.MustCompile(`received message larger than max \((\d+) vs. (\d+)\)`) + +func addSizeErrInfo(err error) error { + if err == nil { + return nil + } + + code := status.Code(err) + if code != codes.ResourceExhausted { + return err + } + + if !strings.Contains(err.Error(), "received message larger than max") { + return err + } + + matches := sizeErrorRegEx.FindStringSubmatch(err.Error()) + if len(matches) != 3 { + return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) + } + + necessaryByteCount, atoiErr := strconv.Atoi(matches[1]) + if atoiErr != nil { + return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) + } + + return fmt.Errorf("%w: set flag --max-message-size=%d to increase the maximum allowable size", err, 2*necessaryByteCount) +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/cmd.go b/vendor/github.com/authzed/zed/internal/cmd/cmd.go new file mode 100644 index 0000000..a2e5332 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/cmd.go @@ -0,0 +1,187 @@ +package cmd + +import ( + "context" + "errors" + "io" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/cobrautil/v2/cobrazerolog" + "github.com/mattn/go-isatty" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + + "github.com/authzed/zed/internal/commands" +) + +var ( + SyncFlagsCmdFunc = cobrautil.SyncViperPreRunE("ZED") + errParsing = errors.New("parsing error") +) + +func init() { + // NOTE: this is mostly to set up logging in the case where + // the command doesn't exist or the construction of the command + // errors out before the PersistentPreRunE setup in the below function. + // It helps keep log output visually consistent for a user even in + // exceptional cases. + var output io.Writer + + if isatty.IsTerminal(os.Stdout.Fd()) { + output = zerolog.ConsoleWriter{Out: os.Stderr} + } else { + output = os.Stderr + } + + l := zerolog.New(output).With().Timestamp().Logger() + + log.Logger = l +} + +var flagError = flagErrorFunc + +func flagErrorFunc(cmd *cobra.Command, err error) error { + cmd.Println(err) + cmd.Println(cmd.UsageString()) + return errParsing +} + +// InitialiseRootCmd This function is utilised to generate docs for zed +func InitialiseRootCmd(zl *cobrazerolog.Builder) *cobra.Command { + rootCmd := &cobra.Command{ + Use: "zed", + Short: "SpiceDB CLI, built by AuthZed", + Long: "A command-line client for managing SpiceDB clusters.", + PersistentPreRunE: cobrautil.CommandStack( + zl.RunE(), + SyncFlagsCmdFunc, + commands.InjectRequestID, + ), + SilenceErrors: true, + SilenceUsage: true, + } + rootCmd.SetFlagErrorFunc(func(command *cobra.Command, err error) error { + return flagError(command, err) + }) + + zl.RegisterFlags(rootCmd.PersistentFlags()) + + rootCmd.PersistentFlags().String("endpoint", "", "spicedb gRPC API endpoint") + rootCmd.PersistentFlags().String("permissions-system", "", "permissions system to query") + rootCmd.PersistentFlags().String("hostname-override", "", "override the hostname used in the connection to the endpoint") + rootCmd.PersistentFlags().String("token", "", "token used to authenticate to SpiceDB") + rootCmd.PersistentFlags().String("certificate-path", "", "path to certificate authority used to verify secure connections") + rootCmd.PersistentFlags().Bool("insecure", false, "connect over a plaintext connection") + rootCmd.PersistentFlags().Bool("skip-version-check", false, "if true, no version check is performed against the server") + rootCmd.PersistentFlags().Bool("no-verify-ca", false, "do not attempt to verify the server's certificate chain and host name") + rootCmd.PersistentFlags().Bool("debug", false, "enable debug logging") + rootCmd.PersistentFlags().String("request-id", "", "optional id to send along with SpiceDB requests for tracing") + rootCmd.PersistentFlags().Int("max-message-size", 0, "maximum size *in bytes* (defaults to 4_194_304 bytes ~= 4MB) of a gRPC message that can be sent or received by zed") + rootCmd.PersistentFlags().String("proxy", "", "specify a SOCKS5 proxy address") + rootCmd.PersistentFlags().Uint("max-retries", 10, "maximum number of sequential retries to attempt when a request fails") + _ = rootCmd.PersistentFlags().MarkHidden("debug") // This cannot return its error. + + versionCmd := &cobra.Command{ + Use: "version", + Short: "Display zed and SpiceDB version information", + RunE: versionCmdFunc, + } + cobrautil.RegisterVersionFlags(versionCmd.Flags()) + versionCmd.Flags().Bool("include-remote-version", true, "whether to display the version of Authzed or SpiceDB for the current context") + rootCmd.AddCommand(versionCmd) + + // Register root-level aliases + rootCmd.AddCommand(&cobra.Command{ + Use: "use <context>", + Short: "Alias for `zed context use`", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + RunE: contextUseCmdFunc, + ValidArgsFunction: ContextGet, + }) + + // Register CLI-only commands. + registerContextCmd(rootCmd) + registerImportCmd(rootCmd) + registerValidateCmd(rootCmd) + registerBackupCmd(rootCmd) + registerPreviewCmd(rootCmd) + + // Register shared commands. + commands.RegisterPermissionCmd(rootCmd) + + relCmd := commands.RegisterRelationshipCmd(rootCmd) + + commands.RegisterWatchCmd(rootCmd) + commands.RegisterWatchRelationshipCmd(relCmd) + + schemaCmd := commands.RegisterSchemaCmd(rootCmd) + registerAdditionalSchemaCmds(schemaCmd) + + return rootCmd +} + +func Run() { + if err := runWithoutExit(); err != nil { + os.Exit(1) + } +} + +func runWithoutExit() error { + zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel)) + + rootCmd := InitialiseRootCmd(zl) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + signalChan := make(chan os.Signal, 2) + signal.Notify(signalChan, os.Interrupt, syscall.SIGTERM) + defer func() { + signal.Stop(signalChan) + cancel() + }() + + go func() { + select { + case <-signalChan: + cancel() + case <-ctx.Done(): + } + }() + + return handleError(rootCmd, rootCmd.ExecuteContext(ctx)) +} + +func handleError(command *cobra.Command, err error) error { + if err == nil { + return nil + } + // this snippet of code is taken from Command.ExecuteC in order to determine the command that was ultimately + // parsed. This is necessary to be able to print the proper command-specific usage + var findErr error + var cmdToExecute *cobra.Command + args := os.Args[1:] + if command.TraverseChildren { + cmdToExecute, _, findErr = command.Traverse(args) + } else { + cmdToExecute, _, findErr = command.Find(args) + } + if findErr != nil { + cmdToExecute = command + } + + if errors.Is(err, commands.ValidationError{}) { + _ = flagError(cmdToExecute, err) + } else if err != nil && strings.Contains(err.Error(), "unknown command") { + _ = flagError(cmdToExecute, err) + } else if !errors.Is(err, errParsing) { + log.Err(err).Msg("terminated with errors") + } + + return err +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/context.go b/vendor/github.com/authzed/zed/internal/cmd/context.go new file mode 100644 index 0000000..d5a0eec --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/context.go @@ -0,0 +1,202 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/stringz" + "github.com/spf13/cobra" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/internal/printers" + "github.com/authzed/zed/internal/storage" +) + +func registerContextCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(contextCmd) + + contextCmd.AddCommand(contextListCmd) + contextListCmd.Flags().Bool("reveal-tokens", false, "display secrets in results") + + contextCmd.AddCommand(contextSetCmd) + contextCmd.AddCommand(contextRemoveCmd) + contextCmd.AddCommand(contextUseCmd) +} + +var contextCmd = &cobra.Command{ + Use: "context <subcommand>", + Short: "Manage configurations for connecting to SpiceDB deployments", + Aliases: []string{"ctx"}, +} + +var contextListCmd = &cobra.Command{ + Use: "list", + Short: "Lists all available contexts", + Aliases: []string{"ls"}, + Args: commands.ValidationWrapper(cobra.ExactArgs(0)), + ValidArgsFunction: cobra.NoFileCompletions, + RunE: contextListCmdFunc, +} + +var contextSetCmd = &cobra.Command{ + Use: "set <name> <endpoint> <api-token>", + Short: "Creates or overwrite a context", + Args: commands.ValidationWrapper(cobra.ExactArgs(3)), + ValidArgsFunction: cobra.NoFileCompletions, + RunE: contextSetCmdFunc, +} + +var contextRemoveCmd = &cobra.Command{ + Use: "remove <system>", + Short: "Removes a context", + Aliases: []string{"rm"}, + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + ValidArgsFunction: ContextGet, + RunE: contextRemoveCmdFunc, +} + +var contextUseCmd = &cobra.Command{ + Use: "use <system>", + Short: "Sets a context as the current context", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + ValidArgsFunction: ContextGet, + RunE: contextUseCmdFunc, +} + +func ContextGet(_ *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { + _, secretStore := client.DefaultStorage() + secrets, err := secretStore.Get() + if err != nil { + return nil, cobra.ShellCompDirectiveError + } + + names := make([]string, 0, len(secrets.Tokens)) + for _, token := range secrets.Tokens { + names = append(names, token.Name) + } + + return names, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace | cobra.ShellCompDirectiveKeepOrder +} + +func contextListCmdFunc(cmd *cobra.Command, _ []string) error { + cfgStore, secretStore := client.DefaultStorage() + secrets, err := secretStore.Get() + if err != nil { + return err + } + + cfg, err := cfgStore.Get() + if err != nil { + return err + } + + rows := make([][]string, 0, len(secrets.Tokens)) + for _, token := range secrets.Tokens { + current := "" + if token.Name == cfg.CurrentToken { + current = " ✓ " + } + secret := token.APIToken + if !cobrautil.MustGetBool(cmd, "reveal-tokens") { + secret = token.Redacted() + } + + var certStr string + if token.IsInsecure() { + certStr = "insecure" + } else if token.HasNoVerifyCA() { + certStr = "no-verify-ca" + } else if _, ok := token.Certificate(); ok { + certStr = "custom" + } else { + certStr = "system" + } + + rows = append(rows, []string{ + current, + token.Name, + token.Endpoint, + secret, + certStr, + }) + } + + printers.PrintTable(os.Stdout, []string{"current", "name", "endpoint", "token", "tls cert"}, rows) + + return nil +} + +func contextSetCmdFunc(cmd *cobra.Command, args []string) error { + var name, endpoint, apiToken string + err := stringz.Unpack(args, &name, &endpoint, &apiToken) + if err != nil { + return err + } + + certPath := cobrautil.MustGetStringExpanded(cmd, "certificate-path") + var certBytes []byte + if certPath != "" { + certBytes, err = os.ReadFile(certPath) + if err != nil { + return fmt.Errorf("failed to read ceritficate: %w", err) + } + } + + insecure := cobrautil.MustGetBool(cmd, "insecure") + noVerifyCA := cobrautil.MustGetBool(cmd, "no-verify-ca") + cfgStore, secretStore := client.DefaultStorage() + err = storage.PutToken(storage.Token{ + Name: name, + Endpoint: stringz.DefaultEmpty(endpoint, "grpc.authzed.com:443"), + APIToken: apiToken, + Insecure: &insecure, + NoVerifyCA: &noVerifyCA, + CACert: certBytes, + }, secretStore) + if err != nil { + return err + } + + return storage.SetCurrentToken(name, cfgStore, secretStore) +} + +func contextRemoveCmdFunc(_ *cobra.Command, args []string) error { + // If the token is what's currently being used, remove it from the config. + cfgStore, secretStore := client.DefaultStorage() + cfg, err := cfgStore.Get() + if err != nil { + return err + } + + if cfg.CurrentToken == args[0] { + cfg.CurrentToken = "" + } + + err = cfgStore.Put(cfg) + if err != nil { + return err + } + + return storage.RemoveToken(args[0], secretStore) +} + +func contextUseCmdFunc(_ *cobra.Command, args []string) error { + cfgStore, secretStore := client.DefaultStorage() + switch len(args) { + case 0: + cfg, err := cfgStore.Get() + if err != nil { + return err + } + console.Println(cfg.CurrentToken) + case 1: + return storage.SetCurrentToken(args[0], cfgStore, secretStore) + default: + panic("cobra command did not enforce valid number of args") + } + + return nil +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/import.go b/vendor/github.com/authzed/zed/internal/cmd/import.go new file mode 100644 index 0000000..687d42e --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/import.go @@ -0,0 +1,181 @@ +package cmd + +import ( + "bufio" + "context" + "fmt" + "net/url" + "strings" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/authzed/spicedb/pkg/validationfile" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/decode" + "github.com/authzed/zed/internal/grpcutil" +) + +func registerImportCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(importCmd) + importCmd.Flags().Int("batch-size", 1000, "import batch size") + importCmd.Flags().Int("workers", 1, "number of concurrent batching workers") + importCmd.Flags().Bool("schema", true, "import schema") + importCmd.Flags().Bool("relationships", true, "import relationships") + importCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before importing") +} + +var importCmd = &cobra.Command{ + Use: "import <url>", + Short: "Imports schema and relationships from a file or url", + Example: ` + From a gist: + zed import https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6 + + From a playground link: + zed import https://play.authzed.com/s/iksdFvCtvnkR/schema + + From pastebin: + zed import https://pastebin.com/8qU45rVK + + From a devtools instance: + zed import https://localhost:8443/download + + From a local file (with prefix): + zed import file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + From a local file (no prefix): + zed import authzed-x7izWU8_2Gw3.yaml + + Only schema: + zed import --relationships=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + Only relationships: + zed import --schema=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + With schema definition prefix: + zed import --schema-definition-prefix=mypermsystem file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml +`, + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: importCmdFunc, +} + +func importCmdFunc(cmd *cobra.Command, args []string) error { + client, err := client.NewClient(cmd) + if err != nil { + return err + } + u, err := url.Parse(args[0]) + if err != nil { + return err + } + + decoder, err := decode.DecoderForURL(u) + if err != nil { + return err + } + var p validationfile.ValidationFile + if _, _, err := decoder(&p); err != nil { + return err + } + + prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil) + if err != nil { + return err + } + + if cobrautil.MustGetBool(cmd, "schema") { + if err := importSchema(cmd.Context(), client, p.Schema.Schema, prefix); err != nil { + return err + } + } + + if cobrautil.MustGetBool(cmd, "relationships") { + batchSize := cobrautil.MustGetInt(cmd, "batch-size") + workers := cobrautil.MustGetInt(cmd, "workers") + if err := importRelationships(cmd.Context(), client, p.Relationships.RelationshipsString, prefix, batchSize, workers); err != nil { + return err + } + } + + return err +} + +func importSchema(ctx context.Context, client client.Client, schema string, definitionPrefix string) error { + log.Info().Msg("importing schema") + + // Recompile the schema with the specified prefix. + schemaText, err := rewriteSchema(schema, definitionPrefix) + if err != nil { + return err + } + + // Write the recompiled and regenerated schema. + request := &v1.WriteSchemaRequest{Schema: schemaText} + log.Trace().Interface("request", request).Str("schema", schemaText).Msg("writing schema") + + if _, err := client.WriteSchema(ctx, request); err != nil { + return err + } + + return nil +} + +func importRelationships(ctx context.Context, client client.Client, relationships string, definitionPrefix string, batchSize int, workers int) error { + relationshipUpdates := make([]*v1.RelationshipUpdate, 0) + scanner := bufio.NewScanner(strings.NewReader(relationships)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + if strings.HasPrefix(line, "//") { + continue + } + rel, err := tuple.ParseV1Rel(line) + if err != nil { + return fmt.Errorf("failed to parse %s as relationship", line) + } + log.Trace().Str("line", line).Send() + + // Rewrite the prefix on the references, if any. + if len(definitionPrefix) > 0 { + rel.Resource.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Resource.ObjectType) + rel.Subject.Object.ObjectType = fmt.Sprintf("%s/%s", definitionPrefix, rel.Subject.Object.ObjectType) + } + + relationshipUpdates = append(relationshipUpdates, &v1.RelationshipUpdate{ + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + Relationship: rel, + }) + } + if err := scanner.Err(); err != nil { + return err + } + + log.Info(). + Int("batch_size", batchSize). + Int("workers", workers). + Int("count", len(relationshipUpdates)). + Msg("importing relationships") + + err := grpcutil.ConcurrentBatch(ctx, len(relationshipUpdates), batchSize, workers, func(ctx context.Context, no int, start int, end int) error { + request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates[start:end]} + _, err := client.WriteRelationships(ctx, request) + if err != nil { + return err + } + + log.Info(). + Int("batch_no", no). + Int("count", len(relationshipUpdates[start:end])). + Msg("imported relationships") + return nil + }) + return err +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/preview.go b/vendor/github.com/authzed/zed/internal/cmd/preview.go new file mode 100644 index 0000000..279e310 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/preview.go @@ -0,0 +1,120 @@ +package cmd + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/ccoveille/go-safecast" + "github.com/jzelinskie/cobrautil/v2" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "golang.org/x/term" + + newcompiler "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" + newinput "github.com/authzed/spicedb/pkg/composableschemadsl/input" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + + "github.com/authzed/zed/internal/commands" +) + +func registerPreviewCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(previewCmd) + + previewCmd.AddCommand(schemaCmd) + + schemaCmd.AddCommand(schemaCompileCmd) + schemaCompileCmd.Flags().String("out", "", "output filepath; omitting writes to stdout") +} + +var previewCmd = &cobra.Command{ + Use: "preview <subcommand>", + Short: "Experimental commands that have been made available for preview", +} + +var schemaCmd = &cobra.Command{ + Use: "schema <subcommand>", + Short: "Manage schema for a permissions system", +} + +var schemaCompileCmd = &cobra.Command{ + Use: "compile <file>", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + Short: "Compile a schema that uses extended syntax into one that can be written to SpiceDB", + Example: ` + Write to stdout: + zed preview schema compile root.zed + Write to an output file: + zed preview schema compile root.zed --out compiled.zed + `, + ValidArgsFunction: commands.FileExtensionCompletions("zed"), + RunE: schemaCompileCmdFunc, +} + +// Compiles an input schema written in the new composable schema syntax +// and produces it as a fully-realized schema +func schemaCompileCmdFunc(cmd *cobra.Command, args []string) error { + stdOutFd, err := safecast.ToInt(uint(os.Stdout.Fd())) + if err != nil { + return err + } + outputFilepath := cobrautil.MustGetString(cmd, "out") + if outputFilepath == "" && !term.IsTerminal(stdOutFd) { + return fmt.Errorf("must provide stdout or output file path") + } + + inputFilepath := args[0] + inputSourceFolder := filepath.Dir(inputFilepath) + var schemaBytes []byte + schemaBytes, err = os.ReadFile(inputFilepath) + if err != nil { + return fmt.Errorf("failed to read schema file: %w", err) + } + log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file") + + if len(schemaBytes) == 0 { + return errors.New("attempted to compile empty schema") + } + + compiled, err := newcompiler.Compile(newcompiler.InputSchema{ + Source: newinput.Source(inputFilepath), + SchemaString: string(schemaBytes), + }, newcompiler.AllowUnprefixedObjectType(), + newcompiler.SourceFolder(inputSourceFolder)) + if err != nil { + return err + } + + // Attempt to cast one kind of OrderedDefinition to another + oldDefinitions := make([]compiler.SchemaDefinition, 0, len(compiled.OrderedDefinitions)) + for _, definition := range compiled.OrderedDefinitions { + oldDefinition, ok := definition.(compiler.SchemaDefinition) + if !ok { + return fmt.Errorf("could not convert definition to old schemadefinition: %v", oldDefinition) + } + oldDefinitions = append(oldDefinitions, oldDefinition) + } + + // This is where we functionally assert that the two systems are compatible + generated, _, err := generator.GenerateSchema(oldDefinitions) + if err != nil { + return fmt.Errorf("could not generate resulting schema: %w", err) + } + + // Add a newline at the end for hygiene's sake + terminated := generated + "\n" + + if outputFilepath == "" { + // Print to stdout + fmt.Print(terminated) + } else { + err = os.WriteFile(outputFilepath, []byte(terminated), 0o_600) + if err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/restorer.go b/vendor/github.com/authzed/zed/internal/cmd/restorer.go new file mode 100644 index 0000000..468064f --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/restorer.go @@ -0,0 +1,446 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "strings" + "time" + + "github.com/ccoveille/go-safecast" + "github.com/cenkalti/backoff/v4" + "github.com/mattn/go-isatty" + "github.com/rs/zerolog/log" + "github.com/samber/lo" + "github.com/schollz/progressbar/v3" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/pkg/backupformat" +) + +type ConflictStrategy int + +const ( + Fail ConflictStrategy = iota + Skip + Touch + + defaultBackoff = 50 * time.Millisecond + defaultMaxRetries = 10 +) + +var conflictStrategyMapping = map[string]ConflictStrategy{ + "fail": Fail, + "skip": Skip, + "touch": Touch, +} + +// Fallback for datastore implementations on SpiceDB < 1.29.0 not returning proper gRPC codes +// Remove once https://github.com/authzed/spicedb/pull/1688 lands +var ( + txConflictCodes = []string{ + "SQLSTATE 23505", // CockroachDB + "Error 1062 (23000)", // MySQL + } + retryableErrorCodes = []string{ + "retryable error", // CockroachDB, PostgreSQL + "try restarting transaction", "Error 1205", // MySQL + } +) + +type restorer struct { + schema string + decoder *backupformat.Decoder + client client.Client + prefixFilter string + batchSize uint + batchesPerTransaction uint + conflictStrategy ConflictStrategy + disableRetryErrors bool + bar *progressbar.ProgressBar + + // stats + filteredOutRels uint + writtenRels uint + writtenBatches uint + skippedRels uint + skippedBatches uint + duplicateRels uint + duplicateBatches uint + totalRetries uint + requestTimeout time.Duration +} + +func newRestorer(schema string, decoder *backupformat.Decoder, client client.Client, prefixFilter string, batchSize uint, + batchesPerTransaction uint, conflictStrategy ConflictStrategy, disableRetryErrors bool, + requestTimeout time.Duration, +) *restorer { + return &restorer{ + decoder: decoder, + schema: schema, + client: client, + prefixFilter: prefixFilter, + requestTimeout: requestTimeout, + batchSize: batchSize, + batchesPerTransaction: batchesPerTransaction, + conflictStrategy: conflictStrategy, + disableRetryErrors: disableRetryErrors, + bar: console.CreateProgressBar("restoring from backup"), + } +} + +func (r *restorer) restoreFromDecoder(ctx context.Context) error { + relationshipWriteStart := time.Now() + defer func() { + if err := r.bar.Finish(); err != nil { + log.Warn().Err(err).Msg("error finalizing progress bar") + } + }() + + r.bar.Describe("restoring schema from backup") + if _, err := r.client.WriteSchema(ctx, &v1.WriteSchemaRequest{ + Schema: r.schema, + }); err != nil { + return fmt.Errorf("unable to write schema: %w", err) + } + + relationshipWriter, err := r.client.BulkImportRelationships(ctx) + if err != nil { + return fmt.Errorf("error creating writer stream: %w", err) + } + + r.bar.Describe("restoring relationships from backup") + batch := make([]*v1.Relationship, 0, r.batchSize) + batchesToBeCommitted := make([][]*v1.Relationship, 0, r.batchesPerTransaction) + for rel, err := r.decoder.Next(); rel != nil && err == nil; rel, err = r.decoder.Next() { + if err := ctx.Err(); err != nil { + r.bar.Describe("backup restore aborted") + return fmt.Errorf("aborted restore: %w", err) + } + + if !hasRelPrefix(rel, r.prefixFilter) { + r.filteredOutRels++ + continue + } + + batch = append(batch, rel) + + if uint(len(batch))%r.batchSize == 0 { + batchesToBeCommitted = append(batchesToBeCommitted, batch) + err := relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{ + Relationships: batch, + }) + if err != nil { + // It feels non-idiomatic to check for error and perform an operation, but in gRPC, when an element + // sent over the stream fails, we need to call recvAndClose() to get the error. + if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil { + return fmt.Errorf("error committing batches: %w", err) + } + + // after an error + relationshipWriter, err = r.client.BulkImportRelationships(ctx) + if err != nil { + return fmt.Errorf("error creating new writer stream: %w", err) + } + + batchesToBeCommitted = batchesToBeCommitted[:0] + batch = batch[:0] + continue + } + + // The batch just sent is kept in batchesToBeCommitted, which is used for retries. + // Therefore, we cannot reuse the batch. Batches may fail on send, or on commit (CloseAndRecv). + batch = make([]*v1.Relationship, 0, r.batchSize) + + // if we've sent the maximum number of batches per transaction, proceed to commit + if uint(len(batchesToBeCommitted))%r.batchesPerTransaction != 0 { + continue + } + + if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil { + return fmt.Errorf("error committing batches: %w", err) + } + + relationshipWriter, err = r.client.BulkImportRelationships(ctx) + if err != nil { + return fmt.Errorf("error creating new writer stream: %w", err) + } + + batchesToBeCommitted = batchesToBeCommitted[:0] + } + } + + // Write the last batch + if len(batch) > 0 { + // Since we are going to close the stream anyway after the last batch, and given the actual error + // is only returned on CloseAndRecv(), we have to ignore the error here in order to get the actual + // underlying error that caused Send() to fail. It also gives us the opportunity to retry it + // in case it failed. + batchesToBeCommitted = append(batchesToBeCommitted, batch) + _ = relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{Relationships: batch}) + } + + if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil { + return fmt.Errorf("error committing last set of batches: %w", err) + } + + r.bar.Describe("completed import") + if err := r.bar.Finish(); err != nil { + log.Warn().Err(err).Msg("error finalizing progress bar") + } + + totalTime := time.Since(relationshipWriteStart) + log.Info(). + Uint("batches", r.writtenBatches). + Uint("relationships_loaded", r.writtenRels). + Uint("relationships_skipped", r.skippedRels). + Uint("duplicate_relationships", r.duplicateRels). + Uint("relationships_filtered_out", r.filteredOutRels). + Uint("retried_errors", r.totalRetries). + Uint64("perSecond", perSec(uint64(r.writtenRels), totalTime)). + Stringer("duration", totalTime). + Msg("finished restore") + return nil +} + +func (r *restorer) commitStream(ctx context.Context, bulkImportClient v1.ExperimentalService_BulkImportRelationshipsClient, + batchesToBeCommitted [][]*v1.Relationship, +) error { + var numLoaded, expectedLoaded, retries uint + for _, b := range batchesToBeCommitted { + expectedLoaded += uint(len(b)) + } + + resp, err := bulkImportClient.CloseAndRecv() // transaction commit happens here + + // Failure to commit transaction means the stream is closed, so it can't be reused any further + // The retry will be done using WriteRelationships instead of BulkImportRelationships + // This lets us retry with TOUCH semantics in case of failure due to duplicates + retryable := isRetryableError(err) + conflict := isAlreadyExistsError(err) + canceled, cancelErr := isCanceledError(ctx.Err(), err) + unknown := !retryable && !conflict && !canceled && err != nil + + numBatches := uint(len(batchesToBeCommitted)) + + switch { + case canceled: + r.bar.Describe("backup restore aborted") + return cancelErr + case unknown: + r.bar.Describe("failed with unrecoverable error") + return fmt.Errorf("error finalizing write of %d batches: %w", len(batchesToBeCommitted), err) + case retryable && r.disableRetryErrors: + return err + case conflict && r.conflictStrategy == Skip: + r.skippedRels += expectedLoaded + r.skippedBatches += numBatches + r.duplicateBatches += numBatches + r.duplicateRels += expectedLoaded + r.bar.Describe("skipping conflicting batch") + case conflict && r.conflictStrategy == Touch: + r.bar.Describe("touching conflicting batch") + r.duplicateRels += expectedLoaded + r.duplicateBatches += numBatches + r.totalRetries++ + numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted) + if err != nil { + return fmt.Errorf("failed to write retried batch: %w", err) + } + + retries++ // account for the initial attempt + r.writtenBatches += numBatches + r.writtenRels += numLoaded + case conflict && r.conflictStrategy == Fail: + r.bar.Describe("conflict detected, aborting restore") + return fmt.Errorf("duplicate relationships found") + case retryable: + r.bar.Describe("retrying after error") + r.totalRetries++ + numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted) + if err != nil { + return fmt.Errorf("failed to write retried batch: %w", err) + } + + retries++ // account for the initial attempt + r.writtenBatches += numBatches + r.writtenRels += numLoaded + default: + r.bar.Describe("restoring relationships from backup") + r.writtenBatches += numBatches + } + + // it was a successful transaction commit without duplicates + if resp != nil { + numLoaded, err := safecast.ToUint(resp.NumLoaded) + if err != nil { + return spiceerrors.MustBugf("could not cast numLoaded to uint") + } + r.writtenRels += numLoaded + if uint64(expectedLoaded) != resp.NumLoaded { + log.Warn().Uint64("loaded", resp.NumLoaded).Uint("expected", expectedLoaded).Msg("unexpected number of relationships loaded") + } + } + + writtenAndSkipped, err := safecast.ToInt64(r.writtenRels + r.skippedRels) + if err != nil { + return fmt.Errorf("too many written and skipped rels for an int64") + } + + if err := r.bar.Set64(writtenAndSkipped); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + + if !isatty.IsTerminal(os.Stderr.Fd()) { + log.Trace(). + Uint("batches_written", r.writtenBatches). + Uint("relationships_written", r.writtenRels). + Uint("duplicate_batches", r.duplicateBatches). + Uint("duplicate_relationships", r.duplicateRels). + Uint("skipped_batches", r.skippedBatches). + Uint("skipped_relationships", r.skippedRels). + Uint("retries", retries). + Msg("restore progress") + } + + return nil +} + +// writeBatchesWithRetry writes a set of batches using touch semantics and without transactional guarantees - +// each batch will be committed independently. If a batch fails, it will be retried up to 10 times with a backoff. +func (r *restorer) writeBatchesWithRetry(ctx context.Context, batches [][]*v1.Relationship) (uint, uint, error) { + backoffInterval := backoff.NewExponentialBackOff() + backoffInterval.InitialInterval = defaultBackoff + backoffInterval.MaxInterval = 2 * time.Second + backoffInterval.MaxElapsedTime = 0 + backoffInterval.Reset() + + var currentRetries, totalRetries, loadedRels uint + for _, batch := range batches { + updates := lo.Map[*v1.Relationship, *v1.RelationshipUpdate](batch, func(item *v1.Relationship, _ int) *v1.RelationshipUpdate { + return &v1.RelationshipUpdate{ + Relationship: item, + Operation: v1.RelationshipUpdate_OPERATION_TOUCH, + } + }) + + for { + cancelCtx, cancel := context.WithTimeout(ctx, r.requestTimeout) + _, err := r.client.WriteRelationships(cancelCtx, &v1.WriteRelationshipsRequest{Updates: updates}) + cancel() + + if isRetryableError(err) && currentRetries < defaultMaxRetries { + // throttle the writes so we don't overwhelm the server + bo := backoffInterval.NextBackOff() + r.bar.Describe(fmt.Sprintf("retrying write with backoff %s after error (attempt %d/%d)", bo, + currentRetries+1, defaultMaxRetries)) + time.Sleep(bo) + currentRetries++ + r.totalRetries++ + totalRetries++ + continue + } + if err != nil { + return 0, 0, err + } + + currentRetries = 0 + backoffInterval.Reset() + loadedRels += uint(len(batch)) + break + } + } + + return loadedRels, totalRetries, nil +} + +func isAlreadyExistsError(err error) bool { + if err == nil { + return false + } + + if isGRPCCode(err, codes.AlreadyExists) { + return true + } + + return isContainsErrorString(err, txConflictCodes...) +} + +func isRetryableError(err error) bool { + if err == nil { + return false + } + + if isGRPCCode(err, codes.Unavailable, codes.DeadlineExceeded) { + return true + } + + if isContainsErrorString(err, retryableErrorCodes...) { + return true + } + + return errors.Is(err, context.DeadlineExceeded) +} + +func isCanceledError(errs ...error) (bool, error) { + for _, err := range errs { + if err == nil { + continue + } + + if errors.Is(err, context.Canceled) { + return true, err + } + + if isGRPCCode(err, codes.Canceled) { + return true, err + } + } + + return false, nil +} + +func isContainsErrorString(err error, errStrings ...string) bool { + if err == nil { + return false + } + + for _, errString := range errStrings { + if strings.Contains(err.Error(), errString) { + return true + } + } + + return false +} + +func isGRPCCode(err error, codes ...codes.Code) bool { + if err == nil { + return false + } + + if s, ok := status.FromError(err); ok { + for _, code := range codes { + if s.Code() == code { + return true + } + } + } + + return false +} + +func perSec(i uint64, d time.Duration) uint64 { + secs := uint64(d.Seconds()) + if secs == 0 { + return i + } + return i / secs +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/schema.go b/vendor/github.com/authzed/zed/internal/cmd/schema.go new file mode 100644 index 0000000..bbe52f3 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/schema.go @@ -0,0 +1,319 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/ccoveille/go-safecast" + "github.com/jzelinskie/cobrautil/v2" + "github.com/jzelinskie/stringz" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "golang.org/x/term" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/caveats/types" + "github.com/authzed/spicedb/pkg/diff" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/schemadsl/input" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" +) + +func registerAdditionalSchemaCmds(schemaCmd *cobra.Command) { + schemaCmd.AddCommand(schemaCopyCmd) + schemaCopyCmd.Flags().Bool("json", false, "output as JSON") + schemaCopyCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing") + + schemaCmd.AddCommand(schemaWriteCmd) + schemaWriteCmd.Flags().Bool("json", false, "output as JSON") + schemaWriteCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before writing") + + schemaCmd.AddCommand(schemaDiffCmd) +} + +var schemaWriteCmd = &cobra.Command{ + Use: "write <file?>", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + Short: "Write a schema file (.zed or stdin) to the current permissions system", + ValidArgsFunction: commands.FileExtensionCompletions("zed"), + RunE: schemaWriteCmdFunc, +} + +var schemaCopyCmd = &cobra.Command{ + Use: "copy <src context> <dest context>", + Short: "Copy a schema from one context into another", + Args: commands.ValidationWrapper(cobra.ExactArgs(2)), + ValidArgsFunction: ContextGet, + RunE: schemaCopyCmdFunc, +} + +var schemaDiffCmd = &cobra.Command{ + Use: "diff <before file> <after file>", + Short: "Diff two schema files", + Args: commands.ValidationWrapper(cobra.ExactArgs(2)), + RunE: schemaDiffCmdFunc, +} + +func schemaDiffCmdFunc(_ *cobra.Command, args []string) error { + beforeBytes, err := os.ReadFile(args[0]) + if err != nil { + return fmt.Errorf("failed to read before schema file: %w", err) + } + + afterBytes, err := os.ReadFile(args[1]) + if err != nil { + return fmt.Errorf("failed to read after schema file: %w", err) + } + + before, err := compiler.Compile( + compiler.InputSchema{Source: input.Source(args[0]), SchemaString: string(beforeBytes)}, + compiler.AllowUnprefixedObjectType(), + ) + if err != nil { + return err + } + + after, err := compiler.Compile( + compiler.InputSchema{Source: input.Source(args[1]), SchemaString: string(afterBytes)}, + compiler.AllowUnprefixedObjectType(), + ) + if err != nil { + return err + } + + dbefore := diff.NewDiffableSchemaFromCompiledSchema(before) + dafter := diff.NewDiffableSchemaFromCompiledSchema(after) + + schemaDiff, err := diff.DiffSchemas(dbefore, dafter, types.Default.TypeSet) + if err != nil { + return err + } + + for _, ns := range schemaDiff.AddedNamespaces { + console.Printf("Added definition: %s\n", ns) + } + + for _, ns := range schemaDiff.RemovedNamespaces { + console.Printf("Removed definition: %s\n", ns) + } + + for nsName, ns := range schemaDiff.ChangedNamespaces { + console.Printf("Changed definition: %s\n", nsName) + for _, delta := range ns.Deltas() { + console.Printf("\t %s: %s\n", delta.Type, delta.RelationName) + } + } + + for _, caveat := range schemaDiff.AddedCaveats { + console.Printf("Added caveat: %s\n", caveat) + } + + for _, caveat := range schemaDiff.RemovedCaveats { + console.Printf("Removed caveat: %s\n", caveat) + } + + return nil +} + +func schemaCopyCmdFunc(cmd *cobra.Command, args []string) error { + _, secretStore := client.DefaultStorage() + srcClient, err := client.NewClientForContext(cmd, args[0], secretStore) + if err != nil { + return err + } + + destClient, err := client.NewClientForContext(cmd, args[1], secretStore) + if err != nil { + return err + } + + readRequest := &v1.ReadSchemaRequest{} + log.Trace().Interface("request", readRequest).Msg("requesting schema read") + + readResp, err := srcClient.ReadSchema(cmd.Context(), readRequest) + if err != nil { + log.Fatal().Err(err).Msg("failed to read schema") + } + log.Trace().Interface("response", readResp).Msg("read schema") + + prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), nil, &readResp.SchemaText) + if err != nil { + return err + } + + schemaText, err := rewriteSchema(readResp.SchemaText, prefix) + if err != nil { + return err + } + + writeRequest := &v1.WriteSchemaRequest{Schema: schemaText} + log.Trace().Interface("request", writeRequest).Msg("writing schema") + + resp, err := destClient.WriteSchema(cmd.Context(), writeRequest) + if err != nil { + log.Fatal().Err(err).Msg("failed to write schema") + } + log.Trace().Interface("response", resp).Msg("wrote schema") + + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := commands.PrettyProto(resp) + if err != nil { + log.Fatal().Err(err).Msg("failed to convert schema to JSON") + } + + console.Println(string(prettyProto)) + return nil + } + + return nil +} + +func schemaWriteCmdFunc(cmd *cobra.Command, args []string) error { + intFd, err := safecast.ToInt(uint(os.Stdout.Fd())) + if err != nil { + return err + } + if len(args) == 0 && term.IsTerminal(intFd) { + return fmt.Errorf("must provide file path or contents via stdin") + } + + client, err := client.NewClient(cmd) + if err != nil { + return err + } + var schemaBytes []byte + switch len(args) { + case 1: + schemaBytes, err = os.ReadFile(args[0]) + if err != nil { + return fmt.Errorf("failed to read schema file: %w", err) + } + log.Trace().Str("schema", string(schemaBytes)).Str("file", args[0]).Msg("read schema from file") + case 0: + schemaBytes, err = io.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("failed to read schema file: %w", err) + } + log.Trace().Str("schema", string(schemaBytes)).Msg("read schema from stdin") + default: + panic("schemaWriteCmdFunc called with incorrect number of arguments") + } + + if len(schemaBytes) == 0 { + return errors.New("attempted to write empty schema") + } + + prefix, err := determinePrefixForSchema(cmd.Context(), cobrautil.MustGetString(cmd, "schema-definition-prefix"), client, nil) + if err != nil { + return err + } + + schemaText, err := rewriteSchema(string(schemaBytes), prefix) + if err != nil { + return err + } + + request := &v1.WriteSchemaRequest{Schema: schemaText} + log.Trace().Interface("request", request).Msg("writing schema") + + resp, err := client.WriteSchema(cmd.Context(), request) + if err != nil { + log.Fatal().Err(err).Msg("failed to write schema") + } + log.Trace().Interface("response", resp).Msg("wrote schema") + + if cobrautil.MustGetBool(cmd, "json") { + prettyProto, err := commands.PrettyProto(resp) + if err != nil { + log.Fatal().Err(err).Msg("failed to convert schema to JSON") + } + + console.Println(string(prettyProto)) + return nil + } + + return nil +} + +// rewriteSchema rewrites the given existing schema to include the specified prefix on all definitions. +func rewriteSchema(existingSchemaText string, definitionPrefix string) (string, error) { + if definitionPrefix == "" { + return existingSchemaText, nil + } + + compiled, err := compiler.Compile( + compiler.InputSchema{Source: input.Source("schema"), SchemaString: existingSchemaText}, + compiler.ObjectTypePrefix(definitionPrefix), + compiler.SkipValidation(), + ) + if err != nil { + return "", err + } + + generated, _, err := generator.GenerateSchema(compiled.OrderedDefinitions) + return generated, err +} + +// determinePrefixForSchema determines the prefix to be applied to a schema that will be written. +// +// If specifiedPrefix is non-empty, it is returned immediately. +// If existingSchema is non-nil, it is parsed for the prefix. +// Otherwise, the client is used to retrieve the existing schema (if any), and the prefix is retrieved from there. +func determinePrefixForSchema(ctx context.Context, specifiedPrefix string, client client.Client, existingSchema *string) (string, error) { + if specifiedPrefix != "" { + return specifiedPrefix, nil + } + + var schemaText string + if existingSchema != nil { + schemaText = *existingSchema + } else { + readSchemaText, err := commands.ReadSchema(ctx, client) + if err != nil { + return "", nil + } + schemaText = readSchemaText + } + + // If there is no schema found, return the empty string. + if schemaText == "" { + return "", nil + } + + // Otherwise, compile the schema and grab the prefixes of the namespaces defined. + found, err := compiler.Compile( + compiler.InputSchema{Source: input.Source("schema"), SchemaString: schemaText}, + compiler.AllowUnprefixedObjectType(), + compiler.SkipValidation(), + ) + if err != nil { + return "", err + } + + foundPrefixes := make([]string, 0, len(found.OrderedDefinitions)) + for _, def := range found.OrderedDefinitions { + if strings.Contains(def.GetName(), "/") { + parts := strings.Split(def.GetName(), "/") + foundPrefixes = append(foundPrefixes, parts[0]) + } else { + foundPrefixes = append(foundPrefixes, "") + } + } + + prefixes := stringz.Dedup(foundPrefixes) + if len(prefixes) == 1 { + prefix := prefixes[0] + log.Debug().Str("prefix", prefix).Msg("found schema definition prefix") + return prefix, nil + } + + return "", nil +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/validate.go b/vendor/github.com/authzed/zed/internal/cmd/validate.go new file mode 100644 index 0000000..ffca1c6 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/validate.go @@ -0,0 +1,414 @@ +package cmd + +import ( + "errors" + "fmt" + "net/url" + "os" + "strings" + + "github.com/ccoveille/go-safecast" + "github.com/charmbracelet/lipgloss" + "github.com/jzelinskie/cobrautil/v2" + "github.com/muesli/termenv" + "github.com/spf13/cobra" + + composable "github.com/authzed/spicedb/pkg/composableschemadsl/compiler" + "github.com/authzed/spicedb/pkg/development" + core "github.com/authzed/spicedb/pkg/proto/core/v1" + devinterface "github.com/authzed/spicedb/pkg/proto/developer/v1" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/validationfile" + + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/internal/decode" + "github.com/authzed/zed/internal/printers" +) + +var ( + // NOTE: these need to be set *after* the renderer has been set, otherwise + // the forceColor setting can't work, hence the thunking. + success = func() string { + return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("10")).Render("Success!") + } + complete = func() string { return lipgloss.NewStyle().Bold(true).Render("complete") } + errorPrefix = func() string { return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("9")).Render("error: ") } + warningPrefix = func() string { + return lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("3")).Render("warning: ") + } + errorMessageStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true).Width(80) } + linePrefixStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("12")) } + highlightedSourceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) } + highlightedLineStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("9")) } + codeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("8")) } + highlightedCodeStyle = func() lipgloss.Style { return lipgloss.NewStyle().Foreground(lipgloss.Color("15")) } + traceStyle = func() lipgloss.Style { return lipgloss.NewStyle().Bold(true) } +) + +func registerValidateCmd(cmd *cobra.Command) { + validateCmd.Flags().Bool("force-color", false, "force color code output even in non-tty environments") + validateCmd.Flags().String("schema-type", "", "force validation according to specific schema syntax (\"\", \"composable\", \"standard\")") + cmd.AddCommand(validateCmd) +} + +var validateCmd = &cobra.Command{ + Use: "validate <validation_file_or_schema_file>", + Short: "Validates the given validation file (.yaml, .zaml) or schema file (.zed)", + Example: ` + From a local file (with prefix): + zed validate file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml + + From a local file (no prefix): + zed validate authzed-x7izWU8_2Gw3.yaml + + From a gist: + zed validate https://gist.github.com/ecordell/8e3b613a677e3c844742cf24421c08b6 + + From a playground link: + zed validate https://play.authzed.com/s/iksdFvCtvnkR/schema + + From pastebin: + zed validate https://pastebin.com/8qU45rVK + + From a devtools instance: + zed validate https://localhost:8443/download`, + Args: commands.ValidationWrapper(cobra.MinimumNArgs(1)), + ValidArgsFunction: commands.FileExtensionCompletions("zed", "yaml", "zaml"), + PreRunE: validatePreRunE, + RunE: func(cmd *cobra.Command, filenames []string) error { + result, shouldExit, err := validateCmdFunc(cmd, filenames) + if err != nil { + return err + } + console.Print(result) + if shouldExit { + os.Exit(1) + } + return nil + }, + + // A schema that causes the parser/compiler to error will halt execution + // of this command with an error. In that case, we want to just display the error, + // rather than showing usage for this command. + SilenceUsage: true, +} + +var validSchemaTypes = []string{"", "standard", "composable"} + +func validatePreRunE(cmd *cobra.Command, _ []string) error { + // Override lipgloss's autodetection of whether it's in a terminal environment + // and display things in color anyway. This can be nice in CI environments that + // support it. + setForceColor := cobrautil.MustGetBool(cmd, "force-color") + if setForceColor { + lipgloss.SetColorProfile(termenv.ANSI256) + } + + schemaType := cobrautil.MustGetString(cmd, "schema-type") + schemaTypeValid := false + for _, validType := range validSchemaTypes { + if schemaType == validType { + schemaTypeValid = true + } + } + if !schemaTypeValid { + return fmt.Errorf("schema-type must be one of \"\", \"standard\", \"composable\". received: %s", schemaType) + } + + return nil +} + +// validateCmdFunc returns the string to print to the user, whether to return a non-zero status code, and any errors. +func validateCmdFunc(cmd *cobra.Command, filenames []string) (string, bool, error) { + // Initialize variables for multiple files + var ( + totalFiles = len(filenames) + successfullyValidatedFiles = 0 + shouldExit = false + toPrint = &strings.Builder{} + schemaType = cobrautil.MustGetString(cmd, "schema-type") + ) + + for _, filename := range filenames { + // If we're running over multiple files, print the filename for context/debugging purposes + if totalFiles > 1 { + toPrint.WriteString(filename + "\n") + } + + u, err := url.Parse(filename) + if err != nil { + return "", false, err + } + + decoder, err := decode.DecoderForURL(u) + if err != nil { + return "", false, err + } + + var parsed validationfile.ValidationFile + validateContents, isOnlySchema, err := decoder(&parsed) + standardErrors, composableErrs, otherErrs := classifyErrors(err) + + switch schemaType { + case "standard": + if standardErrors != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(standardErrors, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, standardErrors + } + case "composable": + if composableErrs != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(composableErrs, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, composableErrs + } + default: + // By default, validate will attempt to validate a schema first according to composable schema rules, + // then standard schema rules, + // and if both fail it will show the errors from composable schema. + if composableErrs != nil && standardErrors != nil { + var errWithSource spiceerrors.WithSourceError + if errors.As(composableErrs, &errWithSource) { + outputErrorWithSource(toPrint, validateContents, errWithSource) + shouldExit = true + } + return "", shouldExit, composableErrs + } + } + + if otherErrs != nil { + return "", false, otherErrs + } + + tuples := make([]*core.RelationTuple, 0) + totalAssertions := 0 + totalRelationsValidated := 0 + + for _, rel := range parsed.Relationships.Relationships { + tuples = append(tuples, rel.ToCoreTuple()) + } + + // Create the development context for each run + ctx := cmd.Context() + devCtx, devErrs, err := development.NewDevContext(ctx, &devinterface.RequestContext{ + Schema: parsed.Schema.Schema, + Relationships: tuples, + }) + if err != nil { + return "", false, err + } + if devErrs != nil { + schemaOffset := parsed.Schema.SourcePosition.LineNumber + if isOnlySchema { + schemaOffset = 0 + } + + // Output errors + outputDeveloperErrorsWithLineOffset(toPrint, validateContents, devErrs.InputErrors, schemaOffset) + return toPrint.String(), true, nil + } + // Run assertions + adevErrs, aerr := development.RunAllAssertions(devCtx, &parsed.Assertions) + if aerr != nil { + return "", false, aerr + } + if adevErrs != nil { + outputDeveloperErrors(toPrint, validateContents, adevErrs) + return toPrint.String(), true, nil + } + successfullyValidatedFiles++ + + // Run expected relations for file + _, erDevErrs, rerr := development.RunValidation(devCtx, &parsed.ExpectedRelations) + if rerr != nil { + return "", false, rerr + } + if erDevErrs != nil { + outputDeveloperErrors(toPrint, validateContents, erDevErrs) + return toPrint.String(), true, nil + } + // Print out any warnings for file + warnings, err := development.GetWarnings(ctx, devCtx) + if err != nil { + return "", false, err + } + + if len(warnings) > 0 { + for _, warning := range warnings { + fmt.Fprintf(toPrint, "%s%s\n", warningPrefix(), warning.Message) + outputForLine(toPrint, validateContents, uint64(warning.Line), warning.SourceCode, uint64(warning.Column)) // warning.LineNumber is 1-indexed + toPrint.WriteString("\n") + } + + toPrint.WriteString(complete()) + } else { + toPrint.WriteString(success()) + } + totalAssertions += len(parsed.Assertions.AssertTrue) + len(parsed.Assertions.AssertFalse) + totalRelationsValidated += len(parsed.ExpectedRelations.ValidationMap) + + fmt.Fprintf(toPrint, " - %d relationships loaded, %d assertions run, %d expected relations validated\n", + len(tuples), + totalAssertions, + totalRelationsValidated) + } + + if totalFiles > 1 { + fmt.Fprintf(toPrint, "total files: %d, successfully validated files: %d\n", totalFiles, successfullyValidatedFiles) + } + + return toPrint.String(), shouldExit, nil +} + +func outputErrorWithSource(sb *strings.Builder, validateContents []byte, errWithSource spiceerrors.WithSourceError) { + fmt.Fprintf(sb, "%s%s\n", errorPrefix(), errorMessageStyle().Render(errWithSource.Error())) + outputForLine(sb, validateContents, errWithSource.LineNumber, errWithSource.SourceCodeString, 0) // errWithSource.LineNumber is 1-indexed +} + +func outputForLine(sb *strings.Builder, validateContents []byte, oneIndexedLineNumber uint64, sourceCodeString string, oneIndexedColumnPosition uint64) { + lines := strings.Split(string(validateContents), "\n") + // These should be fine to be zero if the cast fails. + intLineNumber, _ := safecast.ToInt(oneIndexedLineNumber) + intColumnPosition, _ := safecast.ToInt(oneIndexedColumnPosition) + errorLineNumber := intLineNumber - 1 + for i := errorLineNumber - 3; i < errorLineNumber+3; i++ { + if i == errorLineNumber { + renderLine(sb, lines, i, sourceCodeString, errorLineNumber, intColumnPosition-1) + } else { + renderLine(sb, lines, i, "", errorLineNumber, -1) + } + } +} + +func outputDeveloperErrors(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError) { + outputDeveloperErrorsWithLineOffset(sb, validateContents, devErrors, 0) +} + +func outputDeveloperErrorsWithLineOffset(sb *strings.Builder, validateContents []byte, devErrors []*devinterface.DeveloperError, lineOffset int) { + lines := strings.Split(string(validateContents), "\n") + + for _, devErr := range devErrors { + outputDeveloperError(sb, devErr, lines, lineOffset) + } +} + +func outputDeveloperError(sb *strings.Builder, devError *devinterface.DeveloperError, lines []string, lineOffset int) { + fmt.Fprintf(sb, "%s %s\n", errorPrefix(), errorMessageStyle().Render(devError.Message)) + errorLineNumber := int(devError.Line) - 1 + lineOffset // devError.Line is 1-indexed + for i := errorLineNumber - 3; i < errorLineNumber+3; i++ { + if i == errorLineNumber { + renderLine(sb, lines, i, devError.Context, errorLineNumber, -1) + } else { + renderLine(sb, lines, i, "", errorLineNumber, -1) + } + } + + if devError.CheckResolvedDebugInformation != nil && devError.CheckResolvedDebugInformation.Check != nil { + fmt.Fprintf(sb, "\n %s\n", traceStyle().Render("Explanation:")) + tp := printers.NewTreePrinter() + printers.DisplayCheckTrace(devError.CheckResolvedDebugInformation.Check, tp, true) + sb.WriteString(tp.Indented()) + } + + sb.WriteString("\n\n") +} + +func renderLine(sb *strings.Builder, lines []string, index int, highlight string, highlightLineIndex int, highlightStartingColumnIndex int) { + if index < 0 || index >= len(lines) { + return + } + + lineNumberLength := len(fmt.Sprintf("%d", len(lines))) + lineContents := strings.ReplaceAll(lines[index], "\t", " ") + lineDelimiter := "|" + + highlightLength := max(0, len(highlight)-1) + highlightColumnIndex := -1 + + // If the highlight string was provided, then we need to find the index of the highlight + // string in the line contents to determine where to place the caret. + if len(highlight) > 0 { + offset := 0 + for { + foundRelativeIndex := strings.Index(lineContents[offset:], highlight) + foundIndex := foundRelativeIndex + offset + if foundIndex >= highlightStartingColumnIndex { + highlightColumnIndex = foundIndex + break + } + + offset = foundIndex + 1 + if foundRelativeIndex < 0 || foundIndex > len(lineContents) { + break + } + } + } else if highlightStartingColumnIndex >= 0 { + // Otherwise, just show a caret at the specified starting column, if any. + highlightColumnIndex = highlightStartingColumnIndex + } + + lineNumberStr := fmt.Sprintf("%d", index+1) + noNumberSpaces := strings.Repeat(" ", lineNumberLength) + + lineNumberStyle := linePrefixStyle + lineContentsStyle := codeStyle + if index == highlightLineIndex { + lineNumberStyle = highlightedLineStyle + lineContentsStyle = highlightedCodeStyle + lineDelimiter = ">" + } + + lineNumberSpacer := strings.Repeat(" ", lineNumberLength-len(lineNumberStr)) + + if highlightColumnIndex < 0 { + fmt.Fprintf(sb, " %s%s %s %s\n", lineNumberSpacer, lineNumberStyle().Render(lineNumberStr), lineDelimiter, lineContentsStyle().Render(lineContents)) + } else { + fmt.Fprintf(sb, " %s%s %s %s%s%s\n", + lineNumberSpacer, + lineNumberStyle().Render(lineNumberStr), + lineDelimiter, + lineContentsStyle().Render(lineContents[0:highlightColumnIndex]), + highlightedSourceStyle().Render(highlight), + lineContentsStyle().Render(lineContents[highlightColumnIndex+len(highlight):])) + + fmt.Fprintf(sb, " %s %s %s%s%s\n", + noNumberSpaces, + lineDelimiter, + strings.Repeat(" ", highlightColumnIndex), + highlightedSourceStyle().Render("^"), + highlightedSourceStyle().Render(strings.Repeat("~", highlightLength))) + } +} + +// classifyErrors returns errors from the composable DSL, the standard DSL, and any other parsing errors. +func classifyErrors(err error) (error, error, error) { + if err == nil { + return nil, nil, nil + } + var standardErr compiler.BaseCompilerError + var composableErr composable.BaseCompilerError + var retStandard, retComposable, allOthers error + + ok := errors.As(err, &standardErr) + if ok { + retStandard = standardErr + } + ok = errors.As(err, &composableErr) + if ok { + retComposable = composableErr + } + + if retStandard == nil && retComposable == nil { + allOthers = err + } + + return retStandard, retComposable, allOthers +} diff --git a/vendor/github.com/authzed/zed/internal/cmd/version.go b/vendor/github.com/authzed/zed/internal/cmd/version.go new file mode 100644 index 0000000..b87ec66 --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/version.go @@ -0,0 +1,64 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/gookit/color" + "github.com/jzelinskie/cobrautil/v2" + "github.com/mattn/go-isatty" + "github.com/spf13/cobra" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/authzed/authzed-go/pkg/responsemeta" + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/console" +) + +func versionCmdFunc(cmd *cobra.Command, _ []string) error { + if !isatty.IsTerminal(os.Stdout.Fd()) { + color.Disable() + } + + includeRemoteVersion := cobrautil.MustGetBool(cmd, "include-remote-version") + hasContext := false + if includeRemoteVersion { + configStore, secretStore := client.DefaultStorage() + _, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + hasContext = err == nil + } + + if hasContext && includeRemoteVersion { + green := color.FgGreen.Render + fmt.Print(green("client: ")) + } + + console.Println(cobrautil.UsageVersion("zed", cobrautil.MustGetBool(cmd, "include-deps"))) + + if hasContext && includeRemoteVersion { + client, err := client.NewClient(cmd) + if err != nil { + return err + } + + // NOTE: we ignore the error here, as it may be due to a schema not existing, or + // the client being unable to connect, etc. We just treat all such cases as an unknown + // version. + var headerMD metadata.MD + _, _ = client.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}, grpc.Header(&headerMD)) + version := headerMD.Get(string(responsemeta.ServerVersion)) + + blue := color.FgLightBlue.Render + fmt.Print(blue("service: ")) + if len(version) == 1 { + console.Println(version[0]) + } else { + console.Println("(unknown)") + } + } + + return nil +} |
