diff options
Diffstat (limited to 'vendor/github.com/authzed/zed/internal/cmd/backup.go')
| -rw-r--r-- | vendor/github.com/authzed/zed/internal/cmd/backup.go | 868 |
1 files changed, 868 insertions, 0 deletions
diff --git a/vendor/github.com/authzed/zed/internal/cmd/backup.go b/vendor/github.com/authzed/zed/internal/cmd/backup.go new file mode 100644 index 0000000..57fa83c --- /dev/null +++ b/vendor/github.com/authzed/zed/internal/cmd/backup.go @@ -0,0 +1,868 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/jzelinskie/cobrautil/v2" + "github.com/mattn/go-isatty" + "github.com/rodaine/table" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + schemapkg "github.com/authzed/spicedb/pkg/schema" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/generator" + "github.com/authzed/spicedb/pkg/tuple" + + "github.com/authzed/zed/internal/client" + "github.com/authzed/zed/internal/commands" + "github.com/authzed/zed/internal/console" + "github.com/authzed/zed/pkg/backupformat" +) + +const ( + returnIfExists = true + doNotReturnIfExists = false +) + +// cobraRunEFunc is the signature of a cobra.Command.RunE function. +type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error) + +// withErrorHandling is a wrapper that centralizes error handling, instead of having to scatter it around the command logic. +func withErrorHandling(f cobraRunEFunc) cobraRunEFunc { + return func(cmd *cobra.Command, args []string) (err error) { + return addSizeErrInfo(f(cmd, args)) + } +} + +var ( + backupCmd = &cobra.Command{ + Use: "backup <filename>", + Short: "Create, restore, and inspect permissions system backups", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + // Create used to be on the root, so add it here for back-compat. + RunE: withErrorHandling(backupCreateCmdFunc), + } + + backupCreateCmd = &cobra.Command{ + Use: "create <filename>", + Short: "Backup a permission system to a file", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + RunE: withErrorHandling(backupCreateCmdFunc), + } + + backupRestoreCmd = &cobra.Command{ + Use: "restore <filename>", + Short: "Restore a permission system from a file", + Args: commands.ValidationWrapper(commands.StdinOrExactArgs(1)), + RunE: backupRestoreCmdFunc, + } + + backupParseSchemaCmd = &cobra.Command{ + Use: "parse-schema <filename>", + Short: "Extract the schema from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseSchemaCmdFunc(cmd, os.Stdout, args) + }, + } + + backupParseRevisionCmd = &cobra.Command{ + Use: "parse-revision <filename>", + Short: "Extract the revision from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseRevisionCmdFunc(cmd, os.Stdout, args) + }, + } + + backupParseRelsCmd = &cobra.Command{ + Use: "parse-relationships <filename>", + Short: "Extract the relationships from a backup file", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupParseRelsCmdFunc(cmd, os.Stdout, args) + }, + } + + backupRedactCmd = &cobra.Command{ + Use: "redact <filename>", + Short: "Redact a backup file to remove sensitive information", + Args: commands.ValidationWrapper(cobra.ExactArgs(1)), + RunE: func(cmd *cobra.Command, args []string) error { + return backupRedactCmdFunc(cmd, args) + }, + } +) + +func registerBackupCmd(rootCmd *cobra.Command) { + rootCmd.AddCommand(backupCmd) + registerBackupCreateFlags(backupCmd) + + backupCmd.AddCommand(backupCreateCmd) + registerBackupCreateFlags(backupCreateCmd) + + backupCmd.AddCommand(backupRestoreCmd) + registerBackupRestoreFlags(backupRestoreCmd) + + backupCmd.AddCommand(backupRedactCmd) + backupRedactCmd.Flags().Bool("redact-definitions", true, "redact definitions") + backupRedactCmd.Flags().Bool("redact-relations", true, "redact relations") + backupRedactCmd.Flags().Bool("redact-object-ids", true, "redact object IDs") + backupRedactCmd.Flags().Bool("print-redacted-object-ids", false, "prints the redacted object IDs") + + // Restore used to be on the root, so add it there too, but hidden. + restoreCmd := &cobra.Command{ + Use: "restore <filename>", + Short: "Restore a permission system from a backup file", + Args: commands.ValidationWrapper(cobra.MaximumNArgs(1)), + RunE: backupRestoreCmdFunc, + Hidden: true, + } + rootCmd.AddCommand(restoreCmd) + registerBackupRestoreFlags(restoreCmd) + + backupCmd.AddCommand(backupParseSchemaCmd) + backupParseSchemaCmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + backupParseSchemaCmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + + backupCmd.AddCommand(backupParseRevisionCmd) + backupCmd.AddCommand(backupParseRelsCmd) + backupParseRelsCmd.Flags().String("prefix-filter", "", "Include only relationships with a given prefix") +} + +func registerBackupRestoreFlags(cmd *cobra.Command) { + cmd.Flags().Uint("batch-size", 1_000, "restore relationship write batch size") + cmd.Flags().Uint("batches-per-transaction", 10, "number of batches per transaction") + cmd.Flags().String("conflict-strategy", "fail", "strategy used when a conflicting relationship is found. Possible values: fail, skip, touch") + cmd.Flags().Bool("disable-retries", false, "retries when an errors is determined to be retryable (e.g. serialization errors)") + cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + cmd.Flags().Duration("request-timeout", 30*time.Second, "timeout for each request performed during restore") +} + +func registerBackupCreateFlags(cmd *cobra.Command) { + cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix") + cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax") + cmd.Flags().Uint32("page-limit", 0, "defines the number of relationships to be read by requested page during backup") +} + +func createBackupFile(filename string, returnIfExists bool) (*os.File, bool, error) { + if filename == "-" { + log.Trace().Str("filename", "- (stdout)").Send() + return os.Stdout, false, nil + } + + log.Trace().Str("filename", filename).Send() + + if _, err := os.Stat(filename); err == nil { + if !returnIfExists { + return nil, false, fmt.Errorf("backup file already exists: %s", filename) + } + + f, err := os.OpenFile(filename, os.O_RDWR|os.O_APPEND, 0o644) + if err != nil { + return nil, false, fmt.Errorf("unable to open existing backup file: %w", err) + } + + return f, true, nil + } + + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o644) + if err != nil { + return nil, false, fmt.Errorf("unable to create backup file: %w", err) + } + + return f, false, nil +} + +var ( + missingAllowedTypes = regexp.MustCompile(`(\s*)(relation)(.+)(/\* missing allowed types \*/)(.*)`) + shortRelations = regexp.MustCompile(`(\s*)relation [a-z][a-z0-9_]:(.+)`) +) + +func partialPrefixMatch(name, prefix string) bool { + return strings.HasPrefix(name, prefix+"/") +} + +func filterSchemaDefs(schema, prefix string) (filteredSchema string, err error) { + if schema == "" || prefix == "" { + return schema, nil + } + + compiledSchema, err := compiler.Compile( + compiler.InputSchema{Source: "schema", SchemaString: schema}, + compiler.AllowUnprefixedObjectType(), + compiler.SkipValidation(), + ) + if err != nil { + return "", fmt.Errorf("error reading schema: %w", err) + } + + var prefixedDefs []compiler.SchemaDefinition + for _, def := range compiledSchema.ObjectDefinitions { + if partialPrefixMatch(def.Name, prefix) { + prefixedDefs = append(prefixedDefs, def) + } + } + for _, def := range compiledSchema.CaveatDefinitions { + if partialPrefixMatch(def.Name, prefix) { + prefixedDefs = append(prefixedDefs, def) + } + } + + if len(prefixedDefs) == 0 { + return "", errors.New("filtered all definitions from schema") + } + + filteredSchema, _, err = generator.GenerateSchema(prefixedDefs) + if err != nil { + return "", fmt.Errorf("error generating filtered schema: %w", err) + } + + // Validate that the type system for the generated schema is comprehensive. + compiledFilteredSchema, err := compiler.Compile( + compiler.InputSchema{Source: "generated-schema", SchemaString: filteredSchema}, + compiler.AllowUnprefixedObjectType(), + ) + if err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + + for _, rawDef := range compiledFilteredSchema.ObjectDefinitions { + ts := schemapkg.NewTypeSystem(schemapkg.ResolverForCompiledSchema(*compiledFilteredSchema)) + def, err := schemapkg.NewDefinition(ts, rawDef) + if err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + if _, err := def.Validate(context.Background()); err != nil { + return "", fmt.Errorf("generated invalid schema: %w", err) + } + } + + return +} + +func hasRelPrefix(rel *v1.Relationship, prefix string) bool { + // Skip any relationships without the prefix on both sides. + return strings.HasPrefix(rel.Resource.ObjectType, prefix) && + strings.HasPrefix(rel.Subject.Object.ObjectType, prefix) +} + +func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") + + backupFileName, err := computeBackupFileName(cmd, args) + if err != nil { + return err + } + + backupFile, backupExists, err := createBackupFile(backupFileName, returnIfExists) + if err != nil { + return err + } + + defer func(e *error) { + *e = errors.Join(*e, backupFile.Sync()) + *e = errors.Join(*e, backupFile.Close()) + }(&err) + + // the goal of this file is to keep the bulk export cursor in case the process is terminated + // and we need to resume from where we left off. OCF does not support in-place record updates. + progressFile, cursor, err := openProgressFile(backupFileName, backupExists) + if err != nil { + return err + } + + var backupCompleted bool + defer func(e *error) { + *e = errors.Join(*e, progressFile.Sync()) + *e = errors.Join(*e, progressFile.Close()) + + if backupCompleted { + if err := os.Remove(progressFile.Name()); err != nil { + log.Warn(). + Str("progress-file", progressFile.Name()). + Msg("failed to remove progress file, consider removing it manually") + } + } + }(&err) + + c, err := client.NewClient(cmd) + if err != nil { + return fmt.Errorf("unable to initialize client: %w", err) + } + + var zedToken *v1.ZedToken + var encoder *backupformat.Encoder + if backupExists { + encoder, err = backupformat.NewEncoderForExisting(backupFile) + if err != nil { + return fmt.Errorf("error creating backup file encoder: %w", err) + } + } else { + encoder, zedToken, err = encoderForNewBackup(cmd, c, backupFile) + if err != nil { + return err + } + } + + defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err) + + if zedToken == nil && cursor == nil { + return errors.New("malformed existing backup, consider recreating it") + } + + req := &v1.BulkExportRelationshipsRequest{ + OptionalLimit: pageLimit, + OptionalCursor: cursor, + } + + // if a cursor is present, zedtoken is not needed (it is already in the cursor) + if zedToken != nil { + req.Consistency = &v1.Consistency{ + Requirement: &v1.Consistency_AtExactSnapshot{ + AtExactSnapshot: zedToken, + }, + } + } + + ctx := cmd.Context() + relationshipStream, err := c.BulkExportRelationships(ctx, req) + if err != nil { + return fmt.Errorf("error exporting relationships: %w", err) + } + + relationshipReadStart := time.Now() + tick := time.Tick(5 * time.Second) + bar := console.CreateProgressBar("processing backup") + var relsFilteredOut, relsProcessed uint64 + defer func() { + _ = bar.Finish() + + evt := log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)) + if isCanceled(err) { + evt.Msg("backup canceled - resume by restarting the backup command") + } else if err != nil { + evt.Msg("backup failed") + } else { + evt.Msg("finished backup") + } + }() + + for { + if err := ctx.Err(); err != nil { + if isCanceled(err) { + return context.Canceled + } + + return fmt.Errorf("aborted backup: %w", err) + } + + relsResp, err := relationshipStream.Recv() + if err != nil { + if isCanceled(err) { + return context.Canceled + } + + if !errors.Is(err, io.EOF) { + return fmt.Errorf("error receiving relationships: %w", err) + } + break + } + + for _, rel := range relsResp.Relationships { + if hasRelPrefix(rel, prefixFilter) { + if err := encoder.Append(rel); err != nil { + return fmt.Errorf("error storing relationship: %w", err) + } + } else { + relsFilteredOut++ + } + + relsProcessed++ + if err := bar.Add(1); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + + // progress fallback in case there is no TTY + if !isatty.IsTerminal(os.Stderr.Fd()) { + select { + case <-tick: + log.Info(). + Uint64("filtered", relsFilteredOut). + Uint64("processed", relsProcessed). + Uint64("throughput", perSec(relsProcessed, time.Since(relationshipReadStart))). + Stringer("elapsed", time.Since(relationshipReadStart).Round(time.Second)). + Msg("backup progress") + default: + } + } + } + + if err := writeProgress(progressFile, relsResp); err != nil { + return err + } + } + + backupCompleted = true + return nil +} + +// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup +// must be taken. +func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) { + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + + schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}) + if err != nil { + return nil, nil, fmt.Errorf("error reading schema: %w", err) + } + if schemaResp.ReadAt == nil { + return nil, nil, fmt.Errorf("`backup` is not supported on this version of SpiceDB") + } + schema := schemaResp.SchemaText + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + + if prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return nil, nil, err + } + } + + zedToken := schemaResp.ReadAt + + encoder, err := backupformat.NewEncoder(backupFile, schema, zedToken) + if err != nil { + return nil, nil, fmt.Errorf("error creating backup file encoder: %w", err) + } + + return encoder, zedToken, nil +} + +func writeProgress(progressFile *os.File, relsResp *v1.BulkExportRelationshipsResponse) error { + err := progressFile.Truncate(0) + if err != nil { + return fmt.Errorf("unable to truncate backup progress file: %w", err) + } + + _, err = progressFile.Seek(0, 0) + if err != nil { + return fmt.Errorf("unable to seek backup progress file: %w", err) + } + + _, err = progressFile.WriteString(relsResp.AfterResultCursor.Token) + if err != nil { + return fmt.Errorf("unable to write result cursor to backup progress file: %w", err) + } + + return nil +} + +// openProgressFile returns the progress marker file and the stored progress cursor if it exists, or creates +// a new one if it does not exist. If the backup file exists, but the progress marker does not, it will return an error. +// +// The progress marker file keeps track of the last successful cursor received from the server, and is used to resume +// backups in case of failure. +func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.File, *v1.Cursor, error) { + var cursor *v1.Cursor + progressFileName := toLockFileName(backupFileName) + var progressFile *os.File + // if a backup existed + var fileMode int + readCursor, err := os.ReadFile(progressFileName) + if backupAlreadyExisted && (os.IsNotExist(err) || len(readCursor) == 0) { + return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName) + } else if backupAlreadyExisted && err == nil { + cursor = &v1.Cursor{ + Token: string(readCursor), + } + + // if backup existed and there is a progress marker, the latter should not be truncated to make sure the + // cursor stays around in case of a failure before we even start ingesting from bulk export + fileMode = os.O_WRONLY | os.O_CREATE + log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume") + } else { + // if a backup did not exist, make sure to truncate the progress file + fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC + } + + progressFile, err = os.OpenFile(progressFileName, fileMode, 0o644) + if err != nil { + return nil, nil, err + } + + return progressFile, cursor, nil +} + +func toLockFileName(backupFileName string) string { + return backupFileName + ".lock" +} + +// computeBackupFileName computes the backup file name based. +// If no file name is provided, it derives a backup on the current context +func computeBackupFileName(cmd *cobra.Command, args []string) (string, error) { + if len(args) > 0 { + return args[0], nil + } + + configStore, secretStore := client.DefaultStorage() + token, err := client.GetCurrentTokenWithCLIOverride(cmd, configStore, secretStore) + if err != nil { + return "", fmt.Errorf("failed to determine current zed context: %w", err) + } + + ex, err := os.Executable() + if err != nil { + return "", err + } + exPath := filepath.Dir(ex) + + backupFileName := filepath.Join(exPath, token.Name+".zedbackup") + + return backupFileName, nil +} + +func isCanceled(err error) bool { + if st, ok := status.FromError(err); ok && st.Code() == codes.Canceled { + return true + } + + return errors.Is(err, context.Canceled) +} + +func openRestoreFile(filename string) (*os.File, int64, error) { + if filename == "" { + log.Trace().Str("filename", "(stdin)").Send() + return os.Stdin, -1, nil + } + + log.Trace().Str("filename", filename).Send() + + stats, err := os.Stat(filename) + if err != nil { + return nil, 0, fmt.Errorf("unable to stat restore file: %w", err) + } + + f, err := os.Open(filename) + if err != nil { + return nil, 0, fmt.Errorf("unable to open restore file: %w", err) + } + + return f, stats.Size(), nil +} + +func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + if loadedToken := decoder.ZedToken(); loadedToken != nil { + log.Debug().Str("revision", loadedToken.Token).Msg("parsed revision") + } + + schema := decoder.Schema() + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") + if prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return err + } + } + log.Debug().Str("schema", schema).Bool("filtered", prefixFilter != "").Msg("parsed schema") + + c, err := client.NewClient(cmd) + if err != nil { + return fmt.Errorf("unable to initialize client: %w", err) + } + + batchSize := cobrautil.MustGetUint(cmd, "batch-size") + batchesPerTransaction := cobrautil.MustGetUint(cmd, "batches-per-transaction") + + strategy, err := GetEnum[ConflictStrategy](cmd, "conflict-strategy", conflictStrategyMapping) + if err != nil { + return err + } + disableRetries := cobrautil.MustGetBool(cmd, "disable-retries") + requestTimeout := cobrautil.MustGetDuration(cmd, "request-timeout") + + return newRestorer(schema, decoder, c, prefixFilter, batchSize, batchesPerTransaction, strategy, + disableRetries, requestTimeout).restoreFromDecoder(cmd.Context()) +} + +// GetEnum is a helper for getting an enum value from a string cobra flag. +func GetEnum[E constraints.Integer](cmd *cobra.Command, name string, mapping map[string]E) (E, error) { + value := cobrautil.MustGetString(cmd, name) + value = strings.TrimSpace(strings.ToLower(value)) + if enum, ok := mapping[value]; ok { + return enum, nil + } + + var zeroValueE E + return zeroValueE, fmt.Errorf("unexpected flag '%s' value '%s': should be one of %v", name, value, maps.Keys(mapping)) +} + +func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + schema := decoder.Schema() + + // Remove any invalid relations generated from old, backwards-incompat + // Serverless permission systems. + if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + schema = rewriteLegacy(schema) + } + + // Skip any definitions without the provided prefix + if prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter"); prefixFilter != "" { + schema, err = filterSchemaDefs(schema, prefixFilter) + if err != nil { + return err + } + } + + _, err = fmt.Fprintln(out, schema) + return err +} + +func backupParseRevisionCmdFunc(_ *cobra.Command, out io.Writer, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + loadedToken := decoder.ZedToken() + if loadedToken == nil { + return fmt.Errorf("failed to parse decoded revision") + } + + _, err = fmt.Fprintln(out, loadedToken.Token) + return err +} + +func backupRedactCmdFunc(cmd *cobra.Command, args []string) error { + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return fmt.Errorf("error creating restore file decoder: %w", err) + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + filename := args[0] + ".redacted" + writer, _, err := createBackupFile(filename, doNotReturnIfExists) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, writer.Close()) }(&err) + + redactor, err := backupformat.NewRedactor(decoder, writer, backupformat.RedactionOptions{ + RedactDefinitions: cobrautil.MustGetBool(cmd, "redact-definitions"), + RedactRelations: cobrautil.MustGetBool(cmd, "redact-relations"), + RedactObjectIDs: cobrautil.MustGetBool(cmd, "redact-object-ids"), + }) + if err != nil { + return fmt.Errorf("error creating redactor: %w", err) + } + + defer func(e *error) { *e = errors.Join(*e, redactor.Close()) }(&err) + bar := console.CreateProgressBar("redacting backup") + var written int64 + for { + if err := cmd.Context().Err(); err != nil { + return fmt.Errorf("aborted redaction: %w", err) + } + + err := redactor.Next() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return fmt.Errorf("error redacting: %w", err) + } + + written++ + if err := bar.Set64(written); err != nil { + return fmt.Errorf("error incrementing progress bar: %w", err) + } + } + + if err := bar.Finish(); err != nil { + return fmt.Errorf("error finalizing progress bar: %w", err) + } + + fmt.Println("Redaction map:") + fmt.Println("--------------") + fmt.Println() + + // Draw a table of definitions, caveats and relations mapped. + tbl := table.New("Definition Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Definitions { + tbl.AddRow(k, v) + } + + tbl.Print() + fmt.Println() + + if len(redactor.RedactionMap().Caveats) > 0 { + tbl = table.New("Caveat Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Caveats { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + } + + tbl = table.New("Relation/Permission Name", "Redacted Name") + for k, v := range redactor.RedactionMap().Relations { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + + if len(redactor.RedactionMap().ObjectIDs) > 0 && cobrautil.MustGetBool(cmd, "print-redacted-object-ids") { + tbl = table.New("Object ID", "Redacted Object ID") + for k, v := range redactor.RedactionMap().ObjectIDs { + tbl.AddRow(k, v) + } + tbl.Print() + fmt.Println() + } + + return nil +} + +func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { + prefix := cobrautil.MustGetString(cmd, "prefix-filter") + decoder, closer, err := decoderFromArgs(args...) + if err != nil { + return err + } + + defer func(e *error) { *e = errors.Join(*e, closer.Close()) }(&err) + defer func(e *error) { *e = errors.Join(*e, decoder.Close()) }(&err) + + for rel, err := decoder.Next(); rel != nil && err == nil; rel, err = decoder.Next() { + if !hasRelPrefix(rel, prefix) { + continue + } + + relString, err := tuple.V1StringRelationship(rel) + if err != nil { + return err + } + + if _, err = fmt.Fprintln(out, replaceRelString(relString)); err != nil { + return err + } + } + + return nil +} + +func decoderFromArgs(args ...string) (*backupformat.Decoder, io.Closer, error) { + filename := "" // Default to stdin. + if len(args) > 0 { + filename = args[0] + } + + f, _, err := openRestoreFile(filename) + if err != nil { + return nil, nil, err + } + + decoder, err := backupformat.NewDecoder(f) + if err != nil { + return nil, nil, fmt.Errorf("error creating restore file decoder: %w", err) + } + + return decoder, f, nil +} + +func replaceRelString(rel string) string { + rel = strings.Replace(rel, "@", " ", 1) + return strings.Replace(rel, "#", " ", 1) +} + +func rewriteLegacy(schema string) string { + schema = string(missingAllowedTypes.ReplaceAll([]byte(schema), []byte("\n/* deleted missing allowed type error */"))) + return string(shortRelations.ReplaceAll([]byte(schema), []byte("\n/* deleted short relation name */"))) +} + +var sizeErrorRegEx = regexp.MustCompile(`received message larger than max \((\d+) vs. (\d+)\)`) + +func addSizeErrInfo(err error) error { + if err == nil { + return nil + } + + code := status.Code(err) + if code != codes.ResourceExhausted { + return err + } + + if !strings.Contains(err.Error(), "received message larger than max") { + return err + } + + matches := sizeErrorRegEx.FindStringSubmatch(err.Error()) + if len(matches) != 3 { + return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) + } + + necessaryByteCount, atoiErr := strconv.Atoi(matches[1]) + if atoiErr != nil { + return fmt.Errorf("%w: set flag --max-message-size=bytecounthere to increase the maximum allowable size", err) + } + + return fmt.Errorf("%w: set flag --max-message-size=%d to increase the maximum allowable size", err, 2*necessaryByteCount) +} |
